Mosaic Integration of RNA+ADT#

In this tutorial, we demonstrate how to integrate a mosaic dataset that includes RNA and ADT data. The goal is to perform data integration and imputation, then evaluate the imputed counts by calculating the Pearson’s correlation coefficient (r) between the predicted imputed counts and the ground-truth counts.

Step 1: Downloading the Demo Data#

[ ]:
from scmidas.data import download_data
download_data('wnn_mosaic_3batch', './dataset')

Step 2: Setting Up the Environment#

Before we begin, ensure that the required environment is set up. This includes importing the necessary packages and dependencies.

[ ]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

from scmidas.config import load_config
from scmidas.model import MIDAS
from scmidas.utils import load_predicted
import lightning as L

import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

sc.set_figure_params(figsize=(4, 4))

Step 3: Configuring the Model#

In this step, we configure the model for our dataset.

[ ]:
configs = load_config() # load basic configurations
[ ]:
task = 'wnn_mosaic_3batch'
model = MIDAS.configure_data_from_dir(configs, './dataset/'+task+'/data')
INFO:root:Input data:
         #CELL    #RNA   #ADT  #VALID_RNA  #VALID_ADT
BATCH 0   6378  3617.0    NaN      3617.0         NaN
BATCH 1   6952     NaN  224.0         NaN       224.0
BATCH 2   8908  3617.0  224.0      3617.0       224.0

Step 4: Training the Model (~3.5h)#

After configuring the model, we proceed with training. This step typically takes around 3.5 hours using a single V100 GPU, depending on your system’s specifications. If you prefer a quicker result, you can set max_epochs=500 for a reasonable outcome, instead of the default max_epochs=2000 for the best result.

[4]:
trainer = L.Trainer(max_epochs=2000)
trainer.fit(model=model)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name | Type          | Params | Mode
-----------------------------------------------
0 | net  | VAE           | 8.2 M  | train
1 | dsc  | Discriminator | 39.0 K | train
-----------------------------------------------
8.2 M     Trainable params
0         Non-trainable params
8.2 M     Total params
32.817    Total estimated model params size (MB)
154       Modules in train mode
0         Modules in eval mode
INFO:root:Total number of samples: 22238 from 3 datasets.
INFO:root:Using MultiBatchSampler for data loading.
/root/anaconda3/envs/pl/lib/python3.12/site-packages/torch/utils/data/sampler.py:76: UserWarning: `data_source` argument is not used and will be removed in 2.2.0.You may still have custom implementation that utilizes it.
  warnings.warn(
INFO:root:DataLoader created with batch size 256 and 20 workers.
Epoch 500: 100%|██████████| 105/105 [00:05<00:00, 18.01it/s, v_num=8, loss_/recon_loss_step=1.27e+3, loss_/kld_loss_step=56.30, loss_/consistency_loss_step=0.000, loss/net_step=1.27e+3, loss/dsc_step=49.40, loss_/recon_loss_epoch=1.48e+3, loss_/kld_loss_epoch=72.70, loss_/consistency_loss_epoch=17.00, loss/net_epoch=1.51e+3, loss/dsc_epoch=56.40]
INFO:root:Checkpoint successfully saved to "./saved_models/model_epoch500_20241217-030941.pt".
INFO:root:Checkpoint saved for epoch "500" at "./saved_models/model_epoch500_20241217-030941.pt".
Epoch 1000: 100%|██████████| 105/105 [00:06<00:00, 15.33it/s, v_num=8, loss_/recon_loss_step=2.45e+3, loss_/kld_loss_step=95.80, loss_/consistency_loss_step=36.80, loss/net_step=2.51e+3, loss/dsc_step=67.40, loss_/recon_loss_epoch=1.46e+3, loss_/kld_loss_epoch=69.40, loss_/consistency_loss_epoch=15.10, loss/net_epoch=1.49e+3, loss/dsc_epoch=55.00]
INFO:root:Checkpoint successfully saved to "./saved_models/model_epoch1000_20241217-040044.pt".
INFO:root:Checkpoint saved for epoch "1000" at "./saved_models/model_epoch1000_20241217-040044.pt".
Epoch 1500: 100%|██████████| 105/105 [00:06<00:00, 16.22it/s, v_num=8, loss_/recon_loss_step=2.43e+3, loss_/kld_loss_step=95.60, loss_/consistency_loss_step=36.00, loss/net_step=2.49e+3, loss/dsc_step=66.80, loss_/recon_loss_epoch=1.45e+3, loss_/kld_loss_epoch=68.50, loss_/consistency_loss_epoch=15.20, loss/net_epoch=1.48e+3, loss/dsc_epoch=55.30]
INFO:root:Checkpoint successfully saved to "./saved_models/model_epoch1500_20241217-045310.pt".
INFO:root:Checkpoint saved for epoch "1500" at "./saved_models/model_epoch1500_20241217-045310.pt".
Epoch 1999: 100%|██████████| 105/105 [00:05<00:00, 17.61it/s, v_num=8, loss_/recon_loss_step=2.41e+3, loss_/kld_loss_step=95.70, loss_/consistency_loss_step=35.00, loss/net_step=2.47e+3, loss/dsc_step=69.20, loss_/recon_loss_epoch=1.44e+3, loss_/kld_loss_epoch=68.30, loss_/consistency_loss_epoch=12.90, loss/net_epoch=1.46e+3, loss/dsc_epoch=54.60]
`Trainer.fit` stopped: `max_epochs=2000` reached.
Epoch 1999: 100%|██████████| 105/105 [00:05<00:00, 17.58it/s, v_num=8, loss_/recon_loss_step=2.41e+3, loss_/kld_loss_step=95.70, loss_/consistency_loss_step=35.00, loss/net_step=2.47e+3, loss/dsc_step=69.20, loss_/recon_loss_epoch=1.44e+3, loss_/kld_loss_epoch=68.30, loss_/consistency_loss_epoch=12.90, loss/net_epoch=1.46e+3, loss/dsc_epoch=54.60]
INFO:root:Checkpoint successfully saved to "./saved_models/model_epoch2000_20241217-054513.pt".
INFO:root:Checkpoint saved for epoch "2000" at ./saved_models/model_epoch2000_20241217-054513.pt".

Step 5: Predicting#

Once the model is trained, we can run predict() to obtain various outputs from MIDAS.

[ ]:
model.predict('./predict/'+task,
        joint_latent=True,
        mod_latent=True,
        impute=True,
        batch_correct=True,
        translate=True,
        input=True)

Outputs: Joint Embeddings#

In this step, we explore the various outputs generated by MIDAS. First, we load the cell-type and batch index labels associated with the dataset.

[ ]:
label = []
batch_id = []
for i in ['p1_0', 'p5_0', 'p8_0']:
    label.append(pd.read_csv('./dataset/'+task+'/label/%s.csv'%i, index_col=0).values.flatten())
    batch_id.append([i] * len(label[-1]))
labels = np.concatenate(label)
batch_ids = np.concatenate(batch_id)

The joint embeddings consist of two components: biological information (c) and technical information (u). To analyze them, we split the embeddings and visualize them separately.

[ ]:
joint_embeddings = load_predicted('./predict/'+task, model.combs, joint_latent=True)

adata_bio = sc.AnnData(joint_embeddings['z']['joint'][:, :model.dim_c])
adata_tech = sc.AnnData(joint_embeddings['z']['joint'][:, model.dim_c:])

adata_bio.obs['batch'] = batch_ids
adata_bio.obs['label'] = labels
adata_tech.obs['batch'] = batch_ids
adata_tech.obs['label'] = labels

for adata in [adata_bio, adata_tech]:
    sc.pp.neighbors(adata)
    sc.tl.umap(adata)
    # shuffle
    sc.pp.subsample(adata, fraction=1)
    sc.pl.umap(adata, color=['batch', 'label'], ncols=2)
INFO:root:Loading predicted variables ...
INFO:root:Loading batch 0: z, joint
100%|██████████| 25/25 [00:00<00:00, 279.39it/s]
INFO:root:Loading batch 1: z, joint
100%|██████████| 28/28 [00:00<00:00, 324.17it/s]
INFO:root:Loading batch 2: z, joint
100%|██████████| 35/35 [00:00<00:00, 273.56it/s]
INFO:root:Converting to numpy ...
INFO:root:Converting batch 0: s, joint
INFO:root:Converting batch 0: z, joint
INFO:root:Converting batch 1: s, joint
INFO:root:Converting batch 1: z, joint
INFO:root:Converting batch 2: s, joint
INFO:root:Converting batch 2: z, joint
/root/anaconda3/envs/pl/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
... storing 'batch' as categorical
... storing 'label' as categorical
../../_images/tutorials_basics_demo2_21_1.png
... storing 'batch' as categorical
... storing 'label' as categorical
../../_images/tutorials_basics_demo2_21_3.png

Outputs: Modality-specific Embeddings#

Here, we check the alignment among modalities by visualizing them with UMAP.

[ ]:
mod_embeddings = load_predicted('./predict/'+task, model.combs,  mod_latent=True, group_by='batch')
batch_names = ['p1_0', 'p5_0', 'p8_0']
adata_list = []
for i in range(model.dims_s['joint']):
    for m in model.mods+['joint']:
        if m in mod_embeddings[i]['z']:
            adata = sc.AnnData(mod_embeddings[i]['z'][m][:, :model.dim_c])
            adata.obs['batch'] = batch_names[i]
            adata.obs['modality'] = m
            adata.obs['label'] = label[i]
            adata_list.append(adata)
adata_mod_concat = sc.concat(adata_list)
for i in adata_mod_concat.obs:
    adata_mod_concat.obs[i] = adata_mod_concat.obs[i].astype('category')
sc.pp.neighbors(adata_mod_concat)
#shuffle
sc.pp.subsample(adata_mod_concat, fraction=1)
sc.tl.umap(adata_mod_concat)
[10]:
# setup figure
nrows = len(model.mods) + 1
ncols = model.dims_s['joint']
point_size = 10

fig, ax = plt.subplots(nrows, ncols, figsize=[2 * ncols, 2 * nrows])

# set up the name of modalities and batch
mod_names = model.mods + ['joint']

# iteratively scatter the data
for i, mod in enumerate(mod_names):
    for b in range(model.dims_s['joint']):
        # filter data
        adata = adata_mod_concat[
            (adata_mod_concat.obs['modality'] == mod) &
            (adata_mod_concat.obs['batch'] == batch_names[b])
        ].copy()
        if len(adata):
            sc.pl.umap(adata, color='label', show=False, ax=ax[i, b], s=point_size)
            ax[i, b].get_legend().set_visible(False)
            handles, labels_ = ax[i, b].get_legend_handles_labels()
        ax[i, b].set_xticks([])
        ax[i, b].set_yticks([])
        ax[i, b].set_xlabel('')
        if b==0:
            ax[i, b].set_ylabel(mod.upper())
        else:
            ax[i, b].set_ylabel('')
        if i==0:
            ax[i, b].set_title(batch_names[b])
        else:
            ax[i, b].set_title('')
# create global legend
fig.legend(handles, labels_, loc='center', bbox_to_anchor=(0.5, -0.02), ncol=len(labels_), fontsize=10)

# adjust the figure
plt.tight_layout(rect=[0.1, 0.05, 1, 1])
plt.show()
../../_images/tutorials_basics_demo2_25_0.png

Outputs: Imputed Counts#

Since this dataset is a mosaic dataset, we retrieve the imputed counts, which are the completed counts. We calculat pearson’s r between the imputed data and the ground-truth data.

[ ]:
imputed = load_predicted('./predict/'+task, model.combs, impute=True)

Calculate similarity.

[ ]:
ref_adt = pd.read_csv('./dataset/'+task+'/data/p1_0/mat/adt.csv', index_col=0).values
print('Pearson\'s r for ADT (p1_0)', pearsonr(ref_adt.reshape(-1), imputed['x_impt']['adt'][:6378].reshape(-1))[0])
Pearson's r for ADT (p1_0) 0.8608445224365799
[ ]:
ref_rna = pd.read_csv('./dataset/'+task+'/data/p5_0/mat/rna.csv', index_col=0).values
print('Pearson\'s for RNA (p5_0)', pearsonr(ref_rna.reshape(-1), imputed['x_impt']['rna'][6378:(6378+6952)].reshape(-1))[0])
Pearson's for RNA (p5_0) 0.9249721667452995

Outputs: Batch-corrected Counts#

[ ]:
batch_corrected_counts = load_predicted('./predict/'+task, model.combs, batch_correct=True)

Save the data to CSV files for later use in R.

[ ]:
pd.DataFrame(batch_corrected_counts['x_bc']['rna']).T.to_csv('temp_rna.csv', index=True)
pd.DataFrame(batch_corrected_counts['x_bc']['adt']).T.to_csv('temp_adt.csv', index=True)

PCA+WNN

[ ]:
from rpy2.robjects.packages import importr
import rpy2.robjects as ro
importr('Seurat')
importr('SeuratDisk')
importr('dplyr')
importr('Signac')
ro.r('''
rna <- read.csv('./temp_rna.csv', header=TRUE, row.names=1)
adt <- read.csv('./temp_adt.csv', header=TRUE, row.names=1)
obj <- CreateSeuratObject(counts = rna, assay = "rna")
obj[["adt"]] <- CreateAssayObject(counts = adt)
obj <- subset(obj, subset = nCount_rna > 0 & nCount_adt > 0)
print(obj)
DefaultAssay(obj) <- 'rna'
VariableFeatures(obj) <- rownames(obj)
obj <-  NormalizeData(obj) %>%
        # FindVariableFeatures(nfeatures = 2000) %>%
        ScaleData() %>%
        RunPCA(reduction.name = "pca_rna", verbose = F)
print('finish rna')
DefaultAssay(obj) <- 'adt'
VariableFeatures(obj) <- rownames(obj)
obj <-  NormalizeData(obj, normalization.method = "CLR", margin = 2) %>%
        ScaleData() %>%
        RunPCA(reduction.name = "pca_adt", verbose = F)
print('finish adt')
print('WNN ...')
obj <- FindMultiModalNeighbors(obj, list("pca_rna", "pca_adt"), list(1:32, 1:32))
obj <- RunUMAP(obj, nn.name = "weighted.nn", reduction.name = "umap")
''')
[ ]:
# Create an AnnData object with 'X' not being used, so we initialize it with all zeros
adata = sc.AnnData(np.zeros([len(batch_corrected_counts['x_bc']['rna']), 1]))
adata.obs['label'] = labels
adata.obs['batch'] = batch_ids
f = ro.r('''DimPlot(obj, reduction='umap')''')
adata.obsm['umap'] = pd.DataFrame(f[0]).iloc[:2].T.values
# shuffle
sc.pp.subsample(adata, fraction=1)
sc.pl.umap(adata, color=['batch', 'label'], ncols=1)
... storing 'label' as categorical
... storing 'batch' as categorical
../../_images/tutorials_basics_demo2_38_1.png