Apple Silicon support: arm64 build fixes + MPS device for torch backend#609
Open
oceanapplications wants to merge 4 commits into
Open
Apple Silicon support: arm64 build fixes + MPS device for torch backend#609oceanapplications wants to merge 4 commits into
oceanapplications wants to merge 4 commits into
Conversation
- build.sh: gate -mavx2/-mfma on x86_64 (arm64 clang rejects them) - build.sh: macOS OpenMP via -Xpreprocessor -fopenmp with Homebrew's omp.h, linked against torch's bundled libomp.dylib. Linking a second OpenMP runtime (e.g. Homebrew's) into a process that imports torch aborts at startup or segfaults in the vecenv's parallel regions. - torch_pufferl: select mps when available and _C has no CUDA; move actions to host memory before cpu_step; round-trip advantage computation through CPU for non-CUDA accelerators - pufferl: fall back to the torch backend automatically when _C was built with --cpu, so 'puffer train env' works without --slowly Verified on M-series (breakout): 605K SPS on MPS vs 190K on CPU. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019AsyRcLQqeJondzSTM6xsn
…ckend Shipped configs contain sweep-produced floats (e.g. num_layers = 2.11327 in cartpole.ini). The native backend truncates them on assignment to C ints; the torch backend passed them straight to nn.Linear and crashed with 'float object cannot be interpreted as an integer'. Fixes the torch (--slowly) backend for 17 of the 38 currently-buildable ocean envs, on all platforms. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019AsyRcLQqeJondzSTM6xsn
Author
ocean env matrix on Apple Silicon (arm64, --cpu build, MPS training)37/38 buildable envs train one full epoch on MPS. 20 build failures: 17 are broken on all platforms on current master (missing ocean/env_binding.h or OBS_TENSOR_T defines), 1 is x86-only SIMD (craftax_classic), and matsci/nethack/impulse_wars need external deps or unvendored headers.
|
The MPS multinomial kernel can intermittently return indices outside [0, num_categories) (pytorch#136623, still unfixed upstream; the fix PR pytorch#170195 was closed unmerged). In long MPS training runs this surfaced as an intermittent 'AcceleratorError: index N is out of bounds' raised at the next sync point — the .cpu() transfer in compute_puff_advantage — with N ~ 2x total_agents, because the bad index from the prioritized-replay multinomial feeds lazily-queued gathers (obs[idx]) and scatters (ratio[idx], val[idx]) that only validate at materialization. Clamp multinomial output on MPS at both call sites: minibatch segment sampling and action sampling (where an out-of-range action would be memcpy'd into the C envs and corrupt memory silently instead of raising). Cost is one elementwise op; out-of-range draws are ~1e-5 rare. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019AsyRcLQqeJondzSTM6xsn
Include paths (pybind11, numpy, sysconfig) and the torch libomp lookup used bare 'python', which fails with ModuleNotFoundError when the target venv is not the active interpreter. Resolve once at the top: PYTHON=$path ./build.sh <env> --cpu now works from any shell. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019AsyRcLQqeJondzSTM6xsn
f506345 to
62c1220
Compare
Author
|
Used to train this 6 link inverted pendulum on a MacBook. So real world tested over hundreds of billions of steps. demo_250th.trimmed.mp4 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Apple Silicon support: arm64 build fixes + MPS device for the torch backend
Makes
./build.sh <env> --cpubuild on Apple Silicon and the PyTorch (--slowly) backend train on the M-series GPU via MPS. Related: #590, #507, #422, #532.Changes
build.sh
-mavx2 -mfmaonx86_64(arm64 clang rejects them; same idea as Conditionally apply AVX2/FMA flags by architecture (fix Apple Silicon build) #590).-Xpreprocessor -fopenmp+ Homebrew'somp.h, but link against torch's bundledlibomp.dylib. This is the subtle one: linking Homebrew's libomp builds fine, but any process that also imports torch then has two OpenMP runtimes loaded — it either aborts at startup (OMP: Error #15) or segfaults insidecpu_vec_step's parallel regions. Falls back to Homebrew's libomp when torch isn't importable at build time.brew install libompwhenomp.his missing.pufferlib/torch_pufferl.py
cuda→mps(when available) →cpu.cpu_step(the vecenv memcpys from the raw pointer; an MPSdata_ptr()segfaults).compute_puff_advantage: round-trip through CPU for non-CUDA accelerators, since_C.puff_advantage_cpureads raw host pointers. (A native Metal kernel like feature: add mps kernel forcompute_puff_advantage#422 would avoid the copy; this keeps the diff minimal.)pufferlib/pufferl.py
_Cwas built with--cpu(nocreate_pufferl), sopuffer train <env>works on macOS without knowing about--slowly.Results (M-series Mac, macOS 25.5 / Apple clang 21, torch 2.12.1)
Benchmark on
breakout(default config, 4096 agents, 32.5K-param policy), steady-state over 3 epochs:Losses match CPU training qualitatively; checkpoints save/load fine.
Also ran a full build + one-epoch MPS training smoke test of every env in
ocean/on arm64: 37 of the 38 currently-buildable envs train on MPS (the 38th,squared_continuous, has no config file). Includes chess, craftax, drive, nmmo3, moba, and terraform. The 20 envs that don't build fail for reasons unrelated to this PR: 17 are broken on all platforms on current master (references to a missingocean/env_binding.h, orbinding.cmissing itsOBS_TENSOR_Tdefine), pluscraftax_classic(x86-only AVX-512),matsci(needs LAMMPS), andnethack/impulse_wars(external/unvendored deps).The second commit fixes a torch-backend crash that predates this PR and affects all platforms: shipped configs contain sweep-produced float values (
num_layers = 2.11327incartpole.ini), which the native backend truncates to C ints but the torch backend passed straight tonn.Linear. 17 of the 38 buildable envs crashed on this before the fix.Not addressed
craftax_classic: hand-written AVX-512 obs path (44_mm512_*intrinsics, no scalar fallback) — x86-only regardless of this PR.ocean/env_binding.h, missingOBS_TENSOR_Tdefines).