From e73d24a13d0cae792e5f8f6e9933ea9a6e9ef931 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Wed, 17 Jun 2026 15:20:17 +0100 Subject: [PATCH 1/3] refactor(dihedrals): add serial chunked conformational map-reduce --- CodeEntropy/config/runtime.py | 4 +- CodeEntropy/levels/conformation_dag.py | 18 +- CodeEntropy/levels/dihedrals.py | 810 ------------------ CodeEntropy/levels/dihedrals/__init__.py | 0 .../levels/dihedrals/angle_observations.py | 259 ++++++ .../dihedrals/conformational_state_builder.py | 198 +++++ CodeEntropy/levels/dihedrals/kernels.py | 116 +++ .../levels/dihedrals/peak_detection.py | 270 ++++++ .../levels/dihedrals/state_assignment.py | 281 ++++++ CodeEntropy/levels/dihedrals/topology.py | 146 ++++ 10 files changed, 1290 insertions(+), 812 deletions(-) delete mode 100644 CodeEntropy/levels/dihedrals.py create mode 100644 CodeEntropy/levels/dihedrals/__init__.py create mode 100644 CodeEntropy/levels/dihedrals/angle_observations.py create mode 100644 CodeEntropy/levels/dihedrals/conformational_state_builder.py create mode 100644 CodeEntropy/levels/dihedrals/kernels.py create mode 100644 CodeEntropy/levels/dihedrals/peak_detection.py create mode 100644 CodeEntropy/levels/dihedrals/state_assignment.py create mode 100644 CodeEntropy/levels/dihedrals/topology.py diff --git a/CodeEntropy/config/runtime.py b/CodeEntropy/config/runtime.py index 79afd73d..246944d3 100644 --- a/CodeEntropy/config/runtime.py +++ b/CodeEntropy/config/runtime.py @@ -36,7 +36,9 @@ from CodeEntropy.core.dask_clusters import HPCDaskManager from CodeEntropy.core.logging import LoggingConfig from CodeEntropy.entropy.workflow import EntropyWorkflow -from CodeEntropy.levels.dihedrals import ConformationStateBuilder +from CodeEntropy.levels.dihedrals.conformational_state_builder import ( + ConformationStateBuilder, +) from CodeEntropy.molecules.grouping import MoleculeGrouper from CodeEntropy.results.reporter import ResultsReporter from CodeEntropy.trajectory.mda import UniverseOperations diff --git a/CodeEntropy/levels/conformation_dag.py b/CodeEntropy/levels/conformation_dag.py index 78ad12ab..39a84b3b 100644 --- a/CodeEntropy/levels/conformation_dag.py +++ b/CodeEntropy/levels/conformation_dag.py @@ -8,7 +8,10 @@ from typing import Any -from CodeEntropy.levels.dihedrals import ConformationStateBuilder +from CodeEntropy.levels.dihedrals.conformational_state_builder import ( + ConformationStateBuilder, +) +from CodeEntropy.levels.execution.policy import ExecutionPolicy from CodeEntropy.trajectory.frames import FrameSelection SharedData = dict[str, Any] @@ -20,9 +23,16 @@ class ConformationDAG: """Execute conformational-state construction for selected trajectory frames.""" def __init__(self, universe_operations: Any | None = None) -> None: + """Initialize the conformational DAG. + + Args: + universe_operations: Optional universe-operation adapter passed to the + underlying conformation-state builder. + """ self._builder = ConformationStateBuilder( universe_operations=universe_operations ) + self._policy = ExecutionPolicy() def build(self) -> ConformationDAG: """Build the conformational DAG topology. @@ -44,6 +54,7 @@ def execute( shared_data: Shared workflow data containing ``reduced_universe``, ``levels``, ``groups``, ``frame_selection``, and ``args.bin_width``. progress: Optional progress sink forwarded to the conformation builder. + Returns: A dictionary containing the computed ``conformational_states`` mapping. """ @@ -52,6 +63,10 @@ def execute( groups = shared_data["groups"] frame_selection: FrameSelection = shared_data["frame_selection"] bin_width = int(shared_data["args"].bin_width) + chunk_size = self._policy.frame_chunk_size( + shared_data, + n_frames=frame_selection.n_frames, + ) states_ua, states_res, flexible_ua, flexible_res = ( self._builder.build_conformational_states( @@ -61,6 +76,7 @@ def execute( bin_width=bin_width, frame_selection=frame_selection, progress=progress, + chunk_size=chunk_size, ) ) diff --git a/CodeEntropy/levels/dihedrals.py b/CodeEntropy/levels/dihedrals.py deleted file mode 100644 index b54eb275..00000000 --- a/CodeEntropy/levels/dihedrals.py +++ /dev/null @@ -1,810 +0,0 @@ -"""Dihedral state assignment for conformational entropy. - -This module converts selected-frame dihedral angle time series into discrete -conformational state labels. The resulting state labels are used downstream to -compute configurational entropy. - -Frame-index contract: - - ``FrameSelection.analysis_indices`` are used for MDAnalysis trajectory access - in the active analysis universe. - - ``Dihedral(...).run(start, stop, step)`` uses frame bounds in the active - analysis-universe index space. - - ``dihedral_results.results.angles`` is always indexed locally from zero. - Never use an absolute/source frame index directly into that result array. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Any - -import numpy as np -from MDAnalysis.analysis.dihedrals import Dihedral -from rich.progress import TaskID - -from CodeEntropy.results.reporter import _RichProgressSink -from CodeEntropy.trajectory.frames import FrameSelection - -logger = logging.getLogger(__name__) - -UAKey = tuple[int, int] -PhiValues = dict[int, list[float]] -PhiContainer = dict[int, PhiValues | list[Any]] - - -@dataclass -class DihedralAngleData: - """Selected-frame dihedral angle data used to identify peaks. - - Attributes: - num_residues: Number of residues in the representative molecule. - num_dihedrals_ua: Number of united-atom dihedrals by residue index. - num_dihedrals_res: Number of residue-level dihedrals. - phi_ua: United-atom angle values by residue and dihedral index. - phi_res: Residue-level angle values by dihedral index, or an empty list - when no residue-level dihedrals are present. - """ - - num_residues: int - num_dihedrals_ua: list[int] - num_dihedrals_res: int - phi_ua: PhiContainer - phi_res: PhiValues | list[Any] - - -@dataclass -class DihedralPeakData: - """Histogram peak definitions used for conformational state assignment. - - Attributes: - peaks_ua: United-atom peak values by residue and dihedral index. - peaks_res: Residue-level peak values by dihedral index. - """ - - peaks_ua: list[list[Any]] - peaks_res: list[Any] - - -@dataclass -class ConformationStateData: - """Serial conformational state data calculated for one molecule group. - - Attributes: - state_res: Residue-level state labels for the group. - flex_res: Number of flexible residue-level dihedrals for the group. - states_ua_updates: United-atom state-label updates by ``(group, residue)``. - flexible_ua_updates: United-atom flexible-dihedral updates by - ``(group, residue)``. - """ - - state_res: list[str] - flex_res: int - states_ua_updates: dict[UAKey, list[str]] - flexible_ua_updates: dict[UAKey, int] - - -class ConformationStateBuilder: - """Build conformational state labels from selected-frame dihedral angles.""" - - def __init__(self, universe_operations: Any) -> None: - """Initialize the analysis helper. - - Args: - universe_operations: Object providing helper methods: - - extract_fragment(data_container, molecule_id) - - select_atoms(atomgroup, selection_string) - """ - self._universe_operations = universe_operations - - def build_conformational_states( - self, - data_container: Any, - levels: dict[Any, list[str]], - groups: dict[int, list[Any]], - bin_width: float, - frame_selection: FrameSelection, - progress: _RichProgressSink | None = None, - ) -> tuple[dict[UAKey, list[str]], list[list[str]], dict[UAKey, int], list[int]]: - """Build conformational state labels from selected trajectory frames. - - Args: - data_container: MDAnalysis Universe or compatible container used to - extract fragments and compute dihedral time series. - levels: Mapping of molecule id to enabled level names. - groups: Mapping of group id to molecule ids. - bin_width: Histogram bin width in degrees used when identifying peak - dihedral populations. - frame_selection: FrameSelection controlling which frames are analysed. - During the current migration stage, ``analysis_indices`` are local - indices into the physically frame-sliced analysis universe. - progress: Optional progress sink. - - Returns: - Tuple ``(states_ua, states_res, flexible_ua, flexible_res)``. - """ - number_groups = len(groups) - states_ua: dict[UAKey, list[str]] = {} - states_res: list[list[str]] = [[] for _ in range(number_groups)] - flexible_ua: dict[UAKey, int] = {} - flexible_res: list[int] = [] - - task: TaskID | None = None - if progress is not None: - total = max(1, len(groups)) - task = progress.add_task( - "[green]Conformational states", - total=total, - title="Initializing", - ) - - if not groups: - if progress is not None and task is not None: - progress.update(task, title="No groups") - progress.advance(task) - return states_ua, states_res, flexible_ua, flexible_res - - for group_id in groups.keys(): - molecules = groups[group_id] - if not molecules: - if progress is not None and task is not None: - progress.update(task, title=f"Group {group_id} (empty)") - progress.advance(task) - continue - - if progress is not None and task is not None: - progress.update(task, title=f"Group {group_id}") - - level_list = levels[molecules[0]] - - peaks_ua, peaks_res = self._identify_peaks( - data_container=data_container, - molecules=molecules, - bin_width=bin_width, - level_list=level_list, - frame_selection=frame_selection, - ) - - self._assign_states( - data_container=data_container, - group_id=group_id, - molecules=molecules, - level_list=level_list, - peaks_ua=peaks_ua, - peaks_res=peaks_res, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - frame_selection=frame_selection, - ) - - if progress is not None and task is not None: - progress.advance(task) - - logger.debug("States UA: %s", states_ua) - logger.debug("Number of flexible dihedrals UA: %s", flexible_ua) - logger.debug("States Res: %s", states_res) - logger.debug("Number of flexible dihedrals Res: %s", flexible_res) - - return states_ua, states_res, flexible_ua, flexible_res - - def _select_heavy_residue(self, mol: Any, res_id: int) -> Any: - """Select heavy atoms in a residue by residue index. - - Args: - mol: Representative molecule AtomGroup. - res_id: Local residue index. - - Returns: - AtomGroup containing heavy atoms in the residue selection. - """ - selection1 = mol.residues[res_id].atoms.indices[0] - selection2 = mol.residues[res_id].atoms.indices[-1] - - res_container = self._universe_operations.select_atoms( - mol, f"index {selection1}:{selection2}" - ) - return self._universe_operations.select_atoms(res_container, "prop mass > 1.1") - - def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]: - """Return dihedral AtomGroups for a container at a given level. - - Args: - data_container: MDAnalysis container. - level: Either ``"united_atom"`` or ``"residue"``. - - Returns: - List of AtomGroups, each representing a dihedral definition. - """ - atom_groups: list[Any] = [] - - if level == "united_atom": - for dihedral in data_container.dihedrals: - atom_groups.append(dihedral.atoms) - - if level == "residue": - num_residues = len(data_container.residues) - if num_residues >= 4: - for residue in range(4, num_residues + 1): - atom1 = data_container.select_atoms( - f"resindex {residue - 4} and bonded resindex {residue - 3}" - ) - atom2 = data_container.select_atoms( - f"resindex {residue - 3} and bonded resindex {residue - 4}" - ) - atom3 = data_container.select_atoms( - f"resindex {residue - 2} and bonded resindex {residue - 1}" - ) - atom4 = data_container.select_atoms( - f"resindex {residue - 1} and bonded resindex {residue - 2}" - ) - atom_groups.append(atom1 + atom2 + atom3 + atom4) - - logger.debug("Level: %s, Dihedrals: %s", level, atom_groups) - return atom_groups - - def _identify_peaks( - self, - data_container: Any, - molecules: list[Any], - bin_width: float, - level_list: list[Any], - frame_selection: FrameSelection, - ) -> tuple[list[list[Any]], list[Any]]: - """Identify histogram peaks for each selected-frame dihedral series. - - Args: - data_container: MDAnalysis universe. - molecules: Molecule ids in the group. - bin_width: Histogram bin width in degrees. - level_list: Enabled hierarchy levels for the representative molecule. - frame_selection: Selected frames in the active analysis-universe index - space. - - Returns: - Tuple of ``(peaks_ua, peaks_res)``. - """ - angle_data = self._collect_dihedral_angle_data( - data_container=data_container, - molecules=molecules, - level_list=level_list, - frame_selection=frame_selection, - ) - peak_data = self._build_peak_data( - angle_data=angle_data, - level_list=level_list, - bin_width=bin_width, - ) - return peak_data.peaks_ua, peak_data.peaks_res - - def _collect_dihedral_angle_data( - self, - data_container: Any, - molecules: list[Any], - level_list: list[Any], - frame_selection: FrameSelection, - ) -> DihedralAngleData: - """Collect selected-frame dihedral angle values for peak detection. - - Args: - data_container: MDAnalysis universe. - molecules: Molecule ids in the group. - level_list: Enabled hierarchy levels for the representative molecule. - frame_selection: Selected frames in the active analysis-universe index - space. - - Returns: - Dihedral angle values and dihedral counts for the group. - """ - rep_mol = self._universe_operations.extract_fragment( - data_container, molecules[0] - ) - number_frames = self._analysis_frame_count(frame_selection) - num_residues = len(rep_mol.residues) - - num_dihedrals_ua: list[int] = [0 for _ in range(num_residues)] - phi_ua: PhiContainer = {} - phi_res: PhiValues | list[Any] = {} - num_dihedrals_res = 0 - - for molecule in molecules: - mol = self._universe_operations.extract_fragment(data_container, molecule) - - for level in level_list: - if level == "united_atom": - for res_id in range(num_residues): - heavy_res = self._select_heavy_residue(mol, res_id) - dihedrals = self._get_dihedrals(heavy_res, level) - num_dihedrals_ua[res_id] = len(dihedrals) - - if num_dihedrals_ua[res_id] == 0: - phi_ua[res_id] = [] - continue - - if res_id not in phi_ua or isinstance(phi_ua[res_id], list): - phi_ua[res_id] = {} - - dihedral_results = self._run_dihedrals( - dihedrals=dihedrals, - frame_selection=frame_selection, - ) - phi_ua[res_id] = self._process_dihedral_phi( - dihedral_results=dihedral_results, - num_dihedrals=num_dihedrals_ua[res_id], - number_frames=number_frames, - phi_values=phi_ua[res_id], - ) - - elif level == "residue": - dihedrals = self._get_dihedrals(mol, level) - num_dihedrals_res = len(dihedrals) - - if num_dihedrals_res == 0: - phi_res = [] - continue - - if isinstance(phi_res, list): - phi_res = {} - - dihedral_results = self._run_dihedrals( - dihedrals=dihedrals, - frame_selection=frame_selection, - ) - phi_res = self._process_dihedral_phi( - dihedral_results=dihedral_results, - num_dihedrals=num_dihedrals_res, - number_frames=number_frames, - phi_values=phi_res, - ) - - logger.debug("phi_ua %s", phi_ua) - logger.debug("phi_res %s", phi_res) - - return DihedralAngleData( - num_residues=num_residues, - num_dihedrals_ua=num_dihedrals_ua, - num_dihedrals_res=num_dihedrals_res, - phi_ua=phi_ua, - phi_res=phi_res, - ) - - def _build_peak_data( - self, - angle_data: DihedralAngleData, - level_list: list[Any], - bin_width: float, - ) -> DihedralPeakData: - """Build histogram peak definitions from collected angle values. - - Args: - angle_data: Selected-frame angle values and dihedral counts. - level_list: Enabled hierarchy levels for the representative molecule. - bin_width: Histogram bin width in degrees. - - Returns: - Peak definitions for united-atom and residue-level states. - """ - peaks_ua: list[list[Any]] = [[] for _ in range(angle_data.num_residues)] - peaks_res: list[Any] = [] - - for level in level_list: - if level == "united_atom": - for res_id in range(angle_data.num_residues): - phi_values = angle_data.phi_ua.get(res_id) - if not phi_values: - peaks_ua[res_id] = [] - else: - peaks_ua[res_id] = self._process_histogram( - num_dihedrals=angle_data.num_dihedrals_ua[res_id], - phi_values=phi_values, - bin_width=bin_width, - ) - - elif level == "residue": - if not angle_data.phi_res: - peaks_res = [] - else: - peaks_res = self._process_histogram( - num_dihedrals=angle_data.num_dihedrals_res, - phi_values=angle_data.phi_res, - bin_width=bin_width, - ) - - return DihedralPeakData(peaks_ua=peaks_ua, peaks_res=peaks_res) - - def _process_dihedral_phi( - self, - dihedral_results: Any, - num_dihedrals: int, - number_frames: int, - phi_values: PhiValues, - ) -> PhiValues: - """Collect positive-angle dihedral values from a local result array. - - Args: - dihedral_results: Result of ``MDAnalysis.analysis.dihedrals.Dihedral``. - num_dihedrals: Number of dihedrals in the result. - number_frames: Number of local frames in ``dihedral_results``. - phi_values: Existing accumulator mapping dihedral index to values. - - Returns: - Updated ``phi_values`` accumulator. - - Notes: - ``dihedral_results.results.angles`` is indexed locally from zero. - """ - for dihedral_index in range(num_dihedrals): - phi: list[float] = [] - - for local_i in range(number_frames): - value = dihedral_results.results.angles[local_i][dihedral_index] - if value < 0: - value += 360 - phi.append(float(value)) - - if dihedral_index not in phi_values: - phi_values[dihedral_index] = phi - else: - phi_values[dihedral_index].extend(phi) - - return phi_values - - def _process_histogram( - self, - num_dihedrals: int, - phi_values: PhiValues, - bin_width: float, - ) -> list[Any]: - """Find histogram peaks from dihedral angle values. - - Args: - num_dihedrals: Number of dihedrals. - phi_values: Mapping from dihedral index to angle values. - bin_width: Histogram bin width in degrees. - - Returns: - List of peak lists, one per dihedral. - """ - peak_values = [] - for dihedral_index in range(num_dihedrals): - phi = phi_values[dihedral_index] - number_bins = int(360 / bin_width) - popul, bin_edges = np.histogram(a=phi, bins=number_bins, range=(0, 360)) - - logger.debug("Histogram: %s", popul) - - bin_value = [ - 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(0, len(popul)) - ] - - peaks = self._find_histogram_peaks(popul=popul, bin_value=bin_value) - peak_values.append(peaks) - - logger.debug("Dihedral: %s Peaks: %s", dihedral_index, peaks) - - return peak_values - - @staticmethod - def _find_histogram_peaks( - popul: np.ndarray[Any, Any], bin_value: list[float] - ) -> list[float]: - """Return convex turning-point peaks from a histogram. - - Args: - popul: Histogram bin populations. - bin_value: Histogram bin centre values. - - Returns: - List of peak positions. - """ - number_bins = len(popul) - peaks: list[float] = [] - - for bin_index in range(number_bins): - if popul[bin_index] == 0: - continue - - left = popul[bin_index - 1] - right = popul[0] if bin_index == number_bins - 1 else popul[bin_index + 1] - - if popul[bin_index] >= left and popul[bin_index] > right: - peaks.append(bin_value[bin_index]) - - return peaks - - def _assign_states( - self, - data_container: Any, - group_id: int, - molecules: list[Any], - level_list: list[Any], - peaks_ua: list[list[Any]], - peaks_res: list[Any], - states_ua: dict[UAKey, list[str]], - states_res: list[list[str]], - flexible_ua: dict[UAKey, int], - flexible_res: list[int], - frame_selection: FrameSelection, - ) -> None: - """Assign discrete state labels for selected-frame dihedrals. - - Args: - data_container: MDAnalysis universe. - group_id: Molecule group id. - molecules: Molecule ids in the group. - level_list: Enabled hierarchy levels. - peaks_ua: UA-level peaks by residue. - peaks_res: Residue-level peaks. - states_ua: UA state accumulator. - states_res: Residue state accumulator. - flexible_ua: UA flexible-dihedral accumulator. - flexible_res: Residue flexible-dihedral accumulator. - frame_selection: Selected frames in the active analysis-universe index - space. - - Returns: - None. Mutates the provided state/flexible accumulators. - """ - state_data = self._calculate_group_state_data( - data_container=data_container, - group_id=group_id, - molecules=molecules, - level_list=level_list, - peaks_ua=peaks_ua, - peaks_res=peaks_res, - frame_selection=frame_selection, - ) - self._merge_group_state_data( - state_data=state_data, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - ) - - def _calculate_group_state_data( - self, - data_container: Any, - group_id: int, - molecules: list[Any], - level_list: list[Any], - peaks_ua: list[list[Any]], - peaks_res: list[Any], - frame_selection: FrameSelection, - ) -> ConformationStateData: - """Calculate conformational states for one group without final merging. - - Args: - data_container: MDAnalysis universe. - group_id: Molecule group id. - molecules: Molecule ids in the group. - level_list: Enabled hierarchy levels. - peaks_ua: UA-level peaks by residue. - peaks_res: Residue-level peaks. - frame_selection: Selected frames in the active analysis-universe index - space. - - Returns: - Serial conformational state data for the group. - """ - rep_mol = self._universe_operations.extract_fragment( - data_container, molecules[0] - ) - number_frames = self._analysis_frame_count(frame_selection) - num_residues = len(rep_mol.residues) - - state_res: list[str] = [] - flex_res = 0 - states_ua_updates: dict[UAKey, list[str]] = {} - flexible_ua_updates: dict[UAKey, int] = {} - - for molecule in molecules: - mol = self._universe_operations.extract_fragment(data_container, molecule) - - for level in level_list: - if level == "united_atom": - for res_id in range(num_residues): - key = (group_id, res_id) - heavy_res = self._select_heavy_residue(mol, res_id) - dihedrals = self._get_dihedrals(heavy_res, level) - num_dihedrals = len(dihedrals) - - if num_dihedrals == 0: - states_ua_updates[key] = [] - flexible_ua_updates[key] = 0 - continue - - dihedral_results = self._run_dihedrals( - dihedrals=dihedrals, - frame_selection=frame_selection, - ) - states, flexible = self._process_conformations( - peaks=peaks_ua[res_id], - dihedral_results=dihedral_results, - num_dihedrals=num_dihedrals, - number_frames=number_frames, - ) - - if key not in states_ua_updates: - states_ua_updates[key] = states - flexible_ua_updates[key] = flexible - else: - states_ua_updates[key].extend(states) - flexible_ua_updates[key] = max( - flexible_ua_updates[key], flexible - ) - - if level == "residue": - dihedrals = self._get_dihedrals(mol, level) - num_dihedrals = len(dihedrals) - - if num_dihedrals == 0: - state_res = [] - continue - - dihedral_results = self._run_dihedrals( - dihedrals=dihedrals, - frame_selection=frame_selection, - ) - states, flexible = self._process_conformations( - peaks=peaks_res, - dihedral_results=dihedral_results, - num_dihedrals=num_dihedrals, - number_frames=number_frames, - ) - state_res.extend(states) - flex_res = max(flex_res, flexible) - - return ConformationStateData( - state_res=state_res, - flex_res=flex_res, - states_ua_updates=states_ua_updates, - flexible_ua_updates=flexible_ua_updates, - ) - - @staticmethod - def _merge_group_state_data( - state_data: ConformationStateData, - states_ua: dict[UAKey, list[str]], - states_res: list[list[str]], - flexible_ua: dict[UAKey, int], - flexible_res: list[int], - ) -> None: - """Merge one group's state data into final output accumulators. - - Args: - state_data: Serial conformational state data for one group. - states_ua: UA state accumulator to mutate. - states_res: Residue state accumulator to mutate. - flexible_ua: UA flexible-dihedral accumulator to mutate. - flexible_res: Residue flexible-dihedral accumulator to mutate. - - Returns: - None. Mutates the provided accumulators. - """ - for key, states in state_data.states_ua_updates.items(): - if key not in states_ua: - states_ua[key] = states - flexible_ua[key] = state_data.flexible_ua_updates[key] - else: - states_ua[key].extend(states) - flexible_ua[key] = max( - flexible_ua[key], - state_data.flexible_ua_updates[key], - ) - - states_res.append(state_data.state_res) - flexible_res.append(state_data.flex_res) - - def _process_conformations( - self, - peaks: list[Any], - dihedral_results: Any, - num_dihedrals: int, - number_frames: int, - ) -> tuple[list[str], int]: - """Assign conformational state labels from local dihedral results. - - Args: - peaks: Histogram peaks. - dihedral_results: Result of ``Dihedral(...).run(...)``. - num_dihedrals: Number of dihedrals. - number_frames: Number of local result frames. - - Returns: - Tuple of ``(states, num_flexible)``. - - Notes: - ``dihedral_results.results.angles`` is indexed locally from zero. - """ - states: list[str] = [] - conformations: list[list[Any]] = [] - num_flexible = 0 - - for dihedral_index in range(num_dihedrals): - conformation: list[Any] = [] - - for local_i in range(number_frames): - value = dihedral_results.results.angles[local_i][dihedral_index] - if value < 0: - value += 360 - - distances = [abs(value - peak) for peak in peaks[dihedral_index]] - conformation.append(np.argmin(distances)) - - unique = np.unique(conformation) - if len(unique) > 1: - num_flexible += 1 - - conformations.append(conformation) - - mol_states = [ - state - for state in ( - "".join(str(int(conformations[d][f])) for d in range(num_dihedrals)) - for f in range(number_frames) - ) - if state - ] - - states.extend(mol_states) - - return states, num_flexible - - def _run_dihedrals( - self, dihedrals: list[Any], frame_selection: FrameSelection - ) -> Any: - """Run MDAnalysis dihedral analysis over selected analysis frames. - - Args: - dihedrals: Dihedral AtomGroups. - frame_selection: Selected trajectory frame selection. - - Returns: - MDAnalysis Dihedral analysis result. - - Notes: - ``Dihedral.run(start, stop, step)`` uses frame bounds in the active - analysis-universe index space. The returned ``results.angles`` array - is indexed locally from zero. - """ - if not dihedrals: - raise ValueError("Cannot run Dihedral analysis with no dihedrals.") - - start, stop, step = self._analysis_run_bounds(frame_selection) - return Dihedral(dihedrals).run(start=start, stop=stop, step=step) - - @staticmethod - def _analysis_frame_count(frame_selection: FrameSelection) -> int: - """Return the number of selected frames. - - Args: - frame_selection: Selected trajectory frame selection. - - Returns: - Number of selected frames. - """ - return frame_selection.n_frames - - @staticmethod - def _analysis_run_bounds(frame_selection: FrameSelection) -> tuple[int, int, int]: - """Return MDAnalysis run bounds for selected analysis frames. - - Args: - frame_selection: Selected trajectory frame selection. - - Returns: - Tuple of ``(start, stop, step)`` in active analysis-universe index - space. - - Raises: - ValueError: If the selection is empty. - """ - start = frame_selection.source_start - stop = frame_selection.source_stop_exclusive - - if start is None or stop is None: - raise ValueError("Frame selection is empty.") - - return start, stop, frame_selection.infer_source_step() diff --git a/CodeEntropy/levels/dihedrals/__init__.py b/CodeEntropy/levels/dihedrals/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CodeEntropy/levels/dihedrals/angle_observations.py b/CodeEntropy/levels/dihedrals/angle_observations.py new file mode 100644 index 00000000..1ad1e606 --- /dev/null +++ b/CodeEntropy/levels/dihedrals/angle_observations.py @@ -0,0 +1,259 @@ +"""Selected-frame dihedral angle observation helpers. + +This module contains the frame-aware angle collection logic used by the +conformational state workflow. It preserves the MDAnalysis frame-index contract: +``Dihedral.run(...)`` receives active analysis-universe frame bounds, while the +returned ``results.angles`` array is indexed locally from zero. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np +from MDAnalysis.analysis.dihedrals import Dihedral + +from CodeEntropy.levels.dihedrals.kernels import wrap_degrees_positive +from CodeEntropy.levels.dihedrals.topology import ( + DihedralTopologyDiscovery, + MoleculeDihedralTopology, +) +from CodeEntropy.levels.execution.chunks import chunk_frame_indices +from CodeEntropy.trajectory.frames import FrameSelection + + +@dataclass(frozen=True) +class ConformationChunkTask: + """Serial conformational work item for one molecule and frame chunk. + + Attributes: + group_id: Molecule group id. + molecule_id: Molecule id. + molecule_order: Position of the molecule within the group. + chunk_id: Deterministic frame-chunk index. + frame_indices: Absolute analysis trajectory indices in this chunk. + frame_selection: FrameSelection covering this chunk. + """ + + group_id: int + molecule_id: Any + molecule_order: int + chunk_id: int + frame_indices: tuple[int, ...] + frame_selection: FrameSelection + + +@dataclass +class DihedralAngleObservable: + """Chunk-local dihedral angle arrays for one molecule/frame chunk. + + Attributes: + task: Source molecule/frame chunk task. + num_residues: Number of residues in the molecule. + ua_angles_by_residue: Positive-angle arrays by residue index. Each array + has shape ``(n_chunk_frames, n_dihedrals)``. + residue_angles: Positive-angle residue-level array with shape + ``(n_chunk_frames, n_residue_dihedrals)``, or ``None`` when the + residue level is disabled or has no dihedrals. + """ + + task: ConformationChunkTask + num_residues: int + ua_angles_by_residue: dict[int, np.ndarray] + residue_angles: np.ndarray | None + + +class DihedralAngleCollector(DihedralTopologyDiscovery): + """Collect dihedral angle observations from selected trajectory frames.""" + + def _build_conformation_chunk_tasks( + self, + topologies: list[MoleculeDihedralTopology], + frame_selection: FrameSelection, + chunk_size: int, + ) -> list[ConformationChunkTask]: + """Build deterministic molecule/frame chunk tasks for conformations. + + Args: + topologies: Per-molecule conformational topology entries. + frame_selection: Selected frames in active analysis-universe index + space. + chunk_size: Number of selected frames per chunk. + + Returns: + Conformation chunk tasks ordered by molecule order, then chunk id. + """ + frame_indices = tuple(int(i) for i in frame_selection.analysis_indices) + frame_chunks = chunk_frame_indices(list(frame_indices), int(chunk_size)) + tasks: list[ConformationChunkTask] = [] + + for topology in topologies: + for chunk_id, frame_chunk in enumerate(frame_chunks): + chunk_indices = tuple(int(index) for index in frame_chunk) + tasks.append( + ConformationChunkTask( + group_id=topology.group_id, + molecule_id=topology.molecule_id, + molecule_order=topology.molecule_order, + chunk_id=chunk_id, + frame_indices=chunk_indices, + frame_selection=self._frame_selection_from_chunk(chunk_indices), + ) + ) + + return tasks + + @staticmethod + def _frame_selection_from_chunk(frame_indices: tuple[int, ...]) -> FrameSelection: + """Build a FrameSelection for a selected frame chunk. + + Args: + frame_indices: Absolute trajectory frame indices in the chunk. + + Returns: + FrameSelection containing exactly the chunk frame indices. + + Raises: + ValueError: If the chunk is empty. + """ + if not frame_indices: + raise ValueError("Cannot build a frame selection from an empty chunk.") + + return FrameSelection(indices=tuple(int(index) for index in frame_indices)) + + def _collect_angle_observable( + self, + topology: MoleculeDihedralTopology, + task: ConformationChunkTask, + level_list: list[Any], + ) -> DihedralAngleObservable: + """Collect chunk-local positive-angle arrays for one molecule. + + Args: + topology: Static dihedral topology for the molecule. + task: Molecule/frame chunk task. + level_list: Enabled hierarchy levels. + + Returns: + Chunk-local angle observable used by both conformational reductions. + """ + number_frames = self._analysis_frame_count(task.frame_selection) + ua_angles_by_residue: dict[int, np.ndarray] = {} + residue_angles: np.ndarray | None = None + + if "united_atom" in level_list: + for res_id in range(topology.num_residues): + dihedrals = topology.ua_dihedrals_by_residue.get(res_id, []) + if not dihedrals: + ua_angles_by_residue[res_id] = np.empty( + (number_frames, 0), dtype=np.float64 + ) + continue + + dihedral_results = self._run_dihedrals( + dihedrals=dihedrals, + frame_selection=task.frame_selection, + ) + ua_angles_by_residue[res_id] = self._extract_positive_angle_array( + dihedral_results=dihedral_results, + num_dihedrals=len(dihedrals), + number_frames=number_frames, + ) + + if "residue" in level_list and topology.residue_dihedrals: + dihedral_results = self._run_dihedrals( + dihedrals=topology.residue_dihedrals, + frame_selection=task.frame_selection, + ) + residue_angles = self._extract_positive_angle_array( + dihedral_results=dihedral_results, + num_dihedrals=len(topology.residue_dihedrals), + number_frames=number_frames, + ) + + return DihedralAngleObservable( + task=task, + num_residues=topology.num_residues, + ua_angles_by_residue=ua_angles_by_residue, + residue_angles=residue_angles, + ) + + def _extract_positive_angle_array( + self, + dihedral_results: Any, + num_dihedrals: int, + number_frames: int, + ) -> np.ndarray: + """Extract a positive-angle NumPy array from MDAnalysis results. + + Args: + dihedral_results: Result of ``Dihedral(...).run(...)``. + num_dihedrals: Number of dihedrals in the result. + number_frames: Number of local result frames. + + Returns: + Positive-angle array with shape ``(number_frames, num_dihedrals)``. + """ + angles = np.asarray( + dihedral_results.results.angles[:number_frames, :num_dihedrals], + dtype=np.float64, + ) + return wrap_degrees_positive(angles) + + def _run_dihedrals( + self, dihedrals: list[Any], frame_selection: FrameSelection + ) -> Any: + """Run MDAnalysis dihedral analysis over selected analysis frames. + + Args: + dihedrals: Dihedral AtomGroups. + frame_selection: Selected trajectory frame selection. + + Returns: + MDAnalysis Dihedral analysis result. + + Notes: + ``Dihedral.run(start, stop, step)`` uses absolute active trajectory + frame bounds. The returned ``results.angles`` array is indexed + locally from zero. + """ + if not dihedrals: + raise ValueError("Cannot run Dihedral analysis with no dihedrals.") + + start, stop, step = self._analysis_run_bounds(frame_selection) + return Dihedral(dihedrals).run(start=start, stop=stop, step=step) + + @staticmethod + def _analysis_frame_count(frame_selection: FrameSelection) -> int: + """Return the number of selected frames. + + Args: + frame_selection: Selected trajectory frame selection. + + Returns: + Number of selected frames. + """ + return frame_selection.n_frames + + @staticmethod + def _analysis_run_bounds(frame_selection: FrameSelection) -> tuple[int, int, int]: + """Return MDAnalysis run bounds for selected analysis frames. + + Args: + frame_selection: Selected trajectory frame selection. + + Returns: + Tuple of ``(start, stop, step)`` in active analysis-universe index + space. + + Raises: + ValueError: If the selection is empty or irregularly spaced. + """ + start = frame_selection.source_start + stop = frame_selection.source_stop_exclusive + + if start is None or stop is None: + raise ValueError("Frame selection is empty.") + + return start, stop, frame_selection.infer_source_step() diff --git a/CodeEntropy/levels/dihedrals/conformational_state_builder.py b/CodeEntropy/levels/dihedrals/conformational_state_builder.py new file mode 100644 index 00000000..3cff0508 --- /dev/null +++ b/CodeEntropy/levels/dihedrals/conformational_state_builder.py @@ -0,0 +1,198 @@ +"""Conformational-state builder for dihedral analysis. + +This module builds the conformational state builder which is splits +domain-specific helpers for topology discovery, angle observation, +peak detection, and state assignment. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from rich.progress import TaskID + +from CodeEntropy.levels.dihedrals.peak_detection import ConformationPeakDetector +from CodeEntropy.levels.dihedrals.state_assignment import ( + ConformationStateAssigner, + UAKey, +) +from CodeEntropy.results.reporter import _RichProgressSink +from CodeEntropy.trajectory.frames import FrameSelection + +logger = logging.getLogger(__name__) + + +class ConformationStateBuilder(ConformationPeakDetector, ConformationStateAssigner): + """Build conformational state labels from selected-frame dihedral angles.""" + + def __init__(self, universe_operations: Any) -> None: + """Initialize the analysis helper. + + Args: + universe_operations: Object providing helper methods: + - extract_fragment(data_container, molecule_id) + - select_atoms(atomgroup, selection_string) + """ + self._universe_operations = universe_operations + + def build_conformational_states( + self, + data_container: Any, + levels: dict[Any, list[str]], + groups: dict[int, list[Any]], + bin_width: float, + frame_selection: FrameSelection, + progress: _RichProgressSink | None = None, + chunk_size: int | None = None, + ) -> tuple[dict[UAKey, list[str]], list[list[str]], dict[UAKey, int], list[int]]: + """Build conformational state labels from selected trajectory frames. + + Args: + data_container: MDAnalysis Universe or compatible container used to + extract fragments and compute dihedral time series. + levels: Mapping of molecule id to enabled level names. + groups: Mapping of group id to molecule ids. + bin_width: Histogram bin width in degrees used when identifying peak + dihedral populations. + frame_selection: FrameSelection controlling which absolute frames are + analysed. + progress: Optional progress sink. + chunk_size: Optional internal frame chunk size. When omitted, the + full selected-frame range is processed as a single chunk. + + Returns: + Tuple ``(states_ua, states_res, flexible_ua, flexible_res)``. + """ + if chunk_size is None: + chunk_size = max(1, int(frame_selection.n_frames)) + + return self._build_conformational_states_serial_chunked( + data_container=data_container, + levels=levels, + groups=groups, + bin_width=bin_width, + frame_selection=frame_selection, + chunk_size=chunk_size, + progress=progress, + ) + + def _build_conformational_states_serial_chunked( + self, + data_container: Any, + levels: dict[Any, list[str]], + groups: dict[int, list[Any]], + bin_width: float, + frame_selection: FrameSelection, + chunk_size: int, + progress: _RichProgressSink | None = None, + ) -> tuple[dict[UAKey, list[str]], list[list[str]], dict[UAKey, int], list[int]]: + """Build conformational states with serial frame-chunk map-reduce. + + Args: + data_container: MDAnalysis universe. + levels: Mapping of molecule id to enabled level names. + groups: Mapping of group id to molecule ids. + bin_width: Histogram bin width in degrees. + frame_selection: Selected absolute trajectory frames. + chunk_size: Number of selected frames per chunk. + progress: Optional progress sink. + + Returns: + Tuple ``(states_ua, states_res, flexible_ua, flexible_res)``. + + Raises: + ValueError: If ``chunk_size`` is less than one. + """ + if chunk_size < 1: + raise ValueError("chunk_size must be >= 1") + + number_groups = len(groups) + states_ua: dict[UAKey, list[str]] = {} + states_res: list[list[str]] = [[] for _ in range(number_groups)] + flexible_ua: dict[UAKey, int] = {} + flexible_res: list[int] = [] + + task: TaskID | None = None + if progress is not None: + total = max(1, len(groups)) + task = progress.add_task( + "[green]Conformational states", + total=total, + title="Initializing", + ) + + if not groups: + if progress is not None and task is not None: + progress.update(task, title="No groups") + progress.advance(task) + return states_ua, states_res, flexible_ua, flexible_res + + for group_id in groups.keys(): + molecules = groups[group_id] + if not molecules: + if progress is not None and task is not None: + progress.update(task, title=f"Group {group_id} (empty)") + progress.advance(task) + continue + + if progress is not None and task is not None: + progress.update(task, title=f"Group {group_id}") + + level_list = levels[molecules[0]] + topologies = self._discover_group_dihedral_topology( + data_container=data_container, + group_id=group_id, + molecules=molecules, + level_list=level_list, + ) + tasks = self._build_conformation_chunk_tasks( + topologies=topologies, + frame_selection=frame_selection, + chunk_size=chunk_size, + ) + topology_by_order = { + topology.molecule_order: topology for topology in topologies + } + + observables = [ + self._collect_angle_observable( + topology=topology_by_order[task_item.molecule_order], + task=task_item, + level_list=level_list, + ) + for task_item in tasks + ] + peak_data = self._reduce_angle_observables_to_peak_data( + observables=observables, + level_list=level_list, + bin_width=bin_width, + ) + state_partials = [ + self._assign_state_partial_from_observable( + observable=observable, + topology=topology_by_order[observable.task.molecule_order], + level_list=level_list, + peaks_ua=peak_data.peaks_ua, + peaks_res=peak_data.peaks_res, + ) + for observable in observables + ] + state_data = self._reduce_state_partials(state_partials) + self._merge_group_state_data( + state_data=state_data, + states_ua=states_ua, + states_res=states_res, + flexible_ua=flexible_ua, + flexible_res=flexible_res, + ) + + if progress is not None and task is not None: + progress.advance(task) + + logger.debug("States UA: %s", states_ua) + logger.debug("Number of flexible dihedrals UA: %s", flexible_ua) + logger.debug("States Res: %s", states_res) + logger.debug("Number of flexible dihedrals Res: %s", flexible_res) + + return states_ua, states_res, flexible_ua, flexible_res diff --git a/CodeEntropy/levels/dihedrals/kernels.py b/CodeEntropy/levels/dihedrals/kernels.py new file mode 100644 index 00000000..dc02094d --- /dev/null +++ b/CodeEntropy/levels/dihedrals/kernels.py @@ -0,0 +1,116 @@ +"""Numba kernels for dihedral conformational-state analysis. + +This module contains numeric kernels used by the serial chunked conformational +workflow. The kernels operate only on NumPy arrays and avoid MDAnalysis objects +so they remain safe to JIT compile and reuse inside future distributed workers. +""" + +from __future__ import annotations + +import numpy as np +from numba import njit + + +@njit(cache=True) +def wrap_degrees_positive(angles: np.ndarray) -> np.ndarray: + """Return dihedral angles wrapped into the positive degree range. + + Args: + angles: Angle array in degrees. The expected shape is + ``(n_frames, n_dihedrals)``. + + Returns: + Copy of ``angles`` with negative values shifted by 360 degrees. + """ + wrapped = angles.copy() + + for frame_i in range(wrapped.shape[0]): + for dihedral_i in range(wrapped.shape[1]): + if wrapped[frame_i, dihedral_i] < 0.0: + wrapped[frame_i, dihedral_i] += 360.0 + + return wrapped + + +@njit(cache=True) +def histogram_counts_by_dihedral( + angles: np.ndarray, + number_bins: int, +) -> np.ndarray: + """Build histogram counts for each dihedral angle series. + + Args: + angles: Positive-angle array with shape ``(n_frames, n_dihedrals)``. + number_bins: Number of histogram bins spanning 0 to 360 degrees. + + Returns: + Histogram counts with shape ``(n_dihedrals, number_bins)``. + """ + n_frames = angles.shape[0] + n_dihedrals = angles.shape[1] + counts = np.zeros((n_dihedrals, number_bins), dtype=np.int64) + bin_width = 360.0 / float(number_bins) + + for frame_i in range(n_frames): + for dihedral_i in range(n_dihedrals): + value = angles[frame_i, dihedral_i] + bin_i = int(value / bin_width) + + if bin_i < 0: + bin_i = 0 + elif bin_i >= number_bins: + bin_i = number_bins - 1 + + counts[dihedral_i, bin_i] += 1 + + return counts + + +@njit(cache=True) +def assign_peak_labels_and_count_flexible( + angles: np.ndarray, + padded_peaks: np.ndarray, + peak_counts: np.ndarray, +) -> tuple[np.ndarray, int]: + """Assign nearest-peak labels and count flexible dihedrals. + + Args: + angles: Positive-angle array with shape ``(n_frames, n_dihedrals)``. + padded_peaks: Peak values with shape ``(n_dihedrals, max_peaks)``. + peak_counts: Number of valid peaks for each dihedral. + + Returns: + Tuple containing an integer label array with shape + ``(n_frames, n_dihedrals)`` and the number of flexible dihedrals. + """ + n_frames = angles.shape[0] + n_dihedrals = angles.shape[1] + labels = np.zeros((n_frames, n_dihedrals), dtype=np.int64) + flexible_count = 0 + + for dihedral_i in range(n_dihedrals): + n_peaks = peak_counts[dihedral_i] + + if n_peaks < 1: + continue + + for frame_i in range(n_frames): + value = angles[frame_i, dihedral_i] + best_label = 0 + best_distance = abs(value - padded_peaks[dihedral_i, 0]) + + for peak_i in range(1, n_peaks): + distance = abs(value - padded_peaks[dihedral_i, peak_i]) + if distance < best_distance: + best_distance = distance + best_label = peak_i + + labels[frame_i, dihedral_i] = best_label + + first_label = labels[0, dihedral_i] + for frame_i in range(1, n_frames): + if labels[frame_i, dihedral_i] != first_label: + flexible_count += 1 + break + + return labels, flexible_count diff --git a/CodeEntropy/levels/dihedrals/peak_detection.py b/CodeEntropy/levels/dihedrals/peak_detection.py new file mode 100644 index 00000000..3efaa2e9 --- /dev/null +++ b/CodeEntropy/levels/dihedrals/peak_detection.py @@ -0,0 +1,270 @@ +"""Conformational peak detection from dihedral angle observations. + +This module contains histogram and peak-identification logic for converting +chunk-local selected-frame dihedral angle observations into global +conformational peak definitions. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np + +from CodeEntropy.levels.dihedrals.angle_observations import ( + DihedralAngleCollector, + DihedralAngleObservable, +) +from CodeEntropy.levels.dihedrals.kernels import ( + histogram_counts_by_dihedral, +) + +logger = logging.getLogger(__name__) + +HistogramValues = dict[int, np.ndarray] +HistogramContainer = dict[int, HistogramValues | list[Any]] + + +@dataclass +class DihedralPeakData: + """Histogram peak definitions used for conformational state assignment. + + Attributes: + peaks_ua: United-atom peak values by residue and dihedral index. + peaks_res: Residue-level peak values by dihedral index. + """ + + peaks_ua: list[list[Any]] + peaks_res: list[Any] + + +@dataclass +class DihedralHistogramData: + """Reduced histogram counts for one conformational group. + + Attributes: + num_residues: Number of residues in the representative molecule. + num_dihedrals_ua: Number of united-atom dihedrals by residue index. + num_dihedrals_res: Number of residue-level dihedrals. + hist_ua: United-atom histogram counts by residue and dihedral index. + hist_res: Residue-level histogram counts by dihedral index, or an empty + list when no residue-level histograms are present. + """ + + num_residues: int + num_dihedrals_ua: list[int] + num_dihedrals_res: int + hist_ua: HistogramContainer + hist_res: HistogramValues | list[Any] + + +class ConformationPeakDetector(DihedralAngleCollector): + """Identify conformational peak definitions from dihedral observations.""" + + def _reduce_angle_observables_to_peak_data( + self, + observables: list[DihedralAngleObservable], + level_list: list[Any], + bin_width: float, + ) -> DihedralPeakData: + """Reduce chunk-local angle observables into global peak definitions. + + Args: + observables: Chunk-local angle observables for one group. + level_list: Enabled hierarchy levels. + bin_width: Histogram bin width in degrees. + + Returns: + Global peak definitions for the group. + """ + histogram_data = self._reduce_angle_observables_to_histograms( + observables=observables, + level_list=level_list, + bin_width=bin_width, + ) + return self._build_peak_data_from_histograms( + histogram_data=histogram_data, + level_list=level_list, + bin_width=bin_width, + ) + + def _reduce_angle_observables_to_histograms( + self, + observables: list[DihedralAngleObservable], + level_list: list[Any], + bin_width: float, + ) -> DihedralHistogramData: + """Reduce chunk-local angle arrays into summed histogram counts. + + Args: + observables: Chunk-local angle observables for one group. + level_list: Enabled hierarchy levels. + bin_width: Histogram bin width in degrees. + + Returns: + Reduced histogram counts for the group. + """ + if not observables: + return DihedralHistogramData( + num_residues=0, + num_dihedrals_ua=[], + num_dihedrals_res=0, + hist_ua={}, + hist_res=[], + ) + + ordered_observables = sorted( + observables, + key=lambda observable: ( + observable.task.molecule_order, + observable.task.chunk_id, + ), + ) + number_bins = int(360 / bin_width) + first = ordered_observables[0] + num_residues = first.num_residues + num_dihedrals_ua = [0 for _ in range(num_residues)] + hist_ua: HistogramContainer = {} + hist_res: HistogramValues | list[Any] = {} + num_dihedrals_res = 0 + + if "united_atom" in level_list: + for res_id in range(num_residues): + for observable in ordered_observables: + angles = observable.ua_angles_by_residue.get(res_id) + if angles is None or angles.shape[1] == 0: + hist_ua.setdefault(res_id, []) + continue + + num_dihedrals_ua[res_id] = angles.shape[1] + counts = histogram_counts_by_dihedral(angles, number_bins) + + if res_id not in hist_ua or isinstance(hist_ua[res_id], list): + hist_ua[res_id] = {} + + target = cast(HistogramValues, hist_ua[res_id]) + for dihedral_index in range(counts.shape[0]): + if dihedral_index not in target: + target[dihedral_index] = counts[dihedral_index].copy() + else: + target[dihedral_index] = ( + target[dihedral_index] + counts[dihedral_index] + ) + + if "residue" in level_list: + for observable in ordered_observables: + if observable.residue_angles is None: + continue + + angles = observable.residue_angles + if angles.shape[1] == 0: + continue + + num_dihedrals_res = angles.shape[1] + counts = histogram_counts_by_dihedral(angles, number_bins) + + if isinstance(hist_res, list): + hist_res = {} + + target_res = cast(HistogramValues, hist_res) + for dihedral_index in range(counts.shape[0]): + if dihedral_index not in target_res: + target_res[dihedral_index] = counts[dihedral_index].copy() + else: + target_res[dihedral_index] = ( + target_res[dihedral_index] + counts[dihedral_index] + ) + + return DihedralHistogramData( + num_residues=num_residues, + num_dihedrals_ua=num_dihedrals_ua, + num_dihedrals_res=num_dihedrals_res, + hist_ua=hist_ua, + hist_res=hist_res, + ) + + def _build_peak_data_from_histograms( + self, + histogram_data: DihedralHistogramData, + level_list: list[Any], + bin_width: float, + ) -> DihedralPeakData: + """Build peak definitions from reduced histogram counts. + + Args: + histogram_data: Reduced histogram counts for one group. + level_list: Enabled hierarchy levels. + bin_width: Histogram bin width in degrees. + + Returns: + Peak definitions for united-atom and residue-level states. + """ + peaks_ua: list[list[Any]] = [[] for _ in range(histogram_data.num_residues)] + peaks_res: list[Any] = [] + number_bins = int(360 / bin_width) + bin_edges = np.linspace(0.0, 360.0, number_bins + 1) + bin_value = [ + 0.5 * (bin_edges[i] + bin_edges[i + 1]) for i in range(number_bins) + ] + + if "united_atom" in level_list: + for res_id in range(histogram_data.num_residues): + hist_values = histogram_data.hist_ua.get(res_id) + if not hist_values: + peaks_ua[res_id] = [] + continue + + hist_values = cast(HistogramValues, hist_values) + residue_peaks = [] + for dihedral_index in range(histogram_data.num_dihedrals_ua[res_id]): + counts = hist_values[dihedral_index] + residue_peaks.append( + self._find_histogram_peaks( + popul=counts, + bin_value=bin_value, + ) + ) + peaks_ua[res_id] = residue_peaks + + if "residue" in level_list and histogram_data.hist_res: + hist_res = cast(HistogramValues, histogram_data.hist_res) + for dihedral_index in range(histogram_data.num_dihedrals_res): + counts = hist_res[dihedral_index] + peaks_res.append( + self._find_histogram_peaks( + popul=counts, + bin_value=bin_value, + ) + ) + + return DihedralPeakData(peaks_ua=peaks_ua, peaks_res=peaks_res) + + @staticmethod + def _find_histogram_peaks( + popul: np.ndarray[Any, Any], bin_value: list[float] + ) -> list[float]: + """Return convex turning-point peaks from a histogram. + + Args: + popul: Histogram bin populations. + bin_value: Histogram bin centre values. + + Returns: + List of peak positions. + """ + number_bins = len(popul) + peaks: list[float] = [] + + for bin_index in range(number_bins): + if popul[bin_index] == 0: + continue + + left = popul[bin_index - 1] + right = popul[0] if bin_index == number_bins - 1 else popul[bin_index + 1] + + if popul[bin_index] >= left and popul[bin_index] > right: + peaks.append(bin_value[bin_index]) + + return peaks diff --git a/CodeEntropy/levels/dihedrals/state_assignment.py b/CodeEntropy/levels/dihedrals/state_assignment.py new file mode 100644 index 00000000..d2a1bd6d --- /dev/null +++ b/CodeEntropy/levels/dihedrals/state_assignment.py @@ -0,0 +1,281 @@ +"""Conformational state assignment from dihedral peak definitions. + +This module contains the logic for converting positive-angle dihedral arrays and +global peak definitions into state labels and flexible-dihedral counts. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from CodeEntropy.levels.dihedrals.angle_observations import ( + ConformationChunkTask, + DihedralAngleObservable, +) +from CodeEntropy.levels.dihedrals.kernels import ( + assign_peak_labels_and_count_flexible, +) +from CodeEntropy.levels.dihedrals.topology import MoleculeDihedralTopology + +UAKey = tuple[int, int] + + +@dataclass +class ConformationStateData: + """Serial conformational state data calculated for one molecule group. + + Attributes: + state_res: Residue-level state labels for the group. + flex_res: Number of flexible residue-level dihedrals for the group. + states_ua_updates: United-atom state-label updates by ``(group, residue)``. + flexible_ua_updates: United-atom flexible-dihedral updates by + ``(group, residue)``. + """ + + state_res: list[str] + flex_res: int + states_ua_updates: dict[UAKey, list[str]] + flexible_ua_updates: dict[UAKey, int] + + +@dataclass +class ConformationStatePartial: + """Chunk-local conformational state labels and flexible counts. + + Attributes: + task: Source molecule/frame chunk task. + state_res: Residue-level state labels for this chunk. + flex_res: Number of flexible residue-level dihedrals for this chunk. + states_ua_updates: United-atom state-label updates by ``(group, residue)``. + flexible_ua_updates: United-atom flexible-dihedral updates by + ``(group, residue)``. + """ + + task: ConformationChunkTask + state_res: list[str] + flex_res: int + states_ua_updates: dict[UAKey, list[str]] + flexible_ua_updates: dict[UAKey, int] + + +class ConformationStateAssigner: + """Assign conformational state labels from global dihedral peak definitions.""" + + def _assign_state_partial_from_observable( + self, + observable: DihedralAngleObservable, + topology: MoleculeDihedralTopology, + level_list: list[Any], + peaks_ua: list[list[Any]], + peaks_res: list[Any], + ) -> ConformationStatePartial: + """Assign chunk-local states from cached angle arrays and global peaks. + + Args: + observable: Chunk-local angle observable. + topology: Static topology for the observable molecule. + level_list: Enabled hierarchy levels. + peaks_ua: Global united-atom peaks by residue. + peaks_res: Global residue-level peaks. + + Returns: + Chunk-local state partial. + """ + state_res: list[str] = [] + flex_res = 0 + states_ua_updates: dict[UAKey, list[str]] = {} + flexible_ua_updates: dict[UAKey, int] = {} + + if "united_atom" in level_list: + for res_id in range(topology.num_residues): + key = (topology.group_id, res_id) + angles = observable.ua_angles_by_residue.get(res_id) + + if angles is None or angles.shape[1] == 0: + states_ua_updates[key] = [] + flexible_ua_updates[key] = 0 + continue + + states, flexible = self._process_conformations_from_angles( + peaks=peaks_ua[res_id], + angles=angles, + ) + states_ua_updates[key] = states + flexible_ua_updates[key] = flexible + + if "residue" in level_list and observable.residue_angles is not None: + if observable.residue_angles.shape[1] > 0: + state_res, flex_res = self._process_conformations_from_angles( + peaks=peaks_res, + angles=observable.residue_angles, + ) + + return ConformationStatePartial( + task=observable.task, + state_res=state_res, + flex_res=flex_res, + states_ua_updates=states_ua_updates, + flexible_ua_updates=flexible_ua_updates, + ) + + def _reduce_state_partials( + self, + partials: list[ConformationStatePartial], + ) -> ConformationStateData: + """Merge chunk-local state partials into one group-level result. + + Args: + partials: Chunk-local state partials for one group. + + Returns: + Group-level state data using deterministic molecule/chunk ordering. + """ + ordered_partials = sorted( + partials, + key=lambda partial: ( + partial.task.molecule_order, + partial.task.chunk_id, + ), + ) + + state_res: list[str] = [] + flex_res = 0 + states_ua_updates: dict[UAKey, list[str]] = {} + flexible_ua_updates: dict[UAKey, int] = {} + + for partial in ordered_partials: + for key, states in partial.states_ua_updates.items(): + if key not in states_ua_updates: + states_ua_updates[key] = list(states) + flexible_ua_updates[key] = partial.flexible_ua_updates[key] + else: + states_ua_updates[key].extend(states) + flexible_ua_updates[key] = max( + flexible_ua_updates[key], + partial.flexible_ua_updates[key], + ) + + state_res.extend(partial.state_res) + flex_res = max(flex_res, partial.flex_res) + + return ConformationStateData( + state_res=state_res, + flex_res=flex_res, + states_ua_updates=states_ua_updates, + flexible_ua_updates=flexible_ua_updates, + ) + + @staticmethod + def _merge_group_state_data( + state_data: ConformationStateData, + states_ua: dict[UAKey, list[str]], + states_res: list[list[str]], + flexible_ua: dict[UAKey, int], + flexible_res: list[int], + ) -> None: + """Merge one group's state data into final output accumulators. + + Args: + state_data: Serial conformational state data for one group. + states_ua: UA state accumulator to mutate. + states_res: Residue state accumulator to mutate. + flexible_ua: UA flexible-dihedral accumulator to mutate. + flexible_res: Residue flexible-dihedral accumulator to mutate. + + Returns: + None. Mutates the provided accumulators. + """ + for key, states in state_data.states_ua_updates.items(): + if key not in states_ua: + states_ua[key] = states + flexible_ua[key] = state_data.flexible_ua_updates[key] + else: + states_ua[key].extend(states) + flexible_ua[key] = max( + flexible_ua[key], + state_data.flexible_ua_updates[key], + ) + + states_res.append(state_data.state_res) + flexible_res.append(state_data.flex_res) + + def _process_conformations_from_angles( + self, + peaks: list[Any], + angles: np.ndarray, + ) -> tuple[list[str], int]: + """Assign conformational states from a positive-angle NumPy array. + + Args: + peaks: Histogram peaks by dihedral. + angles: Positive-angle array with shape ``(n_frames, n_dihedrals)``. + + Returns: + Tuple of ``(states, num_flexible)``. + """ + if angles.size == 0 or angles.shape[1] == 0: + return [], 0 + + padded_peaks, peak_counts = self._pad_peak_values(peaks) + labels, num_flexible = assign_peak_labels_and_count_flexible( + angles, + padded_peaks, + peak_counts, + ) + states = self._state_strings_from_labels(labels) + return states, int(num_flexible) + + @staticmethod + def _pad_peak_values(peaks: list[Any]) -> tuple[np.ndarray, np.ndarray]: + """Convert ragged peak lists into padded arrays for kernels. + + Args: + peaks: Peak values by dihedral. + + Returns: + Tuple of ``(padded_peaks, peak_counts)``. + """ + if not peaks: + return ( + np.zeros((0, 1), dtype=np.float64), + np.zeros(0, dtype=np.int64), + ) + + max_peaks = max((len(dihedral_peaks) for dihedral_peaks in peaks), default=0) + max_peaks = max(1, max_peaks) + padded = np.zeros((len(peaks), max_peaks), dtype=np.float64) + counts = np.zeros(len(peaks), dtype=np.int64) + + for dihedral_index, dihedral_peaks in enumerate(peaks): + counts[dihedral_index] = len(dihedral_peaks) + for peak_index, peak in enumerate(dihedral_peaks): + padded[dihedral_index, peak_index] = float(peak) + + return padded, counts + + @staticmethod + def _state_strings_from_labels(labels: np.ndarray) -> list[str]: + """Convert integer per-frame labels into legacy state strings. + + Args: + labels: Integer labels with shape ``(n_frames, n_dihedrals)``. + + Returns: + Legacy state strings, one per frame. + """ + states: list[str] = [] + number_frames = labels.shape[0] + num_dihedrals = labels.shape[1] + + for frame_index in range(number_frames): + state = "".join( + str(int(labels[frame_index, dihedral_index])) + for dihedral_index in range(num_dihedrals) + ) + if state: + states.append(state) + + return states diff --git a/CodeEntropy/levels/dihedrals/topology.py b/CodeEntropy/levels/dihedrals/topology.py new file mode 100644 index 00000000..680639a1 --- /dev/null +++ b/CodeEntropy/levels/dihedrals/topology.py @@ -0,0 +1,146 @@ +"""Dihedral topology discovery for conformational state analysis. + +This module contains the static molecule/residue dihedral discovery logic used +by conformational entropy calculations. The methods here identify which +dihedrals should be analysed; they do not inspect trajectory frames. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class MoleculeDihedralTopology: + """Static conformational dihedral topology for one molecule. + + Attributes: + group_id: Molecule group id. + molecule_id: Molecule id. + molecule_order: Position of the molecule within its group. + num_residues: Number of residues in the molecule. + ua_dihedrals_by_residue: United-atom dihedrals by residue index. + residue_dihedrals: Residue-level dihedrals for the molecule. + """ + + group_id: int + molecule_id: Any + molecule_order: int + num_residues: int + ua_dihedrals_by_residue: dict[int, list[Any]] + residue_dihedrals: list[Any] + + +class DihedralTopologyDiscovery: + """Discover molecule-level dihedral definitions for conformational analysis.""" + + def _discover_group_dihedral_topology( + self, + data_container: Any, + group_id: int, + molecules: list[Any], + level_list: list[Any], + ) -> list[MoleculeDihedralTopology]: + """Discover static conformational topology for a molecule group. + + Args: + data_container: MDAnalysis universe. + group_id: Molecule group id. + molecules: Molecule ids in the group. + level_list: Enabled hierarchy levels. + + Returns: + Static per-molecule dihedral topology used by both chunked passes. + """ + topologies: list[MoleculeDihedralTopology] = [] + + for molecule_order, molecule_id in enumerate(molecules): + mol = self._universe_operations.extract_fragment( + data_container, molecule_id + ) + num_residues = len(mol.residues) + ua_dihedrals_by_residue: dict[int, list[Any]] = {} + residue_dihedrals: list[Any] = [] + + if "united_atom" in level_list: + for res_id in range(num_residues): + heavy_res = self._select_heavy_residue(mol, res_id) + ua_dihedrals_by_residue[res_id] = self._get_dihedrals( + heavy_res, + "united_atom", + ) + + if "residue" in level_list: + residue_dihedrals = self._get_dihedrals(mol, "residue") + + topologies.append( + MoleculeDihedralTopology( + group_id=group_id, + molecule_id=molecule_id, + molecule_order=molecule_order, + num_residues=num_residues, + ua_dihedrals_by_residue=ua_dihedrals_by_residue, + residue_dihedrals=residue_dihedrals, + ) + ) + + return topologies + + def _select_heavy_residue(self, mol: Any, res_id: int) -> Any: + """Select heavy atoms in a residue by residue index. + + Args: + mol: Representative molecule AtomGroup. + res_id: Local residue index. + + Returns: + AtomGroup containing heavy atoms in the residue selection. + """ + selection1 = mol.residues[res_id].atoms.indices[0] + selection2 = mol.residues[res_id].atoms.indices[-1] + + res_container = self._universe_operations.select_atoms( + mol, f"index {selection1}:{selection2}" + ) + return self._universe_operations.select_atoms(res_container, "prop mass > 1.1") + + def _get_dihedrals(self, data_container: Any, level: str) -> list[Any]: + """Return dihedral AtomGroups for a container at a given level. + + Args: + data_container: MDAnalysis container. + level: Either ``"united_atom"`` or ``"residue"``. + + Returns: + List of AtomGroups, each representing a dihedral definition. + """ + atom_groups: list[Any] = [] + + if level == "united_atom": + for dihedral in data_container.dihedrals: + atom_groups.append(dihedral.atoms) + + if level == "residue": + num_residues = len(data_container.residues) + if num_residues >= 4: + for residue in range(4, num_residues + 1): + atom1 = data_container.select_atoms( + f"resindex {residue - 4} and bonded resindex {residue - 3}" + ) + atom2 = data_container.select_atoms( + f"resindex {residue - 3} and bonded resindex {residue - 4}" + ) + atom3 = data_container.select_atoms( + f"resindex {residue - 2} and bonded resindex {residue - 1}" + ) + atom4 = data_container.select_atoms( + f"resindex {residue - 1} and bonded resindex {residue - 2}" + ) + atom_groups.append(atom1 + atom2 + atom3 + atom4) + + logger.debug("Level: %s, Dihedrals: %s", level, atom_groups) + return atom_groups From a9e3b69e2d492df755fa1f701d49afc2a4183959 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Thu, 18 Jun 2026 09:33:33 +0100 Subject: [PATCH 2/3] tests(unit): add test cases for dihedral chunking module --- .../levels/dihedrals/angle_observations.py | 16 +- .../dihedrals/conformational_state_builder.py | 29 +- .../levels/dihedrals/peak_detection.py | 2 +- .../dihedrals/test_angle_observations.py | 180 ++++ .../test_conformational_state_builder.py | 247 +++++ .../levels/dihedrals/test_kernels.py | 106 ++ .../levels/dihedrals/test_peak_detection.py | 259 +++++ .../levels/dihedrals/test_state_assignment.py | 200 ++++ .../levels/dihedrals/test_topology.py | 146 +++ .../levels/test_conformation_dag.py | 199 ++-- .../unit/CodeEntropy/levels/test_dihedrals.py | 952 ------------------ 11 files changed, 1269 insertions(+), 1067 deletions(-) create mode 100644 tests/unit/CodeEntropy/levels/dihedrals/test_angle_observations.py create mode 100644 tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py create mode 100644 tests/unit/CodeEntropy/levels/dihedrals/test_kernels.py create mode 100644 tests/unit/CodeEntropy/levels/dihedrals/test_peak_detection.py create mode 100644 tests/unit/CodeEntropy/levels/dihedrals/test_state_assignment.py create mode 100644 tests/unit/CodeEntropy/levels/dihedrals/test_topology.py delete mode 100644 tests/unit/CodeEntropy/levels/test_dihedrals.py diff --git a/CodeEntropy/levels/dihedrals/angle_observations.py b/CodeEntropy/levels/dihedrals/angle_observations.py index 1ad1e606..95e2f6f5 100644 --- a/CodeEntropy/levels/dihedrals/angle_observations.py +++ b/CodeEntropy/levels/dihedrals/angle_observations.py @@ -115,12 +115,24 @@ def _frame_selection_from_chunk(frame_indices: tuple[int, ...]) -> FrameSelectio FrameSelection containing exactly the chunk frame indices. Raises: - ValueError: If the chunk is empty. + ValueError: If the chunk is empty, not strictly increasing, or not + regularly strided. """ if not frame_indices: raise ValueError("Cannot build a frame selection from an empty chunk.") - return FrameSelection(indices=tuple(int(index) for index in frame_indices)) + indices = tuple(int(index) for index in frame_indices) + + if len(indices) > 1: + step = indices[1] - indices[0] + if step <= 0: + raise ValueError("Frame chunk indices must be strictly increasing.") + + for previous, current in zip(indices, indices[1:], strict=False): + if current - previous != step: + raise ValueError("Frame chunk indices must be regularly strided.") + + return FrameSelection(indices=indices) def _collect_angle_observable( self, diff --git a/CodeEntropy/levels/dihedrals/conformational_state_builder.py b/CodeEntropy/levels/dihedrals/conformational_state_builder.py index 3cff0508..ee5268b7 100644 --- a/CodeEntropy/levels/dihedrals/conformational_state_builder.py +++ b/CodeEntropy/levels/dihedrals/conformational_state_builder.py @@ -1,8 +1,9 @@ -"""Conformational-state builder for dihedral analysis. +"""Public conformational-state builder for dihedral analysis. -This module builds the conformational state builder which is splits -domain-specific helpers for topology discovery, angle observation, -peak detection, and state assignment. +This module keeps the stable ``ConformationStateBuilder`` entry point used by +``ConformationDAG`` while the implementation is split across domain-specific +helpers for topology discovery, angle observation, peak detection, and state +assignment. """ from __future__ import annotations @@ -27,7 +28,7 @@ class ConformationStateBuilder(ConformationPeakDetector, ConformationStateAssign """Build conformational state labels from selected-frame dihedral angles.""" def __init__(self, universe_operations: Any) -> None: - """Initialize the analysis helper. + """Initialise the analysis helper. Args: universe_operations: Object providing helper methods: @@ -107,18 +108,16 @@ def _build_conformational_states_serial_chunked( if chunk_size < 1: raise ValueError("chunk_size must be >= 1") - number_groups = len(groups) states_ua: dict[UAKey, list[str]] = {} - states_res: list[list[str]] = [[] for _ in range(number_groups)] + states_res: list[list[str]] = [] flexible_ua: dict[UAKey, int] = {} flexible_res: list[int] = [] task: TaskID | None = None if progress is not None: - total = max(1, len(groups)) task = progress.add_task( "[green]Conformational states", - total=total, + total=max(1, len(groups)), title="Initializing", ) @@ -126,20 +125,24 @@ def _build_conformational_states_serial_chunked( if progress is not None and task is not None: progress.update(task, title="No groups") progress.advance(task) + return states_ua, states_res, flexible_ua, flexible_res - for group_id in groups.keys(): - molecules = groups[group_id] + for group_id, molecules in groups.items(): if not molecules: + states_res.append([]) + if progress is not None and task is not None: progress.update(task, title=f"Group {group_id} (empty)") progress.advance(task) + continue if progress is not None and task is not None: progress.update(task, title=f"Group {group_id}") level_list = levels[molecules[0]] + topologies = self._discover_group_dihedral_topology( data_container=data_container, group_id=group_id, @@ -163,11 +166,13 @@ def _build_conformational_states_serial_chunked( ) for task_item in tasks ] + peak_data = self._reduce_angle_observables_to_peak_data( observables=observables, level_list=level_list, bin_width=bin_width, ) + state_partials = [ self._assign_state_partial_from_observable( observable=observable, @@ -178,7 +183,9 @@ def _build_conformational_states_serial_chunked( ) for observable in observables ] + state_data = self._reduce_state_partials(state_partials) + self._merge_group_state_data( state_data=state_data, states_ua=states_ua, diff --git a/CodeEntropy/levels/dihedrals/peak_detection.py b/CodeEntropy/levels/dihedrals/peak_detection.py index 3efaa2e9..e87ac78a 100644 --- a/CodeEntropy/levels/dihedrals/peak_detection.py +++ b/CodeEntropy/levels/dihedrals/peak_detection.py @@ -127,7 +127,7 @@ def _reduce_angle_observables_to_histograms( num_residues = first.num_residues num_dihedrals_ua = [0 for _ in range(num_residues)] hist_ua: HistogramContainer = {} - hist_res: HistogramValues | list[Any] = {} + hist_res: HistogramValues | list[Any] = [] num_dihedrals_res = 0 if "united_atom" in level_list: diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_angle_observations.py b/tests/unit/CodeEntropy/levels/dihedrals/test_angle_observations.py new file mode 100644 index 00000000..f1eb3093 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_angle_observations.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from CodeEntropy.levels.dihedrals.angle_observations import ( + ConformationChunkTask, + DihedralAngleCollector, +) +from CodeEntropy.levels.dihedrals.topology import MoleculeDihedralTopology +from CodeEntropy.trajectory.frames import FrameSelection + + +class _AngleCollector(DihedralAngleCollector): + """Concrete angle collector for unit tests.""" + + def __init__(self) -> None: + """Initialize the test collector.""" + self._universe_operations = MagicMock() + + +def _make_frame_selection(*indices: int) -> FrameSelection: + """Build a FrameSelection from explicit indices. + + Args: + *indices: Absolute frame indices. + + Returns: + FrameSelection containing the requested indices. + """ + return FrameSelection(indices=tuple(indices)) + + +def _make_topology() -> MoleculeDihedralTopology: + """Build a small molecule topology used by angle-observation tests. + + Returns: + MoleculeDihedralTopology with one UA dihedral and one residue dihedral. + """ + return MoleculeDihedralTopology( + group_id=0, + molecule_id=7, + molecule_order=0, + num_residues=2, + ua_dihedrals_by_residue={0: ["ua0"], 1: []}, + residue_dihedrals=["res0"], + ) + + +def test_frame_selection_from_chunk_preserves_absolute_indices(): + frame_selection = _AngleCollector._frame_selection_from_chunk((10, 20, 30)) + + assert frame_selection.indices == (10, 20, 30) + assert frame_selection.analysis_indices == (10, 20, 30) + + +def test_frame_selection_from_single_frame_chunk_preserves_absolute_index(): + frame_selection = _AngleCollector._frame_selection_from_chunk((42,)) + + assert frame_selection.indices == (42,) + + +def test_frame_selection_from_chunk_rejects_invalid_chunks(): + with pytest.raises(ValueError, match="empty chunk"): + _AngleCollector._frame_selection_from_chunk(()) + + with pytest.raises(ValueError, match="strictly increasing"): + _AngleCollector._frame_selection_from_chunk((2, 1)) + + with pytest.raises(ValueError, match="regularly strided"): + _AngleCollector._frame_selection_from_chunk((0, 2, 5)) + + +def test_build_conformation_chunk_tasks_orders_by_molecule_then_chunk(): + collector = _AngleCollector() + frame_selection = _make_frame_selection(10, 20, 30) + topologies = [ + MoleculeDihedralTopology(0, "mol-a", 0, 1, {}, []), + MoleculeDihedralTopology(0, "mol-b", 1, 1, {}, []), + ] + + tasks = collector._build_conformation_chunk_tasks( + topologies=topologies, + frame_selection=frame_selection, + chunk_size=2, + ) + + assert [ + (task.molecule_id, task.chunk_id, task.frame_indices) for task in tasks + ] == [ + ("mol-a", 0, (10, 20)), + ("mol-a", 1, (30,)), + ("mol-b", 0, (10, 20)), + ("mol-b", 1, (30,)), + ] + + +def test_extract_positive_angle_array_wraps_negative_values(): + collector = _AngleCollector() + dihedral_results = SimpleNamespace( + results=SimpleNamespace( + angles=np.array([[-10.0, 20.0], [30.0, -40.0]], dtype=float) + ) + ) + + angles = collector._extract_positive_angle_array( + dihedral_results=dihedral_results, + num_dihedrals=2, + number_frames=2, + ) + + np.testing.assert_allclose(angles, np.array([[350.0, 20.0], [30.0, 320.0]])) + + +def test_collect_angle_observable_collects_ua_and_residue_arrays(): + collector = _AngleCollector() + task = ConformationChunkTask( + group_id=0, + molecule_id=7, + molecule_order=0, + chunk_id=0, + frame_indices=(0, 1), + frame_selection=FrameSelection.from_bounds(0, 2, 1), + ) + topology = _make_topology() + + ua_results = SimpleNamespace( + results=SimpleNamespace(angles=np.array([[-10.0], [10.0]], dtype=float)) + ) + residue_results = SimpleNamespace( + results=SimpleNamespace(angles=np.array([[30.0], [-40.0]], dtype=float)) + ) + collector._run_dihedrals = MagicMock(side_effect=[ua_results, residue_results]) + + observable = collector._collect_angle_observable( + topology=topology, + task=task, + level_list=["united_atom", "residue"], + ) + + np.testing.assert_allclose(observable.ua_angles_by_residue[0], [[350.0], [10.0]]) + assert observable.ua_angles_by_residue[1].shape == (2, 0) + np.testing.assert_allclose(observable.residue_angles, [[30.0], [320.0]]) + assert collector._run_dihedrals.call_count == 2 + + +def test_run_dihedrals_uses_frame_selection_bounds(): + collector = _AngleCollector() + frame_selection = FrameSelection.from_bounds(10, 40, 10) + fake_runner = MagicMock() + fake_runner.run.return_value = "result" + + with patch("CodeEntropy.levels.dihedrals.angle_observations.Dihedral") as fake_cls: + fake_cls.return_value = fake_runner + out = collector._run_dihedrals( + dihedrals=["D0"], + frame_selection=frame_selection, + ) + + assert out == "result" + fake_cls.assert_called_once_with(["D0"]) + fake_runner.run.assert_called_once_with(start=10, stop=31, step=10) + + +def test_run_dihedrals_raises_when_no_dihedrals(): + collector = _AngleCollector() + + with pytest.raises(ValueError, match="no dihedrals"): + collector._run_dihedrals( + dihedrals=[], + frame_selection=FrameSelection.from_bounds(0, 1, 1), + ) + + +def test_analysis_run_bounds_raises_when_frame_selection_empty(): + with pytest.raises(ValueError, match="Frame selection is empty"): + _AngleCollector._analysis_run_bounds(FrameSelection(indices=())) diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py b/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py new file mode 100644 index 00000000..390049d9 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +from CodeEntropy.levels.dihedrals.angle_observations import ( + ConformationChunkTask, + DihedralAngleObservable, +) +from CodeEntropy.levels.dihedrals.conformational_state_builder import ( + ConformationStateBuilder, +) +from CodeEntropy.levels.dihedrals.peak_detection import DihedralPeakData +from CodeEntropy.levels.dihedrals.state_assignment import ( + ConformationStateData, + ConformationStatePartial, +) +from CodeEntropy.levels.dihedrals.topology import MoleculeDihedralTopology +from CodeEntropy.trajectory.frames import FrameSelection + + +def _make_frame_selection( + start: int = 0, + stop: int = 2, + step: int = 1, +) -> FrameSelection: + """Build a FrameSelection for builder tests. + + Args: + start: Inclusive source-frame start. + stop: Exclusive source-frame stop. + step: Source-frame stride. + + Returns: + FrameSelection covering the requested bounds. + """ + return FrameSelection.from_bounds(start=start, stop=stop, step=step) + + +def test_build_conformational_states_defaults_chunk_size_to_selected_frame_count(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + builder._build_conformational_states_serial_chunked = MagicMock( + return_value=("states_ua", "states_res", "flex_ua", "flex_res") + ) + frame_selection = _make_frame_selection(start=0, stop=3, step=1) + + out = builder.build_conformational_states( + data_container="universe", + levels={7: ["residue"]}, + groups={0: [7]}, + bin_width=30.0, + frame_selection=frame_selection, + ) + + assert out == ("states_ua", "states_res", "flex_ua", "flex_res") + builder._build_conformational_states_serial_chunked.assert_called_once_with( + data_container="universe", + levels={7: ["residue"]}, + groups={0: [7]}, + bin_width=30.0, + frame_selection=frame_selection, + chunk_size=3, + progress=None, + ) + + +def test_build_conformational_states_passes_explicit_chunk_size(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + builder._build_conformational_states_serial_chunked = MagicMock( + return_value=({}, [], {}, []) + ) + frame_selection = _make_frame_selection(start=0, stop=3, step=1) + + builder.build_conformational_states( + data_container="universe", + levels={7: ["residue"]}, + groups={0: [7]}, + bin_width=30.0, + frame_selection=frame_selection, + chunk_size=2, + ) + + assert ( + builder._build_conformational_states_serial_chunked.call_args.kwargs[ + "chunk_size" + ] + == 2 + ) + + +def test_chunked_serial_rejects_invalid_chunk_size(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + + try: + builder._build_conformational_states_serial_chunked( + data_container="universe", + levels={}, + groups={}, + bin_width=30.0, + frame_selection=_make_frame_selection(start=0, stop=1, step=1), + chunk_size=0, + ) + except ValueError as exc: + assert "chunk_size must be >= 1" in str(exc) + else: + raise AssertionError("Expected invalid chunk size to raise ValueError") + + +def test_build_conformational_states_with_progress_handles_no_groups(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + progress = MagicMock() + progress.add_task.return_value = 123 + + states_ua, states_res, flex_ua, flex_res = builder.build_conformational_states( + data_container=MagicMock(), + levels={}, + groups={}, + bin_width=30.0, + frame_selection=_make_frame_selection(start=0, stop=1, step=1), + progress=progress, + ) + + assert states_ua == {} + assert states_res == [] + assert flex_ua == {} + assert flex_res == [] + progress.add_task.assert_called_once() + progress.update.assert_called_once_with(123, title="No groups") + progress.advance.assert_called_once_with(123) + + +def test_build_conformational_states_with_progress_skips_empty_molecule_group(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + progress = MagicMock() + progress.add_task.return_value = 5 + + states_ua, states_res, flex_ua, flex_res = builder.build_conformational_states( + data_container=MagicMock(), + levels={}, + groups={0: []}, + bin_width=30.0, + frame_selection=_make_frame_selection(start=0, stop=1, step=1), + progress=progress, + ) + + assert states_ua == {} + assert states_res == [[]] + assert flex_ua == {} + assert flex_res == [] + progress.update.assert_called_with(5, title="Group 0 (empty)") + progress.advance.assert_called_with(5) + + +def test_chunked_serial_group_flow_calls_domain_phases_in_order(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + frame_selection = _make_frame_selection(start=0, stop=2, step=1) + topology = MoleculeDihedralTopology(0, 7, 0, 1, {0: ["ua"]}, ["res"]) + task = ConformationChunkTask(0, 7, 0, 0, (0, 1), frame_selection) + observable = DihedralAngleObservable(task, 1, {}, None) + peak_data = DihedralPeakData(peaks_ua=[[[10.0]]], peaks_res=[]) + partial = ConformationStatePartial(task, [], 0, {}, {}) + state_data = ConformationStateData([], 0, {(0, 0): ["0"]}, {(0, 0): 0}) + + builder._discover_group_dihedral_topology = MagicMock(return_value=[topology]) + builder._build_conformation_chunk_tasks = MagicMock(return_value=[task]) + builder._collect_angle_observable = MagicMock(return_value=observable) + builder._reduce_angle_observables_to_peak_data = MagicMock(return_value=peak_data) + builder._assign_state_partial_from_observable = MagicMock(return_value=partial) + builder._reduce_state_partials = MagicMock(return_value=state_data) + + states_ua, states_res, flexible_ua, flexible_res = ( + builder._build_conformational_states_serial_chunked( + data_container="universe", + levels={7: ["united_atom"]}, + groups={0: [7]}, + bin_width=30.0, + frame_selection=frame_selection, + chunk_size=2, + ) + ) + + assert states_ua == {(0, 0): ["0"]} + assert states_res == [[]] + assert flexible_ua == {(0, 0): 0} + assert flexible_res == [0] + builder._discover_group_dihedral_topology.assert_called_once_with( + data_container="universe", + group_id=0, + molecules=[7], + level_list=["united_atom"], + ) + builder._build_conformation_chunk_tasks.assert_called_once_with( + topologies=[topology], + frame_selection=frame_selection, + chunk_size=2, + ) + builder._collect_angle_observable.assert_called_once_with( + topology=topology, + task=task, + level_list=["united_atom"], + ) + builder._reduce_angle_observables_to_peak_data.assert_called_once_with( + observables=[observable], + level_list=["united_atom"], + bin_width=30.0, + ) + builder._assign_state_partial_from_observable.assert_called_once_with( + observable=observable, + topology=topology, + level_list=["united_atom"], + peaks_ua=peak_data.peaks_ua, + peaks_res=peak_data.peaks_res, + ) + builder._reduce_state_partials.assert_called_once_with([partial]) + + +def test_chunked_serial_with_progress_updates_and_advances_non_empty_group(): + builder = ConformationStateBuilder(universe_operations=MagicMock()) + progress = MagicMock() + progress.add_task.return_value = 44 + + frame_selection = _make_frame_selection(start=0, stop=2, step=1) + topology = MoleculeDihedralTopology(0, 7, 0, 1, {0: ["ua"]}, ["res"]) + task = ConformationChunkTask(0, 7, 0, 0, (0, 1), frame_selection) + observable = DihedralAngleObservable(task, 1, {}, None) + peak_data = DihedralPeakData(peaks_ua=[[[10.0]]], peaks_res=[]) + partial = ConformationStatePartial(task, [], 0, {}, {}) + state_data = ConformationStateData([], 0, {(0, 0): ["0"]}, {(0, 0): 0}) + + builder._discover_group_dihedral_topology = MagicMock(return_value=[topology]) + builder._build_conformation_chunk_tasks = MagicMock(return_value=[task]) + builder._collect_angle_observable = MagicMock(return_value=observable) + builder._reduce_angle_observables_to_peak_data = MagicMock(return_value=peak_data) + builder._assign_state_partial_from_observable = MagicMock(return_value=partial) + builder._reduce_state_partials = MagicMock(return_value=state_data) + + builder._build_conformational_states_serial_chunked( + data_container="universe", + levels={7: ["united_atom"]}, + groups={0: [7]}, + bin_width=30.0, + frame_selection=frame_selection, + chunk_size=2, + progress=progress, + ) + + progress.update.assert_any_call(44, title="Group 0") + progress.advance.assert_called_with(44) diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_kernels.py b/tests/unit/CodeEntropy/levels/dihedrals/test_kernels.py new file mode 100644 index 00000000..21c31b29 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_kernels.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import numpy as np + +from CodeEntropy.levels.dihedrals.kernels import ( + assign_peak_labels_and_count_flexible, + histogram_counts_by_dihedral, + wrap_degrees_positive, +) + + +def test_wrap_degrees_positive_returns_copy_and_wraps_negative_values(): + angles = np.array([[-10.0, 20.0], [-180.0, 0.0]], dtype=np.float64) + + wrapped = wrap_degrees_positive(angles) + + np.testing.assert_allclose(wrapped, np.array([[350.0, 20.0], [180.0, 0.0]])) + np.testing.assert_allclose(angles, np.array([[-10.0, 20.0], [-180.0, 0.0]])) + + +def test_histogram_counts_by_dihedral_counts_each_dihedral_series(): + angles = np.array( + [ + [0.0, 89.0], + [90.0, 180.0], + [359.0, 360.0], + ], + dtype=np.float64, + ) + + counts = histogram_counts_by_dihedral(angles, number_bins=4) + + np.testing.assert_array_equal( + counts, + np.array( + [ + [1, 1, 0, 1], + [1, 0, 1, 1], + ], + dtype=np.int64, + ), + ) + + +def test_assign_peak_labels_uses_first_minimum_tie_and_counts_flexible(): + angles = np.array( + [ + [5.0, 100.0], + [15.0, 100.0], + ], + dtype=np.float64, + ) + padded_peaks = np.array( + [ + [0.0, 10.0], + [100.0, 0.0], + ], + dtype=np.float64, + ) + peak_counts = np.array([2, 1], dtype=np.int64) + + labels, flexible = assign_peak_labels_and_count_flexible( + angles, + padded_peaks, + peak_counts, + ) + + np.testing.assert_array_equal(labels, np.array([[0, 0], [1, 0]], dtype=np.int64)) + assert flexible == 1 + + +def test_assign_peak_labels_handles_dihedrals_with_no_peaks(): + angles = np.array([[10.0], [20.0]], dtype=np.float64) + padded_peaks = np.zeros((1, 1), dtype=np.float64) + peak_counts = np.array([0], dtype=np.int64) + + labels, flexible = assign_peak_labels_and_count_flexible( + angles, + padded_peaks, + peak_counts, + ) + + np.testing.assert_array_equal(labels, np.zeros((2, 1), dtype=np.int64)) + assert flexible == 0 + + +def test_histogram_counts_by_dihedral_clamps_negative_values_to_first_bin(): + angles = np.array([[-1.0], [10.0]], dtype=np.float64) + + counts = histogram_counts_by_dihedral(angles, number_bins=4) + + np.testing.assert_array_equal(counts, np.array([[2, 0, 0, 0]])) + + +def test_histogram_counts_by_dihedral_clamps_negative_bin_to_zero(): + angles = np.array([[-90.0]], dtype=np.float64) + + counts = histogram_counts_by_dihedral( + angles=angles, + number_bins=4, + ) + + np.testing.assert_array_equal( + counts, + np.array([[1, 0, 0, 0]], dtype=np.int64), + ) diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_peak_detection.py b/tests/unit/CodeEntropy/levels/dihedrals/test_peak_detection.py new file mode 100644 index 00000000..920d8ca8 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_peak_detection.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import numpy as np + +from CodeEntropy.levels.dihedrals.angle_observations import ( + ConformationChunkTask, + DihedralAngleObservable, +) +from CodeEntropy.levels.dihedrals.peak_detection import ( + ConformationPeakDetector, + DihedralHistogramData, + DihedralPeakData, +) +from CodeEntropy.trajectory.frames import FrameSelection + + +class _PeakDetector(ConformationPeakDetector): + """Concrete peak detector for unit tests.""" + + def __init__(self) -> None: + """Initialize the test detector.""" + self._universe_operations = MagicMock() + + +def _make_task( + molecule_order: int = 0, + chunk_id: int = 0, +) -> ConformationChunkTask: + """Build a minimal conformation chunk task. + + Args: + molecule_order: Molecule order in the group. + chunk_id: Frame-chunk id. + + Returns: + ConformationChunkTask for one selected frame. + """ + return ConformationChunkTask( + group_id=0, + molecule_id=molecule_order, + molecule_order=molecule_order, + chunk_id=chunk_id, + frame_indices=(chunk_id,), + frame_selection=FrameSelection(indices=(chunk_id,)), + ) + + +def test_find_histogram_peaks_hits_interior_and_wraparound_last_bin(): + popul = np.array([0, 2, 0, 3], dtype=np.int64) + bin_value = [10.0, 20.0, 30.0, 40.0] + + peaks = _PeakDetector._find_histogram_peaks(popul=popul, bin_value=bin_value) + + assert peaks == [20.0, 40.0] + + +def test_reduce_angle_observables_to_peak_data_delegates_to_reducers(): + detector = _PeakDetector() + histogram_data = DihedralHistogramData(0, [], 0, {}, []) + peak_data = DihedralPeakData(peaks_ua=[], peaks_res=[]) + + detector._reduce_angle_observables_to_histograms = MagicMock( + return_value=histogram_data + ) + detector._build_peak_data_from_histograms = MagicMock(return_value=peak_data) + + out = detector._reduce_angle_observables_to_peak_data( + observables=[], + level_list=["united_atom"], + bin_width=30.0, + ) + + assert out is peak_data + detector._reduce_angle_observables_to_histograms.assert_called_once_with( + observables=[], + level_list=["united_atom"], + bin_width=30.0, + ) + detector._build_peak_data_from_histograms.assert_called_once_with( + histogram_data=histogram_data, + level_list=["united_atom"], + bin_width=30.0, + ) + + +def test_reduce_angle_observables_to_histograms_handles_empty_observables(): + detector = _PeakDetector() + + histogram_data = detector._reduce_angle_observables_to_histograms( + observables=[], + level_list=["united_atom", "residue"], + bin_width=90.0, + ) + + assert histogram_data.num_residues == 0 + assert histogram_data.num_dihedrals_ua == [] + assert histogram_data.num_dihedrals_res == 0 + assert histogram_data.hist_ua == {} + assert histogram_data.hist_res == [] + + +def test_reduce_angle_observables_to_histograms_sums_chunk_counts(): + detector = _PeakDetector() + observables = [ + DihedralAngleObservable( + task=_make_task(molecule_order=0, chunk_id=1), + num_residues=1, + ua_angles_by_residue={0: np.array([[190.0]], dtype=np.float64)}, + residue_angles=np.array([[190.0]], dtype=np.float64), + ), + DihedralAngleObservable( + task=_make_task(molecule_order=0, chunk_id=0), + num_residues=1, + ua_angles_by_residue={0: np.array([[10.0], [100.0]], dtype=np.float64)}, + residue_angles=np.array([[10.0], [100.0]], dtype=np.float64), + ), + ] + + histogram_data = detector._reduce_angle_observables_to_histograms( + observables=observables, + level_list=["united_atom", "residue"], + bin_width=90.0, + ) + + np.testing.assert_array_equal( + histogram_data.hist_ua[0][0], + np.array([1, 1, 1, 0], dtype=np.int64), + ) + np.testing.assert_array_equal( + histogram_data.hist_res[0], + np.array([1, 1, 1, 0], dtype=np.int64), + ) + assert histogram_data.num_dihedrals_ua == [1] + assert histogram_data.num_dihedrals_res == 1 + + +def test_reduce_angle_observables_to_histograms_handles_empty_ua_angles(): + detector = _PeakDetector() + observable = DihedralAngleObservable( + task=_make_task(), + num_residues=1, + ua_angles_by_residue={0: np.empty((2, 0), dtype=np.float64)}, + residue_angles=None, + ) + + histogram_data = detector._reduce_angle_observables_to_histograms( + observables=[observable], + level_list=["united_atom"], + bin_width=90.0, + ) + + assert histogram_data.num_residues == 1 + assert histogram_data.num_dihedrals_ua == [0] + assert histogram_data.hist_ua == {0: []} + + +def test_reduce_angle_observables_to_histograms_skips_missing_residue_angles(): + detector = _PeakDetector() + observable = DihedralAngleObservable( + task=_make_task(), + num_residues=1, + ua_angles_by_residue={}, + residue_angles=None, + ) + + histogram_data = detector._reduce_angle_observables_to_histograms( + observables=[observable], + level_list=["residue"], + bin_width=90.0, + ) + + assert histogram_data.num_dihedrals_res == 0 + assert histogram_data.hist_res == [] + + +def test_reduce_angle_observables_to_histograms_skips_empty_residue_angles(): + detector = _PeakDetector() + observable = DihedralAngleObservable( + task=_make_task(), + num_residues=1, + ua_angles_by_residue={}, + residue_angles=np.empty((2, 0), dtype=np.float64), + ) + + histogram_data = detector._reduce_angle_observables_to_histograms( + observables=[observable], + level_list=["residue"], + bin_width=90.0, + ) + + assert histogram_data.num_dihedrals_res == 0 + assert histogram_data.hist_res == [] + + +def test_reduce_angle_observables_to_histograms_initialises_residue_histograms(): + detector = _PeakDetector() + observable = DihedralAngleObservable( + task=_make_task(), + num_residues=1, + ua_angles_by_residue={}, + residue_angles=np.array([[10.0], [20.0]], dtype=np.float64), + ) + + histogram_data = detector._reduce_angle_observables_to_histograms( + observables=[observable], + level_list=["residue"], + bin_width=90.0, + ) + + assert histogram_data.num_dihedrals_res == 1 + assert isinstance(histogram_data.hist_res, dict) + np.testing.assert_array_equal( + histogram_data.hist_res[0], + np.array([2, 0, 0, 0], dtype=np.int64), + ) + + +def test_build_peak_data_from_histograms_finds_ua_and_residue_peaks(): + detector = _PeakDetector() + histogram_data = DihedralHistogramData( + num_residues=1, + num_dihedrals_ua=[1], + num_dihedrals_res=1, + hist_ua={0: {0: np.array([0, 2, 0, 1], dtype=np.int64)}}, + hist_res={0: np.array([1, 0, 3, 0], dtype=np.int64)}, + ) + + peak_data = detector._build_peak_data_from_histograms( + histogram_data=histogram_data, + level_list=["united_atom", "residue"], + bin_width=90.0, + ) + + assert peak_data == DihedralPeakData( + peaks_ua=[[[135.0, 315.0]]], + peaks_res=[[45.0, 225.0]], + ) + + +def test_build_peak_data_from_histograms_handles_empty_ua_histogram(): + detector = _PeakDetector() + histogram_data = DihedralHistogramData( + num_residues=1, + num_dihedrals_ua=[0], + num_dihedrals_res=0, + hist_ua={0: []}, + hist_res=[], + ) + + peak_data = detector._build_peak_data_from_histograms( + histogram_data=histogram_data, + level_list=["united_atom"], + bin_width=90.0, + ) + + assert peak_data.peaks_ua == [[]] + assert peak_data.peaks_res == [] diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_state_assignment.py b/tests/unit/CodeEntropy/levels/dihedrals/test_state_assignment.py new file mode 100644 index 00000000..4f7b90af --- /dev/null +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_state_assignment.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import numpy as np + +from CodeEntropy.levels.dihedrals.angle_observations import ( + ConformationChunkTask, + DihedralAngleObservable, +) +from CodeEntropy.levels.dihedrals.state_assignment import ( + ConformationStateAssigner, + ConformationStateData, + ConformationStatePartial, +) +from CodeEntropy.levels.dihedrals.topology import MoleculeDihedralTopology +from CodeEntropy.trajectory.frames import FrameSelection + + +class _StateAssigner(ConformationStateAssigner): + """Concrete state assigner for unit tests.""" + + def __init__(self) -> None: + """Initialize the test assigner.""" + self._universe_operations = MagicMock() + + +def _make_task(molecule_order: int, chunk_id: int) -> ConformationChunkTask: + """Build a minimal conformation task for reducer tests. + + Args: + molecule_order: Molecule ordering value. + chunk_id: Chunk ordering value. + + Returns: + ConformationChunkTask for a single frame. + """ + return ConformationChunkTask( + group_id=0, + molecule_id=molecule_order, + molecule_order=molecule_order, + chunk_id=chunk_id, + frame_indices=(chunk_id,), + frame_selection=FrameSelection.from_bounds(chunk_id, chunk_id + 1, 1), + ) + + +def test_pad_peak_values_converts_ragged_peaks_to_arrays(): + padded, counts = _StateAssigner._pad_peak_values([[10.0, 20.0], [30.0]]) + + np.testing.assert_allclose(padded, np.array([[10.0, 20.0], [30.0, 0.0]])) + np.testing.assert_array_equal(counts, np.array([2, 1], dtype=np.int64)) + + +def test_pad_peak_values_handles_empty_peak_list(): + padded, counts = _StateAssigner._pad_peak_values([]) + + assert padded.shape == (0, 1) + assert counts.shape == (0,) + + +def test_state_strings_from_labels_builds_legacy_state_strings(): + labels = np.array([[0, 1], [1, 0]], dtype=np.int64) + + states = _StateAssigner._state_strings_from_labels(labels) + + assert states == ["01", "10"] + + +def test_state_strings_from_labels_filters_empty_states(): + labels = np.empty((2, 0), dtype=np.int64) + + states = _StateAssigner._state_strings_from_labels(labels) + + assert states == [] + + +def test_process_conformations_from_angles_assigns_states_and_flexible_count(): + assigner = _StateAssigner() + angles = np.array([[5.0], [15.0]], dtype=np.float64) + + states, flexible = assigner._process_conformations_from_angles( + peaks=[[5.0, 15.0]], + angles=angles, + ) + + assert states == ["0", "1"] + assert flexible == 1 + + +def test_process_conformations_from_angles_handles_no_dihedrals(): + assigner = _StateAssigner() + angles = np.empty((2, 0), dtype=np.float64) + + states, flexible = assigner._process_conformations_from_angles( + peaks=[], + angles=angles, + ) + + assert states == [] + assert flexible == 0 + + +def test_assign_state_partial_from_observable_handles_ua_and_residue_levels(): + assigner = _StateAssigner() + task = _make_task(molecule_order=0, chunk_id=0) + topology = MoleculeDihedralTopology( + group_id=2, + molecule_id=7, + molecule_order=0, + num_residues=2, + ua_dihedrals_by_residue={0: ["ua0"], 1: []}, + residue_dihedrals=["res0"], + ) + observable = DihedralAngleObservable( + task=task, + num_residues=2, + ua_angles_by_residue={ + 0: np.array([[5.0], [15.0]], dtype=np.float64), + 1: np.empty((2, 0), dtype=np.float64), + }, + residue_angles=np.array([[30.0], [40.0]], dtype=np.float64), + ) + + partial = assigner._assign_state_partial_from_observable( + observable=observable, + topology=topology, + level_list=["united_atom", "residue"], + peaks_ua=[[[5.0, 15.0]], []], + peaks_res=[[30.0, 40.0]], + ) + + assert partial.states_ua_updates[(2, 0)] == ["0", "1"] + assert partial.flexible_ua_updates[(2, 0)] == 1 + assert partial.states_ua_updates[(2, 1)] == [] + assert partial.flexible_ua_updates[(2, 1)] == 0 + assert partial.state_res == ["0", "1"] + assert partial.flex_res == 1 + + +def test_reduce_state_partials_preserves_molecule_then_chunk_order_and_max_flex(): + assigner = _StateAssigner() + partials = [ + ConformationStatePartial( + task=_make_task(molecule_order=1, chunk_id=0), + state_res=["m1c0"], + flex_res=0, + states_ua_updates={(0, 0): ["m1c0"]}, + flexible_ua_updates={(0, 0): 0}, + ), + ConformationStatePartial( + task=_make_task(molecule_order=0, chunk_id=1), + state_res=["m0c1"], + flex_res=2, + states_ua_updates={(0, 0): ["m0c1"]}, + flexible_ua_updates={(0, 0): 2}, + ), + ConformationStatePartial( + task=_make_task(molecule_order=0, chunk_id=0), + state_res=["m0c0"], + flex_res=1, + states_ua_updates={(0, 0): ["m0c0"]}, + flexible_ua_updates={(0, 0): 1}, + ), + ] + + state_data = assigner._reduce_state_partials(partials) + + assert state_data.state_res == ["m0c0", "m0c1", "m1c0"] + assert state_data.flex_res == 2 + assert state_data.states_ua_updates[(0, 0)] == ["m0c0", "m0c1", "m1c0"] + assert state_data.flexible_ua_updates[(0, 0)] == 2 + + +def test_merge_group_state_data_extends_existing_ua_states(): + states_ua = {(0, 0): ["0"]} + states_res = [] + flexible_ua = {(0, 0): 1} + flexible_res = [] + state_data = ConformationStateData( + state_res=["1"], + flex_res=0, + states_ua_updates={(0, 0): ["1"], (0, 1): ["2"]}, + flexible_ua_updates={(0, 0): 2, (0, 1): 0}, + ) + + _StateAssigner._merge_group_state_data( + state_data=state_data, + states_ua=states_ua, + states_res=states_res, + flexible_ua=flexible_ua, + flexible_res=flexible_res, + ) + + assert states_ua[(0, 0)] == ["0", "1"] + assert states_ua[(0, 1)] == ["2"] + assert flexible_ua[(0, 0)] == 2 + assert flexible_ua[(0, 1)] == 0 + assert states_res == [["1"]] + assert flexible_res == [0] diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py b/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py new file mode 100644 index 00000000..60affa71 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_topology.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import numpy as np + +from CodeEntropy.levels.dihedrals.topology import DihedralTopologyDiscovery + + +class _AddableAG: + """Minimal addable AtomGroup test double.""" + + def __init__(self, name: str) -> None: + """Initialize the fake AtomGroup. + + Args: + name: Human-readable identifier used in composed names. + """ + self.name = name + + def __add__(self, other: _AddableAG) -> _AddableAG: + """Return a composed fake AtomGroup. + + Args: + other: Fake AtomGroup to combine with this object. + + Returns: + New fake AtomGroup containing a composed name. + """ + return _AddableAG(f"({self.name}+{other.name})") + + +class _TopologyDiscovery(DihedralTopologyDiscovery): + """Concrete topology-discovery helper for unit tests.""" + + def __init__(self, universe_operations: MagicMock) -> None: + """Initialize the test helper. + + Args: + universe_operations: Mock universe-operation adapter. + """ + self._universe_operations = universe_operations + + +def test_select_heavy_residue_builds_expected_selections(): + uops = MagicMock() + helper = _TopologyDiscovery(universe_operations=uops) + + mol = MagicMock() + mol.residues = [MagicMock()] + mol.residues[0].atoms.indices = np.array([10, 11, 12], dtype=int) + uops.select_atoms.side_effect = ["residue_atoms", "heavy_atoms"] + + out = helper._select_heavy_residue(mol, res_id=0) + + assert out == "heavy_atoms" + assert uops.select_atoms.call_args_list == [ + ((mol, "index 10:12"),), + (("residue_atoms", "prop mass > 1.1"),), + ] + + +def test_get_dihedrals_united_atom_collects_atoms_from_dihedral_objects(): + helper = _TopologyDiscovery(universe_operations=MagicMock()) + + d0 = MagicMock() + d0.atoms = "A0" + d1 = MagicMock() + d1.atoms = "A1" + + container = MagicMock() + container.dihedrals = [d0, d1] + + assert helper._get_dihedrals(container, level="united_atom") == ["A0", "A1"] + + +def test_get_dihedrals_residue_returns_empty_when_less_than_four_residues(): + helper = _TopologyDiscovery(universe_operations=MagicMock()) + + mol = MagicMock() + mol.residues = [MagicMock(), MagicMock(), MagicMock()] + mol.select_atoms = MagicMock() + + assert helper._get_dihedrals(mol, level="residue") == [] + mol.select_atoms.assert_not_called() + + +def test_get_dihedrals_residue_builds_one_dihedral_when_four_residues(): + helper = _TopologyDiscovery(universe_operations=MagicMock()) + + mol = MagicMock() + mol.residues = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] + mol.select_atoms = MagicMock( + side_effect=[ + _AddableAG("a1"), + _AddableAG("a2"), + _AddableAG("a3"), + _AddableAG("a4"), + ] + ) + + out = helper._get_dihedrals(mol, level="residue") + + assert len(out) == 1 + assert isinstance(out[0], _AddableAG) + assert mol.select_atoms.call_count == 4 + + +def test_discover_group_dihedral_topology_builds_one_entry_per_molecule(): + uops = MagicMock() + helper = _TopologyDiscovery(universe_operations=uops) + + mol0 = MagicMock() + mol0.residues = [MagicMock(), MagicMock()] + mol1 = MagicMock() + mol1.residues = [MagicMock(), MagicMock()] + uops.extract_fragment.side_effect = [mol0, mol1] + + helper._select_heavy_residue = MagicMock( + side_effect=["heavy0", "heavy1", "heavy2", "heavy3"] + ) + helper._get_dihedrals = MagicMock( + side_effect=[ + ["ua0r0"], + ["ua0r1"], + ["res0"], + ["ua1r0"], + ["ua1r1"], + ["res1"], + ] + ) + + topologies = helper._discover_group_dihedral_topology( + data_container="universe", + group_id=3, + molecules=[7, 8], + level_list=["united_atom", "residue"], + ) + + assert [topology.molecule_id for topology in topologies] == [7, 8] + assert [topology.molecule_order for topology in topologies] == [0, 1] + assert topologies[0].group_id == 3 + assert topologies[0].ua_dihedrals_by_residue == {0: ["ua0r0"], 1: ["ua0r1"]} + assert topologies[0].residue_dihedrals == ["res0"] + assert topologies[1].ua_dihedrals_by_residue == {0: ["ua1r0"], 1: ["ua1r1"]} + assert topologies[1].residue_dihedrals == ["res1"] diff --git a/tests/unit/CodeEntropy/levels/test_conformation_dag.py b/tests/unit/CodeEntropy/levels/test_conformation_dag.py index ad469ffe..92568ca9 100644 --- a/tests/unit/CodeEntropy/levels/test_conformation_dag.py +++ b/tests/unit/CodeEntropy/levels/test_conformation_dag.py @@ -1,137 +1,134 @@ -"""Unit tests for the conformational-state DAG stage.""" - from __future__ import annotations from types import SimpleNamespace +from unittest.mock import MagicMock -from CodeEntropy.levels import conformation_dag from CodeEntropy.levels.conformation_dag import ConformationDAG +from CodeEntropy.trajectory.frames import FrameSelection -class FakeConformationStateBuilder: - """Test double for ConformationStateBuilder.""" - - def __init__(self, universe_operations): - self.universe_operations = universe_operations - self.calls = [] - - def build_conformational_states( - self, - *, - data_container, - levels, - groups, - bin_width, - frame_selection, - progress=None, - ): - self.calls.append( - { - "data_container": data_container, - "levels": levels, - "groups": groups, - "bin_width": bin_width, - "frame_selection": frame_selection, - "progress": progress, - } - ) - return ( - {"ua_key": ["state_a"]}, - [["res_state"]], - {"ua_key": 1}, - [1], - ) - - -def test_conformation_dag_build_returns_self(): - dag = ConformationDAG() +def _make_frame_selection( + start: int = 0, + stop: int = 3, + step: int = 1, +) -> FrameSelection: + """Build a FrameSelection for ConformationDAG tests. - assert dag.build() is dag + Args: + start: Inclusive source-frame start. + stop: Exclusive source-frame stop. + step: Source-frame stride. + Returns: + FrameSelection covering the requested bounds. + """ + return FrameSelection.from_bounds(start=start, stop=stop, step=step) -def test_conformation_dag_executes_builder_and_writes_shared_data(monkeypatch): - builder_holder = {} - def builder_factory(universe_operations): - builder = FakeConformationStateBuilder(universe_operations) - builder_holder["builder"] = builder - return builder +def test_build_returns_self(): + """Test that build returns the DAG instance.""" + dag = ConformationDAG(universe_operations=MagicMock()) - monkeypatch.setattr( - conformation_dag, - "ConformationStateBuilder", - builder_factory, - ) + assert dag.build() is dag - universe_operations = object() - dag = ConformationDAG(universe_operations=universe_operations) - universe = object() - frame_selection = object() - progress = object() +def test_execute_uses_execution_policy_chunk_size_and_stores_outputs(): + """Test that execute stores conformational and flexible-dihedral outputs.""" + dag = ConformationDAG(universe_operations=MagicMock()) + dag._policy = MagicMock() + dag._policy.frame_chunk_size.return_value = 2 + dag._builder = MagicMock() + dag._builder.build_conformational_states.return_value = ( + {(0, 0): ["0"]}, + [["1"]], + {(0, 0): 1}, + [0], + ) + frame_selection = _make_frame_selection(start=0, stop=3, step=1) + progress = MagicMock() shared_data = { - "reduced_universe": universe, - "levels": [["united_atom", "residue"]], - "groups": {0: [0]}, + "reduced_universe": "universe", + "levels": {7: ["united_atom"]}, + "groups": {0: [7]}, "frame_selection": frame_selection, "args": SimpleNamespace(bin_width=30), } - result = dag.execute(shared_data, progress=progress) + out = dag.execute(shared_data, progress=progress) + + dag._policy.frame_chunk_size.assert_called_once_with( + shared_data, + n_frames=frame_selection.n_frames, + ) + dag._builder.build_conformational_states.assert_called_once_with( + data_container="universe", + levels={7: ["united_atom"]}, + groups={0: [7]}, + bin_width=30, + frame_selection=frame_selection, + progress=progress, + chunk_size=2, + ) assert shared_data["conformational_states"] == { - "ua": {"ua_key": ["state_a"]}, - "res": [["res_state"]], + "ua": {(0, 0): ["0"]}, + "res": [["1"]], } assert shared_data["flexible_dihedrals"] == { - "ua": {"ua_key": 1}, - "res": [1], - } - assert result == { - "conformational_states": shared_data["conformational_states"], + "ua": {(0, 0): 1}, + "res": [0], } + assert out == {"conformational_states": shared_data["conformational_states"]} - builder = builder_holder["builder"] - assert builder.universe_operations is universe_operations - assert builder.calls == [ - { - "data_container": universe, - "levels": [["united_atom", "residue"]], - "groups": {0: [0]}, - "bin_width": 30, - "frame_selection": frame_selection, - "progress": progress, - } - ] - - -def test_conformation_dag_converts_bin_width_to_int(monkeypatch): - captured = {} - - class Builder: - def __init__(self, universe_operations): - self.universe_operations = universe_operations - - def build_conformational_states(self, **kwargs): - captured.update(kwargs) - return {}, [], {}, [] - - monkeypatch.setattr( - conformation_dag, - "ConformationStateBuilder", - Builder, - ) +def test_execute_converts_bin_width_to_int(): + """Test that execute converts args.bin_width before calling the builder.""" dag = ConformationDAG() + dag._policy = MagicMock() + dag._policy.frame_chunk_size.return_value = 3 + dag._builder = MagicMock() + dag._builder.build_conformational_states.return_value = ({}, [], {}, []) + + frame_selection = _make_frame_selection(start=0, stop=3, step=1) shared_data = { "reduced_universe": object(), - "levels": [], + "levels": {}, "groups": {}, - "frame_selection": object(), + "frame_selection": frame_selection, "args": SimpleNamespace(bin_width="45"), } dag.execute(shared_data) - assert captured["bin_width"] == 45 + assert dag._builder.build_conformational_states.call_args.kwargs["bin_width"] == 45 + assert dag._builder.build_conformational_states.call_args.kwargs["chunk_size"] == 3 + + +def test_execute_passes_real_frame_selection_to_builder(): + """Test that execute forwards the existing FrameSelection object unchanged.""" + dag = ConformationDAG() + dag._policy = MagicMock() + dag._policy.frame_chunk_size.return_value = 1 + dag._builder = MagicMock() + dag._builder.build_conformational_states.return_value = ({}, [], {}, []) + + frame_selection = _make_frame_selection(start=10, stop=31, step=10) + shared_data = { + "reduced_universe": "universe", + "levels": {}, + "groups": {}, + "frame_selection": frame_selection, + "args": SimpleNamespace(bin_width=30), + } + + dag.execute(shared_data) + + assert ( + dag._builder.build_conformational_states.call_args.kwargs["frame_selection"] + is frame_selection + ) + dag._policy.frame_chunk_size.assert_called_once_with( + shared_data, + n_frames=frame_selection.n_frames, + ) diff --git a/tests/unit/CodeEntropy/levels/test_dihedrals.py b/tests/unit/CodeEntropy/levels/test_dihedrals.py deleted file mode 100644 index ded662d9..00000000 --- a/tests/unit/CodeEntropy/levels/test_dihedrals.py +++ /dev/null @@ -1,952 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock, call, patch - -import numpy as np -import pytest - -from CodeEntropy.levels.dihedrals import ( - ConformationStateBuilder, - ConformationStateData, - DihedralAngleData, - DihedralPeakData, -) -from CodeEntropy.trajectory.frames import FrameSelection - - -class _AddableAG: - """Minimal addable AtomGroup test double.""" - - def __init__(self, name: str): - """Initialize the fake AtomGroup. - - Args: - name: Human-readable identifier used in composed names. - """ - self.name = name - - def __add__(self, other: _AddableAG) -> _AddableAG: - """Return a composed fake AtomGroup. - - Args: - other: Fake AtomGroup to combine with this object. - - Returns: - New fake AtomGroup containing a composed name. - """ - return _AddableAG(f"({self.name}+{other.name})") - - -def _make_frame_selection( - start: int = 0, - stop: int = 2, - step: int = 1, -) -> FrameSelection: - """Build a FrameSelection for dihedral unit tests. - - Args: - start: Inclusive source-frame start. - stop: Exclusive source-frame stop. - step: Source-frame step. - - Returns: - FrameSelection covering the requested bounds. - """ - return FrameSelection.from_bounds(start=start, stop=stop, step=step) - - -def test_select_heavy_residue_builds_two_selections(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([10, 11, 12], dtype=int) - - uops.select_atoms.side_effect = ["res_container", "heavy_only"] - - out = dt._select_heavy_residue(mol, res_id=0) - - assert out == "heavy_only" - assert uops.select_atoms.call_count == 2 - uops.select_atoms.assert_any_call(mol, "index 10:12") - uops.select_atoms.assert_any_call("res_container", "prop mass > 1.1") - - -def test_get_dihedrals_united_atom_collects_atoms_from_dihedral_objects(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - d0 = MagicMock() - d0.atoms = "A0" - d1 = MagicMock() - d1.atoms = "A1" - - container = MagicMock() - container.dihedrals = [d0, d1] - - assert dt._get_dihedrals(container, level="united_atom") == ["A0", "A1"] - - -def test_get_dihedrals_residue_returns_empty_when_less_than_4_residues(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock(), MagicMock()] - mol.select_atoms = MagicMock() - - assert dt._get_dihedrals(mol, level="residue") == [] - mol.select_atoms.assert_not_called() - - -def test_get_dihedrals_residue_builds_one_dihedral_when_4_residues(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] - mol.select_atoms = MagicMock( - side_effect=[ - _AddableAG("a1"), - _AddableAG("a2"), - _AddableAG("a3"), - _AddableAG("a4"), - ] - ) - - out = dt._get_dihedrals(mol, level="residue") - - assert len(out) == 1 - assert isinstance(out[0], _AddableAG) - assert mol.select_atoms.call_count == 4 - - -def test_collect_dihedral_angle_data_sets_empty_outputs_when_no_dihedrals(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=[]) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - angle_data = dt._collect_dihedral_angle_data( - data_container=MagicMock(), - molecules=[0], - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - - assert angle_data.num_residues == 1 - assert angle_data.num_dihedrals_ua == [0] - assert angle_data.num_dihedrals_res == 0 - assert angle_data.phi_ua == {0: []} - assert angle_data.phi_res == [] - - -def test_collect_dihedral_angle_data_wraps_negative_angles(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - dihedrals = ["D0"] - angles = np.array([[-10.0], [10.0]], dtype=float) - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): - angle_data = dt._collect_dihedral_angle_data( - data_container=MagicMock(), - molecules=[0], - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - - assert angle_data.phi_ua[0][0] == [350.0, 10.0] - assert angle_data.phi_res[0] == [350.0, 10.0] - - -def test_build_peak_data_returns_empty_outputs_when_no_angles(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - angle_data = DihedralAngleData( - num_residues=1, - num_dihedrals_ua=[0], - num_dihedrals_res=0, - phi_ua={0: []}, - phi_res=[], - ) - - peak_data = dt._build_peak_data( - angle_data=angle_data, - level_list=["united_atom", "residue"], - bin_width=30.0, - ) - - assert peak_data == DihedralPeakData(peaks_ua=[[]], peaks_res=[]) - - -def test_build_peak_data_calls_process_histogram_for_ua_and_residue(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - angle_data = DihedralAngleData( - num_residues=1, - num_dihedrals_ua=[1], - num_dihedrals_res=1, - phi_ua={0: {0: [10.0, 20.0]}}, - phi_res={0: [30.0, 40.0]}, - ) - - dt._process_histogram = MagicMock(side_effect=[["ua_peak"], ["res_peak"]]) - - peak_data = dt._build_peak_data( - angle_data=angle_data, - level_list=["united_atom", "residue"], - bin_width=30.0, - ) - - assert peak_data.peaks_ua == [["ua_peak"]] - assert peak_data.peaks_res == ["res_peak"] - assert dt._process_histogram.call_args_list == [ - call( - num_dihedrals=1, - phi_values={0: [10.0, 20.0]}, - bin_width=30.0, - ), - call( - num_dihedrals=1, - phi_values={0: [30.0, 40.0]}, - bin_width=30.0, - ), - ] - - -def test_identify_peaks_delegates_to_angle_collection_and_peak_building(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - angle_data = DihedralAngleData( - num_residues=1, - num_dihedrals_ua=[1], - num_dihedrals_res=1, - phi_ua={0: {0: [10.0]}}, - phi_res={0: [10.0]}, - ) - peak_data = DihedralPeakData(peaks_ua=[[[10.0]]], peaks_res=[[10.0]]) - - dt._collect_dihedral_angle_data = MagicMock(return_value=angle_data) - dt._build_peak_data = MagicMock(return_value=peak_data) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - peaks_ua, peaks_res = dt._identify_peaks( - data_container="universe", - molecules=[0], - bin_width=30.0, - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - - assert peaks_ua == peak_data.peaks_ua - assert peaks_res == peak_data.peaks_res - dt._collect_dihedral_angle_data.assert_called_once_with( - data_container="universe", - molecules=[0], - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - dt._build_peak_data.assert_called_once_with( - angle_data=angle_data, - level_list=["united_atom", "residue"], - bin_width=30.0, - ) - - -def test_identify_peaks_wraps_negative_angles_and_calls_process_histogram(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - dihedrals = ["D0"] - angles = np.array([[-10.0], [10.0]], dtype=float) - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with ( - patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral), - patch.object(dt, "_process_histogram", return_value=[15.0]) as peaks_spy, - ): - out_ua, out_res = dt._identify_peaks( - data_container=MagicMock(), - molecules=[0], - bin_width=10.0, - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - - assert out_ua == [[15.0]] - assert out_res == [15.0] - assert peaks_spy.call_count == 2 - - -def test_find_histogram_peaks_hits_interior_and_wraparound_last_bin(): - popul = [0, 2, 0, 3] - bin_value = [10.0, 20.0, 30.0, 40.0] - assert ConformationStateBuilder._find_histogram_peaks(popul, bin_value) == [ - 20.0, - 40.0, - ] - - -def test_calculate_group_state_data_initialises_then_extends_for_multiple_molecules(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - dihedrals = ["D0"] - angles = np.array([[5.0], [15.0]], dtype=float) - peaks = [[5.0, 15.0]] - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): - state_data = dt._calculate_group_state_data( - data_container=MagicMock(), - group_id=0, - molecules=[0, 1], - level_list=["united_atom", "residue"], - peaks_ua=[peaks], - peaks_res=peaks, - frame_selection=frame_selection, - ) - - assert state_data.states_ua_updates[(0, 0)] == ["0", "1", "0", "1"] - assert state_data.flexible_ua_updates[(0, 0)] == 1 - assert state_data.state_res == ["0", "1", "0", "1"] - assert state_data.flex_res == 1 - - -def test_merge_group_state_data_initialises_final_accumulators(): - states_ua = {} - states_res = [] - flexible_ua = {} - flexible_res = [] - - state_data = ConformationStateData( - state_res=["0", "1"], - flex_res=1, - states_ua_updates={(0, 0): ["0", "1"]}, - flexible_ua_updates={(0, 0): 1}, - ) - - ConformationStateBuilder._merge_group_state_data( - state_data=state_data, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - ) - - assert states_ua == {(0, 0): ["0", "1"]} - assert states_res == [["0", "1"]] - assert flexible_ua == {(0, 0): 1} - assert flexible_res == [1] - - -def test_merge_group_state_data_extends_existing_ua_states(): - states_ua = {(0, 0): ["0"]} - states_res = [] - flexible_ua = {(0, 0): 1} - flexible_res = [] - - state_data = ConformationStateData( - state_res=["1"], - flex_res=0, - states_ua_updates={(0, 0): ["1"], (0, 1): ["2"]}, - flexible_ua_updates={(0, 0): 2, (0, 1): 0}, - ) - - ConformationStateBuilder._merge_group_state_data( - state_data=state_data, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - ) - - assert states_ua[(0, 0)] == ["0", "1"] - assert states_ua[(0, 1)] == ["2"] - assert flexible_ua[(0, 0)] == 2 - assert flexible_ua[(0, 1)] == 0 - assert states_res == [["1"]] - assert flexible_res == [0] - - -def test_assign_states_delegates_to_calculation_and_merge(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - - state_data = ConformationStateData( - state_res=["0"], - flex_res=1, - states_ua_updates={(0, 0): ["0"]}, - flexible_ua_updates={(0, 0): 1}, - ) - dt._calculate_group_state_data = MagicMock(return_value=state_data) - dt._merge_group_state_data = MagicMock() - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - states_ua = {} - states_res = [] - flexible_ua = {} - flexible_res = [] - - dt._assign_states( - data_container="universe", - group_id=0, - molecules=[0], - level_list=["united_atom"], - peaks_ua=[[[10.0]]], - peaks_res=[], - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - frame_selection=frame_selection, - ) - - dt._calculate_group_state_data.assert_called_once_with( - data_container="universe", - group_id=0, - molecules=[0], - level_list=["united_atom"], - peaks_ua=[[[10.0]]], - peaks_res=[], - frame_selection=frame_selection, - ) - dt._merge_group_state_data.assert_called_once_with( - state_data=state_data, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - ) - - -def test_assign_states_initialises_then_extends_for_multiple_molecules(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - dihedrals = ["D0"] - angles = np.array([[5.0], [15.0]], dtype=float) - peaks = [[5.0, 15.0]] - - states_ua = {} - states_res = [] - flexible_ua = {} - flexible_res = [] - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): - dt._assign_states( - data_container=MagicMock(), - group_id=0, - molecules=[0, 1], - level_list=["united_atom", "residue"], - peaks_ua=[peaks], - peaks_res=peaks, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - frame_selection=frame_selection, - ) - - assert states_ua[(0, 0)] == ["0", "1", "0", "1"] - assert flexible_ua[(0, 0)] == 1 - assert states_res[0] == ["0", "1", "0", "1"] - assert flexible_res[0] == 1 - - -def test_build_conformational_states_runs_group_and_skips_empty_group(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - groups = {0: [], 1: [7]} - levels = {7: ["residue"]} - - dt._identify_peaks = MagicMock(return_value=([], [])) - dt._assign_states = MagicMock() - - frame_selection = _make_frame_selection(start=0, stop=1, step=1) - - states_ua, states_res, flex_ua, flex_res = dt.build_conformational_states( - data_container=MagicMock(), - levels=levels, - groups=groups, - bin_width=30.0, - frame_selection=frame_selection, - ) - - assert states_ua == {} - assert states_res == [[], []] - assert flex_ua == {} - assert flex_res == [] - - dt._identify_peaks.assert_called_once() - dt._assign_states.assert_called_once() - assert dt._identify_peaks.call_args.kwargs["frame_selection"] is frame_selection - assert dt._assign_states.call_args.kwargs["frame_selection"] is frame_selection - - -def test_identify_peaks_handles_multiple_dihedrals(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - dihedrals = ["D0", "D1"] - angles = np.array( - [ - [-10.0, 10.0], - [20.0, -20.0], - ], - dtype=float, - ) - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - dt._process_histogram = MagicMock(return_value=[1, 2]) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): - out_ua, out_res = dt._identify_peaks( - data_container=MagicMock(), - molecules=[0], - bin_width=30.0, - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - - assert out_ua == [[1, 2]] - assert out_res == [1, 2] - assert dt._process_histogram.call_count == 2 - - -def test_collect_dihedral_angle_data_initialises_phi_res_dict_before_processing(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] - uops.extract_fragment.return_value = mol - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - dihedrals = ["D0"] - dihedral_results = MagicMock() - processed_phi = {0: [10.0, 20.0]} - - dt._get_dihedrals = MagicMock(return_value=dihedrals) - dt._run_dihedrals = MagicMock(return_value=dihedral_results) - dt._process_dihedral_phi = MagicMock(return_value=processed_phi) - - angle_data = dt._collect_dihedral_angle_data( - data_container=MagicMock(), - molecules=[0], - level_list=["residue"], - frame_selection=frame_selection, - ) - - assert angle_data.num_residues == 4 - assert angle_data.phi_res == processed_phi - assert angle_data.num_dihedrals_res == 1 - - dt._process_dihedral_phi.assert_called_once_with( - dihedral_results=dihedral_results, - num_dihedrals=1, - number_frames=2, - phi_values={}, - ) - - -def test_identify_peaks_initialises_phi_res_dict_before_processing_residue_dihedrals(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] - uops.extract_fragment.return_value = mol - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - dihedrals = ["D0"] - dihedral_results = MagicMock() - processed_phi = {0: [10.0, 20.0]} - - dt._get_dihedrals = MagicMock(return_value=dihedrals) - dt._run_dihedrals = MagicMock(return_value=dihedral_results) - dt._process_dihedral_phi = MagicMock(return_value=processed_phi) - dt._process_histogram = MagicMock(return_value=[[15.0]]) - - peaks_ua, peaks_res = dt._identify_peaks( - data_container=MagicMock(), - molecules=[0], - bin_width=30.0, - level_list=["residue"], - frame_selection=frame_selection, - ) - - assert peaks_ua == [[], [], [], []] - assert peaks_res == [[15.0]] - - dt._process_dihedral_phi.assert_called_once_with( - dihedral_results=dihedral_results, - num_dihedrals=1, - number_frames=2, - phi_values={}, - ) - - -def test_assign_states_filters_out_empty_state_strings_when_no_dihedrals(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - states_ua = {} - states_res = [] - flexible_ua = {} - flexible_res = [] - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=[]) - - frame_selection = _make_frame_selection(start=0, stop=3, step=1) - - dt._assign_states( - data_container=MagicMock(), - group_id=0, - molecules=[0], - level_list=["united_atom", "residue"], - peaks_ua=[], - peaks_res=[], - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - frame_selection=frame_selection, - ) - - assert states_ua[(0, 0)] == [] - assert flexible_ua[(0, 0)] == 0 - assert states_res[0] == [] - assert flexible_res[0] == 0 - - -def test_identify_peaks_multiple_molecules_real_histogram(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol0 = MagicMock() - mol0.residues = [MagicMock()] - mol0.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - - mol1 = MagicMock() - mol1.residues = [MagicMock()] - mol1.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - - uops.extract_fragment.side_effect = [mol0, mol0, mol1] - - dihedrals = ["D0"] - angles = np.array([[10.0], [20.0]], dtype=float) - - dt._select_heavy_residue = MagicMock(return_value=mol0) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): - peaks_ua, peaks_res = dt._identify_peaks( - data_container=MagicMock(), - molecules=[0, 1], - bin_width=90.0, - level_list=["united_atom", "residue"], - frame_selection=frame_selection, - ) - - assert len(peaks_ua) == 1 - assert len(peaks_res) == 1 - - -def test_assign_states_wraps_negative_angles(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - mol = MagicMock() - mol.residues = [MagicMock()] - mol.residues[0].atoms.indices = np.array([0, 1, 2, 3], dtype=int) - uops.extract_fragment.return_value = mol - - angles = np.array([[-10.0], [10.0]], dtype=float) - peaks = [[10.0, 350.0]] - dihedrals = ["D0"] - - states_ua = {} - states_res = [] - flexible_ua = {} - flexible_res = [] - - dt._select_heavy_residue = MagicMock(return_value=mol) - dt._get_dihedrals = MagicMock(return_value=dihedrals) - - class _FakeDihedral: - def __init__(self, _dihedrals): - pass - - def run(self, *args, **kwargs): - return SimpleNamespace(results=SimpleNamespace(angles=angles)) - - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with patch("CodeEntropy.levels.dihedrals.Dihedral", _FakeDihedral): - dt._assign_states( - data_container=MagicMock(), - group_id=0, - molecules=[0, 1], - level_list=["united_atom", "residue"], - peaks_ua=[peaks], - peaks_res=peaks, - states_ua=states_ua, - states_res=states_res, - flexible_ua=flexible_ua, - flexible_res=flexible_res, - frame_selection=frame_selection, - ) - - assert states_ua[(0, 0)] == ["1", "0", "1", "0"] - assert flexible_ua[(0, 0)] == 1 - assert states_res[0] == ["1", "0", "1", "0"] - assert flexible_res[0] == 1 - - -def test_build_conformational_states_with_progress_handles_no_groups(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - progress = MagicMock() - progress.add_task.return_value = 123 - - frame_selection = _make_frame_selection(start=0, stop=1, step=1) - - states_ua, states_res, flex_ua, flex_res = dt.build_conformational_states( - data_container=MagicMock(), - levels={}, - groups={}, - bin_width=30.0, - frame_selection=frame_selection, - progress=progress, - ) - - assert states_ua == {} - assert states_res == [] - assert flex_ua == {} - assert flex_res == [] - - progress.add_task.assert_called_once() - progress.update.assert_called_once_with(123, title="No groups") - progress.advance.assert_called_once_with(123) - - -def test_build_conformational_states_with_progress_skips_empty_molecule_group(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - progress = MagicMock() - progress.add_task.return_value = 5 - - frame_selection = _make_frame_selection(start=0, stop=1, step=1) - - states_ua, states_res, flex_ua, flex_res = dt.build_conformational_states( - data_container=MagicMock(), - levels={}, - groups={0: []}, - bin_width=30.0, - frame_selection=frame_selection, - progress=progress, - ) - - assert states_ua == {} - assert states_res == [[]] - assert flex_ua == {} - assert flex_res == [] - - progress.update.assert_called_with(5, title="Group 0 (empty)") - progress.advance.assert_called_with(5) - - -def test_build_conformational_states_with_progress_updates_title_per_group(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - progress = MagicMock() - progress.add_task.return_value = 9 - - groups = {1: [7]} - levels = {7: ["residue"]} - - dt._identify_peaks = MagicMock(return_value=([], [])) - dt._assign_states = MagicMock() - - frame_selection = _make_frame_selection(start=0, stop=1, step=1) - - dt.build_conformational_states( - data_container=MagicMock(), - levels=levels, - groups=groups, - bin_width=30.0, - frame_selection=frame_selection, - progress=progress, - ) - - progress.update.assert_any_call(9, title="Group 1") - progress.advance.assert_called_with(9) - assert dt._identify_peaks.call_args.kwargs["frame_selection"] is frame_selection - assert dt._assign_states.call_args.kwargs["frame_selection"] is frame_selection - - -def test_process_dihedral_phi(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - dihedral_results = MagicMock() - dihedral_results.results.angles = [[0, 1, 2], [3, 4, 5]] - num_dihedrals = 3 - number_frames = 2 - phi_values = {} - - phi_values = dt._process_dihedral_phi( - dihedral_results, num_dihedrals, number_frames, phi_values - ) - - assert len(phi_values) == 3 - assert phi_values[0] == [0, 3] - - -def test_process_dihedral_phi_negative(): - uops = MagicMock() - dt = ConformationStateBuilder(universe_operations=uops) - - dihedral_results = MagicMock() - dihedral_results.results.angles = [[0, 1, 2], [-3, 4, 5]] - num_dihedrals = 3 - number_frames = 2 - phi_values = {} - - phi_values = dt._process_dihedral_phi( - dihedral_results, num_dihedrals, number_frames, phi_values - ) - - assert len(phi_values) == 3 - assert phi_values[0] == [0, 357] - - -def test_run_dihedrals_raises_when_no_dihedrals(): - dt = ConformationStateBuilder(universe_operations=MagicMock()) - frame_selection = _make_frame_selection(start=0, stop=2, step=1) - - with pytest.raises( - ValueError, match="Cannot run Dihedral analysis with no dihedrals" - ): - dt._run_dihedrals( - dihedrals=[], - frame_selection=frame_selection, - ) - - -def test_analysis_run_bounds_raises_when_frame_selection_empty(): - frame_selection = FrameSelection(indices=()) - - with pytest.raises(ValueError, match="Frame selection is empty"): - ConformationStateBuilder._analysis_run_bounds(frame_selection) From 331270e614b909b4de10aeaed7cbbd92079b3f51 Mon Sep 17 00:00:00 2001 From: harryswift01 Date: Thu, 18 Jun 2026 09:53:59 +0100 Subject: [PATCH 3/3] fix(dihedrals): preserve residue conformational state indexing --- CodeEntropy/levels/dihedrals/conformational_state_builder.py | 5 ++--- .../levels/dihedrals/test_conformational_state_builder.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CodeEntropy/levels/dihedrals/conformational_state_builder.py b/CodeEntropy/levels/dihedrals/conformational_state_builder.py index ee5268b7..a944201e 100644 --- a/CodeEntropy/levels/dihedrals/conformational_state_builder.py +++ b/CodeEntropy/levels/dihedrals/conformational_state_builder.py @@ -108,8 +108,9 @@ def _build_conformational_states_serial_chunked( if chunk_size < 1: raise ValueError("chunk_size must be >= 1") + number_groups = len(groups) states_ua: dict[UAKey, list[str]] = {} - states_res: list[list[str]] = [] + states_res: list[list[str]] = [[] for _ in range(number_groups)] flexible_ua: dict[UAKey, int] = {} flexible_res: list[int] = [] @@ -130,8 +131,6 @@ def _build_conformational_states_serial_chunked( for group_id, molecules in groups.items(): if not molecules: - states_res.append([]) - if progress is not None and task is not None: progress.update(task, title=f"Group {group_id} (empty)") progress.advance(task) diff --git a/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py b/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py index 390049d9..1ac72cc5 100644 --- a/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py +++ b/tests/unit/CodeEntropy/levels/dihedrals/test_conformational_state_builder.py @@ -179,7 +179,7 @@ def test_chunked_serial_group_flow_calls_domain_phases_in_order(): ) assert states_ua == {(0, 0): ["0"]} - assert states_res == [[]] + assert states_res == [[], []] assert flexible_ua == {(0, 0): 0} assert flexible_res == [0] builder._discover_group_dihedral_topology.assert_called_once_with(