Refactor custom_splash_attention for improved clarity#431
Open
Perseus14 wants to merge 1 commit into
Open
Conversation
csgoogle
previously approved these changes
Jun 29, 2026
entrpn
reviewed
Jun 30, 2026
entrpn
previously approved these changes
Jun 30, 2026
eltsai
previously approved these changes
Jun 30, 2026
eltsai
left a comment
Collaborator
There was a problem hiding this comment.
Thanks for cleaning this up @Perseus14 ! LGTM
fb64673 to
632e774
Compare
Collaborator
Author
|
@eltsai Could you check again? I have made minor cleanup to ring attention implementation as well |
632e774 to
f6c2c11
Compare
f6c2c11 to
fd56996
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Overview
This PR refactors and cleans up the custom Splash attention kernels by properly encapsulating the inner-VPU tiling step (
block_kv_compute_in) and removing redundant parameters. It also aligns the newly introduced custom Ring Attention implementations with these cleaned-up signatures.Changes
custom_splash_attention.py:block_kv_compute_indirectly to the_BlockSizesdataclass, streamlining parameter passing.tpu_custom_attention,make_custom_splash_sdpa) since orchestration is now fully handled inattention_flax.py._flash_attention_kerneland_splash_attention_forward_ringby stripping out redundant arguments likebq,q_seq_len, and explicitbkv_compute_inpasses.ring_attention_kernel.py:make_custom_ring_attention,_custom_ring_attention_forward, and_custom_bidirectional_ring_forwardto drop the explicitbkv_compute_inargument, correctly extracting it fromblock_sizesinstead.attention_flax.py:custom_splash._BlockSizesto passblock_kv_compute_in.make_splash_mhaandmake_custom_ring_attentionto remove the now-redundantbkv_compute_inkwarg.Impact