MLX: on-device token sampling with Gumbel-max#20454
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20454
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below:
❌ You can merge normally! (1 Unrelated Failure), 3 Unclassified FailuresAs of commit 20af908 with merge base 6f6225c ( UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| from torch import Tensor | ||
|
|
||
|
|
||
| @torch.library.custom_op("mlx::sample", mutates_args=()) |
There was a problem hiding this comment.
This should be in backends/mlx/custom_ops.py
(It's not a custom kernel op b/c it doesn't have metal shader code)
There was a problem hiding this comment.
Moved it to custom_ops.py and relocated the test.
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 8s —— View job Code review: MLX on-device token samplingRead through the full diff ( Addressing the two inline review comments1. Yes, they must stay in sync, and right now they're only implicitly consistent: check_allocation_bounded(shape, uint32, "random_bits"); // hard-coded 4-byte dtype
...
st.set_tensor(n.out, random::bits(shape, n.width, key, s)); // n.width drives real element sizeThe bound check assumes 4 bytes/element while the actual allocation is Since nothing emits a non-4 width, the simplest hardening is to assert the invariant rather than carry a flexible-but-unsupported field: if (n.width != 4) {
throw std::runtime_error("random_bits: only width=4 (uint32) is supported");
}That makes the 2. Agreed.
Additional observationsReference op vs. on-device graph are two different implementations. The CPU reference in
Minor: None of these are blockers. The two inline items (width invariant + file relocation) are the ones worth doing before merge. |
| return self.head(logits, temperature=temperature) | ||
|
|
||
|
|
||
| class TestSampleOp(unittest.TestCase): |
There was a problem hiding this comment.
These test cases should be inside backends/mlx/test/test_ops.py
There was a problem hiding this comment.
Moved TestSampleOp into test_ops.py.
| pte, | ||
| ) | ||
| self.assertEqual(count_mlx_delegate_segments(pte), 1) | ||
| counts = get_mlx_node_counts(pte) |
There was a problem hiding this comment.
See test_ops.py. There are utilities for testing node counts
| # optional_str carries its own None handling; other compound offset | ||
| # fields (int_or_vid, etc.) must be guarded when optional so a None | ||
| # value is serialized as an absent field rather than crashing. | ||
| if fld.required or kind == "optional_str": |
There was a problem hiding this comment.
Why are these changes needed? You don't have a string arg on the new node?
| Gumbel-max sampling from softmax(logits / temperature). | ||
| logits: [B, vocab] | ||
| temperature: scalar float tensor (runtime input) | ||
| seed: scalar int tensor or None |
There was a problem hiding this comment.
Does it not export if seed is an int?
There was a problem hiding this comment.
Yes, a plain Python int doesn't export.
| AsTypeNode( | ||
| x=P.slot_to_tid(g_f32), | ||
| out=P.slot_to_tid(g), | ||
| scalar_type=torch_dtype_to_scalar_type(dt), |
There was a problem hiding this comment.
Should we have this at all? Why not compute divide/argmax/etc in same fp32 dtype? The final output type is integer
| """ | ||
| Gumbel-max sampling from softmax(logits / temperature). | ||
| logits: [B, vocab] | ||
| temperature: scalar float tensor (runtime input) |
There was a problem hiding this comment.
Can we add top-p as well?
| logits, temperature = args[0], args[1] | ||
| seed = args[2] if len(args) > 2 and args[2] is not None else None | ||
|
|
||
| dt = n.args[0].meta["val"].dtype |
There was a problem hiding this comment.
Can we use emit_if_else to specialize on temperature 0 as argmax?
…hain; reorg sample tests
|
|
||
| echo "::group::Run Python unit tests" | ||
| # test_ops.py is listed for its unittest classes; run_all_tests.py only | ||
| # runs its OpTestCase classes. |
There was a problem hiding this comment.
Wouldn't this run it twice?
What tests are non-OpTestCase classes in test_ops.py?
| return torch.ops.mlx.sample(logits, t, p, s) | ||
|
|
||
|
|
||
| class TestSampleOp(unittest.TestCase): |
There was a problem hiding this comment.
Can these move into test_sample?
test_ops.py should only have OpTestCase because of how its driven in CI
| class TestSampleOp(unittest.TestCase): | ||
| """Eager reference behavior of mlx::sample (no export / no runtime).""" | ||
|
|
||
| def test_greedy_parity_small_temperature(self): |
There was a problem hiding this comment.
This case and the next greedy one can be handled in a OpTestCase using regular. machinery (b/c it's bit exact)
| other = run(logits, torch.tensor(1.0), torch.tensor(124, dtype=torch.int64)) | ||
| self.assertFalse(torch.equal(first, other)) | ||
|
|
||
| def test_export_strict_with_graph_inputs(self): |
There was a problem hiding this comment.
Isn't OpTestCase in test_ops already testing export?
| "from __future__ import annotations", | ||
| "", | ||
| "from typing import List, Tuple, Dict", | ||
| "from typing import Dict, List, Optional, Tuple", |
There was a problem hiding this comment.
Do we need any generate.py changes if we make seed Vid (optional)?
You already emit seed as a Vid (not an int literal).
|
@claude review this code |
|
Claude finished @metascroy's task in 32m 1s —— View job Code review: MLX on-device token sampling (round 2)Re-reviewed the latest revision (top-p nucleus, Open review threads1.
2. CI This is the consequence of (1). 3. The two greedy cases are bit-exact → make them
So both are host/device bit-identical and belong as real 4.
5. Short answer: yes, the generate.py changes are still needed, and switching Why: before this PR no schema had an optional compound/struct field.
If you instead declared New observations (non-blocking)
SummaryLogic and lowering look correct and well-tested. The remaining work is test organization, all from your own latest pass: move |
Summary
Adds token sampling that runs inside the exported .pte for the MLX backend: a model wrapped in SamplingHead returns a sampled token id instead of [B, S, vocab] logits, avoiding the per-step logits copy to host and the host-side softmax+multinomial.
Sampling uses Gumbel-max: argmax(logits / temperature + g), g = -log(-log(u)). The only new schema primitive is a random source, RandomBitsNode, the rest reuses existing nodes. Greedy = temperature → 0. temperature is a runtime input; seed is optional.
Changes
Notes
Fixes #20353