From 2330467506494906b81bba57eedb16f53c309d70 Mon Sep 17 00:00:00 2001 From: Choudhry Abdullah Date: Mon, 15 Jun 2026 16:39:42 -0500 Subject: [PATCH] Add parallel iteration for exact chunks Implement Rayon IntoParallelIterator support for ExactChunks and ExactChunksMut by wrapping their NdProducer traversal in an unindexed parallel producer. Add coverage for immutable and mutable exact chunk iteration and document the new API. Closes #1192 --- src/parallel/mod.rs | 22 ++++++++- src/parallel/par.rs | 115 +++++++++++++++++++++++++++++++++++++++++++- tests/par_rayon.rs | 34 +++++++++++++ 3 files changed, 169 insertions(+), 2 deletions(-) diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 3ac0d4b04..97301da2c 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -15,6 +15,7 @@ //! - [`ArrayViewMut`] `.into_par_iter()` //! - [`AxisIter`], [`AxisIterMut`] `.into_par_iter()` //! - [`AxisChunksIter`], [`AxisChunksIterMut`] `.into_par_iter()` +//! - [`ExactChunks`], [`ExactChunksMut`] `.into_par_iter()` //! - [`Zip`] `.into_par_iter()` //! //! The following other parallelized methods exist: @@ -94,6 +95,23 @@ //! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]); //! ``` //! +//! ## Exact chunks +//! +//! Use parallel `.exact_chunks()` to process only complete chunks of an array. +//! +//! ``` +//! use ndarray::Array; +//! use ndarray::parallel::prelude::*; +//! +//! let a = Array::linspace(0.0..=63.0, 64).into_shape_with_order((8, 8)).unwrap(); +//! let sum: f64 = a.exact_chunks((2, 4)) +//! .into_par_iter() +//! .map(|chunk| chunk.sum()) +//! .sum(); +//! +//! assert_eq!(sum, a.sum()); +//! ``` +//! //! ## Zip //! //! Use zip for lock step function application across several arrays @@ -118,7 +136,9 @@ //! ``` #[allow(unused_imports)] // used by rustdoc links -use crate::iter::{AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut}; +use crate::iter::{ + AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, +}; #[allow(unused_imports)] // used by rustdoc links use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, Zip}; diff --git a/src/parallel/par.rs b/src/parallel/par.rs index efff8e6b7..e392d1a24 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -13,9 +13,11 @@ use crate::iter::AxisChunksIter; use crate::iter::AxisChunksIterMut; use crate::iter::AxisIter; use crate::iter::AxisIterMut; +use crate::iter::ExactChunks; +use crate::iter::ExactChunksMut; use crate::split_at::SplitPreference; use crate::Dimension; -use crate::{ArrayView, ArrayViewMut}; +use crate::{ArrayView, ArrayViewMut, Axis}; /// Parallel iterator wrapper. #[derive(Copy, Clone, Debug)] @@ -225,6 +227,117 @@ par_iter_view_wrapper!(ArrayViewMut, [Sync + Send]); use crate::{FoldWhile, NdProducer, Zip}; +macro_rules! par_ndproducer_wrapper { + // thread_bounds are either Sync or Send + Sync + ($producer_name:ident, [$($thread_bounds:tt)*]) => { + /// Requires crate feature `rayon`. + impl<'a, A, D> IntoParallelIterator for $producer_name<'a, A, D> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = ::Item; + type Iter = Parallel; + fn into_par_iter(self) -> Self::Iter { + Parallel { + iter: self, + min_len: DEFAULT_MIN_LEN, + } + } + } + + impl<'a, A, D> ParallelIterator for Parallel<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = <$producer_name<'a, A, D> as NdProducer>::Item; + fn drive_unindexed(self, consumer: C) -> C::Result + where C: UnindexedConsumer + { + bridge_unindexed(ParallelProducer(self.iter, self.min_len), consumer) + } + + fn opt_len(&self) -> Option { + Some(self.iter.raw_dim().size()) + } + } + + impl<'a, A, D> Parallel<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + /// Sets the minimum number of chunks desired to process in each job. This will not be + /// split any smaller than this length, but of course a producer could already be smaller + /// to begin with. + /// + /// ***Panics*** if `min_len` is zero. + pub fn with_min_len(self, min_len: usize) -> Self { + assert_ne!(min_len, 0, "Minimum number of elements must at least be one to avoid splitting off empty tasks."); + + Self { + min_len, + ..self + } + } + } + + impl<'a, A, D> UnindexedProducer for ParallelProducer<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = <$producer_name<'a, A, D> as NdProducer>::Item; + fn split(self) -> (Self, Option) { + let dim = self.0.raw_dim(); + if dim.size() <= self.1 { + return (self, None) + } + + let Some((axis, &len)) = dim + .slice() + .iter() + .enumerate() + .max_by_key(|&(_, len)| len) + else { + return (self, None) + }; + if len <= 1 { + return (self, None) + } + + let (a, b) = self.0.split_at(Axis(axis), len / 2); + (ParallelProducer(a, self.1), Some(ParallelProducer(b, self.1))) + } + + fn fold_with(self, folder: F) -> F + where F: Folder, + { + Zip::from(self.0).fold_while(folder, |mut folder, elt| { + folder = folder.consume(elt); + if folder.full() { + FoldWhile::Done(folder) + } else { + FoldWhile::Continue(folder) + } + }).into_inner() + } + } + + impl<'a, A, D> IntoIterator for ParallelProducer<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = <$producer_name<'a, A, D> as IntoIterator>::Item; + type IntoIter = <$producer_name<'a, A, D> as IntoIterator>::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } + } + + }; +} + +par_ndproducer_wrapper!(ExactChunks, [Sync]); +par_ndproducer_wrapper!(ExactChunksMut, [Send + Sync]); + macro_rules! zip_impl { ($([$($p:ident)*],)+) => { $( diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 1b6b2b794..7fc7c630b 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -77,6 +77,40 @@ fn test_axis_chunks_iter() assert_eq!(s, a.sum()); } +#[test] +fn test_exact_chunks() +{ + let a = Array::from_iter(0..100) + .into_shape_with_order((10, 10)) + .unwrap(); + let s: i32 = a + .exact_chunks((2, 5)) + .into_par_iter() + .map(|chunk| chunk.sum()) + .sum(); + assert_eq!(s, a.sum()); +} + +#[test] +fn test_exact_chunks_mut() +{ + let mut a = Array2::::zeros((7, 8)); + a.exact_chunks_mut((2, 3)) + .into_par_iter() + .for_each(|mut chunk| chunk.fill(1)); + + let ans = array![ + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ]; + assert_eq!(a, ans); +} + #[test] #[cfg(feature = "approx")] fn test_axis_chunks_iter_mut()