omicverse.space.pySTAGATE

omicverse.space.pySTAGATE(adata: AnnData, num_batch_x, num_batch_y, spatial_key: list = ['X', 'Y'], batch_size: int = 1, rad_cutoff: int = 200, num_epoch: int = 1000, lr: float = 0.001, weight_decay: float = 0.0001, hidden_dims: list = [512, 30], device: str = 'cuda:0') None[source]

A class representing the PyTorch implementation of STAGATE (Spatial Transcriptomics Analysis using Graph Attention autoEncoder).

Parameters:
  • adata (AnnData) – Spatial AnnData with coordinates and expression matrix.

  • num_batch_x (int) – Number of tiles along x-axis for mini-batch graph construction.

  • num_batch_y (int) – Number of tiles along y-axis for mini-batch graph construction.

  • spatial_key (list, default=['X', 'Y']) – Coordinate columns in adata.obs used to build spatial graph.

  • batch_size (int, default=1) – Number of tiled graphs per optimization step.

  • rad_cutoff (int, default=200) – Radius cutoff when constructing spatial neighbors.

  • num_epoch (int, default=1000) – Number of training epochs.

  • lr (float, default=0.001) – Learning rate for Adam optimizer.

  • weight_decay (float, default=1e-4) – L2 regularization strength.

  • hidden_dims (list, default=[512, 30]) – Hidden-layer sizes of STAGATE encoder.

  • device (str, default='cuda:0') – Device specifier; falls back to CPU when CUDA is unavailable.

  • Attributes

    device: torch.device

    Device where the model is running.

    loader: DataLoader

    PyTorch DataLoader for batch processing.

    model: STAGATE

    The STAGATE model instance.

    optimizer: torch.optim.Adam

    Adam optimizer for model training.

    adata: AnnData

    Input annotated data matrix.

    data: torch_geometric.data.Data

    PyTorch geometric data object.

  • Notes – The STAGATE model is designed for analyzing spatial transcriptomics data by incorporating spatial information through a graph attention autoencoder architecture.

  • Examples

    >>> import scanpy as sc
    >>> import omicverse as ov
    >>> adata = sc.read_h5ad('spatial_data.h5ad')
    >>> stagate = ov.space.pySTAGATE(adata, num_batch_x=3, num_batch_y=2)
    >>> stagate.train()
    >>> stagate.predicted()