scBERT — Foundation Model Tutorial¶
scBERT — Compact 200-dim embeddings, BERT-style masked gene pretraining, lightweight model
| Property | Value |
|---|---|
| Tasks | embed, integrate |
| Species | human |
| Gene IDs | symbol |
| GPU Required | Yes |
| Min VRAM | 8 GB |
| Embedding Dim | 200 |
| Repository | https://github.com/TencentAILabHealthcare/scBERT |
This tutorial demonstrates how to use scBERT through the unified ov.fm API.
Cite: Zeng, Z. et al. (2024). OmicVerse: a framework for bridging and deepening insights across bulk and single-cell sequencing. Nature Communications, 15(1), 5983.
import omicverse as ov
import scanpy as sc
import os
import warnings
warnings.filterwarnings('ignore')
ov.plot_set()
Why scBERT?¶
scBERT stands out for its compact 200-dimensional embeddings — the smallest among all foundation models in ov.fm. This makes it ideal for:
- Memory-constrained environments — embeddings take ~4x less RAM than 512-dim models
- Fast downstream clustering — lower dimensionality speeds up neighbor graphs and UMAP
- Lightweight deployment — only 8 GB VRAM required, with CPU fallback support
The BERT-style masked gene pretraining means scBERT learns gene co-expression patterns by randomly masking genes and predicting them from context, similar to how BERT learns language.
Step 1: Inspect Model Specification¶
Use ov.fm.describe_model() to get the full spec for scBERT.
info = ov.fm.describe_model("scbert")
print("=== Model Info ===")
print(f"Name: {info['model']['name']}")
print(f"Version: {info['model']['version']}")
print(f"Tasks: {info['model']['tasks']}")
print(f"Species: {info['model']['species']}")
print(f"Embedding dim: {info['model']['embedding_dim']}")
print(f"Differentiator: {info['model']['differentiator']}")
print("\n=== Input Contract ===")
print(f"Gene ID scheme: {info['input_contract']['gene_id_scheme']}")
print(f"Preprocessing: {info['input_contract']['preprocessing']}")
print("\n=== Output Contract ===")
print(f"Embedding key: {info['output_contract']['embedding_key']}")
print(f"Embedding dim: {info['output_contract']['embedding_dim']}")
Step 2: Prepare Data¶
Load a dataset and save it for the ov.fm workflow. Most foundation models expect raw counts (non-negative values).
adata = sc.datasets.pbmc3k()
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
print(f'Dataset: {adata.n_obs} cells x {adata.n_vars} genes')
print(f'Gene names (first 5): {adata.var_names[:5].tolist()}')
print(f'X range: [{adata.X.min():.1f}, {adata.X.max():.1f}]')
adata.write_h5ad('pbmc3k_scbert.h5ad')
Step 3: Profile Data & Validate Compatibility¶
Check whether your data is compatible with scBERT before running inference.
profile = ov.fm.profile_data("pbmc3k_scbert.h5ad")
print("=== Data Profile ===")
print(f"Species: {profile['species']}")
print(f"Gene scheme: {profile['gene_scheme']}")
print(f"Modality: {profile['modality']}")
print(f"Cells: {profile['n_cells']:,}")
print(f"Genes: {profile['n_genes']:,}")
# Validate compatibility
validation = ov.fm.preprocess_validate("pbmc3k_scbert.h5ad", "scbert", "embed")
print(f"\n=== Validation: {validation['status']} ===")
for d in validation.get("diagnostics", []):
print(f" [{d['severity']}] {d['message']}")
if validation.get("auto_fixes"):
print("\nSuggested fixes:")
for fix in validation["auto_fixes"]:
print(f" - {fix}")
Step 4: Run scBERT Inference¶
Execute scBERT through ov.fm.run(). The function handles preprocessing, model loading, inference, and output writing.
result = ov.fm.run(
task="embed",
model_name="scbert",
adata_path="pbmc3k_scbert.h5ad",
output_path="pbmc3k_scbert_out.h5ad",
device="auto",
)
if "error" in result:
print(f"Error: {result['error']}")
if "suggestion" in result:
print(f"Suggestion: {result['suggestion']}")
else:
print(f"Status: {result['status']}")
print(f"Output keys: {result.get('output_keys', [])}")
print(f"Cells processed: {result.get('n_cells', 0)}")
Step 5: Visualize & Interpret Results¶
Load the output, compute UMAP from scBERT embeddings, and evaluate quality.
if os.path.exists("pbmc3k_scbert_out.h5ad"):
adata_out = sc.read_h5ad("pbmc3k_scbert_out.h5ad")
emb_key = "X_scBERT"
if emb_key in adata_out.obsm:
print(f"Embedding shape: {adata_out.obsm[emb_key].shape}")
# UMAP visualization
sc.pp.neighbors(adata_out, use_rep=emb_key)
sc.tl.umap(adata_out)
sc.tl.leiden(adata_out, resolution=0.5)
sc.pl.umap(adata_out, color=["leiden"],
title="scBERT Embedding (PBMC 3k)")
# QA metrics
interpretation = ov.fm.interpret_results("pbmc3k_scbert_out.h5ad", task="embed")
if "embeddings" in interpretation["metrics"]:
for k, v in interpretation["metrics"]["embeddings"].items():
print(f"\n{k}: dim={v['dim']}", end="")
if "silhouette" in v:
print(f", silhouette={v['silhouette']:.4f}", end="")
print()
else:
print(f"Embedding key {emb_key} not found.")
print(f"Available keys: {list(adata_out.obsm.keys())}")
else:
print("Output file not found — check model installation and adapter status.")
print("See the Guide page for installation instructions.")
Summary¶
| Step | Function | What it does |
|---|---|---|
| 1 | ov.fm.describe_model("scbert") |
Inspect model spec and I/O contract |
| 2 | sc.datasets.pbmc3k() |
Prepare input data |
| 3 | ov.fm.profile_data() + preprocess_validate() |
Check compatibility |
| 4 | ov.fm.run() |
Execute scBERT inference |
| 5 | ov.fm.interpret_results() |
Evaluate embedding quality |
For the full model catalog, see ov.fm.list_models() or the ov.fm API Overview.
For detailed scBERT specifications, see the scBERT Guide.