class pyVIA(object):

    def __init__(self,adata:anndata.AnnData,adata_key:str='X_pca',adata_ncomps:int=80,basis:str='X_umap',
                 clusters:str='',dist_std_local:float=2, jac_std_global=0.15, labels:np.ndarray=None,
                 keep_all_local_dist='auto', too_big_factor:float=0.4, resolution_parameter:float=1.0, partition_type:str="ModularityVP", small_pop:int=10,
                 jac_weighted_edges:bool=True, knn:int=30, n_iter_leiden:int=5, random_seed:int=42,
                 num_threads=-1, distance='l2', time_smallpop=15,
                 super_cluster_labels:bool=False,                 super_node_degree_list:bool=False, super_terminal_cells:bool=False, x_lazy:float=0.95, alpha_teleport:float=0.99,
                 root_user=None, preserve_disconnected:bool=True, dataset:str='', super_terminal_clusters:list=[],
                 is_coarse=True, csr_full_graph:np.ndarray='', csr_array_locally_pruned='', ig_full_graph='',
                 full_neighbor_array='', full_distance_array='',  df_annot=None,
                 secondary_annotations:list=None, pseudotime_threshold_TS:int=30, cluster_graph_pruning_std:float=0.15,
                 visual_cluster_graph_pruning:float=0.15, neighboring_terminal_states_threshold=3, num_mcmc_simulations=1300,
                 piegraph_edgeweight_scalingfactor=1.5, max_visual_outgoing_edges:int=2, via_coarse=None, velocity_matrix=None,
                 gene_matrix=None, velo_weight=0.5, edgebundle_pruning=None, A_velo = None, CSM = None, edgebundle_pruning_twice=False, pca_loadings = None, time_series=False,
                 time_series_labels:list=None, knn_sequential:int = 10, knn_sequential_reverse:int = 0,t_diff_step:int = 1,single_cell_transition_matrix = None,
                 embedding_type:str='via-mds',do_compute_embedding:bool=False, color_dict:dict=None,user_defined_terminal_cell:list=[], user_defined_terminal_group:list=[],
                 do_gaussian_kernel_edgeweights:bool=False,RW2_mode:bool=False,working_dir_fp:str ='/home/shobi/Trajectory/Datasets/') -> None:
        Initialize a pyVIA object.

            adata: An AnnData object containing the scRNA-seq.
            adata_key: the key of the AnnData in obsm to perform VIA on. default: 'X_pca'
            adata_ncomps: the number of components to use from the AnnData in obsm to perform VIA on. default: 80
            basis: the key of the AnnData in obsm to use as the basis for the embedding. default: 'X_umap'
            clusters: the clusters to use for the VIA analysis. default: ''
            dist_std_local: local level of pruning for PARC graph clustering stage. Range (0.1,3) higher numbers mean more edge retention
            jac_std_global: (optional, default = 0.15, can also set as 'median') global level graph pruning for PARC clustering stage. Number of standard deviations below the network’s mean-jaccard-weighted edges. 0.1-1 provide reasonable pruning.higher value means less pruning (more edges retained). e.g. a value of 0.15 means all edges that are above mean(edgeweight)-0.15*std(edge-weights) are retained. We find both 0.15 and ‘median’ to yield good results/starting point and resulting in pruning away ~ 50-60% edges
            labels: default is None. and PARC clusters are used for the viagraph. alternatively provide a list of clustermemberships that are integer values (not strings) to construct the viagraph using another clustering method or available annotations
            keep_all_local_dist: default value of 'auto' means that for smaller datasets local-pruning is done prior to clustering, but for large datasets local pruning is set to False for speed.

            too_big_factor: (optional, default=0.4). Forces clusters > 0.4*n_cells to be re-clustered
            resolution_parameter: (float) the resolution parameter for the Louvain algorithm.
            partition_type: (str, default: "ModularityVP") the partitioning algorithm to use.
            small_pop: (int, default: 10) the number of cells to be considered in a small population.
            jac_weighted_edges: (bool, default: True) whether to use weighted edges in the PARC clustering step.
            knn: (int, optional, default: 30) the number of K-Nearest Neighbors for HNSWlib KNN graph. Larger knn means more graph connectivity. Lower knn means more loosely connected clusters/cells.
            n_iter_leiden: (int) the number of iterations for the Leiden algorithm.
            random_seed: (int) the random seed to pass to the clustering algorithm.
            num_threads: (int) the number of threads to use for the clustering algorithm.
            distance: (str, default: 'l2') the distance metric to use for graph construction and similarity. Options are 'l2', 'ip', and 'cosine'.
            visual_cluster_graph_pruning: (float, default: 0.15) the pruning level for the cluster graph. This only comes into play if the user deliberately chooses not to use the default edge-bundling method of visualizating edges (draw_piechart_graph()) and instead calls draw_piechart_graph_nobundle(). It controls the number of edges plotted for visual effect. This does not impact computation of terminal states, pseudotime, or lineage likelihoods.
            cluster_graph_pruning_std: (float, default: 0.15) the pruning level of the cluster graph. Often set to the same value as the PARC clustering level of jac_std_global. Reasonable range is [0.1, 1]. To retain more connectivity in the clustergraph underlying the trajectory computations, increase the value.
            time_smallpop: (max time to be allowed handling singletons) the maximum time allowed to handle singletons.
            x_lazy: (float, default: 0.95) 1-x = probability of staying in the same node (lazy). Values between 0.9-0.99 are reasonable.
            alpha_teleport: (float, default: 0.99) 1-alpha is probability of jumping. Values between 0.95-0.99 are reasonable unless prior knowledge of teleportation.
            root_user: (list, None) the root user list. Can be a list of strings, a list of int, or None. The default is None. When the root_user is set as None and an RNA velocity matrix is available, a root will be automatically computed. If the root_user is None and no velocity matrix is provided, then an arbitrary root is selected. If the root_user is ['celltype_earlystage'] where the str corresponds to an item in true_label, then a suitable starting point will be selected corresponding to this group.
            preserve_disconnected: bool (default = True) If you believe there may be disconnected trajectories then set this to False
            dataset: str Can be set to 'group' or '' (default). this refers to the type of root label (group level root or single cell index) you are going to provide. if your true_label has a sensible group of cells for a root then you can set dataset to 'group' and make the root parameter ['labelname_root_cell_type'] if your root corresponds to one particular cell then set dataset = '' (default)
            embedding: ndarray (optional, default = None) embedding (e.g. precomputed tsne, umap, phate, via-umap) for plotting data. Size n_cells x 2 If an embedding is provided when running VIA, then a scatterplot colored by pseudotime, highlighting terminal fates
            velo_weight: float (optional, default = 0.5) #float between [0,1]. the weight assigned to directionality and connectivity derived from scRNA-velocity
            neighboring_terminal_states_threshold:int (default = 3). Candidates for terminal states that are neighbors of each other may be removed from the list if they have this number of more of terminal states as neighbors
            knn_sequential:int (default =10) number of knn in the adjacent time-point for time-series data (t_i and t_i+1)
            knn_sequential_reverse: int (default = 0) number of knn enforced from current to previous time point
            t_diff_step: int (default =1) Number of permitted temporal intervals between connected nodes. If time data is labeled as [0,25,50,75,100,..] then t_diff_step=1 corresponds to '25' and only edges within t_diff_steps are retained
            is_coarse:bool (default = True) If running VIA in two iterations where you wish to link the second fine-grained iteration with the initial iteration, then you set to False
            via_coarse: VIA (default = None) If instantiating a second iteration of VIA that needs to be linked to a previous iteration (e.g. via0), then set via_coarse to the previous via0 object
            df_annot: DataFrame (default None) used for the Mouse Organ data
            preserve_disconnected_after_pruning:bool (default = True) If you believe there are disconnected trajectories then set this to False and test your hypothesis
            A_velo: ndarray Cluster Graph Transition matrix based on rna velocity [n_clus x n_clus]
            velocity_matrix: matrix (default None) matrix of size [n_samples x n_genes]. this is the velocity matrix computed by scVelo (or similar package) and stored in adata.layers['velocity']. The genes used for computing velocity should correspond to those useing in gene_matrix Requires gene_matrix to be provided too.
            gene_matrix: matrix (default None) Only used if Velocity_matrix is available. matrix of size [n_samples x n_genes]. We recommend using a subset like HVGs rather than full set of genes. (need to densify input if taking from adata = adata.X.todense())
            time_series: if the data has time-series labels then set to True
            time_series_labels:list (default None) list of integer values of temporal annotations corresponding to e.g. hours (post fert), days, or sequential ordering
            pca_loadings: array (default None) the loadings of the pcs used to project the cells (to projected euclidean location based on velocity). n_cells x n_pcs
            secondary_annotations: None (default None)
            edgebundle_pruning:float (default=None) will by default be set to the same as the cluster_graph_pruning_std and influences the visualized level of pruning of edges. Typical values can be between [0,1] with higher numbers retaining more edges
            edgebundle_pruning_twice:bool default: False. When True, the edgebundling is applied to a further visually pruned (visual_cluster_graph_pruning) and can sometimes simplify the visualization. it does not impact the pseudotime and lineage computations
            piegraph_arrow_head_width: float (default = 0.1) size of arrow heads in via cluster graph
            piegraph_edgeweight_scalingfactor: (defaulf = 1.5) scaling factor for edge thickness in via cluster graph
            max_visual_outgoing_edges: int (default =2) Rarely comes into play. Only used if the user later chooses to plot the via-graph without edgebunding using draw_piechart_graph_nobundle() Only allows max_visual_outgoing_edges to come out of any given node.
            edgebundle_pruning:float (default=None) will by default be set to the same as the cluster_graph_pruning_std and influences the visualized level of pruning of edges. Typical values can be between [0,1] with higher numbers retaining more edges
            edgebundle_pruning_twice:bool default: False. When True, the edgebundling is applied to a further visually pruned (visual_cluster_graph_pruning) and can sometimes simplify the visualization. it does not impact the pseudotime and lineage computations
            pseudotime_threshold_TS: int (default = 30) corresponds to the criteria for a state to be considered a candidate terminal cell fate to be 30% or later of the computed pseudotime range
            num_mcmc_simulations:int (default = 1300) number of random walk simulations conducted
            embedding_type: str (default = 'via-mds', other options are 'via-umap' and 'via-force'
            do_compute_embedding: bool (default = False) If you want an embedding (n_samples x2) to be computed on the basis of the via sc graph then set this to True
            do_gaussian_kernel_edgeweights: bool (default = False) Type of edgeweighting on the graph edges


        self.adata = adata
        #self.adata_key = adata_key
        data = adata.obsm[adata_key][:, 0:adata_ncomps]

        if root_user is not None:

    def run(self):
        """calculate the via graph and pseudotime


    def get_piechart_dict(self,label:int=0,clusters:str='')->dict:
        Cluster composition graph

            label: int (default=0) cluster label of pie chart
            clusters: the celltype you want interested

            res_dict: cluster composition graph
        if clusters=='':
        cluster_i_loc=np.where(np.asarray(self.model.labels) == label)[0]
        return res_dict

    def get_pseudotime(self,adata=None):
        Extract the pseudotime of VIA

            adata: an adata object of you interested,if None, it will be added to `self.adata.obs['pt_via']`


        print('...the pseudotime of VIA added to AnnData obs named `pt_via`')
        if adata is None:

    def plot_piechart_graph(self,clusters:str='', type_data='pt',
                                gene_exp:list=[], title='', 
                                cmap:str=None, ax_text=True, figsize:tuple=(8,4),
                                dpi=150,headwidth_arrow = 0.1, 
                                alpha_edge=0.4, linewidth_edge=2, 
                                show_legend:bool=True, pie_size_scale:float=0.8, fontsize:float=8)->Tuple[matplotlib.figure.Figure,
        """plot two subplots with a clustergraph level representation of the viagraph showing true-label composition (lhs) and pseudotime/gene expression (rhs)

            clusters : column name of the adata.obs dataframe that contains the cluster labels
            type_data : string  default 'pt' for pseudotime colored nodes. or 'gene'
            gene_exp : list of values (column of dataframe) corresponding to feature or gene expression to be used to color nodes at CLUSTER level
            title : string
            cmap : default None. automatically chooses coolwarm for gene expression or viridis_r for pseudotime
            ax_text : Bool default= True. Annotates each node with cluster number and population of membership
            dpi : int default = 150
            headwidth_bundle : default = 0.1. width of arrowhead used to directed edges
            reference : None or list. list of categorical (str) labels for cluster composition of the piecharts (LHS subplot) length = n_samples.
            pie_size_scale : float default=0.8 scaling factor of the piechart nodes
            fontsize : float default=8. fontsize of the text in the piecharts
            figsize : tuple default=(8,4). size of the figure

            fig: Returns matplotlib figure with two axes that plot the clustergraph using edge bundling
            ax: left axis shows the clustergraph with each node colored by annotated ground truth membership.
            ax1: right axis shows the same clustergraph with each node colored by the pseudotime or gene expression

        if clusters=='':
        fig, ax, ax1 = draw_piechart_graph_pyomic(clusters=clusters,adata=self.adata,
                                   via_object=self.model, type_data=type_data,
                                gene_exp=gene_exp, title=title, 
                                cmap=cmap, ax_text=ax_text,figsize=figsize,
                                dpi=dpi,headwidth_arrow = headwidth_arrow,
                                alpha_edge=alpha_edge, linewidth_edge=linewidth_edge,
                                show_legend=show_legend, pie_size_scale=pie_size_scale, fontsize=fontsize)
        return fig, ax, ax1

    def plot_stream(self,clusters:str='',basis:str='',
                   density_grid:float=0.5, arrow_size:float=0.7, arrow_color:str = 'k',
                   arrow_style="-|>",  max_length:int=4, linewidth:float=1,min_mass = 1, cutoff_perc:int = 5,
                   scatter_size:int=500, scatter_alpha:float=0.5,marker_edgewidth:float=0.1,
                   density_stream:int = 2, smooth_transition:int=1, smooth_grid:float=0.5,
                   color_scheme:str = 'annotation', add_outline_clusters:bool=False,
                   cluster_outline_edgewidth = 0.001,gp_color = 'white', bg_color='black' ,
                   dpi=80 , title='Streamplot', b_bias=20, n_neighbors_velocity_grid=None,
                   other_labels:list = None,use_sequentially_augmented:bool=False, cmap_str:str='rainbow')->Tuple[matplotlib.figure.Figure,
        """Construct vector streamplot on the embedding to show a fine-grained view of inferred directions in the trajectory

            clusters : column name of the adata.obs dataframe that contains the cluster labels
            basis : str, default = 'X_umap', which to use for the embedding
            density_grid : float, default = 0.5, density of the grid on which to project the directionality of cells
            arrow_size : float, default = 0.7, size of the arrows in the streamplot
            arrow_color : str, default = 'k', color of the arrows in the streamplot
            arrow_style : str, default = "-|>", style of the arrows in the streamplot
            max_length : int, default = 4, maximum length of the arrows in the streamplot
            linewidth : float, default = 1, width of  lines in streamplot
            min_mass : float, default = 1, minimum mass of the arrows in the streamplot
            cutoff_perc : int, default = 5, cutoff percentage of the arrows in the streamplot
            scatter_size : int, default = 500, size of scatter points
            scatter_alpha : float, default = 0.5, transpsarency of scatter points
            marker_edgewidth : float, default = 0.1, width of outline arround each scatter point
            density_stream : int, default = 2, density of the streamplot
            smooth_transition : int, default = 1, smoothness of the transition between the streamplot and the scatter points
            smooth_grid : float, default = 0.5, smoothness of the grid on which to project the directionality of cells
            color_scheme : str, default = 'annotation' corresponds to self.true_labels. Other options are 'time' (uses single-cell pseudotime) or 'clusters' (uses self.clusters)
            add_outline_clusters : bool, default = False, whether to add an outline to the clusters
            cluster_outline_edgewidth : float, default = 0.001, width of the outline around the clusters
            gp_color : str, default = 'white', color of the grid points
            bg_color : str, default = 'black', color of the background
            dpi : int, default = 80, dpi of the figure
            title : str, default = 'Streamplot', title of the figure
            b_bias : int, default = 20, higher value makes the forward bias of pseudotime stronger
            n_neighbors_velocity_grid : int, default = None, number of neighbors to use for the velocity grid
            other_labels : list, default = None, list of other labels to plot in the streamplot
            use_sequentially_augmented : bool, default = False, whether to use the sequentially augmented data
            cmap_str : str, default = 'rainbow', color map to use for the streamplot

            fig : matplotlib figure
            ax : matplotlib axis

        if clusters=='':
        if basis=='':
        fig,ax = via_streamplot_pyomic(adata=self.adata,clusters=clusters,via_object=self.model, 
                                 embedding=embedding,density_grid=density_grid, arrow_size=arrow_size,
                                 arrow_color=arrow_color,arrow_style=arrow_style,  max_length=max_length,
                                 linewidth=linewidth,min_mass = min_mass, cutoff_perc=cutoff_perc,
                                 scatter_size=scatter_size, scatter_alpha=scatter_alpha,marker_edgewidth=marker_edgewidth,
                                 density_stream=density_stream, smooth_transition=smooth_transition, smooth_grid=smooth_grid,
                                 color_scheme=color_scheme, add_outline_clusters=add_outline_clusters,
                                 cluster_outline_edgewidth = cluster_outline_edgewidth,gp_color = gp_color, bg_color=bg_color,
                                 dpi=dpi , title=title, b_bias=b_bias, n_neighbors_velocity_grid=n_neighbors_velocity_grid,
                                 other_labels=other_labels,use_sequentially_augmented=use_sequentially_augmented, cmap_str=cmap_str)
        return fig,ax

    def plot_trajectory_gams(self,clusters:str='',basis:str='',via_fine=None, idx=None,
                         title_str:str= "Pseudotime", draw_all_curves:bool=True, arrow_width_scale_factor:float=15.0,
                         scatter_size:float=50, scatter_alpha:float=0.5,figsize:tuple=(8,4),
                         linewidth:float=1.5, marker_edgewidth:float=1, cmap_pseudotime:str='viridis_r',dpi:int=80,
                         highlight_terminal_states:bool=True, use_maxout_edgelist:bool =False)->Tuple[matplotlib.figure.Figure,
        """projects the graph based coarse trajectory onto a umap/tsne embedding

            clusters : column name of the adata.obs dataframe that contains the cluster labels
            basis : str, default = 'X_umap', which to use for the embedding
            via_fine : via object suggest to use via_object only unless you found that running via_fine gave better pathways
            idx : default: None. Or List. if you had previously computed a umap/tsne (embedding) only on a subset of the total n_samples (subsampled as per idx), then the via objects and results will be indexed according to idx too
            title_str : title of figure
            draw_all_curves : if the clustergraph has too many edges to project in a visually interpretable way, set this to False to get a simplified view of the graph pathways
            arrow_width_scale_factor : the width of the arrows is proportional to the edge weight. This factor scales the width of the arrows
            scatter_size : size of the scatter points
            scatter_alpha : transparency of the scatter points
            linewidth : width of the lines
            marker_edgewidth : width of the outline around each scatter point
            cmap_pseudotime : color map to use for the pseudotime
            dpi : dpi of the figure
            highlight_terminal_states :  whether or not to highlight/distinguish the clusters which are detected as the terminal states by via

            fig : matplotlib figure
            ax1 : matplotlib axis
            ax2 : matplotlib axis


        if clusters=='':
        if basis=='':
        fig,ax1,ax2 = draw_trajectory_gams_pyomic(adata=self.adata,clusters=clusters,via_object=self.model, 
                                            via_fine=via_fine, embedding=embedding, idx=idx,
                                            title_str=title_str, draw_all_curves=draw_all_curves, arrow_width_scale_factor=arrow_width_scale_factor,
                                            scatter_size=scatter_size, scatter_alpha=scatter_alpha,figsize=figsize,
                                            linewidth=linewidth, marker_edgewidth=marker_edgewidth, cmap_pseudotime=cmap_pseudotime,dpi=dpi,
                                            highlight_terminal_states=highlight_terminal_states, use_maxout_edgelist=use_maxout_edgelist)
        return fig,ax1,ax2

    def plot_lineage_probability(self,clusters:str='',basis:str='',via_fine=None, 
                                idx=None, figsize:tuple=(8,4),
                                cmap:str='plasma', dpi:int=80, scatter_size =None,
                                marker_lineages:list = [], fontsize:int=12)->Tuple[matplotlib.figure.Figure,
        """G is the igraph knn (low K) used for shortest path in high dim space. no idx needed as it's made on full sample, knn_hnsw is the knn made in the embedded space used for query to find the nearest point in the downsampled embedding that corresponds to the single cells in the full graph

            clusters : column name of the adata.obs dataframe that contains the cluster labels
            basis : str, default = 'X_umap', which to use for the embedding
            via_fine : usually just set to same as via_coarse unless you ran a refined run and want to link it to initial via_coarse's terminal clusters
            idx : if one uses a downsampled embedding of the original data, then idx is the selected indices of the downsampled samples used in the visualization
            figsize : size of the figure
            cmap : color map to use for the lineage probability
            dpi : dpi of the figure
            scatter_size : size of the scatter points
            marker_lineages : Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).
            fontsize : fontsize of the title

            fig : matplotlib figure
            axs : matplotlib axis

        if clusters=='':
        if basis=='':
        fig, axs = draw_sc_lineage_probability(via_object=self.model,via_fine=via_fine, embedding=embedding,figsize=figsize,
                                               idx=idx, cmap_name=cmap, dpi=dpi, scatter_size =scatter_size,
                                            marker_lineages = marker_lineages, fontsize=fontsize)
        return fig, axs

    def plot_gene_trend(self,gene_list:list=None,figsize:tuple=(8,4),
                        magic_steps:int=3, spline_order:int=5, dpi:int=80,cmap:str='jet', 
                        marker_genes:list = [], linewidth:float = 2.0,
                        n_splines:int=10,  fontsize_:int=12, marker_lineages=[])->Tuple[matplotlib.figure.Figure,
        """plots the gene expression trend along the pseudotime

            gene_list : list of genes to plot
            figsize : size of the figure
            magic_steps : number of magic steps to use for imputation
            spline_order : order of the spline to use for smoothing
            dpi : dpi of the figure
            cmap : color map to use for the gene expression
            marker_genes : Default is to use all genes in gene_exp. other provide a list of marker genes that will be used from gene_exp.
            linewidth : width of the lines
            n_splines : number of splines to use for smoothing
            fontsize_ : fontsize of the title
            marker_lineages : Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).

            fig : matplotlib figure
            axs : matplotlib axis


        df_magic = self.model.do_impute(self.adata[:,gene_list].to_df(), magic_steps=magic_steps, gene_list=gene_list)
        fig, axs=get_gene_expression_pyomic(self.model,df_magic,spline_order=spline_order,dpi=dpi,
                                   cmap=cmap, marker_genes=marker_genes, linewidth=linewidth,figsize=figsize,
                                   n_splines=n_splines,  fontsize_=fontsize_, marker_lineages=marker_lineages)
        return fig, axs

    def plot_clustergraph(self,gene_list:list,arrow_head:float=0.1,figsize:tuple=(8,4),dpi=80,magic_steps=3,
                          edgeweight_scale:float=1.5, cmap=None, label_=True,)->Tuple[matplotlib.figure.Figure,
        """plot the gene in pie chart for each cluster

            gene_list : list of genes to plot
            arrow_head : size of the arrow head
            figsize : size of the figure
            edgeweight_scale : scale of the edge weight
            cmap : color map to use for the gene expression
            label_ : whether to label the nodes

            fig : matplotlib figure
            axs : matplotlib axis
        df_magic = self.model.do_impute(self.adata[:,gene_list].to_df(), magic_steps=magic_steps, gene_list=gene_list)
        df_magic['parc'] = self.model.labels
        df_magic_cluster = df_magic.groupby('parc', as_index=True).mean()
        fig, axs = draw_clustergraph_pyomic(via_object=self.model, type_data='gene', gene_exp=df_magic_cluster, 
                                    gene_list=gene_list, arrow_head=arrow_head,figsize=figsize,
                                    edgeweight_scale=edgeweight_scale, cmap=cmap, label_=label_,dpi=dpi)
        return fig,axs

    def plot_gene_trend_heatmap(self,gene_list:list,marker_lineages:list = [], 
                             fontsize:int=8,cmap:str='viridis', normalize:bool=True, ytick_labelrotation:int = 0, 
        """Plot the gene trends on heatmap: a heatmap is generated for each lineage (identified by terminal cluster number). Default selects all lineages

            gene_list : list of genes to plot
            marker_lineages : list default = None and plots all detected all lineages. Optionally provide a list of integers corresponding to the cluster number of terminal cell fates
            fontsize : int default = 8
            cmap : str default = 'viridis'
            normalize : bool = True
            ytick_labelrotation : int default = 0
            figsize : size of the figure

            fig : matplotlib figure
            axs : list of matplotlib axis       

        df_magic = self.model.do_impute(self.adata[:,gene_list].to_df(), magic_steps=3, gene_list=gene_list)
        df_magic['parc'] = self.model.labels
        df_magic_cluster = df_magic.groupby('parc', as_index=True).mean()
        fig,axs=plot_gene_trend_heatmaps_pyomic(via_object=self.model, df_gene_exp=df_magic, 
        return fig,axs

calculate the via graph and pseudotime

def run(self):
    """calculate the via graph and pseudotime


get_piechart_dict(label=0, clusters='')

Cluster composition graph


Name Type Description Default
label int

int (default=0) cluster label of pie chart

clusters str

the celltype you want interested



Name Type Description
res_dict dict

cluster composition graph

def get_piechart_dict(self,label:int=0,clusters:str='')->dict:
    Cluster composition graph

        label: int (default=0) cluster label of pie chart
        clusters: the celltype you want interested

        res_dict: cluster composition graph
    if clusters=='':
    cluster_i_loc=np.where(np.asarray(self.model.labels) == label)[0]
    return res_dict


Extract the pseudotime of VIA


Name Type Description Default

an adata object of you interested,if None, it will be added to self.adata.obs['pt_via']

def get_pseudotime(self,adata=None):
    Extract the pseudotime of VIA

        adata: an adata object of you interested,if None, it will be added to `self.adata.obs['pt_via']`


    print('...the pseudotime of VIA added to AnnData obs named `pt_via`')
    if adata is None:

plot_piechart_graph(clusters='', type_data='pt', gene_exp=[], title='', cmap=None, ax_text=True, figsize=(8, 4), dpi=150, headwidth_arrow=0.1, alpha_edge=0.4, linewidth_edge=2, edge_color='darkblue', reference=None, show_legend=True, pie_size_scale=0.8, fontsize=8)

plot two subplots with a clustergraph level representation of the viagraph showing true-label composition (lhs) and pseudotime/gene expression (rhs)


Name Type Description Default

column name of the adata.obs dataframe that contains the cluster labels


string default 'pt' for pseudotime colored nodes. or 'gene'


list of values (column of dataframe) corresponding to feature or gene expression to be used to color nodes at CLUSTER level




default None. automatically chooses coolwarm for gene expression or viridis_r for pseudotime


Bool default= True. Annotates each node with cluster number and population of membership


int default = 150


default = 0.1. width of arrowhead used to directed edges


None or list. list of categorical (str) labels for cluster composition of the piecharts (LHS subplot) length = n_samples.


float default=0.8 scaling factor of the piechart nodes


float default=8. fontsize of the text in the piecharts


tuple default=(8,4). size of the figure

(8, 4)


Name Type Description
fig matplotlib.figure.Figure

Returns matplotlib figure with two axes that plot the clustergraph using edge bundling

ax matplotlib.axes._axes.Axes

left axis shows the clustergraph with each node colored by annotated ground truth membership.

ax1 matplotlib.axes._axes.Axes

right axis shows the same clustergraph with each node colored by the pseudotime or gene expression

def plot_piechart_graph(self,clusters:str='', type_data='pt',
                            gene_exp:list=[], title='', 
                            cmap:str=None, ax_text=True, figsize:tuple=(8,4),
                            dpi=150,headwidth_arrow = 0.1, 
                            alpha_edge=0.4, linewidth_edge=2, 
                            show_legend:bool=True, pie_size_scale:float=0.8, fontsize:float=8)->Tuple[matplotlib.figure.Figure,
    """plot two subplots with a clustergraph level representation of the viagraph showing true-label composition (lhs) and pseudotime/gene expression (rhs)

        clusters : column name of the adata.obs dataframe that contains the cluster labels
        type_data : string  default 'pt' for pseudotime colored nodes. or 'gene'
        gene_exp : list of values (column of dataframe) corresponding to feature or gene expression to be used to color nodes at CLUSTER level
        title : string
        cmap : default None. automatically chooses coolwarm for gene expression or viridis_r for pseudotime
        ax_text : Bool default= True. Annotates each node with cluster number and population of membership
        dpi : int default = 150
        headwidth_bundle : default = 0.1. width of arrowhead used to directed edges
        reference : None or list. list of categorical (str) labels for cluster composition of the piecharts (LHS subplot) length = n_samples.
        pie_size_scale : float default=0.8 scaling factor of the piechart nodes
        fontsize : float default=8. fontsize of the text in the piecharts
        figsize : tuple default=(8,4). size of the figure

        fig: Returns matplotlib figure with two axes that plot the clustergraph using edge bundling
        ax: left axis shows the clustergraph with each node colored by annotated ground truth membership.
        ax1: right axis shows the same clustergraph with each node colored by the pseudotime or gene expression

    if clusters=='':
    fig, ax, ax1 = draw_piechart_graph_pyomic(clusters=clusters,adata=self.adata,
                               via_object=self.model, type_data=type_data,
                            gene_exp=gene_exp, title=title, 
                            cmap=cmap, ax_text=ax_text,figsize=figsize,
                            dpi=dpi,headwidth_arrow = headwidth_arrow,
                            alpha_edge=alpha_edge, linewidth_edge=linewidth_edge,
                            show_legend=show_legend, pie_size_scale=pie_size_scale, fontsize=fontsize)
    return fig, ax, ax1

plot_stream(clusters='', basis='', density_grid=0.5, arrow_size=0.7, arrow_color='k', arrow_style='-|>', max_length=4, linewidth=1, min_mass=1, cutoff_perc=5, scatter_size=500, scatter_alpha=0.5, marker_edgewidth=0.1, density_stream=2, smooth_transition=1, smooth_grid=0.5, color_scheme='annotation', add_outline_clusters=False, cluster_outline_edgewidth=0.001, gp_color='white', bg_color='black', dpi=80, title='Streamplot', b_bias=20, n_neighbors_velocity_grid=None, other_labels=None, use_sequentially_augmented=False, cmap_str='rainbow')

Construct vector streamplot on the embedding to show a fine-grained view of inferred directions in the trajectory


Name Type Description Default

column name of the adata.obs dataframe that contains the cluster labels


str, default = 'X_umap', which to use for the embedding


float, default = 0.5, density of the grid on which to project the directionality of cells


float, default = 0.7, size of the arrows in the streamplot


str, default = 'k', color of the arrows in the streamplot


str, default = "-|>", style of the arrows in the streamplot


int, default = 4, maximum length of the arrows in the streamplot


float, default = 1, width of lines in streamplot


float, default = 1, minimum mass of the arrows in the streamplot


int, default = 5, cutoff percentage of the arrows in the streamplot


int, default = 500, size of scatter points


float, default = 0.5, transpsarency of scatter points


float, default = 0.1, width of outline arround each scatter point


int, default = 2, density of the streamplot


int, default = 1, smoothness of the transition between the streamplot and the scatter points


float, default = 0.5, smoothness of the grid on which to project the directionality of cells


str, default = 'annotation' corresponds to self.true_labels. Other options are 'time' (uses single-cell pseudotime) or 'clusters' (uses self.clusters)


bool, default = False, whether to add an outline to the clusters


float, default = 0.001, width of the outline around the clusters


str, default = 'white', color of the grid points


str, default = 'black', color of the background


int, default = 80, dpi of the figure


str, default = 'Streamplot', title of the figure


int, default = 20, higher value makes the forward bias of pseudotime stronger


int, default = None, number of neighbors to use for the velocity grid


list, default = None, list of other labels to plot in the streamplot


bool, default = False, whether to use the sequentially augmented data


str, default = 'rainbow', color map to use for the streamplot



Name Type Description
fig matplotlib.figure.Figure

matplotlib figure

ax matplotlib.axes._axes.Axes

matplotlib axis

def plot_stream(self,clusters:str='',basis:str='',
               density_grid:float=0.5, arrow_size:float=0.7, arrow_color:str = 'k',
               arrow_style="-|>",  max_length:int=4, linewidth:float=1,min_mass = 1, cutoff_perc:int = 5,
               scatter_size:int=500, scatter_alpha:float=0.5,marker_edgewidth:float=0.1,
               density_stream:int = 2, smooth_transition:int=1, smooth_grid:float=0.5,
               color_scheme:str = 'annotation', add_outline_clusters:bool=False,
               cluster_outline_edgewidth = 0.001,gp_color = 'white', bg_color='black' ,
               dpi=80 , title='Streamplot', b_bias=20, n_neighbors_velocity_grid=None,
               other_labels:list = None,use_sequentially_augmented:bool=False, cmap_str:str='rainbow')->Tuple[matplotlib.figure.Figure,
    """Construct vector streamplot on the embedding to show a fine-grained view of inferred directions in the trajectory

        clusters : column name of the adata.obs dataframe that contains the cluster labels
        basis : str, default = 'X_umap', which to use for the embedding
        density_grid : float, default = 0.5, density of the grid on which to project the directionality of cells
        arrow_size : float, default = 0.7, size of the arrows in the streamplot
        arrow_color : str, default = 'k', color of the arrows in the streamplot
        arrow_style : str, default = "-|>", style of the arrows in the streamplot
        max_length : int, default = 4, maximum length of the arrows in the streamplot
        linewidth : float, default = 1, width of  lines in streamplot
        min_mass : float, default = 1, minimum mass of the arrows in the streamplot
        cutoff_perc : int, default = 5, cutoff percentage of the arrows in the streamplot
        scatter_size : int, default = 500, size of scatter points
        scatter_alpha : float, default = 0.5, transpsarency of scatter points
        marker_edgewidth : float, default = 0.1, width of outline arround each scatter point
        density_stream : int, default = 2, density of the streamplot
        smooth_transition : int, default = 1, smoothness of the transition between the streamplot and the scatter points
        smooth_grid : float, default = 0.5, smoothness of the grid on which to project the directionality of cells
        color_scheme : str, default = 'annotation' corresponds to self.true_labels. Other options are 'time' (uses single-cell pseudotime) or 'clusters' (uses self.clusters)
        add_outline_clusters : bool, default = False, whether to add an outline to the clusters
        cluster_outline_edgewidth : float, default = 0.001, width of the outline around the clusters
        gp_color : str, default = 'white', color of the grid points
        bg_color : str, default = 'black', color of the background
        dpi : int, default = 80, dpi of the figure
        title : str, default = 'Streamplot', title of the figure
        b_bias : int, default = 20, higher value makes the forward bias of pseudotime stronger
        n_neighbors_velocity_grid : int, default = None, number of neighbors to use for the velocity grid
        other_labels : list, default = None, list of other labels to plot in the streamplot
        use_sequentially_augmented : bool, default = False, whether to use the sequentially augmented data
        cmap_str : str, default = 'rainbow', color map to use for the streamplot

        fig : matplotlib figure
        ax : matplotlib axis

    if clusters=='':
    if basis=='':
    fig,ax = via_streamplot_pyomic(adata=self.adata,clusters=clusters,via_object=self.model, 
                             embedding=embedding,density_grid=density_grid, arrow_size=arrow_size,
                             arrow_color=arrow_color,arrow_style=arrow_style,  max_length=max_length,
                             linewidth=linewidth,min_mass = min_mass, cutoff_perc=cutoff_perc,
                             scatter_size=scatter_size, scatter_alpha=scatter_alpha,marker_edgewidth=marker_edgewidth,
                             density_stream=density_stream, smooth_transition=smooth_transition, smooth_grid=smooth_grid,
                             color_scheme=color_scheme, add_outline_clusters=add_outline_clusters,
                             cluster_outline_edgewidth = cluster_outline_edgewidth,gp_color = gp_color, bg_color=bg_color,
                             dpi=dpi , title=title, b_bias=b_bias, n_neighbors_velocity_grid=n_neighbors_velocity_grid,
                             other_labels=other_labels,use_sequentially_augmented=use_sequentially_augmented, cmap_str=cmap_str)
    return fig,ax

plot_trajectory_gams(clusters='', basis='', via_fine=None, idx=None, title_str='Pseudotime', draw_all_curves=True, arrow_width_scale_factor=15.0, scatter_size=50, scatter_alpha=0.5, figsize=(8, 4), linewidth=1.5, marker_edgewidth=1, cmap_pseudotime='viridis_r', dpi=80, highlight_terminal_states=True, use_maxout_edgelist=False)

projects the graph based coarse trajectory onto a umap/tsne embedding


Name Type Description Default

column name of the adata.obs dataframe that contains the cluster labels


str, default = 'X_umap', which to use for the embedding


via object suggest to use via_object only unless you found that running via_fine gave better pathways


default: None. Or List. if you had previously computed a umap/tsne (embedding) only on a subset of the total n_samples (subsampled as per idx), then the via objects and results will be indexed according to idx too


title of figure


if the clustergraph has too many edges to project in a visually interpretable way, set this to False to get a simplified view of the graph pathways


the width of the arrows is proportional to the edge weight. This factor scales the width of the arrows


size of the scatter points


transparency of the scatter points


width of the lines


width of the outline around each scatter point


color map to use for the pseudotime


dpi of the figure


whether or not to highlight/distinguish the clusters which are detected as the terminal states by via



Name Type Description
fig matplotlib.figure.Figure

matplotlib figure

ax1 matplotlib.axes._axes.Axes

matplotlib axis

ax2 matplotlib.axes._axes.Axes

matplotlib axis

def plot_trajectory_gams(self,clusters:str='',basis:str='',via_fine=None, idx=None,
                     title_str:str= "Pseudotime", draw_all_curves:bool=True, arrow_width_scale_factor:float=15.0,
                     scatter_size:float=50, scatter_alpha:float=0.5,figsize:tuple=(8,4),
                     linewidth:float=1.5, marker_edgewidth:float=1, cmap_pseudotime:str='viridis_r',dpi:int=80,
                     highlight_terminal_states:bool=True, use_maxout_edgelist:bool =False)->Tuple[matplotlib.figure.Figure,
    """projects the graph based coarse trajectory onto a umap/tsne embedding

        clusters : column name of the adata.obs dataframe that contains the cluster labels
        basis : str, default = 'X_umap', which to use for the embedding
        via_fine : via object suggest to use via_object only unless you found that running via_fine gave better pathways
        idx : default: None. Or List. if you had previously computed a umap/tsne (embedding) only on a subset of the total n_samples (subsampled as per idx), then the via objects and results will be indexed according to idx too
        title_str : title of figure
        draw_all_curves : if the clustergraph has too many edges to project in a visually interpretable way, set this to False to get a simplified view of the graph pathways
        arrow_width_scale_factor : the width of the arrows is proportional to the edge weight. This factor scales the width of the arrows
        scatter_size : size of the scatter points
        scatter_alpha : transparency of the scatter points
        linewidth : width of the lines
        marker_edgewidth : width of the outline around each scatter point
        cmap_pseudotime : color map to use for the pseudotime
        dpi : dpi of the figure
        highlight_terminal_states :  whether or not to highlight/distinguish the clusters which are detected as the terminal states by via

        fig : matplotlib figure
        ax1 : matplotlib axis
        ax2 : matplotlib axis


    if clusters=='':
    if basis=='':
    fig,ax1,ax2 = draw_trajectory_gams_pyomic(adata=self.adata,clusters=clusters,via_object=self.model, 
                                        via_fine=via_fine, embedding=embedding, idx=idx,
                                        title_str=title_str, draw_all_curves=draw_all_curves, arrow_width_scale_factor=arrow_width_scale_factor,
                                        scatter_size=scatter_size, scatter_alpha=scatter_alpha,figsize=figsize,
                                        linewidth=linewidth, marker_edgewidth=marker_edgewidth, cmap_pseudotime=cmap_pseudotime,dpi=dpi,
                                        highlight_terminal_states=highlight_terminal_states, use_maxout_edgelist=use_maxout_edgelist)
    return fig,ax1,ax2

plot_lineage_probability(clusters='', basis='', via_fine=None, idx=None, figsize=(8, 4), cmap='plasma', dpi=80, scatter_size=None, marker_lineages=[], fontsize=12)

G is the igraph knn (low K) used for shortest path in high dim space. no idx needed as it's made on full sample, knn_hnsw is the knn made in the embedded space used for query to find the nearest point in the downsampled embedding that corresponds to the single cells in the full graph


Name Type Description Default

column name of the adata.obs dataframe that contains the cluster labels


str, default = 'X_umap', which to use for the embedding


usually just set to same as via_coarse unless you ran a refined run and want to link it to initial via_coarse's terminal clusters


if one uses a downsampled embedding of the original data, then idx is the selected indices of the downsampled samples used in the visualization


size of the figure

(8, 4)

color map to use for the lineage probability


dpi of the figure


size of the scatter points


Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).


fontsize of the title



Name Type Description
fig matplotlib.figure.Figure

matplotlib figure

axs matplotlib.axes._axes.Axes

matplotlib axis

def plot_lineage_probability(self,clusters:str='',basis:str='',via_fine=None, 
                            idx=None, figsize:tuple=(8,4),
                            cmap:str='plasma', dpi:int=80, scatter_size =None,
                            marker_lineages:list = [], fontsize:int=12)->Tuple[matplotlib.figure.Figure,
    """G is the igraph knn (low K) used for shortest path in high dim space. no idx needed as it's made on full sample, knn_hnsw is the knn made in the embedded space used for query to find the nearest point in the downsampled embedding that corresponds to the single cells in the full graph

        clusters : column name of the adata.obs dataframe that contains the cluster labels
        basis : str, default = 'X_umap', which to use for the embedding
        via_fine : usually just set to same as via_coarse unless you ran a refined run and want to link it to initial via_coarse's terminal clusters
        idx : if one uses a downsampled embedding of the original data, then idx is the selected indices of the downsampled samples used in the visualization
        figsize : size of the figure
        cmap : color map to use for the lineage probability
        dpi : dpi of the figure
        scatter_size : size of the scatter points
        marker_lineages : Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).
        fontsize : fontsize of the title

        fig : matplotlib figure
        axs : matplotlib axis

    if clusters=='':
    if basis=='':
    fig, axs = draw_sc_lineage_probability(via_object=self.model,via_fine=via_fine, embedding=embedding,figsize=figsize,
                                           idx=idx, cmap_name=cmap, dpi=dpi, scatter_size =scatter_size,
                                        marker_lineages = marker_lineages, fontsize=fontsize)
    return fig, axs

plot_gene_trend(gene_list=None, figsize=(8, 4), magic_steps=3, spline_order=5, dpi=80, cmap='jet', marker_genes=[], linewidth=2.0, n_splines=10, fontsize_=12, marker_lineages=[])

plots the gene expression trend along the pseudotime


Name Type Description Default

list of genes to plot


size of the figure

(8, 4)

number of magic steps to use for imputation


order of the spline to use for smoothing


dpi of the figure


color map to use for the gene expression


Default is to use all genes in gene_exp. other provide a list of marker genes that will be used from gene_exp.


width of the lines


number of splines to use for smoothing


fontsize of the title


Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).



Name Type Description
fig matplotlib.figure.Figure

matplotlib figure

axs matplotlib.axes._axes.Axes

matplotlib axis

def plot_gene_trend(self,gene_list:list=None,figsize:tuple=(8,4),
                    magic_steps:int=3, spline_order:int=5, dpi:int=80,cmap:str='jet', 
                    marker_genes:list = [], linewidth:float = 2.0,
                    n_splines:int=10,  fontsize_:int=12, marker_lineages=[])->Tuple[matplotlib.figure.Figure,
    """plots the gene expression trend along the pseudotime

        gene_list : list of genes to plot
        figsize : size of the figure
        magic_steps : number of magic steps to use for imputation
        spline_order : order of the spline to use for smoothing
        dpi : dpi of the figure
        cmap : color map to use for the gene expression
        marker_genes : Default is to use all genes in gene_exp. other provide a list of marker genes that will be used from gene_exp.
        linewidth : width of the lines
        n_splines : number of splines to use for smoothing
        fontsize_ : fontsize of the title
        marker_lineages : Default is to use all lineage pathways. other provide a list of lineage number (terminal cluster number).

        fig : matplotlib figure
        axs : matplotlib axis


    df_magic = self.model.do_impute(self.adata[:,gene_list].to_df(), magic_steps=magic_steps, gene_list=gene_list)
    fig, axs=get_gene_expression_pyomic(self.model,df_magic,spline_order=spline_order,dpi=dpi,
                               cmap=cmap, marker_genes=marker_genes, linewidth=linewidth,figsize=figsize,
                               n_splines=n_splines,  fontsize_=fontsize_, marker_lineages=marker_lineages)
    return fig, axs

plot_gene_trend_heatmap(gene_list, marker_lineages=[], fontsize=8, cmap='viridis', normalize=True, ytick_labelrotation=0, figsize=(2, 4))

Plot the gene trends on heatmap: a heatmap is generated for each lineage (identified by terminal cluster number). Default selects all lineages


Name Type Description Default

list of genes to plot


list default = None and plots all detected all lineages. Optionally provide a list of integers corresponding to the cluster number of terminal cell fates


int default = 8


str default = 'viridis'


bool = True


int default = 0


size of the figure

(2, 4)


Name Type Description
fig matplotlib.figure.Figure

matplotlib figure

axs list

list of matplotlib axis

def plot_gene_trend_heatmap(self,gene_list:list,marker_lineages:list = [], 
                         fontsize:int=8,cmap:str='viridis', normalize:bool=True, ytick_labelrotation:int = 0, 
    """Plot the gene trends on heatmap: a heatmap is generated for each lineage (identified by terminal cluster number). Default selects all lineages

        gene_list : list of genes to plot
        marker_lineages : list default = None and plots all detected all lineages. Optionally provide a list of integers corresponding to the cluster number of terminal cell fates
        fontsize : int default = 8
        cmap : str default = 'viridis'
        normalize : bool = True
        ytick_labelrotation : int default = 0
        figsize : size of the figure

        fig : matplotlib figure
        axs : list of matplotlib axis       

    df_magic = self.model.do_impute(self.adata[:,gene_list].to_df(), magic_steps=3, gene_list=gene_list)
    df_magic['parc'] = self.model.labels
    df_magic_cluster = df_magic.groupby('parc', as_index=True).mean()
    fig,axs=plot_gene_trend_heatmaps_pyomic(via_object=self.model, df_gene_exp=df_magic, 
    return fig,axs

plot_clustergraph(gene_list, arrow_head=0.1, figsize=(8, 4), dpi=80, magic_steps=3, edgeweight_scale=1.5, cmap=None, label_=True)

plot the gene in pie chart for each cluster


Name Type Description Default

list of genes to plot


size of the arrow head


size of the figure

(8, 4)

scale of the edge weight


color map to use for the gene expression


whether to label the nodes



Name Type Description
fig matplotlib.figure.Figure

matplotlib figure

axs matplotlib.axes._axes.Axes

matplotlib axis

def plot_clustergraph(self,gene_list:list,arrow_head:float=0.1,figsize:tuple=(8,4),dpi=80,magic_steps=3,
                      edgeweight_scale:float=1.5, cmap=None, label_=True,)->Tuple[matplotlib.figure.Figure,
    """plot the gene in pie chart for each cluster

        gene_list : list of genes to plot
        arrow_head : size of the arrow head
        figsize : size of the figure
        edgeweight_scale : scale of the edge weight
        cmap : color map to use for the gene expression
        label_ : whether to label the nodes

        fig : matplotlib figure
        axs : matplotlib axis
    df_magic = self.model.do_impute(self.adata[:,gene_list].to_df(), magic_steps=magic_steps, gene_list=gene_list)
    df_magic['parc'] = self.model.labels
    df_magic_cluster = df_magic.groupby('parc', as_index=True).mean()
    fig, axs = draw_clustergraph_pyomic(via_object=self.model, type_data='gene', gene_exp=df_magic_cluster, 
                                gene_list=gene_list, arrow_head=arrow_head,figsize=figsize,
                                edgeweight_scale=edgeweight_scale, cmap=cmap, label_=label_,dpi=dpi)
    return fig,axs