Api pl
Violin and Distribution Plots¶
omicverse.pl.violin(adata, keys, groupby=None, *, log=False, use_raw=None, stripplot=True, jitter=True, size=1, layer=None, density_norm='width', order=None, multi_panel=None, xlabel='', ylabel=None, rotation=None, show=None, save=None, ax=None, enhanced_style=True, show_means=False, show_boxplot=False, jitter_method='uniform', jitter_alpha=0.4, violin_alpha=0.8, background_color='white', spine_color='#b4aea9', grid_lines=True, statistical_tests=False, custom_colors=None, figsize=None, fontsize=13, ticks_fontsize=None, **kwds)
¶
Enhanced violin plot compatible with scanpy's interface.
This function provides all the functionality of scanpy's violin plot with additional customization options for enhanced visualization, implemented using pure matplotlib.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData
|
AnnData. Annotated data matrix. |
required |
keys |
Union[str, Sequence[str]]
|
str | Sequence[str]. Keys for accessing variables of |
required |
groupby |
Optional[str]
|
str | None. The key of the observation grouping to consider. (None) |
None
|
log |
bool
|
bool. Plot on logarithmic axis. (False) |
False
|
use_raw |
Optional[bool]
|
bool | None. Whether to use |
None
|
stripplot |
bool
|
bool. Add a stripplot on top of the violin plot. (True) |
True
|
jitter |
Union[float, bool]
|
float | bool. Add jitter to the stripplot (only when stripplot is True). (True) |
True
|
size |
int
|
int. Size of the jitter points. (1) |
1
|
layer |
Optional[str]
|
str | None. Name of the AnnData object layer that wants to be plotted. (None) |
None
|
density_norm |
DensityNorm
|
str. The method used to scale the width of each violin. If 'width' (the default), each violin will have the same width. If 'area', each violin will have the same area. If 'count', a violin's width corresponds to the number of observations. ('width') |
'width'
|
order |
Optional[Sequence[str]]
|
Sequence[str] | None. Order in which to show the categories. (None) |
None
|
multi_panel |
Optional[bool]
|
bool | None. Display keys in multiple panels also when |
None
|
xlabel |
str
|
str. Label of the x axis. ('') |
''
|
ylabel |
Optional[Union[str, Sequence[str]]]
|
str | Sequence[str] | None. Label of the y axis. (None) |
None
|
rotation |
Optional[float]
|
float | None. Rotation of xtick labels. (None) |
None
|
show |
Optional[bool]
|
bool | None. Whether to show the plot. (None) |
None
|
save |
Optional[Union[bool, str]]
|
bool | str | None. Path to save the figure. (None) |
None
|
ax |
Optional[Axes]
|
Axes | None. A matplotlib axes object. (None) |
None
|
enhanced_style |
bool
|
bool. Whether to apply enhanced styling. (True) |
True
|
show_means |
bool
|
bool. Whether to show mean values with annotations. (False) |
False
|
show_boxplot |
bool
|
bool. Whether to overlay box plots on violins. (False) |
False
|
jitter_method |
str
|
str. Method for jittering: 'uniform' or 't_dist'. ('uniform') |
'uniform'
|
jitter_alpha |
float
|
float. Transparency of jittered points. (0.4) |
0.4
|
violin_alpha |
float
|
float. Transparency of violin plots. (0.8) |
0.8
|
background_color |
str
|
str. Background color of the plot. ('white') |
'white'
|
spine_color |
str
|
str. Color of plot spines. ('#b4aea9') |
'#b4aea9'
|
grid_lines |
bool
|
bool. Whether to show horizontal grid lines. (True) |
True
|
statistical_tests |
bool
|
bool. Whether to perform and display statistical tests. (False) |
False
|
custom_colors |
Optional[Sequence[str]]
|
Sequence[str] | None. Custom colors for groups. (None) |
None
|
figsize |
Optional[tuple]
|
tuple | None. Figure size (width, height). (None) |
None
|
fontsize |
int. Font size for labels and ticks. (13) |
13
|
|
ticks_fontsize |
int | None. Font size for axis ticks. If None, uses fontsize-1. (None) |
None
|
|
**kwds |
Additional keyword arguments passed to violinplot. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
ax |
Union[Axes, None]
|
matplotlib.axes.Axes | None. A matplotlib axes object if |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_violin.py
def violin(
adata: AnnData,
keys: Union[str, Sequence[str]],
groupby: Optional[str] = None,
*,
log: bool = False,
use_raw: Optional[bool] = None,
stripplot: bool = True,
jitter: Union[float, bool] = True,
size: int = 1,
layer: Optional[str] = None,
density_norm: DensityNorm = "width",
order: Optional[Sequence[str]] = None,
multi_panel: Optional[bool] = None,
xlabel: str = "",
ylabel: Optional[Union[str, Sequence[str]]] = None,
rotation: Optional[float] = None,
show: Optional[bool] = None,
save: Optional[Union[bool, str]] = None,
ax: Optional[Axes] = None,
# Enhanced features
enhanced_style: bool = True,
show_means: bool = False,
show_boxplot: bool = False,
jitter_method: str = 'uniform', # 'uniform', 't_dist'
jitter_alpha: float = 0.4,
violin_alpha: float = 0.8,
background_color: str = 'white',
spine_color: str = '#b4aea9',
grid_lines: bool = True,
statistical_tests: bool = False,
custom_colors: Optional[Sequence[str]] = None,
figsize: Optional[tuple] = None,
fontsize=13,
ticks_fontsize=None,
**kwds
) -> Union[Axes, None]:
r"""
Enhanced violin plot compatible with scanpy's interface.
This function provides all the functionality of scanpy's violin plot
with additional customization options for enhanced visualization,
implemented using pure matplotlib.
Arguments:
adata: AnnData. Annotated data matrix.
keys: str | Sequence[str]. Keys for accessing variables of `.var_names` or fields of `.obs`.
groupby: str | None. The key of the observation grouping to consider. (None)
log: bool. Plot on logarithmic axis. (False)
use_raw: bool | None. Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present. (None)
stripplot: bool. Add a stripplot on top of the violin plot. (True)
jitter: float | bool. Add jitter to the stripplot (only when stripplot is True). (True)
size: int. Size of the jitter points. (1)
layer: str | None. Name of the AnnData object layer that wants to be plotted. (None)
density_norm: str. The method used to scale the width of each violin. If 'width' (the default), each violin will have the same width. If 'area', each violin will have the same area. If 'count', a violin's width corresponds to the number of observations. ('width')
order: Sequence[str] | None. Order in which to show the categories. (None)
multi_panel: bool | None. Display keys in multiple panels also when `groupby is not None`. (None)
xlabel: str. Label of the x axis. ('')
ylabel: str | Sequence[str] | None. Label of the y axis. (None)
rotation: float | None. Rotation of xtick labels. (None)
show: bool | None. Whether to show the plot. (None)
save: bool | str | None. Path to save the figure. (None)
ax: Axes | None. A matplotlib axes object. (None)
enhanced_style: bool. Whether to apply enhanced styling. (True)
show_means: bool. Whether to show mean values with annotations. (False)
show_boxplot: bool. Whether to overlay box plots on violins. (False)
jitter_method: str. Method for jittering: 'uniform' or 't_dist'. ('uniform')
jitter_alpha: float. Transparency of jittered points. (0.4)
violin_alpha: float. Transparency of violin plots. (0.8)
background_color: str. Background color of the plot. ('white')
spine_color: str. Color of plot spines. ('#b4aea9')
grid_lines: bool. Whether to show horizontal grid lines. (True)
statistical_tests: bool. Whether to perform and display statistical tests. (False)
custom_colors: Sequence[str] | None. Custom colors for groups. (None)
figsize: tuple | None. Figure size (width, height). (None)
fontsize: int. Font size for labels and ticks. (13)
ticks_fontsize: int | None. Font size for axis ticks. If None, uses fontsize-1. (None)
**kwds: Additional keyword arguments passed to violinplot.
Returns:
ax: matplotlib.axes.Axes | None. A matplotlib axes object if `ax` is `None` else `None`.
"""
# Handle AnnData availability
if not ANNDATA_AVAILABLE:
raise ImportError("AnnData is required for this function. Install with: pip install anndata")
# Ensure keys is a list
if isinstance(keys, str):
keys = [keys]
keys = list(OrderedDict.fromkeys(keys)) # remove duplicates, preserving order
# Handle ylabel
if isinstance(ylabel, (str, type(None))):
ylabel = [ylabel] * (1 if groupby is None else len(keys))
ylabel=keys
# Validate ylabel length
if groupby is None:
if len(ylabel) != 1:
raise ValueError(f"Expected number of y-labels to be `1`, found `{len(ylabel)}`.")
elif len(ylabel) != len(keys):
raise ValueError(f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`.")
# Extract data from AnnData object
obs_df = _extract_data_from_adata(adata, keys, groupby, layer, use_raw)
# Colorful data analysis and parameter optimization suggestions
print(f"{Colors.HEADER}{Colors.BOLD}🎻 Violin Plot Analysis:{Colors.ENDC}")
print(f" {Colors.CYAN}Total cells: {Colors.BOLD}{len(obs_df)}{Colors.ENDC}")
print(f" {Colors.BLUE}Variables to plot: {Colors.BOLD}{len(keys)} {keys}{Colors.ENDC}")
if groupby is not None:
group_counts = obs_df[groupby].value_counts()
print(f" {Colors.GREEN}Groupby variable: '{Colors.BOLD}{groupby}{Colors.ENDC}{Colors.GREEN}' with {Colors.BOLD}{len(group_counts)}{Colors.ENDC}{Colors.GREEN} groups{Colors.ENDC}")
# Show group distribution
for group, count in group_counts.head(10).items(): # Show top 10 groups
if count < 10:
color = Colors.WARNING
elif count < 50:
color = Colors.BLUE
else:
color = Colors.GREEN
print(f" {color}• {group}: {Colors.BOLD}{count}{Colors.ENDC}{color} cells{Colors.ENDC}")
if len(group_counts) > 10:
print(f" {Colors.CYAN}... and {Colors.BOLD}{len(group_counts) - 10}{Colors.ENDC}{Colors.CYAN} more groups{Colors.ENDC}")
# Check for imbalanced groups
min_count = group_counts.min()
max_count = group_counts.max()
if max_count / min_count > 10:
print(f" {Colors.WARNING}⚠️ Imbalanced groups detected: {Colors.BOLD}{min_count}-{max_count}{Colors.ENDC}{Colors.WARNING} cells per group{Colors.ENDC}")
else:
print(f" {Colors.BLUE}Groupby: {Colors.BOLD}None{Colors.ENDC}{Colors.BLUE} (comparing variables){Colors.ENDC}")
# Analyze data distribution for each variable
print(f"\n{Colors.HEADER}{Colors.BOLD}📊 Data Distribution Analysis:{Colors.ENDC}")
for key in keys:
if key in obs_df.columns:
data_vals = obs_df[key].dropna()
if len(data_vals) > 0:
data_range = data_vals.max() - data_vals.min()
zero_fraction = (data_vals == 0).sum() / len(data_vals)
# Determine if data might be log-transformed already
log_suggestion = ""
if data_vals.min() >= 0 and data_vals.max() > 100:
log_suggestion = f" {Colors.BLUE}(consider log=True){Colors.ENDC}"
elif data_vals.min() < 0:
log_suggestion = f" {Colors.WARNING}(negative values detected){Colors.ENDC}"
print(f" {Colors.BLUE}'{key}': range {Colors.BOLD}{data_vals.min():.2f}-{data_vals.max():.2f}{Colors.ENDC}{Colors.BLUE}, {Colors.BOLD}{zero_fraction*100:.1f}%{Colors.ENDC}{Colors.BLUE} zeros{log_suggestion}")
# Display current function parameters
print(f"\n{Colors.HEADER}{Colors.BOLD}⚙️ Current Function Parameters:{Colors.ENDC}")
print(f" {Colors.BLUE}Plot style: enhanced_style={Colors.BOLD}{enhanced_style}{Colors.ENDC}{Colors.BLUE}, stripplot={Colors.BOLD}{stripplot}{Colors.ENDC}{Colors.BLUE}, jitter={Colors.BOLD}{jitter}{Colors.ENDC}")
print(f" {Colors.BLUE}Additional features: show_means={Colors.BOLD}{show_means}{Colors.ENDC}{Colors.BLUE}, show_boxplot={Colors.BOLD}{show_boxplot}{Colors.ENDC}{Colors.BLUE}, statistical_tests={Colors.BOLD}{statistical_tests}{Colors.ENDC}")
print(f" {Colors.BLUE}Figure settings: figsize={Colors.BOLD}{figsize}{Colors.ENDC}{Colors.BLUE}, fontsize={Colors.BOLD}{fontsize}{Colors.ENDC}{Colors.BLUE}, violin_alpha={Colors.BOLD}{violin_alpha}{Colors.ENDC}")
if custom_colors is not None:
print(f" {Colors.BLUE}Colors: {Colors.BOLD}{len(custom_colors)} custom colors specified{Colors.ENDC}")
else:
print(f" {Colors.BLUE}Colors: {Colors.BOLD}Default palette{Colors.ENDC}")
# Parameter optimization suggestions
print(f"\n{Colors.HEADER}{Colors.BOLD}💡 Parameter Optimization Suggestions:{Colors.ENDC}")
suggestions = []
# Check for too many groups
if groupby is not None:
n_groups = len(obs_df[groupby].unique())
if n_groups > 8:
suggestions.append(f" {Colors.WARNING}▶ Many groups detected ({n_groups}):{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
suggested_width = max(8, n_groups * 1.2)
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_width}, {figsize[1] if figsize else 6}){Colors.ENDC}")
if rotation is None:
suggestions.append(f" {Colors.GREEN}Consider adding: {Colors.BOLD}rotation=45{Colors.ENDC} for better label readability")
# Check for small sample sizes
if groupby is not None:
group_counts = obs_df[groupby].value_counts()
if group_counts.min() < 10:
suggestions.append(f" {Colors.WARNING}▶ Small sample sizes detected (min: {group_counts.min()}):{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}stripplot=True, jitter=0.3{Colors.ENDC} to show individual points")
# Check data distribution and log scale
for key in keys:
if key in obs_df.columns:
data_vals = obs_df[key].dropna()
if len(data_vals) > 0 and data_vals.min() >= 0 and data_vals.max() / data_vals.min() > 100:
suggestions.append(f" {Colors.WARNING}▶ Wide data range for '{key}' ({data_vals.max()/data_vals.min():.1f}x):{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: log={Colors.BOLD}{log}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}log=True{Colors.ENDC} for better visualization")
break
# Check figure size vs number of variables
if len(keys) > 3 and (figsize is None or figsize[0] < len(keys) * 2):
suggestions.append(f" {Colors.WARNING}▶ Multiple variables with small figure:{Colors.ENDC}")
suggested_width = max(len(keys) * 2, 8)
suggestions.append(f" {Colors.CYAN}Current: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_width}, 6){Colors.ENDC} or {Colors.BOLD}multi_panel=True{Colors.ENDC}")
# Font size optimization
if groupby is not None:
n_groups = len(obs_df[groupby].unique())
max_label_length = max(len(str(x)) for x in obs_df[groupby].unique())
if max_label_length > 8 and fontsize > 12:
suggestions.append(f" {Colors.WARNING}▶ Long group labels detected:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: fontsize={Colors.BOLD}{fontsize}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}fontsize=10, rotation=45{Colors.ENDC}")
# Enhanced features suggestions
if not show_means and not show_boxplot and not statistical_tests:
suggestions.append(f" {Colors.BLUE}▶ Consider enhancing your plot:{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Options: {Colors.BOLD}show_means=True{Colors.ENDC}{Colors.GREEN}, {Colors.BOLD}show_boxplot=True{Colors.ENDC}{Colors.GREEN}, or {Colors.BOLD}statistical_tests=True{Colors.ENDC}")
if suggestions:
for suggestion in suggestions:
print(suggestion)
print(f"\n {Colors.BOLD}📋 Copy-paste ready function call:{Colors.ENDC}")
# Generate optimized function call
optimized_params = []
# Core parameters
if isinstance(keys, list) and len(keys) == 1:
optimized_params.append(f"adata, keys='{keys[0]}'")
else:
optimized_params.append(f"adata, keys={keys}")
if groupby is not None:
optimized_params.append(f"groupby='{groupby}'")
# Add optimized parameters based on suggestions
if groupby is not None:
n_groups = len(obs_df[groupby].unique())
if n_groups > 8:
suggested_width = max(8, n_groups * 1.2)
optimized_params.append(f"figsize=({suggested_width}, 6)")
if rotation is None and n_groups > 6:
optimized_params.append("rotation=45")
group_counts = obs_df[groupby].value_counts()
if group_counts.min() < 10:
optimized_params.append("stripplot=True")
optimized_params.append("jitter=0.3")
# Check for log scale suggestion
for key in keys:
if key in obs_df.columns:
data_vals = obs_df[key].dropna()
if len(data_vals) > 0 and data_vals.min() >= 0 and data_vals.max() / data_vals.min() > 100:
optimized_params.append("log=True")
break
# Multi-panel suggestion
if len(keys) > 3 and (figsize is None or figsize[0] < len(keys) * 2):
optimized_params.append("multi_panel=True")
# Enhanced features
if not show_means and not show_boxplot:
optimized_params.append("show_means=True")
optimized_call = f" {Colors.GREEN}ov.pl.violin({', '.join(optimized_params)}){Colors.ENDC}"
print(optimized_call)
else:
print(f" {Colors.GREEN}✅ Current parameters are optimal for your data!{Colors.ENDC}")
print(f"{Colors.CYAN}{'─' * 60}{Colors.ENDC}")
# Prepare data for plotting
if groupby is None:
obs_tidy = pd.melt(obs_df, value_vars=keys)
x_col = "variable"
y_cols = ["value"]
group_categories = keys
else:
obs_tidy = obs_df
x_col = groupby
y_cols = keys
obs_df[groupby] = obs_df[groupby].astype('category')
group_categories = obs_df[groupby].cat.categories if order is None else order
# Set up colors
colors = _setup_colors(custom_colors, group_categories, adata, groupby)
# Handle multi-panel case
if multi_panel and groupby is None and len(y_cols) == 1:
return _create_multi_panel_plot(
obs_tidy, keys, y_cols[0], colors, log, stripplot, jitter,
size, density_norm, enhanced_style, **kwds
)
# Create single or multiple axis plots
if ax is None and figsize is not None:
fig, ax = plt.subplots(figsize=figsize)
elif ax is None:
fig, ax = plt.subplots(figsize=(10, 6))
# Apply enhanced styling
if enhanced_style:
_apply_enhanced_styling(ax, background_color, spine_color, grid_lines)
# Create plots for each y variable
for i, (y_col, y_label) in enumerate(zip(y_cols, ylabel)):
if len(y_cols) > 1:
# Create subplots for multiple keys
if i == 0:
fig, axes = plt.subplots(1, len(y_cols), figsize=figsize or (5*len(y_cols), 6))
if len(y_cols) == 1:
axes = [axes]
current_ax = axes[i] if len(y_cols) > 1 else ax
else:
current_ax = ax
# Prepare data for current y variable
plot_data = _prepare_plot_data(obs_tidy, x_col, y_col, group_categories, order)
# Create violin plots
_create_violin_plots(
current_ax, plot_data, group_categories, colors, density_norm,
violin_alpha, enhanced_style, **kwds
)
# Add box plots if requested
if show_boxplot:
_add_box_plots(current_ax, plot_data, group_categories)
# Add strip plot (jittered points)
if stripplot:
_add_strip_plot(
current_ax, plot_data, group_categories, jitter, jitter_method,
size, jitter_alpha, colors
)
# Add mean annotations
if show_means:
_add_mean_annotations(current_ax, plot_data, group_categories)
# Add statistical tests
if statistical_tests:
_add_statistical_tests(current_ax, plot_data, group_categories)
# Customize axis
_customize_axis(
current_ax, group_categories, xlabel, y_label, groupby,
rotation, log, order
)
current_ax.spines['top'].set_visible(False)
current_ax.spines['right'].set_visible(False)
current_ax.spines['bottom'].set_visible(True)
current_ax.spines['left'].set_visible(True)
current_ax.spines['left'].set_position(('outward', 10))
current_ax.spines['bottom'].set_position(('outward', 10))
if ticks_fontsize==None:
ticks_fontsize=fontsize-1
current_ax.set_xticklabels(current_ax.get_xticklabels(),fontsize=ticks_fontsize,rotation=rotation)
current_ax.set_yticklabels(current_ax.get_yticklabels(),fontsize=ticks_fontsize)
current_ax.set_xlabel(groupby,fontsize=fontsize)
current_ax.set_ylabel(y_label,fontsize=fontsize)
#print(y_label)
#plt.tight_layout()
current_ax.grid(False)
# Save figure if requested
if save:
save_path = save if isinstance(save, str) else "violin_plot.pdf"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
# Show figure if requested
if show is True or (show is None and ax is None):
plt.show()
return None
return ax
Dot Plots¶
omicverse.pl.dotplot(adata, var_names, groupby, *, use_raw=None, log=False, num_categories=7, categories_order=None, expression_cutoff=0.0, mean_only_expressed=False, standard_scale=None, title=None, colorbar_title='Mean expression\nin group', size_title='Fraction of cells\nin group (%)', figsize=None, dendrogram=False, gene_symbols=None, var_group_positions=None, var_group_labels=None, var_group_rotation=None, layer=None, swap_axes=False, dot_color_df=None, show=None, save=None, ax=None, return_fig=False, vmin=None, vmax=None, vcenter=None, norm=None, cmap='Reds', dot_max=None, dot_min=None, smallest_dot=0.0, fontsize=12, preserve_dict_order=False, **kwds)
¶
Make a dot plot of the expression values of var_names
.
For each var_name and each groupby
category a dot is plotted.
Each dot represents two values: mean expression within each category
(visualized by color) and fraction of cells expressing the var_name
in the
category (visualized by the size of the dot).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData
|
AnnData Annotated data matrix. |
required |
var_names |
Union[_VarNames, Mapping[str, _VarNames]]
|
str or list of str or dict Variables to plot. |
required |
groupby |
Union[str, Sequence[str]]
|
str or list of str The key of the observation grouping to consider. |
required |
use_raw |
Optional[bool]
|
bool, optional (default=None)
Use |
None
|
log |
bool
|
bool, optional (default=False) Whether to log-transform the data. |
False
|
num_categories |
int
|
int, optional (default=7) Number of categories to show. |
7
|
categories_order |
Optional[Sequence[str]]
|
list of str, optional (default=None) Order of categories to display. |
None
|
expression_cutoff |
float
|
float, optional (default=0.0) Expression cutoff for calculating fraction of expressing cells. |
0.0
|
mean_only_expressed |
bool
|
bool, optional (default=False) Whether to calculate mean only for expressing cells. |
False
|
standard_scale |
Optional[Literal['var', 'group']]
|
{'var', 'group'} or None, optional (default=None) Whether to standardize data. |
None
|
title |
Optional[str]
|
str, optional (default=None) Title for the plot. |
None
|
colorbar_title |
Optional[str]
|
str, optional (default='Mean expression\nin group') Title for the color bar. |
'Mean expression\nin group'
|
size_title |
Optional[str]
|
str, optional (default='Fraction of cells\nin group (%)') Title for the size legend. |
'Fraction of cells\nin group (%)'
|
figsize |
Optional[Tuple[float, float]]
|
tuple, optional (default=None) Figure size (width, height) in inches. If provided, the plot dimensions will be scaled accordingly. |
None
|
dendrogram |
Union[bool, str]
|
bool or str, optional (default=False) Whether to add dendrogram to the plot. |
False
|
gene_symbols |
Optional[str]
|
str, optional (default=None)
Key for gene symbols in |
None
|
var_group_positions |
Optional[Sequence[Tuple[int, int]]]
|
list of tuples, optional (default=None) Positions for variable groups. |
None
|
var_group_labels |
Optional[Sequence[str]]
|
list of str, optional (default=None) Labels for variable groups. |
None
|
var_group_rotation |
Optional[float]
|
float, optional (default=None) Rotation angle for variable group labels. |
None
|
layer |
Optional[str]
|
str, optional (default=None) Layer to use for expression data. |
None
|
swap_axes |
Optional[bool]
|
bool, optional (default=False) Whether to swap x and y axes. |
False
|
dot_color_df |
Optional[pd.DataFrame]
|
pandas.DataFrame, optional (default=None) DataFrame for dot colors. |
None
|
show |
Optional[bool]
|
bool, optional (default=None) Whether to show the plot. |
None
|
save |
Optional[Union[str, bool]]
|
str or bool, optional (default=None) Whether to save the plot. |
None
|
ax |
Optional[_AxesSubplot]
|
matplotlib.axes.Axes, optional (default=None) Axes object to plot on. |
None
|
return_fig |
Optional[bool]
|
bool, optional (default=False) Whether to return the figure object. |
False
|
vmin |
Optional[float]
|
float, optional (default=None) Minimum value for color scaling. |
None
|
vmax |
Optional[float]
|
float, optional (default=None) Maximum value for color scaling. |
None
|
vcenter |
Optional[float]
|
float, optional (default=None) Center value for diverging colormap. |
None
|
norm |
Optional[Normalize]
|
matplotlib.colors.Normalize, optional (default=None) Normalization object for colors. |
None
|
cmap |
Union[Colormap, str, None]
|
str or matplotlib.colors.Colormap, optional (default='Reds') Colormap for the plot. |
'Reds'
|
dot_max |
Optional[float]
|
float, optional (default=None) Maximum dot size. |
None
|
dot_min |
Optional[float]
|
float, optional (default=None) Minimum dot size. |
None
|
smallest_dot |
float
|
float, optional (default=0.0) Size of the smallest dot. |
0.0
|
fontsize |
int
|
int, optional (default=12) Font size for labels and legends. Titles will be one point larger. |
12
|
preserve_dict_order |
bool
|
bool, optional (default=False) When var_names is a dictionary, whether to preserve the original dictionary order. If True, genes will be ordered according to the dictionary's insertion order. If False (default), genes will be ordered according to cell type categories. |
False
|
Returns:
Type | Description |
---|---|
Optional[Union[Dict, DotPlot]]
|
If |
Optional[Union[Dict, DotPlot]]
|
If |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_dotplot.py
def dotplot(
adata: AnnData,
var_names: Union[_VarNames, Mapping[str, _VarNames]],
groupby: Union[str, Sequence[str]],
*,
use_raw: Optional[bool] = None,
log: bool = False,
num_categories: int = 7,
categories_order: Optional[Sequence[str]] = None,
expression_cutoff: float = 0.0,
mean_only_expressed: bool = False,
standard_scale: Optional[Literal['var', 'group']] = None,
title: Optional[str] = None,
colorbar_title: Optional[str] = 'Mean expression\nin group',
size_title: Optional[str] = 'Fraction of cells\nin group (%)',
figsize: Optional[Tuple[float, float]] = None,
dendrogram: Union[bool, str] = False,
gene_symbols: Optional[str] = None,
var_group_positions: Optional[Sequence[Tuple[int, int]]] = None,
var_group_labels: Optional[Sequence[str]] = None,
var_group_rotation: Optional[float] = None,
layer: Optional[str] = None,
swap_axes: Optional[bool] = False,
dot_color_df: Optional[pd.DataFrame] = None,
show: Optional[bool] = None,
save: Optional[Union[str, bool]] = None,
ax: Optional[_AxesSubplot] = None,
return_fig: Optional[bool] = False,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
vcenter: Optional[float] = None,
norm: Optional[Normalize] = None,
cmap: Union[Colormap, str, None] = 'Reds',
dot_max: Optional[float] = None,
dot_min: Optional[float] = None,
smallest_dot: float = 0.0,
fontsize: int = 12,
preserve_dict_order: bool = False,
**kwds,
) -> Optional[Union[Dict, 'DotPlot']]:
r"""
Make a dot plot of the expression values of `var_names`.
For each var_name and each `groupby` category a dot is plotted.
Each dot represents two values: mean expression within each category
(visualized by color) and fraction of cells expressing the `var_name` in the
category (visualized by the size of the dot).
Arguments:
adata: AnnData
Annotated data matrix.
var_names: str or list of str or dict
Variables to plot.
groupby: str or list of str
The key of the observation grouping to consider.
use_raw: bool, optional (default=None)
Use `raw` attribute of `adata` if present.
log: bool, optional (default=False)
Whether to log-transform the data.
num_categories: int, optional (default=7)
Number of categories to show.
categories_order: list of str, optional (default=None)
Order of categories to display.
expression_cutoff: float, optional (default=0.0)
Expression cutoff for calculating fraction of expressing cells.
mean_only_expressed: bool, optional (default=False)
Whether to calculate mean only for expressing cells.
standard_scale: {'var', 'group'} or None, optional (default=None)
Whether to standardize data.
title: str, optional (default=None)
Title for the plot.
colorbar_title: str, optional (default='Mean expression\nin group')
Title for the color bar.
size_title: str, optional (default='Fraction of cells\nin group (%)')
Title for the size legend.
figsize: tuple, optional (default=None)
Figure size (width, height) in inches. If provided, the plot dimensions will be scaled accordingly.
dendrogram: bool or str, optional (default=False)
Whether to add dendrogram to the plot.
gene_symbols: str, optional (default=None)
Key for gene symbols in `adata.var`.
var_group_positions: list of tuples, optional (default=None)
Positions for variable groups.
var_group_labels: list of str, optional (default=None)
Labels for variable groups.
var_group_rotation: float, optional (default=None)
Rotation angle for variable group labels.
layer: str, optional (default=None)
Layer to use for expression data.
swap_axes: bool, optional (default=False)
Whether to swap x and y axes.
dot_color_df: pandas.DataFrame, optional (default=None)
DataFrame for dot colors.
show: bool, optional (default=None)
Whether to show the plot.
save: str or bool, optional (default=None)
Whether to save the plot.
ax: matplotlib.axes.Axes, optional (default=None)
Axes object to plot on.
return_fig: bool, optional (default=False)
Whether to return the figure object.
vmin: float, optional (default=None)
Minimum value for color scaling.
vmax: float, optional (default=None)
Maximum value for color scaling.
vcenter: float, optional (default=None)
Center value for diverging colormap.
norm: matplotlib.colors.Normalize, optional (default=None)
Normalization object for colors.
cmap: str or matplotlib.colors.Colormap, optional (default='Reds')
Colormap for the plot.
dot_max: float, optional (default=None)
Maximum dot size.
dot_min: float, optional (default=None)
Minimum dot size.
smallest_dot: float, optional (default=0.0)
Size of the smallest dot.
fontsize: int, optional (default=12)
Font size for labels and legends. Titles will be one point larger.
preserve_dict_order: bool, optional (default=False)
When var_names is a dictionary, whether to preserve the original dictionary order.
If True, genes will be ordered according to the dictionary's insertion order.
If False (default), genes will be ordered according to cell type categories.
Returns:
If `return_fig` is True, returns the figure object.
If `show` is False, returns axes dictionary.
"""
# Convert var_names to list if string
original_var_names_dict = None
if isinstance(var_names, str):
var_names = [var_names]
elif isinstance(var_names, Mapping):
# Save original dictionary reference for color bar ordering
if preserve_dict_order:
original_var_names_dict = var_names
# Get gene groups
gene_groups = []
var_names_list = []
if preserve_dict_order:
# Preserve the original dictionary order
for group, genes in var_names.items():
if isinstance(genes, str):
genes = [genes]
var_names_list.extend(genes)
gene_groups.extend([group] * len(genes))
else:
# Get cell type order (original behavior)
if categories_order is not None:
group_order = categories_order
elif pd.api.types.is_categorical_dtype(adata.obs[groupby]):
group_order = list(adata.obs[groupby].cat.categories)
else:
group_order = list(adata.obs[groupby].unique())
# Order gene groups according to cell types
for group in group_order:
if group in var_names:
genes = var_names[group]
if isinstance(genes, str):
genes = [genes]
var_names_list.extend(genes)
gene_groups.extend([group] * len(genes))
# Add any remaining groups that weren't in the cell types
for group, genes in var_names.items():
if group not in group_order:
if isinstance(genes, str):
genes = [genes]
var_names_list.extend(genes)
gene_groups.extend([group] * len(genes))
var_names = var_names_list
# Get expression matrix
if use_raw and adata.raw is not None:
matrix = adata.raw.X
var_names_idx = [adata.raw.var_names.get_loc(name) for name in var_names]
else:
matrix = adata.X if layer is None else adata.layers[layer]
var_names_idx = [adata.var_names.get_loc(name) for name in var_names]
# Determine category order
if categories_order is not None:
cats = categories_order
else:
# Use the categorical order from adata if available
if pd.api.types.is_categorical_dtype(adata.obs[groupby]):
cats = adata.obs[groupby].cat.categories
else:
# If not categorical, get unique values
cats = adata.obs[groupby].unique()
# Get aggregated data with specified order
agg = adata.obs[groupby].value_counts().reindex(cats)
cell_counts = agg.to_numpy()
# Get colors for cell types if available
cell_colors = None
color_dict = None
try:
color_key = f"{groupby}_colors"
if color_key in adata.uns:
colors = adata.uns[color_key]
# Create color dictionary mapping cell types to colors
if pd.api.types.is_categorical_dtype(adata.obs[groupby]):
# Use categorical order for colors
color_dict = dict(zip(adata.obs[groupby].cat.categories, colors))
else:
# Use unique order for colors
unique_cats = adata.obs[groupby].unique()
color_dict = dict(zip(unique_cats, colors[:len(unique_cats)]))
# Get colors for the actual categories in the plot
cell_colors = [color_dict.get(cat, '#CCCCCC') for cat in agg.index]
except (KeyError, IndexError):
cell_colors = None
color_dict = None
# Calculate mean expression and fraction of expressing cells
means = np.zeros((len(agg), len(var_names)))
fractions = np.zeros_like(means)
for i, group in enumerate(agg.index):
mask = adata.obs[groupby] == group
group_matrix = matrix[mask][:, var_names_idx]
# Calculate mean expression
if mean_only_expressed:
expressed = group_matrix > expression_cutoff
means[i] = np.array([
group_matrix[:, j][expressed[:, j]].mean() if expressed[:, j].any() else 0
for j in range(group_matrix.shape[1])
])
else:
means[i] = np.mean(group_matrix, axis=0)
# Calculate fraction of expressing cells
fractions[i] = np.mean(group_matrix > expression_cutoff, axis=0)
# Scale if requested
if standard_scale == 'group':
means = (means - means.min(axis=1, keepdims=True)) / (means.max(axis=1, keepdims=True) - means.min(axis=1, keepdims=True))
elif standard_scale == 'var':
means = (means - means.min(axis=0)) / (means.max(axis=0) - means.min(axis=0))
# Handle dot size limits
if dot_max is not None:
fractions = np.minimum(fractions, dot_max)
if dot_min is not None:
fractions = np.maximum(fractions, dot_min)
# Scale dot sizes to account for smallest_dot
if smallest_dot > 0:
fractions = smallest_dot + (1 - smallest_dot) * fractions
# Create the plot
h, w = means.shape
# Calculate dimensions based on figsize if provided
if figsize is not None:
# Use figsize to determine height and width
# Adjust for the number of rows and columns to maintain aspect ratio
base_height = figsize[1] * 0.7 # Use 70% of figsize height for main plot
base_width = figsize[0] * 0.7 # Use 70% of figsize width for main plot
# Scale based on data dimensions
height = base_height * (h / max(h, w))
width = base_width * (w / max(h, w))
else:
# Default behavior
height = h / 3
width = w / 3
# Create SizedHeatmap
m = ma.SizedHeatmap(
size=fractions,
color=means,
cluster_data=fractions if dendrogram else None,
height=height,
width=width,
edgecolor="lightgray",
cmap=cmap,
vmin=vmin,
vmax=vmax,
#norm=norm,
size_legend_kws=dict(
colors="#c2c2c2",
title=size_title,
labels=[f"{int(x*100)}%" for x in [0.2, 0.4, 0.6, 0.8, 1.0]],
show_at=[0.2, 0.4, 0.6, 0.8, 1.0],
fontsize=fontsize,
ncol=3,
title_fontproperties={"size": fontsize + 1, "weight": 100}
),
color_legend_kws=dict(
title=colorbar_title,
fontsize=fontsize,
orientation="horizontal",
title_fontproperties={"size": fontsize + 1, "weight": 100}
),
)
# Add labels
m.add_top(mp.Labels(var_names, fontsize=fontsize), pad=0.1)
# Group genes if var_names was a dictionary
if 'gene_groups' in locals():
# Get colors for gene groups
try:
# Use the same color_dict that was created for cell types
if color_dict is not None:
# Get unique groups and check which ones are not in color_dict
unique_groups = list(dict.fromkeys(gene_groups))
missing_groups = [g for g in unique_groups if g not in color_dict]
# If there are missing groups, add colors from palette
if missing_groups:
if len(palette_28) >= len(color_dict) + len(missing_groups):
extra_colors = palette_28[len(color_dict):len(color_dict) + len(missing_groups)]
else:
extra_colors = palette_56[len(color_dict):len(color_dict) + len(missing_groups)]
color_dict.update(dict(zip(missing_groups, extra_colors)))
else:
# If no colors found in uns, use default palette
unique_groups = list(dict.fromkeys(gene_groups))
if len(unique_groups) <= 28:
palette = palette_28
else:
palette = palette_56
color_dict = dict(zip(unique_groups, palette[:len(unique_groups)]))
except (KeyError, AttributeError):
# If colors not found in uns, use default palette
unique_groups = list(dict.fromkeys(gene_groups))
if len(unique_groups) <= 28:
palette = palette_28
else:
palette = palette_56
color_dict = dict(zip(unique_groups, palette[:len(unique_groups)]))
# Add color bars with matching order
# Add group labels
# Only show used colors in legend and increase group spacing
if preserve_dict_order and original_var_names_dict is not None:
# When preserving dict order, use the original dictionary key order
used_groups = list(original_var_names_dict.keys())
else:
# Use the order as they appear in gene_groups
used_groups = list(dict.fromkeys(gene_groups))
used_color_dict = {k: color_dict[k] for k in used_groups}
m.add_top(
mp.Colors(gene_groups, palette=used_color_dict),
pad=0.1,
size=0.15,
)
# Add group labels with increased spacing
m.group_cols(gene_groups,order=used_groups)
# Add cell type colors if available
if color_dict is not None:
# Add color bar using the properly created color_dict
m.add_left(
mp.Colors(agg.index, palette=color_dict),
size=0.15,
pad=0.1,
legend=False,
)
# Add cell type labels
m.add_left(mp.Labels(agg.index, align="right", fontsize=fontsize), pad=0.1)
# Add cell counts
m.add_right(
mp.Numbers(
cell_counts,
color="#EEB76B",
label="Count",
label_props={'size': fontsize},
props={'size': fontsize},
show_value=False
),
size=0.5,
pad=0.1,
)
# Add dendrogram if requested
if dendrogram:
m.add_dendrogram("right", pad=0.1)
# Add legends
m.add_legends(box_padding=2)
# Render the plot
fig = m.render()
if return_fig:
return fig
elif not show:
return m
return None
omicverse.pl.rank_genes_groups_dotplot(adata, plot_type='dotplot', *, groups=None, n_genes=None, groupby=None, values_to_plot=None, var_names=None, min_logfoldchange=None, key=None, show=None, save=None, return_fig=False, gene_symbols=None, **kwds)
¶
Create a dot plot from rank_genes_groups results.
Parameters¶
AnnData
Annotated data matrix.
str
Currently only 'dotplot' is supported.
str or list of str, optional
Groups to include in the plot.
int, optional
Number of genes to include in the plot.
str, optional
Key in adata.obs
to group by.
str, optional
Key in rank_genes_groups results to plot (e.g. 'logfoldchanges', 'scores').
str or list of str or dict, optional
Variables to include in the plot. Can be: - A list of gene names: ['gene1', 'gene2', ...] - A dictionary mapping group names to gene lists: {'group1': ['gene1', 'gene2'], 'group2': ['gene3', 'gene4']} When a dictionary is provided, genes will be grouped and labeled accordingly in the plot.
float, optional
Minimum log fold change to include in the plot.
str, optional
Key in adata.uns
to use for rank_genes_groups results.
bool, optional
Whether to show the plot.
bool, optional
Whether to save the plot.
bool
Whether to return the figure object.
str, optional
Key for gene symbols in adata.var
.
**kwds : dict Additional keyword arguments to pass to dotplot.
Returns¶
If return_fig
is True, returns the figure object.
If show
is False, returns axes dictionary.
Examples¶
Basic usage with top genes¶
sc.pl.rank_genes_groups_dotplot(adata, n_genes=5)
Using logfoldchanges for coloring¶
sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, values_to_plot='logfoldchanges')
Grouping genes manually¶
gene_groups = { ... 'Group1': ['gene1', 'gene2'], ... 'Group2': ['gene3', 'gene4'] ... } sc.pl.rank_genes_groups_dotplot(adata, var_names=gene_groups)
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_dotplot.py
def rank_genes_groups_dotplot(
adata: AnnData,
plot_type: str = "dotplot",
*,
groups: Optional[Union[str, Sequence[str]]] = None,
n_genes: Optional[int] = None,
groupby: Optional[str] = None,
values_to_plot: Optional[str] = None,
var_names: Optional[Union[Sequence[str], Mapping[str, Sequence[str]]]] = None,
min_logfoldchange: Optional[float] = None,
key: Optional[str] = None,
show: Optional[bool] = None,
save: Optional[bool] = None,
return_fig: bool = False,
gene_symbols: Optional[str] = None,
**kwds: Any,
) -> Optional[Union[Dict, Any]]:
"""
Create a dot plot from rank_genes_groups results.
Parameters
----------
adata : AnnData
Annotated data matrix.
plot_type : str
Currently only 'dotplot' is supported.
groups : str or list of str, optional
Groups to include in the plot.
n_genes : int, optional
Number of genes to include in the plot.
groupby : str, optional
Key in `adata.obs` to group by.
values_to_plot : str, optional
Key in rank_genes_groups results to plot (e.g. 'logfoldchanges', 'scores').
var_names : str or list of str or dict, optional
Variables to include in the plot. Can be:
- A list of gene names: ['gene1', 'gene2', ...]
- A dictionary mapping group names to gene lists: {'group1': ['gene1', 'gene2'], 'group2': ['gene3', 'gene4']}
When a dictionary is provided, genes will be grouped and labeled accordingly in the plot.
min_logfoldchange : float, optional
Minimum log fold change to include in the plot.
key : str, optional
Key in `adata.uns` to use for rank_genes_groups results.
show : bool, optional
Whether to show the plot.
save : bool, optional
Whether to save the plot.
return_fig : bool
Whether to return the figure object.
gene_symbols : str, optional
Key for gene symbols in `adata.var`.
**kwds : dict
Additional keyword arguments to pass to dotplot.
Returns
-------
If `return_fig` is True, returns the figure object.
If `show` is False, returns axes dictionary.
Examples
--------
>>> # Basic usage with top genes
>>> sc.pl.rank_genes_groups_dotplot(adata, n_genes=5)
>>> # Using logfoldchanges for coloring
>>> sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, values_to_plot='logfoldchanges')
>>> # Grouping genes manually
>>> gene_groups = {
... 'Group1': ['gene1', 'gene2'],
... 'Group2': ['gene3', 'gene4']
... }
>>> sc.pl.rank_genes_groups_dotplot(adata, var_names=gene_groups)
"""
if plot_type != "dotplot":
raise ValueError("Only 'dotplot' is currently supported")
if var_names is not None and n_genes is not None:
msg = (
"The arguments n_genes and var_names are mutually exclusive. Please "
"select only one."
)
raise ValueError(msg)
if key is None:
key = "rank_genes_groups"
if groupby is None:
groupby = str(adata.uns[key]["params"]["groupby"])
group_names = adata.uns[key]["names"].dtype.names if groups is None else groups
if var_names is not None:
if isinstance(var_names, Mapping):
# get a single list of all gene names in the dictionary
var_names_list = functools.reduce(
operator.iadd, [list(x) for x in var_names.values()], []
)
elif isinstance(var_names, str):
var_names_list = [var_names]
else:
var_names_list = var_names
else:
# set n_genes = 10 as default when none of the options is given
if n_genes is None:
n_genes = 10
# dict in which each group is the key and the n_genes are the values
var_names = {}
var_names_list = []
for group in group_names:
df = rank_genes_groups_df(
adata,
group,
key=key,
gene_symbols=gene_symbols,
log2fc_min=min_logfoldchange,
)
if gene_symbols is not None:
df["names"] = df[gene_symbols]
genes_list = df.names[df.names.notnull()].tolist()
if len(genes_list) == 0:
print(f"Warning: No genes found for group {group}")
continue
genes_list = genes_list[n_genes:] if n_genes < 0 else genes_list[:n_genes]
var_names[group] = genes_list
var_names_list.extend(genes_list)
# by default add dendrogram to plots
kwds.setdefault("dendrogram", True)
# Get values to plot if specified
title = None
values_df = None
if values_to_plot is not None:
values_df = _get_values_to_plot(
adata,
values_to_plot,
var_names_list,
key=key,
gene_symbols=gene_symbols,
)
title = values_to_plot
if values_to_plot == "logfoldchanges":
title = "log fold change"
else:
title = values_to_plot.replace("_", " ").replace("pvals", "p-value")
# Create the plot
_pl = dotplot(
adata,
var_names,
groupby,
dot_color_df=values_df,
return_fig=True,
gene_symbols=gene_symbols,
preserve_dict_order=True,
**kwds,
)
if title is not None and "colorbar_title" not in kwds:
_pl.legend(colorbar_title=title)
if return_fig:
return _pl
elif not show:
return _pl
return None
Embedding and Scatter Plots¶
omicverse.pl.embedding(adata, basis, *, color=None, gene_symbols=None, use_raw=None, sort_order=True, edges=False, edges_width=0.1, edges_color='grey', neighbors_key=None, arrows=False, arrows_kwds=None, groups=None, components=None, dimensions=None, layer=None, projection='2d', scale_factor=None, color_map=None, cmap=None, palette=None, na_color='lightgray', na_in_legend=True, size=None, frameon='small', legend_fontsize=None, legend_fontweight='bold', legend_loc='right margin', legend_fontoutline=None, colorbar_loc='right', vmax=None, vmin=None, vcenter=None, norm=None, add_outline=False, outline_width=(0.3, 0.05), outline_color=('black', 'white'), ncols=4, hspace=0.25, wspace=None, title=None, show=None, save=None, ax=None, return_fig=None, marker='.', **kwargs)
¶
Scatter plot for user specified embedding basis (e.g. umap, pca, etc).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData
|
Annotated data matrix. |
required |
basis |
str
|
Name of the |
required |
color |
Union[str, Sequence[str], None]
|
Keys for annotations of observations/cells or variables/genes. (None) |
None
|
gene_symbols |
Optional[str]
|
Key for field in |
None
|
use_raw |
Optional[bool]
|
Use |
None
|
sort_order |
bool
|
For continuous annotations used as color parameter, plot data points with higher values on top of others. (True) |
True
|
edges |
bool
|
Show edges between cells. (False) |
False
|
edges_width |
float
|
Width of edges. (0.1) |
0.1
|
edges_color |
Union[str, Sequence[float], Sequence[str]]
|
Color of edges. ('grey') |
'grey'
|
neighbors_key |
Optional[str]
|
Key to use for neighbors. (None) |
None
|
arrows |
bool
|
Show arrows for velocity. (False) |
False
|
arrows_kwds |
Optional[Mapping[str, Any]]
|
Keyword arguments for arrow plots. (None) |
None
|
groups |
Optional[str]
|
Groups to highlight. (None) |
None
|
components |
Union[str, Sequence[str]]
|
Components to plot. (None) |
None
|
dimensions |
Optional[Union[Tuple[int, int], Sequence[Tuple[int, int]]]]
|
Dimensions to plot. (None) |
None
|
layer |
Optional[str]
|
Name of the layer to use for coloring. (None) |
None
|
projection |
Literal['2d', '3d']
|
Type of projection ('2d' or '3d'). ('2d') |
'2d'
|
scale_factor |
Optional[float]
|
Scaling factor for sizes. (None) |
None
|
color_map |
Union[Colormap, str, None]
|
Colormap to use for continuous variables. (None) |
None
|
cmap |
Union[Colormap, str, None]
|
Colormap to use for continuous variables. (None) |
None
|
palette |
Union[str, Sequence[str], Cycler, None]
|
Colors to use for categorical variables. (None) |
None
|
na_color |
ColorLike
|
Color to use for NaN values. ('lightgray') |
'lightgray'
|
na_in_legend |
bool
|
Include NaN values in legend. (True) |
True
|
size |
Union[float, Sequence[float], None]
|
Size of the dots. (None) |
None
|
frameon |
Optional[bool]
|
Draw a frame around the plot. ('small') |
'small'
|
legend_fontsize |
Union[int, float, _FontSize, None]
|
Font size for legend. (None) |
None
|
legend_fontweight |
Union[int, _FontWeight]
|
Font weight for legend. ('bold') |
'bold'
|
legend_loc |
str
|
Location of legend. ('right margin') |
'right margin'
|
legend_fontoutline |
Optional[int]
|
Outline width for legend text. (None) |
None
|
colorbar_loc |
Optional[str]
|
Location of colorbar. ('right') |
'right'
|
vmax |
Union[VBound, Sequence[VBound], None]
|
Maximum value for colorbar. (None) |
None
|
vmin |
Union[VBound, Sequence[VBound], None]
|
Minimum value for colorbar. (None) |
None
|
vcenter |
Union[VBound, Sequence[VBound], None]
|
Center value for colorbar. (None) |
None
|
norm |
Union[Normalize, Sequence[Normalize], None]
|
Normalization for colorbar. (None) |
None
|
add_outline |
Optional[bool]
|
Add outline to points. (False) |
False
|
outline_width |
Tuple[float, float]
|
Width of outline. ((0.3, 0.05)) |
(0.3, 0.05)
|
outline_color |
Tuple[str, str]
|
Color of outline. (('black', 'white')) |
('black', 'white')
|
ncols |
int
|
Number of columns for subplots. (4) |
4
|
hspace |
float
|
Height spacing between subplots. (0.25) |
0.25
|
wspace |
Optional[float]
|
Width spacing between subplots. (None) |
None
|
title |
Union[str, Sequence[str], None]
|
Title for the plot. (None) |
None
|
show |
Optional[bool]
|
Show the plot. (None) |
None
|
save |
Union[bool, str, None]
|
Save the plot. (None) |
None
|
ax |
Optional[Axes]
|
Matplotlib axes object. (None) |
None
|
return_fig |
Optional[bool]
|
Return figure object. (None) |
None
|
marker |
Union[str, Sequence[str]]
|
Marker style. ('.') |
'.'
|
**kwargs |
Additional keyword arguments. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
ax |
Union[Figure, Axes, None]
|
If |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def embedding(
adata: AnnData,
basis: str,
*,
color: Union[str, Sequence[str], None] = None,
gene_symbols: Optional[str] = None,
use_raw: Optional[bool] = None,
sort_order: bool = True,
edges: bool = False,
edges_width: float = 0.1,
edges_color: Union[str, Sequence[float], Sequence[str]] = 'grey',
neighbors_key: Optional[str] = None,
arrows: bool = False,
arrows_kwds: Optional[Mapping[str, Any]] = None,
groups: Optional[str] = None,
components: Union[str, Sequence[str]] = None,
dimensions: Optional[Union[Tuple[int, int], Sequence[Tuple[int, int]]]] = None,
layer: Optional[str] = None,
projection: Literal['2d', '3d'] = '2d',
scale_factor: Optional[float] = None,
color_map: Union[Colormap, str, None] = None,
cmap: Union[Colormap, str, None] = None,
palette: Union[str, Sequence[str], Cycler, None] = None,
na_color: ColorLike = "lightgray",
na_in_legend: bool = True,
size: Union[float, Sequence[float], None] = None,
frameon: Optional[bool] = 'small',
legend_fontsize: Union[int, float, _FontSize, None] = None,
legend_fontweight: Union[int, _FontWeight] = 'bold',
legend_loc: str = 'right margin',
legend_fontoutline: Optional[int] = None,
colorbar_loc: Optional[str] = "right",
vmax: Union[VBound, Sequence[VBound], None] = None,
vmin: Union[VBound, Sequence[VBound], None] = None,
vcenter: Union[VBound, Sequence[VBound], None] = None,
norm: Union[Normalize, Sequence[Normalize], None] = None,
add_outline: Optional[bool] = False,
outline_width: Tuple[float, float] = (0.3, 0.05),
outline_color: Tuple[str, str] = ('black', 'white'),
ncols: int = 4,
hspace: float = 0.25,
wspace: Optional[float] = None,
title: Union[str, Sequence[str], None] = None,
show: Optional[bool] = None,
save: Union[bool, str, None] = None,
ax: Optional[Axes] = None,
return_fig: Optional[bool] = None,
marker: Union[str, Sequence[str]] = '.',
**kwargs,
) -> Union[Figure, Axes, None]:
r"""Scatter plot for user specified embedding basis (e.g. umap, pca, etc).
Arguments:
adata: Annotated data matrix.
basis: Name of the `obsm` basis to use.
color: Keys for annotations of observations/cells or variables/genes. (None)
gene_symbols: Key for field in `.var` that stores gene symbols. (None)
use_raw: Use `.raw` attribute of `adata` if present. (None)
sort_order: For continuous annotations used as color parameter, plot data points with higher values on top of others. (True)
edges: Show edges between cells. (False)
edges_width: Width of edges. (0.1)
edges_color: Color of edges. ('grey')
neighbors_key: Key to use for neighbors. (None)
arrows: Show arrows for velocity. (False)
arrows_kwds: Keyword arguments for arrow plots. (None)
groups: Groups to highlight. (None)
components: Components to plot. (None)
dimensions: Dimensions to plot. (None)
layer: Name of the layer to use for coloring. (None)
projection: Type of projection ('2d' or '3d'). ('2d')
scale_factor: Scaling factor for sizes. (None)
color_map: Colormap to use for continuous variables. (None)
cmap: Colormap to use for continuous variables. (None)
palette: Colors to use for categorical variables. (None)
na_color: Color to use for NaN values. ('lightgray')
na_in_legend: Include NaN values in legend. (True)
size: Size of the dots. (None)
frameon: Draw a frame around the plot. ('small')
legend_fontsize: Font size for legend. (None)
legend_fontweight: Font weight for legend. ('bold')
legend_loc: Location of legend. ('right margin')
legend_fontoutline: Outline width for legend text. (None)
colorbar_loc: Location of colorbar. ('right')
vmax: Maximum value for colorbar. (None)
vmin: Minimum value for colorbar. (None)
vcenter: Center value for colorbar. (None)
norm: Normalization for colorbar. (None)
add_outline: Add outline to points. (False)
outline_width: Width of outline. ((0.3, 0.05))
outline_color: Color of outline. (('black', 'white'))
ncols: Number of columns for subplots. (4)
hspace: Height spacing between subplots. (0.25)
wspace: Width spacing between subplots. (None)
title: Title for the plot. (None)
show: Show the plot. (None)
save: Save the plot. (None)
ax: Matplotlib axes object. (None)
return_fig: Return figure object. (None)
marker: Marker style. ('.')
**kwargs: Additional keyword arguments.
Returns:
ax: If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
return _embedding(adata=adata, basis=basis, color=color,
gene_symbols=gene_symbols, use_raw=use_raw,
sort_order=sort_order, edges=edges,
edges_width=edges_width, edges_color=edges_color,
neighbors_key=neighbors_key, arrows=arrows,
arrows_kwds=arrows_kwds, groups=groups,
components=components, dimensions=dimensions,
layer=layer, projection=projection, scale_factor=scale_factor,
color_map=color_map, cmap=cmap, palette=palette,
na_color=na_color, na_in_legend=na_in_legend,
size=size, frameon=frameon, legend_fontsize=legend_fontsize,
legend_fontweight=legend_fontweight, legend_loc=legend_loc,
legend_fontoutline=legend_fontoutline, colorbar_loc=colorbar_loc,
vmax=vmax, vmin=vmin, vcenter=vcenter, norm=norm,
add_outline=add_outline, outline_width=outline_width,
outline_color=outline_color, ncols=ncols, hspace=hspace,
wspace=wspace, title=title, show=show, save=save, ax=ax,
return_fig=return_fig, marker=marker, **kwargs)
omicverse.pl.embedding_celltype(adata, figsize=(6, 4), basis='umap', celltype_key='major_celltype', title=None, celltype_range=(2, 9), embedding_range=(3, 10), xlim=-1000)
¶
Plot embedding with celltype color by omicverse.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData
|
AnnData object |
required |
figsize |
tuple
|
tuple, optional (default=(6,4)) Figure size |
(6, 4)
|
basis |
str
|
str, optional (default='umap') Embedding method |
'umap'
|
celltype_key |
str
|
str, optional (default='major_celltype') Celltype key in adata.obs |
'major_celltype'
|
title |
str
|
str, optional (default=None) Figure title |
None
|
celltype_range |
tuple
|
tuple, optional (default=(2,9)) Celltype range to plot |
(2, 9)
|
embedding_range |
tuple
|
tuple, optional (default=(3,10)) Embedding range to plot |
(3, 10)
|
xlim |
int
|
int, optional (default=-1000) X axis limit |
-1000
|
Returns:
Name | Type | Description |
---|---|---|
fig |
tuple
|
figure and axis |
ax |
tuple
|
axis |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def embedding_celltype(adata:AnnData,figsize:tuple=(6,4),basis:str='umap',
celltype_key:str='major_celltype',title:str=None,
celltype_range:tuple=(2,9),
embedding_range:tuple=(3,10),
xlim:int=-1000)->tuple:
r"""
Plot embedding with celltype color by omicverse.
Arguments:
adata: AnnData object
figsize: tuple, optional (default=(6,4))
Figure size
basis: str, optional (default='umap')
Embedding method
celltype_key: str, optional (default='major_celltype')
Celltype key in adata.obs
title: str, optional (default=None)
Figure title
celltype_range: tuple, optional (default=(2,9))
Celltype range to plot
embedding_range: tuple, optional (default=(3,10))
Embedding range to plot
xlim: int, optional (default=-1000)
X axis limit
Returns:
fig: figure and axis
ax: axis
"""
adata.obs[celltype_key]=adata.obs[celltype_key].astype('category')
if pd.__version__>="2.0.0":
cell_num_pd=pd.DataFrame(adata.obs[celltype_key].value_counts())
cell_num_pd[celltype_key]=cell_num_pd['count']
else:
cell_num_pd=pd.DataFrame(adata.obs[celltype_key].value_counts())
if '{}_colors'.format(celltype_key) in adata.uns.keys():
cell_color_dict=dict(zip(adata.obs[celltype_key].cat.categories.tolist(),
adata.uns['{}_colors'.format(celltype_key)]))
else:
if len(adata.obs[celltype_key].cat.categories)>28:
cell_color_dict=dict(zip(adata.obs[celltype_key].cat.categories,sc.pl.palettes.default_102))
else:
cell_color_dict=dict(zip(adata.obs[celltype_key].cat.categories,sc.pl.palettes.zeileis_28))
if figsize==None:
if len(adata.obs[celltype_key].cat.categories)<10:
fig = plt.figure(figsize=(6,4))
else:
print('The number of cell types is too large, please set the figsize parameter')
return
else:
fig = plt.figure(figsize=figsize)
grid = plt.GridSpec(10, 10)
ax1 = fig.add_subplot(grid[:, embedding_range[0]:embedding_range[1]]) # 占据第一行的所有列
ax2 = fig.add_subplot(grid[celltype_range[0]:celltype_range[1], :2])
# 定义子图的大小和位置
# 占据第二行的前两列
#ax3 = fig.add_subplot(grid[1:, 2]) # 占据第二行及以后的最后一列
#ax4 = fig.add_subplot(grid[2, 0]) # 占据最后一行的第一列
#ax5 = fig.add_subplot(grid[2, 1]) # 占据最后一行的第二列
sc.pl.embedding(
adata,
basis=basis,
color=[celltype_key],
title='',
frameon=False,
#wspace=0.65,
ncols=3,
ax=ax1,
legend_loc=False,
show=False
)
for idx,cell in zip(range(cell_num_pd.shape[0]),
adata.obs[celltype_key].cat.categories):
ax2.scatter(100,
cell,c=cell_color_dict[cell],s=50)
ax2.plot((100,cell_num_pd.loc[cell,celltype_key]),(idx,idx),
c=cell_color_dict[cell],lw=4)
ax2.text(100,idx+0.2,
cell+'('+str("{:,}".format(cell_num_pd.loc[cell,celltype_key]))+')',fontsize=11)
ax2.set_xlim(xlim,cell_num_pd.iloc[1].values[0])
ax2.text(xlim,idx+1,title,fontsize=12)
ax2.grid(False)
#ax2.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10,**legend_awargs)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax2.axis('off')
# ——关键:确保 ax2 没有图例——
if ax1.get_legend() is not None: # 如果有,就移除
ax1.get_legend().remove()
if ax2.get_legend() is not None: # 如果有,就移除
ax2.get_legend().remove()
return fig,[ax1,ax2]
omicverse.pl.embedding_adjust(adata, groupby, exclude=(), basis='X_umap', ax=None, adjust_kwargs=None, text_kwargs=None)
¶
Get locations of cluster median and adjust text labels accordingly.
Borrowed from scanpy github forum.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData object |
required | |
groupby |
str Key in adata.obs for grouping |
required | |
exclude |
tuple, optional (default=()) Groups to exclude from labeling |
()
|
|
basis |
str, optional (default='X_umap') Embedding basis key in adata.obsm |
'X_umap'
|
|
ax |
matplotlib.axes.Axes, optional (default=None) Axes object to plot on |
None
|
|
adjust_kwargs |
dict, optional (default=None) Arguments for adjust_text function |
None
|
|
text_kwargs |
dict, optional (default=None) Arguments for text annotation |
None
|
Returns:
Name | Type | Description |
---|---|---|
medians | dict Dictionary of median positions for each group |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def embedding_adjust(
adata, groupby, exclude=(),
basis='X_umap',ax=None, adjust_kwargs=None, text_kwargs=None
):
r"""
Get locations of cluster median and adjust text labels accordingly.
Borrowed from scanpy github forum.
Arguments:
adata: AnnData object
groupby: str
Key in adata.obs for grouping
exclude: tuple, optional (default=())
Groups to exclude from labeling
basis: str, optional (default='X_umap')
Embedding basis key in adata.obsm
ax: matplotlib.axes.Axes, optional (default=None)
Axes object to plot on
adjust_kwargs: dict, optional (default=None)
Arguments for adjust_text function
text_kwargs: dict, optional (default=None)
Arguments for text annotation
Returns:
medians: dict
Dictionary of median positions for each group
"""
if adjust_kwargs is None:
adjust_kwargs = {"text_from_points": False}
if text_kwargs is None:
text_kwargs = {}
medians = {}
for g, g_idx in adata.obs.groupby(groupby).groups.items():
if g in exclude:
continue
medians[g] = np.median(adata[g_idx].obsm[basis], axis=0)
if ax is None:
texts = [
plt.text(x=x, y=y, s=k, **text_kwargs) for k, (x, y) in medians.items()
]
else:
texts = [ax.text(x=x, y=y, s=k, **text_kwargs) for k, (x, y) in medians.items()]
from adjustText import adjust_text
adjust_text(texts, **adjust_kwargs)
omicverse.pl.embedding_atlas(adata, basis, color, title=None, figsize=(4, 4), ax=None, cmap='RdBu', legend_loc='right margin', frameon='small', fontsize=12)
¶
Create high-resolution embedding plots using Datashader for large datasets.
Uses Datashader to render embeddings at high resolution, suitable for datasets with millions of cells where standard scatter plots become ineffective.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
Annotated data object with embedding coordinates |
required | |
basis |
Key in adata.obsm containing embedding coordinates (e.g., 'X_umap') |
required | |
color |
Gene name or obs column to color cells by |
required | |
title |
Plot title (None, uses color name) |
None
|
|
figsize |
Figure dimensions as (width, height) ((4,4)) |
(4, 4)
|
|
ax |
Existing matplotlib axes object (None) |
None
|
|
cmap |
Colormap for continuous values ('RdBu') |
'RdBu'
|
|
legend_loc |
Legend position ('right margin') |
'right margin'
|
|
frameon |
Frame style - False, 'small', or True ('small') |
'small'
|
|
fontsize |
Font size for labels and title (12) |
12
|
Returns:
Name | Type | Description |
---|---|---|
ax | matplotlib.axes.Axes object with rendered embedding |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_embedding.py
def embedding_atlas(adata,basis,color,
title=None,figsize=(4,4),ax=None,cmap='RdBu',
legend_loc = 'right margin',frameon='small',
fontsize=12):
r"""
Create high-resolution embedding plots using Datashader for large datasets.
Uses Datashader to render embeddings at high resolution, suitable for datasets
with millions of cells where standard scatter plots become ineffective.
Arguments:
adata: Annotated data object with embedding coordinates
basis: Key in adata.obsm containing embedding coordinates (e.g., 'X_umap')
color: Gene name or obs column to color cells by
title: Plot title (None, uses color name)
figsize: Figure dimensions as (width, height) ((4,4))
ax: Existing matplotlib axes object (None)
cmap: Colormap for continuous values ('RdBu')
legend_loc: Legend position ('right margin')
frameon: Frame style - False, 'small', or True ('small')
fontsize: Font size for labels and title (12)
Returns:
ax: matplotlib.axes.Axes object with rendered embedding
"""
import scanpy as sc
import pandas as pd
import datashader as ds
import datashader.transfer_functions as tf
from scipy.sparse import issparse
from bokeh.palettes import RdBu9
import bokeh
# 创建一个 Canvas 对象
cvs = ds.Canvas(plot_width=800, plot_height=800)
embedding = adata.obsm[basis]
# 如果你有一个感兴趣的分类标签,比如细胞类型
# 将数据转换为 DataFrame
df = pd.DataFrame(embedding, columns=['x', 'y'])
if color in adata.obs.columns:
labels = adata.obs[color].tolist() # 假设'cell_type'是一个列名
elif color in adata.var_names:
X=adata[:,color].X
if issparse(X):
labels=X.toarray().reshape(-1)
else:
labels=X.reshape(-1)
elif (not adata.raw is None) and (color in adata.raw.var_names):
X=adata.raw[:,color].X
if issparse(X):
labels=X.toarray().reshape(-1)
else:
labels=X.reshape(-1)
df['label'] = labels
#return labels
#print(labels[0],type(labels[0]))
if type(labels[0]) is str:
df['label']=df['label'].astype('category')
# 聚合数据
agg = cvs.points(df, 'x', 'y',ds.count_cat('label'),
)
legend_tag=True
color_key = dict(zip(adata.obs[color].cat.categories,
adata.uns[f'{color}_colors']))
# 使用色彩映射
img = tf.shade(tf.spread(agg,px=0),color_key=[color_key[i] for i in df['label'].cat.categories],
how='eq_hist')
elif (type(labels[0]) is int) or (type(labels[0]) is float) or (type(labels[0]) is np.float32)\
or (type(labels[0]) is np.float64) or (type(labels[0]) is np.int):
# 聚合数据
agg = cvs.points(df, 'x', 'y',ds.mean('label'),
)
legend_tag=False
if cmap in bokeh.palettes.all_palettes.keys():
num=list(bokeh.palettes.all_palettes[cmap].keys())[-1]
img = tf.shade(agg,cmap=bokeh.palettes.all_palettes[cmap][num],
)
else:
img = tf.shade(agg,cmap=cmap,
)
else:
print('Unrecognized label type')
return None
# 假设 img 是 Datashader 渲染的图像
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig=ax.figure
# 假设 img 是一个 NumPy 数组或类似的对象,这里使用 img 的占位符
# img = np.random.rand(100, 100) # 示例数据
ax.imshow(img.to_pil(), aspect='auto')
# 自定义格式化函数以显示坐标
def format_coord(x, y):
return f"x={x:.2f}, y={y:.2f}"
ax.format_coord = format_coord
if legend_tag==True:
# 手动创建图例
unique_labels = adata.obs[color].cat.categories
# 创建图例项
for label in unique_labels:
ax.scatter([], [], c=color_key[label], label=label)
if legend_loc == "right margin":
ax.legend(
frameon=False,
loc="center left",
bbox_to_anchor=(1, 0.5),
ncol=(1 if len(unique_labels) <= 14 else 2 if len(unique_labels) <= 30 else 3),
fontsize=fontsize-1,
)
if frameon==False:
ax.axis('off')
elif frameon=='small':
ax.axis('on')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_bounds(0,150)
ax.spines['left'].set_bounds(650,800)
ax.set_xlabel(f'{basis}1',loc='left',fontsize=fontsize)
ax.set_ylabel(f'{basis}2',loc='bottom',fontsize=fontsize)
else:
ax.axis('on')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.set_xlabel(f'{basis}1',loc='center',fontsize=fontsize)
ax.set_ylabel(f'{basis}2',loc='center',fontsize=fontsize)
# 调整坐标轴线的粗细
line_width = 1.2 # 设置线宽
ax.spines['left'].set_linewidth(line_width)
ax.spines['bottom'].set_linewidth(line_width)
if title is None:
title=color
ax.set_title(title,fontsize=fontsize+1)
return ax
omicverse.pl.embedding_multi(data, basis, color=None, use_raw=None, layer=None, **kwargs)
¶
Create embedding scatter plots for multi-modal data (MuData) or single-cell data.
Produces scatter plots on specified embeddings, supporting cross-modality feature visualization. For modality-specific embeddings, use format 'modality:embedding' (e.g., 'rna:X_pca').
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
AnnData
|
AnnData or MuData object containing embedding and feature data |
required |
basis |
str
|
Name of embedding in obsm (e.g., 'X_umap') or modality-specific ('rna:X_pca') |
required |
color |
Optional[Union[str, Sequence[str]]]
|
Gene names or obs columns to color points by (None) |
None
|
use_raw |
Optional[bool]
|
Whether to use .raw attribute for features (None, auto-determined) |
None
|
layer |
Optional[str]
|
Specific data layer to use for coloring (None) |
None
|
**kwargs |
Additional arguments passed to embedding plotting function |
{}
|
Returns:
Type | Description |
---|---|
Plot axes or figure depending on underlying plotting function |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_multi.py
def embedding_multi(
data: AnnData,
basis: str,
color: Optional[Union[str, Sequence[str]]] = None,
use_raw: Optional[bool] = None,
layer: Optional[str] = None,
**kwargs,
):
r"""
Create embedding scatter plots for multi-modal data (MuData) or single-cell data.
Produces scatter plots on specified embeddings, supporting cross-modality feature visualization.
For modality-specific embeddings, use format 'modality:embedding' (e.g., 'rna:X_pca').
Arguments:
data: AnnData or MuData object containing embedding and feature data
basis: Name of embedding in obsm (e.g., 'X_umap') or modality-specific ('rna:X_pca')
color: Gene names or obs columns to color points by (None)
use_raw: Whether to use .raw attribute for features (None, auto-determined)
layer: Specific data layer to use for coloring (None)
**kwargs: Additional arguments passed to embedding plotting function
Returns:
Plot axes or figure depending on underlying plotting function
"""
if isinstance(data, AnnData):
return embedding(
data, basis=basis, color=color, use_raw=use_raw, layer=layer, **kwargs
)
# `data` is MuData
if basis not in data.obsm and "X_" + basis in data.obsm:
basis = "X_" + basis
if basis in data.obsm:
adata = data
basis_mod = basis
else:
# basis is not a joint embedding
try:
mod, basis_mod = basis.split(":")
except ValueError:
raise ValueError(f"Basis {basis} is not present in the MuData object (.obsm)")
if mod not in data.mod:
raise ValueError(
f"Modality {mod} is not present in the MuData object with modalities {', '.join(data.mod)}"
)
adata = data.mod[mod]
if basis_mod not in adata.obsm:
if "X_" + basis_mod in adata.obsm:
basis_mod = "X_" + basis_mod
elif len(adata.obsm) > 0:
raise ValueError(
f"Basis {basis_mod} is not present in the modality {mod} with embeddings {', '.join(adata.obsm)}"
)
else:
raise ValueError(
f"Basis {basis_mod} is not present in the modality {mod} with no embeddings"
)
obs = data.obs.loc[adata.obs.index.values]
if color is None:
ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp)
return sc.pl.embedding(ad, basis=basis_mod, **kwargs)
# Some `color` has been provided
if isinstance(color, str):
keys = color = [color]
elif isinstance(color, Iterable):
keys = color
else:
raise TypeError("Expected color to be a string or an iterable.")
# Fetch respective features
if not all([key in obs for key in keys]):
# {'rna': [True, False], 'prot': [False, True]}
keys_in_mod = {m: [key in data.mod[m].var_names for key in keys] for m in data.mod}
# .raw slots might have exclusive var_names
if use_raw is None or use_raw:
for i, k in enumerate(keys):
for m in data.mod:
if keys_in_mod[m][i] == False and data.mod[m].raw is not None:
keys_in_mod[m][i] = k in data.mod[m].raw.var_names
# e.g. color="rna:CD8A" - especially relevant for mdata.axis == -1
mod_key_modifier: dict[str, str] = dict()
for i, k in enumerate(keys):
mod_key_modifier[k] = k
for m in data.mod:
if not keys_in_mod[m][i]:
k_clean = k
if k.startswith(f"{m}:"):
k_clean = k.split(":", 1)[1]
keys_in_mod[m][i] = k_clean in data.mod[m].var_names
if keys_in_mod[m][i]:
mod_key_modifier[k] = k_clean
if use_raw is None or use_raw:
if keys_in_mod[m][i] == False and data.mod[m].raw is not None:
keys_in_mod[m][i] = k_clean in data.mod[m].raw.var_names
for m in data.mod:
if np.sum(keys_in_mod[m]) > 0:
mod_keys = np.array(keys)[keys_in_mod[m]]
mod_keys = np.array([mod_key_modifier[k] for k in mod_keys])
if use_raw is None or use_raw:
if data.mod[m].raw is not None:
keysidx = data.mod[m].raw.var.index.get_indexer_for(mod_keys)
fmod_adata = AnnData(
X=data.mod[m].raw.X[:, keysidx],
var=pd.DataFrame(index=mod_keys),
obs=data.mod[m].obs,
)
else:
if use_raw:
warnings.warn(
f"Attibute .raw is None for the modality {m}, using .X instead"
)
fmod_adata = data.mod[m][:, mod_keys]
else:
fmod_adata = data.mod[m][:, mod_keys]
if layer is not None:
if isinstance(layer, Dict):
m_layer = layer.get(m, None)
if m_layer is not None:
x = data.mod[m][:, mod_keys].layers[m_layer]
fmod_adata.X = x.todense() if issparse(x) else x
if use_raw:
warnings.warn(f"Layer='{layer}' superseded use_raw={use_raw}")
elif layer in data.mod[m].layers:
x = data.mod[m][:, mod_keys].layers[layer]
fmod_adata.X = x.todense() if issparse(x) else x
if use_raw:
warnings.warn(f"Layer='{layer}' superseded use_raw={use_raw}")
else:
warnings.warn(
f"Layer {layer} is not present for the modality {m}, using count matrix instead"
)
x = fmod_adata.X.toarray() if issparse(fmod_adata.X) else fmod_adata.X
obs = obs.join(
pd.DataFrame(x, columns=mod_keys, index=fmod_adata.obs_names),
how="left",
)
color = [mod_key_modifier[k] for k in keys]
ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp, uns=adata.uns)
retval = embedding(ad, basis=basis_mod, color=color, **kwargs)
for key, col in zip(keys, color):
try:
adata.uns[f"{key}_colors"] = ad.uns[f"{col}_colors"]
except KeyError:
pass
return retval
Cell Proportion and Analysis¶
omicverse.pl.cellproportion(adata, celltype_clusters, groupby, groupby_li=None, figsize=(4, 6), ticks_fontsize=12, labels_fontsize=12, ax=None, legend=False, legend_awargs={'ncol': 1}, transpose=False)
¶
Plot cell proportion of each cell type in each visual cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData
|
AnnData object. |
required |
celltype_clusters |
str
|
Cell type clusters. |
required |
groupby |
str
|
Visual clusters. |
required |
groupby_li |
Visual cluster list. (None) |
None
|
|
figsize |
tuple
|
Figure size. ((4,6)) |
(4, 6)
|
ticks_fontsize |
int
|
Ticks fontsize. (12) |
12
|
labels_fontsize |
int
|
Labels fontsize. (12) |
12
|
ax |
Matplotlib axes object. (None) |
None
|
|
legend |
bool
|
Whether to show legend. (False) |
False
|
legend_awargs |
Legend arguments. ({'ncol':1}) |
{'ncol': 1}
|
|
transpose |
bool
|
Whether to transpose the plot (horizontal bars). (False) |
False
|
Returns:
Type | Description |
---|---|
None |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_single.py
def cellproportion(adata:AnnData,celltype_clusters:str,groupby:str,
groupby_li=None,figsize:tuple=(4,6),
ticks_fontsize:int=12,labels_fontsize:int=12,ax=None,
legend:bool=False,legend_awargs={'ncol':1},transpose:bool=False):
r"""Plot cell proportion of each cell type in each visual cluster.
Arguments:
adata: AnnData object.
celltype_clusters: Cell type clusters.
groupby: Visual clusters.
groupby_li: Visual cluster list. (None)
figsize: Figure size. ((4,6))
ticks_fontsize: Ticks fontsize. (12)
labels_fontsize: Labels fontsize. (12)
ax: Matplotlib axes object. (None)
legend: Whether to show legend. (False)
legend_awargs: Legend arguments. ({'ncol':1})
transpose: Whether to transpose the plot (horizontal bars). (False)
Returns:
None
"""
b=pd.DataFrame(columns=['cell_type','value','Week'])
visual_clusters=groupby
visual_li=groupby_li
if visual_li==None:
adata.obs[visual_clusters]=adata.obs[visual_clusters].astype('category')
visual_li=adata.obs[visual_clusters].cat.categories
for i in visual_li:
b1=pd.DataFrame()
test=adata.obs.loc[adata.obs[visual_clusters]==i,celltype_clusters].value_counts()
b1['cell_type']=test.index
b1['value']=test.values/test.sum()
b1['Week']=i.replace('Retinoblastoma_','')
b=pd.concat([b,b1])
plt_data2=adata.obs[celltype_clusters].value_counts()
plot_data2_color_dict=dict(zip(adata.obs[celltype_clusters].cat.categories,
adata.uns['{}_colors'.format(celltype_clusters)]))
plt_data3=adata.obs[visual_clusters].value_counts()
plot_data3_color_dict=dict(zip([i.replace('Retinoblastoma_','') for i in adata.obs[visual_clusters].cat.categories],adata.uns['{}_colors'.format(visual_clusters)]))
b['cell_type_color'] = b['cell_type'].map(plot_data2_color_dict)
b['stage_color']=b['Week'].map(plot_data3_color_dict)
if ax==None:
fig, ax = plt.subplots(figsize=figsize)
#用ax控制图片
#sns.set_theme(style="whitegrid")
#sns.set_theme(style="ticks")
n=0
all_celltype=adata.obs[celltype_clusters].cat.categories
for i in all_celltype:
if n==0:
test1=b[b['cell_type']==i]
if transpose:
ax.barh(y=test1['Week'],width=test1['value'],height=0.8,color=list(set(test1['cell_type_color']))[0], label=i)
else:
ax.bar(x=test1['Week'],height=test1['value'],width=0.8,color=list(set(test1['cell_type_color']))[0], label=i)
bottoms=test1['value'].values
else:
test2=b[b['cell_type']==i]
if transpose:
ax.barh(y=test2['Week'],width=test2['value'],left=bottoms,height=0.8,color=list(set(test2['cell_type_color']))[0], label=i)
else:
ax.bar(x=test2['Week'],height=test2['value'],bottom=bottoms,width=0.8,color=list(set(test2['cell_type_color']))[0], label=i)
test1=test2
bottoms+=test1['value'].values
n+=1
if legend!=False:
plt.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10,**legend_awargs)
plt.grid(False)
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
# 设置左边和下边的坐标刻度为透明色
#ax.yaxis.tick_left()
#ax.xaxis.tick_bottom()
#ax.xaxis.set_tick_params(color='none')
#ax.yaxis.set_tick_params(color='none')
# 设置左边和下边的坐标轴线为独立的线段
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
if transpose:
plt.yticks(fontsize=ticks_fontsize,rotation=0)
plt.xticks(fontsize=ticks_fontsize)
plt.ylabel(groupby,fontsize=labels_fontsize)
plt.xlabel('Cells per Stage',fontsize=labels_fontsize)
else:
plt.xticks(fontsize=ticks_fontsize,rotation=90)
plt.yticks(fontsize=ticks_fontsize)
plt.xlabel(groupby,fontsize=labels_fontsize)
plt.ylabel('Cells per Stage',fontsize=labels_fontsize)
#fig.tight_layout()
if ax==None:
return fig,ax
Bulk Data Visualization¶
omicverse.pl.volcano(result, pval_name='qvalue', fc_name='log2FC', pval_max=None, FC_max=None, figsize=(4, 4), title='', titlefont={'weight': 'normal', 'size': 14}, up_color='#e25d5d', down_color='#7388c1', normal_color='#d7d7d7', up_fontcolor='#e25d5d', down_fontcolor='#7388c1', normal_fontcolor='#d7d7d7', legend_bbox=(0.8, -0.2), legend_ncol=2, legend_fontsize=12, plot_genes=None, plot_genes_num=10, plot_genes_fontsize=10, ticks_fontsize=12, pval_threshold=0.05, fc_max=1.5, fc_min=-1.5, ax=None)
¶
Create a volcano plot for differential expression analysis.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
pandas.DataFrame Dataframe containing differential expression results with 'sig' column |
required | |
pval_name |
str, optional (default='qvalue') Column name for p-values/q-values |
'qvalue'
|
|
fc_name |
str, optional (default='log2FC') Column name for fold change values |
'log2FC'
|
|
pval_max |
float, optional (default=None) Maximum p-value for y-axis scaling |
None
|
|
FC_max |
float, optional (default=None) Maximum fold change for x-axis scaling |
None
|
|
figsize |
tuple
|
tuple, optional (default=(4,4)) Figure size (width, height) |
(4, 4)
|
title |
str
|
str, optional (default='') Plot title |
''
|
titlefont |
dict
|
dict, optional (default={'weight':'normal','size':14}) Font settings for title and axis labels |
{'weight': 'normal', 'size': 14}
|
up_color |
str
|
str, optional (default='#e25d5d') Color for upregulated genes |
'#e25d5d'
|
down_color |
str
|
str, optional (default='#7388c1') Color for downregulated genes |
'#7388c1'
|
normal_color |
str
|
str, optional (default='#d7d7d7') Color for non-significant genes |
'#d7d7d7'
|
up_fontcolor |
str
|
str, optional (default='#e25d5d') Font color for upregulated gene labels |
'#e25d5d'
|
down_fontcolor |
str
|
str, optional (default='#7388c1') Font color for downregulated gene labels |
'#7388c1'
|
normal_fontcolor |
str
|
str, optional (default='#d7d7d7') Font color for normal gene labels |
'#d7d7d7'
|
legend_bbox |
tuple
|
tuple, optional (default=(0.8, -0.2)) Legend bounding box position |
(0.8, -0.2)
|
legend_ncol |
int
|
int, optional (default=2) Number of legend columns |
2
|
legend_fontsize |
int
|
int, optional (default=12) Legend font size |
12
|
plot_genes |
list
|
list, optional (default=None) Specific genes to label on plot |
None
|
plot_genes_num |
int
|
int, optional (default=10) Number of top genes to label automatically |
10
|
plot_genes_fontsize |
int
|
int, optional (default=10) Font size for gene labels |
10
|
ticks_fontsize |
int
|
int, optional (default=12) Font size for axis ticks |
12
|
pval_threshold |
float
|
float, optional (default=0.05) P-value threshold for significance |
0.05
|
fc_max |
float
|
float, optional (default=1.5) Upper fold change threshold |
1.5
|
fc_min |
float
|
float, optional (default=-1.5) Lower fold change threshold |
-1.5
|
ax |
matplotlib.axes, optional (default=None) Existing axes to plot on |
None
|
Returns:
Name | Type | Description |
---|---|---|
ax | matplotlib.axes The plot axes object |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_bulk.py
def volcano(result,pval_name='qvalue',fc_name='log2FC',pval_max=None,FC_max=None,
figsize:tuple=(4,4),title:str='',titlefont:dict={'weight':'normal','size':14,},
up_color:str='#e25d5d',down_color:str='#7388c1',normal_color:str='#d7d7d7',
up_fontcolor:str='#e25d5d',down_fontcolor:str='#7388c1',normal_fontcolor:str='#d7d7d7',
legend_bbox:tuple=(0.8, -0.2),legend_ncol:int=2,legend_fontsize:int=12,
plot_genes:list=None,plot_genes_num:int=10,plot_genes_fontsize:int=10,
ticks_fontsize:int=12,pval_threshold:float=0.05,fc_max:float=1.5,fc_min:float=-1.5,
ax = None,):
r"""
Create a volcano plot for differential expression analysis.
Arguments:
result: pandas.DataFrame
Dataframe containing differential expression results with 'sig' column
pval_name: str, optional (default='qvalue')
Column name for p-values/q-values
fc_name: str, optional (default='log2FC')
Column name for fold change values
pval_max: float, optional (default=None)
Maximum p-value for y-axis scaling
FC_max: float, optional (default=None)
Maximum fold change for x-axis scaling
figsize: tuple, optional (default=(4,4))
Figure size (width, height)
title: str, optional (default='')
Plot title
titlefont: dict, optional (default={'weight':'normal','size':14})
Font settings for title and axis labels
up_color: str, optional (default='#e25d5d')
Color for upregulated genes
down_color: str, optional (default='#7388c1')
Color for downregulated genes
normal_color: str, optional (default='#d7d7d7')
Color for non-significant genes
up_fontcolor: str, optional (default='#e25d5d')
Font color for upregulated gene labels
down_fontcolor: str, optional (default='#7388c1')
Font color for downregulated gene labels
normal_fontcolor: str, optional (default='#d7d7d7')
Font color for normal gene labels
legend_bbox: tuple, optional (default=(0.8, -0.2))
Legend bounding box position
legend_ncol: int, optional (default=2)
Number of legend columns
legend_fontsize: int, optional (default=12)
Legend font size
plot_genes: list, optional (default=None)
Specific genes to label on plot
plot_genes_num: int, optional (default=10)
Number of top genes to label automatically
plot_genes_fontsize: int, optional (default=10)
Font size for gene labels
ticks_fontsize: int, optional (default=12)
Font size for axis ticks
pval_threshold: float, optional (default=0.05)
P-value threshold for significance
fc_max: float, optional (default=1.5)
Upper fold change threshold
fc_min: float, optional (default=-1.5)
Lower fold change threshold
ax: matplotlib.axes, optional (default=None)
Existing axes to plot on
Returns:
ax: matplotlib.axes
The plot axes object
"""
# Color codes for terminal output
class Colors:
HEADER = '\033[95m'
BLUE = '\033[94m'
CYAN = '\033[96m'
GREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
# Analyze the input data
print(f"{Colors.HEADER}{Colors.BOLD}🌋 Volcano Plot Analysis:{Colors.ENDC}")
print(f" {Colors.CYAN}Total genes: {Colors.BOLD}{len(result)}{Colors.ENDC}")
# Check required columns
required_cols = [pval_name, fc_name, 'sig']
missing_cols = [col for col in required_cols if col not in result.columns]
if missing_cols:
print(f" {Colors.FAIL}❌ Missing required columns: {Colors.BOLD}{missing_cols}{Colors.ENDC}")
raise ValueError(f"Missing required columns: {missing_cols}")
# Calculate gene counts by significance
sig_counts = result['sig'].value_counts()
total_sig = sig_counts.get('up', 0) + sig_counts.get('down', 0)
print(f" {Colors.GREEN}↗️ Upregulated genes: {Colors.BOLD}{sig_counts.get('up', 0)}{Colors.ENDC}")
print(f" {Colors.BLUE}↘️ Downregulated genes: {Colors.BOLD}{sig_counts.get('down', 0)}{Colors.ENDC}")
print(f" {Colors.CYAN}➡️ Non-significant genes: {Colors.BOLD}{sig_counts.get('normal', 0)}{Colors.ENDC}")
print(f" {Colors.WARNING}🎯 Total significant genes: {Colors.BOLD}{total_sig}{Colors.ENDC}")
# Data range information
fc_range = result[fc_name].max() - result[fc_name].min()
pval_range = result[pval_name].max() - result[pval_name].min()
print(f" {Colors.BLUE}{fc_name} range: {Colors.BOLD}{result[fc_name].min():.2f} to {result[fc_name].max():.2f}{Colors.ENDC}")
print(f" {Colors.BLUE}{pval_name} range: {Colors.BOLD}{result[pval_name].min():.2e} to {result[pval_name].max():.2e}{Colors.ENDC}")
# Display current function parameters
print(f"\n{Colors.HEADER}{Colors.BOLD}⚙️ Current Function Parameters:{Colors.ENDC}")
print(f" {Colors.BLUE}Data columns: pval_name='{pval_name}', fc_name='{fc_name}'{Colors.ENDC}")
print(f" {Colors.BLUE}Thresholds: pval_threshold={Colors.BOLD}{pval_threshold}{Colors.ENDC}{Colors.BLUE}, fc_max={Colors.BOLD}{fc_max}{Colors.ENDC}{Colors.BLUE}, fc_min={Colors.BOLD}{fc_min}{Colors.ENDC}")
print(f" {Colors.BLUE}Plot size: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
print(f" {Colors.BLUE}Gene labels: plot_genes_num={Colors.BOLD}{plot_genes_num}{Colors.ENDC}{Colors.BLUE}, plot_genes_fontsize={Colors.BOLD}{plot_genes_fontsize}{Colors.ENDC}")
if plot_genes is not None:
print(f" {Colors.BLUE}Custom genes: {Colors.BOLD}{len(plot_genes)} specified{Colors.ENDC}")
else:
print(f" {Colors.BLUE}Custom genes: {Colors.BOLD}None{Colors.ENDC}{Colors.BLUE} (auto-select top genes){Colors.ENDC}")
# Parameter optimization suggestions
print(f"\n{Colors.HEADER}{Colors.BOLD}💡 Parameter Optimization Suggestions:{Colors.ENDC}")
suggestions = []
# Check if there are enough significant genes
if total_sig < 10:
suggestions.append(f" {Colors.WARNING}▶ Few significant genes detected ({total_sig}):{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: pval_threshold={Colors.BOLD}{pval_threshold}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}pval_threshold=0.1{Colors.ENDC} or {Colors.BOLD}pval_threshold=0.2{Colors.ENDC}")
# Check fold change thresholds
if fc_range > 10 and (fc_max <= 2 or abs(fc_min) <= 2):
new_fc_max = min(round(result[fc_name].quantile(0.95), 1), 4.0)
new_fc_min = max(round(result[fc_name].quantile(0.05), 1), -4.0)
suggestions.append(f" {Colors.WARNING}▶ Wide fold change range detected:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: fc_max={Colors.BOLD}{fc_max}{Colors.ENDC}{Colors.CYAN}, fc_min={Colors.BOLD}{fc_min}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}fc_max={new_fc_max}, fc_min={new_fc_min}{Colors.ENDC}")
# Check plot size based on gene label settings
if plot_genes_num > 20 and figsize[0] < 6:
suggestions.append(f" {Colors.WARNING}▶ Many gene labels with small plot:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: plot_genes_num={Colors.BOLD}{plot_genes_num}{Colors.ENDC}{Colors.CYAN}, figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}figsize=(6, 6){Colors.ENDC} or {Colors.BOLD}plot_genes_num=15{Colors.ENDC}")
# Check if gene labels might be too small
if plot_genes_fontsize < 8 and plot_genes_num > 15:
suggestions.append(f" {Colors.WARNING}▶ Small font with many labels:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: plot_genes_fontsize={Colors.BOLD}{plot_genes_fontsize}{Colors.ENDC}{Colors.CYAN}, plot_genes_num={Colors.BOLD}{plot_genes_num}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}plot_genes_fontsize=10{Colors.ENDC} or {Colors.BOLD}plot_genes_num=10{Colors.ENDC}")
# Check figure aspect ratio
if abs(figsize[0] - figsize[1]) > 2:
suggestions.append(f" {Colors.BLUE}▶ Unbalanced figure dimensions:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: figsize={Colors.BOLD}{figsize}{Colors.ENDC}")
optimal_size = max(figsize)
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({optimal_size}, {optimal_size}){Colors.ENDC}")
if suggestions:
for suggestion in suggestions:
print(suggestion)
print(f"\n {Colors.BOLD}📋 Copy-paste ready function call:{Colors.ENDC}")
# Generate optimized function call
optimized_params = ["result"]
# Add data column parameters if different from defaults
if pval_name != 'qvalue':
optimized_params.append(f"pval_name='{pval_name}'")
if fc_name != 'log2FC':
optimized_params.append(f"fc_name='{fc_name}'")
# Add optimized parameters based on suggestions
if total_sig < 10:
optimized_params.append("pval_threshold=0.1")
if fc_range > 10 and (fc_max <= 2 or abs(fc_min) <= 2):
new_fc_max = min(round(result[fc_name].quantile(0.95), 1), 4.0)
new_fc_min = max(round(result[fc_name].quantile(0.05), 1), -4.0)
optimized_params.append(f"fc_max={new_fc_max}")
optimized_params.append(f"fc_min={new_fc_min}")
if plot_genes_num > 20 and figsize[0] < 6:
optimized_params.append("figsize=(6, 6)")
elif abs(figsize[0] - figsize[1]) > 2:
optimal_size = max(figsize)
optimized_params.append(f"figsize=({optimal_size}, {optimal_size})")
if plot_genes_fontsize < 8 and plot_genes_num > 15:
optimized_params.append("plot_genes_fontsize=10")
if plot_genes is not None:
optimized_params.append(f"plot_genes={plot_genes}")
optimized_call = f" {Colors.GREEN}ov.pl.volcano({', '.join(optimized_params)}){Colors.ENDC}"
print(optimized_call)
else:
print(f" {Colors.GREEN}✅ Current parameters are optimal for your data!{Colors.ENDC}")
print(f"{Colors.CYAN}{'─' * 60}{Colors.ENDC}")
# Original volcano plot code starts here
result=result.copy()
result['-log(qvalue)']=-np.log10(result[pval_name])
result['log2FC']= result[fc_name].copy()
if pval_max!=None:
result.loc[result['-log(qvalue)']>pval_max,'-log(qvalue)']=pval_max
if FC_max!=None:
result.loc[result['log2FC']>FC_max,'log2FC']=FC_max
result.loc[result['log2FC']<-FC_max,'log2FC']=0-FC_max
if ax==None:
fig, ax = plt.subplots(figsize=figsize)
ax.scatter(x=result[result['sig']=='normal']['log2FC'],
y=result[result['sig']=='normal']['-log(qvalue)'],
color=normal_color,#颜色
alpha=.5,#透明度
)
#接着绘制上调基因
ax.scatter(x=result[result['sig']=='up']['log2FC'],
y=result[result['sig']=='up']['-log(qvalue)'],
color=up_color,#选择色卡第15个颜色
alpha=.5,#透明度
)
#绘制下调基因
ax.scatter(x=result[result['sig']=='down']['log2FC'],
y=result[result['sig']=='down']['-log(qvalue)'],
color=down_color,#颜色
alpha=.5,#透明度
)
ax.plot([result['log2FC'].min(),result['log2FC'].max()],#辅助线的x值起点与终点
[-np.log10(pval_threshold),-np.log10(pval_threshold)],#辅助线的y值起点与终点
linewidth=2,#辅助线的宽度
linestyle="--",#辅助线类型:虚线
color='black'#辅助线的颜色
)
ax.plot([fc_max,fc_max],
[result['-log(qvalue)'].min(),result['-log(qvalue)'].max()],
linewidth=2,
linestyle="--",
color='black')
ax.plot([fc_min,fc_min],
[result['-log(qvalue)'].min(),result['-log(qvalue)'].max()],
linewidth=2,
linestyle="--",
color='black')
#设置横标签与纵标签
ax.set_ylabel(r'$-log_{10}(qvalue)$',titlefont)
ax.set_xlabel(r'$log_{2}FC$',titlefont)
#设置标题
ax.set_title(title,titlefont)
#绘制图注
#legend标签列表,上面的color即是颜色列表
labels = ['up:{0}'.format(len(result[result['sig']=='up'])),
'down:{0}'.format(len(result[result['sig']=='down']))]
#用label和color列表生成mpatches.Patch对象,它将作为句柄来生成legend
color = [up_color,down_color]
patches = [mpatches.Patch(color=color[i], label="{:s}".format(labels[i]) ) for i in range(len(color))]
ax.legend(handles=patches,
bbox_to_anchor=legend_bbox,
ncol=legend_ncol,
fontsize=legend_fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.set_xticks([round(i,2) for i in ax.get_xticks()[1:-1]],#获取x坐标轴内容
[round(i,2) for i in ax.get_xticks()[1:-1]],#更新x坐标轴内容
fontsize=ticks_fontsize,
fontweight='normal'
)
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
from adjustText import adjust_text
import adjustText
if plot_genes is not None:
hub_gene=plot_genes
elif (plot_genes is None) and (plot_genes_num is None):
return ax
else:
up_result=result.loc[result['sig']=='up']
down_result=result.loc[result['sig']=='down']
hub_gene=up_result.sort_values(pval_name).index[:plot_genes_num//2].tolist()+down_result.sort_values(pval_name).index[:plot_genes_num//2].tolist()
color_dict={
'up':up_fontcolor,
'down':down_fontcolor,
'normal':normal_fontcolor
}
texts=[ax.text(result.loc[i,'log2FC'],
result.loc[i,'-log(qvalue)'],
i,
fontdict={'size':plot_genes_fontsize,'weight':'bold','color':color_dict[result.loc[i,'sig']]}
) for i in hub_gene]
if adjustText.__version__<='0.8':
adjust_text(texts,only_move={'text': 'xy'},arrowprops=dict(arrowstyle='->', color='red'),)
else:
adjust_text(texts,only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
arrowprops=dict(arrowstyle='->', color='red'))
return ax
omicverse.pl.venn(sets={}, out='./', palette='bgrc', ax=False, ext='png', dpi=300, fontsize=3.5, bbox_to_anchor=(0.5, 0.99), nc=2, cs=4)
¶
Create a Venn diagram to visualize set overlaps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sets |
Dictionary with set names as keys and sets as values ({}) |
{}
|
|
out |
Output directory path ('./') |
'./'
|
|
palette |
Color palette for sets ('bgrc') |
'bgrc'
|
|
ax |
Matplotlib axes object or False to create new (False) |
False
|
|
ext |
File extension for output ('png') |
'png'
|
|
dpi |
Resolution for output image (300) |
300
|
|
fontsize |
Font size for text (3.5) |
3.5
|
|
bbox_to_anchor |
Legend position ((.5, .99)) |
(0.5, 0.99)
|
|
nc |
Number of legend columns (2) |
2
|
|
cs |
Font size for legend (4) |
4
|
Returns:
Name | Type | Description |
---|---|---|
ax | matplotlib.axes.Axes object |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_bulk.py
def venn(sets={}, out='./', palette='bgrc',
ax=False, ext='png', dpi=300, fontsize=3.5,
bbox_to_anchor=(.5, .99),nc=2,cs=4):
r"""
Create a Venn diagram to visualize set overlaps.
Arguments:
sets: Dictionary with set names as keys and sets as values ({})
out: Output directory path ('./')
palette: Color palette for sets ('bgrc')
ax: Matplotlib axes object or False to create new (False)
ext: File extension for output ('png')
dpi: Resolution for output image (300)
fontsize: Font size for text (3.5)
bbox_to_anchor: Legend position ((.5, .99))
nc: Number of legend columns (2)
cs: Font size for legend (4)
Returns:
ax: matplotlib.axes.Axes object
"""
from ..utils import venny4py
venny4py(sets=sets,out=out,ce=palette,asax=ax,ext=ext,
dpi=dpi,size=fontsize,bbox_to_anchor=bbox_to_anchor,
nc=nc,cs=cs,
)
return ax
omicverse.pl.boxplot(data, hue, x_value, y_value, width=0.3, title='', figsize=(6, 3), palette=None, fontsize=10, legend_bbox=(1, 0.55), legend_ncol=1, hue_order=None)
¶
Create a boxplot with jittered points to visualize data distribution across categories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
DataFrame containing the data to plot |
required | |
hue |
Column name for grouping variable (color coding) |
required | |
x_value |
Column name for x-axis categories |
required | |
y_value |
Column name for y-axis values |
required | |
width |
Width of each boxplot (0.3) |
0.3
|
|
title |
Plot title ('') |
''
|
|
figsize |
Figure dimensions as (width, height) ((6,3)) |
(6, 3)
|
|
palette |
Color palette for groups (None, uses default) |
None
|
|
fontsize |
Font size for labels and ticks (10) |
10
|
|
legend_bbox |
Legend position as (x, y) ((1, 0.55)) |
(1, 0.55)
|
|
legend_ncol |
Number of legend columns (1) |
1
|
|
hue_order |
Custom order for hue categories (None, uses alphabetical) |
None
|
Returns:
Name | Type | Description |
---|---|---|
fig | matplotlib.figure.Figure object |
|
ax | matplotlib.axes.Axes object |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_bulk.py
def boxplot(data,hue,x_value,y_value,width=0.3,title='',
figsize=(6,3),palette=None,fontsize=10,
legend_bbox=(1, 0.55),legend_ncol=1,hue_order=None):
r"""
Create a boxplot with jittered points to visualize data distribution across categories.
Arguments:
data: DataFrame containing the data to plot
hue: Column name for grouping variable (color coding)
x_value: Column name for x-axis categories
y_value: Column name for y-axis values
width: Width of each boxplot (0.3)
title: Plot title ('')
figsize: Figure dimensions as (width, height) ((6,3))
palette: Color palette for groups (None, uses default)
fontsize: Font size for labels and ticks (10)
legend_bbox: Legend position as (x, y) ((1, 0.55))
legend_ncol: Number of legend columns (1)
hue_order: Custom order for hue categories (None, uses alphabetical)
Returns:
fig: matplotlib.figure.Figure object
ax: matplotlib.axes.Axes object
"""
# Color codes for terminal output
class Colors:
HEADER = '\033[95m'
BLUE = '\033[94m'
CYAN = '\033[96m'
GREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
# Print data information for user guidance
print(f"{Colors.HEADER}{Colors.BOLD}📊 Boxplot Data Analysis:{Colors.ENDC}")
print(f" {Colors.CYAN}Total samples: {Colors.BOLD}{len(data)}{Colors.ENDC}")
print(f" {Colors.BLUE}X-axis variable ('{x_value}'): {Colors.BOLD}{sorted(data[x_value].unique())}{Colors.ENDC}")
print(f" {Colors.BLUE}Hue variable ('{hue}'): {Colors.BOLD}{sorted(data[hue].unique())}{Colors.ENDC}")
print(f" {Colors.BLUE}Y-axis variable: '{y_value}' (range: {Colors.BOLD}{data[y_value].min():.2f} - {data[y_value].max():.2f}{Colors.ENDC})")
# Check for missing data
missing_data = data[[hue, x_value, y_value]].isnull().sum().sum()
if missing_data > 0:
print(f" {Colors.WARNING}⚠️ Warning: Found {Colors.BOLD}{missing_data}{Colors.ENDC}{Colors.WARNING} missing values in key columns{Colors.ENDC}")
# Display current function parameters
print(f"\n{Colors.HEADER}{Colors.BOLD}⚙️ Current Function Parameters:{Colors.ENDC}")
print(f" {Colors.BLUE}hue='{hue}', x_value='{x_value}', y_value='{y_value}'{Colors.ENDC}")
print(f" {Colors.BLUE}width={Colors.BOLD}{width}{Colors.ENDC}{Colors.BLUE}, figsize={Colors.BOLD}{figsize}{Colors.ENDC}{Colors.BLUE}, fontsize={Colors.BOLD}{fontsize}{Colors.ENDC}")
if hue_order is not None:
print(f" {Colors.BLUE}hue_order={Colors.BOLD}{hue_order}{Colors.ENDC}")
else:
print(f" {Colors.BLUE}hue_order={Colors.BOLD}None{Colors.ENDC}{Colors.BLUE} (using alphabetical order){Colors.ENDC}")
def calculate_box_positions(n_hues, spacing=0.8):
"""
Calculate evenly distributed positions for boxes within the range [-0.5, 0.5].
Parameters
----------
n_hues : int
Number of hue categories
spacing : float
Fraction of the total range to use for spacing boxes (0.8 means use 80% of [-0.5, 0.5])
Returns
-------
positions : list
List of positions for each hue
"""
if n_hues == 1:
return [0.0]
# Calculate the range to use for positioning
total_range = spacing # Use 80% of the [-0.5, 0.5] range by default
half_range = total_range / 2
# Calculate positions evenly distributed within the range
if n_hues > 1:
step = total_range / (n_hues - 1)
positions = [-half_range + i * step for i in range(n_hues)]
else:
positions = [0.0]
return positions
#获取需要分割的数据
if hue_order is not None:
hue_datas = hue_order
# Check if all hue values in data are in hue_order
data_hue_values = set(data[hue].unique())
hue_order_set = set(hue_order)
if not data_hue_values.issubset(hue_order_set):
missing_values = data_hue_values - hue_order_set
raise ValueError(f"The following hue values are in data but not in hue_order: {missing_values}")
print(f" {Colors.GREEN}📋 Using custom hue order: {Colors.BOLD}{hue_order}{Colors.ENDC}")
else:
hue_datas = sorted(list(set(data[hue])))
print(f" {Colors.GREEN}📋 Using alphabetical hue order: {Colors.BOLD}{hue_datas}{Colors.ENDC}")
#获取箱线图的横坐标
x=x_value
ticks=sorted(list(set(data[x])))
# Calculate box positions
box_positions = calculate_box_positions(len(hue_datas))
print(f"\n{Colors.HEADER}{Colors.BOLD}🎯 Box Positioning:{Colors.ENDC}")
print(f" {Colors.CYAN}Number of hue groups: {Colors.BOLD}{len(hue_datas)}{Colors.ENDC}")
print(f" {Colors.CYAN}Box positions: {Colors.BOLD}{[round(pos, 3) for pos in box_positions]}{Colors.ENDC}")
print(f" {Colors.CYAN}Box width: {Colors.BOLD}{width}{Colors.ENDC}")
# Calculate sample sizes for each combination
print(f"\n{Colors.HEADER}{Colors.BOLD}📈 Sample sizes per group:{Colors.ENDC}")
for hue_cat in hue_datas:
for x_cat in ticks:
count = len(data[(data[hue] == hue_cat) & (data[x] == x_cat)])
if count < 5:
color = Colors.WARNING
elif count < 10:
color = Colors.BLUE
else:
color = Colors.GREEN
print(f" {color}{hue_cat} × {x_cat}: {Colors.BOLD}{count}{Colors.ENDC}{color} samples{Colors.ENDC}")
# Provide parameter suggestions with current vs suggested comparison
print(f"\n{Colors.HEADER}{Colors.BOLD}💡 Parameter Optimization Suggestions:{Colors.ENDC}")
suggestions = []
if len(hue_datas) > 4:
suggested_width = round(max(0.1, 0.8 / len(hue_datas)), 1)
suggested_figsize_width = max(8, len(ticks) * 2)
suggestions.append(f" {Colors.WARNING}▶ Many hue groups detected:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: {Colors.BOLD}width={width}{Colors.ENDC}{Colors.CYAN}, figsize={figsize}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}width={suggested_width}, figsize=({suggested_figsize_width}, {figsize[1]}){Colors.ENDC}")
if len(ticks) > 5:
suggested_figsize_width = max(10, len(ticks) * 1.5)
suggestions.append(f" {Colors.WARNING}▶ Many x-categories detected:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: {Colors.BOLD}figsize={figsize}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_figsize_width}, {figsize[1]}){Colors.ENDC}")
max_samples = max([len(data[(data[hue] == h) & (data[x] == x_cat)]) for h in hue_datas for x_cat in ticks])
if max_samples < 5:
suggested_width = round(max(0.1, width * 0.7), 1)
suggestions.append(f" {Colors.WARNING}▶ Small sample sizes detected:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: {Colors.BOLD}width={width}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}width={suggested_width}{Colors.ENDC}")
# Check if current width might cause overlap
if len(hue_datas) > 3 and width > 0.25:
suggested_width = round(0.8 / len(hue_datas), 1)
suggestions.append(f" {Colors.FAIL}▶ Box overlap risk detected:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: {Colors.BOLD}width={width}{Colors.ENDC}{Colors.CYAN} (too wide for {len(hue_datas)} groups){Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}width={suggested_width}{Colors.ENDC}")
# Figure size optimization based on both dimensions
if len(ticks) > 3 and len(hue_datas) > 3:
suggested_width = max(8, len(ticks) * 1.5)
suggested_height = max(4, figsize[1])
suggestions.append(f" {Colors.BLUE}▶ Complex plot optimization:{Colors.ENDC}")
suggestions.append(f" {Colors.CYAN}Current: {Colors.BOLD}figsize={figsize}{Colors.ENDC}")
suggestions.append(f" {Colors.GREEN}Suggested: {Colors.BOLD}figsize=({suggested_width}, {suggested_height}){Colors.ENDC}")
if suggestions:
for suggestion in suggestions:
print(suggestion)
print(f"\n {Colors.BOLD}📋 Copy-paste ready function call:{Colors.ENDC}")
# Generate optimized function call
optimized_params = []
optimized_params.append(f"data, hue='{hue}', x_value='{x_value}', y_value='{y_value}'")
# Add optimized parameters
if len(hue_datas) > 4 or (len(hue_datas) > 3 and width > 0.25) or max_samples < 5:
if len(hue_datas) > 4:
opt_width = round(max(0.1, 0.8 / len(hue_datas)), 1)
elif max_samples < 5:
opt_width = round(max(0.1, width * 0.7), 1)
else:
opt_width = round(0.8 / len(hue_datas), 1)
optimized_params.append(f"width={opt_width}")
if len(ticks) > 5 or len(hue_datas) > 4 or (len(ticks) > 3 and len(hue_datas) > 3):
if len(hue_datas) > 4:
opt_fig_w = max(8, len(ticks) * 2)
elif len(ticks) > 5:
opt_fig_w = max(10, len(ticks) * 1.5)
else:
opt_fig_w = max(8, len(ticks) * 1.5)
opt_fig_h = max(4, figsize[1])
optimized_params.append(f"figsize=({opt_fig_w}, {opt_fig_h})")
if hue_order is not None:
optimized_params.append(f"hue_order={hue_order}")
optimized_call = f" {Colors.GREEN}ov.pl.boxplot({', '.join(optimized_params)}){Colors.ENDC}"
print(optimized_call)
else:
print(f" {Colors.GREEN}✅ Current parameters are optimal for your data!{Colors.ENDC}")
print(f"{Colors.CYAN}{'─' * 60}{Colors.ENDC}")
#在这个数据中,我们有6个不同的癌症,每个癌症都有2个基因(2个箱子)
#所以我们需要得到每一个基因的6个箱线图位置,6个散点图的抖动
plot_data1={}#字典里的每一个元素就是每一个基因的所有值
plot_data_random1={}#字典里的每一个元素就是每一个基因的随机20个值
plot_data_xs1={}#字典里的每一个元素就是每一个基因的20个抖动值
#箱子的参数
#width=0.6
y=y_value
import random
for hue_data,num in zip(hue_datas,box_positions):
data_a=[]
data_a_random=[]
data_a_xs=[]
for i,k in zip(ticks,range(len(ticks))):
test_data=data.loc[((data[x]==i)&(data[hue]==hue_data)),y].tolist()
data_a.append(test_data)
if len(test_data)<20:
data_size=len(test_data)
else:
data_size=20
if len(test_data) > 0:
random_data=random.sample(test_data,data_size)
else:
random_data=[]
data_a_random.append(random_data)
data_a_xs.append(np.random.normal(k+num, 0.04, len(random_data)))
#data_a=np.array(data_a)
data_a_random=np.array(data_a_random,dtype=object)
plot_data1[hue_data]=data_a
plot_data_random1[hue_data]=data_a_random
plot_data_xs1[hue_data]=data_a_xs
fig, ax = plt.subplots(figsize=figsize)
#色卡
if palette==None:
from ._palette import sc_color
palette=sc_color
#palette=["#a64d79","#674ea7"]
#绘制箱线图
for hue_data,hue_color,num in zip(hue_datas,palette,box_positions):
b1=ax.boxplot(plot_data1[hue_data],
positions=np.array(range(len(ticks)))+num,
sym='',
widths=width,)
plt.setp(b1['boxes'], color=hue_color)
plt.setp(b1['whiskers'], color=hue_color)
plt.setp(b1['caps'], color=hue_color)
plt.setp(b1['medians'], color=hue_color)
clevels = np.linspace(0., 1., len(plot_data_random1[hue_data]))
for x, val, clevel in zip(plot_data_xs1[hue_data], plot_data_random1[hue_data], clevels):
if len(val) > 0: # Only plot if there's data
plt.scatter(x, val,c=hue_color,alpha=0.4)
#坐标轴字体
#fontsize=10
#修改横坐标
ax.set_xticks(range(len(ticks)), ticks,fontsize=fontsize)
#修改纵坐标
yticks=ax.get_yticks()
ax.set_yticks(yticks[yticks>=0],yticks[yticks>=0],fontsize=fontsize)
labels = hue_datas #legend标签列表,上面的color即是颜色列表
color = palette
#用label和color列表生成mpatches.Patch对象,它将作为句柄来生成legend
patches = [ mpatches.Patch(color=color[i], label="{:s}".format(labels[i]) ) for i in range(len(hue_datas)) ]
ax.legend(handles=patches,bbox_to_anchor=legend_bbox, ncol=legend_ncol,fontsize=fontsize)
#设置标题
ax.set_title(title,fontsize=fontsize+1)
#设置spines可视化情况
plt.grid(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
return fig,ax
Density and Contour¶
omicverse.pl.calculate_gene_density(adata, features, basis='X_umap', dims=(0, 1), adjust=1, min_expr=0.1)
¶
Calculate weighted kernel density estimates for gene expression on 2D embeddings.
Computes KDE for each feature using expression values as weights and stores density values in adata.obs as 'density_{feature}' columns.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
adata |
Annotated data object with embedding coordinates |
required | |
features |
List of gene names or feature names to process |
required | |
basis |
Key in adata.obsm containing 2D embedding coordinates ('X_umap') |
'X_umap'
|
|
dims |
Embedding dimensions to use as (x_dim, y_dim) ((0, 1)) |
(0, 1)
|
|
adjust |
Bandwidth scaling factor for KDE (1) |
1
|
|
min_expr |
Minimum expression threshold for including cells (0.1) |
0.1
|
Returns:
Name | Type | Description |
---|---|---|
None | Updates adata.obs with density_{feature} columns |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_density.py
def calculate_gene_density(
adata,
features,
basis="X_umap",
dims=(0, 1),
adjust=1,
min_expr=0.1, # NEW: minimal raw expression to keep as weight > 0
):
r"""
Calculate weighted kernel density estimates for gene expression on 2D embeddings.
Computes KDE for each feature using expression values as weights and stores
density values in adata.obs as 'density_{feature}' columns.
Arguments:
adata: Annotated data object with embedding coordinates
features: List of gene names or feature names to process
basis: Key in adata.obsm containing 2D embedding coordinates ('X_umap')
dims: Embedding dimensions to use as (x_dim, y_dim) ((0, 1))
adjust: Bandwidth scaling factor for KDE (1)
min_expr: Minimum expression threshold for including cells (0.1)
Returns:
None: Updates adata.obs with density_{feature} columns
"""
if len(dims) != 2:
raise ValueError("`dims` must have length 2")
if basis not in adata.obsm:
raise ValueError(f"Embedding '{basis}' not found.")
emb_all = adata.obsm[basis][:, dims] # (n_cells, 2)
for feat in features:
# ----- fetch raw weights -------------------------------------------
if feat in adata.obs:
w_raw = adata.obs[feat].to_numpy()
elif feat in adata.var_names:
w_raw = adata[:, feat].X.toarray().ravel()
else:
raise ValueError(f"Feature '{feat}' not found in obs or var.")
# ----- validity mask: finite coords & finite expr -------------------
mask_finite = np.isfinite(w_raw) & np.all(np.isfinite(emb_all), axis=1)
# ----- NEW: expression threshold -----------------------------------
mask_expr = w_raw > min_expr
mask_train = mask_finite & mask_expr
emb_train = emb_all[mask_train]
w_train_raw = w_raw[mask_train]
if emb_train.shape[0] < 5:
print(f"[{feat}] too few cells above threshold; skipping KDE.")
adata.obs[f"density_{feat}"] = np.nan
continue
# ----- min–max scale to 0-1 ----------------------------------------
w_min, w_max = w_train_raw.min(), w_train_raw.max()
w_train = (w_train_raw - w_min) / (w_max - w_min)
# ----- KDE fit ------------------------------------------------------
kde = gaussian_kde(emb_train.T, weights=w_train, bw_method=adjust)
# ----- evaluate on ALL cells ---------------------------------------
density = kde(emb_all.T)
adata.obs[f"density_{feat}"] = density
print(f"✅ density_{feat} written (train cells = {emb_train.shape[0]})")
omicverse.pl.add_density_contour(ax, embeddings, weights, levels='quantile', n_quantiles=5, bw_adjust=0.3, cmap_contour='Greys', linewidth=1.0, zorder=10, fill=False, alpha=0.4)
¶
Add KDE-based density contours to an existing matplotlib plot.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ax |
matplotlib.axes.Axes object to draw contours on |
required | |
embeddings |
2D coordinate array with shape (n_cells, 2) |
required | |
weights |
1D weight array for KDE, will be min-max normalized |
required | |
levels |
Contour level specification - 'quantile' or list of values ('quantile') |
'quantile'
|
|
n_quantiles |
Number of quantile levels when levels='quantile' (5) |
5
|
|
bw_adjust |
Bandwidth adjustment factor for KDE (0.3) |
0.3
|
|
cmap_contour |
Colormap for contour lines ('Greys') |
'Greys'
|
|
linewidth |
Width of contour lines (1.0) |
1.0
|
|
zorder |
Drawing order for contours (10) |
10
|
|
fill |
Whether to fill contours (False for lines, True for filled) |
False
|
|
alpha |
Transparency for filled contours (0.4) |
0.4
|
Returns:
Name | Type | Description |
---|---|---|
cs | matplotlib contour object for potential colorbar addition |
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_density.py
def add_density_contour(
ax,
embeddings, # (n_cells, 2) array
weights, # 1-D array, will be min-max scaled
levels="quantile", # "quantile" or a numeric list, see below
n_quantiles=5,
bw_adjust=0.3,
cmap_contour="Greys",
linewidth=1.0,
zorder=10,
fill=False,
alpha=0.4,
):
r"""
Add KDE-based density contours to an existing matplotlib plot.
Arguments:
ax: matplotlib.axes.Axes object to draw contours on
embeddings: 2D coordinate array with shape (n_cells, 2)
weights: 1D weight array for KDE, will be min-max normalized
levels: Contour level specification - 'quantile' or list of values ('quantile')
n_quantiles: Number of quantile levels when levels='quantile' (5)
bw_adjust: Bandwidth adjustment factor for KDE (0.3)
cmap_contour: Colormap for contour lines ('Greys')
linewidth: Width of contour lines (1.0)
zorder: Drawing order for contours (10)
fill: Whether to fill contours (False for lines, True for filled)
alpha: Transparency for filled contours (0.4)
Returns:
cs: matplotlib contour object for potential colorbar addition
"""
# ---------- fit KDE ----------------------------------------------------
w_min, w_max = weights.min(), weights.max()
w_norm = None if w_max == w_min else (weights - w_min) / (w_max - w_min)
kde = gaussian_kde(embeddings.T, weights=w_norm, bw_method=bw_adjust)
# ---------- prepare evaluation grid -----------------------------------
xmin, xmax = embeddings[:, 0].min(), embeddings[:, 0].max()
ymin, ymax = embeddings[:, 1].min(), embeddings[:, 1].max()
xx, yy = np.mgrid[xmin:xmax:300j, ymin:ymax:300j] # 300×300 grid
grid = kde(np.vstack([xx.ravel(), yy.ravel()])).reshape(xx.shape)
# ---------- determine contour levels ----------------------------------
if levels == "quantile":
qs = np.linspace(0, 1, n_quantiles + 2)[1:-1] # drop 0 & 1
levels = np.quantile(grid, qs)
# ---------- draw -------------------------------------------------------
if fill:
cs = ax.contourf(
xx, yy, grid,
levels=levels,
cmap=cmap_contour,
alpha=alpha,
zorder=zorder,
)
else:
cs = ax.contour(
xx, yy, grid,
levels=levels,
cmap=cmap_contour,
linewidths=linewidth,
zorder=zorder,
)
return cs # so you can add a colorbar if desired
Spatial Plotting¶
omicverse.pl.spatial_segment(adata, seg_cell_id, seg=True, seg_key='segmentation', seg_contourpx=None, seg_outline=False, **kwargs)
¶
Plot spatial omics data with segmentation masks on top.
Argument seg_cell_id
in :attr:anndata.AnnData.obs
controls unique segmentation mask's ids to be plotted.
By default, 'segmentation'
, seg_key
for the segmentation and 'hires'
for the image is attempted.
- Use ``seg_key`` to display the image in the background.
- Use ``seg_contourpx`` or ``seg_outline`` to control how the segmentation mask is displayed.
%(spatial_plot.summary_ext)s
.. seealso::
- :func:squidpy.pl.spatial_scatter
on how to plot spatial data with overlayed data on top.
Parameters¶
%(adata)s %(plotting_segment)s %(color)s %(groups)s %(library_id)s %(library_key)s %(spatial_key)s %(plotting_image)s %(plotting_features)s
Returns¶
%(spatial_plot.returns)s
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_spatial.py
def spatial_segment(
adata: AnnData,
seg_cell_id: str,
seg: bool | Sequence[np.ndarray] | None = True,
seg_key: str = "segmentation",
seg_contourpx: int | None = None,
seg_outline: bool = False,
**kwargs: Any,
) -> Axes | Sequence[Axes] | None:
"""
Plot spatial omics data with segmentation masks on top.
Argument ``seg_cell_id`` in :attr:`anndata.AnnData.obs` controls unique segmentation mask's ids to be plotted.
By default, ``'segmentation'``, ``seg_key`` for the segmentation and ``'hires'`` for the image is attempted.
- Use ``seg_key`` to display the image in the background.
- Use ``seg_contourpx`` or ``seg_outline`` to control how the segmentation mask is displayed.
%(spatial_plot.summary_ext)s
.. seealso::
- :func:`squidpy.pl.spatial_scatter` on how to plot spatial data with overlayed data on top.
Parameters
----------
%(adata)s
%(plotting_segment)s
%(color)s
%(groups)s
%(library_id)s
%(library_key)s
%(spatial_key)s
%(plotting_image)s
%(plotting_features)s
Returns
-------
%(spatial_plot.returns)s
"""
from squidpy._docs import d
return _spatial_plot(
adata,
seg=seg,
seg_key=seg_key,
seg_cell_id=seg_cell_id,
seg_contourpx=seg_contourpx,
seg_outline=seg_outline,
**kwargs,
)
omicverse.pl.spatial_segment_overlay(adata, seg_cell_id, color, seg=True, seg_key='segmentation', seg_contourpx=None, seg_outline=False, alpha=0.5, cmaps=None, library_id=None, library_key=None, spatial_key='spatial', img=True, img_res_key='hires', img_alpha=None, img_cmap=None, img_channel=None, use_raw=None, layer=None, alt_var=None, groups=None, palette=None, norm=None, na_color=(0, 0, 0, 0), size=None, size_key='spot_diameter_fullres', scale_factor=None, crop_coord=None, connectivity_key=None, edges_width=1.0, edges_color='grey', frameon=None, legend_loc='right margin', legend_fontsize=None, legend_fontweight='bold', legend_fontoutline=None, legend_na=True, colorbar=True, colorbar_position='bottom', colorbar_grid=None, colorbar_tick_size=10, colorbar_title_size=12, colorbar_width=None, colorbar_height=None, colorbar_spacing=None, scalebar_dx=None, scalebar_units=None, title=None, axis_label=None, fig=None, ax=None, return_ax=False, figsize=None, dpi=None, save=None, scalebar_kwargs=MappingProxyType({}), edges_kwargs=MappingProxyType({}), **kwargs)
¶
Plot multiple genes overlaid on the same spatial segmentation plot.
This function allows visualization of multiple genes in the same spatial context, using different colors and transparency to show expression overlap and co-localization.
Parameters¶
%(adata)s
seg_cell_id
Key in :attr:anndata.AnnData.obs
that contains unique cell IDs for segmentation.
color
Gene names or features to plot. Can be a single gene or list of genes.
seg
Segmentation mask. If True
, uses segmentation from adata.uns['spatial']
.
seg_key
Key for segmentation mask in adata.uns['spatial'][library_id]['images']
.
seg_contourpx
Contour width in pixels. If specified, segmentation boundaries will be eroded.
seg_outline
If True
, show segmentation boundaries.
alpha
Transparency level for gene expression overlay (0-1).
cmaps
Colormap(s) for each gene. If single string, uses same colormap for all genes.
If list, should match length of color
parameter.
%(library_id)s
%(library_key)s
%(spatial_key)s
%(plotting_image)s
%(plotting_features)s
groups
Categories to plot for categorical features.
palette
Color palette for categorical features.
norm
Colormap normalization.
na_color
Color for NA/missing values.
colorbar
Whether to show colorbars for each gene.
colorbar_position
Position of colorbars: 'bottom', 'right', or 'none'.
colorbar_grid
Grid layout for colorbars as (rows, cols). If None, auto-determined.
colorbar_tick_size
Font size for colorbar tick labels.
colorbar_title_size
Font size for colorbar titles.
colorbar_width
Width of individual colorbars. If None, uses default values.
colorbar_height
Height of individual colorbars. If None, uses default values.
colorbar_spacing
Dictionary with spacing parameters: 'hspace' (vertical gaps), 'wspace' (horizontal gaps).
If None, uses default spacing.
title
Plot title.
ax
Matplotlib axes object to plot on.
return_ax
Whether to return the axes object.
figsize
Figure size (width, height).
save
Path to save the figure.
Returns¶
If return_ax
is True
, returns matplotlib axes object, otherwise None
.
Examples¶
import squidpy as sq
Overlay two genes with different colors¶
sq.pl.spatial_segment_overlay( ... adata, ... seg_cell_id='cell_id', ... color=['gene1', 'gene2'], ... cmaps=['Reds', 'Blues'], ... alpha=0.7 ... )
Source code in /Users/fernandozeng/miniforge3/envs/space/lib/python3.10/site-packages/omicverse/pl/_spatial.py
def spatial_segment_overlay(
adata: AnnData,
seg_cell_id: str,
color: str | Sequence[str],
seg: bool | Sequence[np.ndarray] | None = True,
seg_key: str = "segmentation",
seg_contourpx: int | None = None,
seg_outline: bool = False,
alpha: float = 0.5,
cmaps: str | Sequence[str] | None = None,
library_id: _SeqStr | None = None,
library_key: str | None = None,
spatial_key: str = "spatial",
img: bool | Sequence[np.ndarray] | None = True,
img_res_key: str | None = "hires",
img_alpha: float | None = None,
img_cmap: Colormap | str | None = None,
img_channel: int | list[int] | None = None,
use_raw: bool | None = None,
layer: str | None = None,
alt_var: str | None = None,
groups: _SeqStr | None = None,
palette: Palette_t = None,
norm: _Normalize | None = None,
na_color: str | tuple[float, ...] = (0, 0, 0, 0),
size: _SeqFloat | None = None,
size_key: str | None = "spot_diameter_fullres",
scale_factor: _SeqFloat | None = None,
crop_coord: _CoordTuple | Sequence[_CoordTuple] | None = None,
connectivity_key: str | None = None,
edges_width: float = 1.0,
edges_color: str | Sequence[str] | Sequence[float] = "grey",
frameon: bool | None = None,
legend_loc: str | None = "right margin",
legend_fontsize: int | float | _FontSize | None = None,
legend_fontweight: int | _FontWeight = "bold",
legend_fontoutline: int | None = None,
legend_na: bool = True,
colorbar: bool = True,
colorbar_position: str = "bottom",
colorbar_grid: tuple[int, int] | None = None,
colorbar_tick_size: int = 10,
colorbar_title_size: int = 12,
colorbar_width: float | None = None,
colorbar_height: float | None = None,
colorbar_spacing: dict[str, float] | None = None,
scalebar_dx: _SeqFloat | None = None,
scalebar_units: _SeqStr | None = None,
title: _SeqStr | None = None,
axis_label: _SeqStr | None = None,
fig: Figure | None = None,
ax: Axes | None = None,
return_ax: bool = False,
figsize: tuple[float, float] | None = None,
dpi: int | None = None,
save: str | Path | None = None,
scalebar_kwargs: Mapping[str, Any] = MappingProxyType({}),
edges_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> Axes | None:
"""
Plot multiple genes overlaid on the same spatial segmentation plot.
This function allows visualization of multiple genes in the same spatial context,
using different colors and transparency to show expression overlap and co-localization.
Parameters
----------
%(adata)s
seg_cell_id
Key in :attr:`anndata.AnnData.obs` that contains unique cell IDs for segmentation.
color
Gene names or features to plot. Can be a single gene or list of genes.
seg
Segmentation mask. If `True`, uses segmentation from `adata.uns['spatial']`.
seg_key
Key for segmentation mask in `adata.uns['spatial'][library_id]['images']`.
seg_contourpx
Contour width in pixels. If specified, segmentation boundaries will be eroded.
seg_outline
If `True`, show segmentation boundaries.
alpha
Transparency level for gene expression overlay (0-1).
cmaps
Colormap(s) for each gene. If single string, uses same colormap for all genes.
If list, should match length of `color` parameter.
%(library_id)s
%(library_key)s
%(spatial_key)s
%(plotting_image)s
%(plotting_features)s
groups
Categories to plot for categorical features.
palette
Color palette for categorical features.
norm
Colormap normalization.
na_color
Color for NA/missing values.
colorbar
Whether to show colorbars for each gene.
colorbar_position
Position of colorbars: 'bottom', 'right', or 'none'.
colorbar_grid
Grid layout for colorbars as (rows, cols). If None, auto-determined.
colorbar_tick_size
Font size for colorbar tick labels.
colorbar_title_size
Font size for colorbar titles.
colorbar_width
Width of individual colorbars. If None, uses default values.
colorbar_height
Height of individual colorbars. If None, uses default values.
colorbar_spacing
Dictionary with spacing parameters: 'hspace' (vertical gaps), 'wspace' (horizontal gaps).
If None, uses default spacing.
title
Plot title.
ax
Matplotlib axes object to plot on.
return_ax
Whether to return the axes object.
figsize
Figure size (width, height).
save
Path to save the figure.
Returns
-------
If `return_ax` is `True`, returns matplotlib axes object, otherwise `None`.
Examples
--------
>>> import squidpy as sq
>>> # Overlay two genes with different colors
>>> sq.pl.spatial_segment_overlay(
... adata,
... seg_cell_id='cell_id',
... color=['gene1', 'gene2'],
... cmaps=['Reds', 'Blues'],
... alpha=0.7
... )
"""
from matplotlib.colors import LinearSegmentedColormap, to_rgb
from matplotlib.gridspec import GridSpec
from squidpy.pl._utils import sanitize_anndata, save_fig
import matplotlib as mpl
from squidpy.gr._utils import _assert_spatial_basis
sanitize_anndata(adata)
_assert_spatial_basis(adata, spatial_key)
# Ensure color is a list
if isinstance(color, str):
color = [color]
# Handle colormaps
if cmaps is None:
# Default colors for overlay: Red, Green, Blue, Cyan, Magenta, Yellow
default_colors = ['#FF0000', '#00FF00', '#0000FF', '#00FFFF', '#FF00FF', '#FFFF00']
cmaps = []
for i, _ in enumerate(color):
base_color = default_colors[i % len(default_colors)]
base_rgb = to_rgb(base_color)
cmap_colors = [
base_rgb + (0.0,), # Transparent
base_rgb + (1.0,) # Opaque
]
cmap = LinearSegmentedColormap.from_list(f'overlay_{i}', cmap_colors, N=100)
cmaps.append(cmap)
elif isinstance(cmaps, str):
cmaps = [cmaps] * len(color)
# Create figure and axes with colorbar layout
if ax is None:
figsize = figsize or (10, 8)
if colorbar and colorbar_position != "none":
# Set default colorbar dimensions and spacing
if colorbar_spacing is None:
colorbar_spacing = {}
# Determine colorbar grid layout
if colorbar_grid is None:
if colorbar_position == "bottom":
if len(color) <= 3:
colorbar_grid = (1, len(color))
else:
n_rows = (len(color) + 2) // 3 # Round up division
colorbar_grid = (n_rows, 3)
elif colorbar_position == "right":
colorbar_grid = (len(color), 1)
# Create GridSpec layout with customizable dimensions
if colorbar_position == "bottom":
# Default dimensions for bottom colorbars
default_width = 0.2 if colorbar_width is None else colorbar_width
default_height = 0.08 if colorbar_height is None else colorbar_height
default_hspace = colorbar_spacing.get('hspace', 0.4)
default_wspace = colorbar_spacing.get('wspace', 0.3)
gs = GridSpec(
nrows=colorbar_grid[0] + 1,
ncols=max(colorbar_grid[1], 1) + 2,
width_ratios=[0.1, *[default_width] * max(colorbar_grid[1], 1), 0.1],
height_ratios=[1, *[default_height] * colorbar_grid[0]],
hspace=default_hspace,
wspace=default_wspace,
figure=plt.figure(figsize=figsize, dpi=dpi)
)
fig = gs.figure
ax = fig.add_subplot(gs[0, :])
ax.grid(False)
elif colorbar_position == "right":
# Default dimensions for right colorbars
default_width = 0.15 if colorbar_width is None else colorbar_width
default_height = 0.2 if colorbar_height is None else colorbar_height
default_hspace = colorbar_spacing.get('hspace', 0.3)
default_wspace = colorbar_spacing.get('wspace', 0.1)
gs = GridSpec(
nrows=max(colorbar_grid[0], 1) + 2,
ncols=colorbar_grid[1] + 1,
width_ratios=[1, *[default_width] * colorbar_grid[1]],
height_ratios=[0.1, *[default_height] * max(colorbar_grid[0], 1), 0.1],
hspace=default_hspace,
wspace=default_wspace,
figure=plt.figure(figsize=figsize, dpi=dpi)
)
fig = gs.figure
ax = fig.add_subplot(gs[:, 0])
ax.grid(False)
else:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
ax.grid(False)
gs = None # No GridSpec when no colorbars
else:
fig = ax.figure if ax.figure is not None else fig
gs = None # No GridSpec when external ax is provided
# Get spatial parameters once
spatial_params = _image_spatial_attrs(
adata=adata,
spatial_key=spatial_key,
library_id=library_id,
library_key=library_key,
img=img,
img_res_key=img_res_key,
img_channel=img_channel,
seg=seg,
seg_key=seg_key,
cell_id_key=seg_cell_id,
scale_factor=scale_factor,
size=size,
size_key=size_key,
img_cmap=img_cmap,
)
# Get coordinates and crops
coords, crops = _set_coords_crops(
adata=adata,
spatial_params=spatial_params,
spatial_key=spatial_key,
crop_coord=crop_coord,
)
# Use first library for now (could be extended for multiple libraries)
_lib_count = 0
_size = spatial_params.size[_lib_count]
_img = spatial_params.img[_lib_count]
_seg = spatial_params.segment[_lib_count]
_cell_id = spatial_params.cell_id[_lib_count]
_crops = crops[_lib_count]
_lib = spatial_params.library_id[_lib_count]
_coords = coords[_lib_count]
# Subset data
adata_sub, coords_sub, image_sub = _subs(
adata,
_coords,
_img,
library_key=library_key,
library_id=_lib,
crop_coords=_crops,
)
# Store colorbar information for each gene
colorbar_info = []
# Plot each gene as an overlay
for i, (gene, cmap) in enumerate(zip(color, cmaps, strict=False)):
# Get color vector for this gene
color_source_vector, color_vector, categorical = _set_color_source_vec(
adata_sub,
gene,
layer=layer,
use_raw=use_raw,
alt_var=alt_var,
groups=groups,
palette=palette,
na_color=na_color,
alpha=alpha,
)
if _seg is not None and _cell_id is not None:
# Create colormap parameters
from matplotlib.colors import Normalize
norm_to_use = norm or Normalize()
# Fix normalization if needed
if not categorical and (norm_to_use.vmin is None or norm_to_use.vmax is None):
valid_values = color_vector[~pd.isna(color_vector)] if hasattr(color_vector, 'isna') else color_vector[~np.isnan(color_vector)]
if len(valid_values) > 0:
norm_to_use.vmin = np.min(valid_values)
norm_to_use.vmax = np.max(valid_values)
cmap_params = CmapParams(cmap, img_cmap, norm_to_use)
color_params = ColorParams(None, [gene], groups, alpha, img_alpha or 1.0, use_raw or False)
# Store colorbar information
if colorbar and not categorical:
colorbar_info.append({
'gene': gene,
'cmap': cmap,
'norm': norm_to_use,
'values': color_vector
})
# Plot this gene's segmentation
ax, cax = _plot_segment(
seg=_seg,
cell_id=_cell_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
seg_contourpx=seg_contourpx,
seg_outline=seg_outline,
na_color=na_color,
ax=ax,
cmap_params=cmap_params,
color_params=color_params,
categorical=categorical,
**kwargs,
)
# Handle axis decoration manually (essential parts from _decorate_axs)
ax.set_yticks([])
ax.set_xticks([])
ax.autoscale_view() # needed when plotting points but no image
# Handle image display and axis orientation (core logic from _decorate_axs)
if image_sub is not None:
ax.imshow(image_sub, cmap=img_cmap, alpha=img_alpha)
else:
ax.set_aspect("equal")
ax.invert_yaxis()
# Set title
if title is None:
title = f"Overlay: {', '.join(color)}"
ax.set_title(title)
# Add colorbars
if colorbar and colorbar_position != "none" and colorbar_info:
# Create colorbar axes
cbar_axes = []
n_genes = len(colorbar_info)
if colorbar_position == "bottom":
for row in range(colorbar_grid[0]):
for col in range(colorbar_grid[1]):
idx = row * colorbar_grid[1] + col
if idx < n_genes:
cbar_ax = fig.add_subplot(gs[row + 1, col + 1])
cbar_axes.append(cbar_ax)
elif colorbar_position == "right":
for row in range(min(colorbar_grid[0], n_genes)):
cbar_ax = fig.add_subplot(gs[row + 1, 1])
cbar_axes.append(cbar_ax)
# Create colorbars
for i, (cbar_info, cbar_ax) in enumerate(zip(colorbar_info, cbar_axes)):
if i < len(cbar_axes):
# Calculate ticks
vmin, vmax = cbar_info['norm'].vmin, cbar_info['norm'].vmax
if vmin is not None and vmax is not None:
ticks = [vmin, (vmin + vmax) / 2, vmax]
if vmax > 10:
ticks = [int(t) for t in ticks]
else:
ticks = [round(t, 2) for t in ticks]
# Create colorbar
orientation = "horizontal" if colorbar_position == "bottom" else "vertical"
cbar = fig.colorbar(
mpl.cm.ScalarMappable(norm=cbar_info['norm'], cmap=cbar_info['cmap']),
cax=cbar_ax,
orientation=orientation,
extend="both",
ticks=ticks
)
# Style colorbar
cbar.ax.tick_params(labelsize=colorbar_tick_size)
cbar.ax.set_title(cbar_info['gene'], size=colorbar_title_size, pad=10)
# Add scalebar if specified
if scalebar_dx is not None:
from matplotlib_scalebar.scalebar import ScaleBar
scalebar_dx, scalebar_units = _get_scalebar(scalebar_dx, scalebar_units, 1)
if scalebar_dx and scalebar_units:
scalebar = ScaleBar(scalebar_dx[0], units=scalebar_units[0], **scalebar_kwargs)
ax.add_artist(scalebar)
if save is not None:
save_fig(fig, path=save)
if return_ax:
return ax