{ "cells": [ { "cell_type": "markdown", "id": "ee8c5b91-0e7f-47ca-9a5d-a0911af7ebcf", "metadata": {}, "source": [ "# CellPLM\n", "\n", "CellPLM is the first single-Cell Pre-trained Language Model that encodes cell-cell relations and it consistently outperforms existing pre-trained and non-pre-trained models in diverse downstream tasks, with 100x higher inference speed compared to existing pre-trained models.\n", "\n", "Here, you can use `omicverse.llm.SCLLMManager(model_type=\"cellplm\")` to call this model directly.\n", "\n", "Cite: Wen, H., Tang, W., Dai, X., Ding, J., Jin, W., Xie, Y., & Tang, J. (2023). CellPLM: Pre-training of cell language model beyond single cells. BioRxiv, 2023-10." ] }, { "cell_type": "code", "execution_count": 1, "id": "63de612a-31ac-4bd7-b332-51bb4c854084", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🔬 Starting plot initialization...\n", "Downloading Arial font from GitHub...\n", "Arial font downloaded successfully to: /tmp/omicverse_arial.ttf\n", "Registered as: Arial\n", "🧬 Detecting CUDA devices…\n", "✅ [GPU 0] NVIDIA H100 80GB HBM3\n", " • Total memory: 79.1 GB\n", " • Compute capability: 9.0\n", "\n", " ____ _ _ __ \n", " / __ \\____ ___ (_)___| | / /__ _____________ \n", " / / / / __ `__ \\/ / ___/ | / / _ \\/ ___/ ___/ _ \\ \n", "/ /_/ / / / / / / / /__ | |/ / __/ / (__ ) __/ \n", "\\____/_/ /_/ /_/_/\\___/ |___/\\___/_/ /____/\\___/ \n", "\n", "🔖 Version: 1.7.6rc1 📚 Tutorials: https://omicverse.readthedocs.io/\n", "✅ plot_set complete.\n", "\n" ] } ], "source": [ "import scanpy as sc\n", "import omicverse as ov\n", "ov.plot_set(font_path='Arial')\n", "\n", "# Enable auto-reload for development\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "9cb24e71-c2a2-417e-92ac-829559f30328", "metadata": {}, "source": [ "## Load example datasets\n", "\n", "For this tutorial, we use three batches from the NeurIPS 2021 single-cell competition dataset, which provides an excellent test case for batch integration and cell type annotation.\n", "\n", "- s1d3: https://figshare.com/ndownloader/files/41932005\n", "- s2d1: https://figshare.com/ndownloader/files/41932011\n", "- s3d7: https://figshare.com/ndownloader/files/41932008" ] }, { "cell_type": "code", "execution_count": 2, "id": "002098b5-964a-4156-a74c-326d8681e421", "metadata": { "scrolled": true }, "outputs": [], "source": [ "adata1=ov.read('data/neurips2021_s1d3.h5ad')\n", "adata1.obs['batch']='s1d3'\n", "adata2=ov.read('data/neurips2021_s2d1.h5ad')\n", "adata2.obs['batch']='s2d1'\n", "adata3=ov.read('data/neurips2021_s3d7.h5ad')\n", "adata3.obs['batch']='s3d7'" ] }, { "cell_type": "code", "execution_count": 3, "id": "4c57596f-8c33-471e-b795-1751503fca41", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 27423 × 13953\n", " obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train'\n", " var: 'feature_types', 'gene_id'\n", " obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap'\n", " layers: 'counts'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata=sc.concat([adata1,adata2,adata3],merge='same')\n", "adata" ] }, { "cell_type": "code", "execution_count": null, "id": "883a75f3-2acb-42c5-9431-e0b67699b1a3", "metadata": { "scrolled": true }, "outputs": [], "source": [ "adata=ov.pp.preprocess(adata,mode='shiftlog|pearson',\n", " n_HVGs=3000,batch_key=None,target_sum=1e4)\n", "adata" ] }, { "cell_type": "code", "execution_count": 6, "id": "13035e0e-6472-4188-802c-9fd2c40eaf94", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "View of AnnData object with n_obs × n_vars = 27423 × 3000\n", " obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train'\n", " var: 'feature_types', 'gene_id', 'n_cells', 'percent_cells', 'robust', 'means', 'variances', 'residual_variances', 'highly_variable_rank', 'highly_variable_features'\n", " uns: 'log1p', 'hvg', 'status', 'status_args', 'REFERENCE_MANU'\n", " obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap'\n", " layers: 'counts'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata.raw = adata\n", "adata = adata[:, adata.var.highly_variable_features]\n", "adata" ] }, { "cell_type": "markdown", "id": "866a402d-d5d9-4949-bdab-300a03705a78", "metadata": {}, "source": [ "## Initialize CellPLM model\n", "\n", "CellPLM requires pre-trained model checkpoints that contain the transformer weights and tokenization vocabulary. The model supports multiple pipeline tasks including embedding extraction, cell type annotation, and data imputation.\n", "\n", "Download the CellPLM model from:: https://www.dropbox.com/scl/fo/i5rmxgtqzg7iykt2e9uqm/h/ckpt?dl=0&subfolder_nav_tracking=1" ] }, { "cell_type": "code", "execution_count": 4, "id": "f8c07bc9-6bed-4553-8bae-ef95cbb95797", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "📥 Loading CellPLM model from llm_model/models/cellplm...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading imputation pipeline...: 100%|█████████████████████████████████| 3/3 [00:03<00:00, 1.21s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[Loaded] CellPLM model loaded from llm_model/models/cellplm\n", " - Loaded pipelines: annotation, embedding, imputation\n", " - Pretrain version: 20231027_85M\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "manager = ov.llm.SCLLMManager(\n", " model_type=\"cellplm\",\n", " model_path=\"llm_model/models/cellplm\",\n", " pretrain_version=\"20231027_85M\"\n", ")" ] }, { "cell_type": "markdown", "id": "f6e32259-fe0d-4df6-b374-4272b0461b3f", "metadata": {}, "source": [ "## Zero-shot embedding generation\n", "\n", "Zero-shot embedding leverages CellPLM's pre-trained knowledge to generate meaningful cell representations without any dataset-specific training.\n", "\n", "The 512-dimensional embeddings provide a compressed but information-rich representation of each cell's transcriptional state." ] }, { "cell_type": "code", "execution_count": 7, "id": "e43d863e-0c83-4684-801b-8a4e78a1bb0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[🔬Cells] Data Summary:\n", " Cells: 27,423\n", " Genes: 3,000\n", " Batches: 3\n", " s3d7: 11,230 cells\n", " s2d1: 10,258 cells\n", " s1d3: 5,935 cells\n", "[Embedding] Starting get_embeddings...\n", " cells: 27,423\n", " genes: 3,000\n", " [Embedding] Extracting embeddings for 27,423 cells...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[CellPLM] Computing embeddings...: 0%| | 0/2 [00:00" ] }, "metadata": { "image/png": { "height": 287, "width": 1163 } }, "output_type": "display_data" } ], "source": [ "print(f\"embedding: {embeddings.shape}\")\n", "\n", "adata.obsm['X_cellplm'] = embeddings\n", "\n", "sc.pp.neighbors(adata, use_rep='X_cellplm')\n", "sc.tl.umap(adata)\n", "ov.pl.embedding(\n", " adata, \n", " basis='X_umap',\n", " color=['batch', 'cell_type']\n", ")" ] }, { "cell_type": "markdown", "id": "709932be-af98-4341-bce4-a43dba7d2a36", "metadata": {}, "source": [ "## Fine-tuning for enhanced performance\n", "\n", "Fine-tuning adapts CellPLM's pre-trained weights to the specific characteristics of our dataset, significantly improving performance on downstream tasks. This supervised learning approach uses high-quality cell type annotations from the reference batch (s1d3) to:" ] }, { "cell_type": "code", "execution_count": 13, "id": "dcd84bff-041f-4ef1-bd4e-d9c826d39b8f", "metadata": {}, "outputs": [], "source": [ "reference_adata=adata[adata.obs['batch']=='s1d3']" ] }, { "cell_type": "code", "execution_count": 14, "id": "241ef024-e1bc-48db-b3e6-2212de48525b", "metadata": {}, "outputs": [], "source": [ "reference_adata.obs['celltype']=reference_adata.obs['cell_type'].copy()" ] }, { "cell_type": "code", "execution_count": 33, "id": "ed349cba-8d10-43ef-bd0c-4365a71a3392", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 Starting CellPLM fine-tuning for annotation task...\n", "Training parameters: epochs=500, batch_size=32, lr=0.0001\n", "📊 Preparing cell type mapping...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Data preparation: 100%|█████████████████████████████████████████████| 3/3 [00:00<00:00, 4482.69it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Found 30 cell types: ['B1 B IGKC+', 'B1 B IGKC-', 'CD4+ T activated', 'CD4+ T activated integrinB7+', 'CD4+ T naive', 'CD8+ T CD49f+', 'CD8+ T CD57+ CD45RA+', 'CD8+ T CD69+ CD45RA+', 'CD8+ T CD69+ CD45RO+', 'CD8+ T TIGIT+ CD45RO+', 'CD8+ T naive', 'CD14+ Mono', 'CD16+ Mono', 'Erythroblast', 'G/M prog', 'HSC', 'Lymph prog', 'MAIT', 'MK/E prog', 'NK', 'Naive CD20+ B IGKC+', 'Naive CD20+ B IGKC-', 'Normoblast', 'Plasma cell IGKC+', 'Proerythroblast', 'Reticulocyte', 'T reg', 'Transitional B', 'cDC2', 'pDC']\n", "🔄 Preparing training and validation data...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Copying dataset: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 8.89it/s]\n", "Splitting data: 100%|████████████████████████████████████████████████| 3/3 [00:00<00:00, 257.23it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Split data: 4748 train, 1187 validation\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "WARNING:biothings.client:Input sequence provided is already in string format. No operation performed\n", "WARNING:biothings.client:Input sequence provided is already in string format. No operation performed\n", "INFO:biothings.client:querying 1-1000 ...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "🏋️ Starting training with CellPLM pipeline...\n", "📈 Training for 500 epochs with real-time metrics...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:biothings.client:querying 1001-2000 ...\n", "INFO:biothings.client:querying 2001-3000 ...\n", "INFO:biothings.client:Finished.\n", "WARNING:biothings.client:53 input query terms found dup hits:\t[('LINC00115', 2), ('TNFRSF14-AS1', 3), ('CCDC18-AS1', 2), ('LINC00623', 2), ('LINC02591', 2), ('TMC\n", "WARNING:biothings.client:292 input query terms found no hit:\t['AL139246.5', 'AL021155.5', 'BX284668.5', 'BX284668.6', 'AL020997.5', 'AL360012.1', 'AC004865.2', '\n", "INFO:biothings.client:Pass \"returnall=True\" to return complete lists of duplicate or missing query terms.\n", "100%|██████████| 500/500 [01:16<00:00, 6.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "After filtering, 2339 genes remain.\n", "\n", "📊 Final Training Results (Epoch 499):\n", " 🎯 Train ACC: 0.7077\n", " ✅ Valid ACC: 0.6934\n", " 📈 Train F1: 0.7066\n", " 📈 Valid F1: 0.6846\n", "✓ CellPLM annotation fine-tuning completed successfully!\n" ] } ], "source": [ "fine_tune_results = manager.model.fine_tune(\n", " train_adata=reference_adata,\n", " epochs=1500, # \n", " batch_size=32, # \n", " lr=1e-4, # \n", ")" ] }, { "cell_type": "markdown", "id": "d851947b-4a0a-4f4d-9e4d-fc25dfb384d9", "metadata": {}, "source": [ "### Batch integration with fine-tuned model\n", "\n", "After fine-tuning, we perform batch integration to remove technical variations while preserving biological differences. This critical step ensures that cells from different batches can be properly compared and analyzed together." ] }, { "cell_type": "code", "execution_count": 34, "id": "f807cde7-b06d-4990-9887-0a782c0a7e2f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🔗 Performing batch integration for 27423 cells...\n", "🧬 Extracting embeddings for 27423 cells...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Computing embeddings...: 0%| | 0/2 [00:00" ] }, "metadata": { "image/png": { "height": 287, "width": 1163 } }, "output_type": "display_data" } ], "source": [ "sc.pp.neighbors(adata, use_rep='X_cellplm_fine')\n", "sc.tl.umap(adata)\n", "ov.pl.embedding(\n", " adata, \n", " basis='X_umap',\n", " color=['batch', 'cell_type']\n", ")" ] }, { "cell_type": "markdown", "id": "264119f7-402f-402a-8c53-5aa914077791", "metadata": {}, "source": [ "### Cell type annotation with fine-tuned model\n", "\n", "The fine-tuned CellPLM model can now predict cell types for all cells in the dataset, including those from batches not used in training. This demonstrates the model's ability to generalize learned patterns to new data while leveraging the improved discrimination capability gained through fine-tuning." ] }, { "cell_type": "code", "execution_count": 36, "id": "b83b0287-b129-4ff4-ad56-b81457463567", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🔮 Predicting cell types for 27423 cells...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Running prediction...: 0%| | 0/2 [00:00" ] }, "metadata": { "image/png": { "height": 330, "width": 993 } }, "output_type": "display_data" } ], "source": [ "ov.pl.embedding(\n", " adata, \n", " basis='X_umap',\n", " color=['batch', 'predicted_celltype']\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "omicverse", "language": "python", "name": "omicverse" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.17" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }