scmidas.model#

class scmidas.model.Decoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#

Bases: Module

Decoder class for multi-modal data with shared and modality-specific decoding layers.

Parameters:
  • dims_x – Dict[str, list] Output dimensions for each modality.

  • dims_h – Dict[str, list] Hidden dimensions for each modality.

  • dim_z – int Latent dimension size.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • out_trans – str Output activation function (e.g., ‘relu’).

  • drop – float Dropout rate.

  • kwargs – Dict Additional modality-specific configurations.

forward(latent_data: Tensor) Dict[str, Tensor][source]#

Forward pass for the decoder.

Parameters:

latent_data – torch.Tensor Latent variable input tensor of shape (batch_size, dim_z).

Returns:

Dict[str, torch.Tensor]

Decoded outputs for each modality.

class scmidas.model.Discriminator(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#

Bases: Module

Discriminator class for multi-modal latent variables.

Parameters:
  • dims_x – dict Input dimensions for each modality.

  • dims_s – dict Dimensions of the classes for each modality.

  • kwargs – dict Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.

calculate_loss(predictions: Dict[str, Tensor], targets: Dict[str, Tensor]) Tensor[source]#

Calculate cross-entropy loss for all modalities.

Parameters:
  • predictions – dict Dictionary of predicted logits for each modality.

  • targets – dict Dictionary of ground truth labels for each modality.

Returns:

torch.Tensor

Total normalized loss.

forward(latent_inputs: Dict[str, Tensor]) Dict[str, Tensor][source]#

Forward pass for the discriminator.

Parameters:

latent_inputs – dict Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c).

Returns:

dict

Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]).

class scmidas.model.Encoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#

Bases: Module

Encoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality.

  • dims_h – Dict[str, list] Hidden dimensions for each modality after pre-encoding.

  • dim_z – int Latent dimension size.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • out_trans – str Output activation function (e.g., ‘mish’).

  • drop – float Dropout rate.

  • kwargs – dict Additional modality-specific configurations.

Notes

By default, RNA and ADT data are log1p-transformed and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter trsf_before_enc_.

forward(data: Dict[str, Tensor], mask: Dict[str, Tensor])[source]#

Forward pass for the encoder.

Parameters:
  • data – dict Input data for each modality.

  • mask – dict Masks for each modality.

Returns:

dict

Mean values for latent space for each modality.

z_x_logvardict

Log-variance values for latent space for each modality.

Return type:

z_x_mu

class scmidas.model.MIDAS[source]#

Bases: LightningModule

MIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.

net#

VAE Variational Autoencoder for multi-modal data encoding and decoding.

dsc#

Discriminator Discriminator for distinguishing latent variables across batches.

configs#

dict Model and training configurations dynamically set as attributes.

automatic_optimization#

bool Controls whether optimization is automatic or manually defined. Always True.

static calc_consistency_loss(z_uni: dict)[source]#

Calculate the consistency loss for unified latent variables across modalities.

Parameters:

z_uni – dict Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim).

Returns:

float

Consistency loss computed as the variance of the unified latent variables.

Return type:

consistency_loss

static calc_dsc_loss(pred: dict, true: dict)[source]#

Calculate the discriminator loss using cross-entropy.

Parameters:
  • pred – dict Predicted logits for each modality.

  • true – dict Ground truth labels for each modality.

Returns:

float

Computed discriminator loss.

static calc_kld_loss(mu: Tensor, logvar: Tensor)[source]#

Calculate the KLD loss for a single latent space.

Parameters:
  • mu – torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim).

  • logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim).

Returns:

float

KLD loss for the latent space, normalized by batch size.

Return type:

kld_loss

static calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: Tensor, logvar: Tensor)[source]#

Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.

Parameters:
  • dim_c – int Dimension of the biological latent space.

  • dim_u – int Dimension of the technical latent space.

  • lam_kld_c – float Weight for KLD loss of the biological latent space.

  • lam_kld_u – float Weight for KLD loss of the technical latent space.

  • mu – torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).

  • logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).

Returns:

float

Weighted sum of KLD losses for the biological and technical latent spaces.

Return type:

kld_z_loss

static calc_recon_loss(x: dict, s: Tensor, e: dict, x_r_pre: dict, s_r_pre: dict, dist: dict, lam: dict)[source]#

Calculate the reconstruction loss for input data and predicted outputs.

Parameters:
  • x – dict Original input data for each modality (x^m).

  • s – torch.Tensor Ground truth batch labels.

  • e – dict Mask.

  • x_r_pre – dict Reconstructed predictions for each modality (x_r^m).

  • s_r_pre – dict Reconstructed predictions for batch labels.

  • dist – dict Dictionary specifying the distribution type for each modality’s decoder.

  • lam – dict Dictionary containing reconstruction loss weights for each modality and for s.

Returns:

float

Total reconstruction loss, normalized by batch size.

lossesdict

Dictionary containing reconstruction losses for each modality and for batch labels.

Return type:

total_loss

classmethod configure_data(configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int], s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save: int = 500, save_model_path: str = './saved_models/', sampler_type: str = 'auto')[source]#

Configure the data and model parameters for training.

Parameters:
  • configs – dict, Configurations of the model.

  • datalist – List[Dataset] List of datasets to be used for training.

  • dims_x – Dict[str, list] Dictionary specifying the dimensions of input features for each modality.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • s_joint – List[Dict[str, int]] Modality indices for each batch.

  • combs – List[List[str]] Combinations of modalities.

  • batch_size – int, optional Size of each training batch, by default 256.

  • n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.

  • save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.

  • sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.

Returns:

cls

Returns the configured class instance.

static configure_data_from_csv(data: dict, mask: dict, transform: dict = None)[source]#

Configure data from a CSV input.

Parameters:
  • data – list of dict List of data dictionaries, where keys are modalities and values are file paths.

  • mask – list of dict List of mask dictionaries, where keys are modalities and values are mask file paths.

  • transform – dict, optional Transformations to apply to specific modalities, default is binarization for ‘atac’, ‘met’, and ‘acc’.

Returns:

list

List of initialized MultiModalDataset objects.

dims_sdict

Dimensions for batch correction for each modality.

s_jointlist

Modality indices for each batch.

combslist

List of modality combinations for each batch.

Return type:

datasets

classmethod configure_data_from_dir(configs: dict, dir_path: str, transform: dict = None, sampler_type: str = 'auto', **kwargs: dict)[source]#

Configure data from a directory and apply optional transformations.

Parameters:
  • configs – dict, Configurations of the model.

  • dir_path – str Path to the directory containing data files.

  • transform – dict, optional A dictionary specifying transformations to apply to specific modalities. Example: {‘atac’: ‘binarize’} Default is None, which uses the default transformation settings.

  • sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.

  • kwargs – dict Additional parameters passed to configure_data().

Returns:

cls

Returns the configured class instance.

Raises:

ValueError – If transform is not a dictionary.

Examples

>>> from scmidas.model import MIDAS
>>> from scmidas.config import load_config
>>> configs = load_config()
>>> dir_path = './data_processed/xxx'
>>> transform = {'atac': 'binarize'}
>>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform)
configure_optimizers()[source]#

Configure optimizers for the network and discriminator.

Returns:

list

List of configured optimizers.

get_emb_umap(pred_dir: str, save_dir='.', save_fig=True, **kwargs)[source]#

Generate UMAP embeddings for biological (c) and technical (u) latent variables.

Parameters:
  • pred_dir – str Directory containing predicted data.

  • save_dir – str, optional Directory to save UMAP plots, by default ‘.’.

  • save_fig – bool, optional Whether to save the UMAP figures, by default True.

  • kwargs – dict Additional configurations for sc.pl.umap().

Returns:

tuple

List of AnnData objects and UMAP figures.

static get_info_from_dir(dir_path: str)[source]#

Extract data, mask, and feature dimensions from a directory.

Parameters:

dir_path – str Path to the directory containing data and mask files.

Returns:

list of dict

List of dictionaries where keys are modalities and values are file paths.

masklist of dict

List of dictionaries where keys are modalities and values are mask file paths.

dims_xdict

Dictionary containing feature dimensions for each modality.

Return type:

data

Notes

The directory should be organized as:

dataset/
    feat/
        # Dimensions of each modality: {mod1=[...], mod2=[...]}.
        # Split the data into chunks if the length of the list is greater than 1.
        # For instance, you can split the ATAC data by chromosomes.
        feat_dims.toml
    batch_0/
        mask/mod1.csv
        mask/mod2.csv
        vec/mod1/0000.csv # the first sample
        vec/mod1/0001.csv # the second sample
        ....
        vec/mod2/0000.csv
        vec/mod2/0001.csv
        ....
    batch_1/
        mask/mod1.csv
        mask/mod2.csv
        vec/mod1/0000.csv
        vec/mod1/0001.csv
        ....
        vec/mod2/0000.csv
        vec/mod2/0001.csv
    ....
