From 220b53f58721166a975cfc29da9276bf33adefef Mon Sep 17 00:00:00 2001 From: kiymetakdemir Date: Tue, 23 Jun 2026 11:56:17 -0700 Subject: [PATCH 1/5] MLX: on-device token sampling (mlx::sample, Gumbel-max) --- backends/mlx/custom_kernel_ops/sample.py | 40 +++ .../mlx/custom_kernel_ops/test/test_sample.py | 285 ++++++++++++++++++ backends/mlx/llm/sampling.py | 30 ++ backends/mlx/ops.py | 103 +++++++ backends/mlx/runtime/MLXInterpreter.h | 16 + backends/mlx/serialization/generate.py | 32 +- backends/mlx/serialization/schema.fbs | 11 +- 7 files changed, 510 insertions(+), 7 deletions(-) create mode 100644 backends/mlx/custom_kernel_ops/sample.py create mode 100644 backends/mlx/custom_kernel_ops/test/test_sample.py create mode 100644 backends/mlx/llm/sampling.py diff --git a/backends/mlx/custom_kernel_ops/sample.py b/backends/mlx/custom_kernel_ops/sample.py new file mode 100644 index 00000000000..02133ec4c4d --- /dev/null +++ b/backends/mlx/custom_kernel_ops/sample.py @@ -0,0 +1,40 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Optional + +import torch +from torch import Tensor + + +@torch.library.custom_op("mlx::sample", mutates_args=()) +def sample( + logits: Tensor, temperature: Tensor, seed: Optional[Tensor] = None +) -> Tensor: + """ + Gumbel-max sampling from softmax(logits / temperature). + logits: [B, vocab] + temperature: scalar float tensor (runtime input) + seed: scalar int tensor or None + - tensor -> deterministic, keyed RNG (random::key(seed)) + - None -> MLX global KeySequence (non-deterministic) + -> token_id: [B] int64 + Reference (CPU) implementation for export + numerical parity. + """ + if seed is None: + u = torch.rand(logits.shape) # global RNG + else: + gen = torch.Generator().manual_seed(int(seed.item())) + u = torch.rand(logits.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(logits / temperature + gumbel, dim=-1) + + +@torch.library.register_fake("mlx::sample") +def sample_fake(logits, temperature, seed=None): + return logits.new_empty(logits.shape[:-1], dtype=torch.long) diff --git a/backends/mlx/custom_kernel_ops/test/test_sample.py b/backends/mlx/custom_kernel_ops/test/test_sample.py new file mode 100644 index 00000000000..a9ab9ffc368 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/test/test_sample.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for mlx::sample (Gumbel-max token sampling). + +Most tests exercise the op's eager reference implementation and the +export/partition/serialization path. The end-to-end test runs the exported +program through the compiled op_test_runner (see backends/mlx/test/README.md +for building it). + +Usage: + python -m unittest executorch.backends.mlx.custom_kernel_ops.test.test_sample +""" + +import shutil +import tempfile +import unittest +from pathlib import Path +from typing import Optional + +# Registers torch.ops.mlx.sample. +import executorch.backends.mlx.custom_kernel_ops.sample # noqa: F401 +import torch +import torch.nn as nn +from executorch.backends.mlx.llm.sampling import SamplingHead +from executorch.backends.mlx.test.test_utils import ( + count_mlx_delegate_segments, + export_model_to_pte, + get_mlx_node_counts, +) + + +def _ref_gumbel_max(logits: torch.Tensor, temperature: float, seed: int): + """Independent Gumbel-max reference using the same torch RNG as the op.""" + gen = torch.Generator().manual_seed(seed) + u = torch.rand(logits.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(logits / temperature + gumbel, dim=-1) + + +def _tv_distance(p: torch.Tensor, q: torch.Tensor) -> float: + """Total-variation distance between two discrete distributions.""" + return 0.5 * torch.abs(p - q).sum().item() + + +def _sample(logits, temperature, seed: Optional[int]): + t = torch.tensor(float(temperature)) + s = None if seed is None else torch.tensor(int(seed), dtype=torch.int64) + return torch.ops.mlx.sample(logits, t, s) + + +class _LogitsPassthrough(nn.Module): + """Stand-in for a model returning logits [B, S, vocab].""" + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + return logits + + +class SeededSampleModel(nn.Module): + """SamplingHead with temperature AND seed as runtime forward inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed): + return self.head(logits, temperature=temperature, seed=seed) + + +class UnseededSampleModel(nn.Module): + """SamplingHead with temperature as a runtime input and no seed.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature): + return self.head(logits, temperature=temperature) + + +class TestSampleOp(unittest.TestCase): + """Eager reference behavior of mlx::sample (no export / no runtime).""" + + def test_greedy_parity_small_temperature(self): + # Small temperature -> Gumbel-max collapses to argmax(logits). + torch.manual_seed(0) + logits = torch.randn(8, 1024) + token = _sample(logits, 1e-4, seed=0) + self.assertTrue(torch.equal(token, torch.argmax(logits, dim=-1))) + + def test_matches_independent_gumbel_reference(self): + # Same seed -> bit-identical token vs an independent Gumbel-max impl. + torch.manual_seed(1) + logits = torch.randn(8, 512) + for seed in (0, 1, 7, 42): + got = _sample(logits, 0.8, seed=seed) + expected = _ref_gumbel_max(logits, 0.8, seed) + self.assertTrue(torch.equal(got, expected), f"mismatch at seed={seed}") + + def test_distribution_matches_softmax(self): + # Empirical token frequencies match softmax(logits / T). + vocab = 5 + temperature = 1.0 + torch.manual_seed(0) + base = torch.randn(vocab) + n = 20000 + tokens = _sample(base.expand(n, vocab), temperature, seed=0) + + empirical = torch.bincount(tokens, minlength=vocab).float() / n + target = torch.softmax(base / temperature, dim=-1) + tv = _tv_distance(empirical, target) + self.assertLess(tv, 0.05, f"TV distance {tv:.4f} too large") + + def test_determinism_seeded(self): + # Same seed -> identical draws; different seed -> different draws. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=123) + b = _sample(logits, 1.0, seed=123) + c = _sample(logits, 1.0, seed=124) + self.assertTrue(torch.equal(a, b)) + self.assertFalse(torch.equal(a, c)) + + def test_unseeded_varies_across_calls(self): + # seed=None uses the global RNG -> draws vary, tokens stay in range. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=None) + b = _sample(logits, 1.0, seed=None) + self.assertFalse(torch.equal(a, b)) + self.assertGreaterEqual(int(a.min()), 0) + self.assertLess(int(a.max()), 64) + + +class TestSampleExport(unittest.TestCase): + """torch.export, runtime inputs, and MLXPartitioner lowering.""" + + def setUp(self): + self._tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self._tmp, ignore_errors=True) + + def test_runtime_temperature_single_export(self): + # One exported program run at two temperatures (no re-export) confirms + # temperature is a live graph input: small T is near-greedy, large T + # spreads draws. A fixed logits row is broadcast so the spread of tokens + # reflects the sampling entropy. + vocab = 50 + batch = 256 + torch.manual_seed(0) + row = torch.randn(vocab) + logits = row.expand(batch, 1, vocab).contiguous() # [B, S=1, vocab] + seed = torch.tensor(0, dtype=torch.int64) + + run = torch.export.export( + SeededSampleModel(), (logits, torch.tensor(1.0), seed), strict=True + ).module() + + cold = run(logits, torch.tensor(1e-4), seed) + hot = run(logits, torch.tensor(100.0), seed) + + self.assertTrue(torch.all(cold == int(torch.argmax(row)))) + self.assertEqual(cold.unique().numel(), 1) + self.assertGreater(hot.unique().numel(), 10) + + def test_seeded_export_reproducible_no_host_rng(self): + # Seeded export: same seed -> identical tokens across runs of one + # exported program, independent of host RNG state (the seed is a graph + # input, not host-side stateful RNG). Different seed -> different draws. + torch.manual_seed(0) + logits = torch.randn(128, 1, 64) + seed = torch.tensor(123, dtype=torch.int64) + + run = torch.export.export( + SeededSampleModel(), (logits, torch.tensor(1.0), seed), strict=True + ).module() + + first = run(logits, torch.tensor(1.0), seed) + # Perturb the host global RNG between runs; a seeded draw is unaffected. + _ = torch.rand(1024) + second = run(logits, torch.tensor(1.0), seed) + self.assertTrue(torch.equal(first, second)) + + other = run(logits, torch.tensor(1.0), torch.tensor(124, dtype=torch.int64)) + self.assertFalse(torch.equal(first, other)) + + def test_export_strict_with_graph_inputs(self): + # strict=True export keeps logits, temperature, and seed as graph inputs. + logits = torch.randn(1, 4, 256) + ep = torch.export.export( + SeededSampleModel(), + (logits, torch.tensor(0.8), torch.tensor(0, dtype=torch.int64)), + strict=True, + ) + self.assertEqual(len(ep.graph_signature.user_inputs), 3) + + def test_seeded_lowers_to_mlx_delegate(self): + # The op is assigned to the MLX delegate, with the seed threaded in. + pte = Path(self._tmp) / "seeded.pte" + logits = torch.randn(1, 4, 256) + export_model_to_pte( + SeededSampleModel(), + (logits, torch.tensor(0.8), torch.tensor(0, dtype=torch.int64)), + pte, + ) + self.assertEqual(count_mlx_delegate_segments(pte), 1) + counts = get_mlx_node_counts(pte) + self.assertEqual(counts.get("RandomBitsNode", 0), 1) + self.assertEqual(counts.get("ArgmaxNode", 0), 1) + self.assertEqual(counts.get("ItemIntNode", 0), 1) # seed via .item() + + def test_unseeded_lowers_without_seed_field(self): + # seed=None lowers cleanly: RandomBitsNode emitted with no seed field + # (hence no ItemIntNode threading a seed Vid). + pte = Path(self._tmp) / "unseeded.pte" + logits = torch.randn(1, 4, 256) + export_model_to_pte(UnseededSampleModel(), (logits, torch.tensor(0.8)), pte) + self.assertEqual(count_mlx_delegate_segments(pte), 1) + counts = get_mlx_node_counts(pte) + self.assertEqual(counts.get("RandomBitsNode", 0), 1) + self.assertEqual(counts.get("ItemIntNode", 0), 0) + + +class TestSampleEndToEnd(unittest.TestCase): + """Run the exported program through the compiled op_test_runner.""" + + def setUp(self): + self._tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self._tmp, ignore_errors=True) + + def test_end_to_end(self): + # Requires the compiled op_test_runner (see backends/mlx/test/README.md). + from executorch.backends.mlx.test.test_utils import ( + load_tensors_from_bin, + run_cpp_test_runner, + save_tensors_to_bin, + ) + + vocab = 32 + logits = torch.randn(1, 4, vocab) + inputs = (logits, torch.tensor(0.8), torch.tensor(0, dtype=torch.int64)) + + tmp = Path(self._tmp) + pte, in_bin, out_bin = tmp / "e2e.pte", tmp / "in.bin", tmp / "out.bin" + export_model_to_pte(SeededSampleModel(), inputs, pte) + save_tensors_to_bin(list(inputs), in_bin) + + self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) + (token,) = load_tensors_from_bin(out_bin) + self.assertEqual(tuple(token.shape), (1,)) + self.assertTrue(0 <= int(token) < vocab) + + def test_bf16_large_vocab_greedy_parity(self): + # Regression: bf16 logits + large vocab. A dominant logit must win under + # near-greedy sampling. Catches the bug where casting the uniform to bf16 + # rounded the clamp (~0.99999994) up to 1.0 -> log(0) -> +inf gumbel, + # which then beat even a huge logit and produced a constant wrong token. + from executorch.backends.mlx.test.test_utils import ( + load_tensors_from_bin, + run_cpp_test_runner, + save_tensors_to_bin, + ) + + torch.manual_seed(0) + vocab = 4000 + logits = torch.randn(1, 4, vocab, dtype=torch.bfloat16) + logits[0, -1, 1234] = 50.0 # unambiguous argmax + inputs = (logits, torch.tensor(1e-4), torch.tensor(0, dtype=torch.int64)) + + tmp = Path(self._tmp) + pte, in_bin, out_bin = tmp / "bf16.pte", tmp / "in.bin", tmp / "out.bin" + export_model_to_pte(SeededSampleModel(), inputs, pte) + save_tensors_to_bin(list(inputs), in_bin) + + self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) + (token,) = load_tensors_from_bin(out_bin) + self.assertEqual(int(token), 1234) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/llm/sampling.py b/backends/mlx/llm/sampling.py new file mode 100644 index 00000000000..20e6f397edb --- /dev/null +++ b/backends/mlx/llm/sampling.py @@ -0,0 +1,30 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + + +class SamplingHead(nn.Module): + """ + Wraps a model that returns logits and samples a token id on-device. + + forward(*model_args, temperature, seed=None, **model_kwargs) -> token_id + + temperature: scalar float tensor, e.g. torch.tensor(0.8) + seed: scalar int tensor (seeded) or None (unseeded export) + """ + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, *args, temperature, seed=None, **kwargs): + logits = self.model(*args, **kwargs) # [B, S, vocab] + last = logits[:, -1, :] # [B, vocab] + return torch.ops.mlx.sample(last, temperature, seed) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 44536e675da..e7c42506389 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -18,10 +18,13 @@ import operator from typing import Any, Dict, List, Optional, Set, Tuple, Union +# Registers torch.ops.mlx.sample so _sample_handler can target it at import time. +import executorch.backends.mlx.custom_kernel_ops.sample # noqa: F401 E402 import torch from executorch.backends.mlx.builder.op_helpers import ( emit_lifted_constant, emit_quantized_biases, + emit_shape, parse_dequant_node, to_mlx_qparams, torch_dtype_to_scalar_type, @@ -115,6 +118,7 @@ PartitionNode, PowerNode, ProdNode, + RandomBitsNode, ReciprocalNode, RemainderNode, RepeatNode, @@ -3454,6 +3458,105 @@ def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.mlx.sample.default]) +def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Gumbel-max sampling: argmax(logits / temperature + gumbel_noise). + + Reproduces MLX's uniform -> gumbel -> argmax layering in the IR using the + new RandomBitsNode plus existing elementwise nodes, so a sampled token id is + produced on-device instead of returning the full logits tensor. + """ + args = P.args(n) + require_args(args, 2, 3, "mlx.sample") + require_kwargs(P.kwargs(n), set(), "mlx.sample") + logits, temperature = args[0], args[1] + seed = args[2] if len(args) > 2 and args[2] is not None else None + + dt = n.args[0].meta["val"].dtype + shape = emit_shape(P, n.args[0], logits) + + # Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent -> + # leave RandomBitsNode.seed unset (MLX global RNG). + seed_field = None + if seed is not None: + _, seed_val = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val))) + seed_field = P.to_int_or_vid(seed_val) + + # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95) + _, bits = P.make_tmp_slot() + P.emit( + RandomBitsNode(out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field) + ) + _, bits_f = P.make_tmp_slot() + P.emit( + AsTypeNode( + x=P.slot_to_tid(bits), + out=P.slot_to_tid(bits_f), + scalar_type=torch_dtype_to_scalar_type(torch.float32), + ) + ) + umax = emit_lifted_constant(P, 4294967295.0, torch.float32) + _, div0 = P.make_tmp_slot() + P.emit( + DivideNode( + a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0) + ) + ) + prev1 = emit_lifted_constant( + P, float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))), torch.float32 + ) + _, clamp = P.make_tmp_slot() + P.emit( + MinimumNode( + a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp) + ) + ) + # gumbel noise g = -log(-log(u)) (random.cpp:367), computed in float32. + # Keep the uniform in float32 through the log chain: casting it down to a + # low-precision dtype (e.g. bf16) can round the clamp (~0.99999994) back up + # to 1.0 -> log(1.0)=0 -> -log(0)=+inf, which then dominates argmax (a fixed + # seed makes that the same token every step). Only the finite gumbel is cast + # to the logits dtype. + _, l1 = P.make_tmp_slot() + P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1))) + _, g1 = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1))) + _, l2 = P.make_tmp_slot() + P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2))) + _, g_f32 = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g_f32))) + _, g = P.make_tmp_slot() + P.emit( + AsTypeNode( + x=P.slot_to_tid(g_f32), + out=P.slot_to_tid(g), + scalar_type=torch_dtype_to_scalar_type(dt), + ) + ) + + # sample: argmax(logits / temperature + g) over the vocab axis + _, scaled = P.make_tmp_slot() + P.emit( + DivideNode( + a=P.slot_to_tid(logits), + b=P.slot_to_tid(temperature), + out=P.slot_to_tid(scaled), + ) + ) + _, noisy = P.make_tmp_slot() + P.emit( + AddNode(a=P.slot_to_tid(scaled), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy)) + ) + out = P.make_or_get_slot(n) + P.emit( + ArgmaxNode( + x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False + ) + ) + return out + + @REGISTRY.register(target=[torch.ops.aten.argmin.default]) def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle aten.argmin - index of min element along axis.""" diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 8563ff339a7..919a4ef04be 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1695,6 +1695,19 @@ exec_argmax(const ArgmaxNode& n, ExecutionState& st, StreamOrDevice s) { st.set_tensor(n.out, argmax(x, n.axis, n.keepdims, s)); } +inline void exec_random_bits( + const RandomBitsNode& n, + ExecutionState& st, + StreamOrDevice s) { + auto shape = to_shape(n.shape, st); + check_allocation_bounded(shape, uint32, "random_bits"); + std::optional key = std::nullopt; + if (n.seed.has_value()) { + key = random::key(static_cast(resolve_int(n.seed.value(), st))); + } + st.set_tensor(n.out, random::bits(shape, n.width, key, s)); +} + inline void exec_argmin(const ArgminNode& n, ExecutionState& st, StreamOrDevice s) { const auto& x = st.const_tensor_ref(n.x); @@ -2057,6 +2070,9 @@ class Interpreter { case OpCode::ARGMAX: ops::exec_argmax(std::get(instr.node), st, s); break; + case OpCode::RANDOM_BITS: + ops::exec_random_bits(std::get(instr.node), st, s); + break; case OpCode::SLICE_UPDATE: ops::exec_slice_update(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index fd0b5b672b0..34721236784 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -831,7 +831,12 @@ def _emit_py_prebuild(kind: str, fld: FBSField) -> List[str]: if kind in _PY_PREBUILD_OFFSET: suffix = "_off" expr = _PY_PREBUILD_OFFSET[kind].format(name=n) - return [f" {n}{suffix} = {expr}"] + # optional_str carries its own None handling; other compound offset + # fields (int_or_vid, etc.) must be guarded when optional so a None + # value is serialized as an absent field rather than crashing. + if fld.required or kind == "optional_str": + return [f" {n}{suffix} = {expr}"] + return [f" {n}{suffix} = {expr} if op.{n} is not None else None"] return [] @@ -855,7 +860,12 @@ def _emit_py_add( return [f" {add}(builder, op.{n})"] # Pre-built offsets (string, compound types) if kind in ("str", "int_or_vid", "float_or_vid", "vid_or_tid", "int_or_vid_or_tid"): - return [f" {add}(builder, {n}_off)"] + if fld.required: + return [f" {add}(builder, {n}_off)"] + return [ + f" if {n}_off is not None:", + f" {add}(builder, {n}_off)", + ] # Pre-built vectors (required vs optional) if kind in ( "list_int", @@ -1056,6 +1066,8 @@ def _fbs_type_to_cpp( return "std::optional" if fbs_type == "Vid": return "std::optional" + if fbs_type in ("IntOrVid", "FloatOrVid", "VidOrTid", "IntOrVidOrTid"): + return f"std::optional<{cpp_type}>" if fld is not None and fld.default == "null" and fbs_type in FBS_TO_CPP: return f"std::optional<{cpp_type}>" @@ -1140,7 +1152,7 @@ def _generate_loader_case(table: FBSTable) -> List[str]: fb_field_name = fld.name kind = _get_field_kind(fld, table) - load_lines = _emit_cpp_load(kind, fld.name, fb_field_name, table) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name, table, fld) if load_lines is None: raise ValueError( f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " @@ -1172,16 +1184,24 @@ def _generate_loader_case(table: FBSTable) -> List[str]: } -def _emit_cpp_load( - kind: str, name: str, fb_name: str, table=None +def _emit_cpp_load( # noqa: C901 + kind: str, name: str, fb_name: str, table=None, fld=None ) -> "List[str] | None": """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" # Interned string fields share one std::string via the load-time pool. if _is_interned_str(table, name) and kind in ("str", "optional_str"): return [f" node.{name} = strpool.intern(fb->{fb_name}());"] - # Required struct / compound via converter + # Struct / compound via converter if kind in _CPP_CONVERTER: conv = _CPP_CONVERTER[kind] + # Optional compound fields (e.g. an optional IntOrVid) must be + # presence-guarded; convert_* throws on a null FlatBuffer pointer. + if fld is not None and not fld.required: + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = {conv}(fb->{fb_name}());", + " }", + ] return [f" node.{name} = {conv}(fb->{fb_name}());"] # Scalars (direct value) if kind in ("int", "float", "bool"): diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 42c53e5172b..b5a7f737842 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -985,6 +985,14 @@ table IfNode { else_chain_idx: uint32; // index into MLXGraph.instruction_chains } +table RandomBitsNode { + out: Tid (required); + shape: [IntOrVid] (required); + seed: IntOrVid; // OPTIONAL: present -> random::key(seed); + // absent -> MLX global KeySequence + width: int32 = 4; // bytes per element (4 -> uint32) +} + // Custom Metal kernel execution via mlx::core::fast::metal_kernel(). // Two-phase API: // 1. Factory: metal_kernel(name, input_names, output_names, source, header, @@ -1161,7 +1169,8 @@ union OpNode { BitwiseAndNode, BitwiseOrNode, BitwiseXorNode, - IfNode + IfNode, + RandomBitsNode // BC: Add new op nodes here (append only) } From bddb8190ccd0ca08e89b318a674cc90fceb260ad Mon Sep 17 00:00:00 2001 From: kiymetakdemir Date: Tue, 23 Jun 2026 15:16:33 -0700 Subject: [PATCH 2/5] MLX: reject non-uint32 width; move sample to custom_ops; docs --- backends/mlx/custom_kernel_ops/sample.py | 40 ------------------- backends/mlx/custom_ops.py | 32 +++++++++++++++ backends/mlx/llm/sampling.py | 4 +- backends/mlx/ops.py | 2 - backends/mlx/runtime/MLXInterpreter.h | 4 ++ .../test/test_sample.py | 4 +- 6 files changed, 41 insertions(+), 45 deletions(-) delete mode 100644 backends/mlx/custom_kernel_ops/sample.py rename backends/mlx/{custom_kernel_ops => }/test/test_sample.py (98%) diff --git a/backends/mlx/custom_kernel_ops/sample.py b/backends/mlx/custom_kernel_ops/sample.py deleted file mode 100644 index 02133ec4c4d..00000000000 --- a/backends/mlx/custom_kernel_ops/sample.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -# - -from typing import Optional - -import torch -from torch import Tensor - - -@torch.library.custom_op("mlx::sample", mutates_args=()) -def sample( - logits: Tensor, temperature: Tensor, seed: Optional[Tensor] = None -) -> Tensor: - """ - Gumbel-max sampling from softmax(logits / temperature). - logits: [B, vocab] - temperature: scalar float tensor (runtime input) - seed: scalar int tensor or None - - tensor -> deterministic, keyed RNG (random::key(seed)) - - None -> MLX global KeySequence (non-deterministic) - -> token_id: [B] int64 - Reference (CPU) implementation for export + numerical parity. - """ - if seed is None: - u = torch.rand(logits.shape) # global RNG - else: - gen = torch.Generator().manual_seed(int(seed.item())) - u = torch.rand(logits.shape, generator=gen) - gumbel = -torch.log(-torch.log(u)) - return torch.argmax(logits / temperature + gumbel, dim=-1) - - -@torch.library.register_fake("mlx::sample") -def sample_fake(logits, temperature, seed=None): - return logits.new_empty(logits.shape[:-1], dtype=torch.long) diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index c03db05d918..4c3c095f9aa 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -391,3 +391,35 @@ def gather_qmm_fake( else: batch = w.shape[:-2] return x.new_empty((*batch, M, N)) + + +@torch.library.custom_op("mlx::sample", mutates_args=()) +def sample( + logits: Tensor, temperature: Tensor, seed: Optional[Tensor] = None +) -> Tensor: + """ + Gumbel-max sampling from softmax(logits / temperature). + logits: [B, vocab] + temperature: scalar float tensor (runtime input) + seed: scalar int tensor or None + - tensor -> deterministic, keyed RNG (random::key(seed)) + - None -> MLX global KeySequence (non-deterministic) + -> token_id: [B] int64 + + Host/CPU reference used for export (shape/meta) and distributional checks + only. It is NOT bit-identical to the lowered on-device graph: this uses torch + RNG (plain torch.rand, no uint32/nextafter uniform) while the delegate uses + MLX RNG, so a given seed does not reproduce the same tokens host vs. device. + """ + if seed is None: + u = torch.rand(logits.shape) # global RNG + else: + gen = torch.Generator().manual_seed(int(seed.item())) + u = torch.rand(logits.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(logits / temperature + gumbel, dim=-1) + + +@torch.library.register_fake("mlx::sample") +def sample_fake(logits, temperature, seed=None): + return logits.new_empty(logits.shape[:-1], dtype=torch.long) diff --git a/backends/mlx/llm/sampling.py b/backends/mlx/llm/sampling.py index 20e6f397edb..79eee257de8 100644 --- a/backends/mlx/llm/sampling.py +++ b/backends/mlx/llm/sampling.py @@ -16,7 +16,9 @@ class SamplingHead(nn.Module): forward(*model_args, temperature, seed=None, **model_kwargs) -> token_id - temperature: scalar float tensor, e.g. torch.tensor(0.8) + temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be > 0; + logits are divided by it, so 0.0 yields inf/nan. For greedy, + pass a small epsilon (e.g. 1e-4), not 0. seed: scalar int tensor (seeded) or None (unseeded export) """ diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index e7c42506389..fff319a9571 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -18,8 +18,6 @@ import operator from typing import Any, Dict, List, Optional, Set, Tuple, Union -# Registers torch.ops.mlx.sample so _sample_handler can target it at import time. -import executorch.backends.mlx.custom_kernel_ops.sample # noqa: F401 E402 import torch from executorch.backends.mlx.builder.op_helpers import ( emit_lifted_constant, diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 919a4ef04be..a08f449d1c2 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1699,6 +1699,10 @@ inline void exec_random_bits( const RandomBitsNode& n, ExecutionState& st, StreamOrDevice s) { + // Only width=4 (uint32) is supported; reject other widths. + if (n.width != 4) { + throw std::runtime_error("random_bits: only width=4 (uint32) is supported"); + } auto shape = to_shape(n.shape, st); check_allocation_bounded(shape, uint32, "random_bits"); std::optional key = std::nullopt; diff --git a/backends/mlx/custom_kernel_ops/test/test_sample.py b/backends/mlx/test/test_sample.py similarity index 98% rename from backends/mlx/custom_kernel_ops/test/test_sample.py rename to backends/mlx/test/test_sample.py index a9ab9ffc368..a01b14ce3e5 100644 --- a/backends/mlx/custom_kernel_ops/test/test_sample.py +++ b/backends/mlx/test/test_sample.py @@ -14,7 +14,7 @@ for building it). Usage: - python -m unittest executorch.backends.mlx.custom_kernel_ops.test.test_sample + python -m unittest executorch.backends.mlx.test.test_sample """ import shutil @@ -24,7 +24,7 @@ from typing import Optional # Registers torch.ops.mlx.sample. -import executorch.backends.mlx.custom_kernel_ops.sample # noqa: F401 +import executorch.backends.mlx.custom_ops # noqa: F401 import torch import torch.nn as nn from executorch.backends.mlx.llm.sampling import SamplingHead From 20af90881c2d8b23cfd10fb2bb1668f9e25416e3 Mon Sep 17 00:00:00 2001 From: kiymetakdemir Date: Wed, 24 Jun 2026 13:45:07 -0700 Subject: [PATCH 3/5] MLX: top-p sampling; temperature=0 greedy via IfNode; fp32 sampling chain; reorg sample tests --- .github/workflows/mlx.yml | 4 + backends/mlx/custom_ops.py | 32 +++- backends/mlx/llm/sampling.py | 17 +- backends/mlx/ops.py | 253 +++++++++++++++++-------- backends/mlx/runtime/MLXInterpreter.h | 7 +- backends/mlx/serialization/generate.py | 33 ++-- backends/mlx/test/test_ops.py | 216 +++++++++++++++++++++ backends/mlx/test/test_sample.py | 191 +++++-------------- 8 files changed, 500 insertions(+), 253 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index acc6b4840cf..9794f8080dc 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -74,12 +74,16 @@ jobs: echo "::endgroup::" echo "::group::Run Python unit tests" + # test_ops.py is listed for its unittest classes; run_all_tests.py only + # runs its OpTestCase classes. ${CONDA_RUN} python -m pytest \ backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ backends/mlx/test/test_serialization_dedup.py \ backends/mlx/test/test_slot_recycling.py \ + backends/mlx/test/test_sample.py \ + backends/mlx/test/test_ops.py \ examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index 4c3c095f9aa..74e8a13586d 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -395,12 +395,18 @@ def gather_qmm_fake( @torch.library.custom_op("mlx::sample", mutates_args=()) def sample( - logits: Tensor, temperature: Tensor, seed: Optional[Tensor] = None + logits: Tensor, + temperature: Tensor, + top_p: Tensor, + seed: Optional[Tensor] = None, ) -> Tensor: """ - Gumbel-max sampling from softmax(logits / temperature). + Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus). logits: [B, vocab] - temperature: scalar float tensor (runtime input) + temperature: scalar float tensor (runtime input). temperature == 0 is + greedy: return argmax(logits) directly. + top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every + token, i.e. it is off. seed: scalar int tensor or None - tensor -> deterministic, keyed RNG (random::key(seed)) - None -> MLX global KeySequence (non-deterministic) @@ -411,15 +417,27 @@ def sample( RNG (plain torch.rand, no uint32/nextafter uniform) while the delegate uses MLX RNG, so a given seed does not reproduce the same tokens host vs. device. """ + if float(temperature) == 0: + return torch.argmax(logits, dim=-1) + # whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties). + scaled = logits.float() / temperature + probs = torch.softmax(scaled, dim=-1) + s_probs, _ = torch.sort(probs, dim=-1, descending=True) + cum = torch.cumsum(s_probs, dim=-1) + keep = (cum - s_probs) <= top_p + thresh = torch.where(keep, s_probs, s_probs.new_tensor(float("inf"))).amin( + dim=-1, keepdim=True + ) + scaled = torch.where(probs >= thresh, scaled, scaled.new_tensor(float("-inf"))) if seed is None: - u = torch.rand(logits.shape) # global RNG + u = torch.rand(scaled.shape) # global RNG else: gen = torch.Generator().manual_seed(int(seed.item())) - u = torch.rand(logits.shape, generator=gen) + u = torch.rand(scaled.shape, generator=gen) gumbel = -torch.log(-torch.log(u)) - return torch.argmax(logits / temperature + gumbel, dim=-1) + return torch.argmax(scaled + gumbel, dim=-1) @torch.library.register_fake("mlx::sample") -def sample_fake(logits, temperature, seed=None): +def sample_fake(logits, temperature, top_p, seed=None): return logits.new_empty(logits.shape[:-1], dtype=torch.long) diff --git a/backends/mlx/llm/sampling.py b/backends/mlx/llm/sampling.py index 79eee257de8..cb8d39c64e5 100644 --- a/backends/mlx/llm/sampling.py +++ b/backends/mlx/llm/sampling.py @@ -14,19 +14,24 @@ class SamplingHead(nn.Module): """ Wraps a model that returns logits and samples a token id on-device. - forward(*model_args, temperature, seed=None, **model_kwargs) -> token_id + forward(*model_args, temperature, seed=None, top_p=1.0, **model_kwargs) + -> token_id - temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be > 0; - logits are divided by it, so 0.0 yields inf/nan. For greedy, - pass a small epsilon (e.g. 1e-4), not 0. + temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be >= 0; + temperature=0 is greedy (returns argmax, no division). seed: scalar int tensor (seeded) or None (unseeded export) + top_p: scalar float tensor in (0, 1] for nucleus sampling. top_p=1.0 + (the default) keeps every token, i.e. no filtering. Pass it + as a runtime input to tune per request. """ def __init__(self, model: nn.Module): super().__init__() self.model = model - def forward(self, *args, temperature, seed=None, **kwargs): + def forward(self, *args, temperature, seed=None, top_p=1.0, **kwargs): logits = self.model(*args, **kwargs) # [B, S, vocab] last = logits[:, -1, :] # [B, vocab] - return torch.ops.mlx.sample(last, temperature, seed) + if not isinstance(top_p, torch.Tensor): + top_p = torch.tensor(float(top_p)) + return torch.ops.mlx.sample(last, temperature, top_p, seed) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index fff319a9571..82e322b5d85 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -20,6 +20,7 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( + emit_if_else, emit_lifted_constant, emit_quantized_biases, emit_shape, @@ -3463,95 +3464,193 @@ def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot: Reproduces MLX's uniform -> gumbel -> argmax layering in the IR using the new RandomBitsNode plus existing elementwise nodes, so a sampled token id is produced on-device instead of returning the full logits tensor. + + temperature == 0 is greedy: an IfNode branches to a plain argmax(logits), + skipping the sampling chain (so 0 is exact, not the small-epsilon approx). """ args = P.args(n) - require_args(args, 2, 3, "mlx.sample") + require_args(args, 3, 4, "mlx.sample") require_kwargs(P.kwargs(n), set(), "mlx.sample") - logits, temperature = args[0], args[1] - seed = args[2] if len(args) > 2 and args[2] is not None else None - - dt = n.args[0].meta["val"].dtype - shape = emit_shape(P, n.args[0], logits) - - # Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent -> - # leave RandomBitsNode.seed unset (MLX global RNG). - seed_field = None - if seed is not None: - _, seed_val = P.make_tmp_value_slot() - P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val))) - seed_field = P.to_int_or_vid(seed_val) - - # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95) - _, bits = P.make_tmp_slot() - P.emit( - RandomBitsNode(out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field) - ) - _, bits_f = P.make_tmp_slot() - P.emit( - AsTypeNode( - x=P.slot_to_tid(bits), - out=P.slot_to_tid(bits_f), - scalar_type=torch_dtype_to_scalar_type(torch.float32), + logits, temperature, top_p = args[0], args[1], args[2] + seed = args[3] if len(args) > 3 and args[3] is not None else None + + temp_dt = n.args[1].meta["val"].dtype + out = P.make_or_get_slot(n) + + def emit_greedy(): + P.emit( + ArgmaxNode( + x=P.slot_to_tid(logits), + out=P.slot_to_tid(out), + axis=-1, + keepdims=False, + ) ) - ) - umax = emit_lifted_constant(P, 4294967295.0, torch.float32) - _, div0 = P.make_tmp_slot() - P.emit( - DivideNode( - a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0) + + def emit_sample(): + shape = emit_shape(P, n.args[0], logits) + + # Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent -> + # leave RandomBitsNode.seed unset (MLX global RNG). + seed_field = None + if seed is not None: + _, seed_val = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val))) + seed_field = P.to_int_or_vid(seed_val) + + # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95) + _, bits = P.make_tmp_slot() + P.emit( + RandomBitsNode( + out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field + ) ) - ) - prev1 = emit_lifted_constant( - P, float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))), torch.float32 - ) - _, clamp = P.make_tmp_slot() - P.emit( - MinimumNode( - a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp) + _, bits_f = P.make_tmp_slot() + P.emit( + AsTypeNode( + x=P.slot_to_tid(bits), + out=P.slot_to_tid(bits_f), + scalar_type=torch_dtype_to_scalar_type(torch.float32), + ) ) - ) - # gumbel noise g = -log(-log(u)) (random.cpp:367), computed in float32. - # Keep the uniform in float32 through the log chain: casting it down to a - # low-precision dtype (e.g. bf16) can round the clamp (~0.99999994) back up - # to 1.0 -> log(1.0)=0 -> -log(0)=+inf, which then dominates argmax (a fixed - # seed makes that the same token every step). Only the finite gumbel is cast - # to the logits dtype. - _, l1 = P.make_tmp_slot() - P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1))) - _, g1 = P.make_tmp_slot() - P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1))) - _, l2 = P.make_tmp_slot() - P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2))) - _, g_f32 = P.make_tmp_slot() - P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g_f32))) - _, g = P.make_tmp_slot() - P.emit( - AsTypeNode( - x=P.slot_to_tid(g_f32), - out=P.slot_to_tid(g), - scalar_type=torch_dtype_to_scalar_type(dt), + umax = emit_lifted_constant(P, 4294967295.0, torch.float32) + _, div0 = P.make_tmp_slot() + P.emit( + DivideNode( + a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0) + ) + ) + prev1 = emit_lifted_constant( + P, + float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))), + torch.float32, + ) + _, clamp = P.make_tmp_slot() + P.emit( + MinimumNode( + a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp) + ) + ) + # gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf). + _, l1 = P.make_tmp_slot() + P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1))) + _, g1 = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1))) + _, l2 = P.make_tmp_slot() + P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2))) + _, g = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g))) + + # sample: argmax(logits / temperature + g) over the vocab axis, in float32 + _, logits_f = P.make_tmp_slot() + P.emit( + AsTypeNode( + x=P.slot_to_tid(logits), + out=P.slot_to_tid(logits_f), + scalar_type=torch_dtype_to_scalar_type(torch.float32), + ) + ) + _, scaled = P.make_tmp_slot() + P.emit( + DivideNode( + a=P.slot_to_tid(logits_f), + b=P.slot_to_tid(temperature), + out=P.slot_to_tid(scaled), + ) ) - ) - # sample: argmax(logits / temperature + g) over the vocab axis - _, scaled = P.make_tmp_slot() - P.emit( - DivideNode( - a=P.slot_to_tid(logits), - b=P.slot_to_tid(temperature), - out=P.slot_to_tid(scaled), + # top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending. + _, probs = P.make_tmp_slot() + P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1)) + _, neg_p = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(neg_p))) + _, sorted_neg = P.make_tmp_slot() + P.emit(SortNode(x=P.slot_to_tid(neg_p), out=P.slot_to_tid(sorted_neg), axis=-1)) + _, sorted_p = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(sorted_neg), out=P.slot_to_tid(sorted_p))) + _, cum = P.make_tmp_slot() + P.emit(CumsumNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(cum), axis=-1)) + _, prefix = P.make_tmp_slot() + P.emit( + SubtractNode( + a=P.slot_to_tid(cum), + b=P.slot_to_tid(sorted_p), + out=P.slot_to_tid(prefix), + ) ) - ) - _, noisy = P.make_tmp_slot() - P.emit( - AddNode(a=P.slot_to_tid(scaled), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy)) - ) - out = P.make_or_get_slot(n) + # remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0) + _, remove = P.make_tmp_slot() + P.emit( + GreaterNode( + a=P.slot_to_tid(prefix), + b=P.slot_to_tid(top_p), + out=P.slot_to_tid(remove), + ) + ) + pos_inf = emit_lifted_constant(P, float("inf"), torch.float32) + _, kept = P.make_tmp_slot() + P.emit( + WhereNode( + condition=P.slot_to_tid(remove), + x=P.slot_to_tid(pos_inf), + y=P.slot_to_tid(sorted_p), + out=P.slot_to_tid(kept), + ) + ) + # threshold = smallest kept probability (per row) + _, thresh = P.make_tmp_slot() + P.emit( + MinNode( + x=P.slot_to_tid(kept), + out=P.slot_to_tid(thresh), + axes=[-1], + keepdims=True, + ) + ) + _, drop = P.make_tmp_slot() + P.emit( + LessNode( + a=P.slot_to_tid(probs), + b=P.slot_to_tid(thresh), + out=P.slot_to_tid(drop), + ) + ) + neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32) + _, masked = P.make_tmp_slot() + P.emit( + WhereNode( + condition=P.slot_to_tid(drop), + x=P.slot_to_tid(neg_inf), + y=P.slot_to_tid(scaled), + out=P.slot_to_tid(masked), + ) + ) + + _, noisy = P.make_tmp_slot() + P.emit( + AddNode( + a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy) + ) + ) + P.emit( + ArgmaxNode( + x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False + ) + ) + + # temperature == 0 -> greedy: IfNode branches to argmax(logits), skipping sampling. + zero = emit_lifted_constant(P, 0.0, temp_dt) + _, is_sampling = P.make_tmp_slot() P.emit( - ArgmaxNode( - x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False + GreaterNode( + a=P.slot_to_tid(temperature), + b=P.slot_to_tid(zero), + out=P.slot_to_tid(is_sampling), ) ) + _, cond_val = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(is_sampling), out=P.slot_to_vid(cond_val))) + emit_if_else(P, P.to_int_or_vid(cond_val), emit_sample, emit_greedy) return out diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index a08f449d1c2..95a09e33ac8 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1699,11 +1699,12 @@ inline void exec_random_bits( const RandomBitsNode& n, ExecutionState& st, StreamOrDevice s) { - // Only width=4 (uint32) is supported; reject other widths. - if (n.width != 4) { - throw std::runtime_error("random_bits: only width=4 (uint32) is supported"); + // random::bits supports width (bytes/element) in {1, 2, 4} -> uint8/uint16/uint32. + if (n.width != 1 && n.width != 2 && n.width != 4) { + throw std::runtime_error("random_bits: width must be 1, 2, or 4"); } auto shape = to_shape(n.shape, st); + // uint32 (4 bytes, the widest supported) is a safe upper bound for the guard. check_allocation_bounded(shape, uint32, "random_bits"); std::optional key = std::nullopt; if (n.seed.has_value()) { diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index 34721236784..5b05c1b8b9d 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -572,7 +572,7 @@ def generate_python_serializers(schema: FBSSchema) -> str: "", "from __future__ import annotations", "", - "from typing import List, Tuple, Dict", + "from typing import Dict, List, Optional, Tuple", "", "import flatbuffers", "", @@ -640,8 +640,10 @@ def generate_python_serializers(schema: FBSSchema) -> str: "class GeneratedOpBuilders:", ' """Mixin class with auto-generated op builder methods."""', "", - " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> int:", - ' """Build an IntOrVid table."""', + " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> Optional[int]:", + ' """Build an IntOrVid table (None -> absent field, like _shared_string)."""', + " if iov is None:", + " return None", " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVid as FBIntOrVidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", "", @@ -653,8 +655,10 @@ def generate_python_serializers(schema: FBSSchema) -> str: " FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx))", " return FBIntOrVidModule.End(builder)", "", - " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> int:", - ' """Build a FloatOrVid table."""', + " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> Optional[int]:", + ' """Build a FloatOrVid table (None -> absent field)."""', + " if fov is None:", + " return None", " from executorch.backends.mlx.serialization._generated.mlx_delegate import FloatOrVid as FBFloatOrVidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", "", @@ -665,8 +669,10 @@ def generate_python_serializers(schema: FBSSchema) -> str: " FBFloatOrVidModule.AddVid(builder, CreateVid(builder, fov.vid.idx))", " return FBFloatOrVidModule.End(builder)", "", - " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> int:", - ' """Build a TidOrVid table."""', + " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> Optional[int]:", + ' """Build a TidOrVid table (None -> absent field)."""', + " if vot is None:", + " return None", " from executorch.backends.mlx.serialization._generated.mlx_delegate import VidOrTid as FBVidOrTidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", @@ -679,8 +685,10 @@ def generate_python_serializers(schema: FBSSchema) -> str: " FBVidOrTidModule.AddVid(builder, CreateVid(builder, vot.vid.idx))", " return FBVidOrTidModule.End(builder)", "", - " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> int:", - ' """Build an IntOrVidOrTid table."""', + " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> Optional[int]:", + ' """Build an IntOrVidOrTid table (None -> absent field)."""', + " if ivt is None:", + " return None", " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVidOrTid as FBIntOrVidOrTidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", @@ -831,12 +839,7 @@ def _emit_py_prebuild(kind: str, fld: FBSField) -> List[str]: if kind in _PY_PREBUILD_OFFSET: suffix = "_off" expr = _PY_PREBUILD_OFFSET[kind].format(name=n) - # optional_str carries its own None handling; other compound offset - # fields (int_or_vid, etc.) must be guarded when optional so a None - # value is serialized as an absent field rather than crashing. - if fld.required or kind == "optional_str": - return [f" {n}{suffix} = {expr}"] - return [f" {n}{suffix} = {expr} if op.{n} is not None else None"] + return [f" {n}{suffix} = {expr}"] return [] diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 8f52116f6b8..03956d3bd08 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -25,6 +25,7 @@ """ import os +import unittest from typing import Callable, Dict, List, Optional, Tuple import torch @@ -35,6 +36,7 @@ custom_ops, ops, ) +from executorch.backends.mlx.llm.sampling import SamplingHead from torch.export import Dim from .test_utils import OpTestCase, register_test @@ -7579,3 +7581,217 @@ def create_model(self) -> nn.Module: def create_inputs(self) -> Tuple[torch.Tensor, ...]: x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) return (x,) + + +class _LogitsPassthrough(nn.Module): + """Stand-in for a model returning logits [B, S, vocab].""" + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + return logits + + +class SeededSampleModel(nn.Module): + """SamplingHead with temperature AND seed as runtime forward inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed): + return self.head(logits, temperature=temperature, seed=seed) + + +class UnseededSampleModel(nn.Module): + """SamplingHead with temperature as a runtime input and no seed.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature): + return self.head(logits, temperature=temperature) + + +class TopPSampleModel(nn.Module): + """SamplingHead with temperature, seed, and top_p as runtime inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed, top_p): + return self.head(logits, temperature=temperature, seed=seed, top_p=top_p) + + +@register_test +class SampleSeededTest(OpTestCase): + """Seeded sample lowers to one MLX segment; seed threads in via ItemIntNode.""" + + name = "sample_seeded" + skip_comparison = True # sampling RNG is not host/device bit-identical + expected_node_counts = { + "IfNode": 1, # temperature==0 greedy branch + "RandomBitsNode": 1, + "ArgmaxNode": 2, # sampling branch + greedy branch + "ItemIntNode": 2, # seed + temperature>0 condition + "SoftmaxNode": 1, # top-p nucleus chain + "SortNode": 1, + "CumsumNode": 1, + "MinNode": 1, + "WhereNode": 2, + } + + def create_model(self) -> nn.Module: + return SeededSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn(1, 4, 256), + torch.tensor(0.8), + torch.tensor(0, dtype=torch.int64), + ) + + +@register_test +class SampleUnseededTest(OpTestCase): + """Unseeded sample lowers without a seed field (only the cond ItemIntNode).""" + + name = "sample_unseeded" + skip_comparison = True # sampling RNG is not host/device bit-identical + expected_node_counts = { + "IfNode": 1, + "RandomBitsNode": 1, + "ArgmaxNode": 2, + "ItemIntNode": 1, # temperature>0 condition only (no seed) + "SoftmaxNode": 1, # top-p nucleus chain (top_p defaults to 1.0) + } + + def create_model(self) -> nn.Module: + return UnseededSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(1, 4, 256), torch.tensor(0.8)) + + +@register_test +class SampleTopPTest(OpTestCase): + """Top-p sample emits the nucleus chain (softmax/sort/cumsum/min/where).""" + + name = "sample_top_p" + skip_comparison = True # sampling RNG is not host/device bit-identical + expected_node_counts = { + "IfNode": 1, + "RandomBitsNode": 1, + "ArgmaxNode": 2, + "ItemIntNode": 2, + "SoftmaxNode": 1, + "SortNode": 1, + "CumsumNode": 1, + "MinNode": 1, + "WhereNode": 2, + } + + def create_model(self) -> nn.Module: + return TopPSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn(1, 4, 256), + torch.tensor(0.8), + torch.tensor(0, dtype=torch.int64), + torch.tensor(0.9), + ) + + +def _ref_gumbel_max(logits: torch.Tensor, temperature: float, seed: int): + """Independent Gumbel-max reference using the same torch RNG as the op.""" + gen = torch.Generator().manual_seed(seed) + u = torch.rand(logits.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(logits / temperature + gumbel, dim=-1) + + +def _tv_distance(p: torch.Tensor, q: torch.Tensor) -> float: + """Total-variation distance between two discrete distributions.""" + return 0.5 * torch.abs(p - q).sum().item() + + +def _sample(logits, temperature, seed: Optional[int], top_p: float = 1.0): + t = torch.tensor(float(temperature)) + s = None if seed is None else torch.tensor(int(seed), dtype=torch.int64) + p = torch.tensor(float(top_p)) # 1.0 = off + return torch.ops.mlx.sample(logits, t, p, s) + + +class TestSampleOp(unittest.TestCase): + """Eager reference behavior of mlx::sample (no export / no runtime).""" + + def test_greedy_parity_small_temperature(self): + # Small temperature -> Gumbel-max collapses to argmax(logits). + torch.manual_seed(0) + logits = torch.randn(8, 1024) + token = _sample(logits, 1e-4, seed=0) + self.assertTrue(torch.equal(token, torch.argmax(logits, dim=-1))) + + def test_greedy_temperature_zero(self): + # temperature == 0 is exact greedy: argmax(logits), no RNG, no division. + torch.manual_seed(0) + logits = torch.randn(8, 1024) + token = _sample(logits, 0.0, seed=0) + self.assertTrue(torch.equal(token, torch.argmax(logits, dim=-1))) + + def test_matches_independent_gumbel_reference(self): + # Same seed -> bit-identical token vs an independent Gumbel-max impl. + torch.manual_seed(1) + logits = torch.randn(8, 512) + for seed in (0, 1, 7, 42): + got = _sample(logits, 0.8, seed=seed) + expected = _ref_gumbel_max(logits, 0.8, seed) + self.assertTrue(torch.equal(got, expected), f"mismatch at seed={seed}") + + def test_distribution_matches_softmax(self): + # Empirical token frequencies match softmax(logits / T). + vocab = 5 + temperature = 1.0 + torch.manual_seed(0) + base = torch.randn(vocab) + n = 20000 + tokens = _sample(base.expand(n, vocab), temperature, seed=0) + + empirical = torch.bincount(tokens, minlength=vocab).float() / n + target = torch.softmax(base / temperature, dim=-1) + tv = _tv_distance(empirical, target) + self.assertLess(tv, 0.05, f"TV distance {tv:.4f} too large") + + def test_determinism_seeded(self): + # Same seed -> identical draws; different seed -> different draws. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=123) + b = _sample(logits, 1.0, seed=123) + c = _sample(logits, 1.0, seed=124) + self.assertTrue(torch.equal(a, b)) + self.assertFalse(torch.equal(a, c)) + + def test_unseeded_varies_across_calls(self): + # seed=None uses the global RNG -> draws vary, tokens stay in range. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=None) + b = _sample(logits, 1.0, seed=None) + self.assertFalse(torch.equal(a, b)) + self.assertGreaterEqual(int(a.min()), 0) + self.assertLess(int(a.max()), 64) + + def test_top_p_restricts_to_nucleus(self): + # probs [0.5, 0.3, 0.15, 0.05]; top_p=0.9 keeps {0,1,2}, drops index 3. + base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) + tokens = _sample(base.expand(5000, 4), 1.0, seed=0, top_p=0.9) + self.assertTrue((tokens != 3).all()) # tail token never drawn + self.assertEqual(set(tokens.tolist()), {0, 1, 2}) # nucleus covered + + def test_top_p_one_keeps_all(self): + # top_p=1.0 -> no filtering; the tail token (index 3) is reachable. + base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) + tokens = _sample(base.expand(20000, 4), 1.0, seed=0, top_p=1.0) + self.assertTrue((tokens == 3).any()) diff --git a/backends/mlx/test/test_sample.py b/backends/mlx/test/test_sample.py index a01b14ce3e5..e7f2b39e2af 100644 --- a/backends/mlx/test/test_sample.py +++ b/backends/mlx/test/test_sample.py @@ -6,12 +6,10 @@ # LICENSE file in the root directory of this source tree. """ -Tests for mlx::sample (Gumbel-max token sampling). +Export and on-device tests for mlx::sample (Gumbel-max token sampling). -Most tests exercise the op's eager reference implementation and the -export/partition/serialization path. The end-to-end test runs the exported -program through the compiled op_test_runner (see backends/mlx/test/README.md -for building it). +The end-to-end cases run the exported program through the compiled +op_test_runner (see backends/mlx/test/README.md). Usage: python -m unittest executorch.backends.mlx.test.test_sample @@ -21,7 +19,6 @@ import tempfile import unittest from pathlib import Path -from typing import Optional # Registers torch.ops.mlx.sample. import executorch.backends.mlx.custom_ops # noqa: F401 @@ -29,31 +26,13 @@ import torch.nn as nn from executorch.backends.mlx.llm.sampling import SamplingHead from executorch.backends.mlx.test.test_utils import ( - count_mlx_delegate_segments, export_model_to_pte, - get_mlx_node_counts, + load_tensors_from_bin, + run_cpp_test_runner, + save_tensors_to_bin, ) -def _ref_gumbel_max(logits: torch.Tensor, temperature: float, seed: int): - """Independent Gumbel-max reference using the same torch RNG as the op.""" - gen = torch.Generator().manual_seed(seed) - u = torch.rand(logits.shape, generator=gen) - gumbel = -torch.log(-torch.log(u)) - return torch.argmax(logits / temperature + gumbel, dim=-1) - - -def _tv_distance(p: torch.Tensor, q: torch.Tensor) -> float: - """Total-variation distance between two discrete distributions.""" - return 0.5 * torch.abs(p - q).sum().item() - - -def _sample(logits, temperature, seed: Optional[int]): - t = torch.tensor(float(temperature)) - s = None if seed is None else torch.tensor(int(seed), dtype=torch.int64) - return torch.ops.mlx.sample(logits, t, s) - - class _LogitsPassthrough(nn.Module): """Stand-in for a model returning logits [B, S, vocab].""" @@ -72,77 +51,19 @@ def forward(self, logits, temperature, seed): return self.head(logits, temperature=temperature, seed=seed) -class UnseededSampleModel(nn.Module): - """SamplingHead with temperature as a runtime input and no seed.""" +class TopPSampleModel(nn.Module): + """SamplingHead with temperature, seed, and top_p as runtime inputs.""" def __init__(self): super().__init__() self.head = SamplingHead(_LogitsPassthrough()) - def forward(self, logits, temperature): - return self.head(logits, temperature=temperature) - - -class TestSampleOp(unittest.TestCase): - """Eager reference behavior of mlx::sample (no export / no runtime).""" - - def test_greedy_parity_small_temperature(self): - # Small temperature -> Gumbel-max collapses to argmax(logits). - torch.manual_seed(0) - logits = torch.randn(8, 1024) - token = _sample(logits, 1e-4, seed=0) - self.assertTrue(torch.equal(token, torch.argmax(logits, dim=-1))) - - def test_matches_independent_gumbel_reference(self): - # Same seed -> bit-identical token vs an independent Gumbel-max impl. - torch.manual_seed(1) - logits = torch.randn(8, 512) - for seed in (0, 1, 7, 42): - got = _sample(logits, 0.8, seed=seed) - expected = _ref_gumbel_max(logits, 0.8, seed) - self.assertTrue(torch.equal(got, expected), f"mismatch at seed={seed}") - - def test_distribution_matches_softmax(self): - # Empirical token frequencies match softmax(logits / T). - vocab = 5 - temperature = 1.0 - torch.manual_seed(0) - base = torch.randn(vocab) - n = 20000 - tokens = _sample(base.expand(n, vocab), temperature, seed=0) - - empirical = torch.bincount(tokens, minlength=vocab).float() / n - target = torch.softmax(base / temperature, dim=-1) - tv = _tv_distance(empirical, target) - self.assertLess(tv, 0.05, f"TV distance {tv:.4f} too large") - - def test_determinism_seeded(self): - # Same seed -> identical draws; different seed -> different draws. - torch.manual_seed(0) - logits = torch.randn(256, 64) - a = _sample(logits, 1.0, seed=123) - b = _sample(logits, 1.0, seed=123) - c = _sample(logits, 1.0, seed=124) - self.assertTrue(torch.equal(a, b)) - self.assertFalse(torch.equal(a, c)) - - def test_unseeded_varies_across_calls(self): - # seed=None uses the global RNG -> draws vary, tokens stay in range. - torch.manual_seed(0) - logits = torch.randn(256, 64) - a = _sample(logits, 1.0, seed=None) - b = _sample(logits, 1.0, seed=None) - self.assertFalse(torch.equal(a, b)) - self.assertGreaterEqual(int(a.min()), 0) - self.assertLess(int(a.max()), 64) + def forward(self, logits, temperature, seed, top_p): + return self.head(logits, temperature=temperature, seed=seed, top_p=top_p) class TestSampleExport(unittest.TestCase): - """torch.export, runtime inputs, and MLXPartitioner lowering.""" - - def setUp(self): - self._tmp = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, self._tmp, ignore_errors=True) + """torch.export and runtime-input semantics of the sampling head.""" def test_runtime_temperature_single_export(self): # One exported program run at two temperatures (no re-export) confirms @@ -198,73 +119,19 @@ def test_export_strict_with_graph_inputs(self): ) self.assertEqual(len(ep.graph_signature.user_inputs), 3) - def test_seeded_lowers_to_mlx_delegate(self): - # The op is assigned to the MLX delegate, with the seed threaded in. - pte = Path(self._tmp) / "seeded.pte" - logits = torch.randn(1, 4, 256) - export_model_to_pte( - SeededSampleModel(), - (logits, torch.tensor(0.8), torch.tensor(0, dtype=torch.int64)), - pte, - ) - self.assertEqual(count_mlx_delegate_segments(pte), 1) - counts = get_mlx_node_counts(pte) - self.assertEqual(counts.get("RandomBitsNode", 0), 1) - self.assertEqual(counts.get("ArgmaxNode", 0), 1) - self.assertEqual(counts.get("ItemIntNode", 0), 1) # seed via .item() - - def test_unseeded_lowers_without_seed_field(self): - # seed=None lowers cleanly: RandomBitsNode emitted with no seed field - # (hence no ItemIntNode threading a seed Vid). - pte = Path(self._tmp) / "unseeded.pte" - logits = torch.randn(1, 4, 256) - export_model_to_pte(UnseededSampleModel(), (logits, torch.tensor(0.8)), pte) - self.assertEqual(count_mlx_delegate_segments(pte), 1) - counts = get_mlx_node_counts(pte) - self.assertEqual(counts.get("RandomBitsNode", 0), 1) - self.assertEqual(counts.get("ItemIntNode", 0), 0) - class TestSampleEndToEnd(unittest.TestCase): - """Run the exported program through the compiled op_test_runner.""" + """On-device checks whose assertions the output-compare harness can't express.""" def setUp(self): self._tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self._tmp, ignore_errors=True) - def test_end_to_end(self): - # Requires the compiled op_test_runner (see backends/mlx/test/README.md). - from executorch.backends.mlx.test.test_utils import ( - load_tensors_from_bin, - run_cpp_test_runner, - save_tensors_to_bin, - ) - - vocab = 32 - logits = torch.randn(1, 4, vocab) - inputs = (logits, torch.tensor(0.8), torch.tensor(0, dtype=torch.int64)) - - tmp = Path(self._tmp) - pte, in_bin, out_bin = tmp / "e2e.pte", tmp / "in.bin", tmp / "out.bin" - export_model_to_pte(SeededSampleModel(), inputs, pte) - save_tensors_to_bin(list(inputs), in_bin) - - self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) - (token,) = load_tensors_from_bin(out_bin) - self.assertEqual(tuple(token.shape), (1,)) - self.assertTrue(0 <= int(token) < vocab) - def test_bf16_large_vocab_greedy_parity(self): # Regression: bf16 logits + large vocab. A dominant logit must win under # near-greedy sampling. Catches the bug where casting the uniform to bf16 # rounded the clamp (~0.99999994) up to 1.0 -> log(0) -> +inf gumbel, # which then beat even a huge logit and produced a constant wrong token. - from executorch.backends.mlx.test.test_utils import ( - load_tensors_from_bin, - run_cpp_test_runner, - save_tensors_to_bin, - ) - torch.manual_seed(0) vocab = 4000 logits = torch.randn(1, 4, vocab, dtype=torch.bfloat16) @@ -280,6 +147,40 @@ def test_bf16_large_vocab_greedy_parity(self): (token,) = load_tensors_from_bin(out_bin) self.assertEqual(int(token), 1234) + def test_greedy_temperature_zero_end_to_end(self): + # temperature=0 takes the IfNode greedy branch -> exact argmax on device. + torch.manual_seed(0) + vocab = 64 + logits = torch.randn(1, 4, vocab) + inputs = (logits, torch.tensor(0.0), torch.tensor(0, dtype=torch.int64)) + + tmp = Path(self._tmp) + pte, in_bin, out_bin = tmp / "greedy.pte", tmp / "in.bin", tmp / "out.bin" + export_model_to_pte(SeededSampleModel(), inputs, pte) + save_tensors_to_bin(list(inputs), in_bin) + + self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) + (token,) = load_tensors_from_bin(out_bin) + self.assertEqual(int(token), int(torch.argmax(logits[0, -1]))) + + def test_top_p_end_to_end(self): + # On-device nucleus: probs [0.5,0.3,0.15,0.05], top_p=0.9 -> token in {0,1,2}. + logits = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])).view(1, 1, 4) + inputs = ( + logits, + torch.tensor(1.0), + torch.tensor(0, dtype=torch.int64), + torch.tensor(0.9), + ) + tmp = Path(self._tmp) + pte, in_bin, out_bin = tmp / "topp.pte", tmp / "in.bin", tmp / "out.bin" + export_model_to_pte(TopPSampleModel(), inputs, pte) + save_tensors_to_bin(list(inputs), in_bin) + + self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) + (token,) = load_tensors_from_bin(out_bin) + self.assertIn(int(token), {0, 1, 2}) # tail token (index 3) excluded + if __name__ == "__main__": unittest.main() From 6c5a9ce3a7ccc975baf1e5eb31607a6e87d9d39a Mon Sep 17 00:00:00 2001 From: kiymetakdemir Date: Thu, 25 Jun 2026 10:00:26 -0700 Subject: [PATCH 4/5] MLX sample: seed as optional Vid; greedy/bf16 as OpTestCase; negative-temp parity; move unittest tests to test_sample.py --- .github/workflows/mlx.yml | 3 - backends/mlx/custom_ops.py | 7 +- backends/mlx/ops.py | 2 +- backends/mlx/runtime/MLXInterpreter.h | 3 +- backends/mlx/serialization/generate.py | 51 +++------- backends/mlx/serialization/schema.fbs | 2 +- backends/mlx/test/test_ops.py | 133 ++++++++----------------- backends/mlx/test/test_sample.py | 131 +++++++++++++++--------- 8 files changed, 146 insertions(+), 186 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 9794f8080dc..2a10b1005c9 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -74,8 +74,6 @@ jobs: echo "::endgroup::" echo "::group::Run Python unit tests" - # test_ops.py is listed for its unittest classes; run_all_tests.py only - # runs its OpTestCase classes. ${CONDA_RUN} python -m pytest \ backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ @@ -83,7 +81,6 @@ jobs: backends/mlx/test/test_serialization_dedup.py \ backends/mlx/test/test_slot_recycling.py \ backends/mlx/test/test_sample.py \ - backends/mlx/test/test_ops.py \ examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index 74e8a13586d..17c07097f70 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -403,8 +403,9 @@ def sample( """ Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus). logits: [B, vocab] - temperature: scalar float tensor (runtime input). temperature == 0 is - greedy: return argmax(logits) directly. + temperature: scalar float tensor (runtime input). temperature <= 0 is + greedy: return argmax(logits) directly (matches the device, + which branches on temperature > 0). top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every token, i.e. it is off. seed: scalar int tensor or None @@ -417,7 +418,7 @@ def sample( RNG (plain torch.rand, no uint32/nextafter uniform) while the delegate uses MLX RNG, so a given seed does not reproduce the same tokens host vs. device. """ - if float(temperature) == 0: + if float(temperature) <= 0: # matches the device cond (temperature > 0) return torch.argmax(logits, dim=-1) # whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties). scaled = logits.float() / temperature diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 82e322b5d85..3ca87878cd9 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -3496,7 +3496,7 @@ def emit_sample(): if seed is not None: _, seed_val = P.make_tmp_value_slot() P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val))) - seed_field = P.to_int_or_vid(seed_val) + seed_field = P.slot_to_vid(seed_val) # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95) _, bits = P.make_tmp_slot() diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 95a09e33ac8..f420142787e 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1708,7 +1708,8 @@ inline void exec_random_bits( check_allocation_bounded(shape, uint32, "random_bits"); std::optional key = std::nullopt; if (n.seed.has_value()) { - key = random::key(static_cast(resolve_int(n.seed.value(), st))); + key = random::key( + static_cast(st.const_value_ref(n.seed.value()))); } st.set_tensor(n.out, random::bits(shape, n.width, key, s)); } diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index 5b05c1b8b9d..fd0b5b672b0 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -572,7 +572,7 @@ def generate_python_serializers(schema: FBSSchema) -> str: "", "from __future__ import annotations", "", - "from typing import Dict, List, Optional, Tuple", + "from typing import List, Tuple, Dict", "", "import flatbuffers", "", @@ -640,10 +640,8 @@ def generate_python_serializers(schema: FBSSchema) -> str: "class GeneratedOpBuilders:", ' """Mixin class with auto-generated op builder methods."""', "", - " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> Optional[int]:", - ' """Build an IntOrVid table (None -> absent field, like _shared_string)."""', - " if iov is None:", - " return None", + " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> int:", + ' """Build an IntOrVid table."""', " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVid as FBIntOrVidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", "", @@ -655,10 +653,8 @@ def generate_python_serializers(schema: FBSSchema) -> str: " FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx))", " return FBIntOrVidModule.End(builder)", "", - " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> Optional[int]:", - ' """Build a FloatOrVid table (None -> absent field)."""', - " if fov is None:", - " return None", + " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> int:", + ' """Build a FloatOrVid table."""', " from executorch.backends.mlx.serialization._generated.mlx_delegate import FloatOrVid as FBFloatOrVidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", "", @@ -669,10 +665,8 @@ def generate_python_serializers(schema: FBSSchema) -> str: " FBFloatOrVidModule.AddVid(builder, CreateVid(builder, fov.vid.idx))", " return FBFloatOrVidModule.End(builder)", "", - " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> Optional[int]:", - ' """Build a TidOrVid table (None -> absent field)."""', - " if vot is None:", - " return None", + " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> int:", + ' """Build a TidOrVid table."""', " from executorch.backends.mlx.serialization._generated.mlx_delegate import VidOrTid as FBVidOrTidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", @@ -685,10 +679,8 @@ def generate_python_serializers(schema: FBSSchema) -> str: " FBVidOrTidModule.AddVid(builder, CreateVid(builder, vot.vid.idx))", " return FBVidOrTidModule.End(builder)", "", - " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> Optional[int]:", - ' """Build an IntOrVidOrTid table (None -> absent field)."""', - " if ivt is None:", - " return None", + " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> int:", + ' """Build an IntOrVidOrTid table."""', " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVidOrTid as FBIntOrVidOrTidModule", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", @@ -863,12 +855,7 @@ def _emit_py_add( return [f" {add}(builder, op.{n})"] # Pre-built offsets (string, compound types) if kind in ("str", "int_or_vid", "float_or_vid", "vid_or_tid", "int_or_vid_or_tid"): - if fld.required: - return [f" {add}(builder, {n}_off)"] - return [ - f" if {n}_off is not None:", - f" {add}(builder, {n}_off)", - ] + return [f" {add}(builder, {n}_off)"] # Pre-built vectors (required vs optional) if kind in ( "list_int", @@ -1069,8 +1056,6 @@ def _fbs_type_to_cpp( return "std::optional" if fbs_type == "Vid": return "std::optional" - if fbs_type in ("IntOrVid", "FloatOrVid", "VidOrTid", "IntOrVidOrTid"): - return f"std::optional<{cpp_type}>" if fld is not None and fld.default == "null" and fbs_type in FBS_TO_CPP: return f"std::optional<{cpp_type}>" @@ -1155,7 +1140,7 @@ def _generate_loader_case(table: FBSTable) -> List[str]: fb_field_name = fld.name kind = _get_field_kind(fld, table) - load_lines = _emit_cpp_load(kind, fld.name, fb_field_name, table, fld) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name, table) if load_lines is None: raise ValueError( f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " @@ -1187,24 +1172,16 @@ def _generate_loader_case(table: FBSTable) -> List[str]: } -def _emit_cpp_load( # noqa: C901 - kind: str, name: str, fb_name: str, table=None, fld=None +def _emit_cpp_load( + kind: str, name: str, fb_name: str, table=None ) -> "List[str] | None": """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" # Interned string fields share one std::string via the load-time pool. if _is_interned_str(table, name) and kind in ("str", "optional_str"): return [f" node.{name} = strpool.intern(fb->{fb_name}());"] - # Struct / compound via converter + # Required struct / compound via converter if kind in _CPP_CONVERTER: conv = _CPP_CONVERTER[kind] - # Optional compound fields (e.g. an optional IntOrVid) must be - # presence-guarded; convert_* throws on a null FlatBuffer pointer. - if fld is not None and not fld.required: - return [ - f" if (fb->{fb_name}()) {{", - f" node.{name} = {conv}(fb->{fb_name}());", - " }", - ] return [f" node.{name} = {conv}(fb->{fb_name}());"] # Scalars (direct value) if kind in ("int", "float", "bool"): diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index b5a7f737842..281199a8002 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -988,7 +988,7 @@ table IfNode { table RandomBitsNode { out: Tid (required); shape: [IntOrVid] (required); - seed: IntOrVid; // OPTIONAL: present -> random::key(seed); + seed: Vid; // OPTIONAL: present -> random::key(seed); // absent -> MLX global KeySequence width: int32 = 4; // bytes per element (4 -> uint32) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 03956d3bd08..7d9b1a3b777 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -25,7 +25,6 @@ """ import os -import unittest from typing import Callable, Dict, List, Optional, Tuple import torch @@ -7703,95 +7702,43 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: ) -def _ref_gumbel_max(logits: torch.Tensor, temperature: float, seed: int): - """Independent Gumbel-max reference using the same torch RNG as the op.""" - gen = torch.Generator().manual_seed(seed) - u = torch.rand(logits.shape, generator=gen) - gumbel = -torch.log(-torch.log(u)) - return torch.argmax(logits / temperature + gumbel, dim=-1) - - -def _tv_distance(p: torch.Tensor, q: torch.Tensor) -> float: - """Total-variation distance between two discrete distributions.""" - return 0.5 * torch.abs(p - q).sum().item() - - -def _sample(logits, temperature, seed: Optional[int], top_p: float = 1.0): - t = torch.tensor(float(temperature)) - s = None if seed is None else torch.tensor(int(seed), dtype=torch.int64) - p = torch.tensor(float(top_p)) # 1.0 = off - return torch.ops.mlx.sample(logits, t, p, s) - - -class TestSampleOp(unittest.TestCase): - """Eager reference behavior of mlx::sample (no export / no runtime).""" - - def test_greedy_parity_small_temperature(self): - # Small temperature -> Gumbel-max collapses to argmax(logits). - torch.manual_seed(0) - logits = torch.randn(8, 1024) - token = _sample(logits, 1e-4, seed=0) - self.assertTrue(torch.equal(token, torch.argmax(logits, dim=-1))) - - def test_greedy_temperature_zero(self): - # temperature == 0 is exact greedy: argmax(logits), no RNG, no division. - torch.manual_seed(0) - logits = torch.randn(8, 1024) - token = _sample(logits, 0.0, seed=0) - self.assertTrue(torch.equal(token, torch.argmax(logits, dim=-1))) - - def test_matches_independent_gumbel_reference(self): - # Same seed -> bit-identical token vs an independent Gumbel-max impl. - torch.manual_seed(1) - logits = torch.randn(8, 512) - for seed in (0, 1, 7, 42): - got = _sample(logits, 0.8, seed=seed) - expected = _ref_gumbel_max(logits, 0.8, seed) - self.assertTrue(torch.equal(got, expected), f"mismatch at seed={seed}") - - def test_distribution_matches_softmax(self): - # Empirical token frequencies match softmax(logits / T). - vocab = 5 - temperature = 1.0 - torch.manual_seed(0) - base = torch.randn(vocab) - n = 20000 - tokens = _sample(base.expand(n, vocab), temperature, seed=0) - - empirical = torch.bincount(tokens, minlength=vocab).float() / n - target = torch.softmax(base / temperature, dim=-1) - tv = _tv_distance(empirical, target) - self.assertLess(tv, 0.05, f"TV distance {tv:.4f} too large") - - def test_determinism_seeded(self): - # Same seed -> identical draws; different seed -> different draws. - torch.manual_seed(0) - logits = torch.randn(256, 64) - a = _sample(logits, 1.0, seed=123) - b = _sample(logits, 1.0, seed=123) - c = _sample(logits, 1.0, seed=124) - self.assertTrue(torch.equal(a, b)) - self.assertFalse(torch.equal(a, c)) - - def test_unseeded_varies_across_calls(self): - # seed=None uses the global RNG -> draws vary, tokens stay in range. - torch.manual_seed(0) - logits = torch.randn(256, 64) - a = _sample(logits, 1.0, seed=None) - b = _sample(logits, 1.0, seed=None) - self.assertFalse(torch.equal(a, b)) - self.assertGreaterEqual(int(a.min()), 0) - self.assertLess(int(a.max()), 64) - - def test_top_p_restricts_to_nucleus(self): - # probs [0.5, 0.3, 0.15, 0.05]; top_p=0.9 keeps {0,1,2}, drops index 3. - base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) - tokens = _sample(base.expand(5000, 4), 1.0, seed=0, top_p=0.9) - self.assertTrue((tokens != 3).all()) # tail token never drawn - self.assertEqual(set(tokens.tolist()), {0, 1, 2}) # nucleus covered - - def test_top_p_one_keeps_all(self): - # top_p=1.0 -> no filtering; the tail token (index 3) is reachable. - base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) - tokens = _sample(base.expand(20000, 4), 1.0, seed=0, top_p=1.0) - self.assertTrue((tokens == 3).any()) +@register_test +class SampleGreedyTest(OpTestCase): + """Greedy argmax(logits) is bit-exact host/device, so verify the token with the + normal compare. Covers temperature=0, tiny temperature, and bf16 logits.""" + + name = "sample_greedy" + + def __init__(self, temperature: float = 0.0, dtype: torch.dtype = torch.float32): + self.temperature = temperature + self.dtype = dtype + if dtype == torch.bfloat16: + self.name = "sample_greedy_bf16" + elif temperature < 0: + self.name = "sample_greedy_neg" + elif temperature == 0.0: + self.name = "sample_greedy" + else: + self.name = "sample_greedy_eps" + + @classmethod + def get_test_configs(cls) -> List["SampleGreedyTest"]: + return [ + cls(temperature=0.0), + cls(temperature=1e-4), + cls(temperature=-1.0), # negative -> greedy on both paths (consistent) + cls(temperature=1e-4, dtype=torch.bfloat16), + ] + + def create_model(self) -> nn.Module: + return SeededSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + logits = torch.randn(1, 4, 1024, dtype=self.dtype) + if self.dtype == torch.bfloat16: + logits[0, -1, 512] = 50.0 # dominant -> unambiguous bf16 argmax + return ( + logits, + torch.tensor(self.temperature), + torch.tensor(0, dtype=torch.int64), + ) diff --git a/backends/mlx/test/test_sample.py b/backends/mlx/test/test_sample.py index e7f2b39e2af..ddb0734d13b 100644 --- a/backends/mlx/test/test_sample.py +++ b/backends/mlx/test/test_sample.py @@ -19,6 +19,7 @@ import tempfile import unittest from pathlib import Path +from typing import Optional # Registers torch.ops.mlx.sample. import executorch.backends.mlx.custom_ops # noqa: F401 @@ -62,8 +63,90 @@ def forward(self, logits, temperature, seed, top_p): return self.head(logits, temperature=temperature, seed=seed, top_p=top_p) +def _ref_gumbel_max(logits: torch.Tensor, temperature: float, seed: int): + """Independent Gumbel-max reference using the same torch RNG as the op.""" + gen = torch.Generator().manual_seed(seed) + u = torch.rand(logits.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(logits / temperature + gumbel, dim=-1) + + +def _tv_distance(p: torch.Tensor, q: torch.Tensor) -> float: + """Total-variation distance between two discrete distributions.""" + return 0.5 * torch.abs(p - q).sum().item() + + +def _sample(logits, temperature, seed: Optional[int], top_p: float = 1.0): + t = torch.tensor(float(temperature)) + s = None if seed is None else torch.tensor(int(seed), dtype=torch.int64) + p = torch.tensor(float(top_p)) # 1.0 = off + return torch.ops.mlx.sample(logits, t, p, s) + + +class TestSampleOp(unittest.TestCase): + """Eager reference behavior of mlx::sample (no export / no runtime).""" + + def test_matches_independent_gumbel_reference(self): + # Same seed -> bit-identical token vs an independent Gumbel-max impl. + torch.manual_seed(1) + logits = torch.randn(8, 512) + for seed in (0, 1, 7, 42): + got = _sample(logits, 0.8, seed=seed) + expected = _ref_gumbel_max(logits, 0.8, seed) + self.assertTrue(torch.equal(got, expected), f"mismatch at seed={seed}") + + def test_distribution_matches_softmax(self): + # Empirical token frequencies match softmax(logits / T). + vocab = 5 + temperature = 1.0 + torch.manual_seed(0) + base = torch.randn(vocab) + n = 20000 + tokens = _sample(base.expand(n, vocab), temperature, seed=0) + + empirical = torch.bincount(tokens, minlength=vocab).float() / n + target = torch.softmax(base / temperature, dim=-1) + tv = _tv_distance(empirical, target) + self.assertLess(tv, 0.05, f"TV distance {tv:.4f} too large") + + def test_determinism_seeded(self): + # Same seed -> identical draws; different seed -> different draws. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=123) + b = _sample(logits, 1.0, seed=123) + c = _sample(logits, 1.0, seed=124) + self.assertTrue(torch.equal(a, b)) + self.assertFalse(torch.equal(a, c)) + + def test_unseeded_varies_across_calls(self): + # seed=None uses the global RNG -> draws vary, tokens stay in range. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=None) + b = _sample(logits, 1.0, seed=None) + self.assertFalse(torch.equal(a, b)) + self.assertGreaterEqual(int(a.min()), 0) + self.assertLess(int(a.max()), 64) + + def test_top_p_restricts_to_nucleus(self): + # probs [0.5, 0.3, 0.15, 0.05]; top_p=0.9 keeps {0,1,2}, drops index 3. + base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) + tokens = _sample(base.expand(5000, 4), 1.0, seed=0, top_p=0.9) + self.assertTrue((tokens != 3).all()) # tail token never drawn + self.assertEqual(set(tokens.tolist()), {0, 1, 2}) # nucleus covered + + def test_top_p_one_keeps_all(self): + # top_p=1.0 -> no filtering; the tail token (index 3) is reachable. + base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) + tokens = _sample(base.expand(20000, 4), 1.0, seed=0, top_p=1.0) + self.assertTrue((tokens == 3).any()) + + class TestSampleExport(unittest.TestCase): - """torch.export and runtime-input semantics of the sampling head.""" + """Runtime-input semantics that survive export: temperature and seed stay + live graph inputs (not constant-folded). Lowering/partition is covered by the + OpTestCase classes in test_ops.py.""" def test_runtime_temperature_single_export(self): # One exported program run at two temperatures (no re-export) confirms @@ -109,16 +192,6 @@ def test_seeded_export_reproducible_no_host_rng(self): other = run(logits, torch.tensor(1.0), torch.tensor(124, dtype=torch.int64)) self.assertFalse(torch.equal(first, other)) - def test_export_strict_with_graph_inputs(self): - # strict=True export keeps logits, temperature, and seed as graph inputs. - logits = torch.randn(1, 4, 256) - ep = torch.export.export( - SeededSampleModel(), - (logits, torch.tensor(0.8), torch.tensor(0, dtype=torch.int64)), - strict=True, - ) - self.assertEqual(len(ep.graph_signature.user_inputs), 3) - class TestSampleEndToEnd(unittest.TestCase): """On-device checks whose assertions the output-compare harness can't express.""" @@ -127,42 +200,6 @@ def setUp(self): self._tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self._tmp, ignore_errors=True) - def test_bf16_large_vocab_greedy_parity(self): - # Regression: bf16 logits + large vocab. A dominant logit must win under - # near-greedy sampling. Catches the bug where casting the uniform to bf16 - # rounded the clamp (~0.99999994) up to 1.0 -> log(0) -> +inf gumbel, - # which then beat even a huge logit and produced a constant wrong token. - torch.manual_seed(0) - vocab = 4000 - logits = torch.randn(1, 4, vocab, dtype=torch.bfloat16) - logits[0, -1, 1234] = 50.0 # unambiguous argmax - inputs = (logits, torch.tensor(1e-4), torch.tensor(0, dtype=torch.int64)) - - tmp = Path(self._tmp) - pte, in_bin, out_bin = tmp / "bf16.pte", tmp / "in.bin", tmp / "out.bin" - export_model_to_pte(SeededSampleModel(), inputs, pte) - save_tensors_to_bin(list(inputs), in_bin) - - self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) - (token,) = load_tensors_from_bin(out_bin) - self.assertEqual(int(token), 1234) - - def test_greedy_temperature_zero_end_to_end(self): - # temperature=0 takes the IfNode greedy branch -> exact argmax on device. - torch.manual_seed(0) - vocab = 64 - logits = torch.randn(1, 4, vocab) - inputs = (logits, torch.tensor(0.0), torch.tensor(0, dtype=torch.int64)) - - tmp = Path(self._tmp) - pte, in_bin, out_bin = tmp / "greedy.pte", tmp / "in.bin", tmp / "out.bin" - export_model_to_pte(SeededSampleModel(), inputs, pte) - save_tensors_to_bin(list(inputs), in_bin) - - self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) - (token,) = load_tensors_from_bin(out_bin) - self.assertEqual(int(token), int(torch.argmax(logits[0, -1]))) - def test_top_p_end_to_end(self): # On-device nucleus: probs [0.5,0.3,0.15,0.05], top_p=0.9 -> token in {0,1,2}. logits = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])).view(1, 1, 4) From a2953d47ad3b309ea26e851cdc84528598591e5d Mon Sep 17 00:00:00 2001 From: kiymetakdemir Date: Thu, 25 Jun 2026 11:08:36 -0700 Subject: [PATCH 5/5] MLX sample: reserve top_k in SamplingHead; add batch greedy OpTestCase; tag-style config naming --- backends/mlx/llm/sampling.py | 11 ++++++---- backends/mlx/runtime/MLXInterpreter.h | 3 ++- backends/mlx/test/test_ops.py | 30 ++++++++++++++------------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/backends/mlx/llm/sampling.py b/backends/mlx/llm/sampling.py index cb8d39c64e5..a059e5cff08 100644 --- a/backends/mlx/llm/sampling.py +++ b/backends/mlx/llm/sampling.py @@ -14,22 +14,25 @@ class SamplingHead(nn.Module): """ Wraps a model that returns logits and samples a token id on-device. - forward(*model_args, temperature, seed=None, top_p=1.0, **model_kwargs) - -> token_id + forward(*model_args, temperature, top_k=None, top_p=1.0, seed=None, + **model_kwargs) -> token_id temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be >= 0; temperature=0 is greedy (returns argmax, no division). - seed: scalar int tensor (seeded) or None (unseeded export) + top_k: not implemented yet (reserved); must be None. top_p: scalar float tensor in (0, 1] for nucleus sampling. top_p=1.0 (the default) keeps every token, i.e. no filtering. Pass it as a runtime input to tune per request. + seed: scalar int tensor (seeded) or None (unseeded export) """ def __init__(self, model: nn.Module): super().__init__() self.model = model - def forward(self, *args, temperature, seed=None, top_p=1.0, **kwargs): + def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs): + if top_k is not None: + raise NotImplementedError("top_k sampling is not implemented") logits = self.model(*args, **kwargs) # [B, S, vocab] last = logits[:, -1, :] # [B, vocab] if not isinstance(top_p, torch.Tensor): diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index f420142787e..3c3c2c323a8 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1699,7 +1699,8 @@ inline void exec_random_bits( const RandomBitsNode& n, ExecutionState& st, StreamOrDevice s) { - // random::bits supports width (bytes/element) in {1, 2, 4} -> uint8/uint16/uint32. + // random::bits supports width (bytes/element) in {1, 2, 4} -> + // uint8/uint16/uint32. if (n.width != 1 && n.width != 2 && n.width != 4) { throw std::runtime_error("random_bits: width must be 1, 2, or 4"); } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 7d9b1a3b777..6c022116066 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -7705,36 +7705,38 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: @register_test class SampleGreedyTest(OpTestCase): """Greedy argmax(logits) is bit-exact host/device, so verify the token with the - normal compare. Covers temperature=0, tiny temperature, and bf16 logits.""" + normal compare. Covers temperature=0, tiny temperature, bf16 logits, and a + batch (per-row argmax -> [B] on device).""" name = "sample_greedy" - def __init__(self, temperature: float = 0.0, dtype: torch.dtype = torch.float32): + def __init__( + self, + temperature: float = 0.0, + dtype: torch.dtype = torch.float32, + batch: int = 1, + tag: str = "", + ): self.temperature = temperature self.dtype = dtype - if dtype == torch.bfloat16: - self.name = "sample_greedy_bf16" - elif temperature < 0: - self.name = "sample_greedy_neg" - elif temperature == 0.0: - self.name = "sample_greedy" - else: - self.name = "sample_greedy_eps" + self.batch = batch + self.name = f"sample_greedy_{tag}" if tag else "sample_greedy" @classmethod def get_test_configs(cls) -> List["SampleGreedyTest"]: return [ cls(temperature=0.0), - cls(temperature=1e-4), - cls(temperature=-1.0), # negative -> greedy on both paths (consistent) - cls(temperature=1e-4, dtype=torch.bfloat16), + cls(temperature=1e-4, tag="eps"), + cls(temperature=-1.0, tag="neg"), # negative -> greedy on both paths + cls(temperature=1e-4, dtype=torch.bfloat16, tag="bf16"), + cls(temperature=0.0, batch=4, tag="batch"), # per-row argmax over a batch ] def create_model(self) -> nn.Module: return SeededSampleModel() def create_inputs(self) -> Tuple[torch.Tensor, ...]: - logits = torch.randn(1, 4, 1024, dtype=self.dtype) + logits = torch.randn(self.batch, 4, 1024, dtype=self.dtype) if self.dtype == torch.bfloat16: logits[0, -1, 512] = 50.0 # dominant -> unambiguous bf16 argmax return (