Skip to content

Refactor custom_splash_attention for improved clarity#431

Open
Perseus14 wants to merge 1 commit into
mainfrom
custom_attn_fix
Open

Refactor custom_splash_attention for improved clarity#431
Perseus14 wants to merge 1 commit into
mainfrom
custom_attn_fix

Conversation

@Perseus14

@Perseus14 Perseus14 commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Overview

This PR refactors and cleans up the custom Splash attention kernels by properly encapsulating the inner-VPU tiling step (block_kv_compute_in) and removing redundant parameters. It also aligns the newly introduced custom Ring Attention implementations with these cleaned-up signatures.

Changes

  • custom_splash_attention.py:
    • Added block_kv_compute_in directly to the _BlockSizes dataclass, streamlining parameter passing.
    • Removed unused high-level attention wrappers (tpu_custom_attention, make_custom_splash_sdpa) since orchestration is now fully handled in attention_flax.py.
    • Cleaned up _flash_attention_kernel and _splash_attention_forward_ring by stripping out redundant arguments like bq, q_seq_len, and explicit bkv_compute_in passes.
  • ring_attention_kernel.py:
    • Updated the function signatures for make_custom_ring_attention, _custom_ring_attention_forward, and _custom_bidirectional_ring_forward to drop the explicit bkv_compute_in argument, correctly extracting it from block_sizes instead.
  • attention_flax.py:
    • Updated all instantiations of custom_splash._BlockSizes to pass block_kv_compute_in.
    • Fixed downstream calls to make_splash_mha and make_custom_ring_attention to remove the now-redundant bkv_compute_in kwarg.

Impact

  • Reduces parameter bloat across the low-level Pallas kernels.
  • Ensures API consistency across standard Ulysses attention and Tokamax Ring Attention.

@Perseus14 Perseus14 requested a review from entrpn as a code owner June 29, 2026 06:30
@github-actions

Copy link
Copy Markdown

@Perseus14 Perseus14 self-assigned this Jun 29, 2026
@Perseus14 Perseus14 requested review from csgoogle and eltsai June 29, 2026 06:57
csgoogle
csgoogle previously approved these changes Jun 29, 2026
Comment thread src/maxdiffusion/kernels/custom_splash_attention.py
entrpn
entrpn previously approved these changes Jun 30, 2026
eltsai
eltsai previously approved these changes Jun 30, 2026

@eltsai eltsai left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up @Perseus14 ! LGTM

@Perseus14 Perseus14 dismissed stale reviews from eltsai, entrpn, and csgoogle via 632e774 July 1, 2026 18:04
@Perseus14

Copy link
Copy Markdown
Collaborator Author

@eltsai Could you check again? I have made minor cleanup to ring attention implementation as well

@Perseus14 Perseus14 requested review from csgoogle and eltsai July 1, 2026 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants