diff --git a/src/spatialdata_plot/pl/_color.py b/src/spatialdata_plot/pl/_color.py index b3e3425a..f45d1fd9 100644 --- a/src/spatialdata_plot/pl/_color.py +++ b/src/spatialdata_plot/pl/_color.py @@ -1160,6 +1160,29 @@ def _modify_categorical_color_mapping( return modified_mapping +def _default_categorical_palette(n: int) -> list[str]: + """Return the scanpy default categorical palette sized for ``n`` categories (grey beyond 103).""" + if n <= 20: + return list(default_20) + if n <= 28: + return list(default_28) + if n <= len(default_102): + return list(default_102) + logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + return ["grey"] * n + + +def _next_palette_colors(used_colors: set[str], n: int) -> list[str]: + """Pick ``n`` default-palette colors skipping ``used_colors`` (keeps a 2nd categorical render distinct, #364). + + Falls back to the full palette if too few unused colors remain. + """ + used_norm = {to_hex(to_rgba(c)) for c in used_colors} + pool = _default_categorical_palette(n + len(used_norm)) + unused = [c for c in pool if to_hex(to_rgba(c)) not in used_norm] + return (unused if len(unused) >= n else pool)[:n] + + def _get_default_categorial_color_mapping( color_source_vector: ArrayLike | pd.Series[CategoricalDtype], cmap_params: CmapParams | None = None, @@ -1179,17 +1202,8 @@ def _get_default_categorial_color_mapping( else: palette = None - # Fall back to default palettes if needed if palette is None: - if len_cat <= 20: - palette = default_20 - elif len_cat <= 28: - palette = default_28 - elif len_cat <= len(default_102): # 103 colors - palette = default_102 - else: - palette = ["grey"] * len_cat - logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + palette = _default_categorical_palette(len_cat) return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True)) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 3001a317..c780c152 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -20,8 +20,9 @@ from geopandas import GeoDataFrame from matplotlib.axes import Axes from matplotlib.backend_bases import RendererBase -from matplotlib.colors import Colormap, LogNorm, Normalize +from matplotlib.colors import Colormap, LogNorm, Normalize, to_hex, to_rgba from matplotlib.figure import Figure +from matplotlib.legend import Legend from mpl_toolkits.axes_grid1.inset_locator import inset_axes from spatialdata._utils import _deprecation_alias from spatialdata.transformations.operations import get_transformation @@ -31,6 +32,7 @@ from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl._color import ( _maybe_set_colors, + _next_palette_colors, _prepare_cmap_norm, _set_outline, ) @@ -1584,6 +1586,11 @@ def show( _draw_scalebar(ax, scalebar_params_obj, panel_idx=i) + if fig_params.fig is not None: + for panel_ax in fig_params.axs if fig_params.axs is not None else [fig_params.ax]: + if isinstance(panel_ax, Axes): + _layout_panel_legends(panel_ax, fig_params.fig) + _layout_pending_colorbars(pending_colorbars, fig_params, colorbar_params) if fig_params.fig is not None and save is not None: @@ -1870,6 +1877,45 @@ def _draw_colorbar( trackers_axes[location] = pad_axes + (bbox_axes.width if vertical else bbox_axes.height) +def _layout_panel_legends(ax: Axes, fig: Figure, gap: float = 0.01) -> None: + """Title and stack the per-render categorical legends (#364) in the right margin. + + Only legends this code created (tagged ``_sdata_column``) are touched, so fill/outline and + channel legends keep their own placement. Each legend is titled by its source column (matching + colorbars). When 2+ legends share an axis they are stacked top-to-bottom — the right-margin + convention — so wide legends grow the figure height on save instead of overflowing its right edge. + """ + legends = [c for c in ax.get_children() if isinstance(c, Legend) and hasattr(c, "_sdata_column")] + if not legends: + return + # Title each legend by its source column so it reads like the colorbars (an explicit title set + # earlier stays as-is). A lone legend keeps scanpy's placement; only its title is added here. + for leg in legends: + if not leg.get_title().get_text(): + leg.set_title(leg._sdata_column) + if len(legends) < 2: + return + # 2+ legends share the axis. Let constrained_layout settle, then freeze it: otherwise it shrinks + # the axes to "make room" for the margin legends (squashing the plot, leaving a gap). Frozen, the + # legends still count for `bbox_inches="tight"` on save. + fig.canvas.draw() + fig.set_layout_engine("none") + invf = fig.transFigure.inverted() + ax_bb = ax.get_window_extent().transformed(invf) + left, top = ax_bb.x1 + gap, ax_bb.y1 + # Anchor all at one point and settle, so heights are measured in a single consistent layout state. + for leg in legends: + leg.set_bbox_to_anchor((left, top), transform=fig.transFigure) + if hasattr(leg, "set_loc"): + leg.set_loc("upper left") + fig.canvas.draw() + heights = [leg.get_window_extent().transformed(invf).height for leg in legends] + y = top + for leg, h in zip(legends, heights, strict=True): + leg.set_bbox_to_anchor((left, y), transform=fig.transFigure) + y -= h + gap + + def _layout_pending_colorbars( pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]], fig_params: FigParams, @@ -1944,19 +1990,32 @@ def _should_rasterize( return scale is None or (isinstance(scale, str) and scale != "full" and (dpi is not None or figsize is not None)) -def _maybe_set_label_colors(sdata: sd.SpatialData, render_params: LabelsRenderParams) -> None: - """Materialize a categorical palette on the table annotating a labels element, if applicable.""" +def _maybe_set_label_colors( + sdata: sd.SpatialData, + render_params: LabelsRenderParams, + used_colors: set[str] | None = None, +) -> None: + """Materialize a categorical palette on the table annotating a labels element, if applicable. + + ``used_colors`` accumulates the colors already taken by earlier categorical label renders on + the same panel. When a column's colors are auto-generated (no user palette, not already in + ``.uns``), they are shifted to skip ``used_colors`` so stacked legends stay distinct (#364). + """ table = render_params.table_name - if table is None or render_params.col_for_color is None: + col = render_params.col_for_color + if table is None or col is None: return - colors = sc.get.obs_df(sdata[table], [render_params.col_for_color]) - if isinstance(colors[render_params.col_for_color].dtype, pd.CategoricalDtype): - _maybe_set_colors( - source=sdata[table], - target=sdata[table], - key=render_params.col_for_color, - palette=render_params.palette, - ) + colors = sc.get.obs_df(sdata[table], [col]) + if not isinstance(colors[col].dtype, pd.CategoricalDtype): + return + adata = sdata[table] + color_key = f"{col}_colors" + if render_params.palette is None and used_colors and color_key not in adata.uns: + adata.uns[color_key] = _next_palette_colors(used_colors, len(colors[col].cat.categories)) + else: + _maybe_set_colors(source=adata, target=adata, key=col, palette=render_params.palette) + if used_colors is not None and color_key in adata.uns: + used_colors.update(to_hex(to_rgba(c)) for c in adata.uns[color_key]) def _render_panel( @@ -1985,6 +2044,9 @@ def _render_panel( """ wants = dict.fromkeys(("images", "labels", "points", "shapes"), False) wanted_elements: list[str] = [] + # Colors already taken by categorical label renders on this panel, so later renders can + # avoid reusing them and their stacked legends stay distinct (#364). + used_label_colors: set[str] = set() for cmd, params in render_cmds: # Skip render entries that belong to a different color panel. Entries with no @@ -2033,7 +2095,7 @@ def _render_panel( cast("ImageRenderParams | LabelsRenderParams", element_params), dpi, figsize ) if cmd == "render_labels": - _maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params)) + _maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params), used_label_colors) _RENDERERS[cmd](**kwargs) # Panel finalization depends only on per-panel values, so run it once after the loop. diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 1ab4825a..d8a31a6b 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -81,6 +81,7 @@ _decorate_axs, _fast_extent, _join_table_for_element, + _legend_ncol, _mpl_ax_contains_elements, _multiscale_to_spatial_image, _pixel_to_coord, @@ -548,7 +549,7 @@ def _add_outline_legend( loc=loc, bbox_to_anchor=anchor, fontsize=legend_params.legend_fontsize, - ncol=(1 if len(outline_handles) <= 14 else 2 if len(outline_handles) <= 30 else 3), + ncol=_legend_ncol(len(outline_handles)), ) @@ -697,8 +698,7 @@ def _render_shapes( nan_count = int(pd.isna(cv).sum()) if nan_count: logger.warning( - f"Found {nan_count} NaN values in color data. " - "These observations will be colored with the 'na_color'." + f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." ) color_spec = color_spec.evolve(color_vector=cv) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 323d6c3e..a14834a2 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -26,6 +26,7 @@ ) from matplotlib.figure import Figure from matplotlib.gridspec import GridSpec +from matplotlib.legend import Legend from matplotlib_scalebar.scalebar import ScaleBar from pandas.api.types import CategoricalDtype, is_numeric_dtype from pandas.core.arrays.categorical import Categorical @@ -405,6 +406,49 @@ def _build_alignment_dtype_hint( return "" +def _legend_ncol(n: int) -> int: + """Column count for a categorical legend with ``n`` entries.""" + return 1 if n <= 14 else 2 if n <= 30 else 3 + + +def _categorical_legend_handles(ax: Axes, color_map: Mapping[Any, Any], na_hex: str | None = None) -> list[Any]: + """Empty-scatter handles (colored dots) for a categorical legend, with an optional NA entry.""" + handles = [ax.scatter([], [], c=color, label=str(cat)) for cat, color in color_map.items()] + if na_hex is not None: + handles.append(ax.scatter([], [], c=na_hex, label="NA")) + return handles + + +def _stack_categorical_legend( + ax: Axes, + color_mapping: Mapping[Any, Any], + *, + na_hex: str | None, + title: str | None, + column: str | None, + legend_fontsize: int | float | _FontSize | None, +) -> None: + """Build the 2nd+ categorical legend on a shared axes without dropping existing ones (#364). + + Placement and the column auto-title are finalized later by ``_layout_panel_legends``; the anchor + and (absent an explicit ``title``) the title here are provisional. + """ + handles = _categorical_legend_handles(ax, color_mapping, na_hex) + if (cur := ax.get_legend()) is not None: + ax.add_artist(cur) # only the current legend would be dropped by ax.legend() below + # Auto-title (by column) is applied in `_layout_panel_legends`; an explicit `title` still wins here. + new_leg = ax.legend( + handles=handles, + title=title, + frameon=False, + loc="upper left", + bbox_to_anchor=(1.02, 1.0), + fontsize=legend_fontsize, + ncol=_legend_ncol(len(handles)), + ) + new_leg._sdata_column = column # type: ignore[attr-defined] + + def _decorate_axs( ax: Axes, cax: PatchCollection, @@ -449,22 +493,41 @@ def _decorate_axs( } ) color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict() - _add_categorical_legend( - ax, - pd.Categorical(values=color_source_vector, categories=clusters), - palette=color_mapping, - legend_loc=legend_loc, - legend_fontweight=legend_fontweight, - legend_fontsize=legend_fontsize, - legend_fontoutline=path_effect, - na_color=[na_color.get_hex()], - na_in_legend=na_in_legend, - multi_panel=fig_params.axs is not None, - ) - # scanpy's helper doesn't accept a title; set it post-hoc so the user can - # disambiguate fill vs outline when both legends are drawn. - if legend_title is not None and (legend := ax.get_legend()) is not None: - legend.set_title(legend_title) + color_mapping = {k: v for k, v in color_mapping.items() if not pd.isnull(k)} # NA handled separately + # A lone categorical legend goes through scanpy unchanged. A 2nd categorical render would + # otherwise make scanpy's bare `ax.legend()` merge every labeled artist into one legend and + # drop the first (#364), so route 2nd+ legends through a helper that keeps them separate. + if legend_loc in (None, "none"): + pass + elif any(getattr(c, "_sdata_column", None) is not None for c in ax.get_children() if isinstance(c, Legend)): + na_hex = na_color.get_hex() if (na_in_legend and pd.isnull(color_source_vector).any()) else None + _stack_categorical_legend( + ax, + color_mapping, + na_hex=na_hex, + title=legend_title, + column=value_to_plot, + legend_fontsize=legend_fontsize, + ) + else: + _add_categorical_legend( + ax, + pd.Categorical(values=color_source_vector, categories=clusters), + palette=color_mapping, + legend_loc=legend_loc, + legend_fontweight=legend_fontweight, + legend_fontsize=legend_fontsize, + legend_fontoutline=path_effect, + na_color=[na_color.get_hex()], + na_in_legend=na_in_legend, + multi_panel=fig_params.axs is not None, + ) + # Tag with the column; the column auto-title is applied in `_layout_panel_legends`. + # An explicit title wins now. + if (legend := ax.get_legend()) is not None: + legend._sdata_column = value_to_plot # type: ignore[attr-defined] + if legend_title is not None: + legend.set_title(legend_title) elif colorbar and colorbar_requests is not None and cax is not None: colorbar_requests.append( ColorbarSpec( diff --git a/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png b/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png index 77d14762..e7114687 100644 Binary files a/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png and b/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png differ diff --git a/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png b/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png index e99f972d..2df96201 100644 Binary files a/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png and b/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 2048d11c..eaba52dc 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -7,6 +7,7 @@ import scanpy as sc from anndata import AnnData from matplotlib.colors import Normalize +from matplotlib.legend import Legend from spatial_image import to_spatial_image from spatialdata import SpatialData, deepcopy, get_element_instances from spatialdata.models import Labels2DModel, Labels3DModel, TableModel @@ -96,6 +97,80 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData): .pl.show() ) + def test_two_categorical_label_renders_make_two_distinct_legends(self, sdata_blobs: SpatialData): + # Regression test for #364: two render_labels calls coloring by two categorical columns + # must produce two separate, column-titled legends (not one merged/replaced legend), and + # the second render must not reuse the first's colors. State-based (no image comparison). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + sdata_blobs["table"].obs["cat0"] = pd.Categorical((["A", "B"] * ((n + 1) // 2))[:n]) + sdata_blobs["table"].obs["cat1"] = pd.Categorical((["C", "D"] * ((n + 1) // 2))[:n]) + + ( + sdata_blobs.pl.render_labels("blobs_labels", color="cat0") + .pl.render_labels("blobs_labels", color="cat1") + .pl.show() + ) + + fig = plt.gcf() + ax = fig.axes[0] + legends = [c for c in ax.get_children() if isinstance(c, Legend)] + assert len(legends) == 2 + entries = {leg.get_title().get_text(): {t.get_text() for t in leg.get_texts()} for leg in legends} + assert entries == {"cat0": {"A", "B"}, "cat1": {"C", "D"}} + + # palette offset: the second categorical render must not reuse the first's colors + c0 = set(sdata_blobs["table"].uns["cat0_colors"]) + c1 = set(sdata_blobs["table"].uns["cat1_colors"]) + assert c0.isdisjoint(c1) + + # legends are stacked vertically in the right margin (not left-to-right): left-aligned and + # non-overlapping in y, so the second can't overflow the figure's right edge (#364) + fig.canvas.draw() + inv = ax.transAxes.inverted() + boxes = sorted((leg.get_window_extent().transformed(inv) for leg in legends), key=lambda b: b.y0, reverse=True) + assert boxes[1].y1 <= boxes[0].y0 # no vertical overlap (lower legend's top <= upper's bottom) + assert abs(boxes[0].x0 - boxes[1].x0) < 0.01 # left edges aligned + plt.close() + + def test_three_categorical_label_renders_make_three_legends(self, sdata_blobs: SpatialData): + # Regression test for #364: re-adding prior legends must not duplicate them; three renders + # yield exactly three distinct legends (not four with a repeat). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + for col, (a, b) in {"cat0": ("A", "B"), "cat1": ("C", "D"), "cat2": ("E", "F")}.items(): + sdata_blobs["table"].obs[col] = pd.Categorical(([a, b] * ((n + 1) // 2))[:n]) + + ( + sdata_blobs.pl.render_labels("blobs_labels", color="cat0") + .pl.render_labels("blobs_labels", color="cat1") + .pl.render_labels("blobs_labels", color="cat2") + .pl.show() + ) + + ax = plt.gcf().axes[0] + titles = sorted(c.get_title().get_text() for c in ax.get_children() if isinstance(c, Legend)) + assert titles == ["cat0", "cat1", "cat2"] + plt.close() + + def test_single_categorical_label_render_legend_titled_by_column(self, sdata_blobs: SpatialData): + # A lone categorical render produces exactly one legend titled by its source column, matching + # the colorbar's auto-title (#364). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + sdata_blobs["table"].obs["cat0"] = pd.Categorical((["A", "B"] * ((n + 1) // 2))[:n]) + + sdata_blobs.pl.render_labels("blobs_labels", color="cat0").pl.show() + + ax = plt.gcf().axes[0] + legends = [c for c in ax.get_children() if isinstance(c, Legend)] + assert len(legends) == 1 + assert legends[0].get_title().get_text() == "cat0" + plt.close() + def test_plot_can_color_by_rgba_array(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels("blobs_labels", color=[0.5, 0.5, 1.0, 0.5]).pl.show()