[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV#20457
[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV#20457JulianCloudNTH wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20457
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 2 Unrelated FailuresAs of commit d266a3d with merge base 68bb668 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review if there are any bugs |
|
Claude finished @JulianCloudNTH's task in 4m 26s —— View job Bug review: q4gsw coop4 GEMV
Verdict: no functional bugs found. The kernel, routing, and dispatch math all check out under the stated constraints. Details below. Correctness checks that pass ✅1. Weight layout matches the GEMM exactly. In 2. The 3. 4. All buffer accesses are in-bounds. weight: max 5. Barriers are in uniform control flow. The grid-stride outer loop depends only on 6. Dispatch + grid-stride. Minor nits (non-blocking)
The 3 reported CI "new failures" are |
Stack from ghstack (oldest at bottom):
Add optimized GEMV kernel for M==1 decode path in q4gsw quantized-linear.
Problem: The register-tiled GEMM (from D109250327) wastes 75% of each 4×N tile when M=1, as only 1 of 4 rows is used.
Solution: Add a cooperative GEMV kernel that routes M==1 decode to a more efficient path:
Routing Logic (build-time selection, M is static per graph):
Constraints:
Implementation:
Performance: Keeps decode at measured bandwidth plateau, avoids M=1 tile waste. GEMV uses different reduction order (agrees to fp-rounding, not bit-exact).
@exported-using-ghexport
Differential Revision: D109250570
Differential Revision: D109250570