diff --git a/src/maxdiffusion/kernels/custom_splash_attention.py b/src/maxdiffusion/kernels/custom_splash_attention.py index ab5b6ee6a..1277200cf 100644 --- a/src/maxdiffusion/kernels/custom_splash_attention.py +++ b/src/maxdiffusion/kernels/custom_splash_attention.py @@ -51,7 +51,21 @@ def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = N self.block_kv_compute = block_kv_compute if block_kv_compute is not None else block_kv +# Fixed-m softmax-bound constants. Instead of tracking the online-softmax +# running max per KV block, eligible heads subtract a precomputed per-query +# upper bound on the logits (Cauchy-Schwarz: max_j q_i.k_j <= ||q_i|| * +# max_j||k_j||). _FIXED_M_RECENTER (C) shifts the exp2 exponents up so the +# largest surviving term stays above the f32 subnormal-flush floor 2^-126: +# with k-smoothing the per-row max is >= 0, so the max term has exponent +# >= -ceil(bound) + C, which stays > -126 while ceil(bound) <= +# _FIXED_M_SAFE_BOUND (= C + 126 - 1 of margin). Heads whose worst-case bound +# exceeds the gate fall back to online softmax (the "sink" heads). +_FIXED_M_RECENTER = 88.0 +_FIXED_M_SAFE_BOUND = 213.0 + + def _flash_attention_kernel( + mk_ref, q_ref, k_ref, v_ref, @@ -73,34 +87,45 @@ def _flash_attention_kernel( kv_seq_len: int, use_base2_exp: bool = True, fuse_reciprocal: bool = True, + use_fixed_m: bool = False, ): float32 = jnp.float32 head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES) if rem != 0: raise NotImplementedError(f"{head_dim_v=} should be a multiple of {NUM_SUBLANES}") - _, _, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) + h, _, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) exp = jnp.exp2 if use_base2_exp else jnp.exp + sv_dims = (((0,), (0,)), ((), ())) + + # Per-head dispatch: heads inside the no-flush window run fixed-m, the rest + # keep online softmax. Branch once per head (body level), never per step. + is_fixed = (mk_ref[1, h] > 0.5) if use_fixed_m else False @pl.when(j == 0) def init(): o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) - m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + if use_fixed_m: - def compute_body(kv_compute_index, _): - m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - q = q_ref[...] - o_prev = o_scratch_ref[:] + @pl.when(is_fixed) + def _init_fixed(): + # Per-query Cauchy-Schwarz bound m_i = ceil(||q_i|| * max_j||k_j||) - C. + qf = q_ref[...].astype(float32) + qn = jnp.sqrt((qf * qf).sum(axis=1))[None, :] # (1, bq) per-query norm + bound = qn * mk_ref[0, h] + m_fixed = jnp.ceil(bound) - _FIXED_M_RECENTER + m_scratch_ref[...] = jnp.broadcast_to(m_fixed, m_scratch_ref.shape) - base_offset = kv_compute_index * bkv_compute - slice_k = pl.ds(base_offset, bkv_compute) - k_chunk = k_ref[slice_k, :] + @pl.when(jnp.logical_not(is_fixed)) + def _init_online(): + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) - qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32) - v_chunk = v_ref[slice_k, :] + else: + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) - # --- V1 VPU REGISTER TILING --- + def _online_inner(qk, v_chunk, m_prev, l_prev, o_prev): + # Standard online-softmax tiling over the VPU register block. step = bkv_compute_in for i in range(0, qk.shape[0], step): qk_slice = qk[i : i + step] @@ -113,84 +138,99 @@ def compute_body(kv_compute_index, _): alpha = exp(m_prev - m_next) l_next = l_curr + alpha * l_prev - sv_dims = (((0,), (0,)), ((), ())) o_curr = lax.dot_general( v_chunk[i : i + step], s_curr.astype(q_ref.dtype), sv_dims, preferred_element_type=float32, ) - - alpha_o = alpha[0:1, ...] - o_prev = alpha_o * o_prev + o_curr - + o_prev = alpha[0:1, ...] * o_prev + o_curr m_prev, l_prev = m_next, l_next - # --- END V1 TILING --- - - m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev - o_scratch_ref[:] = o_prev + return m_prev, l_prev, o_prev - def last_compute_body(kv_compute_index): - m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - q = q_ref[...] - o_prev = o_scratch_ref[:] - - slice_k_len = kv_seq_len % bkv_compute - slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len) - k_chunk = k_ref[slice_k, :] - - qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32) - v_chunk = v_ref[slice_k, :] - - # --- V1 VPU REGISTER TILING --- + def _fixed_inner(qk, v_chunk, m_fix, l_prev, o_prev): + # Fixed-m fast path: m is constant, so no reduce-max and no alpha rescale. step = bkv_compute_in for i in range(0, qk.shape[0], step): qk_slice = qk[i : i + step] - m_curr = qk_slice.max(axis=0)[None, :] - m_next = jnp.maximum(m_prev, m_curr) - s_curr = exp(qk_slice - m_next[0:1]) + s_curr = exp(qk_slice - m_fix[0:1]) l_curr = s_curr.sum(axis=0, keepdims=True) - alpha = exp(m_prev - m_next) - l_next = l_curr + alpha * l_prev - - sv_dims = (((0,), (0,)), ((), ())) o_curr = lax.dot_general( v_chunk[i : i + step], s_curr.astype(q_ref.dtype), sv_dims, preferred_element_type=float32, ) + o_prev = o_prev + o_curr + l_prev = l_prev + l_curr + return l_prev, o_prev - alpha_o = alpha[0:1, ...] - o_prev = alpha_o * o_prev + o_curr + def compute_body_online(kv_compute_index, _): + q = q_ref[...] + base_offset = kv_compute_index * bkv_compute + slice_k = pl.ds(base_offset, bkv_compute) + qk = lax.dot_general(k_ref[slice_k, :], q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[slice_k, :] + m_prev, l_prev, o_prev = _online_inner(qk, v_chunk, m_scratch_ref[...], l_scratch_ref[...], o_scratch_ref[:]) + m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev + o_scratch_ref[:] = o_prev - m_prev, l_prev = m_next, l_next - # --- END V1 TILING --- + def compute_body_fixed(kv_compute_index, _): + q = q_ref[...] + base_offset = kv_compute_index * bkv_compute + slice_k = pl.ds(base_offset, bkv_compute) + qk = lax.dot_general(k_ref[slice_k, :], q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[slice_k, :] + l_prev, o_prev = _fixed_inner(qk, v_chunk, m_scratch_ref[...], l_scratch_ref[...], o_scratch_ref[:]) + l_scratch_ref[...] = l_prev + o_scratch_ref[:] = o_prev + def last_compute_body_online(kv_compute_index): + q = q_ref[...] + slice_k_len = kv_seq_len % bkv_compute + slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len) + qk = lax.dot_general(k_ref[slice_k, :], q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[slice_k, :] + m_prev, l_prev, o_prev = _online_inner(qk, v_chunk, m_scratch_ref[...], l_scratch_ref[...], o_scratch_ref[:]) m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev o_scratch_ref[:] = o_prev assert bkv % bkv_compute == 0 - @pl.when(j != grid_width - 1) - def body(): - lax.fori_loop(0, (bkv // bkv_compute), compute_body, None, unroll=True) + if use_fixed_m: + + @pl.when((j != grid_width - 1) & is_fixed) + def _body_fixed(): + lax.fori_loop(0, (bkv // bkv_compute), compute_body_fixed, None, unroll=True) + + @pl.when((j != grid_width - 1) & jnp.logical_not(is_fixed)) + def _body_online(): + lax.fori_loop(0, (bkv // bkv_compute), compute_body_online, None, unroll=True) + + else: + + @pl.when(j != grid_width - 1) + def body(): + lax.fori_loop(0, (bkv // bkv_compute), compute_body_online, None, unroll=True) + # The final KV block always runs online for every head. Fixed-m heads arrive + # with m_scratch = ceil(bound) - C and o/l at that reference; the online + # alpha = exp2(m_prev - m_next) rescale renormalizes them transparently. @pl.when(j == grid_width - 1) def last_body(): if kv_seq_len % bkv == 0: iter_num = bkv // bkv_compute - lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + lax.fori_loop(0, iter_num, compute_body_online, None, unroll=True) else: remain_kv_seq_len = kv_seq_len % bkv iter_num = (remain_kv_seq_len + bkv_compute - 1) // bkv_compute if remain_kv_seq_len % bkv_compute == 0: - lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + lax.fori_loop(0, iter_num, compute_body_online, None, unroll=True) else: - lax.fori_loop(0, iter_num - 1, compute_body, None, unroll=True) - last_compute_body(iter_num - 1) + lax.fori_loop(0, iter_num - 1, compute_body_online, None, unroll=True) + last_compute_body_online(iter_num - 1) @pl.when(j == grid_width - 1) def end(): @@ -373,9 +413,16 @@ def _splash_attention_forward( use_base2_exp: bool = True, use_experimental_scheduler: bool = False, vmem_limit_bytes: int | None = None, + use_fixed_m: bool = False, + mk: jax.Array | None = None, ): num_q_heads, padded_q_seq_len, head_dim_qk = q.shape head_dim_v = v.shape[-1] + # Scalar-prefetch operand carrying per-head fixed-m data: + # mk[0, h] = max_j||k_j|| (Cauchy-Schwarz factor), mk[1, h] = eligibility. + # A dummy is supplied for online callers; the kernel ignores it. + if mk is None: + mk = jnp.zeros((2, num_q_heads), jnp.float32) bq, bkv = block_sizes.block_q, block_sizes.block_kv bkv_compute = block_sizes.block_kv_compute num_kv_heads = k.shape[0] @@ -431,9 +478,10 @@ def v_index_map(h, i, j, *_): q_seq_len=actual_q_seq_len, kv_seq_len=actual_kv_seq_len, use_base2_exp=use_base2_exp, + use_fixed_m=use_fixed_m, ), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=0, + num_scalar_prefetch=1, in_specs=in_specs, out_specs=out_specs, grid=grid, @@ -446,7 +494,7 @@ def v_index_map(h, i, j, *_): vmem_limit_bytes=vmem_limit_bytes, ), out_shape=out_shapes, - )(q, k, v) + )(mk, q, k, v) return all_out[-1] @@ -461,6 +509,8 @@ def _splash_attention_forward_ring( use_base2_exp: bool = True, use_experimental_scheduler: bool = False, vmem_limit_bytes: int | None = None, + use_fixed_m: bool = False, + mk: jax.Array | None = None, ): """Ring-specific forward path that returns pre-reciprocal fp32 accumulators. @@ -524,6 +574,13 @@ def v_index_map(h, i, j, *_): grid_height = (actual_q_seq_len + bq - 1) // bq grid = (num_q_heads, grid_height, grid_width) + # Scalar-prefetch operand carrying per-head fixed-m data (same convention as + # `_splash_attention_forward`): mk[0, h] = max_j||k_j|| over ALL ring shards + # (the caller all-reduces this over the ring axis), mk[1, h] = eligibility. + # A dummy is supplied for online callers; the kernel ignores it. + if mk is None: + mk = jnp.zeros((2, num_q_heads), jnp.float32) + all_out = pl.pallas_call( functools.partial( _flash_attention_kernel, @@ -538,9 +595,10 @@ def v_index_map(h, i, j, *_): kv_seq_len=actual_kv_seq_len, use_base2_exp=use_base2_exp, fuse_reciprocal=False, + use_fixed_m=use_fixed_m, ), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=0, + num_scalar_prefetch=1, in_specs=in_specs, out_specs=out_specs, grid=grid, @@ -553,7 +611,7 @@ def v_index_map(h, i, j, *_): vmem_limit_bytes=vmem_limit_bytes, ), out_shape=out_shapes, - )(q, k, v) + )(mk, q, k, v) out = jnp.swapaxes(all_out[3], 1, 2) # (h, head_dim_v, s) -> (h, s, head_dim_v) l = all_out[4][:, 0, :] # (h, s) m = all_out[5][:, 0, :] # (h, s) @@ -660,9 +718,12 @@ def make_splash_mha( use_base2_exp: bool = True, use_experimental_scheduler: bool = False, vmem_limit_bytes: int | None = None, + use_fixed_m: bool = False, ): - def _splash_attention(q, k, v): + def _splash_attention(q, k, v, mk=None): if heads_per_tile > 1: + if use_fixed_m: + raise NotImplementedError("fixed-m is not supported with heads_per_tile > 1") return _splash_attention_forward_mhpt( q, k, @@ -687,6 +748,8 @@ def _splash_attention(q, k, v): use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, vmem_limit_bytes=vmem_limit_bytes, + use_fixed_m=use_fixed_m, + mk=mk, ) return _splash_attention diff --git a/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py index d8a80b512..76c9e386d 100644 --- a/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py +++ b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py @@ -846,6 +846,8 @@ def _custom_ring_attention_forward( ring_size: int | None = None, perm: list[tuple[int, int]] | None = None, bidirectional: bool = False, + use_fixed_m: bool = False, + mk: jax.Array | None = None, ) -> jax.Array: """Forward-only ring attention using the custom dense splash kernel. @@ -886,6 +888,8 @@ def _custom_ring_attention_forward( "bidirectional (wrap-free) ring requires perm=None and ring_size==axis_size " "(it operates on the full real ring axis)." ) + if use_fixed_m: + raise NotImplementedError("fixed-m is not yet supported on the bidirectional ring path.") return _custom_bidirectional_ring_forward( q, k, @@ -932,6 +936,8 @@ def body(carry, i): use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, vmem_limit_bytes=vmem_limit_bytes, + use_fixed_m=use_fixed_m, + mk=mk, ) m_curr = m_curr.astype(jnp.float32) l_curr = l_curr.astype(jnp.float32) @@ -972,6 +978,8 @@ def make_custom_ring_attention( ring_size: int | None = None, perm: list[tuple[int, int]] | None = None, bidirectional: bool = False, + use_fixed_m: bool = False, + mk: jax.Array | None = None, ): """Builds a forward-only ring-attention callable around the custom kernel. @@ -1006,6 +1014,8 @@ def _ring(q, k, v): ring_size=ring_size, perm=perm, bidirectional=bidirectional, + use_fixed_m=use_fixed_m, + mk=mk, ) return _ring diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f6edc8309..cb63c445c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,6 +15,7 @@ import contextlib import functools import math +import os from typing import Optional, Callable, Tuple, Any, Dict import flax.linen as nn from flax import nnx @@ -375,6 +376,11 @@ def _extract_custom_block_sizes(flash_block_sizes): bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) + # A BlockSizes object carries heads_per_tile=None when the config dict omitted + # it; getattr then returns that None instead of the default, so coerce it back + # to 1 (the custom-kernel default) to keep the `heads_per_tile > 1` guards safe. + if heads_per_tile is None: + heads_per_tile = 1 return bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes @@ -640,6 +646,7 @@ def _ulysses_attention( use_custom_kernel: bool = False, use_base2_exp: bool = True, use_experimental_scheduler: bool = False, + use_fixed_m: bool = False, ) -> jax.Array: """Ulysses sequence-parallel attention. @@ -693,10 +700,28 @@ def wrap_ulysses_attention(query, key, value): if use_base2_exp: query = query * LOG2E + if use_fixed_m: + # k-smoothing (output-invariant): subtracting the per-row key mean + # forces every logit row to have mean 0, hence row-max >= 0 — the + # precondition that keeps the fixed-m Cauchy-Schwarz bound flush-free. + key = key - jnp.mean(key, axis=2, keepdims=True) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq) key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv) value, _, _ = _pad_data_for_flash(value, heads, bkv) + mk_arr = None + if use_fixed_m: + # Per-(local-)head Cauchy-Schwarz inputs over the (batch, seq) slice; + # padded rows have zero norm and never raise the max. mk[0] feeds the + # in-kernel per-query bound, mk[1] flags heads within the no-flush gate. + qf = query.astype(jnp.float32) + kf = key.astype(jnp.float32) + qn_max = jnp.sqrt((qf * qf).sum(-1)).max(axis=(0, 2)) # (local_heads,) + mk_h = jnp.sqrt((kf * kf).sum(-1)).max(axis=(0, 2)) # (local_heads,) + fixed_ok = (qn_max * mk_h <= custom_splash._FIXED_M_SAFE_BOUND).astype(jnp.float32) + mk_arr = jnp.stack([mk_h, fixed_ok]) # (2, local_heads) + bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute) splash_kernel = custom_splash.make_splash_mha( @@ -708,10 +733,15 @@ def wrap_ulysses_attention(query, key, value): use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, vmem_limit_bytes=vmem_limit_bytes, + use_fixed_m=use_fixed_m, ) - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0)) - attention_output = vmapped_splash(query, key, value) + if use_fixed_m: + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + attention_output = vmapped_splash(query, key, value, mk_arr) + else: + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0)) + attention_output = vmapped_splash(query, key, value) attention_output = jnp.swapaxes(attention_output, 2, 3) attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) else: @@ -762,7 +792,26 @@ def wrap_ulysses_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) + + # Fold the (CFG) batch into the heads axis around the Ulysses exchange. + # Each (batch, head) pair is an independent attention problem, so + # [B, H, S, D] -> [1, B*H, S, D] is mathematically identity — but it makes + # XLA compile the attention path as the batch=1 case. At batch=2 XLA + # otherwise places the size-2 batch in the tile sublanes ({3,0,1,2:T(2,128)} + # instead of T(8,128)) which quadruples the cost of every op touching the + # a2a tensors inside the scanned layers (measured 7.0 -> expected ~3.5 + # s/step at 720p 81f cp8 CFG). + batch = query.shape[0] + fold_batch = batch > 1 and (batch * num_heads) % num_shards == 0 + if fold_batch: + query = query.reshape(1, batch * num_heads, *query.shape[2:]) + key = key.reshape(1, batch * num_heads, *key.shape[2:]) + value = value.reshape(1, batch * num_heads, *value.shape[2:]) + x = wrap_ulysses_attention(query, key, value) + + if fold_batch: + x = x.reshape(batch, num_heads, *x.shape[2:]) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -941,6 +990,7 @@ def _ulysses_ring_custom_attention( use_base2_exp: bool = True, use_experimental_scheduler: bool = False, bidirectional: bool = False, + use_fixed_m: bool = False, ) -> jax.Array: """Hybrid Ulysses + Ring (USP) with the CUSTOM splash kernel on main's mesh. @@ -1013,10 +1063,38 @@ def wrap_ulysses_ring_attention(query, key, value): if use_base2_exp: query = query * LOG2E + if use_fixed_m: + # K-smoothing precondition for fixed-m, computed PER SHARD (no ring pmean). + # A global mean would be a perfectly-uniform per-query logit shift, but the + # per-shard local mean differs from it by only O(1/sqrt(local_seq)), and the + # ring's outer online-softmax merge re-normalizes across shards anyway, so we + # drop the per-layer ring collective and accept the negligible shift error. + kbar = jnp.mean(key, axis=2, keepdims=True) + key = key - kbar + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq) key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv) value, _, _ = _pad_data_for_flash(value, heads, bkv) + mk_arr = None + if use_fixed_m: + # Per-(local-)head Cauchy-Schwarz inputs, all LOCAL to this ring shard. The + # outer ring merge does an online softmax across shards, so each shard's + # kernel may use its own local max||k|| as the fixed-m bound for its own + # local keys -- no global ring pmax is needed for correctness. This removes + # the second per-layer ring collective. + qf = query.astype(jnp.float32) + kf = key.astype(jnp.float32) + qn_max = jnp.sqrt((qf * qf).sum(-1)).max(axis=(0, 2)) # (local_heads,) + mk_h = jnp.sqrt((kf * kf).sum(-1)).max(axis=(0, 2)) # (local_heads,) local + fixed_ok = (qn_max * mk_h <= custom_splash._FIXED_M_SAFE_BOUND).astype(jnp.float32) + if os.environ.get("FIXED_M_FORCE_ALL", "0") == "1": + # PERF PROBE ONLY (unsafe): force every head onto the fixed-m fast path, + # bypassing the safety gate, to measure fixed-m's speed CEILING on the + # ring kernel. Output may be garbage; timing is still valid. + fixed_ok = jnp.ones_like(fixed_ok) + mk_arr = jnp.stack([mk_h, fixed_ok]) # (2, local_heads) + bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute) if num_ring_shards == 1: # (2a) R=1: the ring is trivial (no rotation) -> use the lighter dedicated @@ -1032,8 +1110,14 @@ def wrap_ulysses_ring_attention(query, key, value): use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, vmem_limit_bytes=vmem_limit_bytes, + use_fixed_m=use_fixed_m, ) - attention_output = jnp.swapaxes(jax.vmap(splash_kernel, in_axes=(0, 0, 0))(query, key, value), 2, 3) + if use_fixed_m: + attention_output = jnp.swapaxes( + jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))(query, key, value, mk_arr), 2, 3 + ) + else: + attention_output = jnp.swapaxes(jax.vmap(splash_kernel, in_axes=(0, 0, 0))(query, key, value), 2, 3) else: # (2b) Ring (full ppermute over the cross-chip ring axis) with the custom kernel. # bidirectional=True -> wrap-free schedule (streams K/V both directions one hop @@ -1049,6 +1133,8 @@ def wrap_ulysses_ring_attention(query, key, value): ring_axis=ring_axis, ring_size=num_ring_shards, bidirectional=bidirectional, + use_fixed_m=use_fixed_m, + mk=mk_arr, ) attention_output = jax.vmap(ring_kernel, in_axes=(0, 0, 0))(query, key, value) attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) @@ -1231,6 +1317,33 @@ def ulysses_ring_custom_kernel(q, k, v, context): ) +@register_kernel("ulysses_ring_custom_fixed_m") +def ulysses_ring_custom_fixed_m_kernel(q, k, v, context): + """fixed-m variant of ulysses_ring_custom: the per-shard custom splash kernel + uses the Cauchy-Schwarz fixed-m softmax bound (no in-kernel running-max + rescale). max||k|| and the K-smoothing mean are taken LOCALLY per ring shard + (no per-layer ring collective); the outer ring online-softmax merge still + re-normalizes across shards, so per-shard bounds stay correct.""" + return _ulysses_ring_custom_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + ulysses_shards=context["ulysses_shards"], + use_base2_exp=context.get("use_base2_exp", True), + use_experimental_scheduler=context.get("use_experimental_scheduler", False), + use_fixed_m=True, + ) + + @register_kernel("ulysses_ring_custom_bidir") def ulysses_ring_custom_bidir_kernel(q, k, v, context): """Wrap-free (bidirectional) variant of ulysses_ring_custom: the ring streams @@ -1256,6 +1369,28 @@ def ulysses_ring_custom_bidir_kernel(q, k, v, context): ) +@register_kernel("ulysses_custom_fixed_m") +def ulysses_custom_fixed_m_kernel(q, k, v, context): + return _ulysses_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + use_custom_kernel=True, + use_base2_exp=context.get("use_base2_exp", True), + use_experimental_scheduler=context.get("use_experimental_scheduler", False), + use_fixed_m=True, + ) + + @register_kernel("ulysses") def ulysses_kernel(q, k, v, context): return _ulysses_attention( @@ -1413,7 +1548,7 @@ def _apply_attention( seq_len_idx = 2 can_use_flash_attention = True - if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom", "ulysses_ring"]: + if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom", "ulysses_custom_fixed_m", "ulysses_ring"]: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length diff --git a/src/maxdiffusion/tests/custom_splash_fixed_m_test.py b/src/maxdiffusion/tests/custom_splash_fixed_m_test.py new file mode 100644 index 000000000..5571007dd --- /dev/null +++ b/src/maxdiffusion/tests/custom_splash_fixed_m_test.py @@ -0,0 +1,145 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Unit tests for the fixed-m path of the custom splash attention kernel. + +The fixed-m optimization replaces the online-softmax running max with a +precomputed per-query Cauchy-Schwarz bound for eligible heads, falling back to +online softmax for "sink" heads whose bound exceeds the no-flush gate. These +tests check that, mirroring the production calling convention, the kernel: + + * matches an f32 softmax reference for both online and fixed-m modes, + * produces fixed-m output that agrees with online output to bf16 precision, + * flags an out-of-gate head ineligible and falls back without NaNs. +""" + +import math +import unittest + +import jax +import jax.numpy as jnp + +from maxdiffusion.kernels import custom_splash_attention as custom_splash + +_LOG2E = math.log2(math.e) + + +class CustomSplashFixedMTest(unittest.TestCase): + """Numerical equivalence tests for the fixed-m kernel path.""" + + num_heads = 5 + seq_len = 4096 + head_dim = 128 + + def setUp(self): + super().setUp() + self.scale = 1.0 / math.sqrt(self.head_dim) + self.block_sizes = custom_splash._BlockSizes( + block_q=2048, block_kv=1024, block_kv_compute=512 + ) + + def _random_qkv( + self, q_gain: float = 1.0, k_gain: float = 1.0 + ) -> tuple[jax.Array, jax.Array, jax.Array]: + """Returns bf16 (q, k, v), optionally amplifying head 0 of q and k.""" + shape = (self.num_heads, self.seq_len, self.head_dim) + q = jax.random.normal(jax.random.PRNGKey(0), shape, jnp.bfloat16) + k = jax.random.normal(jax.random.PRNGKey(1), shape, jnp.bfloat16) + v = jax.random.normal(jax.random.PRNGKey(2), shape, jnp.bfloat16) + q = q.at[0].multiply(q_gain) + k = k.at[0].multiply(k_gain) + return q, k, v + + def _reference( + self, q: jax.Array, k: jax.Array, v: jax.Array + ) -> jax.Array: + """Per-head f32 softmax attention reference.""" + qf, kf, vf = (x.astype(jnp.float32) for x in (q, k, v)) + logits = jnp.einsum("hsd,htd->hst", qf, kf) * self.scale + probs = jax.nn.softmax(logits, axis=-1) + return jnp.einsum("hst,htd->hsd", probs, vf) + + def _run_kernel( + self, q: jax.Array, k: jax.Array, v: jax.Array, use_fixed_m: bool + ) -> tuple[jax.Array, jax.Array | None]: + """Runs the custom kernel using the production scaling convention. + + Args: + q: Query tensor of shape (heads, seq, dim). + k: Key tensor of shape (heads, seq, dim). + v: Value tensor of shape (heads, seq, dim). + use_fixed_m: Whether to enable the fixed-m bound path. + + Returns: + A tuple of the f32 attention output (heads, seq, dim) and the per-head + mk array (or None for the online path). + """ + q_in = (q * _LOG2E).astype(jnp.bfloat16) + k_in = k * self.scale + mk = None + if use_fixed_m: + # k-smoothing makes every logit row mean-zero so row-max >= 0. + k_in = k_in - jnp.mean(k_in, axis=1, keepdims=True) + qn = jnp.sqrt((q_in.astype(jnp.float32) ** 2).sum(-1)).max(axis=1) + mk_h = jnp.sqrt((k_in.astype(jnp.float32) ** 2).sum(-1)).max(axis=1) + eligible = (qn * mk_h <= custom_splash._FIXED_M_SAFE_BOUND).astype( + jnp.float32 + ) + mk = jnp.stack([mk_h, eligible]) + kernel = custom_splash.make_splash_mha( + block_sizes=self.block_sizes, + bkv_compute_in=256, + orig_q_seq_len=self.seq_len, + orig_kv_seq_len=self.seq_len, + use_base2_exp=True, + use_fixed_m=use_fixed_m, + ) + out = kernel(q_in, k_in, v, mk) if use_fixed_m else kernel(q_in, k_in, v) + out = jnp.swapaxes(out, 1, 2) # (heads, dim, seq) -> (heads, seq, dim) + return out.astype(jnp.float32), mk + + def test_online_matches_reference(self): + """Online softmax path agrees with the f32 reference at bf16 precision.""" + q, k, v = self._random_qkv() + online, _ = self._run_kernel(q, k, v, use_fixed_m=False) + self.assertLess(float(jnp.max(jnp.abs(online - self._reference(q, k, v)))), 2e-2) + + def test_fixed_m_matches_online_when_all_eligible(self): + """With uniform data all heads are eligible and match online output.""" + q, k, v = self._random_qkv() + online, _ = self._run_kernel(q, k, v, use_fixed_m=False) + fixed, mk = self._run_kernel(q, k, v, use_fixed_m=True) + self.assertTrue(bool(jnp.all(mk[1] > 0.5))) # every head eligible + self.assertTrue(bool(jnp.all(jnp.isfinite(fixed)))) + self.assertLess(float(jnp.max(jnp.abs(fixed - online))), 5e-3) + + def test_fixed_m_matches_reference(self): + """Fixed-m output agrees with the f32 softmax reference.""" + q, k, v = self._random_qkv() + fixed, _ = self._run_kernel(q, k, v, use_fixed_m=True) + self.assertLess(float(jnp.max(jnp.abs(fixed - self._reference(q, k, v)))), 2e-2) + + def test_sink_head_falls_back_to_online(self): + """An out-of-gate head is flagged ineligible and stays finite (no flush).""" + q, k, v = self._random_qkv(q_gain=6.0, k_gain=6.0) + fixed, mk = self._run_kernel(q, k, v, use_fixed_m=True) + self.assertEqual(float(mk[1][0]), 0.0) # head 0 is a sink -> ineligible + self.assertTrue(bool(jnp.all(mk[1][1:] > 0.5))) # the rest stay eligible + self.assertTrue(bool(jnp.all(jnp.isfinite(fixed)))) + + +if __name__ == "__main__": + unittest.main()