PyTorch https://pytorch.org Wed, 20 May 2026 15:51:02 +0000 en-US hourly 1 https://wordpress.org/?v=6.9.4 https://pytorch.org/wp-content/uploads/2024/10/cropped-favicon-32x32.webp PyTorch https://pytorch.org 32 32 PyTorch Docathon 2026 Results in 150+ Merged Pull Requests https://pytorch.org/blog/pytorch-docathon-2026-results/ Wed, 20 May 2026 15:45:54 +0000 https://pytorch.org/?p=84653 PyTorch Docathon 2026 Top Community ContributorsThank you to everyone who participated in the PyTorch Docathon 2026! Once again, the community showed up with incredible energy and dedication to make PyTorch documentation better for developers everywhere.

The PyTorch Docathon ran from May 5th 2026 through May 19th 2026, bringing together more than 260+ registrants and 30+ active participants. Participants tackled issues across difficulty levels, resulting in over 150 merged pull requests that fixed various issues, added API documentation and contributed to the ExecuTorch documentation.

We want to give a special shout-out to our top contributors, whose dedication and expertise went above and beyond. Your work directly improves the experience for millions of PyTorch users worldwide. See the full list of contributors in the leaderboard.

Meet the top contributors:

First place: ymrohit

Second place: XAheli, PyDevC, darknight054

Third place: JonathanColetti, Kadermiyanyedi

Honorable mentions: AswaniSahoo, Vasanthadithya-mundrathi, Nazim-fad, ozgecinko, kiszk, saurabhkthakur, spzala

As we wrap up this Docathon, we want to remind everyone that great documentation is an ongoing effort. Whether this was your first open source contribution or your hundredth, your work matters. Clear docs lower the barrier to entry and help the entire deep learning community move faster, and shortens the path from research to production in machine learning. And as AI development accelerates, documentation matters even more. LLMs and AI agents increasingly rely on public technical documentation to learn APIs, generate code, and troubleshoot workflows. High-quality PyTorch docs don’t just help humans, they help ensure AI-generated guidance is more accurate, up-to-date, and aligned with best practice.

We encourage you to keep contributing to PyTorch documentation and code.Thank you for being part of this, and we look forward to seeing you at the next one.

Team PyTorch

]]>
vLLM and PyTorch Work Together to Improve the Developer Experience on aarch64 https://pytorch.org/blog/vllm-and-pytorch-work-together-to-improve-the-developer-experience-on-aarch64/ Mon, 18 May 2026 17:25:11 +0000 https://pytorch.org/?p=83473 TLDR: PyTorch 2.11 makes it possible to install CUDA-enabled PyTorch wheels on aarch64 Linux directly from PyPI, eliminating the need for custom package indexes and workarounds that previously complicated deployment on systems such as NVIDIA GH200, GB200, and GB300. In this post, Kaichao You (Inferact) explains how this packaging change improves the installation experience for vLLM users and highlights how collaboration between vLLM and PyTorch through PyTorch Foundation helped bring the fix to production.

A fix, two years in the making, that makes life much easier on GB200 / GB300 / GH200.

An issue I first hit at a hackathon

This story actually starts back in October 2024.

I was at the CUDA MODE (now GPU MODE) IRL hackathon, trying to get vLLM running on a GH200 box. It should have been a five-minute job. Instead, I spent a frustrating chunk of the day staring at a pip install that, on the surface, looked perfectly fine — wheels were resolved, dependencies were satisfied, the install completed without errors — but at runtime torch.cuda.is_available() stubbornly returned False.

The reason, once I dug in, was almost comically mundane: on aarch64 Linux, pip install torch was pulling the CPU-only wheel from PyPI. There simply was no GPU wheel for aarch64 published to the default PyPI index. To get a CUDA-enabled build, you had to explicitly point pip at the PyTorch download index:

pip install torch --index-url https://download.pytorch.org/whl/cu128

That, by itself, would be only mildly annoying. The real damage came from how this interacted with transitive dependencies. PyPI does not let a package specify a custom index for its dependencies. So if any package in vLLM’s dependency tree declared a requirement of  torch==<some_version>, and that version doesn’t match, pip would happily go back to the default PyPI index, find the CPU wheel, silently uninstall the GPU build I had just carefully installed, and replace it with the CPU one. You’d think everything was fine until your model refused to find a GPU.

For anyone trying to bring up vLLM on GH200 — and later on GB200 / GB300 — this turned a one-line install into a maze of --index-url flags, pinned versions, and post-install sanity checks.

The workarounds vLLM carried in the meantime

While we waited for a proper fix upstream, vLLM had to ship its own workarounds so that aarch64 users were not stuck.

The first one was use_existing_torch.py, added in vllm-project/vllm#8713 back in September 2024 — explicitly framed in the PR title as “enable existing pytorch (for GH200, aarch64, nightly)”. The flow is exactly what the name suggests: you install the right torch build yourself (from the PyTorch index, or a nightly, or a custom build), then run python use_existing_torch.py, which strips every torch/torchvision/torchaudio requirement out of vLLM’s requirements/*.txt, requirements/*.in, and pyproject.toml. With those pins gone, the subsequent vLLM install can no longer trigger pip to “helpfully” reach back into the default PyPI index and silently swap your CUDA-enabled torch for the CPU wheel. It is ugly — we are literally rewriting our own dependency files at install time — but it kept GH200 users unblocked for over a year.

Later, as uv matured, we got a cleaner option. In vllm-project/vllm#24303 we added the following to pyproject.toml:

[tool.uv]
no-build-isolation-package = ["torch"]

This tells uv not to build torch in an isolated environment — which in practice means uv will reuse the torch already present in the current environment instead of trying to resolve and reinstall its own copy. Combined with installing torch first from the right index, this gave us a much more ergonomic path than the file-rewriting trick: a single config line in pyproject.toml, and uv pip install vllm (or a uv sync) would respect the pre-installed CUDA-enabled torch on aarch64.

The vLLM workaround is the community improvising around a gap in the packaging standard. Wheel Variants is NVIDIA and Astral formalizing the fix so the improvisation is no longer needed.

From a hackathon headache to a TAC agenda item

Fast forward to 2025. vLLM joined the PyTorch Foundation, and I became one of its representatives on the Technical Advisory Committee (TAC). The aarch64 wheel situation kept coming up — both in my own work and from other vLLM users on Grace Hopper and Grace Blackwell systems. In August 2025, I filed pytorch/pytorch#160162 to track the problem formally, and earlier this year, in a January 2026 TAC meeting, I raised it directly on behalf of vLLM users.

The ask was straightforward: publish aarch64 GPU wheels to the default PyPI index so that pip install torch “just works” on GB200-class machines, the same way it does on x86. Those wheels would dynamically link to libraries like NCCL and cuBLAS — the same approach already used on x86 — so they don’t balloon in size. Such large binary sizes are both hard to download for users and expensive to host by the PyPi project maintainers. Hence it is limited and heavily discouraged by the PyPi maintainer.

The Nvidia engineering team requested that the CUDA SBSA wheels be published to PyPI, and then drove the small wheel approach that links against them.

This is exactly the kind of cross-project, infrastructure-level issue that the PyTorch Foundation is well-positioned to coordinate. vLLM and PyTorch are both Foundation projects, and having a shared forum to surface ecosystem friction — rather than each project working around it independently — turned out to make a real difference.

The fix has landed

In April 2026, in another TAC meeting, I learned the issue is resolved: starting with PyTorch 2.11.0, the default pip install torch on aarch64 Linux now pulls a CUDA-enabled wheel rather than the CPU-only one. Piotr Bialecki from NVIDIA confirmed the change is live in the 2.11.0 release.

I verified it on a GB200, and the difference is exactly what you’d want — boring, in the best possible way:

$ uv run --no-project --python 3.12 --with 'torch==2.11.0' -- python -c "import torch; print(torch.cuda.is_available())"
True

$ uv run --no-project --python 3.12 --with 'torch==2.10.0' -- python -c "import torch; print(torch.cuda.is_available())"
False

One version bump, and the entire workaround stack disappears. No more custom index URLs propagating through requirements files. No more silent CPU-wheel substitutions clobbering a working install. No more “why is my GB200 not finding the GPU” debugging sessions for new users.

For vLLM specifically, this means installation on GB200 / GB300 is now genuinely smooth. New users showing up with a Grace Blackwell system can follow the standard install instructions and have things work the first time — which, when you’re trying to get inference up and running on a brand-new platform, matters a lot.

The workarounds in vLLM — both use_existing_torch.py and the [tool.uv] no-build-isolation-package = ["torch"] setting — will stay. They are still useful for advanced users who run a custom PyTorch build (a nightly, a patched fork, or a from-source build paired with a vLLM source build) and need vLLM’s install to leave that torch strictly alone. What changes is the default path: ordinary users on aarch64 no longer have to know any of this exists. They can pip install and get on with their work, and the workarounds quietly become an advanced-user tool rather than a tax on everyone.

Why this is worth writing about

It’s a small change in the grand scheme of things — a packaging tweak, not a new feature. But I think it’s worth taking a moment to appreciate, for a couple of reasons.

First, it’s a concrete example of vLLM and PyTorch collaborating productively under the PyTorch Foundation umbrella. The TAC isn’t just a governance ritual; it’s a venue where pain points from downstream projects can land in front of the people who can actually fix them, and where coordination across projects happens by default rather than by accident. This issue traveled the full path — from a developer cursing at a terminal during a hackathon, to a TAC discussion, to a tracked GitHub issue, to a release — and the Foundation is what made that path short.

Second, developer experience compounds. Every hour someone doesn’t spend wrestling with --index-url flags is an hour they spend actually building things on top of vLLM and PyTorch. aarch64 GPU systems are only going to get more common, and it’s much better to fix this now, in the boring infrastructure layer, than to leave each user to discover and work around it on their own.

The uv-side workaround (build isolation passthrough) is part of the broader WheelNext effort — a very welcome push to rethink how Python packaging handles accelerator-bound dependencies in the AI era.

A big shoutout to the people who made this happen: Alban Desmaison,Nikita Shulga, and Andrey Talman from the PyTorch core team, who picked up the original ask and helped move it through; The NVIDIA PyTorch team, who drove the aarch64 build work and confirmed the fix had landed in 2.11.0 with Piotr Bialecki supporting the effort and acting as the steady point of contact across NVIDIA and upstream on these issues; the PyTorch release engineering team for getting the wheels built and published; and the many engineers behind the scenes — across PyTorch, NVIDIA, and Arm — whose work on toolchains, CI infrastructure, and packaging made this possible. Thanks also to everyone in the TAC for keeping the door open for these kinds of conversations.

Onwards.

]]>
Running PyTorch Models on Apple Silicon GPUs with the ExecuTorch MLX Delegate https://pytorch.org/blog/running-pytorch-models-on-apple-silicon-gpus-with-the-executorch-mlx-delegate/ Mon, 18 May 2026 15:30:50 +0000 https://pytorch.org/?p=82879 TL;DR: Introducing the ExecuTorch MLX Delegate

  • The new MLX delegate enables optimized, GPU-accelerated inference for PyTorch models on Apple Silicon Macs, using Apple’s MLX framework. 
  • The delegate seamlessly integrates with the PyTorch 2 export stack and supports a wide range of quantization options (BF16, FP16, FP32, 2/4/8-bit affine, NVFP4).
  • It supports various models, including dense transformers (Llama, Qwen, Gemma), sparse Mixture-of-Experts, and speech-to-text models (Whisper, Voxtral, Parakeet) for both offline and real-time transcription.
  • Note: The MLX delegate is currently experimental.

Apple Silicon has become a popular platform for running large language models locally. Until now, ExecuTorch users on macOS were limited to CPU-based backends like XNNPACK or the AOTI Metal backend. Now we’ve released the MLX delegate, which brings fully optimized GPU-accelerated inference to Apple Silicon Macs through Apple’s MLX framework.

In this post we’ll cover what the MLX delegate is, why we built it as an ExecuTorch backend, and what you can run with it today.

Note: The MLX delegate is currently experimental and under active development. APIs and supported features may change.

What is the MLX Delegate?

The MLX delegate is a new ExecuTorch backend that compiles and runs PyTorch models on Apple Silicon GPUs. You export your model using the standard ExecuTorch pipeline, and the delegate handles the rest: partitioning the graph, serializing it into an optimized format, and dispatching operations to MLX’s Metal GPU kernels at runtime.

From the user’s perspective, the workflow is the same as any other ExecuTorch backend:

  1. Export your model with torch.export
  2. Lower it with to_edge_transform_and_lower using the MLXPartitioner
  3. Run the resulting .pte file with the ExecuTorch runtime

The delegate currently supports around 90 ATen ops, covering the full range of operations needed for transformer inference: quantized matmul, multi-head attention, rotary position embeddings, mixture-of-experts routing, recurrent state-space operations, and more.

Why Build This as an ExecuTorch Delegate?

There are already excellent tools for running models on Apple Silicon, including MLX’s own mlx-lm. So why build another one? Three reasons:

Performance. The MLX delegate achieves 3-6x higher throughput on generative AI workloads compared to existing ExecuTorch delegates on macOS. Moving inference to MLX’s optimized Metal kernels makes a meaningful difference for ExecuTorch applications like chat and real-time transcription.

PyTorch 2 integration. The delegate plugs directly into the PyTorch 2 export stack. It uses torch.export for graph capture and TorchAO for quantization, the same tools used by every other ExecuTorch backend. If you can export a model with torch.export, you can run it on MLX. When new models or quantization techniques land in PyTorch, they become available to the MLX delegate without additional work.

Portable applications. ExecuTorch provides a single runtime API across all backends. An application built against the ExecuTorch C++ or Python runtime can run models exported for MLX, XNNPACK, CoreML, Vulkan, or CUDA without changing application code.

Quantization and Dtype Support

The delegate supports the precision and quantization options you’d expect for on-device inference:

  • BF16, FP16, and FP32 for weights and activations
  • 2, 4, and 8-bit affine quantization via TorchAO’s quantize_ API. This uses the same quantization scheme as the XNNPACK and Vulkan backends, which means a single quantized model definition can target multiple backends, and opens the door to fat PTE files that run on whichever backend is available at runtime.
  • NVFP4 quantization using NVIDIA’s FP4 data type
  • Tied quantized embeddings for models that share weights between the embedding layer and the language model head

What Models Can I Run?

We’ve validated the delegate across a range of architectures:

Large Language Models

Dense transformers work out of the box, with support for both full KV caches and sliding window caches:

  • Llama 3.2 1B
  • Qwen 3 (0.6B, 1.7B, 4B)
  • Phi-4 mini (3.8B)
  • Gemma 3 (1B, 4B) with sliding window attention

Sparse Mixture-of-Experts models are supported through custom gather operations that efficiently route tokens to the correct experts on the GPU:

  • Qwen 3.5 35B-A3B: 256 experts with top-8 routing, combining GatedDeltaNet linear attention layers with full SDPA attention layers

Speech-to-Text

Offline transcription models process a complete audio recording and return the transcript:

  • OpenAI Whisper (tiny through large-v3-turbo)
  • NVIDIA Parakeet TDT (0.6B) with word-level timestamps
  • Mistral Voxtral (3B)

Real-time streaming transcription processes audio in small chunks as it arrives, enabling live use cases:

  • Mistral Voxtral Realtime (4B) with live microphone input, ring buffer KV caches, and sliding window attention

Broader Coverage

Beyond these flagship models, over 30 additional models have been validated through our backend test suites, covering dense transformers, encoder-decoder architectures, and vision models.

Getting Started

Each supported model has a README with detailed export and inference instructions:

For an overview of the delegate architecture, supported operations, and development guide, see the MLX Delegate README.

We’d love to hear what models and use cases matter most to you. If you run into issues or have feature requests, please open an issue on the ExecuTorch GitHub repo or join our Discord Channel.

]]>
PyTorch 2.12 Release Blog https://pytorch.org/blog/pytorch-2-12-release-blog/ Wed, 13 May 2026 18:36:29 +0000 https://pytorch.org/?p=81687 We are excited to announce the release of PyTorch® 2.12 (release notes)!  

