Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
ulysses_shards: -1
# Splits Ulysses all-to-all into head-group chunks. The last chunk carries any remainder.
ulysses_attention_chunks: 1
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
104 changes: 101 additions & 3 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,44 @@ def _build_padding_segment_ids(
return segment_ids_cls(q=q_segment_ids, kv=kv_segment_ids)


def _ulysses_head_chunk_ranges(num_heads: int, ulysses_shards: int, num_chunks: int):
"""Build head-axis ranges for chunked Ulysses all-to-all.

The Ulysses all-to-all splits each local chunk's head axis over
`ulysses_shards`, so every returned range length is a multiple of
`ulysses_shards`. When `num_chunks` does not evenly divide the number of
Ulysses-sized head groups, earlier chunks get the floor-sized range and the
final chunk carries the remainder.

Returns:
A list of `(start, end)` half-open ranges over the head axis. Concatenating
tensors sliced with these ranges along the head axis restores the original
head layout. For `num_chunks <= 1`, returns `[(0, num_heads)]`, which is the
unchunked all-to-all path.
"""
if num_chunks <= 1:
return [(0, num_heads)]
if num_heads % ulysses_shards != 0:
raise ValueError(
"Ulysses attention requires the number of heads to be divisible by the Ulysses shard count, "
f"got heads={num_heads} and ulysses_shards={ulysses_shards}."
)

head_groups = num_heads // ulysses_shards
num_chunks = min(num_chunks, head_groups)
regular_groups_per_chunk = max(1, head_groups // num_chunks)

ranges = []
start_group = 0
for chunk_idx in range(num_chunks):
end_group = head_groups if chunk_idx == num_chunks - 1 else min(start_group + regular_groups_per_chunk, head_groups)
if start_group >= end_group:
break
ranges.append((start_group * ulysses_shards, end_group * ulysses_shards))
start_group = end_group
return ranges


def _tpu_flash_attention(
query: jax.Array,
key: jax.Array,
Expand Down Expand Up @@ -640,6 +678,7 @@ def _ulysses_attention(
use_custom_kernel: bool = False,
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
ulysses_attention_chunks: int = 1,
) -> jax.Array:
"""Ulysses sequence-parallel attention.

Expand All @@ -662,6 +701,7 @@ def _ulysses_attention(
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
f"got heads={num_heads} and context_shards={num_shards}."
)

if not use_custom_kernel:
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")

Expand Down Expand Up @@ -762,7 +802,21 @@ def wrap_ulysses_attention(query, key, value):
"Warning, batch dimension should be shardable among the devices in data and fsdp"
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
)
x = wrap_ulysses_attention(query, key, value)
head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_shards, ulysses_attention_chunks)
if len(head_chunk_ranges) > 1:
# Run Ulysses all-to-all per head group so XLA can overlap one group's
# collective with another group's head-parallel local attention compute.
chunk_outputs = [
wrap_ulysses_attention(
query[:, start:end],
key[:, start:end],
value[:, start:end],
)
for start, end in head_chunk_ranges
]
x = jnp.concatenate(chunk_outputs, axis=1)
else:
x = wrap_ulysses_attention(query, key, value)
x = x[:, :, :orig_q_seq_len, :]
x = _reshape_heads_to_head_dim(x)

Expand All @@ -787,6 +841,7 @@ def _ulysses_ring_attention(
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
ulysses_shards: int = -1,
ulysses_attention_chunks: int = 1,
) -> jax.Array:
"""2D context-parallel attention using a private Ulysses x ring mesh.

Expand Down Expand Up @@ -916,7 +971,21 @@ def wrap_ulysses_ring_attention(query, key, value):
"Warning, batch dimension should be shardable among the devices in data and fsdp"
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
)
x = wrap_ulysses_ring_attention(query, key, value)
head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_ulysses_shards, ulysses_attention_chunks)
if len(head_chunk_ranges) > 1:
# Run the Ulysses phase per head group before the ring phase so XLA can
# overlap all-to-all collectives with head-parallel local attention compute.
chunk_outputs = [
wrap_ulysses_ring_attention(
query[:, start:end],
key[:, start:end],
value[:, start:end],
)
for start, end in head_chunk_ranges
]
x = jnp.concatenate(chunk_outputs, axis=1)
else:
x = wrap_ulysses_ring_attention(query, key, value)
x = jax.lax.with_sharding_constraint(x, q_axis_names)
x = x[:, :, :orig_q_seq_len, :]
x = _reshape_heads_to_head_dim(x)
Expand All @@ -941,6 +1010,7 @@ def _ulysses_ring_custom_attention(
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
bidirectional: bool = False,
ulysses_attention_chunks: int = 1,
) -> jax.Array:
"""Hybrid Ulysses + Ring (USP) with the CUSTOM splash kernel on main's mesh.

Expand Down Expand Up @@ -1057,7 +1127,21 @@ def wrap_ulysses_ring_attention(query, key, value):
attention_output = a2a(attention_output, split_axis=2, concat_axis=1)
return attention_output

x = wrap_ulysses_ring_attention(query, key, value)
head_chunk_ranges = _ulysses_head_chunk_ranges(num_heads, num_ulysses_shards, ulysses_attention_chunks)
if len(head_chunk_ranges) > 1:
# Run the custom Ulysses phase per head group before the custom ring phase so
# XLA can overlap all-to-all collectives with head-parallel attention compute.
chunk_outputs = [
wrap_ulysses_ring_attention(
query[:, start:end],
key[:, start:end],
value[:, start:end],
)
for start, end in head_chunk_ranges
]
x = jnp.concatenate(chunk_outputs, axis=1)
else:
x = wrap_ulysses_ring_attention(query, key, value)
x = jax.lax.with_sharding_constraint(x, q_axis_names)
x = x[:, :, :orig_q_seq_len, :]
x = _reshape_heads_to_head_dim(x)
Expand Down Expand Up @@ -1207,6 +1291,7 @@ def ulysses_custom_kernel(q, k, v, context):
use_custom_kernel=True,
use_base2_exp=context.get("use_base2_exp", True),
use_experimental_scheduler=context.get("use_experimental_scheduler", False),
ulysses_attention_chunks=context["ulysses_attention_chunks"],
)


Expand All @@ -1228,6 +1313,7 @@ def ulysses_ring_custom_kernel(q, k, v, context):
ulysses_shards=context["ulysses_shards"],
use_base2_exp=context.get("use_base2_exp", True),
use_experimental_scheduler=context.get("use_experimental_scheduler", False),
ulysses_attention_chunks=context["ulysses_attention_chunks"],
)


Expand All @@ -1253,6 +1339,7 @@ def ulysses_ring_custom_bidir_kernel(q, k, v, context):
use_base2_exp=context.get("use_base2_exp", True),
use_experimental_scheduler=context.get("use_experimental_scheduler", False),
bidirectional=True,
ulysses_attention_chunks=context["ulysses_attention_chunks"],
)


Expand All @@ -1271,6 +1358,7 @@ def ulysses_kernel(q, k, v, context):
mask_padding_tokens=context["mask_padding_tokens"],
residual_checkpoint_name=context["residual_checkpoint_name"],
attention_mask=context["attention_mask"],
ulysses_attention_chunks=context["ulysses_attention_chunks"],
)


Expand All @@ -1292,6 +1380,7 @@ def ulysses_ring_kernel(q, k, v, context):
use_base2_exp=context["use_base2_exp"],
use_experimental_scheduler=context["use_experimental_scheduler"],
ulysses_shards=context["ulysses_shards"],
ulysses_attention_chunks=context["ulysses_attention_chunks"],
)


Expand Down Expand Up @@ -1404,6 +1493,7 @@ def _apply_attention(
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
ulysses_shards: int = -1,
ulysses_attention_chunks: int = 1,
):
"""Routes to different attention kernels using a module-level registry."""

Expand Down Expand Up @@ -1435,6 +1525,7 @@ def _apply_attention(
"use_base2_exp": use_base2_exp,
"use_experimental_scheduler": use_experimental_scheduler,
"ulysses_shards": ulysses_shards,
"ulysses_attention_chunks": ulysses_attention_chunks,
"dim_head": dim_head,
"split_head_dim": split_head_dim,
"float32_qk_product": float32_qk_product,
Expand Down Expand Up @@ -1648,11 +1739,13 @@ def __init__(
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
ulysses_shards: int = -1,
ulysses_attention_chunks: int = 1,
):
self.dpa_layer = None
self.use_base2_exp = use_base2_exp
self.use_experimental_scheduler = use_experimental_scheduler
self.ulysses_shards = ulysses_shards
self.ulysses_attention_chunks = ulysses_attention_chunks
if attention_kernel == "cudnn_flash_te":
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

Expand Down Expand Up @@ -1716,6 +1809,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False,
use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False,
ulysses_shards=(self.ulysses_shards if hasattr(self, "ulysses_shards") else -1),
ulysses_attention_chunks=(self.ulysses_attention_chunks if hasattr(self, "ulysses_attention_chunks") else 1),
)


Expand All @@ -1737,6 +1831,7 @@ class AttentionOp(nn.Module):
use_base2_exp: bool = False
use_experimental_scheduler: bool = False
ulysses_shards: int = -1
ulysses_attention_chunks: int = 1

def setup(self):
self.dpa_layer = None
Expand Down Expand Up @@ -1785,6 +1880,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
use_base2_exp=self.use_base2_exp,
use_experimental_scheduler=self.use_experimental_scheduler,
ulysses_shards=self.ulysses_shards,
ulysses_attention_chunks=self.ulysses_attention_chunks,
)


Expand Down Expand Up @@ -1827,6 +1923,7 @@ def __init__(
"use_base2_exp": False,
"use_experimental_scheduler": False,
"ulysses_shards": -1,
"ulysses_attention_chunks": 1,
**(attention_config or {}),
}

Expand Down Expand Up @@ -1878,6 +1975,7 @@ def __init__(
use_base2_exp=attention_config["use_base2_exp"],
use_experimental_scheduler=attention_config["use_experimental_scheduler"],
ulysses_shards=attention_config["ulysses_shards"],
ulysses_attention_chunks=attention_config["ulysses_attention_chunks"],
)
# None axes corresponds to the stacked weights across all blocks
# because of the use of nnx.vmap and nnx.scan.
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def __init__(
"use_base2_exp": False,
"use_experimental_scheduler": False,
"ulysses_shards": -1,
"ulysses_attention_chunks": 1,
**(attention_config or {}),
}

Expand Down Expand Up @@ -584,6 +585,7 @@ def __init__(
"use_base2_exp": False,
"use_experimental_scheduler": False,
"ulysses_shards": -1,
"ulysses_attention_chunks": 1,
**(attention_config or {}),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def __init__(
enable_jax_named_scopes: bool = False,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
attention_config: Optional[dict] = None,
face_flash_min_seq_length: int = 0,
motion_encoder_channel_sizes: Optional[Dict[str, int]] = None,
motion_encoder_size: int = 512,
Expand All @@ -812,6 +813,13 @@ def __init__(
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
self.names_which_can_be_saved = names_which_can_be_saved or []
self.names_which_can_be_offloaded = names_which_can_be_offloaded or []
attention_config = {
"use_base2_exp": use_base2_exp,
"use_experimental_scheduler": use_experimental_scheduler,
"ulysses_shards": -1,
"ulysses_attention_chunks": 1,
**(attention_config or {}),
}

self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)

Expand Down Expand Up @@ -903,8 +911,7 @@ def init_block(rngs):
dropout=dropout,
mask_padding_tokens=mask_padding_tokens,
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
attention_config=attention_config,
)

if scan_layers:
Expand Down Expand Up @@ -932,8 +939,7 @@ def init_block(rngs):
dropout=dropout,
mask_padding_tokens=mask_padding_tokens,
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
attention_config=attention_config,
)
blocks.append(block)
self.blocks = nnx.List(blocks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
"use_base2_exp": False,
"use_experimental_scheduler": False,
"ulysses_shards": -1,
"ulysses_attention_chunks": 1,
**(attention_config or {}),
}

Expand Down Expand Up @@ -348,6 +349,7 @@ def __init__(
"use_base2_exp": False,
"use_experimental_scheduler": False,
"ulysses_shards": -1,
"ulysses_attention_chunks": 1,
**(attention_config or {}),
}

Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
"use_base2_exp": config.use_base2_exp,
"use_experimental_scheduler": config.use_experimental_scheduler,
"ulysses_shards": getattr(config, "ulysses_shards", -1),
"ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1),
}

# 2. eval_shape - will not use flops or create weights on device
Expand Down
6 changes: 6 additions & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def _create_model(rngs: nnx.Rngs, wan_config: dict):
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes
wan_config["use_base2_exp"] = config.use_base2_exp
wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler
wan_config["attention_config"] = {
"use_base2_exp": config.use_base2_exp,
"use_experimental_scheduler": config.use_experimental_scheduler,
"ulysses_shards": getattr(config, "ulysses_shards", -1),
"ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1),
}

# 2. eval_shape – creates the model structure without allocating HBM.
p_model_factory = partial(_create_model, wan_config=wan_config)
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
"use_base2_exp": config.use_base2_exp,
"use_experimental_scheduler": config.use_experimental_scheduler,
"ulysses_shards": getattr(config, "ulysses_shards", -1),
"ulysses_attention_chunks": getattr(config, "ulysses_attention_chunks", 1),
}

wan_config["scan_layers"] = False
Expand Down
Loading
Loading