Skip to content

[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit#97

Open
fkuner wants to merge 14 commits into
inclusionAI:mainfrom
fkuner:la-decode-kvbuffer
Open

[LA] Lightning Attention MTP decode + KVBuffer parallel verify / commit#97
fkuner wants to merge 14 commits into
inclusionAI:mainfrom
fkuner:la-decode-kvbuffer

Conversation

@fkuner

@fkuner fkuner commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

📌 Description

Adds target-side Lightning Attention (LA) support for speculative decoding (multi-token prediction), in two complementary pieces:

  • Fused MTP decode (la_decode_mtp): a single launch advances the LA recurrence over all T draft tokens per (batch, head), with ILP variants and a work-unit heuristic (get_mtp_config); packed F32×2 FMA on SM100.
  • KVBuffer parallel verify + state-update (la_verify_kvbuffer / la_state_update_kvbuffer): the parallel-verification path. Verify computes every draft step's output in closed form straight from (h0, k, v) — without materializing the T intermediate states — and optionally writes k/v into a compact pooled KV buffer; state-update (commit) then advances the pooled state by the per-request accepted prefix length L read from that buffer.

Closed form used by verify (per draft step t):

o_t = α^{t+1} · (h0 · q_t · scale)                  # term1  (h0–Q GEMM)
    + Σ_{i=0..t} α^{t-i} · (q_t · k_i · scale) · v_i  # term2  (Q–K GEMM, then ·V)

The two dot-product GEMMs run on Blackwell tensor cores via inline-PTX mma.sync.m16n8k8 (TF32, fp32 SMEM staging); everything downstream is plain scalar math. M/N are padded to BT=8, so any draft length T ∈ [1, 8] (odd or even) is handled.

What changed

  • cula/lightning/la_decode_mtp.py — fused MTP decode kernel + config heuristic + shared dot/update helpers.
  • cula/lightning/la_verify_kvbuffer.py — KVBuffer verify (TF32 MMA, register-shuffle variant) + optional KV-buffer write.
  • cula/lightning/la_state_update_kvbuffer.py — KVBuffer state-update (commit) kernel; per-request accepted_len, skips padded slots (h0_indices < 0).
  • tests/test_la_decode_mtp.py, tests/test_la_kvbuffer.py, — correctness vs a PyTorch reference.
  • benchmarks/bench_la_decode_mtp.py, benchmarks/bench_la_kvbuffer.py — MTP-decode bench, and verify+commit chain bench with an optional SGLang baseline (no hard dependency on a local SGLang checkout).

🔍 Related Issues

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit.
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

All kernels are checked against a PyTorch reference. Coverage:

  • MTP decode — output vs reference across batch/T/head shapes.
  • KVBuffer verify — output vs reference, including odd T (1, 3, 5, 7) and the M/N-padding path; negative h0_indices slots left untouched.
  • State-update (commit) — full / partial / per-request accept length, L=0 no-op, and skipped (padded) slots.
$ pytest tests/test_la_decode_mtp.py tests/test_la_kvbuffer.py -q
Test file Result
tests/test_la_decode_mtp.py 15 passed
tests/test_la_kvbuffer.py 35 passed
Total 50 passed in 51.68s

⚡ Performance

B200 (SM100), K=V=128, bf16 in / fp32 state, accept m = full (L = T), kernel-only timing.
Baseline = SGLang upstream chain (seg_la_mtp_kernel verify + fused_mamba_state_scatter_with_mask commit, both Triton). Each cell = verify + commit chain speedup = sglang_total / cuLA_total; >1 means cuLA is faster. H == HV (no GQA) for a fair comparison.

HV = 32

B T cu_vfy cu_cmt cu_total sg_vfy sg_cmt sg_total speedup
1 2 0.0162 0.0151 0.0312 0.0348 0.0391 0.0739 2.36x
2 2 0.0155 0.0149 0.0304 0.0344 0.0390 0.0734 2.41x
4 2 0.0153 0.0149 0.0302 0.0347 0.0383 0.0730 2.42x
8 2 0.0157 0.0151 0.0308 0.0346 0.0383 0.0729 2.37x
16 2 0.0172 0.0167 0.0340 0.0392 0.0392 0.0784 2.31x
32 2 0.0187 0.0192 0.0379 0.0482 0.0430 0.0912 2.40x
64 2 0.0336 0.0501 0.0837 0.0831 0.0502 0.1333 1.59x
128 2 0.0541 0.0910 0.1451 0.1485 0.0907 0.2392 1.65x
1 4 0.0156 0.0150 0.0305 0.0350 0.0402 0.0753 2.46x
2 4 0.0153 0.0148 0.0301 0.0338 0.0377 0.0715 2.38x
4 4 0.0155 0.0148 0.0303 0.0339 0.0378 0.0717 2.37x
8 4 0.0159 0.0150 0.0309 0.0386 0.0384 0.0770 2.49x
16 4 0.0186 0.0181 0.0367 0.0443 0.0387 0.0830 2.26x
32 4 0.0212 0.0197 0.0409 0.0732 0.0422 0.1155 2.82x
64 4 0.0458 0.0525 0.0983 0.1289 0.0498 0.1788 1.82x
128 4 0.0790 0.0943 0.1733 0.2508 0.0898 0.3406 1.97x
1 8 0.0160 0.0154 0.0315 0.0347 0.0388 0.0734 2.33x
2 8 0.0158 0.0154 0.0313 0.0348 0.0389 0.0737 2.36x
4 8 0.0168 0.0154 0.0322 0.0373 0.0394 0.0767 2.38x
8 8 0.0189 0.0175 0.0364 0.0443 0.0388 0.0831 2.28x
16 8 0.0198 0.0192 0.0390 0.0719 0.0384 0.1103 2.83x
32 8 0.0254 0.0254 0.0508 0.1223 0.0441 0.1664 3.28x
64 8 0.0540 0.0598 0.1139 0.2322 0.0496 0.2818 2.47x
128 8 0.0950 0.1029 0.1979 0.4407 0.0901 0.5308 2.68x

HV = 64

B T cu_vfy cu_cmt cu_total sg_vfy sg_cmt sg_total speedup
1 2 0.0168 0.0156 0.0324 0.0376 0.0425 0.0801 2.47x
2 2 0.0160 0.0156 0.0316 0.0382 0.0424 0.0805 2.55x
4 2 0.0160 0.0154 0.0313 0.0387 0.0421 0.0808 2.58x
8 2 0.0176 0.0171 0.0347 0.0417 0.0428 0.0844 2.44x
16 2 0.0196 0.0191 0.0386 0.0483 0.0472 0.0956 2.48x
32 2 0.0336 0.0501 0.0837 0.0830 0.0503 0.1333 1.59x
64 2 0.0542 0.0911 0.1453 0.1486 0.0905 0.2391 1.65x
128 2 0.0952 0.1709 0.2660 0.2819 0.1694 0.4513 1.70x
1 4 0.0156 0.0151 0.0308 0.0373 0.0426 0.0798 2.59x
2 4 0.0159 0.0154 0.0313 0.0382 0.0416 0.0798 2.55x
4 4 0.0159 0.0167 0.0326 0.0410 0.0422 0.0832 2.55x
8 4 0.0191 0.0183 0.0373 0.0444 0.0425 0.0869 2.33x
16 4 0.0213 0.0203 0.0416 0.0731 0.0476 0.1207 2.90x
32 4 0.0455 0.0527 0.0982 0.1299 0.0499 0.1799 1.83x
64 4 0.0789 0.0942 0.1731 0.2511 0.0902 0.3414 1.97x
128 4 0.1393 0.1741 0.3134 0.4804 0.1688 0.6492 2.07x
1 8 0.0160 0.0153 0.0313 0.0377 0.0425 0.0802 2.56x
2 8 0.0166 0.0154 0.0320 0.0401 0.0417 0.0818 2.56x
4 8 0.0180 0.0172 0.0352 0.0442 0.0413 0.0855 2.42x
8 8 0.0195 0.0188 0.0383 0.0714 0.0421 0.1135 2.96x
16 8 0.0254 0.0254 0.0508 0.1241 0.0469 0.1710 3.36x
32 8 0.0540 0.0602 0.1142 0.2317 0.0501 0.2817 2.47x
64 8 0.0951 0.1032 0.1983 0.4430 0.0897 0.5328 2.69x
128 8 0.1689 0.1877 0.3566 0.8720 0.1688 1.0407 2.92x

Memory (B=128, T=8): cuLA KV buffer 16.8 MB (HV=32) / 33.6 MB (HV=64) vs SGLang intermediate caches 2148 MB / 4295 MB → 128× less in both cases.

Takeaways

  • Chain vs SGLang: 1.59× – 3.36× across every shape (both HV), strongest at T = 8 and mid batch (B = 16–32).
  • Verify scales flat in T. At B=128 the cuLA verify kernel grows only +76% (HV=32) / +77% (HV=64) from T=2→8, while SGLang's Triton
    verify grows +197% / +209%. The verify kernel alone (accept-independent) reaches up to ~5.2× (B=128, T=8, HV=64; 4.6× at HV=32) — the
    tensor-core, closed-form parallel verification paying off as T grows.
  • Memory: 128× less rollback storage. The pooled KV buffer (k,v per draft token) replaces SGLang's T per-token d×d intermediate
    states. At B=128, T=8: 16.8 MB vs 2147 MB (HV=32) and 33.6 MB vs 4295 MB (HV=64) — independent of latency.
  • Correctness: cuLA RMSE ≤ 2.5e-3 vs the PyTorch reference (bf16 in / fp32 state), at or below SGLang's own RMSE on the same inputs.

Speed of Light (NCU, B=128, H=HV=32, K=V=128)

la_decode_mtp (decode kernel)

T Memory SOL Compute (SM) DRAM L1/TEX Duration
2 74.4% 58.7% 73.7% 77.4% 154 µs
4 79.6% 58.5% 72.0% 81.5% 271 µs
8 83.3% 57.4% 69.9% 84.6% 514 µs

Memory-bound (83% mem SOL @ T=8) — consistent with LA's ~30%-lighter per-step compute vs GDN (no delta rule). Small-tile config maximizes occupancy to feed HBM bandwidth.

la_verify_kvbuffer (verify kernel)

T Memory SOL Compute (SM) DRAM L1/TEX Duration path
2 78.3% 60.9% 54.7% 86.0% 78 µs shuffle
4 69.6% 42.3% 45.7% 76.8% 95 µs MMA
8 74.1% 45.3% 43.5% 81.0% 107 µs MMA

L1-cache-bound (L1/TEX 77–86%, DRAM only 43–55%) — q/k/v are small enough to reside in cache and are reused. Large-tile config (tile_v=128, ilp=8) amortizes q/k SMEM staging and fills the m16n8k8 MMA tiles.

Reviewer Notes

  • The public linear_attention_verify_kvbuffer dispatches by draft depth: the inline-PTX mma.sync.m16n8k8 tensor-core kernel for T >= 4, and a warp-shuffle kernel for T < 4 (where the MMA GEMMs would be under-utilized).
  • Verify's KV-buffer write is optional (write_kv) so the verify kernel can run standalone (verify-only path above) or fused with the buffer write.
  • The commit kernel's recurrence body is bit-identical to the baseline T-loop, so at L == T the committed state is bit-equivalent to running the baseline with disable_state_update=False.
  • The benchmark treats SGLang as an optional baseline (set LA_SGLANG_PYTHON to enable); without it the bench still validates against the PyTorch reference.

@fkuner fkuner requested a review from icavan June 21, 2026 05:39

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Lightning Attention MTP (Multi-Token Processing) decode and KVBuffer verify/state-update kernels to optimize speculative-decoding verification scenarios, along with corresponding benchmarks and unit tests. It also adds a Global-to-Register (G2R) prototype to optimize the big-batch decode path. The review feedback is highly constructive and identifies critical correctness and safety issues that must be addressed. Specifically, defensive checks should be added to Python entry points to prevent out-of-bounds memory accesses when the number of tokens T > 8 or head dimension K != 128. Additionally, wrappers must validate that the head dimension V is a multiple of ilp_rows to prevent silent correctness bugs where boundary chunks are skipped. Finally, environment variable lookups in the hot path of linear_attention_decode should be cached at the module level to eliminate unnecessary Python overhead.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread cula/lightning/la_decode_mtp.py
Comment thread cula/lightning/la_decode_mtp.py Outdated
Comment thread cula/lightning/la_verify_kvbuffer.py
Comment thread cula/lightning/la_verify_kvbuffer.py Outdated
Comment thread cula/lightning/la_verify_kvbuffer.py
Comment thread cula/lightning/la_verify_kvbuffer.py Outdated
Comment thread cula/lightning/la_state_update_kvbuffer.py
Comment thread cula/lightning/la_state_update_kvbuffer.py
Comment thread cula/ops/la_decode.py Outdated
Comment thread cula/ops/la_decode.py Outdated
fankun.fan added 3 commits June 21, 2026 13:44
Fused multi-token (MTP) Lightning Attention decode kernel for speculative
decoding: a single launch processes T draft tokens, with ILP variants and a
work-unit heuristic (get_mtp_config). Includes packed F32x2 FMA on SM100.

- cula/lightning/la_decode_mtp.py: kernel + config + shared dot/update helpers
- tests/test_la_decode_mtp.py + tests/_la_mtp_ref.py: correctness vs PyTorch ref
- benchmarks/bench_la_decode_mtp.py: vs sequential decode and FLA, with SOL model
KVBuffer-backed Lightning Attention for speculative decode verify/commit.
Verify computes each draft step's output in closed form (paper Eq. 7) with
the two dot-product GEMMs on tensor cores via inline-PTX mma.sync.m16n8k8
(TF32); state-update commits the accepted prefix into the pooled state
(paper Eq. 8), bit-equivalent to the baseline T-loop at L == T.

- cula/lightning/la_verify_kvbuffer.py: TF32 MMA verify kernel (+ shuffle variant)
- cula/lightning/la_update_kvbuffer.py: KV buffer state-update (commit) kernel
- tests/test_la_kvbuffer.py: correctness vs PyTorch ref (verify + update)
- benchmarks/bench_la_kvbuffer.py: vs SGLang verify+commit (optional), with SOL model
- test_la_kvbuffer.py: add odd-T cases (verify T=1,3,5,7; state-update T=3,7)
  to guard the BT=8 M/N padding path that handles non-even draft lengths.
- la_verify_kvbuffer.py: the shuffle launcher's SMEM byte estimate omitted the
  16B per-allocation alignment padding (4 SMEM tensors), so the declared launch
  size could fall ~12B short of actual usage and trip CUTLASS's size check. Add
  the 4*16 padding term, matching the main-kernel launcher.
@fkuner fkuner force-pushed the la-decode-kvbuffer branch 4 times, most recently from 971d524 to d564e14 Compare June 21, 2026 16:55
Comment thread benchmarks/bench_la_kvbuffer.py Outdated
Comment thread benchmarks/bench_la_decode_mtp.py Outdated
@fkuner fkuner force-pushed the la-decode-kvbuffer branch 2 times, most recently from 59af7b2 to 640983b Compare June 23, 2026 03:39
fankun.fan added 10 commits June 23, 2026 11:41
Module name now matches the public symbol linear_attention_state_update_kvbuffer.
Pure rename plus import-path updates; no behavior change.
Structural cleanup of the LA decode-MTP kernel (no semantic change),
split out of the prior pre-commit chore commit for reviewability.
Formatting/lint fixes plus inlining the shared _la_mtp_ref helper directly
into the test files. Benchmark updates included.
…nstexpr loop

Replace three explicit ilp_rows==2/==4/==8 branches with a single
range_constexpr(ilp_rows) path, mirroring the pattern already used in
la_state_update_kvbuffer.  Cuts ~550 LOC without changing semantics.
- la_verify_kvbuffer: re-check V % ilp_rows == 0 AFTER the ilp_rows->8
  promotion (the pre-promotion assert could let a partial row-block be
  silently skipped); zero the sH0 M-padding rows before GEMM1 so the MMA
  fragment is well-defined instead of consuming stale/NaN SMEM.
- assert K == 128 in the verify (MMA + shuffle) and state-update entry
  points, documenting the hardcoded head-dim assumption.
The kernel-only timing path passed shuffle-only args (use_smem_v, use_packed_fma)
to _get_compiled_verify_kvbuffer_kernel, causing a TypeError at T >= MMA_MIN_T.
Drop duplicated local benchmark_fn helpers; IQR-mean aggregation is already
the default in benchmarks.utils.benchmark_cuda_fn.
Align bench_la_decode_mtp and bench_la_kvbuffer with bench_la_decode_vs_fla:
wrapper for correctness+compile warmup, then get_compiled_*_handle for
kernel-only timing. Centralize cache-key dispatch in kernel modules.
@fkuner fkuner force-pushed the la-decode-kvbuffer branch from 640983b to a773efa Compare June 23, 2026 03:41

@icavan icavan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread benchmarks/bench_la_kvbuffer.py Outdated
q_4d = torch.randn(B, T, H, K, device=device, dtype=dtype)
k_4d = torch.randn(B, T, H, K, device=device, dtype=dtype)
v_4d = torch.randn(B, T, HV, V, device=device, dtype=dtype)
state_init_kmaj = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct k major semantic should adopt B,H,V,K layout, here it's ok since V==K==128.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean K-last here?



@dsl_user_op
def _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, c0, c1, c2, c3, *, loc=None, ip=None):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure whether TF32 is enough for long context. Need to make sure the precision is acceptable.

LGTM if the precision is good enough.

The shared GDN-derived `get_mtp_config` was suboptimal for LA on both
kernels. A grid search (B200, H=HV=64, K=V=128, B in [1..128], T in
[2,4,8]) showed the two kernels sit in opposite regimes and need
opposite tile directions:

- decode (memory-bound, 83% mem SOL @ B=128,T=8): wants SMALL tiles +
  high occupancy. New thresholds tile_v in {32,16,8} as work_units grows
  (was 64 at large WU). 3-9% faster at medium-large B; SOL 74->81%.
- verify/state-update (L1/compute-bound, 86% L1 SOL): wants LARGE tiles
  to fill m16n8k8 MMA and amortize q/k SMEM staging. New thresholds
  (64,8)/(128,8). 5-15% faster on the MMA path at large B.

Split the single shared function into two `get_mtp_config` (one per
kernel module) since one config cannot serve both regimes. state-update
reuses the verify config (both prefer large tiles).

Drop `use_smem_v` from the decode kernel: v has no cross-row reuse in
LA, so SMEM staging only added a barrier. Grid search confirmed the
direct-global path wins for every tile config — the True branch was dead
code (never compiled). Removes sVdata/sOutput SMEM, the cooperative
v-load, and the cooperative output writeback.

50/50 la_decode_mtp + la_kvbuffer tests pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants