Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/spatialdata_plot/pl/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,29 @@ def _modify_categorical_color_mapping(
return modified_mapping


def _default_categorical_palette(n: int) -> list[str]:
"""Return the scanpy default categorical palette sized for ``n`` categories (grey beyond 103)."""
if n <= 20:
return list(default_20)
if n <= 28:
return list(default_28)
if n <= len(default_102):
return list(default_102)
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
return ["grey"] * n


def _next_palette_colors(used_colors: set[str], n: int) -> list[str]:
"""Pick ``n`` default-palette colors skipping ``used_colors`` (keeps a 2nd categorical render distinct, #364).

Falls back to the full palette if too few unused colors remain.
"""
used_norm = {to_hex(to_rgba(c)) for c in used_colors}
pool = _default_categorical_palette(n + len(used_norm))
unused = [c for c in pool if to_hex(to_rgba(c)) not in used_norm]
return (unused if len(unused) >= n else pool)[:n]


def _get_default_categorial_color_mapping(
color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
cmap_params: CmapParams | None = None,
Expand All @@ -1179,17 +1202,8 @@ def _get_default_categorial_color_mapping(
else:
palette = None

# Fall back to default palettes if needed
if palette is None:
if len_cat <= 20:
palette = default_20
elif len_cat <= 28:
palette = default_28
elif len_cat <= len(default_102): # 103 colors
palette = default_102
else:
palette = ["grey"] * len_cat
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
palette = _default_categorical_palette(len_cat)

return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True))

Expand Down
88 changes: 75 additions & 13 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from geopandas import GeoDataFrame
from matplotlib.axes import Axes
from matplotlib.backend_bases import RendererBase
from matplotlib.colors import Colormap, LogNorm, Normalize
from matplotlib.colors import Colormap, LogNorm, Normalize, to_hex, to_rgba
from matplotlib.figure import Figure
from matplotlib.legend import Legend
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from spatialdata._utils import _deprecation_alias
from spatialdata.transformations.operations import get_transformation
Expand All @@ -31,6 +32,7 @@
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl._color import (
_maybe_set_colors,
_next_palette_colors,
_prepare_cmap_norm,
_set_outline,
)
Expand Down Expand Up @@ -1584,6 +1586,11 @@ def show(

_draw_scalebar(ax, scalebar_params_obj, panel_idx=i)

if fig_params.fig is not None:
for panel_ax in fig_params.axs if fig_params.axs is not None else [fig_params.ax]:
if isinstance(panel_ax, Axes):
_layout_panel_legends(panel_ax, fig_params.fig)

_layout_pending_colorbars(pending_colorbars, fig_params, colorbar_params)

if fig_params.fig is not None and save is not None:
Expand Down Expand Up @@ -1870,6 +1877,45 @@ def _draw_colorbar(
trackers_axes[location] = pad_axes + (bbox_axes.width if vertical else bbox_axes.height)


def _layout_panel_legends(ax: Axes, fig: Figure, gap: float = 0.01) -> None:
"""Title and stack the per-render categorical legends (#364) in the right margin.

Only legends this code created (tagged ``_sdata_column``) are touched, so fill/outline and
channel legends keep their own placement. Each legend is titled by its source column (matching
colorbars). When 2+ legends share an axis they are stacked top-to-bottom — the right-margin
convention — so wide legends grow the figure height on save instead of overflowing its right edge.
"""
legends = [c for c in ax.get_children() if isinstance(c, Legend) and hasattr(c, "_sdata_column")]
if not legends:
return
# Title each legend by its source column so it reads like the colorbars (an explicit title set
# earlier stays as-is). A lone legend keeps scanpy's placement; only its title is added here.
for leg in legends:
if not leg.get_title().get_text():
leg.set_title(leg._sdata_column)
if len(legends) < 2:
return
# 2+ legends share the axis. Let constrained_layout settle, then freeze it: otherwise it shrinks
# the axes to "make room" for the margin legends (squashing the plot, leaving a gap). Frozen, the
# legends still count for `bbox_inches="tight"` on save.
fig.canvas.draw()
fig.set_layout_engine("none")
invf = fig.transFigure.inverted()
ax_bb = ax.get_window_extent().transformed(invf)
left, top = ax_bb.x1 + gap, ax_bb.y1
# Anchor all at one point and settle, so heights are measured in a single consistent layout state.
for leg in legends:
leg.set_bbox_to_anchor((left, top), transform=fig.transFigure)
if hasattr(leg, "set_loc"):
leg.set_loc("upper left")
fig.canvas.draw()
heights = [leg.get_window_extent().transformed(invf).height for leg in legends]
y = top
for leg, h in zip(legends, heights, strict=True):
leg.set_bbox_to_anchor((left, y), transform=fig.transFigure)
y -= h + gap


def _layout_pending_colorbars(
pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]],
fig_params: FigParams,
Expand Down Expand Up @@ -1944,19 +1990,32 @@ def _should_rasterize(
return scale is None or (isinstance(scale, str) and scale != "full" and (dpi is not None or figsize is not None))


def _maybe_set_label_colors(sdata: sd.SpatialData, render_params: LabelsRenderParams) -> None:
"""Materialize a categorical palette on the table annotating a labels element, if applicable."""
def _maybe_set_label_colors(
sdata: sd.SpatialData,
render_params: LabelsRenderParams,
used_colors: set[str] | None = None,
) -> None:
"""Materialize a categorical palette on the table annotating a labels element, if applicable.

``used_colors`` accumulates the colors already taken by earlier categorical label renders on
the same panel. When a column's colors are auto-generated (no user palette, not already in
``.uns``), they are shifted to skip ``used_colors`` so stacked legends stay distinct (#364).
"""
table = render_params.table_name
if table is None or render_params.col_for_color is None:
col = render_params.col_for_color
if table is None or col is None:
return
colors = sc.get.obs_df(sdata[table], [render_params.col_for_color])
if isinstance(colors[render_params.col_for_color].dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=render_params.col_for_color,
palette=render_params.palette,
)
colors = sc.get.obs_df(sdata[table], [col])
if not isinstance(colors[col].dtype, pd.CategoricalDtype):
return
adata = sdata[table]
color_key = f"{col}_colors"
if render_params.palette is None and used_colors and color_key not in adata.uns:
adata.uns[color_key] = _next_palette_colors(used_colors, len(colors[col].cat.categories))
else:
_maybe_set_colors(source=adata, target=adata, key=col, palette=render_params.palette)
if used_colors is not None and color_key in adata.uns:
used_colors.update(to_hex(to_rgba(c)) for c in adata.uns[color_key])


def _render_panel(
Expand Down Expand Up @@ -1985,6 +2044,9 @@ def _render_panel(
"""
wants = dict.fromkeys(("images", "labels", "points", "shapes"), False)
wanted_elements: list[str] = []
# Colors already taken by categorical label renders on this panel, so later renders can
# avoid reusing them and their stacked legends stay distinct (#364).
used_label_colors: set[str] = set()

for cmd, params in render_cmds:
# Skip render entries that belong to a different color panel. Entries with no
Expand Down Expand Up @@ -2033,7 +2095,7 @@ def _render_panel(
cast("ImageRenderParams | LabelsRenderParams", element_params), dpi, figsize
)
if cmd == "render_labels":
_maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params))
_maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params), used_label_colors)
_RENDERERS[cmd](**kwargs)

# Panel finalization depends only on per-panel values, so run it once after the loop.
Expand Down
6 changes: 3 additions & 3 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
_decorate_axs,
_fast_extent,
_join_table_for_element,
_legend_ncol,
_mpl_ax_contains_elements,
_multiscale_to_spatial_image,
_pixel_to_coord,
Expand Down Expand Up @@ -548,7 +549,7 @@ def _add_outline_legend(
loc=loc,
bbox_to_anchor=anchor,
fontsize=legend_params.legend_fontsize,
ncol=(1 if len(outline_handles) <= 14 else 2 if len(outline_handles) <= 30 else 3),
ncol=_legend_ncol(len(outline_handles)),
)


Expand Down Expand Up @@ -697,8 +698,7 @@ def _render_shapes(
nan_count = int(pd.isna(cv).sum())
if nan_count:
logger.warning(
f"Found {nan_count} NaN values in color data. "
"These observations will be colored with the 'na_color'."
f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
)
color_spec = color_spec.evolve(color_vector=cv)

Expand Down
95 changes: 79 additions & 16 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.legend import Legend
from matplotlib_scalebar.scalebar import ScaleBar
from pandas.api.types import CategoricalDtype, is_numeric_dtype
from pandas.core.arrays.categorical import Categorical
Expand Down Expand Up @@ -405,6 +406,49 @@ def _build_alignment_dtype_hint(
return ""


def _legend_ncol(n: int) -> int:
"""Column count for a categorical legend with ``n`` entries."""
return 1 if n <= 14 else 2 if n <= 30 else 3


def _categorical_legend_handles(ax: Axes, color_map: Mapping[Any, Any], na_hex: str | None = None) -> list[Any]:
"""Empty-scatter handles (colored dots) for a categorical legend, with an optional NA entry."""
handles = [ax.scatter([], [], c=color, label=str(cat)) for cat, color in color_map.items()]
if na_hex is not None:
handles.append(ax.scatter([], [], c=na_hex, label="NA"))
return handles


def _stack_categorical_legend(
ax: Axes,
color_mapping: Mapping[Any, Any],
*,
na_hex: str | None,
title: str | None,
column: str | None,
legend_fontsize: int | float | _FontSize | None,
) -> None:
"""Build the 2nd+ categorical legend on a shared axes without dropping existing ones (#364).

Placement and the column auto-title are finalized later by ``_layout_panel_legends``; the anchor
and (absent an explicit ``title``) the title here are provisional.
"""
handles = _categorical_legend_handles(ax, color_mapping, na_hex)
if (cur := ax.get_legend()) is not None:
ax.add_artist(cur) # only the current legend would be dropped by ax.legend() below
# Auto-title (by column) is applied in `_layout_panel_legends`; an explicit `title` still wins here.
new_leg = ax.legend(
handles=handles,
title=title,
frameon=False,
loc="upper left",
bbox_to_anchor=(1.02, 1.0),
fontsize=legend_fontsize,
ncol=_legend_ncol(len(handles)),
)
new_leg._sdata_column = column # type: ignore[attr-defined]


def _decorate_axs(
ax: Axes,
cax: PatchCollection,
Expand Down Expand Up @@ -449,22 +493,41 @@ def _decorate_axs(
}
)
color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict()
_add_categorical_legend(
ax,
pd.Categorical(values=color_source_vector, categories=clusters),
palette=color_mapping,
legend_loc=legend_loc,
legend_fontweight=legend_fontweight,
legend_fontsize=legend_fontsize,
legend_fontoutline=path_effect,
na_color=[na_color.get_hex()],
na_in_legend=na_in_legend,
multi_panel=fig_params.axs is not None,
)
# scanpy's helper doesn't accept a title; set it post-hoc so the user can
# disambiguate fill vs outline when both legends are drawn.
if legend_title is not None and (legend := ax.get_legend()) is not None:
legend.set_title(legend_title)
color_mapping = {k: v for k, v in color_mapping.items() if not pd.isnull(k)} # NA handled separately
# A lone categorical legend goes through scanpy unchanged. A 2nd categorical render would
# otherwise make scanpy's bare `ax.legend()` merge every labeled artist into one legend and
# drop the first (#364), so route 2nd+ legends through a helper that keeps them separate.
if legend_loc in (None, "none"):
pass
elif any(getattr(c, "_sdata_column", None) is not None for c in ax.get_children() if isinstance(c, Legend)):
na_hex = na_color.get_hex() if (na_in_legend and pd.isnull(color_source_vector).any()) else None
_stack_categorical_legend(
ax,
color_mapping,
na_hex=na_hex,
title=legend_title,
column=value_to_plot,
legend_fontsize=legend_fontsize,
)
else:
_add_categorical_legend(
ax,
pd.Categorical(values=color_source_vector, categories=clusters),
palette=color_mapping,
legend_loc=legend_loc,
legend_fontweight=legend_fontweight,
legend_fontsize=legend_fontsize,
legend_fontoutline=path_effect,
na_color=[na_color.get_hex()],
na_in_legend=na_in_legend,
multi_panel=fig_params.axs is not None,
)
# Tag with the column; the column auto-title is applied in `_layout_panel_legends`.
# An explicit title wins now.
if (legend := ax.get_legend()) is not None:
legend._sdata_column = value_to_plot # type: ignore[attr-defined]
if legend_title is not None:
legend.set_title(legend_title)
elif colorbar and colorbar_requests is not None and cax is not None:
colorbar_requests.append(
ColorbarSpec(
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading