Install

dmlc/xgboost

XGBoost Wiki

Last updated on Dec 19, 2025 (Commit: b31e279)

Overview

Relevant Files
  • README.md
  • python-package/xgboost/init.py
  • include/xgboost/base.h
  • include/xgboost/learner.h
  • src/learner.cc
  • src/gbm/gbtree.cc
  • src/predictor/predictor.cc

XGBoost (eXtreme Gradient Boosting) is an optimized distributed gradient boosting library designed for efficiency, flexibility, and portability. It implements machine learning algorithms under the Gradient Boosting framework, providing parallel tree boosting (GBDT/GBM) that solves data science problems at scale—from small datasets to problems with billions of examples.

Core Architecture

XGBoost follows a modular architecture with clear separation of concerns:

Loading diagram...

Key Components

Learner (include/xgboost/learner.h, src/learner.cc) The central orchestrator that integrates training and prediction. It manages the objective function, gradient booster, and evaluation metrics. The learner handles one iteration of boosting via UpdateOneIter(), which computes gradients and updates the model.

Gradient Booster (src/gbm/gbtree.cc) Implements the tree boosting algorithm. It creates new trees each iteration, updates leaf weights, and manages the ensemble. Supports both single-target and multi-target models with configurable strategies.

Data Matrix (python-package/xgboost/core.py) Efficient data representation supporting dense, sparse, and external memory formats. Includes DMatrix, QuantileDMatrix, and ExtMemQuantileDMatrix for different use cases.

Objective Functions (src/objective/) Compute loss gradients and Hessians for various tasks: regression, classification, ranking, and survival analysis. Examples include logistic loss, softmax, and LambdaRank.

Metrics (src/metric/) Evaluate model performance using metrics like AUC, logloss, RMSE, and ranking metrics. Computed during training for monitoring and early stopping.

Predictors (src/predictor/) Execute inference on trained models. Supports CPU and GPU prediction, with specialized implementations for leaf indices, feature contributions, and interaction contributions.

Training Flow

  1. Initialize: Create learner with data and parameters
  2. Per Iteration: Compute predictions → Calculate gradients → Build trees → Update leaves
  3. Evaluate: Compute metrics on validation sets
  4. Finalize: Return trained booster model

Distributed Training

XGBoost supports distributed training via collective communication (src/collective/), enabling training on Kubernetes, Hadoop, Spark, and Dask. The framework handles gradient aggregation and model synchronization across workers.

Language Bindings

  • Python: Full-featured API with scikit-learn compatibility (XGBClassifier, XGBRegressor)
  • R: Native R package with formula interface
  • C++: Core implementation with C API for language interoperability
  • JVM: Spark and Flink integrations for big data platforms

Architecture & Core Components

Relevant Files
  • include/xgboost/learner.h
  • include/xgboost/gbm.h
  • include/xgboost/data.h
  • include/xgboost/objective.h
  • include/xgboost/metric.h
  • src/learner.cc

XGBoost's architecture follows a modular, layered design that separates concerns into distinct components. The system is built around the Learner class, which orchestrates training and prediction by coordinating three core subsystems: the Gradient Booster, Objective Function, and Evaluation Metrics.

Core Components

Learner is the main user-facing interface. It manages the entire training pipeline and exposes methods like UpdateOneIter() for training iterations, Predict() for inference, and EvalOneIter() for model evaluation. The Learner holds three key internal components:

  • obj_ - The objective function (e.g., regression, classification)
  • gbm_ - The gradient booster (e.g., tree ensemble)
  • metrics_ - A vector of evaluation metrics

GradientBooster is the core algorithm engine. It implements the boosting logic through DoBoost(), which updates the model using gradient statistics. Different booster types (gbtree, gblinear) are registered via the XGBOOST_REGISTER_GBM macro. The booster also handles prediction through PredictBatch() and specialized methods like PredictLeaf() and PredictContribution() for feature importance.

ObjFunction computes gradients and applies transformations. It implements GetGradient() to compute first and second-order gradients from predictions and labels, and PredTransform() to apply inverse link functions (e.g., sigmoid for binary classification). The objective also determines the task type (regression, classification, ranking).

Metric evaluates model performance. Each metric implements Evaluate() to compute a score given predictions and a DMatrix. Metrics are independent of training and used purely for monitoring.

DMatrix is the data abstraction layer. It holds feature data in MetaInfo (labels, weights, groups) and provides batch access through templated GetBatches() methods. DMatrix supports multiple storage formats: sparse pages (CSR), column-compressed (CSC), and GPU-optimized formats (Ellpack, GHistIndex).

Data Flow

Loading diagram...

Training Loop

Each iteration follows this sequence:

  1. Gradient Computation: Learner calls ObjFunction::GetGradient() with current predictions to compute gradients and Hessians
  2. Boosting Update: Learner calls GradientBooster::DoBoost() with gradient statistics to fit a new tree/model
  3. Prediction: Learner calls GradientBooster::PredictBatch() to generate predictions for the next iteration
  4. Evaluation: Learner calls each Metric::Evaluate() to monitor progress

Model Parameters

LearnerModelParam stores global model metadata: number of features, output groups (for multi-class), task type, and base score. This is distinct from training parameters and is serialized with the model. The MultiStrategy enum controls whether multi-target models use one output per tree or vector leaves.

Key Design Patterns

  • Registry Pattern: Objectives, metrics, and boosters are registered dynamically via macros (XGBOOST_REGISTER_*), enabling plugin-like extensibility
  • Batch Processing: DMatrix provides templated batch iterators for efficient memory management and GPU acceleration
  • Separation of Concerns: Training (Learner), boosting (GradientBooster), loss (ObjFunction), and evaluation (Metric) are independent, allowing flexible composition

Tree Building & Updaters

Relevant Files
  • include/xgboost/tree_model.h
  • include/xgboost/tree_updater.h
  • src/tree/updater_colmaker.cc
  • src/tree/updater_approx.cc
  • src/tree/updater_quantile_hist.cc
  • src/gbm/gbtree.h

Tree Model Structure

XGBoost's tree building system centers on the RegTree class, which represents a regression tree with a compact node-based structure. Each tree consists of:

  • Nodes: Internal split nodes and leaf nodes stored in a flat vector
  • Node Statistics: Loss change, hessian sum, and base weight for each node
  • Split Information: Feature index, split condition, and default direction for missing values
  • Categorical Support: Optional bitsets for categorical feature splits

The RegTree::Node class uses bit-packing to store parent-child relationships and split metadata efficiently. The highest bit of the split index encodes the default direction for missing values, while the parent field's highest bit indicates whether a node is a left or right child.

Node Expansion and Allocation

When a leaf node is split, the ExpandNode() method allocates two new child nodes and updates the parent's split information. The tree maintains a pool of deleted nodes for reuse, avoiding repeated memory reallocation. Each node stores:

  • Left and right child indices
  • Parent index with left-child flag
  • Split feature index and condition value
  • Leaf value (for leaves) or split condition (for internal nodes)

Tree Updater Architecture

Tree updaters implement the TreeUpdater interface and are responsible for constructing or modifying trees. The updater pattern allows modular composition of tree-building strategies. Key updaters include:

  • grow_colmaker (Exact): Column-wise enumeration of all split points. Fast for small datasets but doesn't scale to distributed training.
  • grow_histmaker (Approx): Approximate histogram-based construction with global histogram proposals. Supports distributed training.
  • grow_quantile_histmaker (Hist): Quantized histogram construction with local histogram building. Default for most use cases.
  • prune: Post-processing updater that removes splits with negative loss change.

Tree Building Pipeline

Loading diagram...

Split Evaluation and Loss Calculation

Updaters evaluate candidate splits using gradient statistics. For each potential split, they compute:

  • Left and right child gradient sums
  • Loss change (gain) from the split
  • Leaf weights based on gradient statistics

The split with maximum loss change is selected, subject to constraints like minimum child weight and interaction constraints. The TreeEvaluator class handles weight calculation and split scoring.

Distributed and GPU Support

The histogram-based updaters support distributed training through collective communication for histogram aggregation. GPU variants (grow_gpu_hist) parallelize histogram building and split evaluation on CUDA devices. The updater registry system allows seamless switching between CPU and GPU implementations.

Data Loading & DMatrix

Relevant Files
  • include/xgboost/data.h
  • src/data/data.cc
  • src/data/adapter.h
  • src/data/simple_dmatrix.h
  • src/data/quantile_dmatrix.h
  • src/data/iterative_dmatrix.cc
  • src/data/sparse_page_dmatrix.h
  • python-package/xgboost/data.py

XGBoost's data loading system centers on DMatrix, an abstraction layer that provides uniform access to diverse data sources while supporting multiple internal storage formats optimized for different algorithms.

Core Concepts

DMatrix is the primary data structure holding feature data and metadata. It stores features in sparse format (CSR) and metadata in MetaInfo, which includes labels, weights, sample groups (for ranking), base margins, and feature information. The key design principle is separating data format concerns from algorithm concerns through adapters and batch iterators.

MetaInfo contains all auxiliary information: num_row_, num_col_, num_nonzero_, labels, weights, feature names/types, and categorical information. It supports serialization and slicing operations for distributed training.

Data Adapters

Adapters provide a uniform interface for reading external data formats. Each adapter implements a batch-based iterator pattern, yielding COO tuples (row, column, value) without requiring the entire dataset in memory. This abstraction allows DMatrix constructors to remain format-agnostic.

Supported adapters include:

  • FileAdapter: Reads LIBSVM and CSV formats from disk
  • ArrayAdapter: Handles dense NumPy arrays and similar structures
  • SparseAdapter: Processes scipy CSR/CSC matrices
  • DataFrameAdapter: Converts pandas DataFrames and cuDF DataFrames
  • ArrowAdapter: Supports PyArrow tables

DMatrix Variants

Loading diagram...

SimpleDMatrix stores all data in memory as a single SparsePage (CSR format). It's the default for small datasets and supports on-demand generation of column-compressed (CSC), sorted CSC, Ellpack, and gradient index formats.

QuantileDMatrix is a base class for quantile-based storage, storing quantile cuts and histogram indices instead of raw features. It's optimized for the hist tree method and supports both CPU (CSR + CSC) and GPU (Ellpack) formats.

IterativeDMatrix processes streaming data on CPU by iteratively walking through batches to compute quantiles, then building a gradient index. This reduces memory usage by avoiding data concatenation.

ExtMemQuantileDMatrix combines streaming iteration with GPU support, caching gradient indices on disk and prefetching them during training.

SparsePageDMatrix handles external memory by splitting data into multiple cached pages, with async prefetching to overlap I/O with computation.

Batch Access Pattern

DMatrix provides templated GetBatches<T>() methods returning BatchSet<T> objects for range-based iteration:

for (auto& batch : dmatrix->GetBatches<SparsePage>()) {
  // Process batch
}

Supported batch types include SparsePage (CSR), CSCPage (column-compressed), EllpackPage (GPU-optimized), and GHistIndexMatrix (preprocessed histogram indices). Each DMatrix implementation provides efficient batch generation, with lazy creation and caching where appropriate.

Quantilization Pipeline

For histogram-based training, XGBoost quantilizes features into bins. The pipeline:

  1. Sketch Phase: Compute approximate quantile cuts from data stream
  2. Index Phase: Map feature values to bin indices using cuts
  3. Batch Generation: Provide indexed batches to tree builder

QuantileDMatrix stores cuts and indices, enabling efficient histogram computation without accessing raw features. This is critical for GPU training where data may reside on device.

Python Integration

The Python API dispatches data through xgboost/data.py, which detects input types (NumPy, pandas, cuDF, PyArrow, scipy sparse) and routes them to appropriate adapters. The DMatrix constructor accepts data, label, weight, group, and other metadata parameters, automatically handling type conversions and validation.

Prediction & Inference

Relevant Files
  • include/xgboost/predictor.h
  • src/predictor/cpu_predictor.cc
  • src/predictor/gpu_predictor.cu
  • src/predictor/treeshap.cc
  • src/predictor/treeshap.h
  • src/predictor/predict_fn.h
  • src/predictor/gbtree_view.h
  • src/predictor/array_tree_layout.h

Prediction Architecture

XGBoost's prediction system is designed for high-performance inference on trained models. The Predictor abstract base class defines the interface for generating predictions, with specialized implementations for CPU and GPU execution. Predictions are computed by traversing each tree in the ensemble and accumulating leaf values.

Loading diagram...

Core Prediction Flow

The prediction process follows these steps:

  1. Initialization: InitOutPredictions() allocates output vectors and initializes them with base scores or base margins from the model
  2. Batch Processing: PredictBatch() processes multiple samples, traversing trees and accumulating predictions
  3. Tree Traversal: For each sample, the predictor navigates from root to leaf using feature values and split conditions
  4. Leaf Accumulation: Leaf values are added to the running prediction sum
  5. Output: Final predictions are returned as a vector or matrix (for multi-output models)

Tree Traversal and Node Navigation

Tree traversal uses the GetNextNode() function to determine which child to visit at each internal node. The decision depends on:

  • Feature Value: Compare against the split condition threshold
  • Missing Values: Use the default child direction if the feature is missing
  • Categorical Features: Check if the feature value belongs to the categorical split set

The traversal continues until reaching a leaf node, which stores the prediction value to accumulate.

Prediction Caching

The PredictionContainer maintains a cache of predictions for frequently-used datasets. Each cache entry stores:

  • Predictions: A HostDeviceVector holding cached prediction values
  • Version: Tracks the number of trees used to compute the cache

When predicting on the same dataset multiple times, the cache avoids redundant computation by reusing previous results and only computing incremental predictions for newly-added trees.

CPU Prediction Implementation

The CPU predictor (cpu_predictor.cc) optimizes prediction through:

  • Array Tree Layout: Unrolls the top levels of each tree into a flat array structure, reducing branch mispredictions and improving cache locality
  • Vectorization: Processes multiple samples in parallel using OpenMP
  • Feature Caching: Stores feature vectors in thread-local memory to minimize cache misses
  • Categorical Handling: Efficiently processes categorical splits using bitsets

GPU Prediction Implementation

The GPU predictor (gpu_predictor.cu) accelerates inference on CUDA devices:

  • Kernel Parallelization: Each thread processes one sample, traversing the tree in parallel
  • Shared Memory: Caches tree structure and split information for fast access
  • Batch Processing: Processes thousands of samples simultaneously
  • Multi-GPU Support: Distributes predictions across multiple GPUs for large datasets

Feature Contributions and SHAP Values

XGBoost supports computing feature contributions to individual predictions using TreeSHAP, implemented in treeshap.cc. Two algorithms are available:

  • Approximate: Fast computation using the path-dependent approach, suitable for real-time applications
  • Exact: Computes exact SHAP values by considering all possible feature coalitions, more accurate but slower

The contribution computation traverses each tree, tracking the decision path and calculating how each feature's value changes the prediction. The output is a vector of length (num_features + 1) * num_samples, where the extra element represents the base value.

Leaf Index Prediction

PredictLeaf() returns the leaf index for each sample in each tree, useful for understanding model structure and feature interactions. The output is a matrix of shape (num_samples, num_trees) where each entry is the leaf node index.

Inplace Prediction

InplacePredict() enables prediction directly on data without creating a full DMatrix, supporting various data formats:

  • Dense arrays (NumPy, Pandas)
  • Sparse matrices (CSR, CSC)
  • External data sources via adapters

This is particularly useful for real-time inference where data conversion overhead must be minimized.

Multi-Target and Multi-Output Models

For models with multiple targets or output groups, predictions are organized as a matrix where each row represents a sample and each column represents an output. The predictor handles:

  • Multi-class Classification: One output per class
  • Multi-target Regression: One output per target variable
  • Ranking: Multiple outputs for ranking tasks

The GBTreeModelView manages thread-safe access to tree ensembles, ensuring consistent predictions across concurrent inference requests.

Distributed Training & Collective Communication

Relevant Files
  • src/collective/coll.h
  • src/collective/comm.h
  • src/collective/allreduce.cc
  • src/collective/broadcast.cc
  • src/collective/comm_group.h
  • src/collective/loop.h
  • python-package/xgboost/collective.py
  • python-package/xgboost/tracker.py

XGBoost's distributed training system enables parallel model training across multiple machines using collective communication primitives. The architecture is built on a Rabit communicator that coordinates workers through a central tracker, with support for both CPU and GPU-accelerated collective operations.

Core Architecture

The distributed system consists of three main layers:

  1. Communicator Layer (Comm and HostComm): Manages connections between workers, handles bootstrapping via the tracker, and provides error signaling. The RabitComm class implements the standard Rabit protocol with TCP sockets.

  2. Collective Operations Layer (Coll): Implements distributed algorithms like allreduce, broadcast, and allgather. CPU implementations use ring-based and tree-based algorithms, while GPU implementations leverage NCCL.

  3. Event Loop (Loop): An asynchronous I/O system that queues read/write operations on TCP sockets and processes them in a dedicated worker thread, enabling non-blocking collective operations.

Collective Communication Primitives

Allreduce aggregates data across all workers and returns the result to each worker. The CPU implementation uses a ring-based scatter-reduce-allgather pattern:

// Ring allreduce: each worker sends/receives segments in a ring topology
// Reduces communication rounds from O(log n) to O(n) but with smaller messages
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, 
                     Func const& op, ArrayInterfaceHandler::Type type);

Broadcast distributes data from a root worker to all others using a binomial tree:

// Binomial tree broadcast: logarithmic depth, efficient for large messages
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root);

Allgather collects data from all workers. The ring-based variant (AllgatherV) handles variable-length data efficiently.

Tracker & Bootstrap

The RabitTracker coordinates worker initialization:

tracker = RabitTracker(n_workers=4, host_ip="127.0.0.1")
tracker.start()

with coll.CommunicatorContext(**tracker.worker_args()):
    coll.allreduce(data, op=coll.Op.kSum)

Workers connect to the tracker, exchange peer information, and establish direct TCP connections in a ring topology. The tracker assigns ranks deterministically (by host or task ID) and monitors worker health.

GPU Support

GPU collective operations are dispatched through CommGroup::Backend(), which selects either CPU or GPU implementations based on device type. GPU implementations use NCCL for optimized communication on NVIDIA GPUs.

Python API

The Python collective module provides high-level functions:

import xgboost.collective as coll

coll.init(dmlc_tracker_uri="localhost", dmlc_tracker_port=9091)
rank = coll.get_rank()
world = coll.get_world_size()

# Broadcast pickled objects
data = coll.broadcast({"model": weights}, root=0)

# Allreduce with built-in operations
coll.allreduce(gradient_array, op=coll.Op.kSum)

coll.finalize()

Error Handling & Resilience

The system includes timeout and retry mechanisms. Workers signal errors to the tracker, which can trigger recovery or graceful shutdown. The Loop class enforces operation timeouts (default 30 minutes) and supports configurable retry counts.

Language Bindings & APIs

Relevant Files
  • include/xgboost/c_api.h
  • src/c_api/c_api.cc
  • python-package/xgboost/core.py
  • python-package/xgboost/sklearn.py
  • python-package/xgboost/dask/init.py
  • R-package/R/xgb.train.R
  • jvm-packages/xgboost4j/src/native/xgboost4j.cpp

XGBoost provides language bindings for Python, R, JVM (Java/Scala), and C, all built on a unified C API foundation. This architecture enables consistent behavior across languages while allowing idiomatic interfaces for each ecosystem.

C API Foundation

The C API (include/xgboost/c_api.h) is the core abstraction layer that all language bindings use. It defines opaque handles for key objects:

  • DMatrixHandle - represents training/evaluation data
  • BoosterHandle - represents the trained model
  • CategoriesHandle - represents categorical feature metadata

All C API functions return an integer status code (0 for success, non-zero for errors) and use XGBGetLastError() to retrieve error messages. This design enables safe error handling across language boundaries.

Python Binding

The Python binding uses ctypes to load the native library dynamically (python-package/xgboost/core.py). Key mechanisms:

  • Library Loading: _load_lib() locates and loads the shared library (.so, .dll, .dylib) using platform-specific paths
  • Handle Management: Python objects (DMatrix, Booster) wrap C handles as ctypes.c_void_p pointers
  • Type Conversion: Helper functions convert between Python types and C types (e.g., c_str() for strings, c_array() for arrays)
  • Array Interfaces: Uses NumPy's __array_interface__ protocol to pass data efficiently without copying

The Python package offers multiple APIs:

  • Core API (xgboost.train, xgboost.Booster) - low-level, direct C API access
  • Scikit-Learn API (XGBClassifier, XGBRegressor) - familiar estimator interface
  • Dask API (xgboost.dask) - distributed training across clusters

R Binding

The R package wraps the C API through native code in R-package/src/xgboost_R.cc. It provides:

  • Low-level interface (xgb.train()) - mirrors Python's core API for consistency
  • High-level interface (xgboost()) - accepts R data frames and matrices directly
  • External pointers - R's mechanism for holding C++ object references
  • Callback support - custom objectives and evaluation metrics via R functions

JVM Binding

The JVM binding uses JNI (Java Native Interface) in jvm-packages/xgboost4j/src/native/xgboost4j.cpp to bridge Java and C:

  • Handle Encoding: C pointers are cast to Java long values for storage
  • Array Marshalling: Java arrays are converted to C arrays for passing to C API
  • Spark Integration: xgboost4j-spark provides distributed training via Spark RDDs

Data Flow Pattern

All bindings follow this pattern:

  1. User provides data in language-native format (NumPy array, R data frame, Spark RDD)
  2. Binding converts to C-compatible format (JSON array interface, CSR matrix, or callback iterator)
  3. C API creates DMatrixHandle from the data
  4. Training/prediction calls use handles to reference data and models
  5. Results are converted back to language-native format

This abstraction enables XGBoost to optimize data handling at the C++ level while maintaining clean APIs in each language.

Plugin System & Extensions

Relevant Files
  • plugin/README.md
  • plugin/example/custom_obj.cc
  • plugin/federated/
  • plugin/updater_gpu/
  • include/xgboost/objective.h
  • include/xgboost/metric.h
  • include/xgboost/tree_updater.h
  • include/xgboost/gbm.h

XGBoost's plugin system enables extending the framework with custom objectives, metrics, tree updaters, and gradient boosters without modifying core code. Plugins are compiled into the main library and registered at build time using a registry-based factory pattern.

Core Plugin Types

XGBoost supports four primary plugin categories:

  1. Objective Functions (XGBOOST_REGISTER_OBJECTIVE) - Custom loss functions for training
  2. Evaluation Metrics (XGBOOST_REGISTER_METRIC) - Custom evaluation criteria
  3. Tree Updaters (XGBOOST_REGISTER_TREE_UPDATER) - Custom tree-building algorithms
  4. Gradient Boosters (XGBOOST_REGISTER_GBM) - Custom boosting strategies

Registration Mechanism

Plugins use the DMLC registry system for dynamic factory creation. Each plugin type has:

  • Interface class (e.g., ObjFunction, Metric, TreeUpdater)
  • Registry struct (e.g., ObjFunctionReg, MetricReg)
  • Registration macro that creates a static registry entry

When XGBoost starts, all registered plugins are available via their string names. The factory methods (ObjFunction::Create(), Metric::Create(), etc.) look up plugins by name and instantiate them.

Writing a Custom Objective Plugin

#include <xgboost/objective.h>

class MyObjective : public ObjFunction {
 public:
  void Configure(const Args& args) override { /* parse params */ }
  void GetGradient(const HostDeviceVector<float>& preds,
                   MetaInfo const& info, std::int32_t iter,
                   linalg::Matrix<GradientPair>* out_gpair) override {
    // Compute gradients and hessians
  }
  const char* DefaultEvalMetric() const override { return "metric_name"; }
  ObjInfo Task() const override { return ObjInfo::kRegression; }
};

XGBOOST_REGISTER_OBJECTIVE(MyObjective, "my_objective")
    .describe("Custom objective function")
    .set_body([]() { return new MyObjective(); });

Building and Integrating Plugins

To add a plugin to XGBoost:

  1. Create source file in plugin/ directory
  2. Implement the interface and register using the macro
  3. Add to plugin/CMakeLists.txt:
target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/your_plugin.cc)
  1. Rebuild XGBoost with CMake

Existing Plugins

  • Federated Learning - Distributed training across federated nodes
  • SYCL Support - Intel GPU acceleration via SYCL
  • GPU Updaters - CUDA-accelerated tree building (now in core)

Plugin Parameters

Plugins can define custom parameters using DMLC_REGISTER_PARAMETER. Parameters are configured via the Configure() method and can be serialized/deserialized for model persistence.

struct MyParams : public XGBoostParameter<MyParams> {
  float learning_rate;
  DMLC_DECLARE_PARAMETER(MyParams) {
    DMLC_DECLARE_FIELD(learning_rate).set_default(0.1f);
  }
};
DMLC_REGISTER_PARAMETER(MyParams);

Key Design Patterns

  • Factory Pattern: Plugins are created by name at runtime
  • Static Registration: Macros create static initializers that register at startup
  • Interface Inheritance: All plugins inherit from base classes defining required methods
  • Context Passing: Plugins receive Context for device/runtime configuration
  • Lazy Instantiation: Plugins are only created when explicitly requested