Skip to content

Api pl

Violin and Distribution Plots

omicverse.pl.violin(adata, keys, groupby=None, *, log=False, use_raw=None, stripplot=True, jitter=True, size=1, layer=None, density_norm='width', order=None, multi_panel=None, xlabel='', ylabel=None, rotation=None, show=None, save=None, ax=None, enhanced_style=True, show_means=False, show_boxplot=False, jitter_method='uniform', jitter_alpha=0.4, violin_alpha=0.8, background_color='white', spine_color='#b4aea9', grid_lines=True, statistical_tests=False, custom_colors=None, figsize=None, fontsize=13, ticks_fontsize=None, **kwds)

Enhanced violin plot compatible with scanpy's interface.

This function provides all the functionality of scanpy's violin plot with additional customization options for enhanced visualization, implemented using pure matplotlib.

Parameters:

Name Type Description Default
adata AnnData

AnnData. Annotated data matrix.

required
keys Union[str, Sequence[str]]

str | Sequence[str]. Keys for accessing variables of .var_names or fields of .obs.

required
groupby Optional[str]

str | None. The key of the observation grouping to consider. (None)

None
log bool

bool. Plot on logarithmic axis. (False)

False
use_raw Optional[bool]

bool | None. Whether to use raw attribute of adata. Defaults to True if .raw is present. (None)

None
stripplot bool

bool. Add a stripplot on top of the violin plot. (True)

True
jitter Union[float, bool]

float | bool. Add jitter to the stripplot (only when stripplot is True). (True)

True
size int

int. Size of the jitter points. (1)

1
layer Optional[str]

str | None. Name of the AnnData object layer that wants to be plotted. (None)

None
density_norm DensityNorm

str. The method used to scale the width of each violin. If 'width' (the default), each violin will have the same width. If 'area', each violin will have the same area. If 'count', a violin's width corresponds to the number of observations. ('width')

'width'
order Optional[Sequence[str]]

Sequence[str] | None. Order in which to show the categories. (None)

None
multi_panel Optional[bool]

bool | None. Display keys in multiple panels also when groupby is not None. (None)

None
xlabel str

str. Label of the x axis. ('')

''
ylabel Optional[Union[str, Sequence[str]]]

str | Sequence[str] | None. Label of the y axis. (None)

None
rotation Optional[float]

float | None. Rotation of xtick labels. (None)

None
show Optional[bool]

bool | None. Whether to show the plot. (None)

None
save Optional[Union[bool, str]]

bool | str | None. Path to save the figure. (None)

None
ax Optional[Axes]

Axes | None. A matplotlib axes object. (None)

None
enhanced_style bool

bool. Whether to apply enhanced styling. (True)

True
show_means bool

bool. Whether to show mean values with annotations. (False)

False
show_boxplot bool

bool. Whether to overlay box plots on violins. (False)

False
jitter_method str

str. Method for jittering: 'uniform' or 't_dist'. ('uniform')

'uniform'
jitter_alpha float

float. Transparency of jittered points. (0.4)

0.4
violin_alpha float

float. Transparency of violin plots. (0.8)

0.8
background_color str

str. Background color of the plot. ('white')

'white'
spine_color str

str. Color of plot spines. ('#b4aea9')

'#b4aea9'
grid_lines bool

bool. Whether to show horizontal grid lines. (True)

True
statistical_tests bool

bool. Whether to perform and display statistical tests. (False)

False
custom_colors Optional[Sequence[str]]

Sequence[str] | None. Custom colors for groups. (None)

None
figsize Optional[tuple]

tuple | None. Figure size (width, height). (None)

None
fontsize

int. Font size for labels and ticks. (13)

13
ticks_fontsize

int | None. Font size for axis ticks. If None, uses fontsize-1. (None)

None
**kwds

Additional keyword arguments passed to violinplot.

{}

Returns:

Name Type Description
ax Union[Axes, None]

