diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 167ceb7da83..40ebabbcf28 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -84,6 +84,7 @@ jobs: backends/mlx/test/test_partitioner.py \ backends/mlx/test/test_serialization_dedup.py \ backends/mlx/test/test_slot_recycling.py \ + backends/mlx/test/test_sample.py \ examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index c03db05d918..17c07097f70 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -391,3 +391,54 @@ def gather_qmm_fake( else: batch = w.shape[:-2] return x.new_empty((*batch, M, N)) + + +@torch.library.custom_op("mlx::sample", mutates_args=()) +def sample( + logits: Tensor, + temperature: Tensor, + top_p: Tensor, + seed: Optional[Tensor] = None, +) -> Tensor: + """ + Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus). + logits: [B, vocab] + temperature: scalar float tensor (runtime input). temperature <= 0 is + greedy: return argmax(logits) directly (matches the device, + which branches on temperature > 0). + top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every + token, i.e. it is off. + seed: scalar int tensor or None + - tensor -> deterministic, keyed RNG (random::key(seed)) + - None -> MLX global KeySequence (non-deterministic) + -> token_id: [B] int64 + + Host/CPU reference used for export (shape/meta) and distributional checks + only. It is NOT bit-identical to the lowered on-device graph: this uses torch + RNG (plain torch.rand, no uint32/nextafter uniform) while the delegate uses + MLX RNG, so a given seed does not reproduce the same tokens host vs. device. + """ + if float(temperature) <= 0: # matches the device cond (temperature > 0) + return torch.argmax(logits, dim=-1) + # whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties). + scaled = logits.float() / temperature + probs = torch.softmax(scaled, dim=-1) + s_probs, _ = torch.sort(probs, dim=-1, descending=True) + cum = torch.cumsum(s_probs, dim=-1) + keep = (cum - s_probs) <= top_p + thresh = torch.where(keep, s_probs, s_probs.new_tensor(float("inf"))).amin( + dim=-1, keepdim=True + ) + scaled = torch.where(probs >= thresh, scaled, scaled.new_tensor(float("-inf"))) + if seed is None: + u = torch.rand(scaled.shape) # global RNG + else: + gen = torch.Generator().manual_seed(int(seed.item())) + u = torch.rand(scaled.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(scaled + gumbel, dim=-1) + + +@torch.library.register_fake("mlx::sample") +def sample_fake(logits, temperature, top_p, seed=None): + return logits.new_empty(logits.shape[:-1], dtype=torch.long) diff --git a/backends/mlx/llm/sampling.py b/backends/mlx/llm/sampling.py new file mode 100644 index 00000000000..a059e5cff08 --- /dev/null +++ b/backends/mlx/llm/sampling.py @@ -0,0 +1,40 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + + +class SamplingHead(nn.Module): + """ + Wraps a model that returns logits and samples a token id on-device. + + forward(*model_args, temperature, top_k=None, top_p=1.0, seed=None, + **model_kwargs) -> token_id + + temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be >= 0; + temperature=0 is greedy (returns argmax, no division). + top_k: not implemented yet (reserved); must be None. + top_p: scalar float tensor in (0, 1] for nucleus sampling. top_p=1.0 + (the default) keeps every token, i.e. no filtering. Pass it + as a runtime input to tune per request. + seed: scalar int tensor (seeded) or None (unseeded export) + """ + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs): + if top_k is not None: + raise NotImplementedError("top_k sampling is not implemented") + logits = self.model(*args, **kwargs) # [B, S, vocab] + last = logits[:, -1, :] # [B, vocab] + if not isinstance(top_p, torch.Tensor): + top_p = torch.tensor(float(top_p)) + return torch.ops.mlx.sample(last, temperature, top_p, seed) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index e3a636466c1..00c45196d41 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -20,8 +20,10 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( + emit_if_else, emit_lifted_constant, emit_quantized_biases, + emit_shape, parse_dequant_node, to_mlx_qparams, torch_dtype_to_scalar_type, @@ -115,6 +117,7 @@ PartitionNode, PowerNode, ProdNode, + RandomBitsNode, ReciprocalNode, RemainderNode, RepeatNode, @@ -3513,6 +3516,203 @@ def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.mlx.sample.default]) +def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Gumbel-max sampling: argmax(logits / temperature + gumbel_noise). + + Reproduces MLX's uniform -> gumbel -> argmax layering in the IR using the + new RandomBitsNode plus existing elementwise nodes, so a sampled token id is + produced on-device instead of returning the full logits tensor. + + temperature == 0 is greedy: an IfNode branches to a plain argmax(logits), + skipping the sampling chain (so 0 is exact, not the small-epsilon approx). + """ + args = P.args(n) + require_args(args, 3, 4, "mlx.sample") + require_kwargs(P.kwargs(n), set(), "mlx.sample") + logits, temperature, top_p = args[0], args[1], args[2] + seed = args[3] if len(args) > 3 and args[3] is not None else None + + temp_dt = n.args[1].meta["val"].dtype + out = P.make_or_get_slot(n) + + def emit_greedy(): + P.emit( + ArgmaxNode( + x=P.slot_to_tid(logits), + out=P.slot_to_tid(out), + axis=-1, + keepdims=False, + ) + ) + + def emit_sample(): + shape = emit_shape(P, n.args[0], logits) + + # Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent -> + # leave RandomBitsNode.seed unset (MLX global RNG). + seed_field = None + if seed is not None: + _, seed_val = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val))) + seed_field = P.slot_to_vid(seed_val) + + # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95) + _, bits = P.make_tmp_slot() + P.emit( + RandomBitsNode( + out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field + ) + ) + _, bits_f = P.make_tmp_slot() + P.emit( + AsTypeNode( + x=P.slot_to_tid(bits), + out=P.slot_to_tid(bits_f), + scalar_type=torch_dtype_to_scalar_type(torch.float32), + ) + ) + umax = emit_lifted_constant(P, 4294967295.0, torch.float32) + _, div0 = P.make_tmp_slot() + P.emit( + DivideNode( + a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0) + ) + ) + prev1 = emit_lifted_constant( + P, + float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))), + torch.float32, + ) + _, clamp = P.make_tmp_slot() + P.emit( + MinimumNode( + a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp) + ) + ) + # gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf). + _, l1 = P.make_tmp_slot() + P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1))) + _, g1 = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1))) + _, l2 = P.make_tmp_slot() + P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2))) + _, g = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g))) + + # sample: argmax(logits / temperature + g) over the vocab axis, in float32 + _, logits_f = P.make_tmp_slot() + P.emit( + AsTypeNode( + x=P.slot_to_tid(logits), + out=P.slot_to_tid(logits_f), + scalar_type=torch_dtype_to_scalar_type(torch.float32), + ) + ) + _, scaled = P.make_tmp_slot() + P.emit( + DivideNode( + a=P.slot_to_tid(logits_f), + b=P.slot_to_tid(temperature), + out=P.slot_to_tid(scaled), + ) + ) + + # top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending. + _, probs = P.make_tmp_slot() + P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1)) + _, neg_p = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(neg_p))) + _, sorted_neg = P.make_tmp_slot() + P.emit(SortNode(x=P.slot_to_tid(neg_p), out=P.slot_to_tid(sorted_neg), axis=-1)) + _, sorted_p = P.make_tmp_slot() + P.emit(NegNode(x=P.slot_to_tid(sorted_neg), out=P.slot_to_tid(sorted_p))) + _, cum = P.make_tmp_slot() + P.emit(CumsumNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(cum), axis=-1)) + _, prefix = P.make_tmp_slot() + P.emit( + SubtractNode( + a=P.slot_to_tid(cum), + b=P.slot_to_tid(sorted_p), + out=P.slot_to_tid(prefix), + ) + ) + # remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0) + _, remove = P.make_tmp_slot() + P.emit( + GreaterNode( + a=P.slot_to_tid(prefix), + b=P.slot_to_tid(top_p), + out=P.slot_to_tid(remove), + ) + ) + pos_inf = emit_lifted_constant(P, float("inf"), torch.float32) + _, kept = P.make_tmp_slot() + P.emit( + WhereNode( + condition=P.slot_to_tid(remove), + x=P.slot_to_tid(pos_inf), + y=P.slot_to_tid(sorted_p), + out=P.slot_to_tid(kept), + ) + ) + # threshold = smallest kept probability (per row) + _, thresh = P.make_tmp_slot() + P.emit( + MinNode( + x=P.slot_to_tid(kept), + out=P.slot_to_tid(thresh), + axes=[-1], + keepdims=True, + ) + ) + _, drop = P.make_tmp_slot() + P.emit( + LessNode( + a=P.slot_to_tid(probs), + b=P.slot_to_tid(thresh), + out=P.slot_to_tid(drop), + ) + ) + neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32) + _, masked = P.make_tmp_slot() + P.emit( + WhereNode( + condition=P.slot_to_tid(drop), + x=P.slot_to_tid(neg_inf), + y=P.slot_to_tid(scaled), + out=P.slot_to_tid(masked), + ) + ) + + _, noisy = P.make_tmp_slot() + P.emit( + AddNode( + a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy) + ) + ) + P.emit( + ArgmaxNode( + x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False + ) + ) + + # temperature == 0 -> greedy: IfNode branches to argmax(logits), skipping sampling. + zero = emit_lifted_constant(P, 0.0, temp_dt) + _, is_sampling = P.make_tmp_slot() + P.emit( + GreaterNode( + a=P.slot_to_tid(temperature), + b=P.slot_to_tid(zero), + out=P.slot_to_tid(is_sampling), + ) + ) + _, cond_val = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(is_sampling), out=P.slot_to_vid(cond_val))) + emit_if_else(P, P.to_int_or_vid(cond_val), emit_sample, emit_greedy) + return out + + @REGISTRY.register(target=[torch.ops.aten.argmin.default]) def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle aten.argmin - index of min element along axis.""" diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 8563ff339a7..3c3c2c323a8 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1695,6 +1695,26 @@ exec_argmax(const ArgmaxNode& n, ExecutionState& st, StreamOrDevice s) { st.set_tensor(n.out, argmax(x, n.axis, n.keepdims, s)); } +inline void exec_random_bits( + const RandomBitsNode& n, + ExecutionState& st, + StreamOrDevice s) { + // random::bits supports width (bytes/element) in {1, 2, 4} -> + // uint8/uint16/uint32. + if (n.width != 1 && n.width != 2 && n.width != 4) { + throw std::runtime_error("random_bits: width must be 1, 2, or 4"); + } + auto shape = to_shape(n.shape, st); + // uint32 (4 bytes, the widest supported) is a safe upper bound for the guard. + check_allocation_bounded(shape, uint32, "random_bits"); + std::optional key = std::nullopt; + if (n.seed.has_value()) { + key = random::key( + static_cast(st.const_value_ref(n.seed.value()))); + } + st.set_tensor(n.out, random::bits(shape, n.width, key, s)); +} + inline void exec_argmin(const ArgminNode& n, ExecutionState& st, StreamOrDevice s) { const auto& x = st.const_tensor_ref(n.x); @@ -2057,6 +2077,9 @@ class Interpreter { case OpCode::ARGMAX: ops::exec_argmax(std::get(instr.node), st, s); break; + case OpCode::RANDOM_BITS: + ops::exec_random_bits(std::get(instr.node), st, s); + break; case OpCode::SLICE_UPDATE: ops::exec_slice_update(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 42c53e5172b..281199a8002 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -985,6 +985,14 @@ table IfNode { else_chain_idx: uint32; // index into MLXGraph.instruction_chains } +table RandomBitsNode { + out: Tid (required); + shape: [IntOrVid] (required); + seed: Vid; // OPTIONAL: present -> random::key(seed); + // absent -> MLX global KeySequence + width: int32 = 4; // bytes per element (4 -> uint32) +} + // Custom Metal kernel execution via mlx::core::fast::metal_kernel(). // Two-phase API: // 1. Factory: metal_kernel(name, input_names, output_names, source, header, @@ -1161,7 +1169,8 @@ union OpNode { BitwiseAndNode, BitwiseOrNode, BitwiseXorNode, - IfNode + IfNode, + RandomBitsNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index e96c8075903..afd4f276dde 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -35,6 +35,7 @@ custom_ops, ops, ) +from executorch.backends.mlx.llm.sampling import SamplingHead from torch.export import Dim from .test_utils import OpTestCase, register_test @@ -7633,3 +7634,167 @@ def create_model(self) -> nn.Module: def create_inputs(self) -> Tuple[torch.Tensor, ...]: x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) return (x,) + + +class _LogitsPassthrough(nn.Module): + """Stand-in for a model returning logits [B, S, vocab].""" + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + return logits + + +class SeededSampleModel(nn.Module): + """SamplingHead with temperature AND seed as runtime forward inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed): + return self.head(logits, temperature=temperature, seed=seed) + + +class UnseededSampleModel(nn.Module): + """SamplingHead with temperature as a runtime input and no seed.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature): + return self.head(logits, temperature=temperature) + + +class TopPSampleModel(nn.Module): + """SamplingHead with temperature, seed, and top_p as runtime inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed, top_p): + return self.head(logits, temperature=temperature, seed=seed, top_p=top_p) + + +@register_test +class SampleSeededTest(OpTestCase): + """Seeded sample lowers to one MLX segment; seed threads in via ItemIntNode.""" + + name = "sample_seeded" + skip_comparison = True # sampling RNG is not host/device bit-identical + expected_node_counts = { + "IfNode": 1, # temperature==0 greedy branch + "RandomBitsNode": 1, + "ArgmaxNode": 2, # sampling branch + greedy branch + "ItemIntNode": 2, # seed + temperature>0 condition + "SoftmaxNode": 1, # top-p nucleus chain + "SortNode": 1, + "CumsumNode": 1, + "MinNode": 1, + "WhereNode": 2, + } + + def create_model(self) -> nn.Module: + return SeededSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn(1, 4, 256), + torch.tensor(0.8), + torch.tensor(0, dtype=torch.int64), + ) + + +@register_test +class SampleUnseededTest(OpTestCase): + """Unseeded sample lowers without a seed field (only the cond ItemIntNode).""" + + name = "sample_unseeded" + skip_comparison = True # sampling RNG is not host/device bit-identical + expected_node_counts = { + "IfNode": 1, + "RandomBitsNode": 1, + "ArgmaxNode": 2, + "ItemIntNode": 1, # temperature>0 condition only (no seed) + "SoftmaxNode": 1, # top-p nucleus chain (top_p defaults to 1.0) + } + + def create_model(self) -> nn.Module: + return UnseededSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(1, 4, 256), torch.tensor(0.8)) + + +@register_test +class SampleTopPTest(OpTestCase): + """Top-p sample emits the nucleus chain (softmax/sort/cumsum/min/where).""" + + name = "sample_top_p" + skip_comparison = True # sampling RNG is not host/device bit-identical + expected_node_counts = { + "IfNode": 1, + "RandomBitsNode": 1, + "ArgmaxNode": 2, + "ItemIntNode": 2, + "SoftmaxNode": 1, + "SortNode": 1, + "CumsumNode": 1, + "MinNode": 1, + "WhereNode": 2, + } + + def create_model(self) -> nn.Module: + return TopPSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn(1, 4, 256), + torch.tensor(0.8), + torch.tensor(0, dtype=torch.int64), + torch.tensor(0.9), + ) + + +@register_test +class SampleGreedyTest(OpTestCase): + """Greedy argmax(logits) is bit-exact host/device, so verify the token with the + normal compare. Covers temperature=0, tiny temperature, bf16 logits, and a + batch (per-row argmax -> [B] on device).""" + + name = "sample_greedy" + + def __init__( + self, + temperature: float = 0.0, + dtype: torch.dtype = torch.float32, + batch: int = 1, + tag: str = "", + ): + self.temperature = temperature + self.dtype = dtype + self.batch = batch + self.name = f"sample_greedy_{tag}" if tag else "sample_greedy" + + @classmethod + def get_test_configs(cls) -> List["SampleGreedyTest"]: + return [ + cls(temperature=0.0), + cls(temperature=1e-4, tag="eps"), + cls(temperature=-1.0, tag="neg"), # negative -> greedy on both paths + cls(temperature=1e-4, dtype=torch.bfloat16, tag="bf16"), + cls(temperature=0.0, batch=4, tag="batch"), # per-row argmax over a batch + ] + + def create_model(self) -> nn.Module: + return SeededSampleModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + logits = torch.randn(self.batch, 4, 1024, dtype=self.dtype) + if self.dtype == torch.bfloat16: + logits[0, -1, 512] = 50.0 # dominant -> unambiguous bf16 argmax + return ( + logits, + torch.tensor(self.temperature), + torch.tensor(0, dtype=torch.int64), + ) diff --git a/backends/mlx/test/test_sample.py b/backends/mlx/test/test_sample.py new file mode 100644 index 00000000000..ddb0734d13b --- /dev/null +++ b/backends/mlx/test/test_sample.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export and on-device tests for mlx::sample (Gumbel-max token sampling). + +The end-to-end cases run the exported program through the compiled +op_test_runner (see backends/mlx/test/README.md). + +Usage: + python -m unittest executorch.backends.mlx.test.test_sample +""" + +import shutil +import tempfile +import unittest +from pathlib import Path +from typing import Optional + +# Registers torch.ops.mlx.sample. +import executorch.backends.mlx.custom_ops # noqa: F401 +import torch +import torch.nn as nn +from executorch.backends.mlx.llm.sampling import SamplingHead +from executorch.backends.mlx.test.test_utils import ( + export_model_to_pte, + load_tensors_from_bin, + run_cpp_test_runner, + save_tensors_to_bin, +) + + +class _LogitsPassthrough(nn.Module): + """Stand-in for a model returning logits [B, S, vocab].""" + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + return logits + + +class SeededSampleModel(nn.Module): + """SamplingHead with temperature AND seed as runtime forward inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed): + return self.head(logits, temperature=temperature, seed=seed) + + +class TopPSampleModel(nn.Module): + """SamplingHead with temperature, seed, and top_p as runtime inputs.""" + + def __init__(self): + super().__init__() + self.head = SamplingHead(_LogitsPassthrough()) + + def forward(self, logits, temperature, seed, top_p): + return self.head(logits, temperature=temperature, seed=seed, top_p=top_p) + + +def _ref_gumbel_max(logits: torch.Tensor, temperature: float, seed: int): + """Independent Gumbel-max reference using the same torch RNG as the op.""" + gen = torch.Generator().manual_seed(seed) + u = torch.rand(logits.shape, generator=gen) + gumbel = -torch.log(-torch.log(u)) + return torch.argmax(logits / temperature + gumbel, dim=-1) + + +def _tv_distance(p: torch.Tensor, q: torch.Tensor) -> float: + """Total-variation distance between two discrete distributions.""" + return 0.5 * torch.abs(p - q).sum().item() + + +def _sample(logits, temperature, seed: Optional[int], top_p: float = 1.0): + t = torch.tensor(float(temperature)) + s = None if seed is None else torch.tensor(int(seed), dtype=torch.int64) + p = torch.tensor(float(top_p)) # 1.0 = off + return torch.ops.mlx.sample(logits, t, p, s) + + +class TestSampleOp(unittest.TestCase): + """Eager reference behavior of mlx::sample (no export / no runtime).""" + + def test_matches_independent_gumbel_reference(self): + # Same seed -> bit-identical token vs an independent Gumbel-max impl. + torch.manual_seed(1) + logits = torch.randn(8, 512) + for seed in (0, 1, 7, 42): + got = _sample(logits, 0.8, seed=seed) + expected = _ref_gumbel_max(logits, 0.8, seed) + self.assertTrue(torch.equal(got, expected), f"mismatch at seed={seed}") + + def test_distribution_matches_softmax(self): + # Empirical token frequencies match softmax(logits / T). + vocab = 5 + temperature = 1.0 + torch.manual_seed(0) + base = torch.randn(vocab) + n = 20000 + tokens = _sample(base.expand(n, vocab), temperature, seed=0) + + empirical = torch.bincount(tokens, minlength=vocab).float() / n + target = torch.softmax(base / temperature, dim=-1) + tv = _tv_distance(empirical, target) + self.assertLess(tv, 0.05, f"TV distance {tv:.4f} too large") + + def test_determinism_seeded(self): + # Same seed -> identical draws; different seed -> different draws. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=123) + b = _sample(logits, 1.0, seed=123) + c = _sample(logits, 1.0, seed=124) + self.assertTrue(torch.equal(a, b)) + self.assertFalse(torch.equal(a, c)) + + def test_unseeded_varies_across_calls(self): + # seed=None uses the global RNG -> draws vary, tokens stay in range. + torch.manual_seed(0) + logits = torch.randn(256, 64) + a = _sample(logits, 1.0, seed=None) + b = _sample(logits, 1.0, seed=None) + self.assertFalse(torch.equal(a, b)) + self.assertGreaterEqual(int(a.min()), 0) + self.assertLess(int(a.max()), 64) + + def test_top_p_restricts_to_nucleus(self): + # probs [0.5, 0.3, 0.15, 0.05]; top_p=0.9 keeps {0,1,2}, drops index 3. + base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) + tokens = _sample(base.expand(5000, 4), 1.0, seed=0, top_p=0.9) + self.assertTrue((tokens != 3).all()) # tail token never drawn + self.assertEqual(set(tokens.tolist()), {0, 1, 2}) # nucleus covered + + def test_top_p_one_keeps_all(self): + # top_p=1.0 -> no filtering; the tail token (index 3) is reachable. + base = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])) + tokens = _sample(base.expand(20000, 4), 1.0, seed=0, top_p=1.0) + self.assertTrue((tokens == 3).any()) + + +class TestSampleExport(unittest.TestCase): + """Runtime-input semantics that survive export: temperature and seed stay + live graph inputs (not constant-folded). Lowering/partition is covered by the + OpTestCase classes in test_ops.py.""" + + def test_runtime_temperature_single_export(self): + # One exported program run at two temperatures (no re-export) confirms + # temperature is a live graph input: small T is near-greedy, large T + # spreads draws. A fixed logits row is broadcast so the spread of tokens + # reflects the sampling entropy. + vocab = 50 + batch = 256 + torch.manual_seed(0) + row = torch.randn(vocab) + logits = row.expand(batch, 1, vocab).contiguous() # [B, S=1, vocab] + seed = torch.tensor(0, dtype=torch.int64) + + run = torch.export.export( + SeededSampleModel(), (logits, torch.tensor(1.0), seed), strict=True + ).module() + + cold = run(logits, torch.tensor(1e-4), seed) + hot = run(logits, torch.tensor(100.0), seed) + + self.assertTrue(torch.all(cold == int(torch.argmax(row)))) + self.assertEqual(cold.unique().numel(), 1) + self.assertGreater(hot.unique().numel(), 10) + + def test_seeded_export_reproducible_no_host_rng(self): + # Seeded export: same seed -> identical tokens across runs of one + # exported program, independent of host RNG state (the seed is a graph + # input, not host-side stateful RNG). Different seed -> different draws. + torch.manual_seed(0) + logits = torch.randn(128, 1, 64) + seed = torch.tensor(123, dtype=torch.int64) + + run = torch.export.export( + SeededSampleModel(), (logits, torch.tensor(1.0), seed), strict=True + ).module() + + first = run(logits, torch.tensor(1.0), seed) + # Perturb the host global RNG between runs; a seeded draw is unaffected. + _ = torch.rand(1024) + second = run(logits, torch.tensor(1.0), seed) + self.assertTrue(torch.equal(first, second)) + + other = run(logits, torch.tensor(1.0), torch.tensor(124, dtype=torch.int64)) + self.assertFalse(torch.equal(first, other)) + + +class TestSampleEndToEnd(unittest.TestCase): + """On-device checks whose assertions the output-compare harness can't express.""" + + def setUp(self): + self._tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self._tmp, ignore_errors=True) + + def test_top_p_end_to_end(self): + # On-device nucleus: probs [0.5,0.3,0.15,0.05], top_p=0.9 -> token in {0,1,2}. + logits = torch.log(torch.tensor([0.5, 0.3, 0.15, 0.05])).view(1, 1, 4) + inputs = ( + logits, + torch.tensor(1.0), + torch.tensor(0, dtype=torch.int64), + torch.tensor(0.9), + ) + tmp = Path(self._tmp) + pte, in_bin, out_bin = tmp / "topp.pte", tmp / "in.bin", tmp / "out.bin" + export_model_to_pte(TopPSampleModel(), inputs, pte) + save_tensors_to_bin(list(inputs), in_bin) + + self.assertTrue(run_cpp_test_runner(pte, in_bin, out_bin)) + (token,) = load_tensors_from_bin(out_bin) + self.assertIn(int(token), {0, 1, 2}) # tail token (index 3) excluded + + +if __name__ == "__main__": + unittest.main()