Skip to content

[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV#20457

Open
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/52/basefrom
gh/JulianCloudNTH/52/head
Open

[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV#20457
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/52/basefrom
gh/JulianCloudNTH/52/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Add optimized GEMV kernel for M==1 decode path in q4gsw quantized-linear.

Problem: The register-tiled GEMM (from D109250327) wastes 75% of each 4×N tile when M=1, as only 1 of 4 rows is used.

Solution: Add a cooperative GEMV kernel that routes M==1 decode to a more efficient path:

  • GEMV: 64 lanes per workgroup cooperate over K-dimension, each lane loads u32 words (8 K-values), reduces via shared memory
  • GEMM: M>1 prefill continues using the tiled GEMM

Routing Logic (build-time selection, M is static per graph):

  • Use GEMV when: M==1 && K%8==0 && group_size%8==0
  • Otherwise: Fall back to tiled GEMM

Constraints:

  • K%8==0: Kernel loads 8 K-values per u32 word
  • group_size%8==0: Ensures no quantization-group boundary splits a word (validated via CPU cross-check)
  • Llama models (group_size=32/64) satisfy both constraints

Implementation:

  • New kernel: q4gsw_linear_coop4.wgsl (fixed 64-lane workgroup)
  • New utility: clamp_workgroup_count() for grid-stride dispatch (vs compute_1d_workgroup_count which throws)
  • Shared infrastructure: Same bind layout, Params, weight format

Performance: Keeps decode at measured bandwidth plateau, avoids M=1 tile waste. GEMV uses different reduction order (agrees to fp-rounding, not bit-exact).
@exported-using-ghexport

Differential Revision: D109250570

Differential Revision: D109250570

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20457

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 2 Unrelated Failures

As of commit d266a3d with merge base 68bb668 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review if there are any bugs

@claude

claude Bot commented Jun 24, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 4m 26s —— View job


Bug review: q4gsw coop4 GEMV

  • Read changed files
  • Analyze GEMV kernel (q4gsw_linear_coop4.wgsl)
  • Analyze routing/dispatch logic (QuantizedLinear.cpp)
  • Analyze clamp_workgroup_count util
  • Cross-check weight/scale layout vs. the existing GEMM
  • Post review

Verdict: no functional bugs found. The kernel, routing, and dispatch math all check out under the stated constraints. Details below.

Correctness checks that pass ✅

1. Weight layout matches the GEMM exactly. In q4gsw_linear.wgsl, byte j of row n holds K=2j (low nibble) and K=2j+1 (high nibble). In coop4, word w covers bytes [4w, 4w+3]; for byte bi, kk = bi<<1 and k0 = w<<3, so lo → t_input[in_base + 8w + 2·bi] (even K, low nibble) and hi → +1 (odd K, high nibble). Identical packing convention. ✅

2. The group_size % 8 == 0 constraint genuinely prevents a group boundary from splitting a word. k0 = w<<3 is a multiple of 8, and group boundaries sit at multiples of gs (themselves multiples of 8). The next boundary after k0 is at k0 + gs ≥ k0 + 8 > k0 + 7, so all 8 K-values in a word share floor(k0/gs). Using one scale per word is therefore correct. ✅

3. num_words and row_words agree. With K%8==0: K_packed = K/2, K_packed % 4 == 0, so row_words = K_packed>>2 = K/8 = K>>3 = num_words. The row stride and loop bound are consistent, and wbase = n*row_words is u32-aligned. ✅

4. All buffer accesses are in-bounds. weight: max wbase+w = N·(K/8) − 1 < N·K_packed/4; input (M=1): max k0+7 = K−1; scales: max group (K−8)/gs < ceil(K/gs) ≤ num_groups, n < N ≤ padded_N; out: idx < M·N; bias guarded by has_bias. ✅

5. Barriers are in uniform control flow. The grid-stride outer loop depends only on wid.x/ngrp.x (both workgroup-uniform), so every lane executes the same iteration count and reaches each workgroupBarrier() together — no barrier-divergence UB. The trailing barrier before idx += stride correctly prevents the next iteration's partial[lid.x] = acc from racing the reduction reads. The tree reduction (s = 32→1, partial[lid.x+s], max index 63) stays in bounds, and idle lanes (when num_words < 64) contribute 0.0 harmlessly. ✅

6. Dispatch + grid-stride. clamp_workgroup_count caps ngrp.x at the device limit and the kernel strides by ngrp.x, so large N (> 65535) is still fully covered. constantCount = 0 for GEMV is correct since the shader hardcodes @workgroup_size(64) (≤ WebGPU's min maxComputeInvocationsPerWorkgroup of 256, so always valid). ✅

Minor nits (non-blocking)

  • q4gsw_linear_coop4.wgsl:38-39num_words and row_words are provably equal under the routing constraints, so computing both is slightly redundant. Harmless, but a one-line comment noting they coincide (or dropping one) would reduce confusion.
  • QuantizedLinear.cpp:125-126wg_size is computed unconditionally but only consumed on the GEMM path (the GEMV uses a fixed 64). It's still referenced at line 225 so no unused-var warning, but it reads as if it applies to both paths.
  • The PR notes GEMV uses a different reduction order so results agree only to fp-rounding (not bit-exact). Confirmed — the math is associative-equivalent (input·q·scale vs input·(q·scale)), this is expected, not a bug. Worth ensuring any cross-check test uses a tolerance rather than exact compare.

The 3 reported CI "new failures" are docker exec ... failed with exit code 1/92 and the QNN/macos ones are flagged as broken-trunk — they look like infra/unrelated rather than caused by this change.

@JulianCloudNTH JulianCloudNTH requested a review from psiddh June 24, 2026 17:16
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants