Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 120 additions & 57 deletions src/maxdiffusion/kernels/custom_splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,21 @@ def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = N
self.block_kv_compute = block_kv_compute if block_kv_compute is not None else block_kv


# Fixed-m softmax-bound constants. Instead of tracking the online-softmax
# running max per KV block, eligible heads subtract a precomputed per-query
# upper bound on the logits (Cauchy-Schwarz: max_j q_i.k_j <= ||q_i|| *
# max_j||k_j||). _FIXED_M_RECENTER (C) shifts the exp2 exponents up so the
# largest surviving term stays above the f32 subnormal-flush floor 2^-126:
# with k-smoothing the per-row max is >= 0, so the max term has exponent
# >= -ceil(bound) + C, which stays > -126 while ceil(bound) <=
# _FIXED_M_SAFE_BOUND (= C + 126 - 1 of margin). Heads whose worst-case bound
# exceeds the gate fall back to online softmax (the "sink" heads).
_FIXED_M_RECENTER = 88.0
_FIXED_M_SAFE_BOUND = 213.0


def _flash_attention_kernel(
mk_ref,
q_ref,
k_ref,
v_ref,
Expand All @@ -73,34 +87,45 @@ def _flash_attention_kernel(
kv_seq_len: int,
use_base2_exp: bool = True,
fuse_reciprocal: bool = True,
use_fixed_m: bool = False,
):
float32 = jnp.float32
head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES)
if rem != 0:
raise NotImplementedError(f"{head_dim_v=} should be a multiple of {NUM_SUBLANES}")

_, _, j = pl.program_id(0), pl.program_id(1), pl.program_id(2)
h, _, j = pl.program_id(0), pl.program_id(1), pl.program_id(2)
exp = jnp.exp2 if use_base2_exp else jnp.exp
sv_dims = (((0,), (0,)), ((), ()))

# Per-head dispatch: heads inside the no-flush window run fixed-m, the rest
# keep online softmax. Branch once per head (body level), never per step.
is_fixed = (mk_ref[1, h] > 0.5) if use_fixed_m else False

@pl.when(j == 0)
def init():
o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref)
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
if use_fixed_m:

def compute_body(kv_compute_index, _):
m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...]
q = q_ref[...]
o_prev = o_scratch_ref[:]
@pl.when(is_fixed)
def _init_fixed():
# Per-query Cauchy-Schwarz bound m_i = ceil(||q_i|| * max_j||k_j||) - C.
qf = q_ref[...].astype(float32)
qn = jnp.sqrt((qf * qf).sum(axis=1))[None, :] # (1, bq) per-query norm
bound = qn * mk_ref[0, h]
m_fixed = jnp.ceil(bound) - _FIXED_M_RECENTER
m_scratch_ref[...] = jnp.broadcast_to(m_fixed, m_scratch_ref.shape)

base_offset = kv_compute_index * bkv_compute
slice_k = pl.ds(base_offset, bkv_compute)
k_chunk = k_ref[slice_k, :]
@pl.when(jnp.logical_not(is_fixed))
def _init_online():
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)

qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32)
v_chunk = v_ref[slice_k, :]
else:
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)

# --- V1 VPU REGISTER TILING ---
def _online_inner(qk, v_chunk, m_prev, l_prev, o_prev):
# Standard online-softmax tiling over the VPU register block.
step = bkv_compute_in
for i in range(0, qk.shape[0], step):
qk_slice = qk[i : i + step]
Expand All @@ -113,84 +138,99 @@ def compute_body(kv_compute_index, _):
alpha = exp(m_prev - m_next)
l_next = l_curr + alpha * l_prev

sv_dims = (((0,), (0,)), ((), ()))
o_curr = lax.dot_general(
v_chunk[i : i + step],
s_curr.astype(q_ref.dtype),
sv_dims,
preferred_element_type=float32,
)

alpha_o = alpha[0:1, ...]
o_prev = alpha_o * o_prev + o_curr

o_prev = alpha[0:1, ...] * o_prev + o_curr
m_prev, l_prev = m_next, l_next
# --- END V1 TILING ---

m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev
o_scratch_ref[:] = o_prev
return m_prev, l_prev, o_prev

def last_compute_body(kv_compute_index):
m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...]
q = q_ref[...]
o_prev = o_scratch_ref[:]

slice_k_len = kv_seq_len % bkv_compute
slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len)
k_chunk = k_ref[slice_k, :]

qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32)
v_chunk = v_ref[slice_k, :]

# --- V1 VPU REGISTER TILING ---
def _fixed_inner(qk, v_chunk, m_fix, l_prev, o_prev):
# Fixed-m fast path: m is constant, so no reduce-max and no alpha rescale.
step = bkv_compute_in
for i in range(0, qk.shape[0], step):
qk_slice = qk[i : i + step]

m_curr = qk_slice.max(axis=0)[None, :]
m_next = jnp.maximum(m_prev, m_curr)
s_curr = exp(qk_slice - m_next[0:1])
s_curr = exp(qk_slice - m_fix[0:1])
l_curr = s_curr.sum(axis=0, keepdims=True)

alpha = exp(m_prev - m_next)
l_next = l_curr + alpha * l_prev

sv_dims = (((0,), (0,)), ((), ()))
o_curr = lax.dot_general(
v_chunk[i : i + step],
s_curr.astype(q_ref.dtype),
sv_dims,
preferred_element_type=float32,
)
o_prev = o_prev + o_curr
l_prev = l_prev + l_curr
return l_prev, o_prev

alpha_o = alpha[0:1, ...]
o_prev = alpha_o * o_prev + o_curr
def compute_body_online(kv_compute_index, _):
q = q_ref[...]
base_offset = kv_compute_index * bkv_compute
slice_k = pl.ds(base_offset, bkv_compute)
qk = lax.dot_general(k_ref[slice_k, :], q, NT_DIM_NUMBERS, preferred_element_type=float32)
v_chunk = v_ref[slice_k, :]
m_prev, l_prev, o_prev = _online_inner(qk, v_chunk, m_scratch_ref[...], l_scratch_ref[...], o_scratch_ref[:])
m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev
o_scratch_ref[:] = o_prev

m_prev, l_prev = m_next, l_next
# --- END V1 TILING ---
def compute_body_fixed(kv_compute_index, _):
q = q_ref[...]
base_offset = kv_compute_index * bkv_compute
slice_k = pl.ds(base_offset, bkv_compute)
qk = lax.dot_general(k_ref[slice_k, :], q, NT_DIM_NUMBERS, preferred_element_type=float32)
v_chunk = v_ref[slice_k, :]
l_prev, o_prev = _fixed_inner(qk, v_chunk, m_scratch_ref[...], l_scratch_ref[...], o_scratch_ref[:])
l_scratch_ref[...] = l_prev
o_scratch_ref[:] = o_prev

def last_compute_body_online(kv_compute_index):
q = q_ref[...]
slice_k_len = kv_seq_len % bkv_compute
slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len)
qk = lax.dot_general(k_ref[slice_k, :], q, NT_DIM_NUMBERS, preferred_element_type=float32)
v_chunk = v_ref[slice_k, :]
m_prev, l_prev, o_prev = _online_inner(qk, v_chunk, m_scratch_ref[...], l_scratch_ref[...], o_scratch_ref[:])
m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev
o_scratch_ref[:] = o_prev

assert bkv % bkv_compute == 0

