Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
backends/mlx/test/test_partitioner.py \
backends/mlx/test/test_serialization_dedup.py \
backends/mlx/test/test_slot_recycling.py \
backends/mlx/test/test_sample.py \
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
-v
Expand Down
51 changes: 51 additions & 0 deletions backends/mlx/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,54 @@ def gather_qmm_fake(
else:
batch = w.shape[:-2]
return x.new_empty((*batch, M, N))


@torch.library.custom_op("mlx::sample", mutates_args=())
def sample(
logits: Tensor,
temperature: Tensor,
top_p: Tensor,
seed: Optional[Tensor] = None,
) -> Tensor:
"""
Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus).
logits: [B, vocab]
temperature: scalar float tensor (runtime input). temperature <= 0 is
greedy: return argmax(logits) directly (matches the device,
which branches on temperature > 0).
top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every
token, i.e. it is off.
seed: scalar int tensor or None
- tensor -> deterministic, keyed RNG (random::key(seed))
- None -> MLX global KeySequence (non-deterministic)
-> token_id: [B] int64

Host/CPU reference used for export (shape/meta) and distributional checks
only. It is NOT bit-identical to the lowered on-device graph: this uses torch
RNG (plain torch.rand, no uint32/nextafter uniform) while the delegate uses
MLX RNG, so a given seed does not reproduce the same tokens host vs. device.
"""
if float(temperature) <= 0: # matches the device cond (temperature > 0)
return torch.argmax(logits, dim=-1)
# whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties).
scaled = logits.float() / temperature
probs = torch.softmax(scaled, dim=-1)
s_probs, _ = torch.sort(probs, dim=-1, descending=True)
cum = torch.cumsum(s_probs, dim=-1)
keep = (cum - s_probs) <= top_p
thresh = torch.where(keep, s_probs, s_probs.new_tensor(float("inf"))).amin(
dim=-1, keepdim=True
)
scaled = torch.where(probs >= thresh, scaled, scaled.new_tensor(float("-inf")))
if seed is None:
u = torch.rand(scaled.shape) # global RNG
else:
gen = torch.Generator().manual_seed(int(seed.item()))
u = torch.rand(scaled.shape, generator=gen)
gumbel = -torch.log(-torch.log(u))
return torch.argmax(scaled + gumbel, dim=-1)


@torch.library.register_fake("mlx::sample")
def sample_fake(logits, temperature, top_p, seed=None):
return logits.new_empty(logits.shape[:-1], dtype=torch.long)
40 changes: 40 additions & 0 deletions backends/mlx/llm/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn


class SamplingHead(nn.Module):
"""
Wraps a model that returns logits and samples a token id on-device.

forward(*model_args, temperature, top_k=None, top_p=1.0, seed=None,
**model_kwargs) -> token_id

temperature: scalar float tensor, e.g. torch.tensor(0.8). Must be >= 0;
temperature=0 is greedy (returns argmax, no division).
top_k: not implemented yet (reserved); must be None.
top_p: scalar float tensor in (0, 1] for nucleus sampling. top_p=1.0
(the default) keeps every token, i.e. no filtering. Pass it
as a runtime input to tune per request.
seed: scalar int tensor (seeded) or None (unseeded export)
"""

def __init__(self, model: nn.Module):
super().__init__()
self.model = model

def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):
if top_k is not None:
raise NotImplementedError("top_k sampling is not implemented")
logits = self.model(*args, **kwargs) # [B, S, vocab]
last = logits[:, -1, :] # [B, vocab]
if not isinstance(top_p, torch.Tensor):
top_p = torch.tensor(float(top_p))
return torch.ops.mlx.sample(last, temperature, top_p, seed)
200 changes: 200 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

import torch
from executorch.backends.mlx.builder.op_helpers import (
emit_if_else,
emit_lifted_constant,
emit_quantized_biases,
emit_shape,
parse_dequant_node,
to_mlx_qparams,
torch_dtype_to_scalar_type,
Expand Down Expand Up @@ -115,6 +117,7 @@
PartitionNode,
PowerNode,
ProdNode,
RandomBitsNode,
ReciprocalNode,
RemainderNode,
RepeatNode,
Expand Down Expand Up @@ -3513,6 +3516,203 @@ def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.mlx.sample.default])
def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Gumbel-max sampling: argmax(logits / temperature + gumbel_noise).

Reproduces MLX's uniform -> gumbel -> argmax layering in the IR using the
new RandomBitsNode plus existing elementwise nodes, so a sampled token id is
produced on-device instead of returning the full logits tensor.

temperature == 0 is greedy: an IfNode branches to a plain argmax(logits),
skipping the sampling chain (so 0 is exact, not the small-epsilon approx).
"""
args = P.args(n)
require_args(args, 3, 4, "mlx.sample")
require_kwargs(P.kwargs(n), set(), "mlx.sample")
logits, temperature, top_p = args[0], args[1], args[2]
seed = args[3] if len(args) > 3 and args[3] is not None else None

temp_dt = n.args[1].meta["val"].dtype
out = P.make_or_get_slot(n)

def emit_greedy():
P.emit(
ArgmaxNode(
x=P.slot_to_tid(logits),
out=P.slot_to_tid(out),
axis=-1,
keepdims=False,
)
)

def emit_sample():
shape = emit_shape(P, n.args[0], logits)

# Optional runtime seed: tensor -> SymInt (Vid) via ItemIntNode. Absent ->
# leave RandomBitsNode.seed unset (MLX global RNG).
seed_field = None
if seed is not None:
_, seed_val = P.make_tmp_value_slot()
P.emit(ItemIntNode(x=P.slot_to_tid(seed), out=P.slot_to_vid(seed_val)))
seed_field = P.slot_to_vid(seed_val)

# uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95)
_, bits = P.make_tmp_slot()
P.emit(
RandomBitsNode(
out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field
)
)
_, bits_f = P.make_tmp_slot()
P.emit(
AsTypeNode(
x=P.slot_to_tid(bits),
out=P.slot_to_tid(bits_f),
scalar_type=torch_dtype_to_scalar_type(torch.float32),
)
)
umax = emit_lifted_constant(P, 4294967295.0, torch.float32)
_, div0 = P.make_tmp_slot()
P.emit(
DivideNode(
a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0)
)
)
prev1 = emit_lifted_constant(
P,
float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))),
torch.float32,
)
_, clamp = P.make_tmp_slot()
P.emit(
MinimumNode(
a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp)
)
)
# gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf).
_, l1 = P.make_tmp_slot()
P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1)))
_, g1 = P.make_tmp_slot()
P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1)))
_, l2 = P.make_tmp_slot()
P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2)))
_, g = P.make_tmp_slot()
P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g)))

# sample: argmax(logits / temperature + g) over the vocab axis, in float32
_, logits_f = P.make_tmp_slot()
P.emit(
AsTypeNode(
x=P.slot_to_tid(logits),
out=P.slot_to_tid(logits_f),
scalar_type=torch_dtype_to_scalar_type(torch.float32),
)
)
_, scaled = P.make_tmp_slot()
P.emit(
DivideNode(
a=P.slot_to_tid(logits_f),
b=P.slot_to_tid(temperature),
out=P.slot_to_tid(scaled),
)
)

# top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending.
_, probs = P.make_tmp_slot()
P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1))
_, neg_p = P.make_tmp_slot()
P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(neg_p)))
_, sorted_neg = P.make_tmp_slot()
P.emit(SortNode(x=P.slot_to_tid(neg_p), out=P.slot_to_tid(sorted_neg), axis=-1))
_, sorted_p = P.make_tmp_slot()
P.emit(NegNode(x=P.slot_to_tid(sorted_neg), out=P.slot_to_tid(sorted_p)))
_, cum = P.make_tmp_slot()
P.emit(CumsumNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(cum), axis=-1))
_, prefix = P.make_tmp_slot()
P.emit(
SubtractNode(
a=P.slot_to_tid(cum),
b=P.slot_to_tid(sorted_p),
out=P.slot_to_tid(prefix),
)
)
# remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0)
_, remove = P.make_tmp_slot()
P.emit(
GreaterNode(
a=P.slot_to_tid(prefix),
b=P.slot_to_tid(top_p),
out=P.slot_to_tid(remove),
)
)
pos_inf = emit_lifted_constant(P, float("inf"), torch.float32)
_, kept = P.make_tmp_slot()
P.emit(
WhereNode(
condition=P.slot_to_tid(remove),
x=P.slot_to_tid(pos_inf),
y=P.slot_to_tid(sorted_p),
out=P.slot_to_tid(kept),
)
)
# threshold = smallest kept probability (per row)
_, thresh = P.make_tmp_slot()
P.emit(
MinNode(
x=P.slot_to_tid(kept),
out=P.slot_to_tid(thresh),
axes=[-1],
keepdims=True,
)
)
_, drop = P.make_tmp_slot()
P.emit(
LessNode(
a=P.slot_to_tid(probs),
b=P.slot_to_tid(thresh),
out=P.slot_to_tid(drop),
)
)
neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32)
_, masked = P.make_tmp_slot()
P.emit(
WhereNode(
condition=P.slot_to_tid(drop),
x=P.slot_to_tid(neg_inf),
y=P.slot_to_tid(scaled),
out=P.slot_to_tid(masked),
)
)

_, noisy = P.make_tmp_slot()
P.emit(
AddNode(
a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy)
)
)
P.emit(
ArgmaxNode(
x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False
)
)

# temperature == 0 -> greedy: IfNode branches to argmax(logits), skipping sampling.
zero = emit_lifted_constant(P, 0.0, temp_dt)
_, is_sampling = P.make_tmp_slot()
P.emit(
GreaterNode(
a=P.slot_to_tid(temperature),
b=P.slot_to_tid(zero),
out=P.slot_to_tid(is_sampling),
)
)
_, cond_val = P.make_tmp_value_slot()
P.emit(ItemIntNode(x=P.slot_to_tid(is_sampling), out=P.slot_to_vid(cond_val)))
emit_if_else(P, P.to_int_or_vid(cond_val), emit_sample, emit_greedy)
return out


@REGISTRY.register(target=[torch.ops.aten.argmin.default])
def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.argmin - index of min element along axis."""
Expand Down
23 changes: 23 additions & 0 deletions backends/mlx/runtime/MLXInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,26 @@ exec_argmax(const ArgmaxNode& n, ExecutionState& st, StreamOrDevice s) {
st.set_tensor(n.out, argmax(x, n.axis, n.keepdims, s));
}

inline void exec_random_bits(
const RandomBitsNode& n,
ExecutionState& st,
StreamOrDevice s) {
// random::bits supports width (bytes/element) in {1, 2, 4} ->
// uint8/uint16/uint32.
if (n.width != 1 && n.width != 2 && n.width != 4) {
throw std::runtime_error("random_bits: width must be 1, 2, or 4");
}
auto shape = to_shape(n.shape, st);
// uint32 (4 bytes, the widest supported) is a safe upper bound for the guard.
check_allocation_bounded(shape, uint32, "random_bits");
Comment thread
metascroy marked this conversation as resolved.
std::optional<array> key = std::nullopt;
if (n.seed.has_value()) {
key = random::key(
static_cast<uint64_t>(st.const_value_ref<int32_t>(n.seed.value())));
}
st.set_tensor(n.out, random::bits(shape, n.width, key, s));
}

inline void
exec_argmin(const ArgminNode& n, ExecutionState& st, StreamOrDevice s) {
const auto& x = st.const_tensor_ref(n.x);
Expand Down Expand Up @@ -2057,6 +2077,9 @@ class Interpreter {
case OpCode::ARGMAX:
ops::exec_argmax(std::get<ArgmaxNode>(instr.node), st, s);
break;
case OpCode::RANDOM_BITS:
ops::exec_random_bits(std::get<RandomBitsNode>(instr.node), st, s);
break;
case OpCode::SLICE_UPDATE:
ops::exec_slice_update(std::get<SliceUpdateNode>(instr.node), st, s);
break;
Expand Down
11 changes: 10 additions & 1 deletion backends/mlx/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,14 @@ table IfNode {
else_chain_idx: uint32; // index into MLXGraph.instruction_chains
}

table RandomBitsNode {
out: Tid (required);
shape: [IntOrVid] (required);
seed: Vid; // OPTIONAL: present -> random::key(seed);
// absent -> MLX global KeySequence
width: int32 = 4; // bytes per element (4 -> uint32)
}

// Custom Metal kernel execution via mlx::core::fast::metal_kernel().
// Two-phase API:
// 1. Factory: metal_kernel(name, input_names, output_names, source, header,
Expand Down Expand Up @@ -1161,7 +1169,8 @@ union OpNode {
BitwiseAndNode,
BitwiseOrNode,
BitwiseXorNode,
IfNode
IfNode,
RandomBitsNode
// BC: Add new op nodes here (append only)
}

Expand Down
Loading
Loading