load_checkpoint(checkpoint_path: str)[source]#

Load model and optimizer states from a checkpoint file.

Parameters:

checkpoint_path – str Path to the checkpoint file containing saved model and optimizer states.

Raises:

AssertionError – If the provided checkpoint path does not exist.

log_losses(recon_loss: Tensor, kld_loss, consistency_loss: Tensor, loss_net: Tensor, loss_dsc: Tensor, recon_dict: Dict[str, Tensor])[source]#

Log losses for monitoring and debugging during training.

Parameters:
  • recon_loss – torch.Tensor Reconstruction loss.

  • kld_loss – torch.Tensor KLD loss.

  • consistency_loss – torch.Tensor Consistency loss.

  • recon_dict – dict Per-modality reconstruction losses.

  • loss_net – torch.Tensor Total VAE loss.

  • loss_dsc – torch.Tensor Discriminator loss.

on_train_end()[source]#

Save the final model checkpoint at the end of training.

on_train_epoch_end()[source]#

Save a model checkpoint at the end of each training epoch with a meaningful filename.

predict(pred_dir: str, joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False)[source]#

Predict and save results for multiple modes, including joint latent, imputation, batch correction, and translation.

Parameters:
  • pred_dir – str Directory for saving prediction results.

  • joint_latent – bool, optional Whether to calculate and save joint latent representations.

  • mod_latent – bool, optional Whether to calculate and save modality-specific latent representations.

  • impute – bool, optional Whether to perform data imputation.

  • batch_correct – bool, optional Whether to apply batch correction.

  • translate – bool, optional Whether to perform modality translation.

  • input – bool, optional Whether to save input data.

Notes

See labomics/midas#7.

static print_info(mask: List[Dict[str, str]], datalist: List[Dataset])[source]#

Print summary of mask density and dataset information.

Parameters:
  • mask – list of dict List of mask.

  • datalist – list List of datasets.

save_checkpoint(checkpoint_path: str)[source]#

Save the current model and optimizer states to a checkpoint file.

Parameters:

checkpoint_path – str Path to save the checkpoint file.

Raises:

ValueError – If checkpoint_path is an invalid or empty string.

train_dataloader()[source]#

Create a DataLoader for training, using the appropriate sampler.

Returns:

DataLoader

Configured DataLoader instance for training.

train_discriminator(c_all: Dict[str, Tensor], targets: Dict[str, Tensor])[source]#

Train the discriminator with modality-specific latent representations.

Parameters:
  • c_all – dict Dictionary of latent representations for each modality.

  • targets – dict Ground truth batch labels for each modality.

training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]#

Executes a single training step for MIDAS.

Parameters:
  • batch – dict Input batch containing modality data, batch indices, and masks.

  • batch_idx – int Index of the current training batch.

Returns:

torch.Tensor

Total VAE loss for the current batch.

static update_model(loss: Tensor, model: Module, optimizer: Optimizer, grad_clip=-1)[source]#

Update model parameters using backpropagation.

Parameters:
  • loss – torch.Tensor Computed loss for backpropagation.

  • model – torch.nn.Module Model to update.

  • optimizer – torch.optim.Optimizer Optimizer for parameter updates.

  • grad_clip – bool True to allow clipping gradient.

class scmidas.model.S_Decoder(n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float)[source]#

Bases: Module

Decoder for reconstructing batch ID.

Parameters:
  • n_batches – int Number of distinct batches.

  • dims_dec_s – List[int] List of dimensions for hidden layers in the decoder.

  • dim_u – int Latent dimension size for the input.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • drop – float Dropout rate.

forward(data: Tensor) Tensor[source]#

Forward pass for S_Decoder.

Parameters:

data – torch.Tensor Latent input tensor of shape (batch_size, dim_u).

Returns:

torch.Tensor

Reconstructed tensor of shape (batch_size, n_batches).

class scmidas.model.S_Encoder(n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float)[source]#

Bases: Module

Encoder for batch-specific latent variables.

Parameters:
  • n_batches – int Number of distinct batches.

  • dims_enc_s – List[int] List of dimensions for hidden layers in the encoder.

  • dim_z – int Latent dimension size for the output.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • drop – float Dropout rate.

forward(data: Tensor) Tensor[source]#

Forward pass for S_Encoder.

Parameters:

data – torch.Tensor Input tensor of shape (batch_size, 1), containing batch indices.

