Source code for classify

import anndata
import argparse
import numpy as np
import os
import pandas as pd
import scanpy as sc
import scvi
from scvi.model import TOTALVI
import sys
import torch
import warnings

# Suppress warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)


# sets path for ref_model and ref_adata
[docs]def get_model_path(ref_model, ref_adata): """Fetch the model and reference data paths directly from user input.""" if not os.path.exists(ref_model): raise FileNotFoundError(f"Reference model file not found: {ref_model}") if not os.path.exists(ref_adata): raise FileNotFoundError(f"Reference AnnData file not found: {ref_adata}") return ref_model, ref_adata
[docs]def parse_arguments(): """Parse command-line arguments.""" parser = argparse.ArgumentParser(description="Process AnnData for TOTALVI") input_data = parser.add_argument_group('Input Data') input_data.add_argument('--adata_file', type=str, required=False, help="Path to the input AnnData file") input_data.add_argument('--RNApath', type=str, help="Path to RNA counts file (CSV format)") input_data.add_argument('--metapath', type=str, help="Path to metadata file (CSV format)") input_data.add_argument('--ADTpath', type=str, help="Path to ADT counts file (CSV format)") input_data.add_argument('--umappath', type=str, required=False, help="Path to UMAP file (CSV format)") parser.add_argument('--ref_model', type=str, required=True, help="Path to the reference model file") parser.add_argument('--ref_adata', type=str, required=True, help="Path to the reference AnnData file") parser.add_argument('--classifier_type', type=str, choices=["BBC", "BRF"], default="BBC", help="Classifier type to use for NK cell classification (BBC or BRF)") parser.add_argument('--output_dir', type=str, default="./output", help="Directory to save output files") parser.add_argument('--protein', action='store_true', help="Flag to include protein data") parser.add_argument('--batch', type=str, default="sample", help="Batch column to use in AnnData") parser.add_argument('--proteins_file', type=str, required=True, help="Path to a file containing proteins to exclude from protein_expression.") parser.add_argument('--protein_suffix', type=str, default='-TotalSeqC', help="Suffix to be replaced from protein names in the expression matrix.") parser.add_argument('--adversarial_classifier', type=str, choices=["None", "True", "False"], default="None", help="Enable adversarial classifier in TOTALVI (None, True, False)") parser.add_argument('--mouse', action='store_true', help="Flag to process mouse genes (mm10)") parser.add_argument('--patient', type=str, required=True, help="Name of the patient or dataset being processed") parser.add_argument('--disable_NK_type', action='store_true', help="Disable NK cell classification v1.1 step") # In parse_arguments parser.add_argument('--proteintech', action='store_true', help="Flag if ADT data is in ProteinTech format e.g. 'prot:CD16.65090.1' → 'CD16ADT'.") return parser.parse_args()
[docs]def validate_files(adata_file, RNApath, metapath, umappath, ADTpath, protein): """Check if input files exist.""" files = [adata_file, RNApath, metapath] if umappath: # Only validate UMAP path if it's provided files.append(umappath) if protein: # Only validate ADTpath if protein is True files.append(ADTpath) for file in files: if file and not os.path.exists(file): raise FileNotFoundError(f"File not found: {file}")
[docs]def load_proteins_from_file(file_path): """Load protein names from a file.""" if not os.path.exists(file_path): raise FileNotFoundError(f"Protein file not found: {file_path}") # Read the file; assume it's a plain text file with one protein per line with open(file_path, 'r') as f: proteins = [line.strip() for line in f.readlines()] return proteins
[docs]def validate_data_integrity(RNA_counts, ADT_counts): if RNA_counts.shape[0] != ADT_counts.shape[0]: raise ValueError(f"Mismatch in number of cells: RNA ({RNA_counts.shape[0]}) vs ADT ({ADT_counts.shape[0]}).") if not RNA_counts.obs_names.equals(ADT_counts.obs_names): raise ValueError("Mismatch in barcodes between RNA and ADT counts.") if any(RNA_counts.obs_names != ADT_counts.obs_names): raise ValueError("Barcode names do not match between RNA and ADT counts.") print("RNA and ADT data integrity validated successfully.")
[docs]def load_data(adata_file, ref_adata, RNApath, metapath, umappath, ADTpath, protein): """Load AnnData objects and external data.""" if adata_file: adata = sc.read_h5ad(adata_file) else: adata = sc.read_csv(RNApath) adata = adata.transpose() meta = pd.read_csv(metapath, index_col=0) adata.obs = meta # Convert all object columns to string type for col in adata.obs.select_dtypes(include=['object']).columns: adata.obs[col] = adata.obs[col].astype(str) if umappath: # Only process UMAP if a path is given umap = pd.read_csv(umappath, index_col=0) adata = adata[umap.index] umap = umap.to_numpy() # Convert to array adata.obsm["X_umap"] = umap # Load the reference data ref = sc.read_h5ad(ref_adata) # Check if protein data already exists in adata if protein and "protein_expression" in adata.obsm: protein_adata = adata.obsm["protein_expression"].copy() del adata.obsm["protein_expression"] # force reprocessing elif protein and ADTpath: # If protein flag is set and ADTpath is provided, load the ADT data protein_adata = sc.read_csv(ADTpath) protein_adata = protein_adata.transpose() # Validate integrity between RNA and ADT data validate_data_integrity(adata, protein_adata) # inspect the indexes: print('Protein Indexes Before') print(protein_adata.obs.index) protein_adata.obs.index = protein_adata.obs.index.str.replace('X', '') # note default value of regex changing in other versions, current default is True protein_adata.obs.index = protein_adata.obs.index.str.replace('\.', '-', regex = True) print('Protein Indexes After') print(protein_adata.obs.index) print() else: protein_adata = None return adata, ref, protein_adata
[docs]def preprocess_data(adata, protein, protein_adata, ref, meta, batch, proteins_to_check, protein_suffix, proteintech, output_dir, patient, mouse): if protein: if protein_adata is not None: adata = integrate_protein_data(adata, protein_adata, meta, proteins_to_check, protein_suffix, proteintech) else: adata = initialize_protein_data(adata, ref) adata = prepare_adata_for_totalvi(adata, batch, ref, output_dir, patient, mouse) return adata
# Updated integrate_protein_data to include proteintech data
[docs]def integrate_protein_data(adata, protein_adata, meta, proteins_to_check, protein_suffix, proteintech): """Integrate protein expression data into the AnnData object. Supports two protein naming formats: - Standard (default): CD16-TotalSeqC → CD16ADT (use --protein_suffix) - ProteinTech: prot:CD16.65090.1 → CD16ADT (use --proteintech flag) """ if "protein_expression" not in adata.obsm: if isinstance(protein_adata, pd.DataFrame): adata.obsm["protein_expression"] = protein_adata else: protein_adata.obs = meta adata.obsm["protein_expression"] = protein_adata.to_df() cols = adata.obsm["protein_expression"].columns if proteintech: # Handle ProteinTech format: prot:CD16.65090.1 → CD16ADT cols = cols.str.replace('prot:', '', regex=False).str.split('.').str[0] + 'ADT' print("ProteinTech format detected. Renamed protein columns to ADT format.") else: # Handle standard format: CD16-TotalSeqC → CD16ADT cols = cols.str.replace(protein_suffix, 'ADT', regex=True) adata.obsm["protein_expression"].columns = cols print(f"Protein names after renaming: {cols.tolist()}") adata.obsm["protein_expression"] = adata.obsm["protein_expression"][ adata.obsm["protein_expression"].columns.difference(proteins_to_check)] return adata
[docs]def initialize_protein_data(adata, ref): """Initialize protein expression data with zeros.""" pro_exp = ref.obsm["protein_expression"] data = np.zeros((adata.n_obs, pro_exp.shape[1])) adata.obsm["protein_expression"] = pd.DataFrame(columns=pro_exp.columns, index=adata.obs_names, data=data) return adata
[docs]def prepare_adata_for_totalvi(adata, batch, ref, output_dir, patient, mouse): """Prepare AnnData for TOTALVI.""" adata.obs['batch'] = adata.obs[batch] if mouse: print("Filtering out mouse genes (mm10)...") adata.var['mouse'] = adata.var_names.str.startswith('mm10') adata = adata[:, ~adata.var['mouse']] del adata.var['mouse'] # Remove to avoid conflicts in TOTALVI #adata.raw = adata adata.layers["counts"] = adata.X.copy() sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) adata.raw = adata adata.obs["celltype.l2"] = "Unknown" align_protein_data(adata, ref, output_dir, patient) ref.obs["dataset_name"] = "Reference" adata.obs["dataset_name"] = "Query" prepped_file = os.path.join(output_dir, f'{patient}_prepped.h5ad') adata.write(prepped_file) return adata
[docs]def align_protein_data(adata, ref, output_dir, patient): """Align protein expression data to the reference.""" print(adata.obsm.keys()) adata.obsm["protein_expression"] = adata.obsm["protein_expression"].reindex( columns=ref.obsm["protein_expression"].columns, fill_value=0.0) return adata
[docs]def train_totalvi_model(adata, ref_model, ref, adversarial_classifier): """Train the TOTALVI model on the query data.""" print("Loading pre-trained TOTALVI model...") vae = TOTALVI.load(ref_model, ref) print("Preparing query AnnData for TOTALVI...") try: TOTALVI.prepare_query_anndata(adata, reference_model=vae) print(f"Keys in adata.obsm after prepare_query_anndata: {adata.obsm.keys()}") except Exception as e: raise ValueError(f"Failed to prepare query AnnData: {e}") print("Training TOTALVI query model...") vae_q = TOTALVI.load_query_data(adata, vae) # Convert string input to Python types # Convert string input to Python types if adversarial_classifier == "None": adv_classifier = None elif adversarial_classifier == "True": adv_classifier = True elif adversarial_classifier == "False": adv_classifier = False else: raise ValueError(f"Invalid value for adversarial_classifier: {adversarial_classifier}") vae_q.train( max_epochs=150, plan_kwargs=dict(weight_decay=0.0, scale_adversarial_loss=0.0), adversarial_classifier=adv_classifier ) print(f"Training status: {vae_q.is_trained}") # Check latent space generation try: latent_rep = vae_q.get_latent_representation(adata) print(f"Shape of latent representation: {latent_rep.shape}") adata.obsm["X_totalvi_scarches"] = latent_rep except Exception as e: raise ValueError(f"Failed to generate latent space: {e}") print("Model training complete. Latent space created successfully.") return vae_q
[docs]def classify_latent_space(vae_q, adata, classifier_type): """Classify using BBC or BRF based on the selected model.""" if classifier_type == "BBC": predictions = vae_q.latent_space_classifer_bbc_.predict(adata.obsm["X_totalvi_scarches"]) probs = vae_q.latent_space_classifer_bbc_.predict_proba(adata.obsm["X_totalvi_scarches"]) elif classifier_type == "BRF": predictions = vae_q.latent_space_classifer_brf_.predict(adata.obsm["X_totalvi_scarches"]) probs = vae_q.latent_space_classifer_brf_.predict_proba(adata.obsm["X_totalvi_scarches"]) else: raise ValueError(f"Unknown classifier type '{classifier_type}'") predictions = np.where(predictions == "ML1", "eML1", predictions) predictions = np.where(predictions == "ML2", "eML2", predictions) print("Classifier predictions completed") return predictions, probs
[docs]def save_results(adata, predictions, probs, output_dir, patient, vae_q, classifier_type, mouse): """Save all relevant results based on the classifier used.""" df_probs = pd.DataFrame(probs, columns=getattr(vae_q, f"latent_space_classifer_{classifier_type.lower()}_").classes_, index=adata.obs_names) # Save probabilities to CSV df_probs.to_csv(os.path.join(output_dir, f'{patient}_probabilities{classifier_type}output.csv')) dfi_probs = df_probs.loc[adata.obs_names] # Add probability columns to adata.obs dynamically adata.obs[f"CD56bright{classifier_type}prob"] = dfi_probs["CD56bright"] adata.obs[f"CD56dim{classifier_type}prob"] = dfi_probs["CD56dim"] adata.obs[f"eML1{classifier_type}prob"] = dfi_probs["ML1"] adata.obs[f"eML2{classifier_type}prob"] = dfi_probs["ML2"] # Add predictions to adata.obs adata.obs[f"predictions{classifier_type}"] = predictions print(adata.obs.columns) # This will show all the columns in `adata.obs` print(f"Classifier type: {classifier_type}") # **Ensure `adata.var['mouse']` is deleted before saving, if mouse processing was enabled** if mouse and 'mouse' in adata.var: del adata.var['mouse'] # Save the classified AnnData object classified_file = os.path.join(output_dir, f'{patient}_eMLclassified_adata.h5ad') adata.write_h5ad(classified_file) print("Saved the classified AnnData object") # Save the updated VAE model vae_model_file = os.path.join(output_dir, f'{patient}_vae_model_withclassifiers') vae_q.save(vae_model_file, overwrite=True) print("Saved the updated VAE model")
[docs]def classify_cells(adata, classifier_type, output_dir, patient): """ Classifies cells based on {classifier_type} probabilities stored in adata.obs. Assigns 'ML_transition', 'unclassified', or the label with the highest probability. """ probability_columns = [f'CD56bright{classifier_type}prob', f'CD56dim{classifier_type}prob', f'eML1{classifier_type}prob', f'eML2{classifier_type}prob'] labels = ['CD56bright', 'CD56dim', 'eML1', 'eML2'] # Ensure probability columns are numeric for col in probability_columns: adata.obs[col] = pd.to_numeric(adata.obs[col], errors='coerce') # Classification function def classify_row(row): if all(row[col] < 0.5 for col in probability_columns): if row[f'eML1{classifier_type}prob'] + row[f'eML2{classifier_type}prob'] > 0.5: return 'eML_transition' else: return 'unclassified' else: max_index = row[probability_columns].astype(float).idxmax() return labels[probability_columns.index(max_index)] # Apply nk type classification function adata.obs['NK_type'] = adata.obs.apply(classify_row, axis=1) # Save the classified AnnData object classified_file = os.path.join(output_dir, f'{patient}_eMLclassified_adata.h5ad') adata.write_h5ad(classified_file) return adata
[docs]def main(): """Main function to execute the process.""" args = parse_arguments() # Create the output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) # Save arguments to a file for later review args_file = os.path.join(args.output_dir, f'{args.patient}_arguments_used.txt') with open(args_file, 'w') as f: for key, value in vars(args).items(): f.write(f"{key}: {value}\n") print(f"Arguments saved to {args_file}") # Load proteins from the specified file proteins_to_check = load_proteins_from_file(args.proteins_file) # Validate file paths validate_files(args.adata_file, args.RNApath, args.metapath, args.umappath, args.ADTpath, args.protein) # Get model paths ref_model, ref_adata = get_model_path(args.ref_model, args.ref_adata) # Load data adata, ref, protein_adata = load_data(args.adata_file, ref_adata, args.RNApath, args.metapath, args.umappath, args.ADTpath, args.protein) meta = adata.obs.copy() # Preprocess data adata = preprocess_data(adata, args.protein, protein_adata, ref, adata.obs, args.batch, proteins_to_check, args.protein_suffix, args.proteintech, args.output_dir, args.patient, args.mouse) # Train the model vae_q = train_totalvi_model(adata, ref_model, ref, args.adversarial_classifier) # Classify the data predictions, probs = classify_latent_space(vae_q, adata, args.classifier_type) # Save results save_results(adata, predictions, probs, args.output_dir, args.patient, vae_q, args.classifier_type, args.mouse) # Run classification by default unless disabled if not args.disable_NK_type: adata = classify_cells(adata, args.classifier_type, args.output_dir, args.patient) print(adata.obs["NK_type"].value_counts()) else: print("NK Cell classification v1.1 step skipped.") print("All output files are saved in output directory")
if __name__ == "__main__": main()