The PyTorch 2.12 release features the following changes: 

  • Batched linalg.eigh on CUDA is up to 100x faster due to updated cuSolver backend selection
  • New torch.accelerator.Graph API unifies graph capture and replay across CUDA, XPU, and out-of-tree backends
  • torch.export.save now supports Microscaling (MX) quantization formats, enabling full export of aggressively compressed models
  • Adagrad now supports fused=True, joining Adam, AdamW, and SGD with a single-kernel optimizer implementation
  • torch.cond control flow can now be captured and replayed inside CUDA Graphs
  • ROCm users gain expandable memory segments, rocSHMEM symmetric memory collectives, and FlexAttention pipelining

This release is composed of 2,926 commits from 457 contributors since PyTorch 2.11. We want to sincerely thank our dedicated community for your contributions.  As always, we encourage you to try these out and report any issues as we improve 2.12. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Have questions? Join us on Wednesday, May 20 at 10 am PST for a live Q&A with panelists Joe Spisak, Andrey Talman, and Alban Desmaison, and moderator Chris Gottbrath. We will provide a brief overview of the release and answer your questions live. Register now.

Throughout the 2.x series, PyTorch has been evolving from a research-first framework into a unified, hardware-agnostic platform for production training and inference at scale. PyTorch 2.10 laid the groundwork with cross-backend performance primitives and the formal deprecation of TorchScript. PyTorch 2.11 expanded that foundation with differentiable collectives for distributed training, FlashAttention-4 on next-generation GPUs, and broader export coverage. 

PyTorch 2.12 continues this direction: a new device-agnostic torch.accelerator.Graph API unifies graph capture and replay across CUDA, XPU, and out-of-tree backends; batched eigenvalue decomposition is up to 100x faster; and torch.export now supports Microscaling quantization formats for deploying aggressively compressed models. Across these releases, PyTorch is becoming faster across backends and usable in a wider variety of platforms as it continues to enable AI innovation. 

Performance Features

Up to 100x faster batched eigendecomposition on CUDA (linalg.eigh)

