From e09195e9c22eb4f94934fc4db452eb290d4d5306 Mon Sep 17 00:00:00 2001 From: sagarchapara Date: Wed, 24 Jun 2026 07:33:18 +0000 Subject: [PATCH] feat: configure chunked Ulysses all-to-all Add a ulysses_attention_chunks attention config to split the Ulysses all-to-all into head-group passes. The chunked path lets XLA overlap all-to-all collectives with head-parallel local attention compute while preserving the existing single-shot path by default. Apply the same chunking to plain Ulysses and Ulysses+Ring, and allow the final chunk to carry the remainder when the requested chunk count does not divide the Ulysses head groups evenly. Add mocked attention tests for numerical and layout equivalence across chunk counts. --- src/maxdiffusion/configs/base_wan_27b.yml | 2 + src/maxdiffusion/models/attention_flax.py | 104 ++++++++++++- .../wan/transformers/transformer_wan.py | 2 + .../transformers/transformer_wan_animate.py | 14 +- .../wan/transformers/transformer_wan_vace.py | 2 + .../pipelines/wan/wan_pipeline.py | 1 + .../pipelines/wan/wan_pipeline_animate.py | 6 + .../pipelines/wan/wan_vace_pipeline_2_1.py | 1 + src/maxdiffusion/tests/attention_test.py | 140 ++++++++++++++++++ 9 files changed, 265 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 6f1c5035f..f081ba95f 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -94,6 +94,8 @@ use_base2_exp: True use_experimental_scheduler: True # For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. ulysses_shards: -1 +# Splits Ulysses all-to-all into head-group chunks. The last chunk carries any remainder. +ulysses_attention_chunks: 1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f6edc8309..9a88470c0 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -413,6 +413,44 @@ def _build_padding_segment_ids( return segment_ids_cls(q=q_segment_ids, kv=kv_segment_ids) +def _ulysses_head_chunk_ranges(num_heads: int, ulysses_shards: int, num_chunks: int): + """Build head-axis ranges for chunked Ulysses all-to-all. + + The Ulysses all-to-all splits each local chunk's head axis over + `ulysses_shards`, so every returned range length is a multiple of + `ulysses_shards`. When `num_chunks` does not evenly divide the number of + Ulysses-sized head groups, earlier chunks get the floor-sized range and the + final chunk carries the remainder. + + Returns: + A list of `(start, end)` half-open ranges over the head axis. Concatenating + tensors sliced with these ranges along the head axis restores the original + head layout. For `num_chunks <= 1`, returns `[(0, num_heads)]`, which is the + unchunked all-to-all path. + """ + if num_chunks <= 1: + return [(0, num_heads)] + if num_heads % ulysses_shards != 0: + raise ValueError( + "Ulysses attention requires the number of heads to be divisible by the Ulysses shard count, " + f"got heads={num_heads} and ulysses_shards={ulysses_shards}." + ) + + head_groups = num_heads // ulysses_shards + num_chunks = min(num_chunks, head_groups) + regular_groups_per_chunk = max(1, head_groups // num_chunks) + + ranges = [] + start_group = 0 + for chunk_idx in range(num_chunks): + end_group = head_groups if chunk_idx == num_chunks - 1 else min(start_group + regular_groups_per_chunk, head_groups) + if start_group >= end_group: + break + ranges.append((start_group * ulysses_shards, end_group * ulysses_shards)) + start_group = end_group + return ranges + + def _tpu_flash_attention( query: jax.Array, key: jax.Array, @@ -640,6 +678,7 @@ def _ulysses_attention( use_custom_kernel: bool = False, use_base2_exp: bool = True, use_experimental_scheduler: bool = False, + ulysses_attention_chunks: int = 1, ) -> jax.Array: """Ulysses sequence-parallel attention. @@ -662,6 +701,7 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) + if not use_custom_kernel: block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") @@ -762,7 +802,21 @@ def wrap_ulysses_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_attention(query, key, value) + head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_shards, ulysses_attention_chunks) + if len(head_chunk_ranges) > 1: + # Run Ulysses all-to-all per head group so XLA can overlap one group's + # collective with another group's head-parallel local attention compute. + chunk_outputs = [ + wrap_ulysses_attention( + query[:, start:end], + key[:, start:end], + value[:, start:end], + ) + for start, end in head_chunk_ranges + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_attention(query, key, value) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -787,6 +841,7 @@ def _ulysses_ring_attention( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ulysses_shards: int = -1, + ulysses_attention_chunks: int = 1, ) -> jax.Array: """2D context-parallel attention using a private Ulysses x ring mesh. @@ -916,7 +971,21 @@ def wrap_ulysses_ring_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_ring_attention(query, key, value) + head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_ulysses_shards, ulysses_attention_chunks) + if len(head_chunk_ranges) > 1: + # Run the Ulysses phase per head group before the ring phase so XLA can + # overlap all-to-all collectives with head-parallel local attention compute. + chunk_outputs = [ + wrap_ulysses_ring_attention( + query[:, start:end], + key[:, start:end], + value[:, start:end], + ) + for start, end in head_chunk_ranges + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_ring_attention(query, key, value) x = jax.lax.with_sharding_constraint(x, q_axis_names) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -941,6 +1010,7 @@ def _ulysses_ring_custom_attention( use_base2_exp: bool = True, use_experimental_scheduler: bool = False, bidirectional: bool = False, + ulysses_attention_chunks: int = 1, ) -> jax.Array: """Hybrid Ulysses + Ring (USP) with the CUSTOM splash kernel on main's mesh. @@ -1057,7 +1127,21 @@ def wrap_ulysses_ring_attention(query, key, value): attention_output = a2a(attention_output, split_axis=2, concat_axis=1) return attention_output - x = wrap_ulysses_ring_attention(query, key, value) + head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_ulysses_shards, ulysses_attention_chunks) + if len(head_chunk_ranges) > 1: + # Run the custom Ulysses phase per head group before the custom ring phase so + # XLA can overlap all-to-all collectives with head-parallel attention compute. + chunk_outputs = [ + wrap_ulysses_ring_attention( + query[:, start:end], + key[:, start:end], + value[:, start:end], + ) + for start, end in head_chunk_ranges + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_ring_attention(query, key, value) x = jax.lax.with_sharding_constraint(x, q_axis_names) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -1207,6 +1291,7 @@ def ulysses_custom_kernel(q, k, v, context): use_custom_kernel=True, use_base2_exp=context.get("use_base2_exp", True), use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1228,6 +1313,7 @@ def ulysses_ring_custom_kernel(q, k, v, context): ulysses_shards=context["ulysses_shards"], use_base2_exp=context.get("use_base2_exp", True), use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1253,6 +1339,7 @@ def ulysses_ring_custom_bidir_kernel(q, k, v, context): use_base2_exp=context.get("use_base2_exp", True), use_experimental_scheduler=context.get("use_experimental_scheduler", False), bidirectional=True, + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1271,6 +1358,7 @@ def ulysses_kernel(q, k, v, context): mask_padding_tokens=context["mask_padding_tokens"], residual_checkpoint_name=context["residual_checkpoint_name"], attention_mask=context["attention_mask"], + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1292,6 +1380,7 @@ def ulysses_ring_kernel(q, k, v, context): use_base2_exp=context["use_base2_exp"], use_experimental_scheduler=context["use_experimental_scheduler"], ulysses_shards=context["ulysses_shards"], + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1404,6 +1493,7 @@ def _apply_attention( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ulysses_shards: int = -1, + ulysses_attention_chunks: int = 1, ): """Routes to different attention kernels using a module-level registry.""" @@ -1435,6 +1525,7 @@ def _apply_attention( "use_base2_exp": use_base2_exp, "use_experimental_scheduler": use_experimental_scheduler, "ulysses_shards": ulysses_shards, + "ulysses_attention_chunks": ulysses_attention_chunks, "dim_head": dim_head, "split_head_dim": split_head_dim, "float32_qk_product": float32_qk_product, @@ -1648,11 +1739,13 @@ def __init__( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ulysses_shards: int = -1, + ulysses_attention_chunks: int = 1, ): self.dpa_layer = None self.use_base2_exp = use_base2_exp self.use_experimental_scheduler = use_experimental_scheduler self.ulysses_shards = ulysses_shards + self.ulysses_attention_chunks = ulysses_attention_chunks if attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error @@ -1716,6 +1809,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False, use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False, ulysses_shards=(self.ulysses_shards if hasattr(self, "ulysses_shards") else -1), + ulysses_attention_chunks=(self.ulysses_attention_chunks if hasattr(self, "ulysses_attention_chunks") else 1), ) @@ -1737,6 +1831,7 @@ class AttentionOp(nn.Module): use_base2_exp: bool = False use_experimental_scheduler: bool = False ulysses_shards: int = -1 + ulysses_attention_chunks: int = 1 def setup(self): self.dpa_layer = None @@ -1785,6 +1880,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask use_base2_exp=self.use_base2_exp, use_experimental_scheduler=self.use_experimental_scheduler, ulysses_shards=self.ulysses_shards, + ulysses_attention_chunks=self.ulysses_attention_chunks, ) @@ -1827,6 +1923,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } @@ -1878,6 +1975,7 @@ def __init__( use_base2_exp=attention_config["use_base2_exp"], use_experimental_scheduler=attention_config["use_experimental_scheduler"], ulysses_shards=attention_config["ulysses_shards"], + ulysses_attention_chunks=attention_config["ulysses_attention_chunks"], ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 40c6be3f7..4cdfd0ca1 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -360,6 +360,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } @@ -584,6 +585,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py index 400b967b8..91efde148 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py @@ -788,6 +788,7 @@ def __init__( enable_jax_named_scopes: bool = False, use_base2_exp: bool = False, use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, face_flash_min_seq_length: int = 0, motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, motion_encoder_size: int = 512, @@ -812,6 +813,13 @@ def __init__( self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) self.names_which_can_be_saved = names_which_can_be_saved or [] self.names_which_can_be_offloaded = names_which_can_be_offloaded or [] + attention_config = { + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + "ulysses_shards": -1, + "ulysses_attention_chunks": 1, + **(attention_config or {}), + } self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -903,8 +911,7 @@ def init_block(rngs): dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) if scan_layers: @@ -932,8 +939,7 @@ def init_block(rngs): dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) blocks.append(block) self.blocks = nnx.List(blocks) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index e9acbdb48..ca052282f 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -100,6 +100,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } @@ -348,6 +349,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8b0493ed3..b099af122 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -174,6 +174,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "use_base2_exp": config.use_base2_exp, "use_experimental_scheduler": config.use_experimental_scheduler, "ulysses_shards": getattr(config, "ulysses_shards", -1), + "ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1), } # 2. eval_shape - will not use flops or create weights on device diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py index ffd87ddfd..7531460e1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py @@ -92,6 +92,12 @@ def _create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes wan_config["use_base2_exp"] = config.use_base2_exp wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler + wan_config["attention_config"] = { + "use_base2_exp": config.use_base2_exp, + "use_experimental_scheduler": config.use_experimental_scheduler, + "ulysses_shards": getattr(config, "ulysses_shards", -1), + "ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1), + } # 2. eval_shape – creates the model structure without allocating HBM. p_model_factory = partial(_create_model, wan_config=wan_config) diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index dcdf9396d..87e4d7a1e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -90,6 +90,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "use_base2_exp": config.use_base2_exp, "use_experimental_scheduler": config.use_experimental_scheduler, "ulysses_shards": getattr(config, "ulysses_shards", -1), + "ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1), } wan_config["scan_layers"] = False diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 708af4066..4421b2006 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -235,6 +235,26 @@ def test_select_flash_block_sizes_derives_cross_attn_defaults_for_tokamax(self): self.assertIsNone(cross_attention_block_sizes.block_kv_dq) self.assertTrue(cross_attention_block_sizes.use_fused_bwd_kernel) + def test_ulysses_head_chunk_ranges_preserve_head_layout_with_remainder(self): + ranges = attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=2) + + self.assertEqual(ranges, [(0, 16), (16, 40)]) + self.assertEqual( + attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=5), + [(0, 8), (8, 16), (16, 24), (24, 32), (32, 40)], + ) + self.assertEqual( + attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=3), [(0, 8), (8, 16), (16, 40)] + ) + self.assertEqual(attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=1), [(0, 40)]) + + head_major = jnp.arange(40 * 3, dtype=jnp.float32).reshape(40, 3) + reconstructed = jnp.concatenate((head_major[0:16], head_major[16:40]), axis=0) + self.assertTrue(jnp.array_equal(reconstructed, head_major)) + + ranges_array = jnp.array(ranges) + self.assertTrue(jnp.all((ranges_array[:, 1] - ranges_array[:, 0]) % 8 == 0)) + def test_ulysses_attention_round_trips_query_when_heads_are_divisible(self): """Ulysses attention should preserve the query layout after its collectives.""" batch = 2 @@ -287,6 +307,65 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, query)) + def test_ulysses_attention_chunk_counts_are_numerically_equivalent(self): + """Chunked all-to-all should preserve the same head/sequence layout as one-shot all-to-all.""" + batch = 2 + length = 6 + heads = 8 + head_depth = 3 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_mesh() + + def fake_make_splash_mha(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, segment_ids + return q + jnp.mean(v, axis=1, keepdims=True) + + return fake_kernel + + def run_with_chunks(num_chunks): + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_axis_rules()), + mock.patch.object( + attention_flax.splash_attention_kernel, + "make_splash_mha", + side_effect=fake_make_splash_mha, + ), + ): + return attention_flax._ulysses_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_attention_chunks=num_chunks, + ) + + one_chunk = run_with_chunks(1) + two_chunks = run_with_chunks(2) + three_chunks_with_remainder = run_with_chunks(3) + + self.assertEqual(one_chunk.shape, query.shape) + self.assertTrue(jnp.array_equal(one_chunk, two_chunks)) + self.assertTrue(jnp.array_equal(one_chunk, three_chunks_with_remainder)) + def test_ulysses_attention_raises_when_heads_are_not_divisible_by_context_shards(self): """Ulysses attention should fail fast when heads cannot be evenly sharded.""" batch = 2 @@ -508,6 +587,67 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, query)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring chunk equivalence test requires at least 4 devices.") + def test_ulysses_ring_attention_chunk_counts_are_numerically_equivalent(self): + """Chunked Ulysses+ring all-to-all should match the one-shot layout and numerics.""" + batch = 2 + length = 8 + heads = 8 + head_depth = 3 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + def fake_make_ring_attention(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, segment_ids + return q + jnp.mean(v, axis=1, keepdims=True) + + return fake_kernel + + def run_with_chunks(num_chunks): + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + return attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + ulysses_attention_chunks=num_chunks, + ) + + one_chunk = run_with_chunks(1) + two_chunks = run_with_chunks(2) + three_chunks_with_remainder = run_with_chunks(3) + + self.assertEqual(one_chunk.shape, query.shape) + self.assertTrue(jnp.array_equal(one_chunk, two_chunks)) + self.assertTrue(jnp.array_equal(one_chunk, three_chunks_with_remainder)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention mask test requires at least 4 devices.") def test_ulysses_ring_attention_masks_global_kv_padding(self): """Hybrid Ulysses+ring masks padding via segment ids, not a NumpyMask."""