Skip to content

MLX: on-device token sampling with Gumbel-max#20454

Draft
kiymetakdemir wants to merge 3 commits into
pytorch:mainfrom
kiymetakdemir:mlx-ondevice-sampling
Draft

MLX: on-device token sampling with Gumbel-max#20454
kiymetakdemir wants to merge 3 commits into
pytorch:mainfrom
kiymetakdemir:mlx-ondevice-sampling

Conversation

@kiymetakdemir

Copy link
Copy Markdown

Summary

Adds token sampling that runs inside the exported .pte for the MLX backend: a model wrapped in SamplingHead returns a sampled token id instead of [B, S, vocab] logits, avoiding the per-step logits copy to host and the host-side softmax+multinomial.

Sampling uses Gumbel-max: argmax(logits / temperature + g), g = -log(-log(u)). The only new schema primitive is a random source, RandomBitsNode, the rest reuses existing nodes. Greedy = temperature → 0. temperature is a runtime input; seed is optional.

Changes

  • schema.fbs: new RandomBitsNode (append-only union member, optional seed).
  • custom_kernel_ops/sample.py: mlx::sample op + register_fake + CPU reference.
  • ops.py: _sample_handler lowering the Gumbel-max graph.
  • runtime/MLXInterpreter.h: exec_random_bits + dispatch.
  • llm/sampling.py: SamplingHead wrapper.
  • generate.py: None-guard optional compound fields so the optional seed (de)serializes.

Notes

  • Uniform/gumbel computed in fp32 (bf16 rounds the ~1.0 clamp up → log(0)=-inf → poisons argmax).
  • Tests: custom_kernel_ops/test/test_sample.py, eager parity/distribution/determinism, export+partition lowering, and on-device e2e (incl. a bf16 large-vocab regression).

Fixes #20353

@pytorch-bot

pytorch-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20454

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

❌ You can merge normally! (1 Unrelated Failure), 3 Unclassified Failures

As of commit 20af908 with merge base 6f6225c (image):

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 23, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Comment thread backends/mlx/runtime/MLXInterpreter.h
from torch import Tensor


@torch.library.custom_op("mlx::sample", mutates_args=())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in backends/mlx/custom_ops.py

(It's not a custom kernel op b/c it doesn't have metal shader code)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it to custom_ops.py and relocated the test.

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 23, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 3m 8s —— View job


Code review: MLX on-device token sampling

Read through the full diff (sample.py, _sample_handler, exec_random_bits, SamplingHead, schema + generator changes, and the tests). Overall this is clean, well-tested work — the Gumbel-max lowering is faithful to MLX's uniform → gumbel → argmax layering, the schema change is correctly append-only, and the bf16 fp32-clamp fix is a nice catch with a dedicated regression test. Comments below.

Addressing the two inline review comments

1. MLXInterpreter.h:1703uint32 vs n.width (your question)

Yes, they must stay in sync, and right now they're only implicitly consistent:

check_allocation_bounded(shape, uint32, "random_bits");   // hard-coded 4-byte dtype
...
st.set_tensor(n.out, random::bits(shape, n.width, key, s)); // n.width drives real element size

The bound check assumes 4 bytes/element while the actual allocation is n.width bytes/element. They agree only because _sample_handler always emits width=4 and the schema default is 4. If anyone ever sets a different width, the bound check silently uses the wrong size and random::bits produces a non-uint32 dtype that the rest of the graph (the AsTypeNode → /uint32_max chain) assumes is uint32.

Since nothing emits a non-4 width, the simplest hardening is to assert the invariant rather than carry a flexible-but-unsupported field:

if (n.width != 4) {
  throw std::runtime_error("random_bits: only width=4 (uint32) is supported");
}

That makes the uint32 in the bound check provably correct. (Alternatively, map n.width → dtype and pass that to check_allocation_bounded, but that's more machinery than the single supported path warrants.) Fix this →

2. sample.py location — move to custom_ops.py

Agreed. mlx::sample has no Metal shader; it's a graph-lowered op like mlx::rope / mlx::gather_mm / mlx::gather_qmm, which all live in backends/mlx/custom_ops.py. custom_kernel_ops/ is for ops with actual kernel source (tq_norm, tq4_compress, gated_delta_rule, …). Moving it requires updating the two import sites:

  • backends/mlx/ops.py:22 (import ...custom_kernel_ops.sample)
  • backends/mlx/custom_kernel_ops/test/test_sample.py:27 (and ideally relocate the test to backends/mlx/test/).

Fix this →

Additional observations

Reference op vs. on-device graph are two different implementations. The CPU reference in sample.py uses torch.rand directly (no nextafter clamp), while the lowered graph uses bits/uint32_max → min(·, nextafter(1,0)) → gumbel. The eager tests (test_matches_independent_gumbel_reference) validate the reference against an independent torch impl, and the e2e tests validate the graph distributionally — but the two paths are never checked to agree with each other, and seeded determinism differs (torch RNG host-side vs. MLX RNG on-device). That's a reasonable design given the host-copy avoidance goal, but worth a one-line comment in sample.py noting the reference is for export/shape only and is not bit-identical to the delegated path, so a future reader doesn't assume seeded parity across host/device.

temperature = 0 → division by zero. Greedy is documented as "temperature → 0", and tests use 1e-4. Since temperature is a runtime input, an actual 0.0 produces inf/nan in logits / temperature. Not a blocker (callers are expected to pass a small epsilon), but a brief note in SamplingHead's docstring that temperature must be > 0 would prevent a confusing on-device result.

ArgmaxNode(axis=-1) — confirm exec_argmax/MLX argmax accepts a negative axis at runtime. Logits are [B, vocab] so -1 is unambiguous, but if MLX expects a non-negative axis it'd be safer to emit the resolved positive axis (the handler has n.args[0].meta available). The e2e test would catch this if it's a problem; flagging in case the e2e runner wasn't exercised on this path.

Minor: emit_lifted_constant(P, 4294967295.0, ...) and the nextafter constant are recomputed each call — fine, but if make_or_get_constant-style dedup applies these will already be shared; no action needed.

None of these are blockers. The two inline items (width invariant + file relocation) are the ones worth doing before merge.
· branch mlx-ondevice-sampling

Comment thread backends/mlx/test/test_sample.py Outdated
return self.head(logits, temperature=temperature)


class TestSampleOp(unittest.TestCase):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test cases should be inside backends/mlx/test/test_ops.py

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved TestSampleOp into test_ops.py.

Comment thread backends/mlx/test/test_sample.py Outdated
pte,
)
self.assertEqual(count_mlx_delegate_segments(pte), 1)
counts = get_mlx_node_counts(pte)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See test_ops.py. There are utilities for testing node counts

Comment thread backends/mlx/serialization/generate.py Outdated
# optional_str carries its own None handling; other compound offset
# fields (int_or_vid, etc.) must be guarded when optional so a None
# value is serialized as an absent field rather than crashing.
if fld.required or kind == "optional_str":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these changes needed? You don't have a string arg on the new node?

Gumbel-max sampling from softmax(logits / temperature).
logits: [B, vocab]
temperature: scalar float tensor (runtime input)
seed: scalar int tensor or None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not export if seed is an int?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a plain Python int doesn't export.

Comment thread backends/mlx/ops.py Outdated
AsTypeNode(
x=P.slot_to_tid(g_f32),
out=P.slot_to_tid(g),
scalar_type=torch_dtype_to_scalar_type(dt),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have this at all? Why not compute divide/argmax/etc in same fp32 dtype? The final output type is integer

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed this.

Comment thread backends/mlx/custom_ops.py Outdated
"""
Gumbel-max sampling from softmax(logits / temperature).
logits: [B, vocab]
temperature: scalar float tensor (runtime input)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add top-p as well?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Comment thread backends/mlx/ops.py Outdated
logits, temperature = args[0], args[1]
seed = args[2] if len(args) > 2 and args[2] is not None else None

dt = n.args[0].meta["val"].dtype

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use emit_if_else to specialize on temperature 0 as argmax?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this.

Comment thread .github/workflows/mlx.yml

echo "::group::Run Python unit tests"
# test_ops.py is listed for its unittest classes; run_all_tests.py only
# runs its OpTestCase classes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this run it twice?

What tests are non-OpTestCase classes in test_ops.py?

return torch.ops.mlx.sample(logits, t, p, s)


class TestSampleOp(unittest.TestCase):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these move into test_sample?

test_ops.py should only have OpTestCase because of how its driven in CI

class TestSampleOp(unittest.TestCase):
"""Eager reference behavior of mlx::sample (no export / no runtime)."""

def test_greedy_parity_small_temperature(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case and the next greedy one can be handled in a OpTestCase using regular. machinery (b/c it's bit exact)

other = run(logits, torch.tensor(1.0), torch.tensor(124, dtype=torch.int64))
self.assertFalse(torch.equal(first, other))

def test_export_strict_with_graph_inputs(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't OpTestCase in test_ops already testing export?

"from __future__ import annotations",
"",
"from typing import List, Tuple, Dict",
"from typing import Dict, List, Optional, Tuple",

@metascroy metascroy Jun 25, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need any generate.py changes if we make seed Vid (optional)?

You already emit seed as a Vid (not an int literal).

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 32m 1s —— View job


Code review: MLX on-device token sampling (round 2)

Re-reviewed the latest revision (top-p nucleus, temperature==0 greedy via IfNode, fp32 sampling chain, test reorg, optional-seed serialization). The lowering is correct and the bf16/fp32 reasoning is sound. Below I focus on the still-open threads from your latest review pass plus a couple of new observations.

Open review threads

1. test_ops.py should contain only OpTestCase (CI driver constraint) — not yet resolved

TestSampleOp(unittest.TestCase) is still in test_ops.py:7726. Earlier the class was moved into test_ops.py, but your later comment ("test_ops.py should only have OpTestCase because of how it's driven in CI") asks for the opposite: the plain unittest.TestCase eager tests should live in test_sample.py, leaving test_ops.py with only the OpTestCase subclasses (SampleSeededTest, SampleUnseededTest, SampleTopPTest). Right now test_ops.py mixes both. Fix this →

2. CI mlx.yml:78test_ops.py in the pytest list runs twice / becomes unnecessary

This is the consequence of (1). run_all_tests.py already executes the OpTestCase classes in test_ops.py; the only reason to also list test_ops.py under pytest is to pick up the unittest.TestCase class. Once TestSampleOp moves to test_sample.py, test_ops.py should be dropped from the pytest invocation (it would otherwise re-import/re-collect with nothing pytest-runnable, and the comment about "its unittest classes" no longer applies). Note: I can't edit .github/workflows/ myself (GitHub App permissions), so this one is yours regardless.

3. The two greedy cases are bit-exact → make them OpTestCase, dropping skip_comparison

test_greedy_parity_small_temperature (T=1e-4) and test_greedy_temperature_zero (T=0) are both deterministic on device:

  • T=0 takes the IfNode greedy branch → argmax(logits), no RNG.
  • T=1e-4 makes logits/T dominate the O(1) Gumbel noise, so argmax collapses to argmax(logits) regardless of which RNG produced the noise.

So both are host/device bit-identical and belong as real OpTestCase runs (no skip_comparison), which exercises export + lowering + on-device parity in one shot via the standard harness — exactly what you flagged. The sampling cases that genuinely aren't bit-identical (SampleSeededTest etc.) correctly keep skip_comparison=True.

4. test_sample.py:112 export tests vs. OpTestCase export coverage

TestSampleExport (export-only, runs .module() on host) overlaps with what the OpTestCase lowering tests already cover (export → partition → node counts). The genuinely unique host-only assertions are runtime-input semantics: same exported program at two temperatures (test_runtime_temperature_single_export) and seed-independent-of-host-RNG (test_seeded_export_reproducible_no_host_rng). test_export_strict_with_graph_inputs (just counts 3 user inputs) is fully subsumed by the OpTestCase export and can go. I'd keep the two semantic tests, drop the pure-export one.

5. generate.py:575 — do we need these changes if seed is an optional Vid?

Short answer: yes, the generate.py changes are still needed, and switching seed from IntOrVid to Vid would not remove them — it would move the problem to a path that isn't handled.

Why: before this PR no schema had an optional compound/struct field. seed is always emitted via P.to_int_or_vid(seed_val) where seed_val is a Slot, so it's IntOrVid.from_vid(...) — never a literal — and may be absent. The codegen needed three gaps closed for an absent compound field:

  • Python _build_int_or_vid returning None for None input, and the Add site guarding if seed_off is not None (schema.fbs seed is optional → no AddSeed call when absent).
  • C++ loader presence-guarding convert_* (it dereferences a null FlatBuffer pointer otherwise).
  • _fbs_type_to_cpp mapping optional IntOrVidstd::optional<IntOrVid>.

If you instead declared seed: Vid (optional), the Python emitter would route through _PY_STRUCT_CREATORCreateVid(builder, op.seed.idx) inlined unconditionally, which throws on None and has no None-guard today — so you'd have to add optional-Vid support there instead. Net: IntOrVid reuses the prebuilt-offset path that you've now made None-safe, with less new machinery than optional-Vid would require. So the current choice is the lower-churn one; the generate.py diff is justified. (If you do prefer Vid for clarity since it's always a vid, that's fine too — just be aware it needs its own None-guard in the struct-creator path.)

New observations (non-blocking)

  • MLXInterpreter.h:1708 bound check. Generalizing to widths {1,2,4} is fine, but check_allocation_bounded(shape, uint32, ...) now intentionally over-estimates for width 1/2. The comment says so, which is good — just confirming the guard is conservative-safe (it bounds above the real allocation), not unsafe.

  • Negative temperature diverges between reference and device. The device emits GreaterNode(temperature, 0) → any temperature < 0 takes the greedy branch, while the CPU reference only treats == 0 as greedy and would compute inverted-sampling for a negative value. Callers shouldn't pass negatives, but the two paths disagree there. A one-line note in SamplingHead's docstring ("temperature must be ≥ 0") or clamping in the reference would keep them consistent.

  • sample.py/custom_ops.py reference honesty. The docstring already states the reference isn't bit-identical to the delegate — 👍, that addresses my prior comment.

Summary

Logic and lowering look correct and well-tested. The remaining work is test organization, all from your own latest pass: move TestSampleOp out of test_ops.py (1), drop test_ops.py from the pytest list afterward (2), promote the two bit-exact greedy cases to OpTestCase (3), and trim the redundant export-only test (4). Item (5) is answered above — keep as-is.
· branch mlx-ondevice-sampling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

# Good First Issue: In-Model Sampling Head for the MLX Backend (Gumbel Sampling)

2 participants