Returns:

torch.Tensor

Encoded tensor of shape (batch_size, dim_z * 2).

class scmidas.model.VAE(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#

Bases: Module

Variational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling.

Parameters:
  • dims_x – dict Input dimensions for each modality.

  • dims_s – dict Dimensions of the classes for each modality.

  • kwargs – dict Additional configurations for encoders, decoders, and other modules.

encode_batch(s: Tensor) Tuple[list, list][source]#

Encode batch indices latent variables.

Parameters:

s – torch.Tensor Batch indices.

Returns:

  • List[torch.Tensor]: Mean of batch indices latent variables.

  • List[torch.Tensor]: Log-variance of batch indices latent variables.

Return type:

Tuple

forward(data: Dict[str, Tensor]) tuple[source]#

Forward pass for the VAE.

Parameters:

data – dict Input data dictionary containing: - ‘x’: Dict[str, torch.Tensor], modality-specific input data. - ‘e’: Dict[str, torch.Tensor], modality-specific masks. - ‘s’ (optional): torch.Tensor, dimensions of the output classes for each modality.

Returns:

  • x_r_pre (dict): Reconstructed modality-specific data.

  • s_r_pre (torch.Tensor or None): Reconstructed batch indices.

  • z_mu (torch.Tensor): Mean of the combined latent variables.

  • z_logvar (torch.Tensor): Log-variance of the combined latent variables.

  • z (torch.Tensor): Sampled latent variables.

  • c (torch.Tensor): Biological information variables.

  • u (torch.Tensor): Technical noise variables.

  • z_uni (dict): Unified latent variables for each modality.

  • c_all (dict): Modality-specific Biological information variables.

Return type:

Tuple

gen_real_data(x_r_pre: Dict[str, Tensor], sampling: bool = True) Dict[str, Tensor][source]#

Generate real data from reconstructed data.

Parameters:
  • x_r_pre – dict Dictionary of reconstructed data tensors for each modality.

  • sampling – bool, optional Whether to sample the output (default: True).

Returns:

dict

Generated real data for each modality.

generate_unified_latent(z_x_mu: Dict[str, Tensor], z_x_logvar: Dict[str, Tensor], z_s_mu: List[Tensor], z_s_logvar: List[Tensor], c: Tensor) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#

Generate unified latent variables and modality-specific representations.

Parameters:
  • z_x_mu – dict Mean of modality-specific latent variables.

  • z_x_logvar – dict Log-variance of modality-specific latent variables.

  • z_s_mu – list Mean of modality-specific batch indices latent variables.

  • z_s_logvar – list Log-variance of modality-specific batch indices latent variables.

  • c – torch.Tensor Biological information.

Returns:

Tuple
  • Unified latent variables (z_uni) for each modality.

  • Modality-specific shared representations (c_all).

get_dim_h() Dict[str, List[int]][source]#

Compute hidden dimensions for each modality.

Returns:

dict

A dictionary containing the hidden dimensions for each modality.

static poe(mus: List[Tensor], logvars: List[Tensor]) Tuple[Tensor, Tensor][source]#

Product of Experts (PoE) for combining Gaussian distributions.

Parameters:
  • mus – list of torch.Tensor List of mean tensors for each Gaussian.

  • logvars – list of torch.Tensor List of log-variance tensors for each Gaussian.

Returns:

Tuple
  • Mean of the combined Gaussian distribution.

  • Log-variance of the combined Gaussian distribution.

static sample(name: str, data: Tensor, sampling: bool) Tensor[source]#

Map a sampling function based on the distribution name.

Parameters:
  • name – str Name of the distribution.

  • data – torch.Tensor Input data tensor.

  • sampling – bool Whether to apply sampling.

Returns:

torch.Tensor

Sampled or original data tensor.

static sample_gaussian(mu: Tensor, logvar: Tensor) Tensor[source]#

Sample from a Gaussian distribution using the reparameterization trick.

Parameters:
  • mu – torch.Tensor Mean of the Gaussian distribution.

  • logvar – torch.Tensor Log-variance of the Gaussian distribution.

Returns:

torch.Tensor

Sampled tensor.

sample_latent(z_mu: Tensor, z_logvar: Tensor) Tensor[source]#

Sample latent variables from a Gaussian distribution.

Parameters:
  • z_mu – torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim).

  • z_logvar – torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim).

Returns:

Sampled latent variables of shape (batch_size, latent_dim).

Return type:

torch.Tensor