{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# scBERT — Foundation Model Tutorial\n", "\n", "**scBERT** — Compact 200-dim embeddings, BERT-style masked gene pretraining, lightweight model\n", "\n", "| Property | Value |\n", "|----------|-------|\n", "| **Tasks** | embed, integrate |\n", "| **Species** | human |\n", "| **Gene IDs** | symbol |\n", "| **GPU Required** | Yes |\n", "| **Min VRAM** | 8 GB |\n", "| **Embedding Dim** | 200 |\n", "| **Repository** | [https://github.com/TencentAILabHealthcare/scBERT](https://github.com/TencentAILabHealthcare/scBERT) |\n", "", "\n", "This tutorial demonstrates how to use **scBERT** through the unified `ov.fm` API.\n", "\n", "**Cite:** Zeng, Z. et al. (2024). OmicVerse: a framework for bridging and deepening insights across bulk and single-cell sequencing. *Nature Communications*, 15(1), 5983." ] }, { "cell_type": "code", "metadata": {}, "source": [ "import omicverse as ov\n", "import scanpy as sc\n", "import os\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "ov.plot_set()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Why scBERT?\n", "\n", "scBERT stands out for its **compact 200-dimensional embeddings** — the smallest among all foundation models in `ov.fm`. This makes it ideal for:\n\n- **Memory-constrained environments** — embeddings take ~4x less RAM than 512-dim models\n- **Fast downstream clustering** — lower dimensionality speeds up neighbor graphs and UMAP\n- **Lightweight deployment** — only 8 GB VRAM required, with CPU fallback support\n\nThe BERT-style masked gene pretraining means scBERT learns gene co-expression patterns by randomly masking genes and predicting them from context, similar to how BERT learns language." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Inspect Model Specification\n", "\n", "Use `ov.fm.describe_model()` to get the full spec for scBERT." ] }, { "cell_type": "code", "metadata": {}, "source": [ "info = ov.fm.describe_model(\"scbert\")\n", "\n", "print(\"=== Model Info ===\")\n", "print(f\"Name: {info['model']['name']}\")\n", "print(f\"Version: {info['model']['version']}\")\n", "print(f\"Tasks: {info['model']['tasks']}\")\n", "print(f\"Species: {info['model']['species']}\")\n", "print(f\"Embedding dim: {info['model']['embedding_dim']}\")\n", "print(f\"Differentiator: {info['model']['differentiator']}\")\n", "\n", "print(\"\\n=== Input Contract ===\")\n", "print(f\"Gene ID scheme: {info['input_contract']['gene_id_scheme']}\")\n", "print(f\"Preprocessing: {info['input_contract']['preprocessing']}\")\n", "\n", "print(\"\\n=== Output Contract ===\")\n", "print(f\"Embedding key: {info['output_contract']['embedding_key']}\")\n", "print(f\"Embedding dim: {info['output_contract']['embedding_dim']}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Prepare Data\n", "\n", "Load a dataset and save it for the `ov.fm` workflow. Most foundation models expect raw counts (non-negative values)." ] }, { "cell_type": "code", "metadata": {}, "source": [ "adata = sc.datasets.pbmc3k()\n", "sc.pp.filter_cells(adata, min_genes=200)\n", "sc.pp.filter_genes(adata, min_cells=3)\n", "print(f'Dataset: {adata.n_obs} cells x {adata.n_vars} genes')\n", "print(f'Gene names (first 5): {adata.var_names[:5].tolist()}')\n", "print(f'X range: [{adata.X.min():.1f}, {adata.X.max():.1f}]')\n", "\n", "adata.write_h5ad('pbmc3k_scbert.h5ad')\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Profile Data & Validate Compatibility\n", "\n", "Check whether your data is compatible with scBERT before running inference." ] }, { "cell_type": "code", "metadata": {}, "source": [ "profile = ov.fm.profile_data(\"pbmc3k_scbert.h5ad\")\n", "\n", "print(\"=== Data Profile ===\")\n", "print(f\"Species: {profile['species']}\")\n", "print(f\"Gene scheme: {profile['gene_scheme']}\")\n", "print(f\"Modality: {profile['modality']}\")\n", "print(f\"Cells: {profile['n_cells']:,}\")\n", "print(f\"Genes: {profile['n_genes']:,}\")\n", "\n", "# Validate compatibility\n", "validation = ov.fm.preprocess_validate(\"pbmc3k_scbert.h5ad\", \"scbert\", \"embed\")\n", "print(f\"\\n=== Validation: {validation['status']} ===\")\n", "for d in validation.get(\"diagnostics\", []):\n", " print(f\" [{d['severity']}] {d['message']}\")\n", "if validation.get(\"auto_fixes\"):\n", " print(\"\\nSuggested fixes:\")\n", " for fix in validation[\"auto_fixes\"]:\n", " print(f\" - {fix}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4: Run scBERT Inference\n", "\n", "Execute scBERT through `ov.fm.run()`. The function handles preprocessing, model loading, inference, and output writing." ] }, { "cell_type": "code", "metadata": {}, "source": [ "result = ov.fm.run(\n", " task=\"embed\",\n", " model_name=\"scbert\",\n", " adata_path=\"pbmc3k_scbert.h5ad\",\n", " output_path=\"pbmc3k_scbert_out.h5ad\",\n", " device=\"auto\",\n", ")\n", "\n", "if \"error\" in result:\n", " print(f\"Error: {result['error']}\")\n", " if \"suggestion\" in result:\n", " print(f\"Suggestion: {result['suggestion']}\")\n", "else:\n", " print(f\"Status: {result['status']}\")\n", " print(f\"Output keys: {result.get('output_keys', [])}\")\n", " print(f\"Cells processed: {result.get('n_cells', 0)}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5: Visualize & Interpret Results\n", "\n", "Load the output, compute UMAP from scBERT embeddings, and evaluate quality." ] }, { "cell_type": "code", "metadata": {}, "source": [ "if os.path.exists(\"pbmc3k_scbert_out.h5ad\"):\n", " adata_out = sc.read_h5ad(\"pbmc3k_scbert_out.h5ad\")\n", " emb_key = \"X_scBERT\"\n", " \n", " if emb_key in adata_out.obsm:\n", " print(f\"Embedding shape: {adata_out.obsm[emb_key].shape}\")\n", " \n", " # UMAP visualization\n", " sc.pp.neighbors(adata_out, use_rep=emb_key)\n", " sc.tl.umap(adata_out)\n", " sc.tl.leiden(adata_out, resolution=0.5)\n", " sc.pl.umap(adata_out, color=[\"leiden\"],\n", " title=\"scBERT Embedding (PBMC 3k)\")\n", " \n", " # QA metrics\n", " interpretation = ov.fm.interpret_results(\"pbmc3k_scbert_out.h5ad\", task=\"embed\")\n", " if \"embeddings\" in interpretation[\"metrics\"]:\n", " for k, v in interpretation[\"metrics\"][\"embeddings\"].items():\n", " print(f\"\\n{k}: dim={v['dim']}\", end=\"\")\n", " if \"silhouette\" in v:\n", " print(f\", silhouette={v['silhouette']:.4f}\", end=\"\")\n", " print()\n", " else:\n", " print(f\"Embedding key {emb_key} not found.\")\n", " print(f\"Available keys: {list(adata_out.obsm.keys())}\")\n", "else:\n", " print(\"Output file not found — check model installation and adapter status.\")\n", " print(\"See the Guide page for installation instructions.\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "| Step | Function | What it does |\n", "|------|----------|-------------|\n", "| 1 | `ov.fm.describe_model(\"scbert\")` | Inspect model spec and I/O contract |\n", "| 2 | `sc.datasets.pbmc3k()` | Prepare input data |\n", "| 3 | `ov.fm.profile_data()` + `preprocess_validate()` | Check compatibility |\n", "| 4 | `ov.fm.run()` | Execute scBERT inference |\n", "| 5 | `ov.fm.interpret_results()` | Evaluate embedding quality |\n", "\n", "For the full model catalog, see `ov.fm.list_models()` or the [ov.fm API Overview](t_fm_guide.md).\n", "For detailed scBERT specifications, see the [scBERT Guide](t_fm_scbert_guide.md)." ] } ] }