matplotlib.axes.Axes | None. A matplotlib axes object if ax is None else None.

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_violin.py
def violin(
    adata: AnnData,
    keys: Union[str, Sequence[str]],
    groupby: Optional[str] = None,
    *,
    log: bool = False,
    use_raw: Optional[bool] = None,
    stripplot: bool = True,
    jitter: Union[float, bool] = True,
    size: int = 1,
    layer: Optional[str] = None,
    density_norm: DensityNorm = "width",
    order: Optional[Sequence[str]] = None,
    multi_panel: Optional[bool] = None,
    xlabel: str = "",
    ylabel: Optional[Union[str, Sequence[str]]] = None,
    rotation: Optional[float] = None,
    show: Optional[bool] = None,
    save: Optional[Union[bool, str]] = None,
    ax: Optional[Axes] = None,
    # Enhanced features
    enhanced_style: bool = True,
    show_means: bool = False,
    show_boxplot: bool = False,
    jitter_method: str = 'uniform',  # 'uniform', 't_dist'
    jitter_alpha: float = 0.4,
    violin_alpha: float = 0.8,
    background_color: str = 'white',
    spine_color: str = '#b4aea9',
    grid_lines: bool = True,
    statistical_tests: bool = False,
    custom_colors: Optional[Sequence[str]] = None,
    figsize: Optional[tuple] = None,
    fontsize=13,
    ticks_fontsize=None,
    **kwds
) -> Union[Axes, None]:
    r"""
    Enhanced violin plot compatible with scanpy's interface.

    This function provides all the functionality of scanpy's violin plot
    with additional customization options for enhanced visualization,
    implemented using pure matplotlib.

    Arguments:
        adata: AnnData. Annotated data matrix.
        keys: str | Sequence[str]. Keys for accessing variables of `.var_names` or fields of `.obs`.
        groupby: str | None. The key of the observation grouping to consider. (None)
        log: bool. Plot on logarithmic axis. (False)
        use_raw: bool | None. Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present. (None)
        stripplot: bool. Add a stripplot on top of the violin plot. (True)
        jitter: float | bool. Add jitter to the stripplot (only when stripplot is True). (True)
        size: int. Size of the jitter points. (1)
        layer: str | None. Name of the AnnData object layer that wants to be plotted. (None)
        density_norm: str. The method used to scale the width of each violin. If 'width' (the default), each violin will have the same width. If 'area', each violin will have the same area. If 'count', a violin's width corresponds to the number of observations. ('width')
        order: Sequence[str] | None. Order in which to show the categories. (None)
        multi_panel: bool | None. Display keys in multiple panels also when `groupby is not None`. (None)
        xlabel: str. Label of the x axis. ('')
        ylabel: str | Sequence[str] | None. Label of the y axis. (None)
        rotation: float | None. Rotation of xtick labels. (None)
        show: bool | None. Whether to show the plot. (None)
        save: bool | str | None. Path to save the figure. (None)
        ax: Axes | None. A matplotlib axes object. (None)
        enhanced_style: bool. Whether to apply enhanced styling. (True)
        show_means: bool. Whether to show mean values with annotations. (False)
        show_boxplot: bool. Whether to overlay box plots on violins. (False)
        jitter_method: str. Method for jittering: 'uniform' or 't_dist'. ('uniform')
        jitter_alpha: float. Transparency of jittered points. (0.4)
        violin_alpha: float. Transparency of violin plots. (0.8)
        background_color: str. Background color of the plot. ('white')
        spine_color: str. Color of plot spines. ('#b4aea9')
        grid_lines: bool. Whether to show horizontal grid lines. (True)
        statistical_tests: bool. Whether to perform and display statistical tests. (False)
        custom_colors: Sequence[str] | None. Custom colors for groups. (None)
        figsize: tuple | None. Figure size (width, height). (None)
        fontsize: int. Font size for labels and ticks. (13)
        ticks_fontsize: int | None. Font size for axis ticks. If None, uses fontsize-1. (None)
        **kwds: Additional keyword arguments passed to violinplot.

    Returns:
        ax: matplotlib.axes.Axes | None. A matplotlib axes object if `ax` is `None` else `None`.
    """

    # Handle AnnData availability
    if not ANNDATA_AVAILABLE:
        raise ImportError("AnnData is required for this function. Install with: pip install anndata")

    # Ensure keys is a list
    if isinstance(keys, str):
        keys = [keys]
    keys = list(OrderedDict.fromkeys(keys))  # remove duplicates, preserving order

    # Handle ylabel
    if isinstance(ylabel, (str, type(None))):
        ylabel = [ylabel] * (1 if groupby is None else len(keys))
    ylabel=keys

    # Validate ylabel length
    if groupby is None:
        if len(ylabel) != 1:
            raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
    elif len(ylabel) != len(keys):
        raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`.")

    # Extract data from AnnData object
    obs_df = _extract_data_from_adata(adata, keys, groupby, layer, use_raw)

    # Colorful data analysis and parameter optimization suggestions
    print(f"{Colors.HEADER}{Colors.BOLD}🎻 Violin Plot Analysis:{Colors.ENDC}")
    print(f"   {Colors.CYAN}Total cells: {Colors.BOLD}{len(obs_df)}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Variables to plot: {Colors.BOLD}{len(keys)} {keys}{Colors.ENDC}")

    if groupby is not None:
        group_counts = obs_df[groupby].value_counts()
        print(f"   {Colors.GREEN}Groupby variable: '{Colors.BOLD}{groupby}{Colors.ENDC}{Colors.GREEN}' with {Colors.BOLD}{len(group_counts)}{Colors.ENDC}{Colors.GREEN} groups{Colors.ENDC}")

        # Show group distribution
        for group, count in group_counts.head(10).items():  # Show top 10 groups
            if count < 10:
                color = Colors.WARNING
            elif count < 50:
                color = Colors.BLUE
            else:
                color = Colors.GREEN
            print(f"     {color}{group}: {Colors.BOLD}{count}{Colors.ENDC}{color} cells{Colors.ENDC}")

        if len(group_counts) > 10:
            print(f"     {Colors.CYAN}... and {Colors.BOLD}{len(group_counts) - 10}{Colors.ENDC}{Colors.CYAN} more groups{Colors.ENDC}")

        # Check for imbalanced groups
        min_count = group_counts.min()
        max_count = group_counts.max()
        if max_count / min_count > 10:
            print(f"   {Colors.WARNING}⚠️  Imbalanced groups detected: {Colors.BOLD}{min_count}-{max_count}{Colors.ENDC}{Colors.WARNING} cells per group{Colors.ENDC}")
    else:
        print(f"   {Colors.BLUE}Groupby: {Colors.BOLD}None{Colors.ENDC}{Colors.BLUE} (comparing variables){Colors.ENDC}")

    # Analyze data distribution for each variable
    print(f"\n{Colors.HEADER}{Colors.BOLD}📊 Data Distribution Analysis:{Colors.ENDC}")
    for key in keys:
        if key in obs_df.columns:
            data_vals = obs_df[key].dropna()
            if len(data_vals) > 0:
                data_range = data_vals.max() - data_vals.min()
                zero_fraction = (data_vals == 0).sum() / len(data_vals)

                # Determine if data might be log-transformed already
                log_suggestion = ""
                if data_vals.min() >= 0 and data_vals.max() > 100:
                    log_suggestion = f" {Colors.BLUE}(consider log=True){Colors.ENDC}"
                elif data_vals.min() < 0:
                    log_suggestion = f" {Colors.WARNING}(negative values detected){Colors.ENDC}"

                print(f"   {Colors.BLUE}'{key}': range {Colors.BOLD}{data_vals.min():.2f}-{data_vals.max():.2f}{Colors.ENDC}{Colors.BLUE}, {Colors.BOLD}{zero_fraction*100:.1f}%{Colors.ENDC}{Colors.BLUE} zeros{log_suggestion}")

    # Display current function parameters
    print(f"\n{Colors.HEADER}{Colors.BOLD}⚙️  Current Function Parameters:{Colors.ENDC}")
    print(f"   {Colors.BLUE}Plot style: enhanced_style={Colors.BOLD}{enhanced_style}{Colors.ENDC}{Colors.BLUE}, stripplot={Colors.BOLD}{stripplot}{Colors.ENDC}{Colors.BLUE}, jitter={Colors.BOLD}{jitter}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Additional features: show_means={Colors.BOLD}{show_means}{Colors.ENDC}{Colors.BLUE}, show_boxplot={Colors.BOLD}{show_boxplot}{Colors.ENDC}{Colors.BLUE}, statistical_tests={Colors.BOLD}{statistical_tests}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Figure settings: figsize={Colors.BOLD}{figsize}{Colors.ENDC}{Colors.BLUE}, fontsize={Colors.BOLD}{fontsize}{Colors.ENDC}{Colors.BLUE}, violin_alpha={Colors.BOLD}{violin_alpha}{Colors.ENDC}")
    if custom_colors is not None:
        print(f"   {Colors.BLUE}Colors: {Colors.BOLD}{len(custom_colors)} custom colors specified{Colors.ENDC}")
    else:
        print(f"   {Colors.BLUE}Colors: {Colors.BOLD}Default palette{Colors.ENDC}")

    # Parameter optimization suggestions
    print(f"\n{Colors.HEADER}{Colors.BOLD}💡 Parameter Optimization Suggestions:{Colors.ENDC}")
    suggestions = []

    # Check for too many groups
    if groupby is not None:
        n_groups = len(obs_df[groupby].unique())
        if n_groups > 8:
            suggestions.append(f"   {Colors.WARNING}▶ Many groups detected ({n_groups}):{Colors.ENDC}")
            suggestions.append(f"     {Colors.CYAN}Current: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
            suggested_width = max(8, n_groups * 1.2)
            suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_width}, {figsize[1] if figsize else 6}){Colors.ENDC}")

            if rotation is None:
                suggestions.append(f"     {Colors.GREEN}Consider adding: {Colors.BOLD}rotation=45{Colors.ENDC} for better label readability")

        # Check for small sample sizes
        if groupby is not None:
            group_counts = obs_df[groupby].value_counts()
            if group_counts.min() < 10:
                suggestions.append(f"   {Colors.WARNING}▶ Small sample sizes detected (min: {group_counts.min()}):{Colors.ENDC}")
                suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}stripplot=True, jitter=0.3{Colors.ENDC} to show individual points")

    # Check data distribution and log scale
    for key in keys:
        if key in obs_df.columns:
            data_vals = obs_df[key].dropna()
            if len(data_vals) > 0 and data_vals.min() >= 0 and data_vals.max() / data_vals.min() > 100:
                suggestions.append(f"   {Colors.WARNING}▶ Wide data range for '{key}' ({data_vals.max()/data_vals.min():.1f}x):{Colors.ENDC}")
                suggestions.append(f"     {Colors.CYAN}Current: log={Colors.BOLD}{log}{Colors.ENDC}")
                suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}log=True{Colors.ENDC} for better visualization")
                break

    # Check figure size vs number of variables
    if len(keys) > 3 and (figsize is None or figsize[0] < len(keys) * 2):
        suggestions.append(f"   {Colors.WARNING}▶ Multiple variables with small figure:{Colors.ENDC}")
        suggested_width = max(len(keys) * 2, 8)
        suggestions.append(f"     {Colors.CYAN}Current: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_width}, 6){Colors.ENDC} or {Colors.BOLD}multi_panel=True{Colors.ENDC}")

    # Font size optimization
    if groupby is not None:
        n_groups = len(obs_df[groupby].unique())
        max_label_length = max(len(str(x)) for x in obs_df[groupby].unique())
        if max_label_length > 8 and fontsize > 12:
            suggestions.append(f"   {Colors.WARNING}▶ Long group labels detected:{Colors.ENDC}")
            suggestions.append(f"     {Colors.CYAN}Current: fontsize={Colors.BOLD}{fontsize}{Colors.ENDC}")
            suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}fontsize=10, rotation=45{Colors.ENDC}")

    # Enhanced features suggestions
    if not show_means and not show_boxplot and not statistical_tests:
        suggestions.append(f"   {Colors.BLUE}▶ Consider enhancing your plot:{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Options: {Colors.BOLD}show_means=True{Colors.ENDC}{Colors.GREEN}, {Colors.BOLD}show_boxplot=True{Colors.ENDC}{Colors.GREEN}, or {Colors.BOLD}statistical_tests=True{Colors.ENDC}")

    if suggestions:
        for suggestion in suggestions:
            print(suggestion)

        print(f"\n   {Colors.BOLD}📋 Copy-paste ready function call:{Colors.ENDC}")

        # Generate optimized function call
        optimized_params = []

        # Core parameters
        if isinstance(keys, list) and len(keys) == 1:
            optimized_params.append(f"adata, keys='{keys[0]}'")
        else:
            optimized_params.append(f"adata, keys={keys}")

        if groupby is not None:
            optimized_params.append(f"groupby='{groupby}'")

        # Add optimized parameters based on suggestions
        if groupby is not None:
            n_groups = len(obs_df[groupby].unique())
            if n_groups > 8:
                suggested_width = max(8, n_groups * 1.2)
                optimized_params.append(f"figsize=({suggested_width}, 6)")

            if rotation is None and n_groups > 6:
                optimized_params.append("rotation=45")

            group_counts = obs_df[groupby].value_counts()
            if group_counts.min() < 10:
                optimized_params.append("stripplot=True")
                optimized_params.append("jitter=0.3")

        # Check for log scale suggestion
        for key in keys:
            if key in obs_df.columns:
                data_vals = obs_df[key].dropna()
                if len(data_vals) > 0 and data_vals.min() >= 0 and data_vals.max() / data_vals.min() > 100:
                    optimized_params.append("log=True")
                    break

        # Multi-panel suggestion
        if len(keys) > 3 and (figsize is None or figsize[0] < len(keys) * 2):
            optimized_params.append("multi_panel=True")

        # Enhanced features
        if not show_means and not show_boxplot:
            optimized_params.append("show_means=True")

        optimized_call = f"   {Colors.GREEN}ov.pl.violin({', '.join(optimized_params)}){Colors.ENDC}"
        print(optimized_call)
    else:
        print(f"   {Colors.GREEN}✅ Current parameters are optimal for your data!{Colors.ENDC}")

    print(f"{Colors.CYAN}{'─' * 60}{Colors.ENDC}")

    # Prepare data for plotting
    if groupby is None:
        obs_tidy = pd.melt(obs_df, value_vars=keys)
        x_col = "variable"
        y_cols = ["value"]
        group_categories = keys
    else:
        obs_tidy = obs_df
        x_col = groupby
        y_cols = keys
        obs_df[groupby] = obs_df[groupby].astype('category')
        group_categories = obs_df[groupby].cat.categories if order is None else order

    # Set up colors
    colors = _setup_colors(custom_colors, group_categories, adata, groupby)

    # Handle multi-panel case
    if multi_panel and groupby is None and len(y_cols) == 1:
        return _create_multi_panel_plot(
            obs_tidy, keys, y_cols[0], colors, log, stripplot, jitter, 
            size, density_norm, enhanced_style, **kwds
        )

    # Create single or multiple axis plots
    if ax is None and figsize is not None:
        fig, ax = plt.subplots(figsize=figsize)
    elif ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))

    # Apply enhanced styling
    if enhanced_style:
        _apply_enhanced_styling(ax, background_color, spine_color, grid_lines)

    # Create plots for each y variable
    for i, (y_col, y_label) in enumerate(zip(y_cols, ylabel)):
        if len(y_cols) > 1:
            # Create subplots for multiple keys
            if i == 0:
                fig, axes = plt.subplots(1, len(y_cols), figsize=figsize or (5*len(y_cols), 6))
                if len(y_cols) == 1:
                    axes = [axes]
            current_ax = axes[i] if len(y_cols) > 1 else ax
        else:
            current_ax = ax

        # Prepare data for current y variable
        plot_data = _prepare_plot_data(obs_tidy, x_col, y_col, group_categories, order)

        # Create violin plots
        _create_violin_plots(
            current_ax, plot_data, group_categories, colors, density_norm, 
            violin_alpha, enhanced_style, **kwds
        )

        # Add box plots if requested
        if show_boxplot:
            _add_box_plots(current_ax, plot_data, group_categories)

        # Add strip plot (jittered points)
        if stripplot:
            _add_strip_plot(
                current_ax, plot_data, group_categories, jitter, jitter_method, 
                size, jitter_alpha, colors
            )

        # Add mean annotations
        if show_means:
            _add_mean_annotations(current_ax, plot_data, group_categories)

        # Add statistical tests
        if statistical_tests:
            _add_statistical_tests(current_ax, plot_data, group_categories)

        # Customize axis
        _customize_axis(
            current_ax, group_categories, xlabel, y_label, groupby, 
            rotation, log, order
        )
        current_ax.spines['top'].set_visible(False)
        current_ax.spines['right'].set_visible(False)
        current_ax.spines['bottom'].set_visible(True)
        current_ax.spines['left'].set_visible(True)
        current_ax.spines['left'].set_position(('outward', 10))
        current_ax.spines['bottom'].set_position(('outward', 10))
        if ticks_fontsize==None:
            ticks_fontsize=fontsize-1

        current_ax.set_xticklabels(current_ax.get_xticklabels(),fontsize=ticks_fontsize,rotation=rotation)
        current_ax.set_yticklabels(current_ax.get_yticklabels(),fontsize=ticks_fontsize)
        current_ax.set_xlabel(groupby,fontsize=fontsize)
        current_ax.set_ylabel(y_label,fontsize=fontsize)
        #print(y_label)

    #plt.tight_layout()

        current_ax.grid(False)



    # Save figure if requested
    if save:
        save_path = save if isinstance(save, str) else "violin_plot.pdf"
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    # Show figure if requested
    if show is True or (show is None and ax is None):
        plt.show()
        return None

    return ax

Dot Plots

omicverse.pl.dotplot(adata, var_names, groupby, *, use_raw=None, log=False, num_categories=7, categories_order=None, expression_cutoff=0.0, mean_only_expressed=False, standard_scale=None, title=None, colorbar_title='Mean expression\nin group', size_title='Fraction of cells\nin group (%)', figsize=None, dendrogram=False, gene_symbols=None, var_group_positions=None, var_group_labels=None, var_group_rotation=None, layer=None, swap_axes=False, dot_color_df=None, show=None, save=None, ax=None, return_fig=False, vmin=None, vmax=None, vcenter=None, norm=None, cmap='Reds', dot_max=None, dot_min=None, smallest_dot=0.0, fontsize=12, preserve_dict_order=False, **kwds)

Make a dot plot of the expression values of var_names.

For each var_name and each groupby category a dot is plotted. Each dot represents two values: mean expression within each category (visualized by color) and fraction of cells expressing the var_name in the category (visualized by the size of the dot).

Parameters:

Name Type Description Default
adata AnnData

AnnData Annotated data matrix.

required
var_names Union[_VarNames, Mapping[str, _VarNames]]

str or list of str or dict Variables to plot.

required
groupby Union[str, Sequence[str]]

str or list of str The key of the observation grouping to consider.

required
use_raw Optional[bool]

bool, optional (default=None) Use raw attribute of adata if present.

None
log bool

bool, optional (default=False) Whether to log-transform the data.

False
num_categories int

int, optional (default=7) Number of categories to show.

7
categories_order Optional[Sequence[str]]

list of str, optional (default=None) Order of categories to display.

None
expression_cutoff float

float, optional (default=0.0) Expression cutoff for calculating fraction of expressing cells.

0.0
mean_only_expressed bool

bool, optional (default=False) Whether to calculate mean only for expressing cells.

False
standard_scale Optional[Literal['var', 'group']]

{'var', 'group'} or None, optional (default=None) Whether to standardize data.

None
title Optional[str]

str, optional (default=None) Title for the plot.

None
colorbar_title Optional[str]

str, optional (default='Mean expression\nin group') Title for the color bar.

'Mean expression\nin group'
size_title Optional[str]

str, optional (default='Fraction of cells\nin group (%)') Title for the size legend.

'Fraction of cells\nin group (%)'
figsize Optional[Tuple[float, float]]

tuple, optional (default=None) Figure size (width, height) in inches. If provided, the plot dimensions will be scaled accordingly.

None
dendrogram Union[bool, str]

bool or str, optional (default=False) Whether to add dendrogram to the plot.

False
gene_symbols Optional[str]

str, optional (default=None) Key for gene symbols in adata.var.

None
var_group_positions Optional[Sequence[Tuple[int, int]]]

list of tuples, optional (default=None) Positions for variable groups.

None
var_group_labels Optional[Sequence[str]]

list of str, optional (default=None) Labels for variable groups.

None
var_group_rotation Optional[float]

float, optional (default=None) Rotation angle for variable group labels.

None
layer Optional[str]

str, optional (default=None) Layer to use for expression data.

None
swap_axes Optional[bool]

bool, optional (default=False) Whether to swap x and y axes.

False
dot_color_df Optional[pd.DataFrame]

pandas.DataFrame, optional (default=None) DataFrame for dot colors.

None
show Optional[bool]

bool, optional (default=None) Whether to show the plot.

None
save Optional[Union[str, bool]]

str or bool, optional (default=None) Whether to save the plot.

None
ax Optional[_AxesSubplot]

matplotlib.axes.Axes, optional (default=None) Axes object to plot on.

None
return_fig Optional[bool]

bool, optional (default=False) Whether to return the figure object.

False
vmin Optional[float]

float, optional (default=None) Minimum value for color scaling.

None
vmax Optional[float]

float, optional (default=None) Maximum value for color scaling.

None
vcenter Optional[float]

float, optional (default=None) Center value for diverging colormap.

None
norm Optional[Normalize]

matplotlib.colors.Normalize, optional (default=None) Normalization object for colors.

None
cmap Union[Colormap, str, None]

str or matplotlib.colors.Colormap, optional (default='Reds') Colormap for the plot.

'Reds'
dot_max Optional[float]

float, optional (default=None) Maximum dot size.

None
dot_min Optional[float]

float, optional (default=None) Minimum dot size.

None
smallest_dot float

float, optional (default=0.0) Size of the smallest dot.

0.0
fontsize int

int, optional (default=12) Font size for labels and legends. Titles will be one point larger.

12
preserve_dict_order bool

bool, optional (default=False) When var_names is a dictionary, whether to preserve the original dictionary order. If True, genes will be ordered according to the dictionary's insertion order. If False (default), genes will be ordered according to cell type categories.

False

Returns:

Type Description
Optional[Union[Dict, DotPlot]]

If return_fig is True, returns the figure object.

Optional[Union[Dict, DotPlot]]

If show is False, returns axes dictionary.

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_dotplot.py
def dotplot(
    adata: AnnData,
    var_names: Union[_VarNames, Mapping[str, _VarNames]],
    groupby: Union[str, Sequence[str]],
    *,
    use_raw: Optional[bool] = None,
    log: bool = False,
    num_categories: int = 7,
    categories_order: Optional[Sequence[str]] = None,
    expression_cutoff: float = 0.0,
    mean_only_expressed: bool = False,
    standard_scale: Optional[Literal['var', 'group']] = None,
    title: Optional[str] = None,
    colorbar_title: Optional[str] = 'Mean expression\nin group',
    size_title: Optional[str] = 'Fraction of cells\nin group (%)',
    figsize: Optional[Tuple[float, float]] = None,
    dendrogram: Union[bool, str] = False,
    gene_symbols: Optional[str] = None,
    var_group_positions: Optional[Sequence[Tuple[int, int]]] = None,
    var_group_labels: Optional[Sequence[str]] = None,
    var_group_rotation: Optional[float] = None,
    layer: Optional[str] = None,
    swap_axes: Optional[bool] = False,
    dot_color_df: Optional[pd.DataFrame] = None,
    show: Optional[bool] = None,
    save: Optional[Union[str, bool]] = None,
    ax: Optional[_AxesSubplot] = None,
    return_fig: Optional[bool] = False,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    vcenter: Optional[float] = None,
    norm: Optional[Normalize] = None,
    cmap: Union[Colormap, str, None] = 'Reds',
    dot_max: Optional[float] = None,
    dot_min: Optional[float] = None,
    smallest_dot: float = 0.0,
    fontsize: int = 12,
    preserve_dict_order: bool = False,
    **kwds,
) -> Optional[Union[Dict, 'DotPlot']]:
    r"""
    Make a dot plot of the expression values of `var_names`.

    For each var_name and each `groupby` category a dot is plotted.
    Each dot represents two values: mean expression within each category
    (visualized by color) and fraction of cells expressing the `var_name` in the
    category (visualized by the size of the dot).

    Arguments:
        adata: AnnData
            Annotated data matrix.
        var_names: str or list of str or dict
            Variables to plot.
        groupby: str or list of str
            The key of the observation grouping to consider.
        use_raw: bool, optional (default=None)
            Use `raw` attribute of `adata` if present.
        log: bool, optional (default=False)
            Whether to log-transform the data.
        num_categories: int, optional (default=7)
            Number of categories to show.
        categories_order: list of str, optional (default=None)
            Order of categories to display.
        expression_cutoff: float, optional (default=0.0)
            Expression cutoff for calculating fraction of expressing cells.
        mean_only_expressed: bool, optional (default=False)
            Whether to calculate mean only for expressing cells.
        standard_scale: {'var', 'group'} or None, optional (default=None)
            Whether to standardize data.
        title: str, optional (default=None)
            Title for the plot.
        colorbar_title: str, optional (default='Mean expression\nin group')
            Title for the color bar.
        size_title: str, optional (default='Fraction of cells\nin group (%)')
            Title for the size legend.
        figsize: tuple, optional (default=None)
            Figure size (width, height) in inches. If provided, the plot dimensions will be scaled accordingly.
        dendrogram: bool or str, optional (default=False)
            Whether to add dendrogram to the plot.
        gene_symbols: str, optional (default=None)
            Key for gene symbols in `adata.var`.
        var_group_positions: list of tuples, optional (default=None)
            Positions for variable groups.
        var_group_labels: list of str, optional (default=None)
            Labels for variable groups.
        var_group_rotation: float, optional (default=None)
            Rotation angle for variable group labels.
        layer: str, optional (default=None)
            Layer to use for expression data.
        swap_axes: bool, optional (default=False)
            Whether to swap x and y axes.
        dot_color_df: pandas.DataFrame, optional (default=None)
            DataFrame for dot colors.
        show: bool, optional (default=None)
            Whether to show the plot.
        save: str or bool, optional (default=None)
            Whether to save the plot.
        ax: matplotlib.axes.Axes, optional (default=None)
            Axes object to plot on.
        return_fig: bool, optional (default=False)
            Whether to return the figure object.
        vmin: float, optional (default=None)
            Minimum value for color scaling.
        vmax: float, optional (default=None)
            Maximum value for color scaling.
        vcenter: float, optional (default=None)
            Center value for diverging colormap.
        norm: matplotlib.colors.Normalize, optional (default=None)
            Normalization object for colors.
        cmap: str or matplotlib.colors.Colormap, optional (default='Reds')
            Colormap for the plot.
        dot_max: float, optional (default=None)
            Maximum dot size.
        dot_min: float, optional (default=None)
            Minimum dot size.
        smallest_dot: float, optional (default=0.0)
            Size of the smallest dot.
        fontsize: int, optional (default=12)
            Font size for labels and legends. Titles will be one point larger.
        preserve_dict_order: bool, optional (default=False)
            When var_names is a dictionary, whether to preserve the original dictionary order.
            If True, genes will be ordered according to the dictionary's insertion order.
            If False (default), genes will be ordered according to cell type categories.

    Returns:
        If `return_fig` is True, returns the figure object.
        If `show` is False, returns axes dictionary.
    """
    # Convert var_names to list if string
    original_var_names_dict = None
    if isinstance(var_names, str):
        var_names = [var_names]
    elif isinstance(var_names, Mapping):
        # Save original dictionary reference for color bar ordering
        if preserve_dict_order:
            original_var_names_dict = var_names

        # Get gene groups
        gene_groups = []
        var_names_list = []

        if preserve_dict_order:
            # Preserve the original dictionary order
            for group, genes in var_names.items():
                if isinstance(genes, str):
                    genes = [genes]
                var_names_list.extend(genes)
                gene_groups.extend([group] * len(genes))
        else:
            # Get cell type order (original behavior)
            if categories_order is not None:
                group_order = categories_order
            elif pd.api.types.is_categorical_dtype(adata.obs[groupby]):
                group_order = list(adata.obs[groupby].cat.categories)
            else:
                group_order = list(adata.obs[groupby].unique())

            # Order gene groups according to cell types
            for group in group_order:
                if group in var_names:
                    genes = var_names[group]
                    if isinstance(genes, str):
                        genes = [genes]
                    var_names_list.extend(genes)
                    gene_groups.extend([group] * len(genes))

            # Add any remaining groups that weren't in the cell types
            for group, genes in var_names.items():
                if group not in group_order:
                    if isinstance(genes, str):
                        genes = [genes]
                    var_names_list.extend(genes)
                    gene_groups.extend([group] * len(genes))

        var_names = var_names_list

    # Get expression matrix
    if use_raw and adata.raw is not None:
        matrix = adata.raw.X
        var_names_idx = [adata.raw.var_names.get_loc(name) for name in var_names]
    else:
        matrix = adata.X if layer is None else adata.layers[layer]
        var_names_idx = [adata.var_names.get_loc(name) for name in var_names]

    # Determine category order
    if categories_order is not None:
        cats = categories_order
    else:
        # Use the categorical order from adata if available
        if pd.api.types.is_categorical_dtype(adata.obs[groupby]):
            cats = adata.obs[groupby].cat.categories
        else:
            # If not categorical, get unique values
            cats = adata.obs[groupby].unique()

    # Get aggregated data with specified order
    agg = adata.obs[groupby].value_counts().reindex(cats)
    cell_counts = agg.to_numpy()

    # Get colors for cell types if available
    cell_colors = None
    color_dict = None
    try:
        color_key = f"{groupby}_colors"
        if color_key in adata.uns:
            colors = adata.uns[color_key]
            # Create color dictionary mapping cell types to colors
            if pd.api.types.is_categorical_dtype(adata.obs[groupby]):
                # Use categorical order for colors
                color_dict = dict(zip(adata.obs[groupby].cat.categories, colors))
            else:
                # Use unique order for colors
                unique_cats = adata.obs[groupby].unique()
                color_dict = dict(zip(unique_cats, colors[:len(unique_cats)]))

            # Get colors for the actual categories in the plot
            cell_colors = [color_dict.get(cat, '#CCCCCC') for cat in agg.index]
    except (KeyError, IndexError):
        cell_colors = None
        color_dict = None

    # Calculate mean expression and fraction of expressing cells
    means = np.zeros((len(agg), len(var_names)))
    fractions = np.zeros_like(means)

    for i, group in enumerate(agg.index):
        mask = adata.obs[groupby] == group
        group_matrix = matrix[mask][:, var_names_idx]

        # Calculate mean expression
        if mean_only_expressed:
            expressed = group_matrix > expression_cutoff
            means[i] = np.array([
                group_matrix[:, j][expressed[:, j]].mean() if expressed[:, j].any() else 0
                for j in range(group_matrix.shape[1])
            ])
        else:
            means[i] = np.mean(group_matrix, axis=0)

        # Calculate fraction of expressing cells
        fractions[i] = np.mean(group_matrix > expression_cutoff, axis=0)

    # Scale if requested
    if standard_scale == 'group':
        means = (means - means.min(axis=1, keepdims=True)) / (means.max(axis=1, keepdims=True) - means.min(axis=1, keepdims=True))
    elif standard_scale == 'var':
        means = (means - means.min(axis=0)) / (means.max(axis=0) - means.min(axis=0))

    # Handle dot size limits
    if dot_max is not None:
        fractions = np.minimum(fractions, dot_max)
    if dot_min is not None:
        fractions = np.maximum(fractions, dot_min)

    # Scale dot sizes to account for smallest_dot
    if smallest_dot > 0:
        fractions = smallest_dot + (1 - smallest_dot) * fractions

    # Create the plot
    h, w = means.shape

    # Calculate dimensions based on figsize if provided
    if figsize is not None:
        # Use figsize to determine height and width
        # Adjust for the number of rows and columns to maintain aspect ratio
        base_height = figsize[1] * 0.7  # Use 70% of figsize height for main plot
        base_width = figsize[0] * 0.7   # Use 70% of figsize width for main plot

        # Scale based on data dimensions
        height = base_height * (h / max(h, w))
        width = base_width * (w / max(h, w))
    else:
        # Default behavior
        height = h / 3
        width = w / 3


    # Create SizedHeatmap
    m = ma.SizedHeatmap(
        size=fractions,
        color=means,
        cluster_data=fractions if dendrogram else None,
        height=height,
        width=width,
        edgecolor="lightgray",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        #norm=norm,
        size_legend_kws=dict(
            colors="#c2c2c2",
            title=size_title,
            labels=[f"{int(x*100)}%" for x in [0.2, 0.4, 0.6, 0.8, 1.0]],
            show_at=[0.2, 0.4, 0.6, 0.8, 1.0],
            fontsize=fontsize,
            ncol=3,
            title_fontproperties={"size": fontsize + 1, "weight": 100}
        ),
        color_legend_kws=dict(
            title=colorbar_title,
            fontsize=fontsize,
            orientation="horizontal",
            title_fontproperties={"size": fontsize + 1, "weight": 100}
        ),
    )

    # Add labels
    m.add_top(mp.Labels(var_names, fontsize=fontsize), pad=0.1)

    # Group genes if var_names was a dictionary
    if 'gene_groups' in locals():
        # Get colors for gene groups
        try:
            # Use the same color_dict that was created for cell types
            if color_dict is not None:
                # Get unique groups and check which ones are not in color_dict
                unique_groups = list(dict.fromkeys(gene_groups))
                missing_groups = [g for g in unique_groups if g not in color_dict]

                # If there are missing groups, add colors from palette
                if missing_groups:
                    if len(palette_28) >= len(color_dict) + len(missing_groups):
                        extra_colors = palette_28[len(color_dict):len(color_dict) + len(missing_groups)]
                    else:
                        extra_colors = palette_56[len(color_dict):len(color_dict) + len(missing_groups)]
                    color_dict.update(dict(zip(missing_groups, extra_colors)))
            else:
                # If no colors found in uns, use default palette
                unique_groups = list(dict.fromkeys(gene_groups))
                if len(unique_groups) <= 28:
                    palette = palette_28
                else:
                    palette = palette_56
                color_dict = dict(zip(unique_groups, palette[:len(unique_groups)]))
        except (KeyError, AttributeError):
            # If colors not found in uns, use default palette
            unique_groups = list(dict.fromkeys(gene_groups))
            if len(unique_groups) <= 28:
                palette = palette_28
            else:
                palette = palette_56
            color_dict = dict(zip(unique_groups, palette[:len(unique_groups)]))

        # Add color bars with matching order
        # Add group labels
        # Only show used colors in legend and increase group spacing
        if preserve_dict_order and original_var_names_dict is not None:
            # When preserving dict order, use the original dictionary key order
            used_groups = list(original_var_names_dict.keys())
        else:
            # Use the order as they appear in gene_groups
            used_groups = list(dict.fromkeys(gene_groups))

        used_color_dict = {k: color_dict[k] for k in used_groups}
        m.add_top(
            mp.Colors(gene_groups, palette=used_color_dict),
            pad=0.1,
            size=0.15,
        )
        # Add group labels with increased spacing
        m.group_cols(gene_groups,order=used_groups)

    # Add cell type colors if available
    if color_dict is not None:
        # Add color bar using the properly created color_dict
        m.add_left(
            mp.Colors(agg.index, palette=color_dict),
            size=0.15,
            pad=0.1,
            legend=False,
        )

    # Add cell type labels
    m.add_left(mp.Labels(agg.index, align="right", fontsize=fontsize), pad=0.1)

    # Add cell counts
    m.add_right(
        mp.Numbers(
            cell_counts,
            color="#EEB76B",
            label="Count",
            label_props={'size': fontsize},
            props={'size': fontsize},
            show_value=False
        ),
        size=0.5,
        pad=0.1,
    )

    # Add dendrogram if requested
    if dendrogram:
        m.add_dendrogram("right", pad=0.1)

    # Add legends
    m.add_legends(box_padding=2)

    # Render the plot
    fig = m.render()

    if return_fig:
        return fig
    elif not show:
        return m
    return None

omicverse.pl.rank_genes_groups_dotplot(adata, plot_type='dotplot', *, groups=None, n_genes=None, groupby=None, values_to_plot=None, var_names=None, min_logfoldchange=None, key=None, show=None, save=None, return_fig=False, gene_symbols=None, **kwds)

Create a dot plot from rank_genes_groups results.

Parameters

AnnData

Annotated data matrix.

str

Currently only 'dotplot' is supported.

str or list of str, optional

Groups to include in the plot.

int, optional

Number of genes to include in the plot.

str, optional

Key in adata.obs to group by.

str, optional

Key in rank_genes_groups results to plot (e.g. 'logfoldchanges', 'scores').

str or list of str or dict, optional

Variables to include in the plot. Can be: - A list of gene names: ['gene1', 'gene2', ...] - A dictionary mapping group names to gene lists: {'group1': ['gene1', 'gene2'], 'group2': ['gene3', 'gene4']} When a dictionary is provided, genes will be grouped and labeled accordingly in the plot.

float, optional

Minimum log fold change to include in the plot.

str, optional

Key in adata.uns to use for rank_genes_groups results.

bool, optional

Whether to show the plot.

bool, optional

Whether to save the plot.

bool

Whether to return the figure object.

str, optional

Key for gene symbols in adata.var.

**kwds : dict Additional keyword arguments to pass to dotplot.

Returns

If return_fig is True, returns the figure object. If show is False, returns axes dictionary.

Examples

Basic usage with top genes

sc.pl.rank_genes_groups_dotplot(adata, n_genes=5)

Using logfoldchanges for coloring

sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, values_to_plot='logfoldchanges')

Grouping genes manually

gene_groups = { ... 'Group1': ['gene1', 'gene2'], ... 'Group2': ['gene3', 'gene4'] ... } sc.pl.rank_genes_groups_dotplot(adata, var_names=gene_groups)

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_dotplot.py
def rank_genes_groups_dotplot(
    adata: AnnData,
    plot_type: str = "dotplot",
    *,
    groups: Optional[Union[str, Sequence[str]]] = None,
    n_genes: Optional[int] = None,
    groupby: Optional[str] = None,
    values_to_plot: Optional[str] = None,
    var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None,
    min_logfoldchange: Optional[float] = None,
    key: Optional[str] = None,
    show: Optional[bool] = None,
    save: Optional[bool] = None,
    return_fig: bool = False,
    gene_symbols: Optional[str] = None,
    **kwds: Any,
) -> Optional[Union[Dict, Any]]:
    """
    Create a dot plot from rank_genes_groups results.

    Parameters
    ----------
    adata : AnnData
        Annotated data matrix.
    plot_type : str
        Currently only 'dotplot' is supported.
    groups : str or list of str, optional
        Groups to include in the plot.
    n_genes : int, optional
        Number of genes to include in the plot.
    groupby : str, optional
        Key in `adata.obs` to group by.
    values_to_plot : str, optional
        Key in rank_genes_groups results to plot (e.g. 'logfoldchanges', 'scores').
    var_names : str or list of str or dict, optional
        Variables to include in the plot. Can be:
        - A list of gene names: ['gene1', 'gene2', ...]
        - A dictionary mapping group names to gene lists: {'group1': ['gene1', 'gene2'], 'group2': ['gene3', 'gene4']}
        When a dictionary is provided, genes will be grouped and labeled accordingly in the plot.
    min_logfoldchange : float, optional
        Minimum log fold change to include in the plot.
    key : str, optional
        Key in `adata.uns` to use for rank_genes_groups results.
    show : bool, optional
        Whether to show the plot.
    save : bool, optional
        Whether to save the plot.
    return_fig : bool
        Whether to return the figure object.
    gene_symbols : str, optional
        Key for gene symbols in `adata.var`.
    **kwds : dict
        Additional keyword arguments to pass to dotplot.

    Returns
    -------
    If `return_fig` is True, returns the figure object.
    If `show` is False, returns axes dictionary.

    Examples
    --------
    >>> # Basic usage with top genes
    >>> sc.pl.rank_genes_groups_dotplot(adata, n_genes=5)

    >>> # Using logfoldchanges for coloring
    >>> sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, values_to_plot='logfoldchanges')

    >>> # Grouping genes manually
    >>> gene_groups = {
    ...     'Group1': ['gene1', 'gene2'],
    ...     'Group2': ['gene3', 'gene4']
    ... }
    >>> sc.pl.rank_genes_groups_dotplot(adata, var_names=gene_groups)
    """
    if plot_type != "dotplot":
        raise ValueError("Only 'dotplot' is currently supported")

    if var_names is not None and n_genes is not None:
        msg = (
            "The arguments n_genes and var_names are mutually exclusive. Please "
            "select only one."
        )
        raise ValueError(msg)

    if key is None:
        key = "rank_genes_groups"

    if groupby is None:
        groupby = str(adata.uns[key]["params"]["groupby"])
    group_names = adata.uns[key]["names"].dtype.names if groups is None else groups

    if var_names is not None:
        if isinstance(var_names, Mapping):
            # get a single list of all gene names in the dictionary
            var_names_list = functools.reduce(
                operator.iadd, [list(x) for x in var_names.values()], []
            )
        elif isinstance(var_names, str):
            var_names_list = [var_names]
        else:
            var_names_list = var_names
    else:
        # set n_genes = 10 as default when none of the options is given
        if n_genes is None:
            n_genes = 10

        # dict in which each group is the key and the n_genes are the values
        var_names = {}
        var_names_list = []
        for group in group_names:
            df = rank_genes_groups_df(
                adata,
                group,
                key=key,
                gene_symbols=gene_symbols,
                log2fc_min=min_logfoldchange,
            )

            if gene_symbols is not None:
                df["names"] = df[gene_symbols]

            genes_list = df.names[df.names.notnull()].tolist()

            if len(genes_list) == 0:
                print(f"Warning: No genes found for group {group}")
                continue
            genes_list = genes_list[n_genes:] if n_genes < 0 else genes_list[:n_genes]
            var_names[group] = genes_list
            var_names_list.extend(genes_list)

    # by default add dendrogram to plots
    kwds.setdefault("dendrogram", True)

    # Get values to plot if specified
    title = None
    values_df = None
    if values_to_plot is not None:
        values_df = _get_values_to_plot(
            adata,
            values_to_plot,
            var_names_list,
            key=key,
            gene_symbols=gene_symbols,
        )
        title = values_to_plot
        if values_to_plot == "logfoldchanges":
            title = "log fold change"
        else:
            title = values_to_plot.replace("_", " ").replace("pvals", "p-value")

    # Create the plot
    _pl = dotplot(
        adata,
        var_names,
        groupby,
        dot_color_df=values_df,
        return_fig=True,
        gene_symbols=gene_symbols,
        preserve_dict_order=True,
        **kwds,
    )

    if title is not None and "colorbar_title" not in kwds:
        _pl.legend(colorbar_title=title)

    if return_fig:
        return _pl
    elif not show:
        return _pl
    return None

Embedding and Scatter Plots

omicverse.pl.embedding(adata, basis, *, color=None, gene_symbols=None, use_raw=None, sort_order=True, edges=False, edges_width=0.1, edges_color='grey', neighbors_key=None, arrows=False, arrows_kwds=None, groups=None, components=None, dimensions=None, layer=None, projection='2d', scale_factor=None, color_map=None, cmap=None, palette=None, na_color='lightgray', na_in_legend=True, size=None, frameon='small', legend_fontsize=None, legend_fontweight='bold', legend_loc='right margin', legend_fontoutline=None, colorbar_loc='right', vmax=None, vmin=None, vcenter=None, norm=None, add_outline=False, outline_width=(0.3, 0.05), outline_color=('black', 'white'), ncols=4, hspace=0.25, wspace=None, title=None, show=None, save=None, ax=None, return_fig=None, marker='.', **kwargs)

Scatter plot for user specified embedding basis (e.g. umap, pca, etc).

Parameters:

Name Type Description Default
adata AnnData

Annotated data matrix.

required
basis str

Name of the obsm basis to use.

required
color Union[str, Sequence[str], None]

Keys for annotations of observations/cells or variables/genes. (None)

None
gene_symbols Optional[str]

Key for field in .var that stores gene symbols. (None)

None
use_raw Optional[bool]

Use .raw attribute of adata if present. (None)

None
sort_order bool

For continuous annotations used as color parameter, plot data points with higher values on top of others. (True)

True
edges bool

Show edges between cells. (False)

False
edges_width float

Width of edges. (0.1)

0.1
edges_color Union[str, Sequence[float], Sequence[str]]

Color of edges. ('grey')

'grey'
neighbors_key Optional[str]

Key to use for neighbors. (None)

None
arrows bool

Show arrows for velocity. (False)

False
arrows_kwds Optional[Mapping[str, Any]]

Keyword arguments for arrow plots. (None)

None
groups Optional[str]

Groups to highlight. (None)

None
components Union[str, Sequence[str]]

Components to plot. (None)

None
dimensions Optional[Union[Tuple[int, int], Sequence[Tuple[int, int]]]]

Dimensions to plot. (None)

None
layer Optional[str]

Name of the layer to use for coloring. (None)

None
projection Literal['2d', '3d']

Type of projection ('2d' or '3d'). ('2d')

'2d'
scale_factor Optional[float]

Scaling factor for sizes. (None)

None
color_map Union[Colormap, str, None]

Colormap to use for continuous variables. (None)

None
cmap Union[Colormap, str, None]

Colormap to use for continuous variables. (None)

None
palette Union[str, Sequence[str], Cycler, None]

Colors to use for categorical variables. (None)

None
na_color ColorLike

Color to use for NaN values. ('lightgray')

'lightgray'
na_in_legend bool

Include NaN values in legend. (True)

True
size Union[float, Sequence[float], None]

Size of the dots. (None)

None
frameon Optional[bool]

Draw a frame around the plot. ('small')

'small'
legend_fontsize Union[int, float, _FontSize, None]

Font size for legend. (None)

None
legend_fontweight Union[int, _FontWeight]

Font weight for legend. ('bold')

'bold'
legend_loc str

Location of legend. ('right margin')

'right margin'
legend_fontoutline Optional[int]

Outline width for legend text. (None)

None
colorbar_loc Optional[str]

Location of colorbar. ('right')

'right'
vmax Union[VBound, Sequence[VBound], None]

Maximum value for colorbar. (None)

None
vmin Union[VBound, Sequence[VBound], None]

Minimum value for colorbar. (None)

None
vcenter Union[VBound, Sequence[VBound], None]

Center value for colorbar. (None)

None
norm Union[Normalize, Sequence[Normalize], None]

Normalization for colorbar. (None)

None
add_outline Optional[bool]

Add outline to points. (False)

False
outline_width Tuple[float, float]

Width of outline. ((0.3, 0.05))

(0.3, 0.05)
outline_color Tuple[str, str]

Color of outline. (('black', 'white'))

('black', 'white')
ncols int

Number of columns for subplots. (4)

4
hspace float

Height spacing between subplots. (0.25)

0.25
wspace Optional[float]

Width spacing between subplots. (None)

None
title Union[str, Sequence[str], None]

Title for the plot. (None)

None
show Optional[bool]

Show the plot. (None)

None
save Union[bool, str, None]

Save the plot. (None)

None
ax Optional[Axes]

Matplotlib axes object. (None)

None
return_fig Optional[bool]

Return figure object. (None)

None
marker Union[str, Sequence[str]]

Marker style. ('.')

'.'
**kwargs

Additional keyword arguments.

{}

Returns:

Name Type Description
ax Union[Figure, Axes, None]

If show==False a :class:~matplotlib.axes.Axes or a list of it.

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def embedding(
    adata: AnnData,
    basis: str,
    *,
    color: Union[str, Sequence[str], None] = None,
    gene_symbols: Optional[str] = None,
    use_raw: Optional[bool] = None,
    sort_order: bool = True,
    edges: bool = False,
    edges_width: float = 0.1,
    edges_color: Union[str, Sequence[float], Sequence[str]] = 'grey',
    neighbors_key: Optional[str] = None,
    arrows: bool = False,
    arrows_kwds: Optional[Mapping[str, Any]] = None,
    groups: Optional[str] = None,
    components: Union[str, Sequence[str]] = None,
    dimensions: Optional[Union[Tuple[int, int], Sequence[Tuple[int, int]]]] = None,
    layer: Optional[str] = None,
    projection: Literal['2d', '3d'] = '2d',
    scale_factor: Optional[float] = None,
    color_map: Union[Colormap, str, None] = None,
    cmap: Union[Colormap, str, None] = None,
    palette: Union[str, Sequence[str], Cycler, None] = None,
    na_color: ColorLike = "lightgray",
    na_in_legend: bool = True,
    size: Union[float, Sequence[float], None] = None,
    frameon: Optional[bool] = 'small',
    legend_fontsize: Union[int, float, _FontSize, None] = None,
    legend_fontweight: Union[int, _FontWeight] = 'bold',
    legend_loc: str = 'right margin',
    legend_fontoutline: Optional[int] = None,
    colorbar_loc: Optional[str] = "right",
    vmax: Union[VBound, Sequence[VBound], None] = None,
    vmin: Union[VBound, Sequence[VBound], None] = None,
    vcenter: Union[VBound, Sequence[VBound], None] = None,
    norm: Union[Normalize, Sequence[Normalize], None] = None,
    add_outline: Optional[bool] = False,
    outline_width: Tuple[float, float] = (0.3, 0.05),
    outline_color: Tuple[str, str] = ('black', 'white'),
    ncols: int = 4,
    hspace: float = 0.25,
    wspace: Optional[float] = None,
    title: Union[str, Sequence[str], None] = None,
    show: Optional[bool] = None,
    save: Union[bool, str, None] = None,
    ax: Optional[Axes] = None,
    return_fig: Optional[bool] = None,
    marker: Union[str, Sequence[str]] = '.',
    **kwargs,
) -> Union[Figure, Axes, None]:
    r"""Scatter plot for user specified embedding basis (e.g. umap, pca, etc).

    Arguments:
        adata: Annotated data matrix.
        basis: Name of the `obsm` basis to use.
        color: Keys for annotations of observations/cells or variables/genes. (None)
        gene_symbols: Key for field in `.var` that stores gene symbols. (None)
        use_raw: Use `.raw` attribute of `adata` if present. (None)
        sort_order: For continuous annotations used as color parameter, plot data points with higher values on top of others. (True)
        edges: Show edges between cells. (False)
        edges_width: Width of edges. (0.1)
        edges_color: Color of edges. ('grey')
        neighbors_key: Key to use for neighbors. (None)
        arrows: Show arrows for velocity. (False)
        arrows_kwds: Keyword arguments for arrow plots. (None)
        groups: Groups to highlight. (None)
        components: Components to plot. (None)
        dimensions: Dimensions to plot. (None)
        layer: Name of the layer to use for coloring. (None)
        projection: Type of projection ('2d' or '3d'). ('2d')
        scale_factor: Scaling factor for sizes. (None)
        color_map: Colormap to use for continuous variables. (None)
        cmap: Colormap to use for continuous variables. (None)
        palette: Colors to use for categorical variables. (None)
        na_color: Color to use for NaN values. ('lightgray')
        na_in_legend: Include NaN values in legend. (True)
        size: Size of the dots. (None)
        frameon: Draw a frame around the plot. ('small')
        legend_fontsize: Font size for legend. (None)
        legend_fontweight: Font weight for legend. ('bold')
        legend_loc: Location of legend. ('right margin')
        legend_fontoutline: Outline width for legend text. (None)
        colorbar_loc: Location of colorbar. ('right')
        vmax: Maximum value for colorbar. (None)
        vmin: Minimum value for colorbar. (None)
        vcenter: Center value for colorbar. (None)
        norm: Normalization for colorbar. (None)
        add_outline: Add outline to points. (False)
        outline_width: Width of outline. ((0.3, 0.05))
        outline_color: Color of outline. (('black', 'white'))
        ncols: Number of columns for subplots. (4)
        hspace: Height spacing between subplots. (0.25)
        wspace: Width spacing between subplots. (None)
        title: Title for the plot. (None)
        show: Show the plot. (None)
        save: Save the plot. (None)
        ax: Matplotlib axes object. (None)
        return_fig: Return figure object. (None)
        marker: Marker style. ('.')
        **kwargs: Additional keyword arguments.

    Returns:
        ax: If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
    """

    return _embedding(adata=adata, basis=basis, color=color, 
                     gene_symbols=gene_symbols, use_raw=use_raw, 
                     sort_order=sort_order, edges=edges, 
                     edges_width=edges_width, edges_color=edges_color, 
                     neighbors_key=neighbors_key, arrows=arrows, 
                     arrows_kwds=arrows_kwds, groups=groups, 
                     components=components, dimensions=dimensions, 
                     layer=layer, projection=projection, scale_factor=scale_factor,
                       color_map=color_map, cmap=cmap, palette=palette, 
                       na_color=na_color, na_in_legend=na_in_legend, 
                       size=size, frameon=frameon, legend_fontsize=legend_fontsize, 
                       legend_fontweight=legend_fontweight, legend_loc=legend_loc, 
                       legend_fontoutline=legend_fontoutline, colorbar_loc=colorbar_loc, 
                       vmax=vmax, vmin=vmin, vcenter=vcenter, norm=norm, 
                       add_outline=add_outline, outline_width=outline_width, 
                       outline_color=outline_color, ncols=ncols, hspace=hspace,
                         wspace=wspace, title=title, show=show, save=save, ax=ax,
                           return_fig=return_fig, marker=marker, **kwargs)

omicverse.pl.embedding_celltype(adata, figsize=(6, 4), basis='umap', celltype_key='major_celltype', title=None, celltype_range=(2, 9), embedding_range=(3, 10), xlim=-1000)

Plot embedding with celltype color by omicverse.

Parameters:

Name Type Description Default
adata AnnData

AnnData object

required
figsize tuple

tuple, optional (default=(6,4)) Figure size

(6, 4)
basis str

str, optional (default='umap') Embedding method

'umap'
celltype_key str

str, optional (default='major_celltype') Celltype key in adata.obs

'major_celltype'
title str

str, optional (default=None) Figure title

None
celltype_range tuple

tuple, optional (default=(2,9)) Celltype range to plot

(2, 9)
embedding_range tuple

tuple, optional (default=(3,10)) Embedding range to plot

(3, 10)
xlim int

int, optional (default=-1000) X axis limit

-1000

Returns:

Name Type Description
fig tuple

figure and axis

ax tuple

axis

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def embedding_celltype(adata:AnnData,figsize:tuple=(6,4),basis:str='umap',
                            celltype_key:str='major_celltype',title:str=None,
                            celltype_range:tuple=(2,9),
                            embedding_range:tuple=(3,10),
                            xlim:int=-1000)->tuple:
    r"""
    Plot embedding with celltype color by omicverse.

    Arguments:
        adata: AnnData object  
        figsize: tuple, optional (default=(6,4))
            Figure size
        basis: str, optional (default='umap')
            Embedding method
        celltype_key: str, optional (default='major_celltype')
            Celltype key in adata.obs
        title: str, optional (default=None)
            Figure title
        celltype_range: tuple, optional (default=(2,9))
            Celltype range to plot
        embedding_range: tuple, optional (default=(3,10))
            Embedding range to plot
        xlim: int, optional (default=-1000)
            X axis limit

    Returns:
        fig: figure and axis
        ax: axis
    """

    adata.obs[celltype_key]=adata.obs[celltype_key].astype('category')
    if pd.__version__>="2.0.0":
        cell_num_pd=pd.DataFrame(adata.obs[celltype_key].value_counts())
        cell_num_pd[celltype_key]=cell_num_pd['count']
    else:
        cell_num_pd=pd.DataFrame(adata.obs[celltype_key].value_counts())

    if '{}_colors'.format(celltype_key) in adata.uns.keys():
        cell_color_dict=dict(zip(adata.obs[celltype_key].cat.categories.tolist(),
                        adata.uns['{}_colors'.format(celltype_key)]))
    else:
        if len(adata.obs[celltype_key].cat.categories)>28:
            cell_color_dict=dict(zip(adata.obs[celltype_key].cat.categories,sc.pl.palettes.default_102))
        else:
            cell_color_dict=dict(zip(adata.obs[celltype_key].cat.categories,sc.pl.palettes.zeileis_28))

    if figsize==None:
        if len(adata.obs[celltype_key].cat.categories)<10:
            fig = plt.figure(figsize=(6,4))
        else:
            print('The number of cell types is too large, please set the figsize parameter')
            return
    else:
        fig = plt.figure(figsize=figsize)
    grid = plt.GridSpec(10, 10)
    ax1 = fig.add_subplot(grid[:, embedding_range[0]:embedding_range[1]])       # 占据第一行的所有列
    ax2 = fig.add_subplot(grid[celltype_range[0]:celltype_range[1], :2]) 
    # 定义子图的大小和位置
         # 占据第二行的前两列
    #ax3 = fig.add_subplot(grid[1:, 2])      # 占据第二行及以后的最后一列
    #ax4 = fig.add_subplot(grid[2, 0])       # 占据最后一行的第一列
    #ax5 = fig.add_subplot(grid[2, 1])       # 占据最后一行的第二列

    sc.pl.embedding(
        adata,
        basis=basis,
        color=[celltype_key],
        title='',
        frameon=False,
        #wspace=0.65,
        ncols=3,
        ax=ax1,
        legend_loc=False,
        show=False
    )



    for idx,cell in zip(range(cell_num_pd.shape[0]),
                        adata.obs[celltype_key].cat.categories):
        ax2.scatter(100,
                cell,c=cell_color_dict[cell],s=50)
        ax2.plot((100,cell_num_pd.loc[cell,celltype_key]),(idx,idx),
                c=cell_color_dict[cell],lw=4)
        ax2.text(100,idx+0.2,
                cell+'('+str("{:,}".format(cell_num_pd.loc[cell,celltype_key]))+')',fontsize=11)
    ax2.set_xlim(xlim,cell_num_pd.iloc[1].values[0]) 
    ax2.text(xlim,idx+1,title,fontsize=12)
    ax2.grid(False)
    #ax2.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10,**legend_awargs)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.spines['bottom'].set_visible(False)
    ax2.spines['left'].set_visible(False)
    ax2.axis('off')

    # ——关键:确保 ax2 没有图例——
    if ax1.get_legend() is not None:   # 如果有,就移除
        ax1.get_legend().remove()
    if ax2.get_legend() is not None:   # 如果有,就移除
        ax2.get_legend().remove()

    return fig,[ax1,ax2]

omicverse.pl.embedding_adjust(adata, groupby, exclude=(), basis='X_umap', ax=None, adjust_kwargs=None, text_kwargs=None)

Get locations of cluster median and adjust text labels accordingly.

Borrowed from scanpy github forum.

Parameters:

Name Type Description Default
adata

AnnData object

required
groupby

str Key in adata.obs for grouping

required
exclude

tuple, optional (default=()) Groups to exclude from labeling

()
basis

str, optional (default='X_umap') Embedding basis key in adata.obsm

'X_umap'
ax

matplotlib.axes.Axes, optional (default=None) Axes object to plot on

None
adjust_kwargs

dict, optional (default=None) Arguments for adjust_text function

None
text_kwargs

dict, optional (default=None) Arguments for text annotation

None

Returns:

Name Type Description
medians

dict Dictionary of median positions for each group

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def embedding_adjust(
    adata, groupby, exclude=(), 
    basis='X_umap',ax=None, adjust_kwargs=None, text_kwargs=None
):
    r"""
    Get locations of cluster median and adjust text labels accordingly.

    Borrowed from scanpy github forum.

    Arguments:
        adata: AnnData object
        groupby: str
            Key in adata.obs for grouping
        exclude: tuple, optional (default=())
            Groups to exclude from labeling
        basis: str, optional (default='X_umap')
            Embedding basis key in adata.obsm
        ax: matplotlib.axes.Axes, optional (default=None)
            Axes object to plot on
        adjust_kwargs: dict, optional (default=None)
            Arguments for adjust_text function
        text_kwargs: dict, optional (default=None)
            Arguments for text annotation

    Returns:
        medians: dict
            Dictionary of median positions for each group
    """
    if adjust_kwargs is None:
        adjust_kwargs = {"text_from_points": False}
    if text_kwargs is None:
        text_kwargs = {}

    medians = {}

    for g, g_idx in adata.obs.groupby(groupby).groups.items():
        if g in exclude:
            continue
        medians[g] = np.median(adata[g_idx].obsm[basis], axis=0)

    if ax is None:
        texts = [
            plt.text(x=x, y=y, s=k, **text_kwargs) for k, (x, y) in medians.items()
        ]
    else:
        texts = [ax.text(x=x, y=y, s=k, **text_kwargs) for k, (x, y) in medians.items()]
    from adjustText import adjust_text
    adjust_text(texts, **adjust_kwargs)

omicverse.pl.embedding_atlas(adata, basis, color, title=None, figsize=(4, 4), ax=None, cmap='RdBu', legend_loc='right margin', frameon='small', fontsize=12)

Create high-resolution embedding plots using Datashader for large datasets.

Uses Datashader to render embeddings at high resolution, suitable for datasets with millions of cells where standard scatter plots become ineffective.

Parameters:

Name Type Description Default
adata

Annotated data object with embedding coordinates

required
basis

Key in adata.obsm containing embedding coordinates (e.g., 'X_umap')

required
color

Gene name or obs column to color cells by

required
title

Plot title (None, uses color name)

None
figsize

Figure dimensions as (width, height) ((4,4))

(4, 4)
ax

Existing matplotlib axes object (None)

None
cmap

Colormap for continuous values ('RdBu')

'RdBu'
legend_loc

Legend position ('right margin')

'right margin'
frameon

Frame style - False, 'small', or True ('small')

'small'
fontsize

Font size for labels and title (12)

12

Returns:

Name Type Description
ax

matplotlib.axes.Axes object with rendered embedding

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_embedding.py
def embedding_atlas(adata,basis,color,
                    title=None,figsize=(4,4),ax=None,cmap='RdBu',
                    legend_loc = 'right margin',frameon='small',
                    fontsize=12):
    r"""
    Create high-resolution embedding plots using Datashader for large datasets.

    Uses Datashader to render embeddings at high resolution, suitable for datasets
    with millions of cells where standard scatter plots become ineffective.

    Arguments:
        adata: Annotated data object with embedding coordinates
        basis: Key in adata.obsm containing embedding coordinates (e.g., 'X_umap')
        color: Gene name or obs column to color cells by
        title: Plot title (None, uses color name)
        figsize: Figure dimensions as (width, height) ((4,4))
        ax: Existing matplotlib axes object (None)
        cmap: Colormap for continuous values ('RdBu')
        legend_loc: Legend position ('right margin')
        frameon: Frame style - False, 'small', or True ('small')
        fontsize: Font size for labels and title (12)

    Returns:
        ax: matplotlib.axes.Axes object with rendered embedding
    """
    import scanpy as sc
    import pandas as pd
    import datashader as ds
    import datashader.transfer_functions as tf
    from scipy.sparse import issparse
    from bokeh.palettes import RdBu9
    import bokeh
    # 创建一个 Canvas 对象
    cvs = ds.Canvas(plot_width=800, plot_height=800)


    embedding = adata.obsm[basis]
    # 如果你有一个感兴趣的分类标签,比如细胞类型

    # 将数据转换为 DataFrame
    df = pd.DataFrame(embedding, columns=['x', 'y'])

    if color in adata.obs.columns:
        labels = adata.obs[color].tolist()  # 假设'cell_type'是一个列名
    elif color in adata.var_names:
        X=adata[:,color].X
        if issparse(X):
            labels=X.toarray().reshape(-1)
        else:
            labels=X.reshape(-1)
    elif (not adata.raw is None) and (color in adata.raw.var_names):
        X=adata.raw[:,color].X
        if issparse(X):
            labels=X.toarray().reshape(-1)
        else:
            labels=X.reshape(-1)


    df['label'] = labels
    #return labels
    #print(labels[0],type(labels[0]))
    if type(labels[0]) is str:
        df['label']=df['label'].astype('category')
        # 聚合数据
        agg = cvs.points(df, 'x', 'y',ds.count_cat('label'),
                        )
        legend_tag=True
        color_key = dict(zip(adata.obs[color].cat.categories,
                        adata.uns[f'{color}_colors']))


        # 使用色彩映射
        img = tf.shade(tf.spread(agg,px=0),color_key=[color_key[i] for i in df['label'].cat.categories], 
                       how='eq_hist')
    elif (type(labels[0]) is int) or (type(labels[0]) is float) or (type(labels[0]) is np.float32)\
    or (type(labels[0]) is np.float64) or (type(labels[0]) is np.int):
        # 聚合数据
        agg = cvs.points(df, 'x', 'y',ds.mean('label'),
                        )
        legend_tag=False
        if cmap in bokeh.palettes.all_palettes.keys():
            num=list(bokeh.palettes.all_palettes[cmap].keys())[-1]
            img = tf.shade(agg,cmap=bokeh.palettes.all_palettes[cmap][num], 
                           )
        else:
            img = tf.shade(agg,cmap=cmap, 
                           )
    else:
        print('Unrecognized label type')
        return None



        # 假设 img 是 Datashader 渲染的图像
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig=ax.figure

    # 假设 img 是一个 NumPy 数组或类似的对象,这里使用 img 的占位符
    # img = np.random.rand(100, 100)  # 示例数据
    ax.imshow(img.to_pil(), aspect='auto')


    # 自定义格式化函数以显示坐标
    def format_coord(x, y):
        return f"x={x:.2f}, y={y:.2f}"

    ax.format_coord = format_coord

    if legend_tag==True:
        # 手动创建图例
        unique_labels = adata.obs[color].cat.categories

        # 创建图例项
        for label in unique_labels:
            ax.scatter([], [], c=color_key[label], label=label)

        if legend_loc == "right margin":
            ax.legend(
                frameon=False,
                loc="center left",
                bbox_to_anchor=(1, 0.5),
                ncol=(1 if len(unique_labels) <= 14 else 2 if len(unique_labels) <= 30 else 3),
                fontsize=fontsize-1,
            )
    if frameon==False:
        ax.axis('off')
    elif frameon=='small':
        ax.axis('on')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['left'].set_visible(True)
        ax.spines['bottom'].set_visible(True)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_bounds(0,150)
        ax.spines['left'].set_bounds(650,800)
        ax.set_xlabel(f'{basis}1',loc='left',fontsize=fontsize)
        ax.set_ylabel(f'{basis}2',loc='bottom',fontsize=fontsize)

    else:
        ax.axis('on')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['left'].set_visible(True)
        ax.spines['bottom'].set_visible(True)
        ax.spines['top'].set_visible(True)
        ax.spines['right'].set_visible(True)
        ax.set_xlabel(f'{basis}1',loc='center',fontsize=fontsize)
        ax.set_ylabel(f'{basis}2',loc='center',fontsize=fontsize)


    # 调整坐标轴线的粗细
    line_width = 1.2  # 设置线宽
    ax.spines['left'].set_linewidth(line_width)
    ax.spines['bottom'].set_linewidth(line_width)


    if title is None:
        title=color
    ax.set_title(title,fontsize=fontsize+1)

    return ax

omicverse.pl.embedding_multi(data, basis, color=None, use_raw=None, layer=None, **kwargs)

Create embedding scatter plots for multi-modal data (MuData) or single-cell data.

Produces scatter plots on specified embeddings, supporting cross-modality feature visualization. For modality-specific embeddings, use format 'modality:embedding' (e.g., 'rna:X_pca').

Parameters:

Name Type Description Default
data AnnData

AnnData or MuData object containing embedding and feature data

required
basis str

Name of embedding in obsm (e.g., 'X_umap') or modality-specific ('rna:X_pca')

required
color Optional[Union[str, Sequence[str]]]

Gene names or obs columns to color points by (None)

None
use_raw Optional[bool]

Whether to use .raw attribute for features (None, auto-determined)

None
layer Optional[str]

Specific data layer to use for coloring (None)

None
**kwargs

Additional arguments passed to embedding plotting function

{}

Returns:

Type Description

Plot axes or figure depending on underlying plotting function

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_multi.py
def embedding_multi(
    data: AnnData,
    basis: str,
    color: Optional[Union[str, Sequence[str]]] = None,
    use_raw: Optional[bool] = None,
    layer: Optional[str] = None,
    **kwargs,
):
    r"""
    Create embedding scatter plots for multi-modal data (MuData) or single-cell data.

    Produces scatter plots on specified embeddings, supporting cross-modality feature visualization.
    For modality-specific embeddings, use format 'modality:embedding' (e.g., 'rna:X_pca').

    Arguments:
        data: AnnData or MuData object containing embedding and feature data
        basis: Name of embedding in obsm (e.g., 'X_umap') or modality-specific ('rna:X_pca')
        color: Gene names or obs columns to color points by (None)
        use_raw: Whether to use .raw attribute for features (None, auto-determined)
        layer: Specific data layer to use for coloring (None)
        **kwargs: Additional arguments passed to embedding plotting function

    Returns:
        Plot axes or figure depending on underlying plotting function
    """
    if isinstance(data, AnnData):
        return embedding(
            data, basis=basis, color=color, use_raw=use_raw, layer=layer, **kwargs
        )

    # `data` is MuData
    if basis not in data.obsm and "X_" + basis in data.obsm:
        basis = "X_" + basis

    if basis in data.obsm:
        adata = data
        basis_mod = basis
    else:
        # basis is not a joint embedding
        try:
            mod, basis_mod = basis.split(":")
        except ValueError:
            raise ValueError(f"Basis {basis} is not present in the MuData object (.obsm)")

        if mod not in data.mod:
            raise ValueError(
                f"Modality {mod} is not present in the MuData object with modalities {', '.join(data.mod)}"
            )

        adata = data.mod[mod]
        if basis_mod not in adata.obsm:
            if "X_" + basis_mod in adata.obsm:
                basis_mod = "X_" + basis_mod
            elif len(adata.obsm) > 0:
                raise ValueError(
                    f"Basis {basis_mod} is not present in the modality {mod} with embeddings {', '.join(adata.obsm)}"
                )
            else:
                raise ValueError(
                    f"Basis {basis_mod} is not present in the modality {mod} with no embeddings"
                )

    obs = data.obs.loc[adata.obs.index.values]

    if color is None:
        ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp)
        return sc.pl.embedding(ad, basis=basis_mod, **kwargs)

    # Some `color` has been provided
    if isinstance(color, str):
        keys = color = [color]
    elif isinstance(color, Iterable):
        keys = color
    else:
        raise TypeError("Expected color to be a string or an iterable.")

    # Fetch respective features
    if not all([key in obs for key in keys]):
        # {'rna': [True, False], 'prot': [False, True]}
        keys_in_mod = {m: [key in data.mod[m].var_names for key in keys] for m in data.mod}

        # .raw slots might have exclusive var_names
        if use_raw is None or use_raw:
            for i, k in enumerate(keys):
                for m in data.mod:
                    if keys_in_mod[m][i] == False and data.mod[m].raw is not None:
                        keys_in_mod[m][i] = k in data.mod[m].raw.var_names

        # e.g. color="rna:CD8A" - especially relevant for mdata.axis == -1
        mod_key_modifier: dict[str, str] = dict()
        for i, k in enumerate(keys):
            mod_key_modifier[k] = k
            for m in data.mod:
                if not keys_in_mod[m][i]:
                    k_clean = k
                    if k.startswith(f"{m}:"):
                        k_clean = k.split(":", 1)[1]

                    keys_in_mod[m][i] = k_clean in data.mod[m].var_names
                    if keys_in_mod[m][i]:
                        mod_key_modifier[k] = k_clean
                    if use_raw is None or use_raw:
                        if keys_in_mod[m][i] == False and data.mod[m].raw is not None:
                            keys_in_mod[m][i] = k_clean in data.mod[m].raw.var_names

        for m in data.mod:
            if np.sum(keys_in_mod[m]) > 0:
                mod_keys = np.array(keys)[keys_in_mod[m]]
                mod_keys = np.array([mod_key_modifier[k] for k in mod_keys])

                if use_raw is None or use_raw:
                    if data.mod[m].raw is not None:
                        keysidx = data.mod[m].raw.var.index.get_indexer_for(mod_keys)
                        fmod_adata = AnnData(
                            X=data.mod[m].raw.X[:, keysidx],
                            var=pd.DataFrame(index=mod_keys),
                            obs=data.mod[m].obs,
                        )
                    else:
                        if use_raw:
                            warnings.warn(
                                f"Attibute .raw is None for the modality {m}, using .X instead"
                            )
                        fmod_adata = data.mod[m][:, mod_keys]
                else:
                    fmod_adata = data.mod[m][:, mod_keys]

                if layer is not None:
                    if isinstance(layer, Dict):
                        m_layer = layer.get(m, None)
                        if m_layer is not None:
                            x = data.mod[m][:, mod_keys].layers[m_layer]
                            fmod_adata.X = x.todense() if issparse(x) else x
                            if use_raw:
                                warnings.warn(f"Layer='{layer}' superseded use_raw={use_raw}")
                    elif layer in data.mod[m].layers:
                        x = data.mod[m][:, mod_keys].layers[layer]
                        fmod_adata.X = x.todense() if issparse(x) else x
                        if use_raw:
                            warnings.warn(f"Layer='{layer}' superseded use_raw={use_raw}")
                    else:
                        warnings.warn(
                            f"Layer {layer} is not present for the modality {m}, using count matrix instead"
                        )
                x = fmod_adata.X.toarray() if issparse(fmod_adata.X) else fmod_adata.X
                obs = obs.join(
                    pd.DataFrame(x, columns=mod_keys, index=fmod_adata.obs_names),
                    how="left",
                )

        color = [mod_key_modifier[k] for k in keys]

    ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp, uns=adata.uns)
    retval = embedding(ad, basis=basis_mod, color=color, **kwargs)
    for key, col in zip(keys, color):
        try:
            adata.uns[f"{key}_colors"] = ad.uns[f"{col}_colors"]
        except KeyError:
            pass
    return retval

Cell Proportion and Analysis

omicverse.pl.cellproportion(adata, celltype_clusters, groupby, groupby_li=None, figsize=(4, 6), ticks_fontsize=12, labels_fontsize=12, ax=None, legend=False, legend_awargs={'ncol': 1}, transpose=False)

Plot cell proportion of each cell type in each visual cluster.

Parameters:

Name Type Description Default
adata AnnData

AnnData object.

required
celltype_clusters str

Cell type clusters.

required
groupby str

Visual clusters.

required
groupby_li

Visual cluster list. (None)

None
figsize tuple

Figure size. ((4,6))

(4, 6)
ticks_fontsize int

Ticks fontsize. (12)

12
labels_fontsize int

Labels fontsize. (12)

12
ax

Matplotlib axes object. (None)

None
legend bool

Whether to show legend. (False)

False
legend_awargs

Legend arguments. ({'ncol':1})

{'ncol': 1}
transpose bool

Whether to transpose the plot (horizontal bars). (False)

False

Returns:

Type Description

None

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def cellproportion(adata:AnnData,celltype_clusters:str,groupby:str,
                       groupby_li=None,figsize:tuple=(4,6),
                       ticks_fontsize:int=12,labels_fontsize:int=12,ax=None,
                       legend:bool=False,legend_awargs={'ncol':1},transpose:bool=False):
    r"""Plot cell proportion of each cell type in each visual cluster.

    Arguments:
        adata: AnnData object.
        celltype_clusters: Cell type clusters.
        groupby: Visual clusters.
        groupby_li: Visual cluster list. (None)
        figsize: Figure size. ((4,6))
        ticks_fontsize: Ticks fontsize. (12)
        labels_fontsize: Labels fontsize. (12)
        ax: Matplotlib axes object. (None)
        legend: Whether to show legend. (False)
        legend_awargs: Legend arguments. ({'ncol':1})
        transpose: Whether to transpose the plot (horizontal bars). (False)

    Returns:
        None

    """

    b=pd.DataFrame(columns=['cell_type','value','Week'])
    visual_clusters=groupby
    visual_li=groupby_li
    if visual_li==None:
        adata.obs[visual_clusters]=adata.obs[visual_clusters].astype('category')
        visual_li=adata.obs[visual_clusters].cat.categories

    for i in visual_li:
        b1=pd.DataFrame()
        test=adata.obs.loc[adata.obs[visual_clusters]==i,celltype_clusters].value_counts()
        b1['cell_type']=test.index
        b1['value']=test.values/test.sum()
        b1['Week']=i.replace('Retinoblastoma_','')
        b=pd.concat([b,b1])

    plt_data2=adata.obs[celltype_clusters].value_counts()
    plot_data2_color_dict=dict(zip(adata.obs[celltype_clusters].cat.categories,
                                   adata.uns['{}_colors'.format(celltype_clusters)]))
    plt_data3=adata.obs[visual_clusters].value_counts()
    plot_data3_color_dict=dict(zip([i.replace('Retinoblastoma_','') for i in adata.obs[visual_clusters].cat.categories],adata.uns['{}_colors'.format(visual_clusters)]))
    b['cell_type_color'] = b['cell_type'].map(plot_data2_color_dict)
    b['stage_color']=b['Week'].map(plot_data3_color_dict)

    if ax==None:
        fig, ax = plt.subplots(figsize=figsize)
    #用ax控制图片
    #sns.set_theme(style="whitegrid")
    #sns.set_theme(style="ticks")
    n=0
    all_celltype=adata.obs[celltype_clusters].cat.categories
    for i in all_celltype:
        if n==0:
            test1=b[b['cell_type']==i]
            if transpose:
                ax.barh(y=test1['Week'],width=test1['value'],height=0.8,color=list(set(test1['cell_type_color']))[0], label=i)
            else:
                ax.bar(x=test1['Week'],height=test1['value'],width=0.8,color=list(set(test1['cell_type_color']))[0], label=i)
            bottoms=test1['value'].values
        else:
            test2=b[b['cell_type']==i]
            if transpose:
                ax.barh(y=test2['Week'],width=test2['value'],left=bottoms,height=0.8,color=list(set(test2['cell_type_color']))[0], label=i)
            else:
                ax.bar(x=test2['Week'],height=test2['value'],bottom=bottoms,width=0.8,color=list(set(test2['cell_type_color']))[0], label=i)
            test1=test2
            bottoms+=test1['value'].values
        n+=1
    if legend!=False:
        plt.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10,**legend_awargs)

    plt.grid(False)

    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)

    # 设置左边和下边的坐标刻度为透明色
    #ax.yaxis.tick_left()
    #ax.xaxis.tick_bottom()
    #ax.xaxis.set_tick_params(color='none')
    #ax.yaxis.set_tick_params(color='none')

    # 设置左边和下边的坐标轴线为独立的线段
    ax.spines['left'].set_position(('outward', 10))
    ax.spines['bottom'].set_position(('outward', 10))

    if transpose:
        plt.yticks(fontsize=ticks_fontsize,rotation=0)
        plt.xticks(fontsize=ticks_fontsize)
        plt.ylabel(groupby,fontsize=labels_fontsize)
        plt.xlabel('Cells per Stage',fontsize=labels_fontsize)
    else:
        plt.xticks(fontsize=ticks_fontsize,rotation=90)
        plt.yticks(fontsize=ticks_fontsize)
        plt.xlabel(groupby,fontsize=labels_fontsize)
        plt.ylabel('Cells per Stage',fontsize=labels_fontsize)
    #fig.tight_layout()
    if ax==None:
        return fig,ax

Bulk Data Visualization

omicverse.pl.volcano(result, pval_name='qvalue', fc_name='log2FC', pval_max=None, FC_max=None, figsize=(4, 4), title='', titlefont={'weight': 'normal', 'size': 14}, up_color='#e25d5d', down_color='#7388c1', normal_color='#d7d7d7', up_fontcolor='#e25d5d', down_fontcolor='#7388c1', normal_fontcolor='#d7d7d7', legend_bbox=(0.8, -0.2), legend_ncol=2, legend_fontsize=12, plot_genes=None, plot_genes_num=10, plot_genes_fontsize=10, ticks_fontsize=12, pval_threshold=0.05, fc_max=1.5, fc_min=-1.5, ax=None)

Create a volcano plot for differential expression analysis.

Parameters:

Name Type Description Default
result

pandas.DataFrame Dataframe containing differential expression results with 'sig' column

required
pval_name

str, optional (default='qvalue') Column name for p-values/q-values

'qvalue'
fc_name

str, optional (default='log2FC') Column name for fold change values

'log2FC'
pval_max

float, optional (default=None) Maximum p-value for y-axis scaling

None
FC_max

float, optional (default=None) Maximum fold change for x-axis scaling

None
figsize tuple

tuple, optional (default=(4,4)) Figure size (width, height)

(4, 4)
title str

str, optional (default='') Plot title

''
titlefont dict

dict, optional (default={'weight':'normal','size':14}) Font settings for title and axis labels

{'weight': 'normal', 'size': 14}
up_color str

str, optional (default='#e25d5d') Color for upregulated genes

'#e25d5d'
down_color str

str, optional (default='#7388c1') Color for downregulated genes

'#7388c1'
normal_color str

str, optional (default='#d7d7d7') Color for non-significant genes

'#d7d7d7'
up_fontcolor str

str, optional (default='#e25d5d') Font color for upregulated gene labels

'#e25d5d'
down_fontcolor str

str, optional (default='#7388c1') Font color for downregulated gene labels

'#7388c1'
normal_fontcolor str

str, optional (default='#d7d7d7') Font color for normal gene labels

'#d7d7d7'
legend_bbox tuple

tuple, optional (default=(0.8, -0.2)) Legend bounding box position

(0.8, -0.2)
legend_ncol int

int, optional (default=2) Number of legend columns

2
legend_fontsize int

int, optional (default=12) Legend font size

12
plot_genes list

list, optional (default=None) Specific genes to label on plot

None
plot_genes_num int

int, optional (default=10) Number of top genes to label automatically

10
plot_genes_fontsize int

int, optional (default=10) Font size for gene labels

10
ticks_fontsize int

int, optional (default=12) Font size for axis ticks

12
pval_threshold float

float, optional (default=0.05) P-value threshold for significance

0.05
fc_max float

float, optional (default=1.5) Upper fold change threshold

1.5
fc_min float

float, optional (default=-1.5) Lower fold change threshold

-1.5
ax

matplotlib.axes, optional (default=None) Existing axes to plot on

None

Returns:

Name Type Description
ax

matplotlib.axes The plot axes object

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_bulk.py
def volcano(result,pval_name='qvalue',fc_name='log2FC',pval_max=None,FC_max=None,
            figsize:tuple=(4,4),title:str='',titlefont:dict={'weight':'normal','size':14,},
                     up_color:str='#e25d5d',down_color:str='#7388c1',normal_color:str='#d7d7d7',
                     up_fontcolor:str='#e25d5d',down_fontcolor:str='#7388c1',normal_fontcolor:str='#d7d7d7',
                     legend_bbox:tuple=(0.8, -0.2),legend_ncol:int=2,legend_fontsize:int=12,
                     plot_genes:list=None,plot_genes_num:int=10,plot_genes_fontsize:int=10,
                     ticks_fontsize:int=12,pval_threshold:float=0.05,fc_max:float=1.5,fc_min:float=-1.5,
                     ax = None,):
    r"""
    Create a volcano plot for differential expression analysis.

    Arguments:
        result: pandas.DataFrame
            Dataframe containing differential expression results with 'sig' column
        pval_name: str, optional (default='qvalue')
            Column name for p-values/q-values
        fc_name: str, optional (default='log2FC')
            Column name for fold change values  
        pval_max: float, optional (default=None)
            Maximum p-value for y-axis scaling
        FC_max: float, optional (default=None)
            Maximum fold change for x-axis scaling
        figsize: tuple, optional (default=(4,4))
            Figure size (width, height)
        title: str, optional (default='')
            Plot title
        titlefont: dict, optional (default={'weight':'normal','size':14})
            Font settings for title and axis labels
        up_color: str, optional (default='#e25d5d')
            Color for upregulated genes
        down_color: str, optional (default='#7388c1')
            Color for downregulated genes
        normal_color: str, optional (default='#d7d7d7')
            Color for non-significant genes
        up_fontcolor: str, optional (default='#e25d5d')
            Font color for upregulated gene labels
        down_fontcolor: str, optional (default='#7388c1')
            Font color for downregulated gene labels
        normal_fontcolor: str, optional (default='#d7d7d7')
            Font color for normal gene labels
        legend_bbox: tuple, optional (default=(0.8, -0.2))
            Legend bounding box position
        legend_ncol: int, optional (default=2)
            Number of legend columns
        legend_fontsize: int, optional (default=12)
            Legend font size
        plot_genes: list, optional (default=None)
            Specific genes to label on plot
        plot_genes_num: int, optional (default=10)
            Number of top genes to label automatically
        plot_genes_fontsize: int, optional (default=10)
            Font size for gene labels
        ticks_fontsize: int, optional (default=12)
            Font size for axis ticks
        pval_threshold: float, optional (default=0.05)
            P-value threshold for significance
        fc_max: float, optional (default=1.5)
            Upper fold change threshold
        fc_min: float, optional (default=-1.5)
            Lower fold change threshold
        ax: matplotlib.axes, optional (default=None)
            Existing axes to plot on

    Returns:
        ax: matplotlib.axes
            The plot axes object
    """

    # Color codes for terminal output
    class Colors:
        HEADER = '\033[95m'
        BLUE = '\033[94m'
        CYAN = '\033[96m'
        GREEN = '\033[92m'
        WARNING = '\033[93m'
        FAIL = '\033[91m'
        ENDC = '\033[0m'
        BOLD = '\033[1m'
        UNDERLINE = '\033[4m'

    # Analyze the input data
    print(f"{Colors.HEADER}{Colors.BOLD}🌋 Volcano Plot Analysis:{Colors.ENDC}")
    print(f"   {Colors.CYAN}Total genes: {Colors.BOLD}{len(result)}{Colors.ENDC}")

    # Check required columns
    required_cols = [pval_name, fc_name, 'sig']
    missing_cols = [col for col in required_cols if col not in result.columns]
    if missing_cols:
        print(f"   {Colors.FAIL}❌ Missing required columns: {Colors.BOLD}{missing_cols}{Colors.ENDC}")
        raise ValueError(f"Missing required columns: {missing_cols}")

    # Calculate gene counts by significance
    sig_counts = result['sig'].value_counts()
    total_sig = sig_counts.get('up', 0) + sig_counts.get('down', 0)

    print(f"   {Colors.GREEN}↗️  Upregulated genes: {Colors.BOLD}{sig_counts.get('up', 0)}{Colors.ENDC}")
    print(f"   {Colors.BLUE}↘️  Downregulated genes: {Colors.BOLD}{sig_counts.get('down', 0)}{Colors.ENDC}")
    print(f"   {Colors.CYAN}➡️  Non-significant genes: {Colors.BOLD}{sig_counts.get('normal', 0)}{Colors.ENDC}")
    print(f"   {Colors.WARNING}🎯 Total significant genes: {Colors.BOLD}{total_sig}{Colors.ENDC}")

    # Data range information
    fc_range = result[fc_name].max() - result[fc_name].min()
    pval_range = result[pval_name].max() - result[pval_name].min()
    print(f"   {Colors.BLUE}{fc_name} range: {Colors.BOLD}{result[fc_name].min():.2f} to {result[fc_name].max():.2f}{Colors.ENDC}")
    print(f"   {Colors.BLUE}{pval_name} range: {Colors.BOLD}{result[pval_name].min():.2e} to {result[pval_name].max():.2e}{Colors.ENDC}")

    # Display current function parameters
    print(f"\n{Colors.HEADER}{Colors.BOLD}⚙️  Current Function Parameters:{Colors.ENDC}")
    print(f"   {Colors.BLUE}Data columns: pval_name='{pval_name}', fc_name='{fc_name}'{Colors.ENDC}")
    print(f"   {Colors.BLUE}Thresholds: pval_threshold={Colors.BOLD}{pval_threshold}{Colors.ENDC}{Colors.BLUE}, fc_max={Colors.BOLD}{fc_max}{Colors.ENDC}{Colors.BLUE}, fc_min={Colors.BOLD}{fc_min}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Plot size: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Gene labels: plot_genes_num={Colors.BOLD}{plot_genes_num}{Colors.ENDC}{Colors.BLUE}, plot_genes_fontsize={Colors.BOLD}{plot_genes_fontsize}{Colors.ENDC}")
    if plot_genes is not None:
        print(f"   {Colors.BLUE}Custom genes: {Colors.BOLD}{len(plot_genes)} specified{Colors.ENDC}")
    else:
        print(f"   {Colors.BLUE}Custom genes: {Colors.BOLD}None{Colors.ENDC}{Colors.BLUE} (auto-select top genes){Colors.ENDC}")

    # Parameter optimization suggestions
    print(f"\n{Colors.HEADER}{Colors.BOLD}💡 Parameter Optimization Suggestions:{Colors.ENDC}")
    suggestions = []

    # Check if there are enough significant genes
    if total_sig < 10:
        suggestions.append(f"   {Colors.WARNING}▶ Few significant genes detected ({total_sig}):{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: pval_threshold={Colors.BOLD}{pval_threshold}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}pval_threshold=0.1{Colors.ENDC} or {Colors.BOLD}pval_threshold=0.2{Colors.ENDC}")

    # Check fold change thresholds
    if fc_range > 10 and (fc_max <= 2 or abs(fc_min) <= 2):
        new_fc_max = min(round(result[fc_name].quantile(0.95), 1), 4.0)
        new_fc_min = max(round(result[fc_name].quantile(0.05), 1), -4.0)
        suggestions.append(f"   {Colors.WARNING}▶ Wide fold change range detected:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: fc_max={Colors.BOLD}{fc_max}{Colors.ENDC}{Colors.CYAN}, fc_min={Colors.BOLD}{fc_min}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}fc_max={new_fc_max}, fc_min={new_fc_min}{Colors.ENDC}")

    # Check plot size based on gene label settings
    if plot_genes_num > 20 and figsize[0] < 6:
        suggestions.append(f"   {Colors.WARNING}▶ Many gene labels with small plot:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: plot_genes_num={Colors.BOLD}{plot_genes_num}{Colors.ENDC}{Colors.CYAN}, figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}figsize=(6, 6){Colors.ENDC} or {Colors.BOLD}plot_genes_num=15{Colors.ENDC}")

    # Check if gene labels might be too small
    if plot_genes_fontsize < 8 and plot_genes_num > 15:
        suggestions.append(f"   {Colors.WARNING}▶ Small font with many labels:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: plot_genes_fontsize={Colors.BOLD}{plot_genes_fontsize}{Colors.ENDC}{Colors.CYAN}, plot_genes_num={Colors.BOLD}{plot_genes_num}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}plot_genes_fontsize=10{Colors.ENDC} or {Colors.BOLD}plot_genes_num=10{Colors.ENDC}")

    # Check figure aspect ratio
    if abs(figsize[0] - figsize[1]) > 2:
        suggestions.append(f"   {Colors.BLUE}▶ Unbalanced figure dimensions:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
        optimal_size = max(figsize)
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({optimal_size}, {optimal_size}){Colors.ENDC}")

    if suggestions:
        for suggestion in suggestions:
            print(suggestion)
        print(f"\n   {Colors.BOLD}📋 Copy-paste ready function call:{Colors.ENDC}")

        # Generate optimized function call
        optimized_params = ["result"]

        # Add data column parameters if different from defaults
        if pval_name != 'qvalue':
            optimized_params.append(f"pval_name='{pval_name}'")
        if fc_name != 'log2FC':
            optimized_params.append(f"fc_name='{fc_name}'")

        # Add optimized parameters based on suggestions
        if total_sig < 10:
            optimized_params.append("pval_threshold=0.1")

        if fc_range > 10 and (fc_max <= 2 or abs(fc_min) <= 2):
            new_fc_max = min(round(result[fc_name].quantile(0.95), 1), 4.0)
            new_fc_min = max(round(result[fc_name].quantile(0.05), 1), -4.0)
            optimized_params.append(f"fc_max={new_fc_max}")
            optimized_params.append(f"fc_min={new_fc_min}")

        if plot_genes_num > 20 and figsize[0] < 6:
            optimized_params.append("figsize=(6, 6)")
        elif abs(figsize[0] - figsize[1]) > 2:
            optimal_size = max(figsize)
            optimized_params.append(f"figsize=({optimal_size}, {optimal_size})")

        if plot_genes_fontsize < 8 and plot_genes_num > 15:
            optimized_params.append("plot_genes_fontsize=10")

        if plot_genes is not None:
            optimized_params.append(f"plot_genes={plot_genes}")

        optimized_call = f"   {Colors.GREEN}ov.pl.volcano({', '.join(optimized_params)}){Colors.ENDC}"
        print(optimized_call)
    else:
        print(f"   {Colors.GREEN}✅ Current parameters are optimal for your data!{Colors.ENDC}")

    print(f"{Colors.CYAN}{'─' * 60}{Colors.ENDC}")

    # Original volcano plot code starts here
    result=result.copy()
    result['-log(qvalue)']=-np.log10(result[pval_name])
    result['log2FC']= result[fc_name].copy()
    if pval_max!=None:
        result.loc[result['-log(qvalue)']>pval_max,'-log(qvalue)']=pval_max
    if FC_max!=None:
        result.loc[result['log2FC']>FC_max,'log2FC']=FC_max
        result.loc[result['log2FC']<-FC_max,'log2FC']=0-FC_max

    if ax==None:
        fig, ax = plt.subplots(figsize=figsize)

    ax.scatter(x=result[result['sig']=='normal']['log2FC'],
            y=result[result['sig']=='normal']['-log(qvalue)'],
            color=normal_color,#颜色
            alpha=.5,#透明度
            )
    #接着绘制上调基因
    ax.scatter(x=result[result['sig']=='up']['log2FC'],
            y=result[result['sig']=='up']['-log(qvalue)'],
            color=up_color,#选择色卡第15个颜色
            alpha=.5,#透明度
            )
    #绘制下调基因
    ax.scatter(x=result[result['sig']=='down']['log2FC'],
            y=result[result['sig']=='down']['-log(qvalue)'],
            color=down_color,#颜色
            alpha=.5,#透明度
            )

    ax.plot([result['log2FC'].min(),result['log2FC'].max()],#辅助线的x值起点与终点
            [-np.log10(pval_threshold),-np.log10(pval_threshold)],#辅助线的y值起点与终点
            linewidth=2,#辅助线的宽度
            linestyle="--",#辅助线类型:虚线
            color='black'#辅助线的颜色
    )
    ax.plot([fc_max,fc_max],
            [result['-log(qvalue)'].min(),result['-log(qvalue)'].max()],
            linewidth=2, 
            linestyle="--",
            color='black')
    ax.plot([fc_min,fc_min],
            [result['-log(qvalue)'].min(),result['-log(qvalue)'].max()],
            linewidth=2, 
            linestyle="--",
            color='black')
    #设置横标签与纵标签
    ax.set_ylabel(r'$-log_{10}(qvalue)$',titlefont)                                    
    ax.set_xlabel(r'$log_{2}FC$',titlefont)
    #设置标题
    ax.set_title(title,titlefont)

    #绘制图注
    #legend标签列表,上面的color即是颜色列表
    labels = ['up:{0}'.format(len(result[result['sig']=='up'])),
            'down:{0}'.format(len(result[result['sig']=='down']))]  
    #用label和color列表生成mpatches.Patch对象,它将作为句柄来生成legend
    color = [up_color,down_color]
    patches = [mpatches.Patch(color=color[i], label="{:s}".format(labels[i]) ) for i in range(len(color))] 

    ax.legend(handles=patches,
        bbox_to_anchor=legend_bbox, 
        ncol=legend_ncol,
        fontsize=legend_fontsize)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)

    ax.set_xticks([round(i,2) for i in ax.get_xticks()[1:-1]],#获取x坐标轴内容
        [round(i,2) for i in ax.get_xticks()[1:-1]],#更新x坐标轴内容
        fontsize=ticks_fontsize,
        fontweight='normal'
        )

    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['left'].set_position(('outward', 10))
    ax.spines['bottom'].set_position(('outward', 10))


    from adjustText import adjust_text
    import adjustText

    if plot_genes is not None:
        hub_gene=plot_genes
    elif (plot_genes is None) and (plot_genes_num is None):
        return ax
    else:
        up_result=result.loc[result['sig']=='up']
        down_result=result.loc[result['sig']=='down']
        hub_gene=up_result.sort_values(pval_name).index[:plot_genes_num//2].tolist()+down_result.sort_values(pval_name).index[:plot_genes_num//2].tolist()

    color_dict={
    'up':up_fontcolor,
        'down':down_fontcolor,
        'normal':normal_fontcolor
    }

    texts=[ax.text(result.loc[i,'log2FC'], 
        result.loc[i,'-log(qvalue)'],
        i,
        fontdict={'size':plot_genes_fontsize,'weight':'bold','color':color_dict[result.loc[i,'sig']]}
        ) for i in hub_gene]

    if adjustText.__version__<='0.8':
        adjust_text(texts,only_move={'text': 'xy'},arrowprops=dict(arrowstyle='->', color='red'),)
    else:
        adjust_text(texts,only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
                    arrowprops=dict(arrowstyle='->', color='red'))


    return ax

omicverse.pl.venn(sets={}, out='./', palette='bgrc', ax=False, ext='png', dpi=300, fontsize=3.5, bbox_to_anchor=(0.5, 0.99), nc=2, cs=4)

Create a Venn diagram to visualize set overlaps.

Parameters:

Name Type Description Default
sets

Dictionary with set names as keys and sets as values ({})

{}
out

Output directory path ('./')

'./'
palette

Color palette for sets ('bgrc')

'bgrc'
ax

Matplotlib axes object or False to create new (False)

False
ext

File extension for output ('png')

'png'
dpi

Resolution for output image (300)

300
fontsize

Font size for text (3.5)

3.5
bbox_to_anchor

Legend position ((.5, .99))

(0.5, 0.99)
nc

Number of legend columns (2)

2
cs

Font size for legend (4)

4

Returns:

Name Type Description
ax

matplotlib.axes.Axes object

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_bulk.py
def venn(sets={}, out='./', palette='bgrc',
             ax=False, ext='png', dpi=300, fontsize=3.5,
             bbox_to_anchor=(.5, .99),nc=2,cs=4):
    r"""
    Create a Venn diagram to visualize set overlaps.

    Arguments:
        sets: Dictionary with set names as keys and sets as values ({})
        out: Output directory path ('./')
        palette: Color palette for sets ('bgrc')
        ax: Matplotlib axes object or False to create new (False)
        ext: File extension for output ('png')
        dpi: Resolution for output image (300)
        fontsize: Font size for text (3.5)
        bbox_to_anchor: Legend position ((.5, .99))
        nc: Number of legend columns (2)
        cs: Font size for legend (4)

    Returns:
        ax: matplotlib.axes.Axes object
    """

    from ..utils import venny4py
    venny4py(sets=sets,out=out,ce=palette,asax=ax,ext=ext,
             dpi=dpi,size=fontsize,bbox_to_anchor=bbox_to_anchor,
             nc=nc,cs=cs,
             )
    return ax

omicverse.pl.boxplot(data, hue, x_value, y_value, width=0.3, title='', figsize=(6, 3), palette=None, fontsize=10, legend_bbox=(1, 0.55), legend_ncol=1, hue_order=None)

Create a boxplot with jittered points to visualize data distribution across categories.

Parameters:

Name Type Description Default
data

DataFrame containing the data to plot

required
hue

Column name for grouping variable (color coding)

required
x_value

Column name for x-axis categories

required
y_value

Column name for y-axis values

required
width

Width of each boxplot (0.3)

0.3
title

Plot title ('')

''
figsize

Figure dimensions as (width, height) ((6,3))

(6, 3)
palette

Color palette for groups (None, uses default)

None
fontsize

Font size for labels and ticks (10)

10
legend_bbox

Legend position as (x, y) ((1, 0.55))

(1, 0.55)
legend_ncol

Number of legend columns (1)

1
hue_order

Custom order for hue categories (None, uses alphabetical)

None

Returns:

Name Type Description
fig

matplotlib.figure.Figure object

ax

matplotlib.axes.Axes object

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_bulk.py
def boxplot(data,hue,x_value,y_value,width=0.3,title='',
                 figsize=(6,3),palette=None,fontsize=10,
                 legend_bbox=(1, 0.55),legend_ncol=1,hue_order=None):
    r"""
    Create a boxplot with jittered points to visualize data distribution across categories.

    Arguments:
        data: DataFrame containing the data to plot
        hue: Column name for grouping variable (color coding)
        x_value: Column name for x-axis categories
        y_value: Column name for y-axis values
        width: Width of each boxplot (0.3)
        title: Plot title ('')
        figsize: Figure dimensions as (width, height) ((6,3))
        palette: Color palette for groups (None, uses default)
        fontsize: Font size for labels and ticks (10)
        legend_bbox: Legend position as (x, y) ((1, 0.55))
        legend_ncol: Number of legend columns (1)
        hue_order: Custom order for hue categories (None, uses alphabetical)

    Returns:
        fig: matplotlib.figure.Figure object
        ax: matplotlib.axes.Axes object
    """

    # Color codes for terminal output
    class Colors:
        HEADER = '\033[95m'
        BLUE = '\033[94m'
        CYAN = '\033[96m'
        GREEN = '\033[92m'
        WARNING = '\033[93m'
        FAIL = '\033[91m'
        ENDC = '\033[0m'
        BOLD = '\033[1m'
        UNDERLINE = '\033[4m'

    # Print data information for user guidance
    print(f"{Colors.HEADER}{Colors.BOLD}📊 Boxplot Data Analysis:{Colors.ENDC}")
    print(f"   {Colors.CYAN}Total samples: {Colors.BOLD}{len(data)}{Colors.ENDC}")
    print(f"   {Colors.BLUE}X-axis variable ('{x_value}'): {Colors.BOLD}{sorted(data[x_value].unique())}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Hue variable ('{hue}'): {Colors.BOLD}{sorted(data[hue].unique())}{Colors.ENDC}")
    print(f"   {Colors.BLUE}Y-axis variable: '{y_value}' (range: {Colors.BOLD}{data[y_value].min():.2f} - {data[y_value].max():.2f}{Colors.ENDC})")

    # Check for missing data
    missing_data = data[[hue, x_value, y_value]].isnull().sum().sum()
    if missing_data > 0:
        print(f"   {Colors.WARNING}⚠️  Warning: Found {Colors.BOLD}{missing_data}{Colors.ENDC}{Colors.WARNING} missing values in key columns{Colors.ENDC}")

    # Display current function parameters
    print(f"\n{Colors.HEADER}{Colors.BOLD}⚙️  Current Function Parameters:{Colors.ENDC}")
    print(f"   {Colors.BLUE}hue='{hue}', x_value='{x_value}', y_value='{y_value}'{Colors.ENDC}")
    print(f"   {Colors.BLUE}width={Colors.BOLD}{width}{Colors.ENDC}{Colors.BLUE}, figsize={Colors.BOLD}{figsize}{Colors.ENDC}{Colors.BLUE}, fontsize={Colors.BOLD}{fontsize}{Colors.ENDC}")
    if hue_order is not None:
        print(f"   {Colors.BLUE}hue_order={Colors.BOLD}{hue_order}{Colors.ENDC}")
    else:
        print(f"   {Colors.BLUE}hue_order={Colors.BOLD}None{Colors.ENDC}{Colors.BLUE} (using alphabetical order){Colors.ENDC}")

    def calculate_box_positions(n_hues, spacing=0.8):
        """
        Calculate evenly distributed positions for boxes within the range [-0.5, 0.5].

        Parameters
        ----------
        n_hues : int
            Number of hue categories
        spacing : float
            Fraction of the total range to use for spacing boxes (0.8 means use 80% of [-0.5, 0.5])

        Returns
        -------
        positions : list
            List of positions for each hue
        """
        if n_hues == 1:
            return [0.0]

        # Calculate the range to use for positioning
        total_range = spacing  # Use 80% of the [-0.5, 0.5] range by default
        half_range = total_range / 2

        # Calculate positions evenly distributed within the range
        if n_hues > 1:
            step = total_range / (n_hues - 1)
            positions = [-half_range + i * step for i in range(n_hues)]
        else:
            positions = [0.0]

        return positions

    #获取需要分割的数据
    if hue_order is not None:
        hue_datas = hue_order
        # Check if all hue values in data are in hue_order
        data_hue_values = set(data[hue].unique())
        hue_order_set = set(hue_order)
        if not data_hue_values.issubset(hue_order_set):
            missing_values = data_hue_values - hue_order_set
            raise ValueError(f"The following hue values are in data but not in hue_order: {missing_values}")
        print(f"   {Colors.GREEN}📋 Using custom hue order: {Colors.BOLD}{hue_order}{Colors.ENDC}")
    else:
        hue_datas = sorted(list(set(data[hue])))
        print(f"   {Colors.GREEN}📋 Using alphabetical hue order: {Colors.BOLD}{hue_datas}{Colors.ENDC}")

    #获取箱线图的横坐标
    x=x_value
    ticks=sorted(list(set(data[x])))

    # Calculate box positions
    box_positions = calculate_box_positions(len(hue_datas))
    print(f"\n{Colors.HEADER}{Colors.BOLD}🎯 Box Positioning:{Colors.ENDC}")
    print(f"   {Colors.CYAN}Number of hue groups: {Colors.BOLD}{len(hue_datas)}{Colors.ENDC}")
    print(f"   {Colors.CYAN}Box positions: {Colors.BOLD}{[round(pos, 3) for pos in box_positions]}{Colors.ENDC}")
    print(f"   {Colors.CYAN}Box width: {Colors.BOLD}{width}{Colors.ENDC}")

    # Calculate sample sizes for each combination
    print(f"\n{Colors.HEADER}{Colors.BOLD}📈 Sample sizes per group:{Colors.ENDC}")
    for hue_cat in hue_datas:
        for x_cat in ticks:
            count = len(data[(data[hue] == hue_cat) & (data[x] == x_cat)])
            if count < 5:
                color = Colors.WARNING
            elif count < 10:
                color = Colors.BLUE
            else:
                color = Colors.GREEN
            print(f"   {color}{hue_cat} × {x_cat}: {Colors.BOLD}{count}{Colors.ENDC}{color} samples{Colors.ENDC}")

    # Provide parameter suggestions with current vs suggested comparison
    print(f"\n{Colors.HEADER}{Colors.BOLD}💡 Parameter Optimization Suggestions:{Colors.ENDC}")
    suggestions = []

    if len(hue_datas) > 4:
        suggested_width = round(max(0.1, 0.8 / len(hue_datas)), 1)
        suggested_figsize_width = max(8, len(ticks) * 2)
        suggestions.append(f"   {Colors.WARNING}▶ Many hue groups detected:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: {Colors.BOLD}width={width}{Colors.ENDC}{Colors.CYAN}, figsize={figsize}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}width={suggested_width}, figsize=({suggested_figsize_width}, {figsize[1]}){Colors.ENDC}")

    if len(ticks) > 5:
        suggested_figsize_width = max(10, len(ticks) * 1.5)
        suggestions.append(f"   {Colors.WARNING}▶ Many x-categories detected:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: {Colors.BOLD}figsize={figsize}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_figsize_width}, {figsize[1]}){Colors.ENDC}")

    max_samples = max([len(data[(data[hue] == h) & (data[x] == x_cat)]) for h in hue_datas for x_cat in ticks])
    if max_samples < 5:
        suggested_width = round(max(0.1, width * 0.7), 1)
        suggestions.append(f"   {Colors.WARNING}▶ Small sample sizes detected:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: {Colors.BOLD}width={width}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}width={suggested_width}{Colors.ENDC}")

    # Check if current width might cause overlap
    if len(hue_datas) > 3 and width > 0.25:
        suggested_width = round(0.8 / len(hue_datas), 1)
        suggestions.append(f"   {Colors.FAIL}▶ Box overlap risk detected:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: {Colors.BOLD}width={width}{Colors.ENDC}{Colors.CYAN} (too wide for {len(hue_datas)} groups){Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}width={suggested_width}{Colors.ENDC}")

    # Figure size optimization based on both dimensions
    if len(ticks) > 3 and len(hue_datas) > 3:
        suggested_width = max(8, len(ticks) * 1.5)
        suggested_height = max(4, figsize[1])
        suggestions.append(f"   {Colors.BLUE}▶ Complex plot optimization:{Colors.ENDC}")
        suggestions.append(f"     {Colors.CYAN}Current: {Colors.BOLD}figsize={figsize}{Colors.ENDC}")
        suggestions.append(f"     {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_width}, {suggested_height}){Colors.ENDC}")

    if suggestions:
        for suggestion in suggestions:
            print(suggestion)
        print(f"\n   {Colors.BOLD}📋 Copy-paste ready function call:{Colors.ENDC}")
        # Generate optimized function call
        optimized_params = []
        optimized_params.append(f"data, hue='{hue}', x_value='{x_value}', y_value='{y_value}'")

        # Add optimized parameters
        if len(hue_datas) > 4 or (len(hue_datas) > 3 and width > 0.25) or max_samples < 5:
            if len(hue_datas) > 4:
                opt_width = round(max(0.1, 0.8 / len(hue_datas)), 1)
            elif max_samples < 5:
                opt_width = round(max(0.1, width * 0.7), 1)
            else:
                opt_width = round(0.8 / len(hue_datas), 1)
            optimized_params.append(f"width={opt_width}")

        if len(ticks) > 5 or len(hue_datas) > 4 or (len(ticks) > 3 and len(hue_datas) > 3):
            if len(hue_datas) > 4:
                opt_fig_w = max(8, len(ticks) * 2)
            elif len(ticks) > 5:
                opt_fig_w = max(10, len(ticks) * 1.5)
            else:
                opt_fig_w = max(8, len(ticks) * 1.5)
            opt_fig_h = max(4, figsize[1])
            optimized_params.append(f"figsize=({opt_fig_w}, {opt_fig_h})")

        if hue_order is not None:
            optimized_params.append(f"hue_order={hue_order}")

        optimized_call = f"   {Colors.GREEN}ov.pl.boxplot({', '.join(optimized_params)}){Colors.ENDC}"
        print(optimized_call)
    else:
        print(f"   {Colors.GREEN}✅ Current parameters are optimal for your data!{Colors.ENDC}")

    print(f"{Colors.CYAN}{'─' * 60}{Colors.ENDC}")

    #在这个数据中,我们有6个不同的癌症,每个癌症都有2个基因(2个箱子)
    #所以我们需要得到每一个基因的6个箱线图位置,6个散点图的抖动
    plot_data1={}#字典里的每一个元素就是每一个基因的所有值
    plot_data_random1={}#字典里的每一个元素就是每一个基因的随机20个值
    plot_data_xs1={}#字典里的每一个元素就是每一个基因的20个抖动值


    #箱子的参数
    #width=0.6
    y=y_value
    import random
    for hue_data,num in zip(hue_datas,box_positions):
        data_a=[]
        data_a_random=[]
        data_a_xs=[]
        for i,k in zip(ticks,range(len(ticks))):
            test_data=data.loc[((data[x]==i)&(data[hue]==hue_data)),y].tolist()
            data_a.append(test_data)
            if len(test_data)<20:
                data_size=len(test_data)
            else:
                data_size=20
            if len(test_data) > 0:
                random_data=random.sample(test_data,data_size)
            else:
                random_data=[]
            data_a_random.append(random_data)
            data_a_xs.append(np.random.normal(k+num, 0.04, len(random_data)))
        #data_a=np.array(data_a)
        data_a_random=np.array(data_a_random,dtype=object)
        plot_data1[hue_data]=data_a 
        plot_data_random1[hue_data]=data_a_random
        plot_data_xs1[hue_data]=data_a_xs

    fig, ax = plt.subplots(figsize=figsize)
    #色卡
    if palette==None:
        from ._palette import sc_color
        palette=sc_color
    #palette=["#a64d79","#674ea7"]
    #绘制箱线图
    for hue_data,hue_color,num in zip(hue_datas,palette,box_positions):
        b1=ax.boxplot(plot_data1[hue_data], 
                    positions=np.array(range(len(ticks)))+num, 
                    sym='', 
                    widths=width,)
        plt.setp(b1['boxes'], color=hue_color)
        plt.setp(b1['whiskers'], color=hue_color)
        plt.setp(b1['caps'], color=hue_color)
        plt.setp(b1['medians'], color=hue_color)

        clevels = np.linspace(0., 1., len(plot_data_random1[hue_data]))
        for x, val, clevel in zip(plot_data_xs1[hue_data], plot_data_random1[hue_data], clevels):
            if len(val) > 0:  # Only plot if there's data
                plt.scatter(x, val,c=hue_color,alpha=0.4)

    #坐标轴字体
    #fontsize=10
    #修改横坐标
    ax.set_xticks(range(len(ticks)), ticks,fontsize=fontsize)
    #修改纵坐标
    yticks=ax.get_yticks()
    ax.set_yticks(yticks[yticks>=0],yticks[yticks>=0],fontsize=fontsize)

    labels = hue_datas  #legend标签列表,上面的color即是颜色列表
    color = palette
    #用label和color列表生成mpatches.Patch对象,它将作为句柄来生成legend
    patches = [ mpatches.Patch(color=color[i], label="{:s}".format(labels[i]) ) for i in range(len(hue_datas)) ] 
    ax.legend(handles=patches,bbox_to_anchor=legend_bbox, ncol=legend_ncol,fontsize=fontsize)

    #设置标题
    ax.set_title(title,fontsize=fontsize+1)
    #设置spines可视化情况
    plt.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(True)
    ax.spines['left'].set_visible(True)
    ax.spines['left'].set_position(('outward', 10))
    ax.spines['bottom'].set_position(('outward', 10))

    return fig,ax

Density and Contour

omicverse.pl.calculate_gene_density(adata, features, basis='X_umap', dims=(0, 1), adjust=1, min_expr=0.1)

Calculate weighted kernel density estimates for gene expression on 2D embeddings.

Computes KDE for each feature using expression values as weights and stores density values in adata.obs as 'density_{feature}' columns.

Parameters:

Name Type Description Default
adata

Annotated data object with embedding coordinates

required
features

List of gene names or feature names to process

required
basis

Key in adata.obsm containing 2D embedding coordinates ('X_umap')

'X_umap'
dims

Embedding dimensions to use as (x_dim, y_dim) ((0, 1))

(0, 1)
adjust

Bandwidth scaling factor for KDE (1)

1
min_expr

Minimum expression threshold for including cells (0.1)

0.1

Returns:

Name Type Description
None

Updates adata.obs with density_{feature} columns

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_density.py
def calculate_gene_density(
    adata,
    features,
    basis="X_umap",
    dims=(0, 1),
    adjust=1,
    min_expr=0.1,          # NEW: minimal raw expression to keep as weight > 0
):
    r"""
    Calculate weighted kernel density estimates for gene expression on 2D embeddings.

    Computes KDE for each feature using expression values as weights and stores
    density values in adata.obs as 'density_{feature}' columns.

    Arguments:
        adata: Annotated data object with embedding coordinates
        features: List of gene names or feature names to process
        basis: Key in adata.obsm containing 2D embedding coordinates ('X_umap')
        dims: Embedding dimensions to use as (x_dim, y_dim) ((0, 1))
        adjust: Bandwidth scaling factor for KDE (1)
        min_expr: Minimum expression threshold for including cells (0.1)

    Returns:
        None: Updates adata.obs with density_{feature} columns
    """
    if len(dims) != 2:
        raise ValueError("`dims` must have length 2")
    if basis not in adata.obsm:
        raise ValueError(f"Embedding '{basis}' not found.")

    emb_all = adata.obsm[basis][:, dims]          # (n_cells, 2)

    for feat in features:
        # ----- fetch raw weights -------------------------------------------
        if feat in adata.obs:
            w_raw = adata.obs[feat].to_numpy()
        elif feat in adata.var_names:
            w_raw = adata[:, feat].X.toarray().ravel()
        else:
            raise ValueError(f"Feature '{feat}' not found in obs or var.")

        # ----- validity mask: finite coords & finite expr -------------------
        mask_finite = np.isfinite(w_raw) & np.all(np.isfinite(emb_all), axis=1)

        # ----- NEW: expression threshold -----------------------------------
        mask_expr   = w_raw > min_expr
        mask_train  = mask_finite & mask_expr

        emb_train   = emb_all[mask_train]
        w_train_raw = w_raw[mask_train]

        if emb_train.shape[0] < 5:
            print(f"[{feat}] too few cells above threshold; skipping KDE.")
            adata.obs[f"density_{feat}"] = np.nan
            continue

        # ----- min–max scale to 0-1 ----------------------------------------
        w_min, w_max = w_train_raw.min(), w_train_raw.max()
        w_train      = (w_train_raw - w_min) / (w_max - w_min)

        # ----- KDE fit ------------------------------------------------------
        kde = gaussian_kde(emb_train.T, weights=w_train, bw_method=adjust)

        # ----- evaluate on ALL cells ---------------------------------------
        density = kde(emb_all.T)
        adata.obs[f"density_{feat}"] = density

        print(f"✅ density_{feat} written (train cells = {emb_train.shape[0]})")

omicverse.pl.add_density_contour(ax, embeddings, weights, levels='quantile', n_quantiles=5, bw_adjust=0.3, cmap_contour='Greys', linewidth=1.0, zorder=10, fill=False, alpha=0.4)

Add KDE-based density contours to an existing matplotlib plot.

Parameters:

Name Type Description Default
ax

matplotlib.axes.Axes object to draw contours on

required
embeddings

2D coordinate array with shape (n_cells, 2)

required
weights

1D weight array for KDE, will be min-max normalized

required
levels

Contour level specification - 'quantile' or list of values ('quantile')

'quantile'
n_quantiles

Number of quantile levels when levels='quantile' (5)

5
bw_adjust

Bandwidth adjustment factor for KDE (0.3)

0.3
cmap_contour

Colormap for contour lines ('Greys')

'Greys'
linewidth

Width of contour lines (1.0)

1.0
zorder

Drawing order for contours (10)

10
fill

Whether to fill contours (False for lines, True for filled)

False
alpha

Transparency for filled contours (0.4)

0.4

Returns:

Name Type Description
cs

matplotlib contour object for potential colorbar addition

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_density.py
def add_density_contour(
    ax,
    embeddings,              # (n_cells, 2) array
    weights,                 # 1-D array, will be min-max scaled
    levels="quantile",       # "quantile" or a numeric list, see below
    n_quantiles=5,
    bw_adjust=0.3,
    cmap_contour="Greys",
    linewidth=1.0,
    zorder=10,
    fill=False,
    alpha=0.4,
):
    r"""
    Add KDE-based density contours to an existing matplotlib plot.

    Arguments:
        ax: matplotlib.axes.Axes object to draw contours on
        embeddings: 2D coordinate array with shape (n_cells, 2)
        weights: 1D weight array for KDE, will be min-max normalized
        levels: Contour level specification - 'quantile' or list of values ('quantile')
        n_quantiles: Number of quantile levels when levels='quantile' (5)
        bw_adjust: Bandwidth adjustment factor for KDE (0.3)
        cmap_contour: Colormap for contour lines ('Greys')
        linewidth: Width of contour lines (1.0)
        zorder: Drawing order for contours (10)
        fill: Whether to fill contours (False for lines, True for filled)
        alpha: Transparency for filled contours (0.4)

    Returns:
        cs: matplotlib contour object for potential colorbar addition
    """
    # ---------- fit KDE ----------------------------------------------------
    w_min, w_max = weights.min(), weights.max()
    w_norm = None if w_max == w_min else (weights - w_min) / (w_max - w_min)
    kde = gaussian_kde(embeddings.T, weights=w_norm, bw_method=bw_adjust)

    # ---------- prepare evaluation grid -----------------------------------
    xmin, xmax = embeddings[:, 0].min(), embeddings[:, 0].max()
    ymin, ymax = embeddings[:, 1].min(), embeddings[:, 1].max()
    xx, yy = np.mgrid[xmin:xmax:300j, ymin:ymax:300j]   # 300×300 grid
    grid = kde(np.vstack([xx.ravel(), yy.ravel()])).reshape(xx.shape)

    # ---------- determine contour levels ----------------------------------
    if levels == "quantile":
        qs = np.linspace(0, 1, n_quantiles + 2)[1:-1]   # drop 0 & 1
        levels = np.quantile(grid, qs)

    # ---------- draw -------------------------------------------------------
    if fill:
        cs = ax.contourf(
            xx, yy, grid,
            levels=levels,
            cmap=cmap_contour,
            alpha=alpha,
            zorder=zorder,
        )
    else:
        cs = ax.contour(
            xx, yy, grid,
            levels=levels,
            cmap=cmap_contour,
            linewidths=linewidth,
            zorder=zorder,
        )
    return cs   # so you can add a colorbar if desired

Spatial Plotting

omicverse.pl.spatial_segment(adata, seg_cell_id, seg=True, seg_key='segmentation', seg_contourpx=None, seg_outline=False, **kwargs)

Plot spatial omics data with segmentation masks on top.

Argument seg_cell_id in :attr:anndata.AnnData.obs controls unique segmentation mask's ids to be plotted. By default, 'segmentation', seg_key for the segmentation and 'hires' for the image is attempted.

- Use ``seg_key`` to display the image in the background.
- Use ``seg_contourpx`` or ``seg_outline`` to control how the segmentation mask is displayed.

%(spatial_plot.summary_ext)s

.. seealso:: - :func:squidpy.pl.spatial_scatter on how to plot spatial data with overlayed data on top.

Parameters

%(adata)s %(plotting_segment)s %(color)s %(groups)s %(library_id)s %(library_key)s %(spatial_key)s %(plotting_image)s %(plotting_features)s

Returns

%(spatial_plot.returns)s

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_spatial.py
def spatial_segment(
    adata: AnnData,
    seg_cell_id: str,
    seg: bool | Sequence[np.ndarray] | None = True,
    seg_key: str = "segmentation",
    seg_contourpx: int | None = None,
    seg_outline: bool = False,
    **kwargs: Any,
) -> Axes | Sequence[Axes] | None:
    """
    Plot spatial omics data with segmentation masks on top.

    Argument ``seg_cell_id`` in :attr:`anndata.AnnData.obs` controls unique segmentation mask's ids to be plotted.
    By default, ``'segmentation'``, ``seg_key`` for the segmentation and ``'hires'`` for the image is attempted.

        - Use ``seg_key`` to display the image in the background.
        - Use ``seg_contourpx`` or ``seg_outline`` to control how the segmentation mask is displayed.

    %(spatial_plot.summary_ext)s

    .. seealso::
        - :func:`squidpy.pl.spatial_scatter` on how to plot spatial data with overlayed data on top.

    Parameters
    ----------
    %(adata)s
    %(plotting_segment)s
    %(color)s
    %(groups)s
    %(library_id)s
    %(library_key)s
    %(spatial_key)s
    %(plotting_image)s
    %(plotting_features)s

    Returns
    -------
    %(spatial_plot.returns)s
    """
    from squidpy._docs import d

    return _spatial_plot(
        adata,
        seg=seg,
        seg_key=seg_key,
        seg_cell_id=seg_cell_id,
        seg_contourpx=seg_contourpx,
        seg_outline=seg_outline,
        **kwargs,
    )

omicverse.pl.spatial_segment_overlay(adata, seg_cell_id, color, seg=True, seg_key='segmentation', seg_contourpx=None, seg_outline=False, alpha=0.5, cmaps=None, library_id=None, library_key=None, spatial_key='spatial', img=True, img_res_key='hires', img_alpha=None, img_cmap=None, img_channel=None, use_raw=None, layer=None, alt_var=None, groups=None, palette=None, norm=None, na_color=(0, 0, 0, 0), size=None, size_key='spot_diameter_fullres', scale_factor=None, crop_coord=None, connectivity_key=None, edges_width=1.0, edges_color='grey', frameon=None, legend_loc='right margin', legend_fontsize=None, legend_fontweight='bold', legend_fontoutline=None, legend_na=True, colorbar=True, colorbar_position='bottom', colorbar_grid=None, colorbar_tick_size=10, colorbar_title_size=12, colorbar_width=None, colorbar_height=None, colorbar_spacing=None, scalebar_dx=None, scalebar_units=None, title=None, axis_label=None, fig=None, ax=None, return_ax=False, figsize=None, dpi=None, save=None, scalebar_kwargs=MappingProxyType({}), edges_kwargs=MappingProxyType({}), **kwargs)

Plot multiple genes overlaid on the same spatial segmentation plot.

This function allows visualization of multiple genes in the same spatial context, using different colors and transparency to show expression overlap and co-localization.

Parameters

%(adata)s seg_cell_id Key in :attr:anndata.AnnData.obs that contains unique cell IDs for segmentation. color Gene names or features to plot. Can be a single gene or list of genes. seg Segmentation mask. If True, uses segmentation from adata.uns['spatial']. seg_key Key for segmentation mask in adata.uns['spatial'][library_id]['images']. seg_contourpx Contour width in pixels. If specified, segmentation boundaries will be eroded. seg_outline If True, show segmentation boundaries. alpha Transparency level for gene expression overlay (0-1). cmaps Colormap(s) for each gene. If single string, uses same colormap for all genes. If list, should match length of color parameter. %(library_id)s %(library_key)s
%(spatial_key)s %(plotting_image)s %(plotting_features)s groups Categories to plot for categorical features. palette Color palette for categorical features. norm Colormap normalization. na_color Color for NA/missing values. colorbar Whether to show colorbars for each gene. colorbar_position Position of colorbars: 'bottom', 'right', or 'none'. colorbar_grid Grid layout for colorbars as (rows, cols). If None, auto-determined. colorbar_tick_size Font size for colorbar tick labels. colorbar_title_size Font size for colorbar titles. colorbar_width Width of individual colorbars. If None, uses default values. colorbar_height Height of individual colorbars. If None, uses default values. colorbar_spacing Dictionary with spacing parameters: 'hspace' (vertical gaps), 'wspace' (horizontal gaps). If None, uses default spacing. title Plot title. ax Matplotlib axes object to plot on. return_ax Whether to return the axes object. figsize Figure size (width, height). save Path to save the figure.

Returns

If return_ax is True, returns matplotlib axes object, otherwise None.

Examples

import squidpy as sq

Overlay two genes with different colors

sq.pl.spatial_segment_overlay( ... adata, ... seg_cell_id='cell_id', ... color=['gene1', 'gene2'], ... cmaps=['Reds', 'Blues'], ... alpha=0.7 ... )

Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_spatial.py
def spatial_segment_overlay(
    adata: AnnData,
    seg_cell_id: str,
    color: str | Sequence[str],
    seg: bool | Sequence[np.ndarray] | None = True,
    seg_key: str = "segmentation",
    seg_contourpx: int | None = None,
    seg_outline: bool = False,
    alpha: float = 0.5,
    cmaps: str | Sequence[str] | None = None,
    library_id: _SeqStr | None = None,
    library_key: str | None = None,
    spatial_key: str = "spatial",  
    img: bool | Sequence[np.ndarray] | None = True,
    img_res_key: str | None = "hires",
    img_alpha: float | None = None,
    img_cmap: Colormap | str | None = None,
    img_channel: int | list[int] | None = None,
    use_raw: bool | None = None,
    layer: str | None = None,
    alt_var: str | None = None,
    groups: _SeqStr | None = None,
    palette: Palette_t = None,
    norm: _Normalize | None = None,
    na_color: str | tuple[float, ...] = (0, 0, 0, 0),
    size: _SeqFloat | None = None,
    size_key: str | None = "spot_diameter_fullres",
    scale_factor: _SeqFloat | None = None,
    crop_coord: _CoordTuple | Sequence[_CoordTuple] | None = None,
    connectivity_key: str | None = None,
    edges_width: float = 1.0,
    edges_color: str | Sequence[str] | Sequence[float] = "grey",
    frameon: bool | None = None,
    legend_loc: str | None = "right margin",
    legend_fontsize: int | float | _FontSize | None = None,
    legend_fontweight: int | _FontWeight = "bold",
    legend_fontoutline: int | None = None,
    legend_na: bool = True,
    colorbar: bool = True,
    colorbar_position: str = "bottom",
    colorbar_grid: tuple[int, int] | None = None,
    colorbar_tick_size: int = 10,
    colorbar_title_size: int = 12,
    colorbar_width: float | None = None,
    colorbar_height: float | None = None,
    colorbar_spacing: dict[str, float] | None = None,
    scalebar_dx: _SeqFloat | None = None,
    scalebar_units: _SeqStr | None = None,
    title: _SeqStr | None = None,
    axis_label: _SeqStr | None = None,
    fig: Figure | None = None,
    ax: Axes | None = None,
    return_ax: bool = False,
    figsize: tuple[float, float] | None = None,
    dpi: int | None = None,
    save: str | Path | None = None,
    scalebar_kwargs: Mapping[str, Any] = MappingProxyType({}),
    edges_kwargs: Mapping[str, Any] = MappingProxyType({}),
    **kwargs: Any,
) -> Axes | None:
    """
    Plot multiple genes overlaid on the same spatial segmentation plot.

    This function allows visualization of multiple genes in the same spatial context,
    using different colors and transparency to show expression overlap and co-localization.

    Parameters
    ----------
    %(adata)s
    seg_cell_id
        Key in :attr:`anndata.AnnData.obs` that contains unique cell IDs for segmentation.
    color
        Gene names or features to plot. Can be a single gene or list of genes.
    seg
        Segmentation mask. If `True`, uses segmentation from `adata.uns['spatial']`.
    seg_key
        Key for segmentation mask in `adata.uns['spatial'][library_id]['images']`.
    seg_contourpx
        Contour width in pixels. If specified, segmentation boundaries will be eroded.
    seg_outline
        If `True`, show segmentation boundaries.
    alpha
        Transparency level for gene expression overlay (0-1).
    cmaps
        Colormap(s) for each gene. If single string, uses same colormap for all genes.
        If list, should match length of `color` parameter.
    %(library_id)s
    %(library_key)s  
    %(spatial_key)s
    %(plotting_image)s
    %(plotting_features)s
    groups
        Categories to plot for categorical features.
    palette
        Color palette for categorical features.
    norm
        Colormap normalization.
    na_color
        Color for NA/missing values.
    colorbar
        Whether to show colorbars for each gene.
    colorbar_position
        Position of colorbars: 'bottom', 'right', or 'none'.
    colorbar_grid
        Grid layout for colorbars as (rows, cols). If None, auto-determined.
    colorbar_tick_size
        Font size for colorbar tick labels.
    colorbar_title_size
        Font size for colorbar titles.
    colorbar_width
        Width of individual colorbars. If None, uses default values.
    colorbar_height
        Height of individual colorbars. If None, uses default values.
    colorbar_spacing
        Dictionary with spacing parameters: 'hspace' (vertical gaps), 'wspace' (horizontal gaps).
        If None, uses default spacing.
    title
        Plot title.
    ax
        Matplotlib axes object to plot on.
    return_ax
        Whether to return the axes object.
    figsize
        Figure size (width, height).
    save
        Path to save the figure.

    Returns
    -------
    If `return_ax` is `True`, returns matplotlib axes object, otherwise `None`.

    Examples
    --------
    >>> import squidpy as sq
    >>> # Overlay two genes with different colors
    >>> sq.pl.spatial_segment_overlay(
    ...     adata, 
    ...     seg_cell_id='cell_id',
    ...     color=['gene1', 'gene2'],
    ...     cmaps=['Reds', 'Blues'],
    ...     alpha=0.7
    ... )
    """
    from matplotlib.colors import LinearSegmentedColormap, to_rgb
    from matplotlib.gridspec import GridSpec
    from squidpy.pl._utils import sanitize_anndata, save_fig
    import matplotlib as mpl
    from squidpy.gr._utils import _assert_spatial_basis

    sanitize_anndata(adata)
    _assert_spatial_basis(adata, spatial_key)

    # Ensure color is a list
    if isinstance(color, str):
        color = [color]

    # Handle colormaps
    if cmaps is None:
        # Default colors for overlay: Red, Green, Blue, Cyan, Magenta, Yellow
        default_colors = ['#FF0000', '#00FF00', '#0000FF', '#00FFFF', '#FF00FF', '#FFFF00']
        cmaps = []
        for i, _ in enumerate(color):
            base_color = default_colors[i % len(default_colors)]
            base_rgb = to_rgb(base_color)
            cmap_colors = [
                base_rgb + (0.0,),  # Transparent
                base_rgb + (1.0,)   # Opaque
            ]
            cmap = LinearSegmentedColormap.from_list(f'overlay_{i}', cmap_colors, N=100)
            cmaps.append(cmap)
    elif isinstance(cmaps, str):
        cmaps = [cmaps] * len(color)

    # Create figure and axes with colorbar layout
    if ax is None:
        figsize = figsize or (10, 8)

        if colorbar and colorbar_position != "none":
            # Set default colorbar dimensions and spacing
            if colorbar_spacing is None:
                colorbar_spacing = {}

            # Determine colorbar grid layout
            if colorbar_grid is None:
                if colorbar_position == "bottom":
                    if len(color) <= 3:
                        colorbar_grid = (1, len(color))
                    else:
                        n_rows = (len(color) + 2) // 3  # Round up division
                        colorbar_grid = (n_rows, 3)
                elif colorbar_position == "right":
                    colorbar_grid = (len(color), 1)

            # Create GridSpec layout with customizable dimensions
            if colorbar_position == "bottom":
                # Default dimensions for bottom colorbars
                default_width = 0.2 if colorbar_width is None else colorbar_width
                default_height = 0.08 if colorbar_height is None else colorbar_height
                default_hspace = colorbar_spacing.get('hspace', 0.4)
                default_wspace = colorbar_spacing.get('wspace', 0.3)

                gs = GridSpec(
                    nrows=colorbar_grid[0] + 1,
                    ncols=max(colorbar_grid[1], 1) + 2,
                    width_ratios=[0.1, *[default_width] * max(colorbar_grid[1], 1), 0.1],
                    height_ratios=[1, *[default_height] * colorbar_grid[0]],
                    hspace=default_hspace,
                    wspace=default_wspace,
                    figure=plt.figure(figsize=figsize, dpi=dpi)
                )
                fig = gs.figure
                ax = fig.add_subplot(gs[0, :])
                ax.grid(False)
            elif colorbar_position == "right":
                # Default dimensions for right colorbars
                default_width = 0.15 if colorbar_width is None else colorbar_width
                default_height = 0.2 if colorbar_height is None else colorbar_height
                default_hspace = colorbar_spacing.get('hspace', 0.3)
                default_wspace = colorbar_spacing.get('wspace', 0.1)

                gs = GridSpec(
                    nrows=max(colorbar_grid[0], 1) + 2,
                    ncols=colorbar_grid[1] + 1,
                    width_ratios=[1, *[default_width] * colorbar_grid[1]],
                    height_ratios=[0.1, *[default_height] * max(colorbar_grid[0], 1), 0.1],
                    hspace=default_hspace,
                    wspace=default_wspace,
                    figure=plt.figure(figsize=figsize, dpi=dpi)
                )
                fig = gs.figure
                ax = fig.add_subplot(gs[:, 0])
                ax.grid(False)
        else:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
            ax.grid(False)
            gs = None  # No GridSpec when no colorbars
    else:
        fig = ax.figure if ax.figure is not None else fig
        gs = None  # No GridSpec when external ax is provided

    # Get spatial parameters once
    spatial_params = _image_spatial_attrs(
        adata=adata,
        spatial_key=spatial_key,
        library_id=library_id,
        library_key=library_key,
        img=img,
        img_res_key=img_res_key,
        img_channel=img_channel,
        seg=seg, 
        seg_key=seg_key,
        cell_id_key=seg_cell_id,
        scale_factor=scale_factor,
        size=size,
        size_key=size_key,
        img_cmap=img_cmap,
    )

    # Get coordinates and crops
    coords, crops = _set_coords_crops(
        adata=adata,
        spatial_params=spatial_params,
        spatial_key=spatial_key,
        crop_coord=crop_coord,
    )

    # Use first library for now (could be extended for multiple libraries)
    _lib_count = 0
    _size = spatial_params.size[_lib_count]
    _img = spatial_params.img[_lib_count]
    _seg = spatial_params.segment[_lib_count]
    _cell_id = spatial_params.cell_id[_lib_count]
    _crops = crops[_lib_count]
    _lib = spatial_params.library_id[_lib_count]
    _coords = coords[_lib_count]

    # Subset data
    adata_sub, coords_sub, image_sub = _subs(
        adata,
        _coords,
        _img,
        library_key=library_key,
        library_id=_lib,
        crop_coords=_crops,
    )

    # Store colorbar information for each gene
    colorbar_info = []

    # Plot each gene as an overlay
    for i, (gene, cmap) in enumerate(zip(color, cmaps, strict=False)):
        # Get color vector for this gene
        color_source_vector, color_vector, categorical = _set_color_source_vec(
            adata_sub,
            gene,
            layer=layer,
            use_raw=use_raw,
            alt_var=alt_var,
            groups=groups,
            palette=palette,
            na_color=na_color,
            alpha=alpha,
        )

        if _seg is not None and _cell_id is not None:
            # Create colormap parameters
            from matplotlib.colors import Normalize
            norm_to_use = norm or Normalize()

            # Fix normalization if needed
            if not categorical and (norm_to_use.vmin is None or norm_to_use.vmax is None):
                valid_values = color_vector[~pd.isna(color_vector)] if hasattr(color_vector, 'isna') else color_vector[~np.isnan(color_vector)]
                if len(valid_values) > 0:
                    norm_to_use.vmin = np.min(valid_values)
                    norm_to_use.vmax = np.max(valid_values)

            cmap_params = CmapParams(cmap, img_cmap, norm_to_use)
            color_params = ColorParams(None, [gene], groups, alpha, img_alpha or 1.0, use_raw or False)

            # Store colorbar information
            if colorbar and not categorical:
                colorbar_info.append({
                    'gene': gene,
                    'cmap': cmap,
                    'norm': norm_to_use,
                    'values': color_vector
                })

            # Plot this gene's segmentation
            ax, cax = _plot_segment(
                seg=_seg,
                cell_id=_cell_id,
                color_vector=color_vector,
                color_source_vector=color_source_vector,
                seg_contourpx=seg_contourpx,
                seg_outline=seg_outline,
                na_color=na_color,
                ax=ax,
                cmap_params=cmap_params,
                color_params=color_params,
                categorical=categorical,
                **kwargs,
            )

    # Handle axis decoration manually (essential parts from _decorate_axs)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.autoscale_view()  # needed when plotting points but no image

    # Handle image display and axis orientation (core logic from _decorate_axs)
    if image_sub is not None:
        ax.imshow(image_sub, cmap=img_cmap, alpha=img_alpha)
    else:
        ax.set_aspect("equal")
        ax.invert_yaxis()

    # Set title
    if title is None:
        title = f"Overlay: {', '.join(color)}"
    ax.set_title(title)

    # Add colorbars
    if colorbar and colorbar_position != "none" and colorbar_info:
        # Create colorbar axes
        cbar_axes = []
        n_genes = len(colorbar_info)

        if colorbar_position == "bottom":
            for row in range(colorbar_grid[0]):
                for col in range(colorbar_grid[1]):
                    idx = row * colorbar_grid[1] + col
                    if idx < n_genes:
                        cbar_ax = fig.add_subplot(gs[row + 1, col + 1])
                        cbar_axes.append(cbar_ax)
        elif colorbar_position == "right":
            for row in range(min(colorbar_grid[0], n_genes)):
                cbar_ax = fig.add_subplot(gs[row + 1, 1])
                cbar_axes.append(cbar_ax)

        # Create colorbars
        for i, (cbar_info, cbar_ax) in enumerate(zip(colorbar_info, cbar_axes)):
            if i < len(cbar_axes):
                # Calculate ticks
                vmin, vmax = cbar_info['norm'].vmin, cbar_info['norm'].vmax
                if vmin is not None and vmax is not None:
                    ticks = [vmin, (vmin + vmax) / 2, vmax]
                    if vmax > 10:
                        ticks = [int(t) for t in ticks]
                    else:
                        ticks = [round(t, 2) for t in ticks]

                    # Create colorbar
                    orientation = "horizontal" if colorbar_position == "bottom" else "vertical"
                    cbar = fig.colorbar(
                        mpl.cm.ScalarMappable(norm=cbar_info['norm'], cmap=cbar_info['cmap']),
                        cax=cbar_ax,
                        orientation=orientation,
                        extend="both",
                        ticks=ticks
                    )

                    # Style colorbar
                    cbar.ax.tick_params(labelsize=colorbar_tick_size)
                    cbar.ax.set_title(cbar_info['gene'], size=colorbar_title_size, pad=10)

    # Add scalebar if specified
    if scalebar_dx is not None:
        from matplotlib_scalebar.scalebar import ScaleBar
        scalebar_dx, scalebar_units = _get_scalebar(scalebar_dx, scalebar_units, 1)
        if scalebar_dx and scalebar_units:
            scalebar = ScaleBar(scalebar_dx[0], units=scalebar_units[0], **scalebar_kwargs)
            ax.add_artist(scalebar)

    if save is not None:
        save_fig(fig, path=save)

    if return_ax:
        return ax