diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 3ac0d4b04..97301da2c 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -15,6 +15,7 @@ //! - [`ArrayViewMut`] `.into_par_iter()` //! - [`AxisIter`], [`AxisIterMut`] `.into_par_iter()` //! - [`AxisChunksIter`], [`AxisChunksIterMut`] `.into_par_iter()` +//! - [`ExactChunks`], [`ExactChunksMut`] `.into_par_iter()` //! - [`Zip`] `.into_par_iter()` //! //! The following other parallelized methods exist: @@ -94,6 +95,23 @@ //! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]); //! ``` //! +//! ## Exact chunks +//! +//! Use parallel `.exact_chunks()` to process only complete chunks of an array. +//! +//! ``` +//! use ndarray::Array; +//! use ndarray::parallel::prelude::*; +//! +//! let a = Array::linspace(0.0..=63.0, 64).into_shape_with_order((8, 8)).unwrap(); +//! let sum: f64 = a.exact_chunks((2, 4)) +//! .into_par_iter() +//! .map(|chunk| chunk.sum()) +//! .sum(); +//! +//! assert_eq!(sum, a.sum()); +//! ``` +//! //! ## Zip //! //! Use zip for lock step function application across several arrays @@ -118,7 +136,9 @@ //! ``` #[allow(unused_imports)] // used by rustdoc links -use crate::iter::{AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut}; +use crate::iter::{ + AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, +}; #[allow(unused_imports)] // used by rustdoc links use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, Zip}; diff --git a/src/parallel/par.rs b/src/parallel/par.rs index efff8e6b7..e392d1a24 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -13,9 +13,11 @@ use crate::iter::AxisChunksIter; use crate::iter::AxisChunksIterMut; use crate::iter::AxisIter; use crate::iter::AxisIterMut; +use crate::iter::ExactChunks; +use crate::iter::ExactChunksMut; use crate::split_at::SplitPreference; use crate::Dimension; -use crate::{ArrayView, ArrayViewMut}; +use crate::{ArrayView, ArrayViewMut, Axis}; /// Parallel iterator wrapper. #[derive(Copy, Clone, Debug)] @@ -225,6 +227,117 @@ par_iter_view_wrapper!(ArrayViewMut, [Sync + Send]); use crate::{FoldWhile, NdProducer, Zip}; +macro_rules! par_ndproducer_wrapper { + // thread_bounds are either Sync or Send + Sync + ($producer_name:ident, [$($thread_bounds:tt)*]) => { + /// Requires crate feature `rayon`. + impl<'a, A, D> IntoParallelIterator for $producer_name<'a, A, D> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = ::Item; + type Iter = Parallel; + fn into_par_iter(self) -> Self::Iter { + Parallel { + iter: self, + min_len: DEFAULT_MIN_LEN, + } + } + } + + impl<'a, A, D> ParallelIterator for Parallel<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = <$producer_name<'a, A, D> as NdProducer>::Item; + fn drive_unindexed(self, consumer: C) -> C::Result + where C: UnindexedConsumer + { + bridge_unindexed(ParallelProducer(self.iter, self.min_len), consumer) + } + + fn opt_len(&self) -> Option { + Some(self.iter.raw_dim().size()) + } + } + + impl<'a, A, D> Parallel<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + /// Sets the minimum number of chunks desired to process in each job. This will not be + /// split any smaller than this length, but of course a producer could already be smaller + /// to begin with. + /// + /// ***Panics*** if `min_len` is zero. + pub fn with_min_len(self, min_len: usize) -> Self { + assert_ne!(min_len, 0, "Minimum number of elements must at least be one to avoid splitting off empty tasks."); + + Self { + min_len, + ..self + } + } + } + + impl<'a, A, D> UnindexedProducer for ParallelProducer<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = <$producer_name<'a, A, D> as NdProducer>::Item; + fn split(self) -> (Self, Option) { + let dim = self.0.raw_dim(); + if dim.size() <= self.1 { + return (self, None) + } + + let Some((axis, &len)) = dim + .slice() + .iter() + .enumerate() + .max_by_key(|&(_, len)| len) + else { + return (self, None) + }; + if len <= 1 { + return (self, None) + } + + let (a, b) = self.0.split_at(Axis(axis), len / 2); + (ParallelProducer(a, self.1), Some(ParallelProducer(b, self.1))) + } + + fn fold_with(self, folder: F) -> F + where F: Folder, + { + Zip::from(self.0).fold_while(folder, |mut folder, elt| { + folder = folder.consume(elt); + if folder.full() { + FoldWhile::Done(folder) + } else { + FoldWhile::Continue(folder) + } + }).into_inner() + } + } + + impl<'a, A, D> IntoIterator for ParallelProducer<$producer_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + type Item = <$producer_name<'a, A, D> as IntoIterator>::Item; + type IntoIter = <$producer_name<'a, A, D> as IntoIterator>::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } + } + + }; +} + +par_ndproducer_wrapper!(ExactChunks, [Sync]); +par_ndproducer_wrapper!(ExactChunksMut, [Send + Sync]); + macro_rules! zip_impl { ($([$($p:ident)*],)+) => { $( diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 1b6b2b794..7fc7c630b 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -77,6 +77,40 @@ fn test_axis_chunks_iter() assert_eq!(s, a.sum()); } +#[test] +fn test_exact_chunks() +{ + let a = Array::from_iter(0..100) + .into_shape_with_order((10, 10)) + .unwrap(); + let s: i32 = a + .exact_chunks((2, 5)) + .into_par_iter() + .map(|chunk| chunk.sum()) + .sum(); + assert_eq!(s, a.sum()); +} + +#[test] +fn test_exact_chunks_mut() +{ + let mut a = Array2::::zeros((7, 8)); + a.exact_chunks_mut((2, 3)) + .into_par_iter() + .for_each(|mut chunk| chunk.fill(1)); + + let ans = array![ + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ]; + assert_eq!(a, ans); +} + #[test] #[cfg(feature = "approx")] fn test_axis_chunks_iter_mut()