diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 6f1c5035f..f081ba95f 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -94,6 +94,8 @@ use_base2_exp: True use_experimental_scheduler: True # For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. ulysses_shards: -1 +# Splits Ulysses all-to-all into head-group chunks. The last chunk carries any remainder. +ulysses_attention_chunks: 1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f6edc8309..9a88470c0 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -413,6 +413,44 @@ def _build_padding_segment_ids( return segment_ids_cls(q=q_segment_ids, kv=kv_segment_ids) +def _ulysses_head_chunk_ranges(num_heads: int, ulysses_shards: int, num_chunks: int): + """Build head-axis ranges for chunked Ulysses all-to-all. + + The Ulysses all-to-all splits each local chunk's head axis over + `ulysses_shards`, so every returned range length is a multiple of + `ulysses_shards`. When `num_chunks` does not evenly divide the number of + Ulysses-sized head groups, earlier chunks get the floor-sized range and the + final chunk carries the remainder. + + Returns: + A list of `(start, end)` half-open ranges over the head axis. Concatenating + tensors sliced with these ranges along the head axis restores the original + head layout. For `num_chunks <= 1`, returns `[(0, num_heads)]`, which is the + unchunked all-to-all path. + """ + if num_chunks <= 1: + return [(0, num_heads)] + if num_heads % ulysses_shards != 0: + raise ValueError( + "Ulysses attention requires the number of heads to be divisible by the Ulysses shard count, " + f"got heads={num_heads} and ulysses_shards={ulysses_shards}." + ) + + head_groups = num_heads // ulysses_shards + num_chunks = min(num_chunks, head_groups) + regular_groups_per_chunk = max(1, head_groups // num_chunks) + + ranges = [] + start_group = 0 + for chunk_idx in range(num_chunks): + end_group = head_groups if chunk_idx == num_chunks - 1 else min(start_group + regular_groups_per_chunk, head_groups) + if start_group >= end_group: + break + ranges.append((start_group * ulysses_shards, end_group * ulysses_shards)) + start_group = end_group + return ranges + + def _tpu_flash_attention( query: jax.Array, key: jax.Array, @@ -640,6 +678,7 @@ def _ulysses_attention( use_custom_kernel: bool = False, use_base2_exp: bool = True, use_experimental_scheduler: bool = False, + ulysses_attention_chunks: int = 1, ) -> jax.Array: """Ulysses sequence-parallel attention. @@ -662,6 +701,7 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) + if not use_custom_kernel: block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") @@ -762,7 +802,21 @@ def wrap_ulysses_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_attention(query, key, value) + head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_shards, ulysses_attention_chunks) + if len(head_chunk_ranges) > 1: + # Run Ulysses all-to-all per head group so XLA can overlap one group's + # collective with another group's head-parallel local attention compute. + chunk_outputs = [ + wrap_ulysses_attention( + query[:, start:end], + key[:, start:end], + value[:, start:end], + ) + for start, end in head_chunk_ranges + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_attention(query, key, value) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -787,6 +841,7 @@ def _ulysses_ring_attention( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ulysses_shards: int = -1, + ulysses_attention_chunks: int = 1, ) -> jax.Array: """2D context-parallel attention using a private Ulysses x ring mesh. @@ -916,7 +971,21 @@ def wrap_ulysses_ring_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_ring_attention(query, key, value) + head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_ulysses_shards, ulysses_attention_chunks) + if len(head_chunk_ranges) > 1: + # Run the Ulysses phase per head group before the ring phase so XLA can + # overlap all-to-all collectives with head-parallel local attention compute. + chunk_outputs = [ + wrap_ulysses_ring_attention( + query[:, start:end], + key[:, start:end], + value[:, start:end], + ) + for start, end in head_chunk_ranges + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_ring_attention(query, key, value) x = jax.lax.with_sharding_constraint(x, q_axis_names) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -941,6 +1010,7 @@ def _ulysses_ring_custom_attention( use_base2_exp: bool = True, use_experimental_scheduler: bool = False, bidirectional: bool = False, + ulysses_attention_chunks: int = 1, ) -> jax.Array: """Hybrid Ulysses + Ring (USP) with the CUSTOM splash kernel on main's mesh. @@ -1057,7 +1127,21 @@ def wrap_ulysses_ring_attention(query, key, value): attention_output = a2a(attention_output, split_axis=2, concat_axis=1) return attention_output - x = wrap_ulysses_ring_attention(query, key, value) + head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_ulysses_shards, ulysses_attention_chunks) + if len(head_chunk_ranges) > 1: + # Run the custom Ulysses phase per head group before the custom ring phase so + # XLA can overlap all-to-all collectives with head-parallel attention compute. + chunk_outputs = [ + wrap_ulysses_ring_attention( + query[:, start:end], + key[:, start:end], + value[:, start:end], + ) + for start, end in head_chunk_ranges + ] + x = jnp.concatenate(chunk_outputs, axis=1) + else: + x = wrap_ulysses_ring_attention(query, key, value) x = jax.lax.with_sharding_constraint(x, q_axis_names) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -1207,6 +1291,7 @@ def ulysses_custom_kernel(q, k, v, context): use_custom_kernel=True, use_base2_exp=context.get("use_base2_exp", True), use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1228,6 +1313,7 @@ def ulysses_ring_custom_kernel(q, k, v, context): ulysses_shards=context["ulysses_shards"], use_base2_exp=context.get("use_base2_exp", True), use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1253,6 +1339,7 @@ def ulysses_ring_custom_bidir_kernel(q, k, v, context): use_base2_exp=context.get("use_base2_exp", True), use_experimental_scheduler=context.get("use_experimental_scheduler", False), bidirectional=True, + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1271,6 +1358,7 @@ def ulysses_kernel(q, k, v, context): mask_padding_tokens=context["mask_padding_tokens"], residual_checkpoint_name=context["residual_checkpoint_name"], attention_mask=context["attention_mask"], + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1292,6 +1380,7 @@ def ulysses_ring_kernel(q, k, v, context): use_base2_exp=context["use_base2_exp"], use_experimental_scheduler=context["use_experimental_scheduler"], ulysses_shards=context["ulysses_shards"], + ulysses_attention_chunks=context["ulysses_attention_chunks"], ) @@ -1404,6 +1493,7 @@ def _apply_attention( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ulysses_shards: int = -1, + ulysses_attention_chunks: int = 1, ): """Routes to different attention kernels using a module-level registry.""" @@ -1435,6 +1525,7 @@ def _apply_attention( "use_base2_exp": use_base2_exp, "use_experimental_scheduler": use_experimental_scheduler, "ulysses_shards": ulysses_shards, + "ulysses_attention_chunks": ulysses_attention_chunks, "dim_head": dim_head, "split_head_dim": split_head_dim, "float32_qk_product": float32_qk_product, @@ -1648,11 +1739,13 @@ def __init__( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ulysses_shards: int = -1, + ulysses_attention_chunks: int = 1, ): self.dpa_layer = None self.use_base2_exp = use_base2_exp self.use_experimental_scheduler = use_experimental_scheduler self.ulysses_shards = ulysses_shards + self.ulysses_attention_chunks = ulysses_attention_chunks if attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error @@ -1716,6 +1809,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False, use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False, ulysses_shards=(self.ulysses_shards if hasattr(self, "ulysses_shards") else -1), + ulysses_attention_chunks=(self.ulysses_attention_chunks if hasattr(self, "ulysses_attention_chunks") else 1), ) @@ -1737,6 +1831,7 @@ class AttentionOp(nn.Module): use_base2_exp: bool = False use_experimental_scheduler: bool = False ulysses_shards: int = -1 + ulysses_attention_chunks: int = 1 def setup(self): self.dpa_layer = None @@ -1785,6 +1880,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask use_base2_exp=self.use_base2_exp, use_experimental_scheduler=self.use_experimental_scheduler, ulysses_shards=self.ulysses_shards, + ulysses_attention_chunks=self.ulysses_attention_chunks, ) @@ -1827,6 +1923,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } @@ -1878,6 +1975,7 @@ def __init__( use_base2_exp=attention_config["use_base2_exp"], use_experimental_scheduler=attention_config["use_experimental_scheduler"], ulysses_shards=attention_config["ulysses_shards"], + ulysses_attention_chunks=attention_config["ulysses_attention_chunks"], ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 40c6be3f7..4cdfd0ca1 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -360,6 +360,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } @@ -584,6 +585,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py index 400b967b8..91efde148 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py @@ -788,6 +788,7 @@ def __init__( enable_jax_named_scopes: bool = False, use_base2_exp: bool = False, use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, face_flash_min_seq_length: int = 0, motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, motion_encoder_size: int = 512, @@ -812,6 +813,13 @@ def __init__( self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) self.names_which_can_be_saved = names_which_can_be_saved or [] self.names_which_can_be_offloaded = names_which_can_be_offloaded or [] + attention_config = { + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + "ulysses_shards": -1, + "ulysses_attention_chunks": 1, + **(attention_config or {}), + } self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -903,8 +911,7 @@ def init_block(rngs): dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) if scan_layers: @@ -932,8 +939,7 @@ def init_block(rngs): dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) blocks.append(block) self.blocks = nnx.List(blocks) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index e9acbdb48..ca052282f 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -100,6 +100,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } @@ -348,6 +349,7 @@ def __init__( "use_base2_exp": False, "use_experimental_scheduler": False, "ulysses_shards": -1, + "ulysses_attention_chunks": 1, **(attention_config or {}), } diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8b0493ed3..b099af122 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -174,6 +174,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "use_base2_exp": config.use_base2_exp, "use_experimental_scheduler": config.use_experimental_scheduler, "ulysses_shards": getattr(config, "ulysses_shards", -1), + "ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1), } # 2. eval_shape - will not use flops or create weights on device diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py index ffd87ddfd..7531460e1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py @@ -92,6 +92,12 @@ def _create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes wan_config["use_base2_exp"] = config.use_base2_exp wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler + wan_config["attention_config"] = { + "use_base2_exp": config.use_base2_exp, + "use_experimental_scheduler": config.use_experimental_scheduler, + "ulysses_shards": getattr(config, "ulysses_shards", -1), + "ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1), + } # 2. eval_shape – creates the model structure without allocating HBM. p_model_factory = partial(_create_model, wan_config=wan_config) diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index dcdf9396d..87e4d7a1e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -90,6 +90,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "use_base2_exp": config.use_base2_exp, "use_experimental_scheduler": config.use_experimental_scheduler, "ulysses_shards": getattr(config, "ulysses_shards", -1), + "ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1), } wan_config["scan_layers"] = False diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 708af4066..4421b2006 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -235,6 +235,26 @@ def test_select_flash_block_sizes_derives_cross_attn_defaults_for_tokamax(self): self.assertIsNone(cross_attention_block_sizes.block_kv_dq) self.assertTrue(cross_attention_block_sizes.use_fused_bwd_kernel) + def test_ulysses_head_chunk_ranges_preserve_head_layout_with_remainder(self): + ranges = attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=2) + + self.assertEqual(ranges, [(0, 16), (16, 40)]) + self.assertEqual( + attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=5), + [(0, 8), (8, 16), (16, 24), (24, 32), (32, 40)], + ) + self.assertEqual( + attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=3), [(0, 8), (8, 16), (16, 40)] + ) + self.assertEqual(attention_flax._ulysses_head_chunk_ranges(num_heads=40, ulysses_shards=8, num_chunks=1), [(0, 40)]) + + head_major = jnp.arange(40 * 3, dtype=jnp.float32).reshape(40, 3) + reconstructed = jnp.concatenate((head_major[0:16], head_major[16:40]), axis=0) + self.assertTrue(jnp.array_equal(reconstructed, head_major)) + + ranges_array = jnp.array(ranges) + self.assertTrue(jnp.all((ranges_array[:, 1] - ranges_array[:, 0]) % 8 == 0)) + def test_ulysses_attention_round_trips_query_when_heads_are_divisible(self): """Ulysses attention should preserve the query layout after its collectives.""" batch = 2 @@ -287,6 +307,65 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, query)) + def test_ulysses_attention_chunk_counts_are_numerically_equivalent(self): + """Chunked all-to-all should preserve the same head/sequence layout as one-shot all-to-all.""" + batch = 2 + length = 6 + heads = 8 + head_depth = 3 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_mesh() + + def fake_make_splash_mha(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, segment_ids + return q + jnp.mean(v, axis=1, keepdims=True) + + return fake_kernel + + def run_with_chunks(num_chunks): + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_axis_rules()), + mock.patch.object( + attention_flax.splash_attention_kernel, + "make_splash_mha", + side_effect=fake_make_splash_mha, + ), + ): + return attention_flax._ulysses_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_attention_chunks=num_chunks, + ) + + one_chunk = run_with_chunks(1) + two_chunks = run_with_chunks(2) + three_chunks_with_remainder = run_with_chunks(3) + + self.assertEqual(one_chunk.shape, query.shape) + self.assertTrue(jnp.array_equal(one_chunk, two_chunks)) + self.assertTrue(jnp.array_equal(one_chunk, three_chunks_with_remainder)) + def test_ulysses_attention_raises_when_heads_are_not_divisible_by_context_shards(self): """Ulysses attention should fail fast when heads cannot be evenly sharded.""" batch = 2 @@ -508,6 +587,67 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, query)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring chunk equivalence test requires at least 4 devices.") + def test_ulysses_ring_attention_chunk_counts_are_numerically_equivalent(self): + """Chunked Ulysses+ring all-to-all should match the one-shot layout and numerics.""" + batch = 2 + length = 8 + heads = 8 + head_depth = 3 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + def fake_make_ring_attention(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, segment_ids + return q + jnp.mean(v, axis=1, keepdims=True) + + return fake_kernel + + def run_with_chunks(num_chunks): + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + return attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + ulysses_attention_chunks=num_chunks, + ) + + one_chunk = run_with_chunks(1) + two_chunks = run_with_chunks(2) + three_chunks_with_remainder = run_with_chunks(3) + + self.assertEqual(one_chunk.shape, query.shape) + self.assertTrue(jnp.array_equal(one_chunk, two_chunks)) + self.assertTrue(jnp.array_equal(one_chunk, three_chunks_with_remainder)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention mask test requires at least 4 devices.") def test_ulysses_ring_attention_masks_global_kv_padding(self): """Hybrid Ulysses+ring masks padding via segment ids, not a NumpyMask."""