The backend selection for linalg.eigh on CUDA has been overhauled. The legacy MAGMA backend was deprecated in favor of cuSolver (PR #174619 by Grayson Derossi), and the cuSolver dispatch heuristics were updated to use syevj_batched unconditionally (PR #175403 by Johannes Z). For batched symmetric/Hermitian eigenvalue problems, this yields up to 100x speedups over the previous release, resolving longstanding performance gaps with CuPy.

Workloads which previously took minutes (because PyTorch was inefficiently dispatching each matrix solve individually) now run in seconds by using cuSolver’s syevj_batched kernel, which is designed to process many small/medium matrices as a single GPU operation. These gains are especially relevant for scientific computing and machine learning workloads that rely on eigendecompositions of batched matrices. (example usage in the doc)

Fused Adagrad optimizer 

The Adagrad optimizer now supports fused=True, performing the entire optimizer step in a single CUDA kernel rather than launching separate kernels for each operation. This reduces kernel launch overhead and memory traffic. Adagrad joins Adam, AdamW, and SGD in offering a fused variant. The underlying CUDA kernel was contributed by @MeetThePatel in the 2.11 cycle (PR #159008), with the Python frontend exposing it to users finalized by Jane Xu in 2.12 (PR #177672).

Compilation and export across hardware

torch.accelerator.Graph: Device Agnostic Accelerator Graph Capture and Stream API

`torch.accelerator.Graph` is a new device-agnostic API for graph capture and replay, providing a unified abstraction over backend-specific implementations such as `torch.xpu.XPUGraph`. Each backend can register its own implementation through a lightweight GraphImplInterface, preserving backend autonomy while enabling a consistent user-facing API.

Alongside this, `c10::Stream` and `torch. Stream` now exposes an `is_capturing()` method, replacing the device-specific `is_current_stream_capturing` with a backend-agnostic alternative. Stream context manager reentrance was also fixed. Together, these changes bring cross-backend parity to stream and graph management, with initial support for the XPU backend and extensibility to out-of-tree backends via `PrivateUse1`.
Contributed by Guangye Yu (Intel) across six PRs, anchored by the C++   interface (PR #171269) and Python frontend (PR #171285). (usage example in docstring)

torch.export now supports Microscaling (MX) quantization formats

As models move from research to production, torch.export is the standard path for serializing PyTorch models for deployment. However, models using Microscaling (MX) quantization — an increasingly popular technique for reducing model size and inference cost — could not previously be exported because torch.export.save did not handle the float8_e8m0fnu dtype used as the shared block-scale exponent in MX formats (MXFP4, MXFP6, MXFP8).

In PyTorch 2.12, torch.export.save and torch.export.load now correctly serialize and deserialize tensors with this dtype, unblocking the full export-to-deployment workflow for models leveraging Microscaling quantization. This is particularly relevant for teams deploying large language models to cost-constrained or edge environments where aggressive quantization is essential. Contributed by Chizkiyahu Raful (ARM) (PR #176270).

Capture Control flow with torch.cond within CUDA Graph

Control-flow regions using torch.cond can now be captured and replayed as part of CUDA Graphs. Previously, data-dependent control flow forced fallback to CUDA graph trees because branching was evaluated on the CPU. By leveraging CUDA 12.4’s conditional IF nodes, torch.cond branches are now evaluated entirely on the GPU within a single graph capture.

This was contributed by Daniel Galvez and Ting-Yang Kuei (NVIDIA)  (PR #168912), with Inductor ordering support added by Paul Zhang (Meta) (PR #179457). This currently works with the eager and cudagraphs backends; Inductor support is planned for a future release.

FMA-based addcdiv lowering for XPU

Inductor now uses fused multiply-add (FMA) instructions for addcdiv operations, achieving bitwise numerical parity with eager CUDA execution while preserving Triton kernel fusion benefits.

addcdiv is a fused arithmetic operation (result = input + value × (tensor1 / tensor2)) that sits at the heart of many optimizer update rules, including Adam, AdamW, and RMSprop. Previously, Inductor’s lowering used separate multiply and divide instructions, introducing small floating-point rounding differences compared to eager mode. These differences accumulate over thousands of training steps, making it difficult to validate that compiled models produce numerically identical results.

This was first implemented for CUDA by Michael Lazos (Meta) (PR #174912), then extended to XPU by Guangye Yu (Intel) (PR #176163), fixing several numerical correctness issues on Intel GPUs. Anyone using torch.compile with optimizer-heavy training loops now gets compiled performance without sacrificing numerical reproducibility — on both NVIDIA and Intel hardware.

Distributed Training

ProcessGroup support in custom ops

Custom operators can now accept ProcessGroup objects directly as arguments rather than requiring callers to convert them to string group names and looking them up in a global registry. All c10d functional collective ops (all_reduce, reduce_scatter, etc) have been updated to accept both ProcessGroup objects directly and the string names. Contributed by Aaron Orenstein (Meta) (PR #172795).

Multi-GPU/multi-node profiling improvements

PyTorch Profiler Events API now exposes flow IDs, flow types, activity types, unfinished events, and Python function events — bringing events() to parity with the Chrome trace JSON output and enabling richer programmatic post-hoc analysis. In addition it is now possible to correlate NCCL collective traces across ranks using a new seq_num field – all ranks participating in the same collective share the same sequence number within a process group. Together these changes significantly improve the tooling for debugging distributed training performance across multiple GPUs and nodes.  API enrichment by Ryan Zhang (Meta) (PR #177888) and NCCL seq_num added by Marvin Dsouza (Meta) (PR #177148).

FlightRecorder: ncclx + gloo Backends

FlightRecorder’s trace analyzer now supports ncclx and gloo backends alongside the existing nccl and xccl backends, enabling distributed communication tracing across a broader set of collective backends. Additionally, FlightRecorder now recognizes torchcomms operations (e.g., all_gather_single, reduce_scatter_v, barrier) that were previously untracked. A race condition that could cause an infinite loop when multiple process groups concurrently accessed the FlightRecorder singleton was also fixed in this cycle.  Backend allowlist added by Lily Janjigian (Meta) (PR #180268), with  torchcomms operation support by Tushar Jain (PR #178359).

Platform Related Updates

CUDA

CUDA Graph kernel annotations

torch.cuda.graph now accepts an enable_annotations kwarg that injects annotation metadata (e.g., collective op names, process groups, message sizes) into individual kernels within captured CUDA graphs. After post-processing tracer with a companion post-processing script (python -m torch.cuda._annotate_cuda_graph_trace) annotations are merged into traces. These annotations appear in Perfetto/Chrome profiler traces, making it significantly easier to understand what each kernel in a replayed graph is doing. Contributed by Shangdi Yu (Meta) (PR #179768).

CUDA Green Context workqueue limit

CUDA Green Contexts now support specifying a workqueue limit, giving finer-grained control over GPU resource partitioning. This experimental feature allows users to constrain the number of concurrent work submissions within a green context, enabling more predictable resource sharing across concurrent workloads. Contributed by Matthias Jouanneaux (NVIDIA) (PR #177242).

ROCm

ROCm: Expandable segments

AMD GPUs (ROCm >= 7.02) now support expandable memory segments in PyTorch’s caching allocator, matching the CUDA feature that reduces memory fragmentation by dynamically growing allocations via virtual memory APIs. Added by Prachi Gupta (AMD) (PR #173330) 

ROCm: rocSHMEM support

rocSHMEM support enables symmetric memory collective operations (torch.ops.symm_mem.*) on AMD GPUs, porting the NVSHMEM-based on-GPU communication primitives — including point-to-point, broadcast, all-to-all, and MoE-oriented 2D AllToAllv — to ROCm. The rocSHMEM implementation uses a dedicated compilation unit to handle API and warp-size differences between NVSHMEM and rocSHMEM. Contributed by Prachi Gupta (PR #173518).

ROCm: hipSPARSELt and FP8 semi-structured sparsity

hipSPARSELt is now enabled by default in PyTorch builds on ROCm >= 7.12, bringing semi-structured (2:4) sparsity support to AMD GPUs. FP8 (float8_e4m3fn) inputs are also now supported through hipSPARSELt on MI350X (gfx950), with FP32 output. This enables the same torch._cslt_sparse_mm sparsity acceleration path that was previously CUDA-only. hipSPARSELt enabled by rraminen (AMD) (PR #170852), with FP8 semi-structured sparsity added by Benji Beck (Meta) (PR #179310).

ROCm: Inductor FlexAttention pipelining

FlexAttention on AMD GPUs now uses two-stage pipelining in the Triton backend, delivering 5-26% speedups across a range of attention patterns (causal, alibi, sliding window) and shapes on MI350X. This was a one-line configuration change (num_stages=1 to 2) that unlocks more efficient memory-compute overlap. Contributed by nithinsubbiah (PR #176676).

Apple MPS

MPS: Metal-4 offline shader compilation

Apple Silicon binary wheels now ship with ahead-of-time-compiled Metal-4 shaders, built on macOS 26 with the metal-4 standard. This eliminates the runtime shader compilation overhead on first run, reducing startup latency for MPS workloads. Contributed by Isalia20 (Irakli Salia) (PR #179378).

Deprecations and Breaking Changes

Distributed: Planned Breaking Changes for torchcomms

We’ve been working hard on integrating torchcomms directly into PyTorch Distributed so everyone can get the benefits out of the box. In an upcoming  release (2.13+) we’re planning on using torchcomms by default, which includes some breaking changes to how ProcessGroups operate. We aim to make these changes work automatically for most models and fix any incompatibilities in the ecosystem, but nevertheless, some models will be impacted.

We’re still polishing torchcomms but you can use it right now and get access to the new APIs, fault tolerance, window, scalability, and debuggability features. To get started, pip install torchcomms and set TORCH_DISTRIBUTED_USE_TORCHCOMMS=1.

See https://github.com/meta-pytorch/torchcomms for more details.

Key changes:

  • Eager Initialization: We will require all ProcessGroup/communicators to be eagerly initialized during dist.init_process_group and only support a single backend device. This means that the device will have to be specified during initialization.
  • P2P operations: We aim to make each ProcessGroup/communicator match 1:1 with the underlying communicator. This means that P2P operations issued on the same group/stream will not be guaranteed to run concurrently. Concurrent P2P operations will be required to use the batch APIs or a separate group/communicator.
  • torchcomms dependency: We plan to make torchcomms a required package for PyTorch Distributed and deprecate the existing c10d::Backends in favor of a single, more modern communication definition.

The torchcomms integration is being led by the PyTorch Distributed team, with groundwork in 2.12, including backend wrapper refactoring by Yifan Mao (PR #177157) and FlightRecorder integration by Tushar Jain (PR #175270).

Torchscript is now Deprecated

Torchscript was deprecated in 2.10 and torch.export should be used to replace the jit trace and script APIs, and Executorch should be used to replace the embedded runtime. For more details, see this talk from PTC.

Deprecation of the CUDA 12.8 Wheel

Starting with PyTorch 2.12, the CUDA 12.8 binary wheel is deprecated and will no longer be published as part of the standard release matrix. The default wheel remains CUDA 13.0 (via pip install torch from PyPI), and CUDA 13.2 has been added as an experimental build.

Users running on older architectures (e.g., Pascal, Volta) should switch to the CUDA 12.6 wheel, which remains supported in this release. Users running on newer GPUs (e.g., Blackwell) should use the CUDA 13.0+ wheels; note that this requires an NVIDIA driver upgrade to 580.65.06 (Linux) or 580.88 (Windows).

Updated (2026-05-19): Removed the following sentence, as pytorch/pytorch#177276 did not land in the 2.12 release: A companion API (torch._C._mps_loadMetallib) was also added for loading pre-compiled .metallib blobs directly, supporting the Triton Apple MPS backend’s compile-time metallib workflow.

]]>
Efficient Edge AI on Arm CPUs and NPUs: Understanding ExecuTorch through Practical Labs https://pytorch.org/blog/efficient-edge-ai-on-arm-cpus-and-npus/ Tue, 12 May 2026 15:50:53 +0000 https://pytorch.org/?p=79236 TL;DR:

  • ExecuTorch extends the PyTorch ecosystem to deliver local AI inference on constrained edge devices. To provide a practical entry point, Arm has created a set of Jupyter Labs that complement the official ExecuTorch documentation while explaining both the how and the why of each step. 
  • The blog and labs introduce both CPU and NPU inference, across Cortex-A and Cortex-M + Ethos-U platforms, and showcase use of Model Explorer adapters, developed by Arm, to gain visibility into model deployment with ExecuTorch. 

AI is rapidly and undisputedly becoming part of how we work and live. But today, much of that intelligence is still tied to the cloud, accessed through APIs and web interfaces. 

That model doesn’t always fit. Businesses increasingly want to bring AI closer to where it’s actually used—on devices like wearables, smart cameras, and other low-power edge systems. Running AI locally can reduce latency, improve privacy, and unlock new real-time capabilities, but it also introduces a new challenge: how do you run complex models efficiently on constrained hardware with limited memory, compute, and power? 

PyTorch has become the foremost framework for training and inferencing AI models in the cloud. ExecuTorch extends that ecosystem to bring local AI inference to the edge. It takes a PyTorch model, exports it into a lightweight format, and runs it through a runtime built specifically for edge inference. If you’re already familiar with PyTorch, the appeal is clear: you stay in the same ecosystem, while gaining a deployment path better suited to real devices.

To make this practical, Arm has created a set of hands-on Jupyter labs that walk through the deployment process—from CPU inference on a Raspberry Pi through to hardware acceleration on Ethos-U NPUs. Whether you’re an ML developer already comfortable with PyTorch or an embedded engineer building your ML foundations, this lab series provides a practical entry point, with executable examples that complement the official ExecuTorch documentation while explaining both the how and the why of each step.

ExecuTorch on Edge CPUs

You may already be familiar with running PyTorch on edge devices like the Raspberry Pi 5. We explore this in our course Optimizing Generative AI on Arm. While this works well, the Pi sits in the category of single-board computers (SBCs), with significantly more resources than many production-grade embedded or IoT systems. For more constrained targets—such as Cortex-M microcontrollers— running PyTorch is not viable due to its size and dependencies.

ExecuTorch addresses this and enables efficient deployment of PyTorch models to edge devices. This is achieved through exporting a model into a minimal .pte artefact containing both the model weights and a static computation graph. This removes the need for Python at runtime and avoids dynamic execution overhead that is unnecessary for inference.

The export step is followed by lowering, where the model graph is transformed into a backend-compatible form. This is where hardware-aware optimization begins.

The resulting artefact is: 

  • lightweight and portable
  • predictable in execution
  • suitable for deployment on constrained systems 

Beyond the portability of the .pte, there are other benefits. Even on devices like the Raspberry Pi, which can run PyTorch models without needing ExecuTorch, performance improvements can be found through using ExecuTorch. However, performance depends heavily on how the model is executed. ExecuTorch achieves performance by delegating parts of the model to optimized backends.

On Arm CPUs, this is typically done using the XNNPACK backend. When enabled, supported operators—such as convolutions and matrix multiplications—are delegated to highly optimized implementations. On Arm platforms, these implementations leverage KleidiAI microkernels, which make efficient use of architectural features such as Neon. In our labs, we compare inference of an OPT-125M transformer model on a Raspberry Pi 5. The graph below shows a significant latency reduction when using ExecuTorch with XNNPACK:

PyTorch eager

ExecuTorch XNNPACK

Fig 1. Comparison of PyTorch and ExecuTorch Inference Time on Raspberry Pi 5 CPU 

Note: Both PyTorch eager mode and ExecuTorch + XNNPACK were run with several warm-up iterations discarded to avoid the typical pattern of slower initial runs followed by faster steady-state performance. 

In the case of ExecuTorch, the opposite trend is observed: the first few measured runs are faster, with latency increasing over subsequent runs. This behaviour is attributed to thermal effects on the Raspberry Pi. Sustained, highly optimized inference with ExecuTorch + XNNPACK places greater load on the CPU, leading to increased temperature and a corresponding reduction in clock speed over time. No active cooling was used during these experiments.

It’s important to note that backend delegation doesn’t occur by default. Running ExecuTorch without XNNPACK will often result in higher latency compared to PyTorch (which has its own KleidiAI optimizations), though you still benefit from a reduced runtime footprint and improved portability.

The key takeaway is that ExecuTorch provides the deployment framework, but backend selection determines how effectively the hardware is utilized.

From CPU to NPU: Ethos-U and TOSA

To go further, we can target hardware acceleration using Arm Ethos-U NPUs, typically paired with Cortex-A or Cortex-M CPUs. 

At this point, execution becomes heterogeneous. Rather than running the entire model on one processor, ExecuTorch partitions the graph: 

  • supported subgraphs are delegated to the NPU
  • unsupported operators fall back to the CPU 

Ethos-U operates on quantized integer models (typically INT8), so models must be quantized before delegation. The first step is to create a quantizer specific to the backend using EthosUQuantizer and a compile_spec matching your specific target Ethos-U.  

For example, the Ethos-U targeted here is an Ethos-U85 with 256 multiply-accumulate (MAC) units: 

 compile_spec = EthosUCompileSpec(
target="ethos-u85-256",
system_config="Ethos_U85_SYS_DRAM_Mid",
memory_mode="Shared_Sram",
extra_flags=["--output-format=raw"],
)

quantizer = EthosUQuantizer(compile_spec)

Once a quantizer has been created, the PyTorch 2 Export (PT2E) quantization flow can be performed as normal.

The next step involves lowering the model into TOSA (Tensor Operator Set Architecture), an intermediate representation designed to bridge high-level frameworks and hardware backends. TOSA provides a stable, hardware-agnostic operator set. Instead of requiring each hardware vendor to support every framework-specific operator, models are lowered into TOSA, and hardware backends implement this smaller, standardized set.

This step uses the to_edge_transform_and_lower API, specifying use of the EthosUPartitioner. For Ethos-U this triggers the backend path that serializes to TOSA and runs Vela to produce an optimized command stream for execution on the NPU. Finally, .to_executorch(...) packages the result into a .pte file.

Understanding this flow is useful when analyzing performance. Efficient delegation typically results in large, contiguous subgraphs running on the NPU. If unsupported operators are present, the graph can become fragmented, leading to multiple smaller subgraphs and increased overhead due to frequent transitions between CPU and NPU.

To make this visible, the labs utilize Google’s Model Explorer, along with adapters developed by Arm. These tools allow you to:

  • Inspect the ExecuTorch graph (.pte) and visualize how it is partitioned across backends
  • Examine the TOSA representation (.tosa)
  • You can also visualize VGF (.vgf) files used for the Arm ML SDK for Vulkan® (not covered in the hands-on labs)

For example, below are two .pte files targeting the same Ethos-U configuration, but generated from slightly different models.

The right-hand image shows a MobileNetV2 model with an additional LRN layer inserted. Because LRN is not natively supported, it is decomposed into lower-level operations during lowering. Not all of these operations can be delegated, and the graph is partitioned into multiple segments. Supported regions are delegated to the NPU, while the unsupported portion runs on the CPU. In contrast, the left-hand model is regular MobileNetV2, and contains only supported operators, allowing the entire compute region to be delegated as a single, continuous Ethos-U subgraph.

MobileNetV2MobileNetV2LRN

Fig 2. Model Explorer using the PTE Adapter to inspect .pte files targeting Ethos-U for two different models (MobileNetV2, MobileNetV2 + LRN layer)

This level of visibility helps explain performance behavior and can guide optimization decisions.

Practical Next Steps

To get familiar with these topics, we have released a collection of Jupyter labs, designed so you can run and modify the code on your own hardware – making the theory immediately actionable. Take a look here. 

This collection includes contributions from Professor Marcelo Rovai (UNIFEI University, and a member of the Edge AI Foundation Academia-Industry Partnership).

Additional thanks go to the academic reviewers at IIIT Bangalore, who ensured the material is rigorously validated and valuable for developers and learners.

For a broader overview of Edge AI Developer Resources provided by Arm, please look here. 

Building models is only half the story—getting them running efficiently at the edge is what matters. ExecuTorch makes that possible, and these labs show you how to get started quickly while understanding the underlying concepts.

]]>
In-Kernel Broadcast Optimization: Co-Designing Kernels for RecSys Inference https://pytorch.org/blog/in-kernel-broadcast-optimization-co-designing-kernels-for-recsys-inference/ Tue, 05 May 2026 16:56:32 +0000 https://pytorch.org/?p=76581 TL;DR:

  • Traditional RecSys inference explicitly replicates shared user embeddings/sequences for every candidate. In-Kernel Broadcast Optimization (IKBO) eliminates this overhead via a kernel-model-system co-design that fuses broadcast logic directly into user-candidate interaction kernels. By decreasing both the memory footprint and IO utilization, IKBO unlocks even higher throughput.
  • IKBO delivers up to a 2/3 reduction in compute-intensive net latency, serving as the scalability backbone for the request-centric, inference-efficient framework that powers the Meta Adaptive Ranking Model.
  • Deployed end-to-end across Meta’s multi-stage recommendation funnel on both GPU and MTIA (Meta Training and Inference Accelerator).
  • The IKBO Linear Compression kernel achieved a cumulative ~4× speedup on H100 SXM5 after four stages of progressive co-design, culminating in warp-specialized fusion via TLX.
  • The IKBO co-design shifted the Flash Attention kernel from IO-bound to compute-bound (hitting 621 BF16 TFLOPs on H100 SXM5). Coupled with TLX warp-specialized optimization, this results in a 2.4x/6.4× throughput gain over the non-co-designed CuTeDSL FA4 Hopper baseline (kernel only/kernel + broadcasting).

In this post, we present In-Kernel Broadcast Optimization (IKBO), a kernel-model-system co-design approach that eliminates redundant user-embedding broadcast in recommendation model inference. In production RecSys, user embeddings are identical across all candidates for a given request, yet standard approaches require explicit replication, wasting memory bandwidth and compute that scale with candidate count. IKBO encodes a simple insight: broadcast is a data layout concern, not a computational necessity. Each IKBO kernel accepts user and candidate inputs at their natural, mismatched batch sizes and handles broadcast internally, so no replicated tensors ever materialize. We showcase the methodology through two kernel deep dives: Linear Compression and Flash Attention.

Deployed across Meta’s RecSys inference stack—from early-stage to late-stage ranking models, spanning both GPU and MTIA (Meta Training and Inference Accelerator)—IKBO delivers up to a 2/3 reduction in compute-intensive net latency on co-designed models. It serves as the scalability backbone for the request-centric, inference-efficient framework underlying the Meta Adaptive Ranking Model (serving LLM-scale models in production). On H100 SXM5, our IKBO Linear Compression kernel achieves ~4× speedup through four progressive co-design stages: matmul decomposition, memory alignment, broadcast fusion, and warp-specialized multi-stage fusion via TLX (Triton Low-Level Extensions). For Flash Attention, IKBO delivers a 2.4×/6.4× throughput compared to non-co-designed CuTeDSL FA4-Hopper (kernel only / kernel + broadcasting) with 621 BF16 TFLOPs. Unlike system-level broadcast or net-splitting that work around replication, IKBO eliminates it at the computational primitive layer, achieving dense interaction quality at near-independent cost.

Code Repository: https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/ikbo

Work done while at Meta

1. In-Kernel Broadcast Optimization: Eliminating Memory and Compute Redundancy

When a user opens their feed, the recommendation system must score hundreds to thousands of candidate items to decide what to show. The model’s inputs split into two categories: user features (e.g., browsing history, profile, context) that are identical for every candidate in a request, and candidate features (e.g., item ID, category, engagement statistics) that are unique to each item. Both pass through embedding lookups and subsequent processing to produce embedding representations. At various points in the model, interaction layers (e.g., linear projections, feature crosses, target attention) combine user and candidate embeddings. We call embeddings shared across all candidates in a request Request-Only (RO), and per-candidate embeddings Non-Request-Only (NRO).

Fig. 1. A very simplified RecSys inference data flow. Request-Only (RO) user embeddings must be broadcast (replicated) to match the Non-Request-Only (NRO) candidate batch dimension before interaction layers. IKBO eliminates this materialization by handling broadcast internally within each kernel.

Interaction layers require tensors with matching batch dimensions. In a batch of 1,024 candidates served by ~15 users, RO embeddings must be broadcast, replicated ~70 times, to match the NRO batch size before any interaction (Fig. 1). As architectures have evolved from DLRM [1] and DCN [2] through sequential models like HSTU [3] and X’s Phoenix [4], they have steadily enriched user-candidate interaction. But richer interaction comes at a cost: user features must be broadcast across all candidates. For batch sizes of 10 – 10,000+ in inference, this replication overhead incurs significant computation and memory cost that scales linearly with candidate count.

Broadcast is a data layout concern, not a computational necessity. Viewing the model and inference system through this lens opens optimization at every layer: the inference runtime eliminates system-level broadcast, user-only model layers run at the smaller user batch size, and kernels that mix both are redesigned to handle broadcast internally—no replicated tensors ever materialize. Deployed across Meta’s RecSys inference stack, from early-stage to late-stage ranking models, spanning both GPU and MTIA, IKBO delivers up to 2/3 reduction in compute-intensive net latency on co-designed models.

This post focuses on the kernel layer through two deep dives: Linear Compression and Flash Attention.

1.1. Kernel Optimization Type

Type I — Decomposable Operations. Mathematical restructuring lets the Request-Only (RO) portion be computed independently at small batch size, combining with the Non-Request-Only (NRO) portion only at the end. This saves both memory bandwidth and compute.

Type II — Memory-Only Optimization. Handling RO-NRO broadcasting within the kernel avoids redundant data movement, pushing the kernel away from IO bound.

1.2. E2E System Design

Deploying IKBO touches three layers of the infra stack:

  1. Kernels: Custom GPU kernels that accept mismatched RO/NRO batch sizes and handle broadcast internally (Sections 2 and 3).
  2. Compilation Specification: The ML compiler needs per-operator dynamic shape ranges to select appropriately shaped kernels. With one batch size this is trivial; with two (user and candidate) or even more, reliably resolving which each operator uses—across production models where interactions obscure batch lineage—requires systematic automation.
  3. Inference: The runtime passes the candidate-to-user mapping into the model instead of materializing the broadcast.

These kernels enter the model through one of two paths:

  1. Direct adoption: Model authors integrate IKBO kernels directly into their model definitions. When candidate-to-user ratio > 1 during training, the same kernels reduce training cost as well.
  2. Inference-time transformation: A pass automatically swaps standard ops for IKBO equivalents at inference time — no model code changes required.

The net effect: broadcast disappears from every stage of inference, with no architectural constraints on the model and no infrastructure changes beyond the inference runtime’s mapping interface.

1.3. Comparison with Other Approaches

Existing approaches work around broadcast rather than eliminating it. 

  1. System-level broadcast materializes the replicated tensor before GPU dispatch—simple but wasteful, with cost scaling linearly with candidate count. 
  2. Net-splitting (ROO) [5] partitions the model into RO and NRO sub-networks, reducing redundant work but constraining where user-candidate interactions can occur and still introduce extra cost at small RO batch sizes.

Both preserve broadcast as a materialized tensor. IKBO eliminates it at the computational primitive layer: savings scale with the candidate-to-user ratio, any interaction pattern works without broadcast cost, and the full NRO batch dimension provides GPU occupancy within fused kernels.

IKBO has been deployed on both GPU and MTIA accelerators. In this blog post, we focus on H100 GPU kernel design to illustrate the core optimization principles.

2. Kernel Deep Dive I: IKBO Linear Compression

Linear Compress Embedding (LCE) compresses input embeddings (B, K, N) via a learned projection (M, K) @ (B, K, N) → (B, M, N), and is widely adopted in Meta RecSys models, e.g., Wukong [6]. We go through four progressive optimization stages.

2.1 Matmul Decomposition

Fig. 2. LCE decomposition: baseline batched matmul (top-left), embedding separation and user deduplication along K (top-right), two independent GEMMs with broadcast-add on compressed output (bottom).

The baseline LCE computes a single batched matmul across all B candidates. The input embeddings concatenate user and candidate parts along K — but user embeddings are identical across all candidates for the same user.

Push broadcast past the matmul. Since W is batch-independent, we decompose by linearity: separate user and candidate embedding blocks along K, deduplicate the repeated user embeddings, and compute two independent GEMMs at their natural batch sizes. Instead of replicating user embeddings before the matmul, we broadcast only the small compressed result. See Fig. 2. With a candidate-to-user ratio of ~70 (a representative setting), the user batch shrinks from B=1024 to B_user ≈ 15 — a 70x reduction in user-side compute. The decomposition is implemented in standard PyTorch.

Result. 1.944 ms → 1.389 ms (28.5% reduction; benchmark setup in Appendix 1). Both the original batched GEMM (arithmetic intensity ~ 356 FLOPs/Byte, below H100’s ~495 FLOPs/Byte machine balance point; see Appendix 2 for derivations) and the two decomposed GEMMs are memory-bound, so the speedup is driven by memory cost reduction. Deduplication cuts memory cost more than half — as the user-side GEMM (B_user ≈ 15 vs. B = 1024) becomes negligible in cost.

Note that the decomposition pushes broadcast past the matmul: instead of replicating full K-dimensional input embeddings before the GEMM, we broadcast only the small compressed result, which is far cheaper. In Section 2.3, we will further eliminate this remaining broadcast entirely via in-kernel broadcast fusion.

The current bottleneck is L1/TEX pipeline utilization (84%) rather than DRAM utilization — a suspicious imbalance we will zoom into in the next section. Detailed profiling breakdown in Appendix 3.

2.2 Memory Layout Optimization

Detailed result analysis of the decomposed GEMM reveals an imbalance: L1/TEX sits at 84% of peak while DRAM reaches only 19%, indicating unnecessarily narrow memory loads. SASS confirms: every cp.async copies only 4 bytes instead of a single 128-bit load.

LDGSTS.E.LTC128B P0, [R203],      [R38.64]       // 4 bytes
LDGSTS.E.LTC128B P1, [R203+0x4],  [R38.64+0x4]   // 4 bytes  (×4 total, only 16B load in total)

cp.async width is capped by the source pointer’s natural alignment. Matrix A is (M, K) row-major with stride K × 2 bytes, so when K is not a multiple of 8, the stride breaks 128-bit alignment.

Model-kernel co-design insights. Memory alignment is a well-understood GPU optimization — but decomposition turns it into a model-kernel co-design challenge. K is formed by torch.cat of embedding tensors whose sizes depend on many model config factors. Decomposition makes it very hard to manually engineer these factors so that decomposed embeddings remain perfect multiples. A systematic solution is needed.

Solution. Pad each decomposed K to the next multiple of 8 by appending zeros to the concat list. We prove this is mathematically equivalent in both forward and backward passes (see Proof 1 below), and with the ML compiler’s memory planner, reduces to a cheap constant copy.

Proof 1. Zero-padding K preserves exact numerical equivalence in both forward and backward passes.

Result. 1.389 ms → 0.798 ms (42.5% reduction). Padding enables CUTLASS to select a TMA-based kernel, bypassing L1/TEX entirely (sectors 351M → 0) and cutting GEMM latency from 0.984 ms to 0.400 ms. With the GEMM resolved, the unfused broadcast and add (0.398 ms) now accounts for half the total latency — to be addressed in the next section. Detailed result analysis in Appendix 5.

2.3 Candidate GEMM In-Kernel Broadcast Fusion

The unfused broadcast and add are memory-bound: write the candidate GEMM result to HBM, read it back alongside the user result, add, and write again. We eliminate this by fusing the broadcast into the candidate GEMM epilogue (Fig. 3). After each tile’s accumulation, the epilogue looks up the user index, loads the pre-computed user result, adds it in registers, and writes the final sum — the intermediate tensor is never materialized. We implement this as a Triton kernel: a standard batched GEMM with a custom post-accumulation epilogue block.

Fig. 3. In-kernel broadcast fusion: the GEMM epilogue loads the pre-computed user result via index lookup and adds it in-register.

Result. 0.798 ms → 0.580 ms (27.4% reduction). Fusion eliminates 0.87 GB of intermediate DRAM traffic, contributing to the latency win. However, occupancy is just 6.25% (1 warp per scheduler), leaving every stall fully exposed. Beyond 42% of cycles waiting on global loads, 20% are spent waiting on WGMMA — stalls that cannot be hidden by the epilogue, and without persistence there is no next-tile load to overlap with. This is a challenging tradeoff: large tiles and deep pipelines are needed to keep tensor cores fed, but they consume most of the shared memory budget, leaving little room to hide latency through occupancy. Detailed result analysis in Appendix 6.

2.4 Warp-Specialized Multi-Stage Fusion with TLX

TLX (Triton Low-level Language Extensions) exposes Hopper’s warp specialization, TMA, mbarriers, and named barriers while preserving Triton’s Python DSL and autotuning infrastructure.

Using TLX, we address the occupancy limitation from Section 2.3 with warp specialization — hiding latency through functional partitioning rather than additional warps.

Sections 2.1 – 2.3 decomposed the original LCE into two independent computations: the user GEMM (Stage 1) and the candidate GEMM with fused broadcast-add epilogue (Stage 2). We first optimize latency hiding within Stage 2, the dominant bottleneck, then fuse both stages into a single persistent kernel.

Intra-Stage Latency Overlap

The candidate IKBO kernel is memory-bound — the design goal is to keep the memory pipeline continuously fed. Triton’s software pipelining (Section 2.3) already overlaps Loads with WGMMA, but the epilogue remains serialized — it blocks future Loads and exposes the WGMMA wait stalls. We resolve both by partitioning each CTA into specialized warp groups: a dedicated producer issues TMA loads continuously (Overlap #1, analogous to Triton’s software pipeline), while two consumers ping-pong tiles so one’s epilogue overlaps the other’s WGMMA (Overlap #2). With persistence, tiles flow continuously with no cross-tile gaps. See Fig. 4.

Fig. 4. Candidate IKBO kernel structure with two intra-stage latency overlaps and warp group role assignments.

Multi-Stage Fusion

We fuse user IKBO (Stage 1) and candidate IKBO (Stage 2) into a single mega-kernel to reduce wave quantization, eliminate kernel launch overhead, and improve L2 cache utilization. High candidate-to-user ratios amplify wave quantization in Stage 1. Since the candidate GEMM is independent of user results until its epilogue, we schedule both stages concurrently.

This concurrent scheduling unlocks two additional cross-stage overlaps, bringing the total overlaps to four. See Fig. 5.

Fig. 5. Concurrent stage scheduling: SMs without user tiles enter Stage 2 immediately, overlapping with Stage 1’s partial wave. All four latency overlaps after multi-stage fusion, showing intra-stage (#1, #2) and cross-stage (#3, #4) overlap opportunities. SM 0-49, 50-131 are example numbers.

Warp Group Specialization & Synchronization Setup

To realize all four overlaps, each CTA is partitioned into one producer and two consumer warp groups. Critically, both stages share the same circular buffer and mbarrier infrastructure — no pipeline drain or barrier reinitialization occurs at the stage boundary. The last user K-block and the first candidate K-block coexist in different buffer slots simultaneously. See Fig. 6.

Fig. 6. Per-CTA warp group setup and the three synchronization mechanisms.

Bidirectional Stage-Alternating Tile Scheduling

When neither stage’s tile count divides evenly by the SM count, naive unidirectional dispatch causes workload imbalance. We reverse tile assignment direction between stages: Stage 1 starts at pid, Stage 2 at NUM_SM - 1 - pid. See Fig. 7.

Fig. 7. Unidirectional (left) vs. bidirectional stage-alternating dispatch (right), balancing per-SM workload across partial waves.

Tile-Granularity Cross-CTA Synchronization

User and candidate tiles may execute on different CTAs, requiring cross-CTA synchronization — but a device-wide barrier would serialize all work and destroy the overlap. We synchronize at per-tile granularity using a three-step release-acquire protocol: 

  1. A single thread per warp group spins on the tile flag with ld.relaxed, minimizing memory traffic
  2. Once set, a single ld.acquire establishes the happens-before edge
  3. A named barrier broadcasts readiness to all 128 threads in the warp group

This avoids expensive fences during polling and lets candidate CTAs on different user tiles proceed fully independently. Details in Appendix 7.

Results

With all optimizations combined, latency improves from 0.580 ms to 0.482 ms (16.9% reduction). The clear intra-warp Proton tracer timeline confirms all four overlaps are realized in practice.

Fig. 8. Proton profiler timeline for two CTAs, with all four overlaps color-coded. The memory pipeline remains continuously fed.

The primary gain comes from Overlap #2: ping-ponging consumers hide WGMMA and epilogue stalls on every tile — directly addressing the dominant wasted cycles from Section 2.3. Overlap #1 (Load↔WGMMA) carries forward from Triton’s existing software pipelining. Overlaps #3 and #4 hide idle time at the user-to-candidate stage transition. See Fig. 8.

NCU confirms: occupancy rises from 6.25% to 18.75% (3 warp groups vs. 1), DRAM throughput from 39% to 52%, and L2 — the bottleneck — from 74% to 84% of peak. This is not occupancy alone: the aggressive latency hiding across all four overlaps keeps the memory pipeline saturated, which is what pushes L2 past 80%. Detailed NCU metrics in Appendix 8.

We benchmark across batch sizes and candidate-to-user ratios, with the default (batch=1024, ratio=70) settings. See Fig. 9.

Fig. 9. Cumulative IKBO speedup across batch sizes (left, ratio=70) and candidate-to-user ratios (right, batch=1024).

The IKBO fusion delivers robust gains across scenarios: ~4x speedup across batch sizes (left) and candidate-to-user ratios (right). Even at low candidate-to-user ratios, the kernel still achieves meaningful speedup.

3. Kernel Deep Dive II: IKBO Flash Attention

As recommendation models scale to capture richer user sequential behavior, sequential architectures – including attention – have emerged as a critical compute bottleneck, accounting for approximately 40% of inference latency at 1K sequence lengths. This motivates our focus on IKBO-aware Flash Attention, co-designed with RecSys’s unique batching semantics.

Inspired by Transformers and Set Transformers [7, 8], two fundamental user history interaction modules have been widely adopted in RecSys: 

  • Target attention (analogous to cross-attention) captures the relationship between the prediction candidate and the user’s historical interactions.
  • Self-attention models sequential dependencies within the user history itself

Since user history is a RO feature while the target operates on a distinct candidate (non-RO) batch dimension, this architectural asymmetry presents an opportunity for IKBO to improve model scalability and computational efficiency. Target attention will be our main focus for optimization, while with minor co-design, self attention could also be fused into IKBO target attention in Section. 3.3. As our model is encoder-driven, full attention is applied without causal masking.

The ultimate optimized target attention version leveraging e2e co-design achieves 2.4×/6.4× the throughput of non-co-designed CuTeDSL FA4-Hopper (attn kernel only / attn kernel + broadcasting cost), reducing latency by 0.320ms / 1.232ms respectively (Table. 2).

3.1 IKBO flash attention solves the IO bound issues under RecSys boundary conditions

Fig. 10: Traditional SDPA with candidate-user broadcasting (left) vs. fused IKBO target attention (right). 

IKBO fuses K/V broadcasting into the attention kernel, maintaining mathematical equivalence via a candidate-user mapping tensor from the inference runtime that handles non-uniform candidate-to-user ratios. Fig. 10 contrasts the two approaches: the traditional SDPA path broadcasts K and V to the full candidate batch size before attention, while the IKBO path eliminates this materialization entirely — each candidate indexes into its user’s K/V on the fly.

Shifting IO-Bound to Compute-Bound by IKBO co-design

In RecSys boundary conditions, target attention uses a relatively small number of candidate embeddings to represent the candidate attributes compared to the user’s browsing history. Roofline analysis of standard attention reveals an arithmetic intensity of ~60 FLOPs/Byte – well below the H100 (SXM5 HBM2e version) peak of ~495 FLOPs/Byte (Appendix 2)—making even standard flash attention heavily IO-bound. IKBO addresses this by amortizing K/V memory accesses across multiple candidates sharing the same user context, improving arithmetic intensity from ~60 FLOPs/Byte to ~833 FLOPs/Byte (at B_candidate : B_user = 70:1) and shifting the kernel firmly into compute-bound territory.

To maximize this benefit, our implementation reorders the threadblock launch grid so that batch_size_candidate comes before num_heads. This ensures threadblocks processing different candidates — but sharing the same user K/V — are scheduled concurrently, improving L2 cache reuse.

Grid dimension Flash attention (SDPA) IKBO target attention
x num_q_seq_block num_q_seq_block
y num_heads batch_size_candidate
z batch_size_candidate num_heads

Table 1: Launch grid configuration comparison. SDPA prioritizes GQA optimization by placing num_heads in grid.y. IKBO swaps head and candidate dimensions, placing batch_size_candidate in grid.y to enable efficient K/V sharing across candidates.

Table 2 compares our IKBO Triton implementation (FA2 logic + IKBO) against state-of-the-art Flash Attention implementations on Hopper (without IKBO co-design). Throughput and IO are measured on attention only; the broadcasting latency for Key and Value is even larger than the attention cost itself.

Throughput (TFLOPs/s) IO (GB/s) Latency (ms)
Triton IKBO FA2 425 487 0.321 (broadcast fused)
TLX FA3 245 2152 0.561 + 0.912 (broadcast K&V)
CuTeDSL FA4 Hopper 250 2193 0.550 + 0.912 (broadcast K&V)
TLX IKBO FA3 persistence generalized 594 681 0.230 (broadcast fused)

Table 2: Attention kernel comparison under RecSys boundary conditions (B_candidate = 2048, B_u = 32, uniform candidate-to-user ratio). Without co-design, even cutting-edge Hopper implementations remain IO-bound.

3.2 Adopting Modern Kernel Techniques (FA3, FA4) with IKBO on TLX

With IKBO shifting the kernel from IO-bound to compute-bound, the natural next step was to adopt the state-of-the-art compute optimizations from Flash Attention 3 (FA3 [10]) and Flash Attention 4 (FA4 [11]) on Hopper – specifically warp specialization and pipelining. However, our boundary conditions on the number of query embeddings (q_seq = 32 or 64) make it difficult to directly adopt FA3’s ping-pong or cooperative warp specialization.

Warp specialization on Hopper requires asynchronous WGMMA instructions, which impose a minimum BLOCK_M ≥ 64. Two consumer warp groups are also necessary to minimize bubbles between them. To satisfy these constraints, we customized the kernel to launch both B_candidate = i and B_candidate = i + 1 within a single threadblock, sharing the same B_user. In the discussion below, we assume all users rank an even number of candidates with q_seq = 64; odd-candidate handling follows afterward.

Performance improvement for IKBO FA3 kernel

Starting from FA3’s recipe — intra-warp pipelining, warpgroup specialization, and ping-pong scheduling — the initial TLX IKBO FA3 kernel performed similarly to the FA2 baseline (Fig. 12, blue vs. red, Appendix 11), with on-par throughput.

To diagnose the bottleneck, we visualized intra-warp pipelining using the Proton tracer with GPU cycles as the latency unit (Fig. 10). Table 3 summarizes the key bottlenecks before and after persistence, measured in GPU cycles via the Proton tracer.

Fig. 11: Proton-based intra-warp profiling of the TLX IKBO FA3 kernel. Representative warps from each warp group are shown: warp 0 (producer), warp 4 (consumer 1), and warp 8 (consumer 2). The softmax_PV_overlap and pure softmax regions are marked separately to identify the tensor core bubbles. (A) Before persistence zoomed in view of B (B) Before persistence with 2 waves (C) After persistence with 2 waves

Bottlenecks Before After Key change
Tensor Core Bubbles (1st QKT per wave, Blue) ~1,300 cycles (400 cycles from warp scheduler switching) ~1,300 cycles Unchanged
Tensor Core Bubbles (last PV per wave, Blue) ~2,000 cycles ~300 cycles Async TMA store + reciprocal overlap with last PV
Cross-CTA Stalls (Orange) ~14,000 cycles Eliminated Persistence removes CTA re-launch entirely
Init Buffers & Barriers (Green) ~1,600 cycles/wave ~1,600 cycles (1st wave only) Persistence shared buffer and barrier amortized across waves 
Wait 1st Q/K Load (Dark purple) 2,100~4,000 cycles/wave (length varies depending on HBM bandwidth contention) ~2,000 cycles (1st wave only) Cross-wave pipelining; producer prefetches ~3K cycles ahead

Table 3: Key bottlenecks before and after persistence + optimizations.

Key takeaway: cross-CTA stalls are the dominant bottleneck — not tensor core utilization – at these small query sequence lengths. Persistence is a must for this improvement. After persistence, the profiling results and its latency changes are presented in Fig. 11C and Table. 3.

HBM2e-Specific Optimizations

We further tuned the persistent kernel for the H100 SXM5’s HBM2e bandwidth constraints, trading shared memory capacity for reduced load/store blocking. (Table 4).

Customized optimization/fix Benefit
Decoupled SMEM buffer of O from Q/V with pipelined TMA async store Decoupled O from Q/V SMEM sharing enable TMA async stores could overlap with next-wave compute, shortening store blocking time from 1,300 to 400 cycles/wave
Separate Q₀ and Q₁ buffers  Reduces per-Q loading time, allowing one consumer group starts earlier— beneficial when wave count greatly exceeds K/V sequence iterations (common in RecSys)
Instruction Cache Misses fix Merges the peeled-out last-iteration code path back into the main loop, eliminating icache thrashing caused by excessive warp-specialized instructions (Appendix 12)

Table 4: Customized optimizations for the HBM2e H100 SXM5. These still fit within the available SMEM budget under RecSys boundary conditions (Appendix 10).

We also implemented persistent V2, which iterates from the end of the K sequence to the front (matching FA3/FA4-Hopper’s approach) to simplify masking logic. Both persistent variants apply the Table 4 optimizations. As shown in Fig. 12, at low sequence lengths (512–4,096) the TLX FA3 persistent kernel outperforms all other candidates; beyond 8K the two persistent variants converge.

Fig. 12: IKBO implementation throughput vs. sequence length (B_candidate = 2,048; B_candidate : B_user = 64; num_head = 2; d_head = 128). Practical RecSys sequence lengths are under 4K [3]; longer lengths are included for comparison with LLM use cases. The generalized version handles non-even candidates per user with 50% odd-candidates per user probability

Generalizing IKBO FA3 for ranking Arbitrary Candidate Batch Sizes

Our IKBO FA3 kernel co-processes two candidate batches per CTA to meet WGMMA’s BLOCK_M ≥ 64 requirement. When a user has an odd number of candidates, one consumer warpgroup has no pairing partner. We handle this with idling logic (Fig. 13, left; Algorithm 1):

  • The idle warpgroup drains K/V buffers via mbarrier signaling to prevent producer deadlock.
  • The active warpgroup disables ping-pong synchronization (its partner no longer arrives at the named barriers).

At a ~70 : 1 candidate-to-user ratio, the idle path triggers less than 0.7% of the time with negligible overhead (Fig. 12, IKBO TLX FA3 generalized). This approach generalizes to q_seq_len = 32, where four candidate batches are bundled per CTA using analogous idling and masking logic.

Fig. 13: CTA assignment for generalized target attention (left) and self + target attention fusion (right). Each CTA assigns two consumer warp groups sharing the same user K/V. When the candidate count is odd, the 2nd consumer idles and drains barriers.

Algorithm 1: IKBO Attention Forward Pass with Odd Candidate Handling

3.3 Self + Target Attention Fusion via Model Co-Design

The previous sections focused on optimizing target (cross) attention. A natural question arises: can we fold self-attention into the same kernel?

The key insight is that both attention types share the same key-value source — the user sequence. The only difference is the query: self-attention queries come from the user side, while target-attention queries come from the candidate side. By sharing K/V projections between the two, we enable direct horizontal kernel fusion within a single launch. Fig. 13 (right) illustrates the fused CTA layout: the first CTAs handle self-attention query blocks, while the remaining CTAs handle target-attention candidate pairs — all reading from the same pipelined K/V stream.

Similar co-design ideas have been explored in XAI Phoenix, an open-source recommendation system from X [4].

We prototyped a fused kernel to quantify the fusion benefit, excluding K/V projection savings (Fig. 13, right):

  • seq_len = 512:    6.6% improvement (514 vs. 482 TFLOPs/s)
  • seq_len = 1,024:  4.1% improvement (581 vs. 558 TFLOPs/s)
  • seq_len = 2,048:  0.3% improvement (612 vs. 610 TFLOPs/s) — self-attention saturates the SMs

The gains at short sequences stem from kernel fusion benefits: reduced launch overhead, shared buffer allocation savings, cross-kernel pipelining opportunities, and wave quantization mitigation — the same inefficiencies that megakernel techniques [12] target in LLM inference. In production, the shared K/V projections provide additional savings on linear projection cost, analogous to KV cache reuse.

4. Summary of Benchmarks and Results

We summarize the kernel-level benchmarks presented in this post alongside end-to-end deployment outcomes. All kernel benchmarks below are on H100 SXM5 (see details in Appendix 1).

  • Linear Compression (Section 2). Four progressive co-design stages — matmul decomposition, memory alignment, broadcast fusion, and warp-specialized multi-stage fusion via TLX — yield a cumulative ~4× speedup (1.944 ms → 0.482 ms) at representative settings. Gains remain robust across batch sizes and candidate-to-user ratios (Fig. 9).
  • Flash Attention (Section 3). IKBO shifts target attention from IO-bound (~60 FLOPs/Byte) to compute-bound (~833 FLOPs/Byte), achieving 2.4×/6.4× the throughput of non-co-designed CuTeDSL FA4-Hopper (kernel only / kernel + broadcasting) with 621 BF16 TFLOPs.
  • End-to-end deployment. IKBO has been deployed broadly across Meta’s RecSys inference stack — from early-stage to late-stage ranking models, on both GPU and MTIA accelerators — delivering up to 2/3 reduction in compute-intensive net latency on co-designed models. IKBO has been validated across candidate-to-user broadcast ratios spanning from ~10,000 : 1 down to ~10 : 1, confirming both numerical stability and scalability across workloads.

5. Conclusion and Future Directions

IKBO demonstrates that broadcast — long treated as an unavoidable cost of user-candidate interaction — can be eliminated at the computational primitive layer through kernel-model-system co-design. By encoding broadcast semantics directly into kernels, no replicated tensors ever materialize, and savings scale naturally with the candidate-to-user ratio.

While the kernel implementations presented in this work target NVIDIA Hopper via Triton and TLX, the core idea — replacing materialized broadcasts with index-driven in-kernel lookups — is hardware-vendor independent. Adapting the IKBO kernels to CuTeDSL (for advanced NVIDIA backend support) and completing the AMD CK support are natural next steps.

Beyond the two-level user-candidate hierarchy presented here, some RecSys scenarios involve deeper hierarchies — for example, user → ads vendor → ads item, where each user sees multiple vendors and each vendor offers multiple items. This introduces two nested broadcast relationships with independent, non-uniform ratios. IKBO can handle this elegantly, and applying it to multi-level workloads is a natural direction for further reducing materialization overhead in production RecSys architectures.

Acknowledgements

We are grateful to Hongtao Yu, Yuanwei (Kevin) Fang, Daohang Shi, Yueming Hao, Srivatsan Ramesh and Manman Ren for their strong internal support of the Triton and TLX foundation, the powerful Triton profiling toolings, and for promptly resolving Triton-related issues throughout this work.

Thanks Chris Gottbrath for his insightful feedback, which significantly improved the clarity of this post. We also greatly appreciate his help in facilitating a smooth review process.

Thanks Santanu Kolay, Sandeep Pandey, Matt Steiner, GP Musumeci, Ashwin Kumar, Ian Barber, Aparna Ramani, CQ Tang for leadership support.

References

[1] Naumov, M., et al. “Deep Learning Recommendation Model for Personalization and Recommendation Systems,” arXiv:1906.00091, 2019.

[2] Wang, R., et al. “Deep & Cross Network for Ad Click Predictions,” ADKDD, 2017.

[3] Zhai, J., et al. “Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations,” ICML, 2024.

[4] xAI. “Phoenix: Recommendation System,” GitHub, 2026. https://github.com/xai-org/x-algorithm

[5] Guo, L., et al. “Request-Only Optimization for Recommendation Systems,” arXiv:2508.05640, 2025.

[6] Zhang, B., et al. “Wukong: Towards a Scaling Law for Large-Scale Recommendation,” ICML, 2024.

[7] Vaswani, A., et al. “Attention Is All You Need,” NeurIPS, 2017.

[8] Lee, J., et al. “Set Transformer: A Framework for Attention-based Permutation-Invariant Input,” ICML, 2019.

[9] Dao, T. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,” ICLR, 2024.

[10] Shah, J., et al. “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision,” NeurIPS, 2024.

[11] Zadouri, T., et al. “FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling,” arXiv:2603.05451, 2026.

[12] Spector, B., et al. “Look Ma, No Bubbles! Designing a Low-Latency Megakernel for Llama-1B,” Hazy Research Blog, 2025. https://hazyresearch.stanford.edu/blog/2025-05-27-no-bubbles

Appendix

Appendix 1. Benchmark Setup

All experiments are conducted on a single NVIDIA H100 SXM5 GPU (700 W TDP, 96 GB HBM2e) with the following software stack:

  • CUDA: 12.4
  • PyTorch: 2.11.0a0+fb (internal build)
  • Triton: facebookexperimental/triton@4059e79bf (#831)

Appendix 2. Arithmetic Intensity Analysis

2.1 Machine Balance Point of H100 SXM5 (700 W TDP, 96 GB HBM2E)

2.2 Arithmetic Intensity of the Baseline LCE

For a batched matmul (M, K) @ (B, K, N) → (B, M, N) in FP16, with B=1024, M=433, K=2044, N=256:

Appendix 3. Detailed Result Analysis for Section 2.1

Setup: H100 SXM5 (Appendix 1), PyTorch eager mode (no kernel fusion), inference. Shapes from a representative configuration.

Version Total 

(ms)

Kernels Latency

(ms)

DRAM

(GB)

L1/TEX Sectors

(M)

Compute

(GFLOPs)*

Bottleneck

Baseline 1.944 1 CUTLASS GEMM 1.944 1.31 798 460 L1/TEX (89%) 
Decomposition 1.389 2 CUTLASS GEMM (user + candidate matmul) 0.984 0.68 351 200 L1/TEX (84%)
1 ATen Gather + 1 ATen add 0.405 0.87 36 0.11 DRAM (92%)

*Total FLOPs executed, not throughput.
†Bottleneck identified via NCU Speed of Light analysis; methodology in Appendix 4.

Deduplication eliminates >98% of user-side work (batch 1024 → ~15), cutting L1/TEX sectors from 798M to 351M and GEMM latency from 1.944 ms to 0.984 ms. The post-GEMM broadcast and addition costs 0.405 ms (DRAM-bound), yielding a net saving of 0.555 ms.

Precision note. The baseline accumulates all K products in a single FP32/TF32 reduction. Decomposition accumulates K_user and K_cand separately, then sums the partial results in BF16/FP16. Training uses the same decomposition, so numerics match end-to-end. For exact inference parity, a fused kernel (Section 2.4) can perform the final summation in FP32.

Appendix 4. Bottleneck Analysis Methodology

For a closer look after roofline analysis, we use NCU’s Speed of Light analysis to identify hardware subsystem bottlenecks. The bottleneck is the subsystem with the highest utilization relative to its peak sustained throughput. For the analysis in Section 2.1, we monitor three metrics:

Compute is the peak SM pipeline utilization, reported directly by NCU (Compute (SM) Throughput). It measures how busy the most active execution pipeline (tensor cores for GEMMs) is relative to its peak instruction rate.

L1/TEX utilization is derived from the total sectors the L1/TEX unit must process as below, where num_L1_tex_sectors is l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum and _st.sum counter, is SM_active_cycles sm__cycles_active.avg counter, num_SM is 132 and num_sustained_peak_sectors_per_sm_per_cycle is 2.0 on H100.

DRAM utilization is derived from total HBM bytes transferred as below, where dram_bytes_read_and_write is the dram__bytes_read.sum and dram__bytes_write.sum counter. peak_bandwidth is 2TB/s on the testing GPU server.

Appendix 5. Detailed Result Analysis for Section 2.2

Result. 1.389 ms → 0.798 ms (42.5% reduction).

Version Total Latency

(ms)

Kernels Latency

(ms)

DRAM Traffic

(GB)

Compute

(GFLOPs)

*not speed

L1/TEX Sectors

(M)

Bottleneck

Decomposition

(unpadded)

1.386 2 CUTLASS GEMM – user & candidate matmul 0.984 0.68 200 351 L1/TEX (84%)
1 ATen Gather – broadcast

1 ATen Elementwise – add

0.402 0.87 0.11 36 DRAM (92%)
Decomposition

(padded K)

0.798 2 CUTLASS GEMM – user & candidate matmul 0.400 0.69 200 0 Balanced
1 ATen Gather – broadcast

1 ATen Elementwise – add

0.398 0.87 0.11 36 DRAM (92%)

Two factors behind the large speedup.

  • TMA. With aligned matrices, CUTLASS selects a TMA-based kernel, bypassing L1/TEX entirely (sectors → 0). The unpadded kernel also penalized matrix B unnecessarily: it applied 4-byte loads to both matrices, even though B (with aligned N) could have used 128-bit loads.
  • Bank conflicts. The unpadded kernel also uses sm80 MMA path whose swizzle pattern doesn’t protect against 4-byte cp.async writes, causing many shared memory bank conflicts. The padded kernel doesn’t have this issue.

Appendix 6. Detailed Result Analysis for Section 2.3

Result. Latency: 0.798 ms → 0.580 ms (27.4% reduction).

Version Total Latency

(ms)

Kernels Latency

(ms)

DRAM Traffic

(GB)

Decomposition

(padded K)

0.798 2 CUTLASS GEMM – user & candidate matmul 0.400 0.68
1 ATen Gather – broadcast

1 ATen Elementwise – add

0.398 0.87
iKBO Fusion 0.580 user GEMM & candidate iKBO kernel 0.580 0.68

The 0.87 GB of intermediate DRAM traffic is eliminated as expected. NCU profiling reveals further opportunity: occupancy is just 6.25% with 1 warp per scheduler, and PC sampling shows only 23% of cycles are productive:

Stall Reason Percentage What it mainly refers in the kernel
Stall long scoreboard 41.8% Global memory loads
Selected (executing) 23.1% Productive work (good) – instructions actually issued
Stall wait 20.1% Wait WGMMA
Stall barrier 5.7% bar.sync between software-pipeline stages

With 1 warp per scheduler, every stall is fully exposed: there is no other warp to switch to. Increasing occupancy by reducing pipeline depth would sacrifice K-loop latency hiding. This is a challenging situation for this kernel: large tiles and deep pipelines are needed to keep the tensor cores throughput, but they consume most of the shared memory budget, leaving little room to hide latency through occupancy.

Appendix 7. Release-Acquire Synchronization Protocol

Producer (user CTA). After storing a user tile to global memory, the CTA sets a per-tile flag with release semantics, ensuring data visibility before the flag write:

tl.atomic_add(user_tile_flag_ptr, 1, sem="release", scope="gpu")

Consumer (candidate CTA). A single thread per warp group polls the flag with ld.relaxed to minimize memory traffic during the spin. Once the flag transitions, a single ld.acquire establishes the happens-before edge, and a named barrier broadcasts readiness to all 128 threads in the warp group:

if tlx.thread_id(axis=0) % 128 == 0:  # 1 thread per warp group (4 warps)
    ready = tl.inline_asm_elementwise(
        "ld.relaxed.gpu.global.b32 $0, [$1];", "=r,l",
        [user_tile_flag_ptr], dtype=tl.int32, is_pure=False, pack=1)
    while ready == 0:
        ready = tl.inline_asm_elementwise(
            "nanosleep.u32 50; ld.relaxed.gpu.global.b32 $0, [$1];", "=r,l",
            [user_tile_flag_ptr], dtype=tl.int32, is_pure=False, pack=1)
    tl.inline_asm_elementwise(
        "ld.acquire.gpu.global.b32 $0, [$1];", "=r,l",
        [user_tile_flag_ptr], dtype=tl.int32, is_pure=False, pack=1)
tlx.named_barrier_wait(12, 128)

Appendix 8. NCU Profiling Metrics for TLX vs. Triton

Metric Triton TLX Notes
Theoretical Occupancy 6.25% 18.75% 3 warp groups per CTA vs. 1
DRAM Throughput

(dram__cycles_active.avg.pct_of_peak_sustained_elapsed)

38.51% 52.39% Higher utilization from continuous TMA loads
L2 Cache Throughput

(lts__throughput.avg.pct_of_peak_sustained_elapsed)

73.69% 83.86% Bottleneck. TLX pushes closer to peak

Appendix 9. Roofline analysis of normal flash attention vs IKBO flash attention

Arithmetic intensity (AI) is calculated given FP16/BF16 precision, user_seq_len = 1024, n_seed = 64, B_candidate (B in eq) : B_user (B/num_cand_user in eq) = 70: 1.

Appendix 10. SMEM consumption of IKBO TLX FA3 

SMEM buffer Counts Block dim Total size
Query 2 (1 for each consumer group) 64 * 128  (2Bytes) 32KB
Key 2 128 * 128 (2Bytes) 64KB
Value 2 128 * 128 (2Bytes) 64KB
Output 2 (1 for each consumer group) 64 * 128  (2Bytes) 32KB
Total 192KB

Appendix 11. Benchmarking IKBO FA vs CuTeDSL FA4 Hopper and TLX FA3 Hopper kernel under RecSys boundary condition

IKBO kernel is basically enabling the user-candidate interaction mapping logic which shares a similar IO and computation pattern as GQA. During benchmarking, a stable B_candidate : B_user = 64 : 1 is applied for IKBO kernel and similar compute patterns for CuTeDSL FA4 Hopper GQA version (Q_seq_len = 128 to make sure 2-consumer warpgroup to work perfectly). Worth additional mentioning, IKBO kernel still needs to extra consume the candidate-user mapping tensor to handle a varied number of candidates to be ranked in real time.

Kernel type Throughput (TFLOPs/s) IO (GB/s)
Triton IKBO FA2 425 519
TLX IKBO FA3 418 510
TLX IKBO FA3 persistent 592 723
TLX IKBO FA3 persistent V2 (reverse k,v order) 537 655
CuTeDSL FA4 Hopper GQA 518 633
TLX FA3 GQA 576 703

IKBO FA benchmarked vs open-source GQA kernel. Q, K, V shape for IKBO kernel in the sequence of [Batch size, num head, seq, d_head] Q_ikbo [2048, 2, 64, 128], K/V_ikbo [32, 2, 1024, 128]. Q, K, V shape for GQA kernel Q_gqa [1024, 2, 128, 128], K/V_gqa [32, 2, 1024, 128] 

Kernel type Throughput (TFLOPs/s) IO (GB/s)
Triton IKBO FA2 449 329
TLX IKBO FA3 470 345
TLX IKBO FA3 persistent 621 455
TLX IKBO FA3 persistent V2 (reverse k,v order) 587 430
CuTeDSL FA4 Hopper GQA 608 445
TLX FA3 GQA 628 460

IKBO FA benchmarked vs open-source GQA kernel. Q_ikbo [2048, 2, 64, 128], K/V_ikbo [32, 2, 2048, 128]. Q, K, V shape for GQA kernel Q_gqa [1024, 2, 128, 128], K/V_gqa [32, 2,2048, 128] 

Note: Since standard Flash Attention kernels do not incorporate IKBO logic, we use a GQA configuration with similar IO cost and FLOPs consumption to simulate throughput results for cuteDSL versions.

Appendix 12: Instruction cache miss cause significant delay on the consumer-2 warpgroup

Fig. A1 

Instruction cache miss result before and after the fix

Before instruction cache miss fix:
    ---------------------------------------------------- ----------- ------------
    Metric Name                                          Metric Unit Metric Value
    ---------------------------------------------------- ----------- ------------
    gcc__cache_requests_type_instruction.sum                              319,394
    gcc__cache_requests_type_instruction_lookup_miss.sum                    7,234
    sm__icc_requests.sum                                       cycle    6,049,376
    sm__icc_requests_lookup_hit.sum                            cycle    5,438,421
    sm__icc_requests_lookup_miss.sum                           cycle      610,955
    ---------------------------------------------------- ----------- ------------

After instruction cache miss fix:
    ---------------------------------------------------- ----------- ------------
    Metric Name                                          Metric Unit Metric Value
    ---------------------------------------------------- ----------- ------------
    gcc__cache_requests_type_instruction.sum                               33,008
    gcc__cache_requests_type_instruction_lookup_miss.sum                      769
    sm__icc_requests.sum                                       cycle      792,437
    sm__icc_requests_lookup_hit.sum                            cycle      722,244
    sm__icc_requests_lookup_miss.sum                           cycle       70,193
    ---------------------------------------------------- ----------- ------------
]]>
SMG: The Case for Disaggregating CPU from GPU in LLM Serving https://pytorch.org/blog/lightseek-smg/ Thu, 30 Apr 2026 18:56:36 +0000 https://pytorch.org/?p=76091 How It Started: Hitting the GIL Wall at Scale

We’ve been running production model serving for many years. When we first started building Shepherd Model Gateway, the goal was modest: figure out if cache-aware load balancing could improve routing across inference replicas.

It could. And as we went deeper, we found a much bigger problem.

In both SGLang and vLLM, tokenization and detokenization had become bottlenecks. Not in theory — in production, under real traffic. The root cause was architectural: although both engines use Rust or C++ tokenizer libraries underneath, the calls go through Python. That means the GIL. That means a single-threaded ceiling on CPU-bound work that sits directly in the serving path.

At a small scale, this doesn’t matter. At large-scale prefill-decode disaggregated serving, and at large-scale expert parallelism across GPU clusters, it matters enormously. These configurations make GPUs extremely fast — fast enough that the CPU side of the pipeline becomes the constraint. Every microsecond of GIL-bound tokenization is a microsecond where GPUs worth hundreds of thousands of dollars sit idle, waiting for input.

That’s where the journey really started. Not with a gateway vision — with a production problem. Could we disaggregate the entire CPU workload from the GPU path and run it in Rust? Not Python-calling-Rust. Pure Rust. No GIL. No single-threaded ceiling. No Python process boundaries.

The answer was yes, and the project that proved it became Shepherd Model Gateway.

SMG Architecture: Clients → Gateway → Router → Workers

SMG’s architecture is built on one principle: GPUs should do tensor math. Everything else belongs in a dedicated serving layer.

We looked at the model-serving stack and identified every CPU-bound workload entangled with GPU inference: tokenization, detokenization, reasoning output parsing, function call extraction, MCP tool orchestration, multimodal preprocessing, chat history management, structured output validation, stop sequence detection. Each one is a CPU task that, when co-located with the GPU process behind the Python GIL, creates back-pressure on the most expensive hardware in the rack.

SMG moves all of these into a Rust gateway layer that communicates with inference engines over gRPC. The protocol is minimal and GPU-focused: send preprocessed tokens in, stream generated tokens out. Everything else is the gateway’s responsibility.

This isn’t the approach most projects in the distributed inference space have taken. Excellent work is happening with projects like NVIDIA Dynamo and llm-d, which focus on optimizing the inference engine layer and the orchestration around it. We see that work as complementary. But SMG’s bet is different: rather than making the engine smarter, make the gateway smarter. Offload everything that doesn’t require a GPU onto a purpose-built Rust layer that scales independently, evolves independently, and runs with zero GIL contention.

The gRPC Re-Architecture: Making It Real

The gRPC Pipeline: Gateway-side processing before engine handoff

The single largest technical investment in SMG’s history was rebuilding the entire serving pipeline around a native Rust gRPC data plane. This was the architectural proof of the disaggregation thesis.

Tokenization and detokenization move into the gateway. SMG runs tokenizers natively in Rust with a two-level cache — L0 exact-match for repeated prompts, L1 prefix-aware at special-token boundaries. The inference engine receives pre-tokenized input and never touches a tokenizer. No Python. No GIL.

Reasoning and tool call parsing runs in the gateway’s streaming pipeline. As tokens arrive over gRPC, SMG’s parsers — including Cohere Command, DeepSeek, Llama, Nemotron, Kimi-K2, GLM-4, and Qwen Coder — extract reasoning blocks, function calls, and structured output in real-time. No post-processing step on the engine side.

Multimodal processing was the most ambitious piece. We rewrote major components of Hugging Face’s transformers image processor from Python to Rust — reimplementing vision preprocessing pipelines, tensor operations, and model-specific transformations in a completely different language and runtime. The result: SMG communicates preprocessed tensors directly to engines via gRPC with zero Python overhead. Support for Llama 4 Vision, Qwen VL, and all major vision-language models, with backend-specific optimizations for SGLang, vLLM, and TensorRT-LLM. This is, to our knowledge, an industry first.

MCP tool orchestration runs entirely in the gateway with auth-aware connection pooling, concurrent batch execution, approval workflows, automatic reconnection, and HTTP header forwarding. The inference engine has no knowledge of MCP. We also built a complete built-in tool routing infrastructure — turning any MCP server into native capabilities (FileSearch, WebSearch, CodeInterpreter) for any model. Deploy Llama or Qwen with the same built-in tools as GPT-4.

Chat history management with pluggable storage (PostgreSQL, OracleDB, Redis and in-memory), schema versioning via Flyway, customizable table/column names, and storage hooks for pre/post persistence callbacks. All in the gateway, keeping the engine stateless.

WASM middleware provides programmable extensibility without forking the codebase. Custom authentication, compliance logging, PII redaction, cost tracking, compression — all via WebAssembly plugins with sandboxed isolation. Another industry first.

The gRPC protocol itself — published as smg-grpc-proto on PyPI — defines the narrow contract between gateway and engine. This design means you can upgrade your gateway (new parsers, new protocols, new tools) without touching your inference engine, and upgrade your engine (new GPU kernels, new quantization) without touching your gateway. They evolve independently because the interface is clean.

What SMG Delivers Today

SMG was created by Simo Lin and Chang Su, members of the LightSeek Foundation. In roughly six months, we shipped thirteen releases. Rather than walk through each one, here is what the project delivers today — and the evidence behind each capability.

Multi-Model Inference Gateway

A single SMG process fronts your entire fleet — multiple models, multiple engines, one entry point. Route requests across SGLang, vLLM, TensorRT-LLM, and MLX backends simultaneously. Add OpenAI, Anthropic, Google Gemini, AWS Bedrock, and Azure OpenAI as external providers. One gateway, every engine, every vendor.

Five Native Agentic APIs

SMG natively supports Chat Completions (OpenAI), Responses API (OpenAI), Messages API (Anthropic), Interactions API (Gemini), and Realtime API (WebSocket/WebRTC). These are not translation layers — each is a first-class implementation. The Messages API preserves thinking blocks end-to-end with ThinkingConfig, thinking_delta streaming events, and interleaved reasoning + text + tool use content blocks. The Responses API brings OpenAI’s conversation management to Llama, DeepSeek, Qwen, and every open-source model — SMG remains the only open-source gateway supporting it. Run agentic workflows designed for Claude on Llama 4, Qwen 3, DeepSeek, or Kimi with full protocol fidelity.

Native Rust gRPC Data Plane

Two-Level Tokenizer Cache: L0 exact-match, L1 prefix-aware

The architectural core: a native Rust gRPC pipeline between gateway and engine. The contract is minimal — preprocessed tokens in, generated tokens out. Everything else is the gateway’s responsibility. Tokenization runs in Rust with a two-level cache (L0 exact-match, L1 prefix-aware). Reasoning and tool call parsing runs in the streaming pipeline as tokens arrive — supporting fifteen model families including DeepSeek-R1, Qwen3, GLM-4, Kimi, Llama-4, Cohere Command, and more. No Python. No GIL. The gRPC protocol is published as smg-grpc-proto on PyPI, and both vLLM (PR #36169) and NVIDIA TensorRT-LLM (five merged PRs) have adopted it upstream.

Intelligent Routing

Cache-Aware Routing Flow

Eight load-balancing policies: cache-aware, round robin, random, power-of-two, consistent hashing, prefix hash, manual (sticky sessions), and bucket-based. Cache-aware routing was rewritten from the ground up — 10–12x faster (216,000 insertions/sec), 99% memory reduction (180 KB → 1.4 KB per node, 10,000 cached prefixes: 1.8 GB → 14 MB). Event-driven KV cache routing streams real-time cache state from all backends via SubscribeKvEvents RPC, with auto-learned block sizes. Production results on 8 Llama replicas: TTFT average down 23%, TTFT p99 down 28%. Prefill-decode disaggregation routes prefill and decode phases to separate worker pools with independent policies — 20–30% TTFT improvement in PD setups.

Multimodal Processing in Rust

We rewrote major components of Hugging Face’s image processors from Python to Rust — vision preprocessing pipelines, tensor operations, and model-specific transformations in a completely different language and runtime. Eight vision model families supported: Kimi K2.5, Llama-4 Vision, LLaVA, Phi-3/Phi-4 Vision, Pixtral, Qwen-VL, Qwen2-VL, and Qwen3-VL. Preprocessed tensors flow directly to engines via gRPC with zero Python overhead. To our knowledge, an industry first.

MCP Tool Orchestration & Built-in Tools

MCP Architecture: Tool orchestration in the gateway

MCP runs entirely in the gateway with auth-aware connection pooling, concurrent batch execution, approval workflows, automatic reconnection, and four transports (STDIO, HTTP, SSE, Streamable). Universal MCP Built-in Tools turn any MCP server into native capabilities — FileSearch, WebSearch, CodeInterpreter — for any model. Deploy Llama or Qwen with the same built-in tools as GPT-4. Per-tenant isolation, policy-based trust levels, and execution metrics come standard.

WASM Middleware

WASM Plugin Pipeline

Programmable extensibility via WebAssembly plugins with sandboxed isolation — another industry first. Custom authentication, compliance logging, PII redaction, cost tracking, compression — all without forking the codebase. Built on Wasmtime with Component Model and async support. Storage hooks intercept chat history operations for custom pre/post processing.

Enterprise Security & Observability

TLS/mTLS Architecture

JWT/OIDC authentication with JWKS discovery, role-based access control, API key auth, and multi-tenant rate limiting. TLS and mTLS for both client-facing and inter-node communication. A six-layer metrics system with 40+ Prometheus metrics covering HTTP, router, worker, inference, discovery, MCP, database, and mesh layers. Full OpenTelemetry distributed tracing. Structured JSON logging with request correlation.

Reliability & High Availability

Circuit Breaker State Machine

Per-worker circuit breakers (closed/open/half-open), automatic retries with exponential backoff and jitter, periodic health checks, concurrent request rate limiting, request timeouts, and configurable graceful shutdown. SWIM-protocol gossip mesh with CRDT-based state sync for multi-node deployments. Distributed rate limiting via consistent hashing across cluster nodes. Partition-tolerant by design.

Data Persistence & Service Discovery

Service Discovery: Kubernetes, DNS, Manual

Chat history management with pluggable storage — PostgreSQL, OracleDB, Redis, or in-memory — with schema versioning and customizable table/column names. Kubernetes label-based pod discovery, DNS discovery, or manual worker URLs. Model ID sourced from pod namespace, labels, or annotations. Bootstrap port annotation for automatic prefill port discovery in PD setups.

Universal Platform Support

Linux, Windows, macOS, x86, ARM — from a single Python wheel (pip install smg). Python 3.8–3.14. Production-ready client SDKs in Python, Rust, Java, and Go. Engine-specific Docker images. Full modularization into standalone crates: smg-auth, smg-mesh, smg-mcp, smg-wasm, smg-grpc-client, smg-kv-index, llm-tokenizer, llm-multimodal, openai-protocol, and more.

Proving the Thesis: gRPC Gateway Benchmarks

The disaggregation thesis predicts that moving CPU workloads off the GPU path should show measurable benefits — especially under production conditions. We tested this systematically.

Methodology

All benchmarks run on NVIDIA H100 GPUs using NVIDIA GenAI-Perf (genai-perf) via the SMG nightly benchmark suite on GitHub Actions. 8 models (GPT-OSS-20B, Llama-3.1-8B, Llama-3.3-70B, Llama-3.3-70B-FP8, Llama-4-Maverick, Llama-4-Scout, Qwen2.5-7B, Qwen3-30B-MoE), 2 runtimes (SGLang, vLLM), 5 traffic scenarios, 9 concurrency levels (1–256). Total: 1,082 matched gRPC vs HTTP comparison points.

The Scaling Story: Advantage Grows with Concurrency

At concurrency 1, gRPC and HTTP perform within noise. At concurrency 256, gRPC delivers ~8% more throughput. The gateway’s binary serialization and HTTP/2 multiplexing compound under load — exactly when it matters.

Long Contexts: Where gRPC Transforms Performance

HTTP/JSON serialization cost grows linearly with prompt length. gRPC/protobuf uses compact binary encoding that doesn’t pay this tax. At 7800 input tokens, the serialization cost is substantial. The D(7800,200) scenario shows +12.2% throughput advantage across all models.

The most dramatic result: Llama-3.3-70B-FP8 with 7800-token inputs. This model, running FP8 quantization on H100, is fast enough that HTTP serialization becomes the dominant bottleneck. gRPC delivers up to 3.5x higher output throughput: 1,150 tok/s vs 327 tok/s.

Per-Model Breakdown at High Concurrency

At production concurrency levels (32–256), the gRPC advantage varies by model architecture. Llama-3.3-70B-FP8 sees the largest gains (+15.8% E2E p99, +44.6% output throughput). Smaller dense models (Llama-3.1-8B, Qwen2.5-7B) show modest improvements. The pattern is clear: faster GPUs → larger gRPC advantage, because CPU overhead becomes a bigger fraction of total latency.

The Landscape

We’re not the only team working on LLM infrastructure.

NVIDIA Dynamo brings deep hardware integration and optimized inference orchestration. llm-d tackles distributed inference scheduling with a Kubernetes-native approach. Both are doing important work at the engine and cluster layer.

SMG operates at a different boundary: the serving and protocol layer. We own everything between the client and the GPU — tokenization, agentic protocol translation, tool orchestration, cache-aware routing, multimodal preprocessing, reliability. One layer, zero external dependencies, pure Rust.

The key insight: these approaches compose. You can run SMG in front of vLLM managed by llm-d, or in front of TensorRT-LLM with Dynamo handling GPU orchestration. The boundaries are clean because the responsibilities are different.

Production Adoption

SMG powers production deployments at:

  • Google Cloud Platform — multi-tenant AI infrastructure
  • Oracle Cloud Infrastructure — enterprise GenAI services
  • Alibaba Cloud — cloud-native AI workloads
  • TogetherAI — distributed inference infrastructure

From startups to hyperscalers.

What’s Next

  • Batch API scheduling — two-tier architecture with Job Scheduler and Capacity Governor for offline workloads.
  • Semantic routing — lightweight classification-based dispatch to different backends based on content, not static rules.
  • Mixture of Vendors (MoV) — route the same model across multiple providers for A/B testing, cost optimization, and quality comparison.
  • MCP Semantic Search — efficient tool discovery across servers with hundreds of registered tools.
  • Custom metrics load balancing — CEL expressions over arbitrary metrics with sub-millisecond routing overhead.

GitHub: github.com/lightseekorg/smg

Install: pip install smg –upgrade

Docs: lightseekorg.github.io/smg

Acknowledgement

SMG’s development has been shaped by close collaboration with engineering teams and open-source communities across the industry. We’re grateful for the contributions, feedback, and partnership of:

  • Oracle Generative AI Service — Jun Qian, Jingqiao Zhang, Wei Gao, Keyang Ru, Xinyue Zhang, Yifeng Liu, Ziwen Zhao, Daisy Zhou, Khoa Tran.
  • TogetherAI — Yineng Zhang, Wei Gong, Chandra Mourya, Connor Li.
  • Thinking Machines Lab — Eric Zhang, Rajat Goel, Jeff Hanson.

We also thank the SGLang, vLLM, and TensorRT-LLM communities for upstream collaboration and protocol adoption, and the teams at radixArk and Inferact for their partnership and feedback.

Their production deployments, code contributions, and technical insights have shaped what SMG is today.

]]>
Introducing AutoSP https://pytorch.org/blog/introducing-autosp/ Wed, 29 Apr 2026 15:25:24 +0000 https://pytorch.org/?p=75374

¹ SSAIL Lab, University of Illinois Urbana-Champaign, ² Anyscale, ³ Snowflake

TL;DR: AutoSP automatically converts standard transformer training code into sequence-parallel code for long-context LLM training across multiple GPUs. Integrated with DeepSpeed, it increases maximum trainable context length with little runtime overhead versus hand-written baselines.

Increasingly, Large-Language-Models (LLMs) are being trained for extremely long-context tasks, where token counts can exceed 100k+. At these token counts, out-of-memory (OOM) issues start to surface, even when scaling device counts using conventional training techniques such as ZeRO/FSDP. To circumvent these issues, sequence parallelism (SP): partitioning the input tokens across devices to enable long-context training with increasing GPU counts, is a commonly used parallel training technique.

However, implementing SP is notoriously difficult, requiring invasive code changes to existing libraries such as DeepSpeed or HuggingFace. These code changes often involve partitioning input token contexts (and intermediate activations), inserting communication collectives, and overlapping communication with computation, all of which must be done for both the forward and backwards pass. This results in researchers who want to experiment with long context capabilities spending significant effort on engineering the system’s stack to enable such capability, repeating this effort for different hardware vendors.

To avoid this complexity, we introduce AutoSP: a fully automated compiler-based solution that automatically converts easy-to-write training code to multi-GPU sequence parallel code that efficiently uses GPUs to train on longer input contexts while composing with existing parallel strategies (such as ZeRO). This avoids the cumbersome need for developers to repeatedly modify training pipelines for long-context training. Users can now simply import AutoSP and compile arbitrary models using the AutoSP backend, giving the power of long-context training to anyone. Moreover, by embedding this technology into the compiler, our approach is performance-portable: highly performant SP can be realised on diverse hardware.

We structure this post as follows: (1) AutoSP and how model scientists can use it to enable long-context training, (2) Key design decisions of AutoSP, (3) key AutoSP results, demonstrating its ease-of-use and impact, (4) some limitations and things AutoSP cannot do.

AutoSP Usage

A key design philosophy of AutoSP is simplicity in abstracting most of the complexity in programming multiple GPUs from users. To do this, we implement AutoSP within DeepCompile: a compiler ecosystem within DeepSpeed to programmatically enable diverse optimisations for deep neural network training. With this, any user who uses DeepSpeed can automatically enable Sequence Parallelism with almost zero hassle. We take a look at an example next.

# We instantiate a deepspeed config.
# Assume 8 GPUs with 2 DP ranks and 4 SP ranks.

config = {
    "train_micro_batch_size_per_gpu": 1,
    "train_batch_size": 2,
    "steps_per_print": 1,
    "optimiser": {
        "type": "Adam",
        "params": {
            "lr": 1e-4
        }
    },
    "zero_optimization": {
        "stage": 1, # AutoSP interoperates with ZeRO 0/1.
    },
    # Simply turn on deepcompile and set
    # the AutoSP pass to be triggered on.
    "compile": {
        "deepcompile": True,
        "passes": ["autosp"]
    },
    "sequence_parallel_size": 4,
    "gradient_clipping": 1.0,
}

# Initialise deepspeed with model.
model, _, _ = deepspeed.initialize(config=config,model=model)

# Compiles model and automatically applies AutoSP passes.
model.compile(compile_kwargs={"dynamic": True})

for idx, batch in enumerate(train_loader):
    # Custom function that we expose within:
    #     deepspeed/compile/passes/sp_compile.
    inputs, labels, positions, mask = prepare_auto_sp_inputs(batch)

    loss = model(
        input_ids=inputs,
        labels=labels,
        position_ids=positions,
        attention_mask=mask
    )

    ... # Backwards pass, optimiser step etc...

As seen in the example above, users take existing training code that runs on a single device and do the following: (1) use the prepare_autosp_input utility function (exposed in DeepSpeed) for lightweight tagging of input tokens, attention masks and position ids for use in program analysis within AutoSP. (2) Adjust the DeepSpeed config to turn DeepCompile on, specifying the “passes” flag to “autosp”. The rest is handled through the AutoSP compiler passes, called when compiling the model, which automatically enable sequence-parallelism alongside other long-context training optimisations. AutoSP additionally automatically composes with ZeRO stage 1 out of the box, simply set the ZeRO-1 flag in DeepSpeed alongside the AutoSP flags to combine both strategies.

AutoSP Compiler Passes

Since AutoSP transforms user code to enable longer-context training, we briefly cover the key design points of AutoSP and code transformations, as well as its consequences to users for transparency.

Sequence Parallelism Code Transformations. AutoSP automatically converts single-GPU code to multi-GPU sequence parallel (SP) code. The specific SP strategy AutoSP converts code into is DeepSpeed-Ulysses. We specifically focus on DeepSpeed-Ulysses over other strategies (e.g. RingAttention) as its communication overhead stays constant with increasing GPU counts on NVLink network topologies or fat-tree networks. However, DeepSpeed-Ulysses only enables scaling the SP-size to the number of heads in a model (32 in 7-8B models).

Activation Checkpointing for longer-context training. AutoSP additionally applies a custom activation-checkpointing (AC) strategy curated for long-context modelling. AC releases intermediate activations of cheap-to-compute operators, recomputing them in the backwards pass as required to compute relevant gradients. PyTorch-2.0 introduces an automated max-flow min-cut based AC formulation, but we find this to be overly conservative for long-context modelling. We accordingly introduce a novel AC strategy targeted for long-context training: Sequence-aware AC (SAC), which exploits unique long-context FLOP dynamics. When triggered on (the default setting in AutoSP), this marginally reduces training throughput. However, without it, training on longer contexts is infeasible, so the user can selectively choose to turn this pass on only for configurations that OOM.

Evaluating AutoSP on Real Models

To demonstrate AutoSP’s viability, we evaluate its performance on models of varying sizes on NVIDIA GPUs to show that its ease of use comes at little to no cost to runtime performance. We benchmark different Llama 3.1 models on an 8 A100-80Gb SXM node. We use PyTorch 2.7 with CUDA 12.8, comparing AutoSP to torch-compiled hand-written baselines of: RingFlashAttention, DeepSpeed-Ulysses, and ZeRO-3. We summarise key results in the figure below:

Not only can AutoSP increase the maximum trainable sequence length given the same resources (left figure – higher is better), but also these benefits come at little cost to runtime performance (right figure – lower is better). 

Limitations

There are two key limitations of AutoSP. First, we require that the user forcefully compile a transformer as a single compilable artifact. Occasionally, PyTorch users may compile many functions individually and stitch them together into one model. This is disallowed in AutoSP as we need to compile and see the entire model to correctly shard input sequences and propagate this information throughout the entire graph. Second, we disallow any graph breaks in compilable artifacts. This complicates analysis and propagation of information, and we leave extending AutoSP to be graph-break resilient to future research.

Conclusion

AutoSP enables users to easily extend arbitrary transformer training code to enable Sequence Parallelism, with a custom AC strategy for enhanced long-context training. Integration with DeepSpeed allows users to easily use existing DeepSpeed training code to train on longer contexts by simply changing a config file. We have prepared end-to-end examples for users to play around with on real model workloads (e.g. Llama 3.1 8B) here. Give it a try to see how easy long context training has become. 

]]>
IBM Research uses vLLM at the heart of its RITS Platform https://pytorch.org/blog/ibm-research-uses-vllm-at-the-heart-of-its-rits-platform/ Fri, 24 Apr 2026 17:38:54 +0000 https://pytorch.org/?p=69763

TL;DR: vLLM has been critical to democratizing access to our research community to the latest and greatest LLMs as they release.

Introduction

In mid-November 2024, IBM Research introduced the Research Inference & Tuning Service (RITS) Platform.  RITS is an Infrastructure / Service Platform accessible to the entire IBM Research community, providing centralized deployment of and shared access to Model Inferencing Endpoints and “Ancillary” Tuning Service Endpoints.  Since its inception, it has grown its research community user base to more than 1300 active users and hosts over 100 models at any given time.

The Business Challenge

RITS was introduced to ensure the IBM Research community has access to a shared operational Infrastructure / Service Platform, which could:

  • Optimize the utilization of GPU resources across Research work streams by democratizing Model Inference Endpoints (and thereby reducing overall operating costs)
  • Focus on providing access to exotic and experimental models in high demand, not available through other channels, which are typically needed for Research initiatives

When establishing RITS, Research leadership and the AI Platform Enablement team established the following high-level objectives:

  • Develop an Infrastructure / Service Platform for Expert Users to leverage via API access
  • Develop the Platform based on open standard interfaces -for example, Red Hat Openshift AI, vLLM Serving, and OpenAI API
  • Prioritize support for Inference at Scale, optimizing GPU utilization with Serverless-based Auto-Scaling, Scaling based on custom metrics, and Throttling of User Requests
  • Endpoint Security via the use of Self-managed API Keys governed via API Gateway technologies
  • Support easy Model Endpoint Discovery and Model Endpoint Introspection (visibility to model deployment and model runtime configuration)

Figure 1: RITS Platform High-level Architecture

Figure 2: RITS Platform UI – Model Discovery & Introspection

How IBM Research Uses vLLM

vLLM is at the heart of the RITS Platform. All Models deployed to the RITS Platform utilize vLLM as a model serving runtime.  vLLM integrates seamlessly into the Red Hat AI portfolio, with Red Hat AI Inference Server and OpenShift AI, which provide foundational capabilities to deploy, monitor, scale, and maintain large models that require specialized accelerator resources for model serving. Furthermore, because vLLM is a hosted project under the PyTorch Foundation, it guarantees a long-term, vendor-neutral open standard that aligns with IBM’s commitment to hybrid, open architectures. 

Red Hat OpenShift AI integrates KServe, which orchestrates model serving by leveraging model-serving runtimes that implement the loading of various types of model servers, in our case, the runtime being vLLM. Given the fact that RITS runs 100s of different models, many of which are often new or experimental, Red Hat OpenShift AI allows us to register different versions of vLLM as custom Serving Runtimes (when custom images are required for a particular model deployment). 

The Serving Runtime creates the environment for deploying and managing the model, creating the template for the vLLM pods that dynamically load and unload models, and exposes the service endpoint for inferencing requests.

Solving AI Challenges with vLLM

The efficient use of limited and costly GPU resources is always paramount for a Model as a Service Platform like RITS, especially when the platform is shared by 100s of active users at any given time, submitting varied workloads that are unknown and unpredictable to platform administrators. Serving performance is paramount, and vLLM as a Serving Runtime helps address our performance requirements while efficiently managing available compute resources.

vLLM is designed for inference serving efficiency. Features such as PagedAttention for efficient memory management, Continuous Batching for optimized serving performance, and Quantization Support allowing for the deployment of models of reduced size without sacrificing model accuracy are all features that contribute to RITS Platform adoption and success.

From an operations perspective, vLLM exposes a rich set of server-level and request-level metrics that can be utilized by both administrators and users to monitor model serving performance and stability.  These metrics, exported to Prometheus, provide critical insight to platform administrators responsible for ensuring the operational stability of the platform and the performance SLOs of its hosted models, as well as providing the means to establish near real-time dashboards which allow end-users to monitor model performance to optimize batch job scheduling and configuration.

Autoscaling model deployments based on unknown and unpredictable load is critical for the success of RITS (where GPU resources need to be carefully managed).  Performing model autoscaling based on basic metrics such as Requests per Second (RPS), which may be adequate for traditional CPU/Memory-intensive workloads, is not adequate for Model Serving workloads that require limited and expensive GPU Accelerator resources.  Once again, vLLM and its exported metrics have made it possible for RITS to implement a hybrid autoscaling model.  RITS leverages serverless technologies for 0 to 1 and 1 to 0 scaling, but then leverages Turbonomic, IBM’s Application Resource Management (ARM) product, to perform scaling from 1 to n and n to 1, using custom metrics emitted by the vLLM serving runtime (Requests Waiting being a much better scaling metric than RPS). This sophisticated, metric-driven approach to scaling and request routing directly anticipates the advanced orchestration features that are now becoming native in emerging distributed frameworks like llm-d.

Figure 3: RITS Platform Hybrid Autoscaling Model

A Word from IBM

” The vLLM community is vibrant and responsive, and with collaborative expertise, we are able to do great things both upstream and internally by leveraging and contributing to this groundbreaking project. vLLM has been critical to democratizing access to our research community to the latest and greatest LLMs as they release.”

Priya Nagpurkar, Vice President, AI Platform, IBM Research

The Benefits of Using vLLM

Leveraging vLLM as our core serving runtime has significantly contributed to the RITS Platform’s overall success and adoption.  In addition to the performance optimizations, scalability, and stability benefits already mentioned, vLLM offers the end-user community an easy platform adoption path.  vLLM’s support of OpenAI API has allowed our user community to leverage a simple, consistent HTTP-based inference API, as well as leveraging the user’s favorite client-side SDKs.  In addition, vLLM’s broad adoption and integration with various open-source LLMs has allowed the RITS Administration team the ability to leverage a single and consistent serving runtime across all deployed models.  

vLLM has been and will continue to be the core enabler of the RITS Platform, even as we evolve to optimized distributed runtime frameworks such as llm-d in the coming months. By pairing vLLM’s high-performance inference engine with llm-d’s cluster-wide orchestration, RITS will be able to leverage techniques like predicted-latency-based scheduling and cache-aware affinity routing to maximize prefix reuse. This ensures every GPU has work, driving predictable, cost-efficient scale while seamlessly supporting a wider variety of hardware accelerators. vLLM will also allow us to expand beyond using GPU Accelerators, establishing a heterogeneous compute environment within RITS; the IBM Spyre Accelerator will become a core Accelerator for various models leveraging the vLLM serving runtime. Stay Tuned – exciting times lie ahead!

Learn More

This is just one of the many ways that IBM leverages the power of vLLM. Learn more about how we combine vLLM with KServe for fast inference at scale, how to set up and run vLLM on IBM Power, and run Granite models with vLLM in a container.

vLLM is a PyTorch Foundation-hosted project. Learn more >>

llm-d is a Cloud Native Computing Foundation Sandbox project. Learn more >>

]]>
Optimizing Effective Training Time for Meta’s Internal Recommendation/Ranking Workloads https://pytorch.org/blog/optimizing-effective-training-time-for-metas-internal-recommendation-ranking-workloads/ Fri, 17 Apr 2026 16:00:18 +0000 https://pytorch.org/?p=63618 Motivation and Introduction

Across the industry, teams training and serving large AI models face aggressive ROI targets under tight compute capacity. As workloads scale, improving infrastructure effectiveness gets harder because end-to-end runtime increasingly includes overheads beyond “real training” (initialization, orchestration, checkpointing, retries, failures, and recovery). 

Meta utilizes Effective Training Time (ETT%) to quantify efficiency, defining it as the percentage of total end-to-end (E2E) wall time dedicated to productive training. This metric directly points to areas where time is wasted, thus facilitating the prioritization of efficiency improvements.

In this work stream, while grounded in Meta’s production experience using PyTorch for model training, we aim to share broadly useful lessons: some improvements have been implemented in open source—e.g., TorchRec sharding plan improvements and PyTorch 2 (PT2) compilation optimizations that reduce compile time and recompilation—while others (like checkpointing and model publishing) are more Meta-specific, but address common industry bottlenecks and can be adapted elsewhere.

Effective Training Time Definition

Effective Training Time (ETT%) is defined as the percentage of E2E wall time spent on consuming new data. Since the end to end wall time depends on many factors such as model architecture, complexity, training data volume etc, it is hard to directly measure Effective Training Time(ETT%). Instead, focus on measuring idleness and failures, which can be represented as following formula:  

A visual view of the formula is shown below with three L1 sub-metrics: 

  • Time to Start : the period from when a job is allocated hardware to when it begins training the first batch of data.
  • Time to Recover: the duration required for a training job to restart and resume productive training after a failure or interruption.
  • Number of Failures: refers to the total count of infra-related interruptions or unsuccessful attempts that occur during the lifecycle of a training job.

Time to Start and Time to Recover are used to measure the idleness of each single attempt from the system optimization perspective and Number of Failure is targeted to measure different kinds of failures from the reliability area. 

Figure 1. Training Cycle Overview

where the definitions for those L2 area are:  

  • Scheduling Time: time spent in infra to get a training job scheduled when resources are available. 
  • Hardware Setup Time: time spent to bring up launcher/trainer binaries in the hardware.
  • Launcher Init Time: time to start the launcher to enter into the PT2 compilation stage. 
  • PT2 Compilation Time: time to apply PT2 compilation to optimize train model before starting to consume training data. 
  • Effective Training Time: training on time on training data. 
  • Wasted Training Time: time within the train loop but not consuming new training data such as repeated training on samples and blocked training time etc.
  • Shutdown Time: time to stop a training job. 

The Journey to Improve ETT% in Meta

Starting from H2’ 24, we have been proactively analyzing the fleetwide Effective Training Time (ETT). This effort aims to establish the ETT% status, identify key focus areas, and implement improvements. 

For past years, we have developed more than 40 new technologies in order to improve the overall ETT%. The following diagram shows a brief view on improvement in Time to Start for each main area:

Figure 2. Time to Start Improvement Over Each Techs

With the team’s concentrated efforts, we achieved a major milestone by the end of ’25, successfully increasing the Effective Training Time (ETT%) percentage to >90% for offline training.

Technique Deep-Dives

The team conducted a detailed analysis of each area contributing to the Effective Training Time (ETT%) and focused optimizations primarily on the following initiatives:

  • Time to Start and Recover: Optimized trainer initialization and PT2 compilation to lower training costs related to Time to Start and Time to Recover metrics.
  • Checkpoint Management: Improved checkpoint processes to minimize idleness during training and reduce unsaved training time.
  • Shutdown Time Optimizations: Switched to using CPU machines instead of GPUs for model publishing for inference, resulting in savings on GPU hours for jobs’ shutdown time.
  • Failure Reduction and Observability: Collaborated with partner teams to reduce scheduling time and improve the preemption job ratio and established component-level observability and refined the categorization of trainer errors to reduce the frequency of failures.

Trainer Initialization Optimizations

Figure 3. Trainer Initialization Overview

Trainer initialization comprises multiple sub-stages: device_init, process_group_init, preproc_creation, train_module_creation, init_plugins, pre_train, and get_first_batch_data.

Beginning in 2024, we have focused on various initiatives to minimize trainer initialization time. The main methodology we applied is

  1. Communication optimizations:  remove unnecessary creations or communications between each rank to reduce the overhead cost.
  2. Pipeline Optimizations:  for independent processes, run the sub-stage to overlap with each other to maximize the time usage. 

Communication Optimizations

Before this work stream, there were numerous unnecessary creations of process groups and non-optimistic communication across different ranks in each job initialization, which collectively contribute to an increase in train initialization time. 

For instance, instead of relying on numerous all_gather calls to build shard metadata piece by piece—a method that caused substantial overhead in the sharding process—the team implemented an optimization. They now have each rank build its section of the global rank using metadata that is already locally available after the sharding plan broadcast. This change significantly improved sharding time.

Figure 4. Communication Optimizations Overview

Pipeline Optimizations

Many sub-stages in trainer initialization don’t have dependencies between each other, which allows the room to create separate processes to run the sub-stage to overlap with each other. 

For example, the PT2 compilation and DPP warm-up (data process we used to fetch training data) to get the first batch of data, are costly and time-consuming steps that occur before the actual training begins. Currently, the PT2 compilation is delayed, as it can only start once the first batch of real data is available for the compilation process.

In order to enhance the efficiency of this process, we introduced the new technologies to use the fast batch to quickly get the data which allows PT2 to start compiling much earlier while DPP is still fetching the first batch’ data.

Figure 5. PT2 compilation and DPP warm-up Parallel

This new technology is most beneficial for larger models, such as Foundation Models, because their data loading process is significantly more time-consuming than for other model types. 

PT 2.0 Compilation Optimizations

PyTorch 2.0 (PT2) compilation time is another big area where the team invested into. There are 3 main methods we are approaching to reduce the long PT2 compilation time: 

  1. Reduce unnecessary recompilations
  2. Improve overall PT2 cache hit and coverage
  3. Reduce large amounts of user defined autotune kernels’ configs

Previously, the team already posted the experience in reducing PT2 compilation time for meta internal workloads, here we just recap the main approaches we did recently and for more details pls refer to the blog

Reduce unnecessary recompilations

Recompilation due to dynamic shapes is a significant source of overhead in our Meta workloads. This recompilation contributes substantially to the overall compilation time across the fleet, resulting in considerable cumulative cost.

To address this, the v-team collaborated with the Pytorch team in H1 ’25 to develop TORCH_COMPILE_DYNAMIC_SOURCES, which improved the handling of dynamic shapes for parameters by providing an easy and user-friendly way to mark parameters as dynamic without modifying the underlying code. This feature also supports marking integers as dynamic and allows the use of regular expressions to include a broader range of parameters, enhancing flexibility and reducing compilation time.

Figure 6. Internal Tool to Identify Dynamic Shape

Improve PT2 Cache

MegaCache brings together several types of PT2 compilation caches—including components like inductor (the core PT2 compiler), triton bundler (for GPU code), AOT Autograd (for efficient gradient computation), Dynamo PGO (profile-guided optimizations), and autotune settings—into a single archive that can be easily downloaded and shared. 

By consolidating these elements, MegaCache offers those improvements:

  • Minimizes repeated requests to remote servers
  • Cuts down on time spent setting up models
  • Makes startup and retried jobs more dependable, even in distributed or cloud environments

By the end of 2025, teams worked together to enable the mega cache across all the training platforms. The average PT2 compile time was significantly reduced by approximately 40% due to this effort.

Autotune config pruning

Autotune in PyTorch 2.0 is a feature that automatically optimizes the performance of PyTorch models by tuning various hyperparameters and settings. With the increasing adoption of Triton kernels, the time required to compile and search for the best settings and hyperparameters for Triton kernels has increased. 

To address this, we developed a process to identify the most time-consuming kernels and determine optimal runtime configurations for implementation in the codebase. This approach has led to a substantial reduction in compilation time.

Checkpoint Management

Checkpoint: a checkpoint is a saved snapshot of a model’s state during training, including its parameters, optimizer settings, and progress. 

At Meta, checkpoints are used to ensure that if a training job is interrupted—due to hardware or software issues—the process can resume from the last saved point rather than starting over.

Checkpoint saving, while necessary, currently blocks GPU training by demanding memory resources, leading to GPU idle time. Furthermore, the time interval between checkpoint saves directly impacts the amount of training progress that is lost (unsaved training time) if a failure occurs.

To address these inefficiencies, the team successfully developed and implemented Async Checkpointing and PyTorch Native Staging. These advancements have significantly improved checkpointing performance by reducing the checkpoint blocking time for all models.

Async checkpointing: it involves creating a copy of the checkpoint in CPU memory, allowing the main trainer process to resume the training loop while a background process completes the checkpoint upload.

PyTorch native staging: the initial async checkpoint implementation used custom C++ staging, which was designed to minimize trainer memory usage during staging by utilizing streaming copy. The checkpointing team has developed a separate async checkpointing solution using PyTorch native staging APIs which allows improved save blocking time at the cost of increased trainer memory consumption.

These improvements were achieved by significantly reducing the total daily GPU hours blocked for checkpointing.

Reducing Wasted Training Time

Optimizing the time required to save checkpoints directly boosts the Effective Training Time (ETT) percentage by reducing interruptions to the training loop. Furthermore, these checkpoint save improvements can unlock greater ETT% gains when paired with adjustments to the checkpoint interval.

Adjusting the checkpoint interval impacts two components of wasted training time:

Unsaved Training Time: this is the training progress lost after a job failure, as any work completed since the last checkpoint is discarded.

  • Calculation: (# train loop failures) * (checkpoint interval)/2

Checkpoint Save Blocking Time: this is the time the training loop is paused specifically while a new checkpoint is being created.

  • Calculation: ((time spent in train loop) / (checkpoint interval)) * (blocking time per checkpoint)

With the job failure rate, the checkpoint interval can be tuned to minimize the expected wasted training time, equal to:

sum(unsaved training time, checkpoint save blocking time)

The following graph illustrates the relationship between checkpoint save intervals and the percentage of wasted training time (WTT%), using a hypothetical scenario with a 15-second checkpoint save blocking time and 3 daily failures.

Figure 7. Checkpoint Save Interval vs Wasted Training Time

By optimizing the checkpoint saving interval, the team successfully reduced the unsaved training time for both production and exploration jobs.

Shutdown Time Optimizations

The team dived into each component of the shutdown phase, and found that the model publish processing (model publishing for inference) dominated the post-train process duration.

Model Publish Processing: Model publishing is the process of optimizing a model using processing code to create an inference-ready snapshot to serve inference. 

The team’s analysis led to the adoption of a standalone publishing strategy, which decouples publishing from the training process. With this approach, publishing is initiated only after the training job has finished and created an anchor checkpoint. This checkpoint is then used by a model processing job, leveraging the stored data, to generate the final inference-ready snapshot.

The key differences between this standalone publishing method and the traditional “trending end” model publishing are visually represented in the diagram below.

Figure 8.  “Trending End” Model Publish vs Standalone Publish

The implementation of the new model publishing pipeline has successfully shortened the shutdown time for each job by approximately 30 minutes.

Failure Reduction and Observability

A major focus area for the team has been failure reduction, as the number of failures significantly impacts the overall Effective Training Time (ETT) percentage. Regressions from code or configuration changes can directly cause this percentage to drop.

Fluctuations in the ETT dashboard are primarily attributed to two factors:

  1. Increased Job Preemptions: A higher volume of running jobs leads to more preemptions.
  2. Service Regressions: Issues with services cause a greater number of job failures.

To tackle preemptions, we are collaborating with infrastructure teams to develop a new scheduling algorithm aimed at lowering the preemption ratio without negatively affecting users’ quotas or experience.

Regarding failure reduction, a dedicated team is scrutinizing each ETT-related component and building dashboards to monitor overall ETT performance, including Time to Start/Time to Restart (TTS/TTR), unsaved training time, and checkpoint saving time. This proactive monitoring ensures that any regression is detected and mitigated early within the SLA.

In the End

As model training scales, resource constraints are becoming a defining challenge across the industry. For years, a major lever for improving training efficiency has been increasing Model FLOPs Utilization (MFU) through techniques like model co-design and kernel optimization. That work remains essential, but large-scale training has surfaced a complementary bottleneck: significant GPU time is spent idle outside the steady-state training loop.

Our analysis shows that non-training overhead can be substantial especially on some of the largest runs. 

To address this, we launched a successful workstream focused on improving Effective Training Time (ETT%), which has already produced meaningful capacity savings. The key takeaway for practitioners is simple: to improve cost and throughput at scale, you must optimize the “in-between” phases—not just the training steps.

Since our training stack utilizes PyTorch, we made an effort to ensure these enhancements are applicable beyond a single environment. We have open-sourced and shared relevant building blocks, such as those in TorchRec and PyTorch 2, within the open-source PyTorch ecosystem. This allows others to leverage these improvements, replicate our results, and build upon our work. Other components, like model publishing and checkpointing, are more specific to Meta but tackle common industry challenges and can be adapted for use elsewhere.

We hope these lessons help teams diagnose similar bottlenecks, apply ETT%-style measurement, and contribute further improvements back to the ecosystem.

Acknowledgements

We extend our gratitude to Max Leung, Apoorv Purwar, Musharaf Sultan, John Bocharov, Barak Pat, Jonathan Tang, Vivek Trehan, Chris Gottbrath and Vitor Brumatti Pereira  for their valuable reviews and insightful support. We also thank the entire Meta team responsible for the development and productionization of this workstream.

]]>