Development Instructions#
MIDAS currently supports integration of RNA, ADT, and ATAC data. If you’d like to develop the model, follow the instructions below.
Framework Overview#
MIDAS is configured via the scmidas/model_config.toml file and primarily employs Multi-Layer Perceptrons (MLPs). Below are the key components of the MIDAS framework:
Key Components#
Data Encoder: Encodes each modality into Gaussian-distributed latent features, including the means and log-transformed variances.
Data Decoder: Reconstructs counts for each modality using the joint latent features as input.
Batch Indices Encoder: Encodes batch indices for each modality into Gaussian-distributed latent features.
Batch Indices Decoder: Reconstructs batch indices for each modality using the joint latent features.
Discriminator: A group of classifiers that categorizes modality-specific latents and joint latents. Only the biological part of the latents is used for this classification.
Neural network architecture for MIDAS:
Transformation and Distribution Functions#
MIDAS includes a range of pre-defined transformation and distribution functions. These can be customized or extended to support new modalities, providing flexibility for various data types and workflows.
Transformation Functions#
MIDAS provides several transformation pairs designed to prepare data for training. These include:
binarize
Input Transformation: Converts data into binary form.
Output Transformation: None.
log1p
Input Transformation: Apply
log1p(log(x + 1)) transformation.Output Transformation: Apply the exponential function (
exp).
Note
Transformation functions specified via the transform parameter in scmidas.MIDAS.configure_data_from_dir()
are applied exclusively when retrieving items from the dataset (via get_item()).
If transformations are defined using trsf_before_enc_{mod} in the configs,
both the transformation and its inverse are applied, ensuring consistency throughout the training process.
Distribution Functions#
Distribution functions in MIDAS are defined using a combination of the loss function, sampling function, and activation function. These functions enable the modeling of data distributions during training:
Loss Function: Defines the reconstruction loss function.
Sampling: Specifies how to sample from parameters.
Activation: Configures the output layer’s activation function in the decoder.
Pre-defined distribution functions include:
POISSON
Loss Function: Poisson loss
Sampling: Poisson sampling
Activation: None
BERNOULLI
Loss Function: Binary cross-entropy loss
Sampling: Bernoulli sampling
Activation: Sigmoid
Default Configurations#
Here, we show the default settings of the model:
Embeddings#
Key |
Value |
Description |
|---|---|---|
dim_c |
32 |
Latent dimension for biological information c. |
dim_u |
2 |
Latent dimension for technical information u (always be small to avoid capturing biological information). |
Basic Network Structure (MLP)#
Key |
Value |
Description |
|---|---|---|
norm |
‘ln’ |
Use layer normalization. ‘bn’, ‘ln’, or False. |
drop |
0.2 |
Dropout rate. |
out_trans |
‘mish’ |
Activation function for the output. Support: ‘tanh’, ‘relu’, ‘silu’, ‘mish’, ‘sigmoid’, ‘softmax’, ‘log_softmax’. |
RNA#
Key |
Value |
Description |
|---|---|---|
trsf_before_enc_rna |
‘log1p’ |
Apply log1p transformation before encoding. Exponential transformation will be applied after decoding. |
distribution_dec_rna |
‘POISSON’ |
Poisson distribution assumption for decoder. |
ADT#
Key |
Value |
Description |
|---|---|---|
trsf_before_enc_adt |
‘log1p’ |
Apply log1p transformation before encoding. Exponential transformation will be applied after decoding. |
distribution_dec_adt |
‘POISSON’ |
Poisson distribution assumption for decoder. |
ATAC#
Key |
Value |
Description |
|---|---|---|
dims_before_enc_atac |
[512, 128] |
Independent MLP structure before shared encoder. It is used to compress the data chunks of the ATAC modality. |
dims_after_dec_atac |
[128, 512] |
Independent MLP structure after shared decoder. It expands the embeddings to reconstruct the ATAC modality. |
distribution_dec_atac |
‘BERNOULLI’ |
Bernoulli distribution assumption for decoder. Use BCE loss. |
Batch Indices#
Key |
Value |
Description |
|---|---|---|
s_drop_rate |
0.1 |
Rate to drop batch indices during training. |
dims_enc_s |
[16, 16] |
Encoder structure. |
dims_dec_s |
[16, 16] |
Decoder structure. |
dims_dsc |
[128, 64] |
Structure of the discriminator. |
Training#
Key |
Value |
Description |
|---|---|---|
optim_net |
‘AdamW’ |
Optimizer for the main network. |
lr_net |
1e-4 |
Learning rate for the main network. |
optim_dsc |
‘AdamW’ |
Optimizer for the discriminator. |
lr_dsc |
1e-4 |
Learning rate for the discriminator. |
grad_clip |
-1 |
Gradient clipping ( |
Loss Weights#
Key |
Value |
Description |
|---|---|---|
lam_kld_c |
1 |
Weight for variable c’s KLD loss. |
lam_kld_u |
5 |
Weight for variable u’s KLD loss. |
lam_kld |
1 |
Weight for total KLD loss. |
lam_recon |
1 |
Weight for reconstruction loss. |
lam_dsc |
30 |
Weight for discriminator loss (for training the discriminator). |
lam_adv |
1 |
Weight for adversarial loss. loss = VAE_loss - disc_loss * lam_adv |
lam_alignment |
50 |
Weight for modality alignment loss. |
lam_recon_rna |
1 |
Weight for RNA reconstruction loss. |
lam_recon_adt |
1 |
Weight for ADT reconstruction loss. |
lam_recon_atac |
1 |
Weight for ATAC reconstruction loss. |
lam_recon_s |
1000 |
Weight for batch indices reconstruction loss. |
Discriminator Training#
Key |
Value |
Description |
|---|---|---|
n_iter_disc |
3 |
Number of discriminator training iterations before training the VAE. |
Data Loader#
Key |
Value |
Description |
|---|---|---|
num_workers |
20 |
Number of worker threads for data loading. |
pin_memory |
true |
Load data into pinned memory. |
persistent_workers |
true |
Persistent worker threads. |
n_max |
10000 |
Maximum number of samples per batch. |
Extending MIDAS to More Modalities#
Step 1: Defining New Modality#
To integrate new modalities into the MIDAS framework, you need to define several key components, including the Data Encoder, Data Decoder, Loss and Distribution functions that are specific to the new modality. This allows MIDAS to process and reconstruct data from diverse biological data types.
Before making any modifications, you need to load the model configurations. You can do this using the following command:
from scmidas.config import load_config
configs = load_config()
Once the configuration is loaded, you can customize the encoder, decoder, and other settings for the new modality.
Data Encoder#
The data encoder transforms input data through modality-specific and shared layers to produce latent representations. Configure it as follows:
(Optional) Transformation Before Encoding: Specify the transformation function to be applied before encoding.
Example:
configs['trsf_before_enc_{new_mod}'] = 'log1p'
Attention
If the specified transformation is not registered, an error will occur. Refer to Registering Transformations for details.
(Optional) Dimensionality Reduction Layer: If the data is split into chunks, define the modality-specific layers for encoding each chunk individually before merging them.
Example:
configs['dims_before_enc_{new_mod}'] = [512, 128] # First encode to 512 dimensions, then to 128
Data Decoder#
The data decoder reconstructs original data from latent features. Configure the shared layers and dimensionality expansion layers as follows:
(Optional) Dimensionality Expansion Layer: If the data is split into chunks, define the dimensionality expansion layers after the shared layers.
Example:
configs['dims_after_dec_{new_mod}'] = [128, 512]
Output Distribution: Set the output distribution for each modality.
Example:
configs['distribution_dec_{new_mod}'] = 'POISSON'
Attention
If the specified distribution is not registered, an error will occur. Refer to Registering Distributions for guidance.
Reconstruction Loss Weight#
Adjust the weight for reconstruction loss as needed:
configs['lam_recon_{new_mod}'] = 1 # Adjust as needed
Step 2: (Optional) Registering New Functions#
To add new functionalities, register transformation and distribution functions as follows:
Registering New Transformation Functions#
from scmidas.nn import transform_registry
transform_registry.register(name, fn, inverse_fn)
Registering New Distribution Functions#
from scmidas.nn import distribution_registry
distribution_registry.register(name, loss_fn, sampling_fn, activate_fn)
Calling for Contributions#
We encourage you to contribute to MIDAS by submitting pull requests for new features, enhancements, or bug fixes. Contributions will be reviewed and, if suitable, integrated into the main repository. Thank you for helping us improve MIDAS!