@pl.when(j != grid_width - 1)
def body():
lax.fori_loop(0, (bkv // bkv_compute), compute_body, None, unroll=True)
if use_fixed_m:

@pl.when((j != grid_width - 1) & is_fixed)
def _body_fixed():
lax.fori_loop(0, (bkv // bkv_compute), compute_body_fixed, None, unroll=True)

@pl.when((j != grid_width - 1) & jnp.logical_not(is_fixed))
def _body_online():
lax.fori_loop(0, (bkv // bkv_compute), compute_body_online, None, unroll=True)

else:

@pl.when(j != grid_width - 1)
def body():
lax.fori_loop(0, (bkv // bkv_compute), compute_body_online, None, unroll=True)

# The final KV block always runs online for every head. Fixed-m heads arrive
# with m_scratch = ceil(bound) - C and o/l at that reference; the online
# alpha = exp2(m_prev - m_next) rescale renormalizes them transparently.
@pl.when(j == grid_width - 1)
def last_body():
if kv_seq_len % bkv == 0:
iter_num = bkv // bkv_compute
lax.fori_loop(0, iter_num, compute_body, None, unroll=True)
lax.fori_loop(0, iter_num, compute_body_online, None, unroll=True)
else:
remain_kv_seq_len = kv_seq_len % bkv
iter_num = (remain_kv_seq_len + bkv_compute - 1) // bkv_compute
if remain_kv_seq_len % bkv_compute == 0:
lax.fori_loop(0, iter_num, compute_body, None, unroll=True)
lax.fori_loop(0, iter_num, compute_body_online, None, unroll=True)
else:
lax.fori_loop(0, iter_num - 1, compute_body, None, unroll=True)
last_compute_body(iter_num - 1)
lax.fori_loop(0, iter_num - 1, compute_body_online, None, unroll=True)
last_compute_body_online(iter_num - 1)

@pl.when(j == grid_width - 1)
def end():
Expand Down Expand Up @@ -373,9 +413,16 @@ def _splash_attention_forward(
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
use_fixed_m: bool = False,
mk: jax.Array | None = None,
):
num_q_heads, padded_q_seq_len, head_dim_qk = q.shape
head_dim_v = v.shape[-1]
# Scalar-prefetch operand carrying per-head fixed-m data:
# mk[0, h] = max_j||k_j|| (Cauchy-Schwarz factor), mk[1, h] = eligibility.
# A dummy is supplied for online callers; the kernel ignores it.
if mk is None:
mk = jnp.zeros((2, num_q_heads), jnp.float32)
bq, bkv = block_sizes.block_q, block_sizes.block_kv
bkv_compute = block_sizes.block_kv_compute
num_kv_heads = k.shape[0]
Expand Down Expand Up @@ -431,9 +478,10 @@ def v_index_map(h, i, j, *_):
q_seq_len=actual_q_seq_len,
kv_seq_len=actual_kv_seq_len,
use_base2_exp=use_base2_exp,
use_fixed_m=use_fixed_m,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
num_scalar_prefetch=1,
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
Expand All @@ -446,7 +494,7 @@ def v_index_map(h, i, j, *_):
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=out_shapes,
)(q, k, v)
)(mk, q, k, v)
return all_out[-1]


Expand All @@ -461,6 +509,8 @@ def _splash_attention_forward_ring(
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
use_fixed_m: bool = False,
mk: jax.Array | None = None,
):
"""Ring-specific forward path that returns pre-reciprocal fp32 accumulators.

Expand Down Expand Up @@ -524,6 +574,13 @@ def v_index_map(h, i, j, *_):
grid_height = (actual_q_seq_len + bq - 1) // bq
grid = (num_q_heads, grid_height, grid_width)

# Scalar-prefetch operand carrying per-head fixed-m data (same convention as
# `_splash_attention_forward`): mk[0, h] = max_j||k_j|| over ALL ring shards
# (the caller all-reduces this over the ring axis), mk[1, h] = eligibility.
# A dummy is supplied for online callers; the kernel ignores it.
if mk is None:
mk = jnp.zeros((2, num_q_heads), jnp.float32)

all_out = pl.pallas_call(
functools.partial(
_flash_attention_kernel,
Expand All @@ -538,9 +595,10 @@ def v_index_map(h, i, j, *_):
kv_seq_len=actual_kv_seq_len,
use_base2_exp=use_base2_exp,
fuse_reciprocal=False,
use_fixed_m=use_fixed_m,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
num_scalar_prefetch=1,
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
Expand All @@ -553,7 +611,7 @@ def v_index_map(h, i, j, *_):
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=out_shapes,
)(q, k, v)
)(mk, q, k, v)
out = jnp.swapaxes(all_out[3], 1, 2) # (h, head_dim_v, s) -> (h, s, head_dim_v)
l = all_out[4][:, 0, :] # (h, s)
m = all_out[5][:, 0, :] # (h, s)
Expand Down Expand Up @@ -660,9 +718,12 @@ def make_splash_mha(
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
use_fixed_m: bool = False,
):
def _splash_attention(q, k, v):
def _splash_attention(q, k, v, mk=None):
if heads_per_tile > 1:
if use_fixed_m:
raise NotImplementedError("fixed-m is not supported with heads_per_tile > 1")
return _splash_attention_forward_mhpt(
q,
k,
Expand All @@ -687,6 +748,8 @@ def _splash_attention(q, k, v):
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
use_fixed_m=use_fixed_m,
mk=mk,
)

return _splash_attention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ def _custom_ring_attention_forward(
ring_size: int | None = None,
perm: list[tuple[int, int]] | None = None,
bidirectional: bool = False,
use_fixed_m: bool = False,
mk: jax.Array | None = None,
) -> jax.Array:
"""Forward-only ring attention using the custom dense splash kernel.

Expand Down Expand Up @@ -886,6 +888,8 @@ def _custom_ring_attention_forward(
"bidirectional (wrap-free) ring requires perm=None and ring_size==axis_size "
"(it operates on the full real ring axis)."
)
if use_fixed_m:
raise NotImplementedError("fixed-m is not yet supported on the bidirectional ring path.")
return _custom_bidirectional_ring_forward(
q,
k,
Expand Down Expand Up @@ -932,6 +936,8 @@ def body(carry, i):
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
use_fixed_m=use_fixed_m,
mk=mk,
)
m_curr = m_curr.astype(jnp.float32)
l_curr = l_curr.astype(jnp.float32)
Expand Down Expand Up @@ -972,6 +978,8 @@ def make_custom_ring_attention(
ring_size: int | None = None,
perm: list[tuple[int, int]] | None = None,
bidirectional: bool = False,
use_fixed_m: bool = False,
mk: jax.Array | None = None,
):
"""Builds a forward-only ring-attention callable around the custom kernel.

Expand Down Expand Up @@ -1006,6 +1014,8 @@ def _ring(q, k, v):
ring_size=ring_size,
perm=perm,
bidirectional=bidirectional,
use_fixed_m=use_fixed_m,
mk=mk,
)

return _ring
Loading
Loading