omicverse.space.pySTAGATE¶
- omicverse.space.pySTAGATE(adata: AnnData, num_batch_x, num_batch_y, spatial_key: list = ['X', 'Y'], batch_size: int = 1, rad_cutoff: int = 200, num_epoch: int = 1000, lr: float = 0.001, weight_decay: float = 0.0001, hidden_dims: list = [512, 30], device: str = 'cuda:0') None[source]¶
A class representing the PyTorch implementation of STAGATE (Spatial Transcriptomics Analysis using Graph Attention autoEncoder).
- Parameters:
adata (AnnData) – Spatial AnnData with coordinates and expression matrix.
num_batch_x (int) – Number of tiles along x-axis for mini-batch graph construction.
num_batch_y (int) – Number of tiles along y-axis for mini-batch graph construction.
spatial_key (list, default=['X', 'Y']) – Coordinate columns in
adata.obsused to build spatial graph.batch_size (int, default=1) – Number of tiled graphs per optimization step.
rad_cutoff (int, default=200) – Radius cutoff when constructing spatial neighbors.
num_epoch (int, default=1000) – Number of training epochs.
lr (float, default=0.001) – Learning rate for Adam optimizer.
weight_decay (float, default=1e-4) – L2 regularization strength.
hidden_dims (list, default=[512, 30]) – Hidden-layer sizes of STAGATE encoder.
device (str, default='cuda:0') – Device specifier; falls back to CPU when CUDA is unavailable.
Attributes –
- device: torch.device
Device where the model is running.
- loader: DataLoader
PyTorch DataLoader for batch processing.
- model: STAGATE
The STAGATE model instance.
- optimizer: torch.optim.Adam
Adam optimizer for model training.
- adata: AnnData
Input annotated data matrix.
- data: torch_geometric.data.Data
PyTorch geometric data object.
Notes – The STAGATE model is designed for analyzing spatial transcriptomics data by incorporating spatial information through a graph attention autoencoder architecture.
Examples –
>>> import scanpy as sc >>> import omicverse as ov >>> adata = sc.read_h5ad('spatial_data.h5ad') >>> stagate = ov.space.pySTAGATE(adata, num_batch_x=3, num_batch_y=2) >>> stagate.train() >>> stagate.predicted()