scmidas.model#
- class scmidas.model.Decoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#
Bases:
ModuleDecoder class for multi-modal data with shared and modality-specific decoding layers.
- Parameters:
dims_x – Dict[str, list] Output dimensions for each modality.
dims_h – Dict[str, list] Hidden dimensions for each modality.
dim_z – int Latent dimension size.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
out_trans – str Output activation function (e.g., ‘relu’).
drop – float Dropout rate.
kwargs – Dict Additional modality-specific configurations.
- class scmidas.model.Discriminator(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#
Bases:
ModuleDiscriminator class for multi-modal latent variables.
- Parameters:
dims_x – dict Input dimensions for each modality.
dims_s – dict Dimensions of the classes for each modality.
kwargs – dict Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.
- calculate_loss(predictions: Dict[str, Tensor], targets: Dict[str, Tensor]) Tensor[source]#
Calculate cross-entropy loss for all modalities.
- Parameters:
predictions – dict Dictionary of predicted logits for each modality.
targets – dict Dictionary of ground truth labels for each modality.
- Returns:
- torch.Tensor
Total normalized loss.
- forward(latent_inputs: Dict[str, Tensor]) Dict[str, Tensor][source]#
Forward pass for the discriminator.
- Parameters:
latent_inputs – dict Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c).
- Returns:
- dict
Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]).
- class scmidas.model.Encoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#
Bases:
ModuleEncoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality.
dims_h – Dict[str, list] Hidden dimensions for each modality after pre-encoding.
dim_z – int Latent dimension size.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
out_trans – str Output activation function (e.g., ‘mish’).
drop – float Dropout rate.
kwargs – dict Additional modality-specific configurations.
Notes
By default, RNA and ADT data are log1p-transformed and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter trsf_before_enc_.
- forward(data: Dict[str, Tensor], mask: Dict[str, Tensor])[source]#
Forward pass for the encoder.
- Parameters:
data – dict Input data for each modality.
mask – dict Masks for each modality.
- Returns:
- dict
Mean values for latent space for each modality.
- z_x_logvardict
Log-variance values for latent space for each modality.
- Return type:
z_x_mu
- class scmidas.model.MIDAS[source]#
Bases:
LightningModuleMIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.
- net#
VAE Variational Autoencoder for multi-modal data encoding and decoding.
- dsc#
Discriminator Discriminator for distinguishing latent variables across batches.
- configs#
dict Model and training configurations dynamically set as attributes.
- automatic_optimization#
bool Controls whether optimization is automatic or manually defined. Always True.
- static calc_consistency_loss(z_uni: dict)[source]#
Calculate the consistency loss for unified latent variables across modalities.
- Parameters:
z_uni – dict Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim).
- Returns:
- float
Consistency loss computed as the variance of the unified latent variables.
- Return type:
consistency_loss
- static calc_dsc_loss(pred: dict, true: dict)[source]#
Calculate the discriminator loss using cross-entropy.
- Parameters:
pred – dict Predicted logits for each modality.
true – dict Ground truth labels for each modality.
- Returns:
- float
Computed discriminator loss.
- static calc_kld_loss(mu: Tensor, logvar: Tensor)[source]#
Calculate the KLD loss for a single latent space.
- Parameters:
mu – torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim).
logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim).
- Returns:
- float
KLD loss for the latent space, normalized by batch size.
- Return type:
kld_loss
- static calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: Tensor, logvar: Tensor)[source]#
Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.
- Parameters:
dim_c – int Dimension of the biological latent space.
dim_u – int Dimension of the technical latent space.
lam_kld_c – float Weight for KLD loss of the biological latent space.
lam_kld_u – float Weight for KLD loss of the technical latent space.
mu – torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).
logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).
- Returns:
- float
Weighted sum of KLD losses for the biological and technical latent spaces.
- Return type:
kld_z_loss
- static calc_recon_loss(x: dict, s: Tensor, e: dict, x_r_pre: dict, s_r_pre: dict, dist: dict, lam: dict)[source]#
Calculate the reconstruction loss for input data and predicted outputs.
- Parameters:
x – dict Original input data for each modality (x^m).
s – torch.Tensor Ground truth batch labels.
e – dict Mask.
x_r_pre – dict Reconstructed predictions for each modality (x_r^m).
s_r_pre – dict Reconstructed predictions for batch labels.
dist – dict Dictionary specifying the distribution type for each modality’s decoder.
lam – dict Dictionary containing reconstruction loss weights for each modality and for s.
- Returns:
- float
Total reconstruction loss, normalized by batch size.
- lossesdict
Dictionary containing reconstruction losses for each modality and for batch labels.
- Return type:
total_loss
- classmethod configure_data(configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int], s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save: int = 500, save_model_path: str = './saved_models/', sampler_type: str = 'auto')[source]#
Configure the data and model parameters for training.
- Parameters:
configs – dict, Configurations of the model.
datalist – List[Dataset] List of datasets to be used for training.
dims_x – Dict[str, list] Dictionary specifying the dimensions of input features for each modality.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
s_joint – List[Dict[str, int]] Modality indices for each batch.
combs – List[List[str]] Combinations of modalities.
batch_size – int, optional Size of each training batch, by default 256.
n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.
save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.
sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.
- Returns:
- cls
Returns the configured class instance.
- static configure_data_from_csv(data: dict, mask: dict, transform: dict = None)[source]#
Configure data from a CSV input.
- Parameters:
data – list of dict List of data dictionaries, where keys are modalities and values are file paths.
mask – list of dict List of mask dictionaries, where keys are modalities and values are mask file paths.
transform – dict, optional Transformations to apply to specific modalities, default is binarization for ‘atac’, ‘met’, and ‘acc’.
- Returns:
- list
List of initialized MultiModalDataset objects.
- dims_sdict
Dimensions for batch correction for each modality.
- s_jointlist
Modality indices for each batch.
- combslist
List of modality combinations for each batch.
- Return type:
datasets
- classmethod configure_data_from_dir(configs: dict, dir_path: str, transform: dict = None, sampler_type: str = 'auto', **kwargs: dict)[source]#
Configure data from a directory and apply optional transformations.
- Parameters:
configs – dict, Configurations of the model.
dir_path – str Path to the directory containing data files.
transform – dict, optional A dictionary specifying transformations to apply to specific modalities. Example: {‘atac’: ‘binarize’} Default is None, which uses the default transformation settings.
sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.
kwargs – dict Additional parameters passed to configure_data().
- Returns:
- cls
Returns the configured class instance.
- Raises:
ValueError – If transform is not a dictionary.
Examples
>>> from scmidas.model import MIDAS >>> from scmidas.config import load_config >>> configs = load_config() >>> dir_path = './data_processed/xxx' >>> transform = {'atac': 'binarize'} >>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform)
- configure_optimizers()[source]#
Configure optimizers for the network and discriminator.
- Returns:
- list
List of configured optimizers.
- get_emb_umap(pred_dir: str, save_dir='.', save_fig=True, **kwargs)[source]#
Generate UMAP embeddings for biological (c) and technical (u) latent variables.
- Parameters:
pred_dir – str Directory containing predicted data.
save_dir – str, optional Directory to save UMAP plots, by default ‘.’.
save_fig – bool, optional Whether to save the UMAP figures, by default True.
kwargs – dict Additional configurations for sc.pl.umap().
- Returns:
- tuple
List of AnnData objects and UMAP figures.
- static get_info_from_dir(dir_path: str)[source]#
Extract data, mask, and feature dimensions from a directory.
- Parameters:
dir_path – str Path to the directory containing data and mask files.
- Returns:
- list of dict
List of dictionaries where keys are modalities and values are file paths.
- masklist of dict
List of dictionaries where keys are modalities and values are mask file paths.
- dims_xdict
Dictionary containing feature dimensions for each modality.
- Return type:
data
Notes
The directory should be organized as:
dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv # the first sample vec/mod1/0001.csv # the second sample .... vec/mod2/0000.csv vec/mod2/0001.csv .... batch_1/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv vec/mod1/0001.csv .... vec/mod2/0000.csv vec/mod2/0001.csv ....
- load_checkpoint(checkpoint_path: str)[source]#
Load model and optimizer states from a checkpoint file.
- Parameters:
checkpoint_path – str Path to the checkpoint file containing saved model and optimizer states.
- Raises:
AssertionError – If the provided checkpoint path does not exist.
- log_losses(recon_loss: Tensor, kld_loss, consistency_loss: Tensor, loss_net: Tensor, loss_dsc: Tensor, recon_dict: Dict[str, Tensor])[source]#
Log losses for monitoring and debugging during training.
- Parameters:
recon_loss – torch.Tensor Reconstruction loss.
kld_loss – torch.Tensor KLD loss.
consistency_loss – torch.Tensor Consistency loss.
recon_dict – dict Per-modality reconstruction losses.
loss_net – torch.Tensor Total VAE loss.
loss_dsc – torch.Tensor Discriminator loss.
- on_train_epoch_end()[source]#
Save a model checkpoint at the end of each training epoch with a meaningful filename.
- predict(pred_dir: str, joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False)[source]#
Predict and save results for multiple modes, including joint latent, imputation, batch correction, and translation.
- Parameters:
pred_dir – str Directory for saving prediction results.
joint_latent – bool, optional Whether to calculate and save joint latent representations.
mod_latent – bool, optional Whether to calculate and save modality-specific latent representations.
impute – bool, optional Whether to perform data imputation.
batch_correct – bool, optional Whether to apply batch correction.
translate – bool, optional Whether to perform modality translation.
input – bool, optional Whether to save input data.
Notes
See labomics/midas#7.
- static print_info(mask: List[Dict[str, str]], datalist: List[Dataset])[source]#
Print summary of mask density and dataset information.
- Parameters:
mask – list of dict List of mask.
datalist – list List of datasets.
- save_checkpoint(checkpoint_path: str)[source]#
Save the current model and optimizer states to a checkpoint file.
- Parameters:
checkpoint_path – str Path to save the checkpoint file.
- Raises:
ValueError – If checkpoint_path is an invalid or empty string.
- train_dataloader()[source]#
Create a DataLoader for training, using the appropriate sampler.
- Returns:
- DataLoader
Configured DataLoader instance for training.
- train_discriminator(c_all: Dict[str, Tensor], targets: Dict[str, Tensor])[source]#
Train the discriminator with modality-specific latent representations.
- Parameters:
c_all – dict Dictionary of latent representations for each modality.
targets – dict Ground truth batch labels for each modality.
- training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]#
Executes a single training step for MIDAS.
- Parameters:
batch – dict Input batch containing modality data, batch indices, and masks.
batch_idx – int Index of the current training batch.
- Returns:
- torch.Tensor
Total VAE loss for the current batch.
- static update_model(loss: Tensor, model: Module, optimizer: Optimizer, grad_clip=-1)[source]#
Update model parameters using backpropagation.
- Parameters:
loss – torch.Tensor Computed loss for backpropagation.
model – torch.nn.Module Model to update.
optimizer – torch.optim.Optimizer Optimizer for parameter updates.
grad_clip – bool True to allow clipping gradient.
- class scmidas.model.S_Decoder(n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float)[source]#
Bases:
ModuleDecoder for reconstructing batch ID.
- Parameters:
n_batches – int Number of distinct batches.
dims_dec_s – List[int] List of dimensions for hidden layers in the decoder.
dim_u – int Latent dimension size for the input.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
drop – float Dropout rate.
- class scmidas.model.S_Encoder(n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float)[source]#
Bases:
ModuleEncoder for batch-specific latent variables.
- Parameters:
n_batches – int Number of distinct batches.
dims_enc_s – List[int] List of dimensions for hidden layers in the encoder.
dim_z – int Latent dimension size for the output.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
drop – float Dropout rate.
- class scmidas.model.VAE(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#
Bases:
ModuleVariational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling.
- Parameters:
dims_x – dict Input dimensions for each modality.
dims_s – dict Dimensions of the classes for each modality.
kwargs – dict Additional configurations for encoders, decoders, and other modules.
- encode_batch(s: Tensor) Tuple[list, list][source]#
Encode batch indices latent variables.
- Parameters:
s – torch.Tensor Batch indices.
- Returns:
List[torch.Tensor]: Mean of batch indices latent variables.
List[torch.Tensor]: Log-variance of batch indices latent variables.
- Return type:
Tuple
- forward(data: Dict[str, Tensor]) tuple[source]#
Forward pass for the VAE.
- Parameters:
data – dict Input data dictionary containing: - ‘x’: Dict[str, torch.Tensor], modality-specific input data. - ‘e’: Dict[str, torch.Tensor], modality-specific masks. - ‘s’ (optional): torch.Tensor, dimensions of the output classes for each modality.
- Returns:
x_r_pre (dict): Reconstructed modality-specific data.
s_r_pre (torch.Tensor or None): Reconstructed batch indices.
z_mu (torch.Tensor): Mean of the combined latent variables.
z_logvar (torch.Tensor): Log-variance of the combined latent variables.
z (torch.Tensor): Sampled latent variables.
c (torch.Tensor): Biological information variables.
u (torch.Tensor): Technical noise variables.
z_uni (dict): Unified latent variables for each modality.
c_all (dict): Modality-specific Biological information variables.
- Return type:
Tuple
- gen_real_data(x_r_pre: Dict[str, Tensor], sampling: bool = True) Dict[str, Tensor][source]#
Generate real data from reconstructed data.
- Parameters:
x_r_pre – dict Dictionary of reconstructed data tensors for each modality.
sampling – bool, optional Whether to sample the output (default: True).
- Returns:
- dict
Generated real data for each modality.
- generate_unified_latent(z_x_mu: Dict[str, Tensor], z_x_logvar: Dict[str, Tensor], z_s_mu: List[Tensor], z_s_logvar: List[Tensor], c: Tensor) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#
Generate unified latent variables and modality-specific representations.
- Parameters:
z_x_mu – dict Mean of modality-specific latent variables.
z_x_logvar – dict Log-variance of modality-specific latent variables.
z_s_mu – list Mean of modality-specific batch indices latent variables.
z_s_logvar – list Log-variance of modality-specific batch indices latent variables.
c – torch.Tensor Biological information.
- Returns:
- Tuple
Unified latent variables (z_uni) for each modality.
Modality-specific shared representations (c_all).
- get_dim_h() Dict[str, List[int]][source]#
Compute hidden dimensions for each modality.
- Returns:
- dict
A dictionary containing the hidden dimensions for each modality.
- static poe(mus: List[Tensor], logvars: List[Tensor]) Tuple[Tensor, Tensor][source]#
Product of Experts (PoE) for combining Gaussian distributions.
- Parameters:
mus – list of torch.Tensor List of mean tensors for each Gaussian.
logvars – list of torch.Tensor List of log-variance tensors for each Gaussian.
- Returns:
- Tuple
Mean of the combined Gaussian distribution.
Log-variance of the combined Gaussian distribution.
- static sample(name: str, data: Tensor, sampling: bool) Tensor[source]#
Map a sampling function based on the distribution name.
- Parameters:
name – str Name of the distribution.
data – torch.Tensor Input data tensor.
sampling – bool Whether to apply sampling.
- Returns:
- torch.Tensor
Sampled or original data tensor.
- static sample_gaussian(mu: Tensor, logvar: Tensor) Tensor[source]#
Sample from a Gaussian distribution using the reparameterization trick.
- Parameters:
mu – torch.Tensor Mean of the Gaussian distribution.
logvar – torch.Tensor Log-variance of the Gaussian distribution.
- Returns:
- torch.Tensor
Sampled tensor.
- sample_latent(z_mu: Tensor, z_logvar: Tensor) Tensor[source]#
Sample latent variables from a Gaussian distribution.
- Parameters:
z_mu – torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim).
z_logvar – torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim).
- Returns:
Sampled latent variables of shape (batch_size, latent_dim).
- Return type:
torch.Tensor