# Standard Library Imports
import os
import csv
import shutil
import itertools
import math
from glob import glob
from collections import defaultdict
from typing import List, Tuple, Dict, Union, Any
# Third-Party Library Imports
import numpy as np
import torch
from tqdm import tqdm
import logging
logging.basicConfig(level=logging.INFO)
[docs]
def load_csv(filename: str) -> list:
"""
Load a CSV file and return its contents as a list of rows.
Parameters:
filename : str
Path to the CSV file.
Returns:
list
A list of rows, where each row is a list of strings.
"""
with open(filename, 'r') as file:
reader = csv.reader(file)
data = list(reader)
return data
[docs]
def exp(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""
Compute a numerically stable exponential transformation.
Handles negative and positive values to avoid numerical instability.
Parameters:
x : torch.Tensor
Input tensor.
eps : float, optional
A small epsilon value to avoid division by zero, by default 1e-12.
Returns:
torch.Tensor
Transformed tensor with the exponential applied.
"""
return (x < 0) * (x.clamp(max=0)).exp() + (x >= 0) / ((-x.clamp(min=0)).exp() + eps)
[docs]
def log(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
"""
Compute a numerically stable logarithm transformation.
Ensures numerical stability by adding a small epsilon.
Parameters:
x : torch.Tensor
Input tensor.
eps : float, optional
A small epsilon value to avoid log(0), by default 1e-12.
Returns:
torch.Tensor
Transformed tensor with the logarithm applied.
"""
return (x + eps).log()
[docs]
def ref_sort(x: List[str], ref: List[str]) -> List[str]:
"""
Sort the elements of `x` based on the order defined in `ref`.
Parameters:
x : list of str
List of elements to be sorted.
ref : list of str
Reference list defining the sort order.
Returns:
list of str
A sorted list of elements from `x` that appear in `ref`,
maintaining the order of `ref`.
"""
return [v for v in ref if v in x]
[docs]
def reverse_dict(original_dict: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""
Reverse the keys and sub-keys of a nested dictionary.
Parameters:
original_dict : dict
The original nested dictionary to be reversed.
Returns:
dict
A reconstructed dictionary where the keys and sub-keys are swapped.
"""
reconstructed_dict = defaultdict(dict)
for key, subdict in original_dict.items():
for sub_key, value in subdict.items():
reconstructed_dict[sub_key][key] = value
return dict(reconstructed_dict)
[docs]
def detach_tensors(x: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively detach all tensors in a dictionary.
Parameters:
x : dict
Dictionary containing tensors or nested dictionaries.
Returns:
dict
A new dictionary with all tensors detached.
"""
y = {}
for kw, arg in x.items():
if torch.is_tensor(arg):
y[kw] = arg.detach()
else:
y[kw] = detach_tensors(arg)
return y
[docs]
def convert_tensor_to_list(data: Union[torch.Tensor, List[List[Any]]]) -> List[List[Any]]:
"""
Convert a 2D tensor or list into a 2D list.
Parameters:
data : torch.Tensor or list of lists
Input data to be converted.
Returns:
list of lists
Converted 2D list.
"""
if torch.is_tensor(data):
return data.cpu().detach().numpy().tolist()
else:
return [list(line) for line in data]
[docs]
def save_list_to_csv(data: List[List[Any]], filename: str, delimiter: str = ','):
"""
Save a 2D list to a CSV file.
Parameters:
data : list of lists
Input data to be saved.
filename : str
Path to the CSV file.
delimiter : str, optional
Delimiter to separate values in the CSV file, by default ','.
"""
with open(filename, 'w') as file:
writer = csv.writer(file, delimiter=delimiter)
writer.writerows(data)
[docs]
def save_tensor_to_csv(data: torch.Tensor, filename: str, delimiter: str = ','):
"""
Save a 2D tensor to a CSV file.
Parameters:
data : torch.Tensor
Input tensor to be saved.
filename : str
Path to the CSV file.
delimiter : str, optional
Delimiter to separate values in the CSV file, by default ','.
"""
data_list = convert_tensor_to_list(data)
save_list_to_csv(data_list, filename, delimiter)
[docs]
def get_name_fmt(file_num: int) -> str:
"""
Generate a format string for filenames based on the total number of files.
Parameters:
file_num : int
Total number of files to be named.
Returns:
str
Format string for filenames, e.g., '%03d' for three-digit naming.
"""
digits = math.floor(math.log10(file_num)) + 1
return f'%0{digits}d'
[docs]
def convert_tensors_to_cuda(x: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
"""
Recursively convert all tensors in a dictionary to CUDA.
Parameters:
x : dict
Dictionary containing tensors or nested dictionaries.
device : torch.device
Device to move the tensors to (e.g., CUDA or CPU).
Returns:
dict
A new dictionary with all tensors moved to the specified device.
"""
y = {}
for kw, arg in x.items():
if torch.is_tensor(arg):
y[kw] = arg.to(device)
else:
y[kw] = convert_tensors_to_cuda(arg, device)
return y
[docs]
def filter_keys(d: Dict[str, Any], substring: str) -> Dict[str, Any]:
"""
Filter a dictionary to include only keys that contain a specific substring.
Parameters:
d : Dict[str, Any]
The input dictionary to filter.
substring : str
The substring to look for in the keys.
Returns:
Dict[str, Any]
A new dictionary containing only the keys from the original
dictionary that include the specified substring.
"""
return {k: v for k, v in d.items() if substring in k}
[docs]
def get_filenames(directory: str, extension: str) -> List[str]:
"""
Get sorted filenames with the given extension in the specified directory.
Parameters:
directory : str
The directory to search for files.
extension : str
The file extension to filter by.
Returns:
list of str
Sorted list of filenames with the specified extension.
"""
filenames = glob(os.path.join(directory, f'*.{extension}'))
filenames = [os.path.basename(filename) for filename in filenames]
filenames.sort()
return filenames
[docs]
def load_predicted(
pred_dir: str,
combs: List[List[str]],
joint_latent: bool = True,
mod_latent: bool = False,
impute: bool = False,
batch_correct: bool = False,
translate: bool = False,
input: bool = False,
group_by: str = 'modality'
) -> Union[Dict[int, Dict[str, Any]], Dict[str, Dict[str, np.ndarray]]]:
"""
Load predicted variables from a specified directory.
Parameters:
pred_dir : str
Path to the prediction directory.
combs : list of list of str
Combinations of modalities for each batch. Example: [['rna'], ['rna', 'adt']].
joint_latent : bool, optional
Whether to include joint latent variables, by default True.
mod_latent : bool, optional
Whether to include modality-specific latent variables, by default False.
impute : bool, optional
Whether to include imputed data, by default False.
batch_correct : bool, optional
Whether to include batch-corrected data, by default False.
translate : bool, optional
Whether to include translated data, by default False.
input : bool, optional
Whether to include input data, by default False.
group_by : str, optional
Grouping method for the data, either 'modality' or 'batch', by default 'modality'.
Returns:
dict
Loaded predicted data grouped by the specified method.
"""
logging.info('Loading predicted variables ...')
dirs = get_pred_dirs(pred_dir,
combs,
joint_latent,
mod_latent,
impute,
batch_correct,
translate,
input)
data = {}
# Load data from directories
for batch_id, batch_dirs in dirs.items():
data[batch_id] = {'s': {}}
for variable, variable_dirs in batch_dirs.items():
data[batch_id][variable] = {}
for mod, dir_path in variable_dirs.items():
logging.info(f'Loading batch {batch_id}: {variable}, {mod}')
data[batch_id][variable][mod] = []
if variable == 'z':
data[batch_id]['s'][mod] = []
filenames = get_filenames(dir_path, 'csv')
for filename in tqdm(filenames):
v = load_csv(os.path.join(dir_path, filename))
data[batch_id][variable][mod] += v
if variable == 'z':
data[batch_id]['s'][mod] += [batch_id] * len(v)
logging.info('Converting to numpy ...')
for batch_id, batch_data in data.items():
for variable, variable_data in batch_data.items():
for mod, values in variable_data.items():
logging.info(f'Converting batch {batch_id}: {variable}, {mod}')
if variable in ['z', 'x_trans', 'x_impt', 's', 'x', 'x_bc']:
data[batch_id][variable][mod] = values
# Group data by modality if required
if group_by == 'batch':
for batch_id, batch_data in data.items():
for variable, variable_data in batch_data.items():
for mod, values in variable_data.items():
data[batch_id][variable][mod] = np.array(data[batch_id][variable][mod]).astype(np.float32)
return data
elif group_by == 'modality':
data_m = {}
for batch_id, batch_data in data.items():
for variable, variable_data in batch_data.items():
if variable not in data_m:
data_m[variable] = {}
for mod, values in variable_data.items():
if mod not in data_m[variable]:
data_m[variable][mod] = {}
data_m[variable][mod][batch_id] = values
# Concatenate data across batches
for variable, variable_data in data_m.items():
for mod, mod_data in variable_data.items():
data_m[variable][mod] = np.concatenate(list(mod_data.values()), axis=0).astype(np.float32)
return data_m
[docs]
def get_s_joint_mods(combs: List[List[str]]):
"""
Generate `s_joint` and `mods` from a list of modality combinations.
Parameters:
----------
combs : list
A list where each element is a list of strings representing combinations
of modalities for a specific batch.
Returns:
-------
tuple
A tuple containing:
- `s_joint`: A list of dictionaries, where each dictionary maps the modalities
to their corresponding indices for each batch.
- `mods`: A list of all unique modalities across the dataset.
"""
s_joint = []
mods = {}
for b in combs:
t = {}
for m in b + ['joint']:
if m in mods:
mods[m] += 1
else:
mods[m] = 0
t[m] = mods[m]
s_joint.append(t)
mods = list(mods.keys())
mods.remove('joint')
return s_joint, mods
[docs]
def get_pred_dirs(
pred_dir: str,
combs: List[List[str]],
joint_latent: bool,
mod_latent: bool,
impute: bool,
batch_correct: bool,
translate: bool,
input: bool
) -> Dict[int, Dict[str, Dict[str, str]]]:
"""
Generate directory paths for predictions based on configurations.
Parameters:
pred_dir : str
Base directory for predictions.
combs : list of list of str
Combinations of modalities for each batch.
joint_latent : bool
Include joint latent variables.
mod_latent : bool
Include modality-specific latent variables.
impute : bool
Include imputed data.
batch_correct : bool
Include batch-corrected data.
translate : bool
Include translated data.
input : bool
Include input data.
Returns:
dict
Dictionary of directories for each batch and variable.
"""
dirs = {}
s_joint, mods = get_s_joint_mods(combs)
for batch_id in range(len(s_joint)):
batch_dir = os.path.join(pred_dir, f'batch_{batch_id}')
dirs[batch_id] = {}
if joint_latent or mod_latent:
dirs[batch_id]['z'] = {}
if joint_latent:
dirs[batch_id]['z']['joint'] = os.path.join(batch_dir, 'z', 'joint')
if mod_latent:
for mod in combs[batch_id]:
dirs[batch_id]['z'][mod] = os.path.join(batch_dir, 'z', mod)
if impute:
dirs[batch_id]['x_impt'] = {mod: os.path.join(batch_dir, 'x_impt', mod) for mod in mods}
if batch_correct:
dirs[batch_id]['x_bc'] = {mod: os.path.join(batch_dir, 'x_bc', mod) for mod in mods}
if translate:
dirs[batch_id]['x_trans'] = {}
all_combinations = generate_all_combinations(mods)
for input_mods, output_mods in all_combinations:
input_mods_sorted = sorted(input_mods)
for mod in output_mods:
key = '_'.join(input_mods_sorted) + '_to_' + mod
dirs[batch_id]['x_trans'][key] = os.path.join(batch_dir, 'x_trans', key)
if input:
dirs[batch_id]['x'] = {mod: os.path.join(batch_dir, 'x', mod) for mod in combs[batch_id]}
return dirs
[docs]
def rmdir(directory: str):
"""
Remove a directory if it exists.
Parameters:
directory : str
Path to the directory to remove.
"""
if os.path.exists(directory):
logging.warning(f'Removing directory "{directory}"')
shutil.rmtree(directory)
[docs]
def mkdir(directory: str, remove_old: bool = False):
"""
Create a directory, optionally removing the old one.
Parameters:
directory : str
Path to the directory.
remove_old : bool, optional
Whether to remove the old directory if it exists, by default False.
"""
if remove_old:
rmdir(directory)
if not os.path.exists(directory):
os.makedirs(directory)
[docs]
def mkdirs(directories: Union[str, List[str], Dict[str, Any]], remove_old: bool = False):
"""
Recursively create directories.
Parameters:
directories : str, list, or dict
Path(s) to directories to create.
remove_old : bool, optional
Whether to remove old directories if they exist, by default False.
"""
if isinstance(directories, (list, tuple)):
for d in directories:
mkdirs(d, remove_old=remove_old)
elif isinstance(directories, dict):
for d in directories.values():
mkdirs(d, remove_old=remove_old)
else:
mkdir(directories, remove_old=remove_old)
[docs]
def reverse_trsf(name: str, data: np.ndarray, **kwargs) -> np.ndarray:
"""
Apply a reverse transformation to the given data.
Parameters:
name : str
Name of the transformation to reverse (e.g., 'log1p').
data : np.ndarray
Data to transform.
kwargs : dict
Additional transformation parameters.
Returns:
np.ndarray
Transformed data.
"""
# Extract parameters from kwargs
params = {k.split('_')[-1]: v for k, v in kwargs.items()}
# Perform the reverse transformation based on the name
if name == 'log1p':
return data.exp()
else:
return data
[docs]
def generate_all_combinations(mods: List[str]) -> List[Tuple[Tuple[str, ...], List[str]]]:
"""
Generate all possible input-output combinations for a given list of modalities.
For N modalities, generate all combinations of size r (1 <= r < N) as input,
and the remaining modalities as output.
Parameters:
mods : list of str
List of modality names.
Returns:
list of tuple
A list of tuples, where each tuple contains:
- A tuple of input modalities.
- A list of output modalities.
"""
combinations = []
for r in range(1, len(mods)): # Generate combinations of size r
for input_mods in itertools.combinations(mods, r):
output_mods = list(set(mods) - set(input_mods))
combinations.append((input_mods, output_mods))
return combinations