Install Now

pytorch/pytorch

PyTorch Core Architecture

Last updated on Dec 18, 2025 (Commit: 1984725)

Overview

Relevant Files
  • README.md
  • torch/init.py
  • torch/_tensor.py
  • torch/autograd/init.py
  • torch/nn/init.py

PyTorch is a Python-based machine learning framework that provides two core capabilities: tensor computation with GPU acceleration and automatic differentiation for building neural networks. It is designed to be intuitive, flexible, and production-ready.

Core Architecture

PyTorch's architecture is built around several interconnected components:

Tensors are the fundamental data structure—multi-dimensional arrays similar to NumPy's ndarrays but with GPU support and automatic differentiation capabilities. Tensors can live on CPU or GPU devices and support a comprehensive set of mathematical operations.

Autograd is PyTorch's automatic differentiation engine. It uses reverse-mode auto-differentiation (backpropagation) to compute gradients dynamically. When you set requires_grad=True on a tensor, PyTorch records all operations performed on it, building a computational graph that can be traversed to compute gradients.

Neural Networks (torch.nn) provides a modular framework for building deep learning models. The nn.Module base class allows you to define custom architectures with learnable parameters, pre-built layers (Linear, Conv2d, etc.), activation functions, and loss functions.

Optimization (torch.optim) supplies optimizers like SGD, Adam, and RMSprop for updating model parameters during training.

Key Design Principles

Imperative Execution: Code runs immediately when executed, not in a deferred graph.
Python-First: Deep integration with Python—write layers in Python using NumPy-like APIs.
Dynamic Graphs: Computational graphs are built on-the-fly, enabling flexible architectures.
GPU Acceleration: Seamless CUDA support for fast tensor operations on NVIDIA GPUs.

Typical Workflow

  1. Create tensors and define a model using nn.Module
  2. Forward pass: compute predictions
  3. Compute loss using a loss function
  4. Backward pass: call .backward() to compute gradients via autograd
  5. Update parameters using an optimizer

Supported Devices

PyTorch supports multiple hardware backends:

  • CPU - Default, always available
  • CUDA - NVIDIA GPUs (compute capability >= 3.0)
  • ROCm - AMD GPUs (Linux only)
  • MPS - Apple Metal Performance Shaders
  • XPU - Intel GPUs

Additional Modules

  • torch.jit - TorchScript compiler for serialization and optimization
  • torch.distributed - Multi-GPU and multi-machine training
  • torch.utils.data - DataLoader for efficient batch processing
  • torch.linalg - Linear algebra operations
  • torch.fft - Fast Fourier Transform operations
  • torch.sparse - Sparse tensor support

Architecture & Core Components

Relevant Files
  • aten/src/ATen/ATen.h
  • c10/core/TensorImpl.h
  • c10/core/DispatchKey.h
  • aten/src/ATen/core/dispatch/Dispatcher.h
  • torch/csrc/Module.cpp
  • torch/_tensor.py
  • torch/nn/init.py

PyTorch's architecture is organized in layered components, from low-level tensor operations to high-level Python APIs. Understanding these layers is essential for contributing to the codebase.

Core Layers

C10 (Core Tensor Library)

C10 provides foundational abstractions for PyTorch. The TensorImpl class in c10/core/TensorImpl.h is the low-level representation of a tensor, containing:

  • A pointer to Storage (the actual data buffer)
  • Metadata: sizes, strides, storage offset, data type
  • Reference counting for memory management
  • A DispatchKeySet that determines which backend implementation to use

ATen (Abstract Tensor Library)

ATen, defined in aten/src/ATen/, is the C++ tensor library built on top of C10. It provides:

  • The Tensor class (a reference-counted handle to TensorImpl)
  • Hundreds of tensor operations (add, matmul, conv, etc.)
  • Device-agnostic operator definitions
  • Automatic dispatch to CPU, CUDA, or other backends

The Dispatcher System

The dispatcher is PyTorch's central routing mechanism. Located in aten/src/ATen/core/dispatch/, it uses a multi-dispatch strategy:

  1. DispatchKey: Identifies a specific implementation (e.g., CPU, CUDA, Autograd, Sparse)
  2. DispatchKeySet: A set of keys on each tensor indicating which implementations apply
  3. Dispatcher: The singleton that maintains a registry of kernels and routes calls

When an operator is called, the dispatcher:

  • Extracts the DispatchKeySet from input tensors
  • Looks up the appropriate kernel in the dispatch table
  • Calls the kernel with the correct backend implementation

Python Bindings

The torch._C module (built from torch/csrc/Module.cpp) bridges Python and C++:

  • Exposes C++ classes and functions to Python via pybind11
  • Initializes all backend modules (CUDA, CPU, MPS, etc.)
  • Registers Python-specific dispatch keys for custom behavior
  • Manages the Python interpreter integration

The Python Tensor class in torch/_tensor.py inherits from torch._C.TensorBase, providing a Pythonic interface to the C++ tensor implementation.

Neural Network Modules

torch/nn provides high-level building blocks:

  • Module: Base class for neural network layers
  • Parameter: Learnable tensor wrapper
  • functional: Stateless operations (activations, pooling, etc.)
  • init: Weight initialization utilities

Data Flow Example

Python: torch.add(a, b)
  ↓
torch._C (pybind11 binding)
  ↓
ATen: at::add(Tensor, Tensor)
  ↓
Dispatcher: lookup kernel for DispatchKeySet
  ↓
Backend kernel: CPU/CUDA implementation
  ↓
C10: TensorImpl storage access
  ↓
Result: new Tensor

Key Design Principles

  • Separation of concerns: C10 handles memory, ATen handles operations, dispatcher handles routing
  • Multiple dispatch: Kernels selected based on all input tensors, not just the first
  • Extensibility: Custom backends and operations can register with the dispatcher
  • Performance: Dispatch overhead minimized through inline dispatch tables and caching

Compilation & Optimization Systems

Relevant Files
  • torch/_dynamo/init.py
  • torch/_inductor/init.py
  • torch/fx/init.py
  • torch/_export/init.py
  • torch/jit/init.py

PyTorch provides multiple compilation and optimization systems for different use cases. The primary modern stack is PT2 (PyTorch 2.0), which combines TorchDynamo, TorchInductor, and FX for dynamic compilation and optimization.

TorchDynamo: Python-Level JIT Compilation

TorchDynamo is a Python-level just-in-time compiler that hooks into CPython's frame evaluation API (PEP 523) to capture PyTorch computation graphs dynamically. It rewrites Python bytecode to extract sequences of tensor operations into FX graphs before execution.

Key capabilities:

  • Bytecode analysis and symbolic execution to trace Python code
  • Automatic graph capture with minimal code changes
  • Guard system to validate graph assumptions at runtime
  • Graph break handling for unsupported Python features
  • Backend-agnostic compilation pipeline
import torch

@torch.compile(backend="inductor")
def model(x):
    return torch.sin(x) + torch.cos(x)

TorchInductor: Code Generation Backend

TorchInductor is the default compiler backend that generates optimized code for CPUs and GPUs. It takes FX graphs and produces high-performance implementations through multi-stage optimization.

Compilation pipeline:

  1. Graph lowering - Decompose high-level operations to ATen operators
  2. Scheduling - Fuse operations, reorder computations, optimize memory
  3. Code generation - Generate Triton (GPU) or C++ (CPU) code
  4. Compilation - Compile generated code to machine code

Inductor supports multiple optimization modes: default, reduce-overhead (CUDAGraphs), max-autotune, and lite (minimal optimizations).

FX: Intermediate Representation and Transformations

FX provides the intermediate representation (IR) for all compilation systems. It consists of three components:

Symbolic Tracer - Records operations on proxy objects to capture program semantics without executing real data.

Graph IR - A list of nodes representing inputs, operations, and outputs. Each node tracks data flow and dependencies.

Code Generation - Converts graphs back to executable Python code as GraphModule instances.

from torch.fx import symbolic_trace

traced = symbolic_trace(model)
print(traced.graph)  # View IR
print(traced.code)   # View generated code

torch.export: Ahead-of-Time Compilation

torch.export captures models for deployment without requiring the original Python code. It produces ExportedProgram objects containing normalized ATen operations and metadata.

Workflow:

  1. Trace module with example inputs
  2. Decompose to ATen operators
  3. Serialize to portable format
  4. Deploy with torch._inductor.aoti_compile_and_package()
ep = torch.export.export(model, (example_input,))
torch._inductor.aoti_compile_and_package(ep, package_path="model.pt2")

TorchScript: Legacy Compilation System

TorchScript provides two compilation modes (now deprecated in favor of PT2):

  • Scripting - Compiles Python source code to TorchScript IR
  • Tracing - Records operations during execution

Both support graph optimization, serialization, and deployment but have limited Python feature support.

Loading diagram...

Optimization Strategies

Fusion - Combine multiple operations into single kernels to reduce memory bandwidth.

Scheduling - Reorder operations to improve cache locality and reduce peak memory.

Autotuning - Search for optimal kernel configurations and operator implementations.

Dead code elimination - Remove unused computations and intermediate buffers.

CUDAGraphs - Capture GPU operations into graphs for reduced CPU overhead.

Choose compilation mode based on your needs: torch.compile() for training/inference with dynamic shapes, torch.export for deployment, or TorchScript for legacy code.

Automatic Differentiation

Relevant Files
  • torch/autograd/init.py
  • torch/autograd/function.py
  • torch/autograd/grad_mode.py
  • torch/autograd/graph.py
  • torch/_functorch/aot_autograd.py

PyTorch's automatic differentiation system enables gradient computation through reverse-mode AD (backpropagation). The system builds a computational graph during the forward pass and traverses it backward to compute gradients.

Core Concepts

Computational Graph: During forward computation, PyTorch records operations as nodes in a directed acyclic graph (DAG). Each tensor has a grad_fn attribute pointing to the operation that created it. Leaf tensors (created by users with requires_grad=True) have no grad_fn.

Gradient Accumulation: Gradients flow backward through the graph via the chain rule. Each operation's backward function receives upstream gradients and computes gradients for its inputs.

Key Components

1. Function Class

The Function class enables custom differentiable operations:

class MyExp(Function):
    @staticmethod
    def forward(ctx, x):
        result = x.exp()
        ctx.save_for_backward(result)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

output = MyExp.apply(input)

Use ctx.save_for_backward() to store tensors needed during backward. The backward() method receives upstream gradients and returns gradients for inputs.

2. Gradient Modes

Control whether gradients are computed:

with torch.no_grad():
    y = x * 2  # y.requires_grad = False

with torch.enable_grad():
    y = x * 2  # y.requires_grad = True (even inside no_grad)

torch.set_grad_enabled(False)  # Global setting

3. Backward Execution

Loading diagram...

The backward pass:

  1. Starts from output tensors with gradients
  2. Traverses the graph via next_functions edges
  3. Calls each node's backward function with upstream gradients
  4. Accumulates gradients into leaf tensors

Advanced Features

Higher-Order Gradients: Set create_graph=True to build a graph of the backward pass:

x = torch.tensor(1.0, requires_grad=True)
y = x ** 2
z = y.backward(create_graph=True)  # Enables second derivatives

Gradient Checkpointing: Trade computation for memory by recomputing intermediates during backward instead of storing them.

Hooks: Register functions to inspect or modify gradients:

def hook(grad):
    return grad * 2

tensor.register_hook(hook)

Performance Considerations

  • Use torch.no_grad() during inference to skip graph construction
  • Call tensor.detach() to break gradient flow
  • Set retain_graph=False (default) to free intermediate activations after backward
  • Compiled autograd optimizes the backward pass by fusing operations

Neural Network Modules

Relevant Files
  • torch/nn/init.py
  • torch/nn/modules/module.py
  • torch/nn/modules/ (activation, conv, linear, rnn, batchnorm, etc.)
  • torch/nn/functional.py
  • torch/nn/parameter.py

PyTorch's neural network modules provide the building blocks for constructing deep learning models. The architecture is organized around two complementary approaches: stateful modules (classes inheriting from nn.Module) and functional operations (functions in nn.functional).

Core Architecture

Module Base Class: All neural network components inherit from torch.nn.Module, which manages parameters, buffers, and submodules. When you assign a module or parameter as an attribute, PyTorch automatically registers it, enabling automatic differentiation and device management.

import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.fc = nn.Linear(320, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        return self.fc(x.view(x.size(0), -1))

Parameters vs. Buffers: Parameter objects are automatically tracked for gradient computation and optimization. Regular tensors assigned as attributes become buffers (non-learnable state like running statistics in batch norm).

Module Categories

Layers: Core building blocks including:

  • Linear: Linear, Bilinear, LazyLinear
  • Convolutional: Conv1d, Conv2d, Conv3d, ConvTranspose1d/2d/3d
  • Recurrent: RNN, LSTM, GRU, and cell variants (LSTMCell, GRUCell)
  • Embedding: Embedding, EmbeddingBag for discrete input lookup tables

Normalization: Stabilize training by normalizing activations:

  • BatchNorm1d/2d/3d, LayerNorm, GroupNorm, InstanceNorm1d/2d/3d
  • SyncBatchNorm for distributed training

Activation Functions: Non-linearities like ReLU, Sigmoid, Tanh, GELU, LeakyReLU, Softmax, and many others.

Regularization: Dropout, Dropout1d/2d/3d, AlphaDropout for preventing overfitting.

Pooling: MaxPool1d/2d/3d, AvgPool1d/2d/3d, AdaptiveMaxPool, AdaptiveAvgPool.

Containers: Organize modules hierarchically:

  • Sequential: Applies modules in order
  • ModuleList: Holds modules in a list
  • ModuleDict: Holds modules in a dictionary

Functional vs. Module Approach

Loading diagram...

Modules store parameters and state, ideal for reusable components in models. Functional operations in torch.nn.functional are stateless and require explicit parameter passing, useful for dynamic architectures or one-off operations.

Loss Functions

Located in torch.nn, loss modules compute training objectives:

  • Classification: CrossEntropyLoss, BCELoss, BCEWithLogitsLoss
  • Regression: MSELoss, L1Loss, SmoothL1Loss
  • Ranking: MarginRankingLoss, TripletMarginLoss
  • Specialized: CTCLoss, KLDivLoss, CosineEmbeddingLoss

Lazy Modules

Lazy variants (LazyLinear, LazyConv2d, LazyBatchNorm1d) infer input dimensions on first forward pass, simplifying model definition when input shapes are unknown.

Quantization & Fusion

For production optimization, PyTorch provides:

  • Quantized modules in torch.ao.nn.quantized for reduced precision inference
  • Fused modules in torch.nn.intrinsic combining operations (e.g., ConvBnReLU2d) for efficiency
  • QAT modules in torch.ao.nn.qat for quantization-aware training

Distributed Training & Communication

Relevant Files
  • torch/distributed/init.py
  • torch/distributed/distributed_c10d.py
  • torch/distributed/fsdp/fully_sharded_data_parallel.py
  • torch/distributed/tensor/_api.py
  • torch/distributed/device_mesh.py

PyTorch's distributed training system enables efficient multi-GPU and multi-node training through collective communication primitives and advanced parallelism strategies. The core abstraction is the ProcessGroup, which manages communication between ranks (processes) using backends like NCCL, Gloo, and MPI.

Core Concepts

ProcessGroup & Backends: A ProcessGroup represents a set of processes that can communicate. The init_process_group() function initializes the default group, specifying a backend (NCCL for GPU, Gloo for CPU/GPU, MPI for HPC). Each backend implements collective operations efficiently for its target hardware.

Collective Operations: PyTorch provides standard collective communication primitives:

  • all_reduce: Reduces tensors across all ranks and broadcasts result to all
  • all_gather: Gathers tensors from all ranks into a single tensor
  • reduce_scatter: Reduces tensors then scatters results to individual ranks
  • broadcast: Sends tensor from one rank to all others
  • barrier: Synchronizes all ranks

Distributed Tensor (DTensor)

DTensor provides a single-program, multiple-data (SPMD) abstraction for distributed computation. It describes how tensors are distributed across a DeviceMesh using Placement specifications:

  • Shard(dim): Tensor sharded along dimension dim across mesh devices
  • Replicate(): Tensor replicated on all mesh devices
  • Partial(): Tensor pending reduction across mesh devices
from torch.distributed.tensor import init_device_mesh, distribute_tensor, Shard

mesh = init_device_mesh("cuda", (4,))
tensor = torch.randn(100, 88)
# Shard tensor's 0th dimension across 4 GPUs
sharded = distribute_tensor(tensor, mesh, [Shard(0)])

Fully Sharded Data Parallel (FSDP)

FSDP shards model parameters, gradients, and optimizer states across data-parallel workers to reduce memory usage. The training loop follows this pattern:

  1. Before forward: All-gather sharded parameters to reconstruct full tensors
  2. Forward pass: Compute with full parameters
  3. After forward: Free unsharded parameters (if reshard_after_forward=True)
  4. Backward pass: Recompute or all-gather parameters again
  5. After backward: Reduce-scatter gradients back to sharded form

FSDP2 (fully_shard) uses per-parameter DTensor sharding for cleaner semantics than FSDP1's flat-parameter approach, enabling better integration with tensor parallelism and simpler state dict handling.

Communication Patterns

Collective operations are typically synchronous (blocking) but support async_op=True for non-blocking execution. Async operations return a Work object that can be waited on later:

work = dist.all_reduce(tensor, async_op=True)
# Do other work
work.wait()

Coalescing optimizes multiple small collectives into fewer large ones, reducing overhead. The _coalescing_manager context manager batches operations for efficient execution.

Device Mesh & Multi-Dimensional Parallelism

DeviceMesh enables multi-dimensional parallelism by organizing devices into n-dimensional grids. This supports:

  • Data Parallelism: Replicate model, shard data
  • Tensor Parallelism: Shard model parameters across devices
  • HSDP (Hybrid Sharded DP): Combine data and tensor parallelism
Loading diagram...

Best Practices

  • Use NCCL for GPU-to-GPU communication (fastest)
  • Enable async operations for overlapping computation and communication
  • Use FSDP for memory-efficient training of large models
  • Leverage DTensor for flexible multi-dimensional parallelism
  • Monitor collective operation timing to identify communication bottlenecks

Quantization & Model Optimization

Relevant Files
  • torch/ao/quantization - Core quantization APIs and observers
  • torch/ao/quantization/fx - FX graph mode quantization
  • torch/ao/pruning - Pruning and sparsification utilities
  • torch/ao/nn/quantized - Quantized module implementations
  • torch/quantization - Legacy quantization APIs (deprecated)

PyTorch provides comprehensive tools for model optimization through quantization and pruning, enabling efficient inference and reduced memory footprint.

Quantization Overview

Quantization reduces model size and accelerates inference by representing weights and activations with lower precision (e.g., int8 instead of float32). PyTorch supports three main quantization approaches:

  1. Eager Mode Quantization - Direct module replacement in Python
  2. FX Graph Mode Quantization - Automated graph-level transformations
  3. Post-Training Quantization (PT2E) - Migrated to torchao library

Quantization Workflow

Loading diagram...

The standard workflow involves three steps:

  1. Prepare: Insert observers to measure activation and weight ranges
  2. Calibrate/Train: Run data through the model to collect statistics
  3. Convert: Replace float operations with quantized equivalents

QConfig and Observers

QConfig specifies how to quantize a layer by defining observer classes for activations and weights:

from torch.ao.quantization import QConfig, MinMaxObserver

qconfig = QConfig(
    activation=MinMaxObserver.with_args(quant_min=0, quant_max=127),
    weight=MinMaxObserver.with_args(dtype=torch.qint8)
)

Key observer types include:

  • MinMaxObserver - Tracks min/max values for affine quantization
  • PerChannelMinMaxObserver - Per-channel quantization for weights
  • MovingAverageMinMaxObserver - Exponential moving average for dynamic ranges
  • HistogramObserver - Entropy-based calibration

Eager Mode Quantization

Direct API for quantizing models in Python:

from torch.ao.quantization import quantize, quantize_dynamic

# Static quantization (requires calibration)
quantized_model = quantize(model, run_fn, run_args)

# Dynamic quantization (weights-only, no calibration)
quantized_model = quantize_dynamic(model)

FX Graph Mode Quantization

Automated quantization via symbolic tracing:

from torch.ao.quantization import prepare_fx, convert_fx

prepared = prepare_fx(model, qconfig_mapping, example_inputs)
# Calibrate with data
quantized = convert_fx(prepared)

Advantages: automatic fusion, pattern matching, backend-aware optimization.

Pruning and Sparsification

PyTorch provides structured and unstructured pruning through sparsifiers:

from torch.ao.pruning import WeightNormSparsifier

sparsifier = WeightNormSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config)
sparsifier.step()  # Update masks
sparsifier.convert(model, mapping)

Key sparsifiers:

  • WeightNormSparsifier - Block-wise sparsity based on weight norms
  • NearlyDiagonalSparsifier - Structured patterns for specific architectures
  • SaliencyPruner - Structured pruning via saliency scores (experimental)

Quantized Modules

Quantized implementations in torch.ao.nn.quantized:

  • Linear, Conv1d/2d/3d - Quantized linear and convolutional layers
  • LSTM - Quantized recurrent layers
  • Embedding, EmbeddingBag - Quantized embeddings
  • BatchNorm2d/3d - Quantized normalization

Quantization Schemes

Supported quantization schemes:

  • torch.per_tensor_affine - Single scale/zero-point per tensor
  • torch.per_channel_affine - Per-channel scale/zero-point (weights)
  • torch.per_tensor_symmetric - Symmetric quantization around zero
  • torch.per_channel_symmetric - Per-channel symmetric quantization

Data Types

Common quantization data types:

  • torch.qint8 - Signed 8-bit integer
  • torch.quint8 - Unsigned 8-bit integer
  • torch.float16 - Half-precision floating point
  • torch.quint4x2 - 4-bit quantization (experimental)

Best Practices

  • Use prepare_fx for new code; eager mode is deprecated
  • Calibrate with representative data for static quantization
  • Combine quantization with pruning for maximum compression
  • Test quantized models thoroughly for accuracy degradation
  • Use backend_config to specify hardware-specific optimizations

Device Backends & Acceleration

Relevant Files
  • torch/cuda/init.py
  • torch/backends/
  • torch/cpu/
  • torch/xpu/
  • torch/mps/
  • torch/_dynamo/device_interface.py
  • c10/core/DeviceType.h
  • c10/core/Backend.h
  • aten/src/ATen/DeviceAccelerator.h

PyTorch supports multiple hardware backends through a unified device abstraction layer. Each backend (CPU, CUDA, XPU, MPS, MTIA) provides device-specific implementations while maintaining a consistent API for tensor operations and acceleration features.

Device Types & Backends

PyTorch defines device types in c10/core/DeviceType.h as an enum with entries for CPU, CUDA, HIP, XPU, MPS, Metal, Vulkan, and others. Each device type can have multiple backends (e.g., SparseCUDA, QuantizedCUDA). The mapping between device types and backends is managed through the Backend enum in c10/core/Backend.h.

Key device types:

  • CPU: Default backend for CPU-based computation
  • CUDA: NVIDIA GPU acceleration (also supports HIP for AMD GPUs)
  • XPU: Intel GPU acceleration
  • MPS: Apple Metal Performance Shaders for macOS
  • MTIA: Meta Training and Inference Accelerator

Backend Registration System

Backends are registered through a dispatch key system. The DispatchKey enum identifies specific implementations, and the dispatcher routes operations to the correct kernel based on tensor properties. Backend registration happens in torch/_inductor/codegen/common.py via register_backend_for_device(), which associates each device with scheduling strategies and code generation wrappers.

# Example: CUDA backend registration
register_backend_for_device(
    "cuda",
    lambda scheduling: CUDACombinedScheduling(scheduling),
    PythonWrapperCodegen,
    CppWrapperGpu,
    WrapperFxCodegen,
)

Device Interface Abstraction

torch/_dynamo/device_interface.py provides a unified DeviceInterface base class that all backends implement. This enables device-agnostic code in TorchDynamo and Inductor. Key methods include:

  • current_device(): Get active device index
  • set_device(): Switch to a device
  • device_count(): Query available devices
  • Stream & Event: Asynchronous execution primitives
  • Worker: Multi-process device property caching

Backend-Specific Modules

Each backend has a dedicated module under torch/backends/ and torch/ with device-specific APIs:

  • torch.cuda: Memory management, streams, graphs, profiling
  • torch.backends.cudnn: cuDNN configuration and optimization
  • torch.backends.mps: Metal backend configuration
  • torch.xpu: XPU-specific utilities
  • torch.cpu: CPU-specific features like AMP

Accelerator Concept

aten/src/ATen/DeviceAccelerator.h defines the accelerator concept for backends that support asynchronous compute via streams and events. Accelerator devices include CUDA, MTIA, XPU, HIP, MPS, and PrivateUse1. The getAccelerator() function determines the active accelerator, enabling device-agnostic synchronization and memory management.

Custom Backend Registration

External backends can register via torch.utils.backend_registration.register_privateuse1_backend(). Custom backends must implement required APIs like is_available(), current_device(), and optionally AMP support through get_amp_supported_dtype().

Code Generation & IR

Relevant Files
  • torch/fx/graph.py
  • torch/fx/graph_module.py
  • torch/fx/node.py
  • torch/_inductor/ir.py
  • torch/_inductor/lowering.py
  • torch/_inductor/codegen/wrapper.py
  • torch/_inductor/codegen/wrapper_fxir.py
  • torch/_inductor/compile_fx.py

PyTorch's code generation pipeline transforms high-level neural network code into optimized machine code through multiple intermediate representations. This section covers FX graphs, inductor IR, and the lowering process.

FX: Functional Transformation Framework

FX is PyTorch's symbolic tracing and transformation toolkit. It captures program semantics without executing real data by feeding proxy objects through code and recording operations.

Three core components:

  1. Symbolic Tracer - Records operations on proxy objects to build a computation graph
  2. Graph IR - A list of Node objects representing inputs, operations, and outputs
  3. Code Generation - Converts graphs back to executable Python code as GraphModule instances
from torch.fx import symbolic_trace

traced = symbolic_trace(model)
print(traced.graph)   # View IR
print(traced.code)    # View generated code

Graph and Node Structure

A Graph is a linked list of Node objects. Each node represents a callsite or syntactic construct with an operation type (op), target, arguments, and keyword arguments.

Node operation types:

  • placeholder - Function input
  • call_function - Call to a Python function or operator
  • call_method - Method invocation on an object
  • call_module - Invocation of a submodule
  • get_attr - Attribute access
  • output - Return value
graph = torch.fx.Graph()
x = graph.placeholder("x")
y = graph.call_function(torch.relu, args=(x,))
graph.output(y)

Inductor IR: Low-Level Representation

After FX graphs are created, the Inductor compiler lowers them to an intermediate representation optimized for code generation. The IR hierarchy models tensor storage and views:

  • TensorBox - Top-level tensor representation
  • StorageBox - Introduces layout information
  • Buffer - Simple 1D allocation
  • View - Metadata-only transformations (transpose, reshape)
TensorBox -> StorageBox -> Buffer
TensorBox -> View -> StorageBox -> Buffer

Lowering: Operations to IR

The lowering process converts ATen operations to inductor IR nodes. The @register_lowering decorator maps operations to decomposition functions that produce IR nodes like Pointwise, Reduction, or Scan.

@register_lowering(aten.add, broadcast=True)
def add(a, b):
    return Pointwise.create(
        device=a.get_device(),
        dtype=a.get_dtype(),
        inner_fn=lambda x, y: x + y,
        ranges=broadcast_shapes(a.get_size(), b.get_size()),
        inputs=[a, b],
    )

Code Generation Pipeline

Loading diagram...

The wrapper codegen converts IR nodes to kernel code. For each operation, the codegen selects appropriate backends (Triton for GPU, C++ for CPU) and generates optimized kernels with memory planning and fusion.

Utilities & Development Tools

Relevant Files
  • torch/utils/init.py
  • torch/utils/data/ (DataLoader, Datasets, Samplers, DataPipes)
  • torch/utils/checkpoint.py
  • torch/utils/benchmark/
  • torch/profiler/
  • torch/testing/
  • tools/ (Build, code generation, testing infrastructure)

PyTorch provides a comprehensive suite of utilities and development tools that support data loading, performance analysis, testing, and model optimization. These tools are essential for both research and production workflows.

Data Loading & Processing

The torch.utils.data module provides the core infrastructure for efficient data handling:

  • DataLoader - Handles batching, shuffling, and parallel data loading with multiprocessing support. Supports both map-style and iterable datasets with customizable collation functions.
  • Datasets - Base classes (Dataset, IterableDataset) and utilities like ConcatDataset, ChainDataset, Subset, and TensorDataset for composing data sources.
  • Samplers - Control data ordering: SequentialSampler, RandomSampler, BatchSampler, WeightedRandomSampler, and DistributedSampler for multi-GPU training.
  • DataPipes - Functional API for composing data transformations with IterDataPipe and MapDataPipe for declarative data pipelines.

Memory & Computation Optimization

Gradient Checkpointing (torch.utils.checkpoint) trades computation for memory by recomputing activations during backpropagation instead of storing them. Key features include:

  • checkpoint() - Wraps functions to enable selective activation checkpointing
  • checkpoint_sequential() - Applies checkpointing to sequential module layers
  • SelectiveCheckpointContext - Fine-grained control over which operations to checkpoint
  • Debug modes and policy configurations for advanced use cases

Performance Analysis

Profiler (torch.profiler) provides comprehensive performance metrics:

  • profile() - Context manager for collecting CPU, GPU, and memory metrics
  • record_function() - Manual annotation of code regions for profiling
  • schedule() - Control profiling behavior (e.g., warmup, active steps)
  • tensorboard_trace_handler() - Export traces for visualization
  • Memory profiling and FLOP counting capabilities

Benchmark Utilities (torch.utils.benchmark) enable reproducible performance testing:

  • Timer - Precise timing with automatic warmup and statistical analysis
  • Fuzzing tools for operator testing
  • Comparison utilities for before/after performance analysis

Testing Infrastructure

Public Testing API (torch.testing) provides:

  • assert_close() / assert_allclose() - Tensor comparison with tolerance handling
  • make_tensor() - Generate test tensors with specified properties
  • FileCheck - Pattern matching for JIT and compiler output validation

Internal Testing (torch.testing._internal) includes:

  • Device-specific test utilities (common_cuda.py, common_device_type.py)
  • Distributed training test helpers (common_distributed.py, common_fsdp.py)
  • Operator metadata and test generation (opinfo/)
  • Hypothesis-based property testing utilities

Development Tools

The tools/ directory contains build and code generation infrastructure:

  • Autograd Generation - tools/autograd/ generates backward functions from YAML specifications
  • Code Analysis - Operator registration, linting, and static analysis
  • Build System - CMake, Bazel, and Buck integration helpers
  • Testing Infrastructure - Test discovery, statistics collection, and CI integration
  • Code Coverage - Coverage tracking and reporting plugins
# Example: DataLoader with custom collation
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

loader = DataLoader(
    CustomDataset(data),
    batch_size=32,
    shuffle=True,
    num_workers=4
)
# Example: Gradient checkpointing
from torch.utils.checkpoint import checkpoint

def forward(x):
    return checkpoint(expensive_layer, x)
# Example: Performance profiling
from torch.profiler import profile, record_function

with profile(activities=[ProfilerActivity.CPU]) as prof:
    with record_function("model_inference"):
        output = model(input_data)

These utilities form the backbone of PyTorch's ecosystem, enabling efficient data pipelines, reproducible research, and production-grade performance optimization.