Install

keras-team/keras

Keras 3: Multi-Backend Deep Learning Framework

Last updated on Dec 17, 2025 (Commit: 7edb6a4)

Overview

Relevant Files
  • README.md
  • keras/src/init.py
  • keras/src/version.py
  • pyproject.toml

Keras 3 is a multi-backend deep learning framework designed for building and training neural networks across multiple computational backends. Version 3.14.0 supports JAX, TensorFlow, PyTorch, and OpenVINO (inference-only), enabling developers to write backend-agnostic code while leveraging framework-specific optimizations.

Core Purpose

Keras 3 serves as a unified API layer that abstracts away backend differences, allowing you to:

  • Write models once and run them on any supported backend
  • Switch backends without code changes by setting the KERAS_BACKEND environment variable
  • Combine Keras high-level workflows with low-level framework-specific code
  • Avoid framework lock-in while maintaining production-grade performance

Key Features

Multi-Backend Support: The framework seamlessly integrates with JAX, TensorFlow, PyTorch, and OpenVINO. Each backend can be selected at runtime, enabling performance optimization for specific use cases.

Accelerated Development: Keras provides high-level APIs for common deep learning tasks including computer vision, natural language processing, audio processing, timeseries forecasting, and recommender systems.

Production-Ready Performance: Benchmarks show 20-350% speedups compared to other frameworks depending on the backend and model architecture. JAX often delivers the best performance for many architectures.

Datacenter-Scale Training: Built-in support for distributed training across GPUs and TPUs enables scaling from laptops to large clusters.

Architecture Overview

Loading diagram...

Backend Compatibility

The minimum supported versions for Keras 3.14.0 are:

  • TensorFlow: 2.16.1
  • JAX: 0.4.20
  • PyTorch: 2.1.0
  • OpenVINO: 2025.3.0

GPU support is available through separate CUDA-enabled requirement files (requirements-{backend}-cuda.txt).

Core Modules

The main Keras package exposes these primary modules:

  • keras.layers - Neural network layer implementations
  • keras.models - Model classes (Sequential, Functional, Model)
  • keras.optimizers - Training optimizers
  • keras.activations - Activation functions
  • keras.initializers - Weight initialization strategies
  • keras.regularizers - Regularization techniques
  • keras.ops - Backend-agnostic operations
  • keras.backend - Backend abstraction layer
  • keras.datasets - Built-in datasets
  • keras.applications - Pre-trained models
  • keras.utils - Utility functions
  • keras.visualization - Model visualization tools

Backwards Compatibility

Keras 3 is designed as a drop-in replacement for tf.keras when using the TensorFlow backend. Existing TensorFlow Keras code requires minimal changes: primarily updating model save formats to the .keras format. Models without custom components can immediately run on JAX or PyTorch backends.

Architecture & Backend Abstraction

Relevant Files
  • keras/src/backend/init.py
  • keras/src/backend/config.py
  • keras/src/backend/common/global_state.py
  • keras/src/backend/common/variables.py
  • keras/src/backend/jax/core.py
  • keras/src/backend/tensorflow/core.py
  • keras/src/backend/torch/core.py
  • keras/src/backend/numpy/core.py
  • keras/src/backend/openvino/core.py

Keras implements a pluggable backend abstraction layer that allows the same code to run on multiple deep learning frameworks. This design enables users to switch between TensorFlow, JAX, PyTorch, NumPy, and OpenVINO without changing their model code.

Backend Selection & Configuration

The active backend is determined at import time via keras.config.backend() or the KERAS_BACKEND environment variable (defaults to TensorFlow). Configuration is managed in keras/src/backend/config.py and includes global settings like float precision, epsilon, and image data format.

# Switch backend via environment variable
export KERAS_BACKEND=jax

# Or programmatically
import keras
print(keras.config.backend())  # Returns current backend name

Core Abstraction Pattern

Each backend implements the same interface through a modular structure:

  • core.py — Tensor operations, type conversions, and device management
  • numpy.py — NumPy-like array operations
  • math.py — Mathematical functions (sin, cos, exp, etc.)
  • nn.py — Neural network layers (activations, pooling, convolutions)
  • random.py — Random number generation
  • linalg.py — Linear algebra operations

Each backend module exports the same function signatures, allowing layers and operations to call backend functions without knowing which framework is active.

Variable System

The Variable class in keras/src/backend/common/variables.py provides a backend-agnostic container for trainable state. Each backend subclasses this with framework-specific implementations:

  • JAXJaxVariable wraps immutable arrays with optional distributed layouts
  • TensorFlow — Wraps tf.Variable with SavedModel support
  • PyTorch — Wraps torch.nn.Parameter with device management
  • NumPy — Simple NumPy array wrapper
# Backend-agnostic variable creation
var = keras.Variable(
    initializer=keras.initializers.RandomNormal(),
    shape=(10, 20),
    dtype="float32",
    trainable=True
)

Dynamic Dispatch

The main keras/src/backend/__init__.py uses conditional imports to load the active backend at startup:

Loading diagram...

Global State Management

keras/src/backend/common/global_state.py manages thread-local state including device context, distributed settings, and feature flags. The clear_session() function resets this state and backend-specific caches (e.g., TensorFlow's kernel cache, PyTorch's dynamo cache).

Backend Capabilities

Each backend declares its capabilities via module-level constants:

  • SUPPORTS_SPARSE_TENSORS — JAX and TensorFlow support sparse operations
  • SUPPORTS_RAGGED_TENSORS — Only TensorFlow supports ragged tensors
  • IS_THREAD_SAFE — JAX, PyTorch, and NumPy are thread-safe; TensorFlow is not

This allows layers to adapt behavior based on backend capabilities at runtime.

Layers & Models

Relevant Files
  • keras/src/layers/layer.py
  • keras/src/models/model.py
  • keras/src/models/sequential.py
  • keras/src/models/functional.py
  • keras/src/ops/operation.py
  • keras/src/ops/node.py

Core Concepts

Keras organizes neural networks into two fundamental abstractions: Layers and Models. A Layer is a callable object that combines computation (via call()) with state (weights and variables). A Model is a container that groups layers into a trainable object with inference capabilities.

All layers inherit from keras.Layer, which itself inherits from Operation. This inheritance chain enables layers to participate in the computation graph while managing their own state through weight tracking and deferred building.

Layer Architecture

A layer encapsulates:

  1. State Management - Weights created via add_weight() in __init__() or build()
  2. Computation - Forward pass logic in the call() method
  3. Deferred Building - Shapes inferred on first call, enabling flexible input dimensions
  4. Tracking - Automatic tracking of nested layers and their weights
class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super().__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer="zeros",
            trainable=True,
        )

    def call(self, inputs):
        return ops.matmul(inputs, self.w) + self.b

Three Model Patterns

Sequential - A linear stack of layers, each with one input and output:

model = keras.Sequential([
    keras.layers.Dense(32, activation="relu"),
    keras.layers.Dense(10, activation="softmax"),
])

Functional - A directed acyclic graph (DAG) of layers supporting multiple inputs/outputs and layer sharing:

inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

Subclassed - Custom models via inheritance, enabling complex control flow:

class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = keras.layers.Dense(32, activation="relu")
        self.dense2 = keras.layers.Dense(10)

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

Computation Graph

The Operation class and Node system form Keras' computation graph. When a layer is called with symbolic tensors (during model construction), it creates a Node that records the operation, inputs, and outputs. Each KerasTensor carries _keras_history metadata linking it to its source node, enabling graph reconstruction and introspection.

Loading diagram...

Model Execution Modes

Models support two execution paths:

  • Symbolic - Graph construction with KerasTensor inputs, used during model definition
  • Eager - Immediate computation with concrete array inputs, used during training and inference

The __call__() method automatically dispatches to the appropriate path based on input type.

Training Loop & Optimization

Relevant Files
  • keras/src/trainers/trainer.py
  • keras/src/trainers/epoch_iterator.py
  • keras/src/trainers/compile_utils.py
  • keras/src/backend/jax/trainer.py
  • keras/src/backend/tensorflow/trainer.py
  • keras/src/backend/torch/trainer.py
  • keras/src/optimizers/optimizer.py
  • keras/src/optimizers/adam.py

Overview

Keras implements a backend-agnostic training loop that abstracts away framework-specific details while maintaining high performance. The training system is built on three core components: the Trainer mixin (base logic), backend-specific trainers (TensorFlow, PyTorch, JAX), and the EpochIterator (data batching). This architecture enables the same model code to train efficiently across different backends.

Compilation & Configuration

Before training, models must be compiled using model.compile(), which configures:

  • Optimizer: Handles gradient updates (e.g., Adam, SGD)
  • Loss function: Computed via CompileLoss wrapper
  • Metrics: Tracked via CompileMetrics wrapper
  • Execution mode: run_eagerly (immediate execution) or jit_compile (graph compilation)
  • Steps per execution: Number of batches processed before updating weights (reduces overhead)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[keras.metrics.BinaryAccuracy()],
    jit_compile=True,
    steps_per_execution=10,
)

Training Loop Architecture

Loading diagram...

Backend-Specific Training Steps

Each backend implements train_step() differently to leverage native operations:

TensorFlow uses tf.GradientTape() for automatic differentiation:

with tf.GradientTape() as tape:
    y_pred = self(x, training=True)
    loss = self.compute_loss(x, y, y_pred, sample_weight)
gradients = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))

PyTorch uses backward() and manual gradient management:

y_pred = self(x, training=True)
loss = self.compute_loss(x, y, y_pred, sample_weight)
self.zero_grad()
loss.backward()
gradients = [v.value.grad for v in self.trainable_weights]
self.optimizer.apply(gradients, self.trainable_weights)

JAX uses functional transformations (jax.grad, jax.vmap) for pure functional training.

Data Iteration & Batching

The EpochIterator handles data loading and batching:

  • Adapts various input formats (NumPy arrays, tf.data.Dataset, torch.DataLoader, generators)
  • Respects steps_per_epoch and steps_per_execution settings
  • Manages shuffling and class weighting
  • Yields (begin_step, end_step, batch_data) tuples for callback integration

Optimizer & Weight Updates

Optimizers inherit from BaseOptimizer and implement:

  • build(): Initialize optimizer state variables (momentum, adaptive learning rates)
  • update_step(): Core update logic (e.g., Adam's exponential moving averages)
  • apply(): Apply gradients to trainable weights

Adam optimizer example:

class Adam(Optimizer):
    def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7):
        super().__init__(learning_rate=learning_rate)
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon

Key Features

  • Mixed precision training: Automatic loss scaling for float16 via LossScaleOptimizer
  • Gradient clipping: clipnorm, clipvalue, global_clipnorm prevent exploding gradients
  • Gradient accumulation: Simulate larger batch sizes by accumulating gradients
  • EMA (Exponential Moving Average): Optional weight averaging for better generalization
  • Callbacks: Hooks at epoch/batch boundaries for custom logic (early stopping, learning rate scheduling)

Operations & Ops API

Relevant Files
  • keras/src/ops/operation.py
  • keras/src/ops/core.py
  • keras/src/ops/numpy.py
  • keras/src/ops/nn.py
  • keras/src/ops/math.py
  • keras/src/ops/image.py
  • keras/src/ops/linalg.py
  • keras/src/ops/function.py

The Keras Operations API provides a backend-agnostic interface for tensor computations. All operations inherit from the Operation base class, which handles both eager execution and symbolic graph building. This enables seamless switching between TensorFlow, JAX, and PyTorch backends.

Core Architecture

The Operation class is the foundation of the ops system. When called, it intelligently routes to either call() for eager execution or symbolic_call() for graph building. The symbolic_call() method performs shape and dtype inference via compute_output_spec(), then records a node in the computation graph.

class Operation(KerasSaveable):
    def __call__(self, *args, **kwargs):
        if any_symbolic_tensors(args, kwargs):
            return self.symbolic_call(*args, **kwargs)
        else:
            return self.call(*args, **kwargs)
    
    def symbolic_call(self, *args, **kwargs):
        outputs = self.compute_output_spec(*args, **kwargs)
        Node(operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs)
        return outputs

Operation Categories

Keras organizes operations into six main modules:

  1. NumPy Operations (keras.ops.numpy) - Array manipulation, arithmetic, and linear algebra basics (add, matmul, reshape, concatenate, etc.)
  2. Neural Network Operations (keras.ops.nn) - Activation functions, convolutions, pooling, normalization (relu, sigmoid, conv, batch_norm, etc.)
  3. Math Operations (keras.ops.math) - Advanced math functions, segment reductions, FFT (segment_sum, segment_max, fft, etc.)
  4. Image Operations (keras.ops.image) - Image processing utilities (resize, crop, pad, etc.)
  5. Linear Algebra (keras.ops.linalg) - Matrix decompositions and solvers (qr, svd, solve, etc.)
  6. Core Operations (keras.ops.core) - Control flow and graph utilities (map, scan, cond, etc.)

Creating Custom Operations

To create a custom operation, subclass Operation and implement call() and compute_output_spec():

class MyOp(Operation):
    def call(self, x):
        return backend.custom_op(x)
    
    def compute_output_spec(self, x):
        return KerasTensor(shape=x.shape, dtype=x.dtype)

def my_op(x):
    if any_symbolic_tensors((x,)):
        return MyOp().symbolic_call(x)
    return backend.custom_op(x)

Function API

The Function class captures computation graphs of operations without state tracking. Unlike Functional Models, Function instances are stateless and don't implement the Layer API, making them lightweight for pure computation graphs.

input_1 = keras.KerasTensor(shape=(None, 2, 3))
input_2 = keras.KerasTensor(shape=(None, 2, 3))
x = input_1 + input_2
output = keras.ops.sigmoid(x)
fn = keras.Function(inputs=[input_1, input_2], outputs=output)

Backend Abstraction

All operations delegate to backend-specific implementations. The ops layer provides a unified API while backends (TensorFlow, JAX, PyTorch) handle actual computation. This design enables code portability across frameworks without modification.

Losses, Metrics & Callbacks

Relevant Files
  • keras/src/losses/loss.py
  • keras/src/losses/losses.py
  • keras/src/metrics/metric.py
  • keras/src/metrics/accuracy_metrics.py
  • keras/src/callbacks/callback.py
  • keras/src/callbacks/model_checkpoint.py
  • keras/src/callbacks/early_stopping.py
  • keras/src/callbacks/tensorboard.py

Overview

Keras provides three core components for training: losses quantify prediction errors, metrics track performance, and callbacks hook into the training lifecycle. These work together during model.fit() to optimize, monitor, and control training.

Losses

The loss system is built on a base Loss class that subclasses implement via a call() method:

class Loss(KerasSaveable):
    def __call__(self, y_true, y_pred, sample_weight=None):
        # Converts inputs, calls call(), applies reduction
        losses = self.call(y_true, y_pred)
        return reduce_weighted_values(losses, sample_weight, reduction=self.reduction)
    
    def call(self, y_true, y_pred):
        raise NotImplementedError  # Subclasses implement this

Key features:

  • Reduction modes: "sum_over_batch_size" (default), "sum", "mean", "mean_with_sample_weight", or None
  • LossFunctionWrapper: Wraps functional losses (e.g., mean_squared_error) into class-based losses
  • Masking support: Respects Keras masks for variable-length sequences
  • Dtype policies: Supports mixed precision via dtype policies

Common losses: MeanSquaredError, BinaryCrossentropy, CategoricalCrossentropy, Huber, Dice, CTC

Metrics

Metrics track model performance independently of loss. The Metric base class maintains state variables updated during training:

class Metric(KerasSaveable):
    def __init__(self, dtype=None, name=None):
        self._variables = []  # State variables
        self._metrics = []    # Nested metrics
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        # Update internal state
        pass
    
    def result(self):
        # Return computed metric value
        pass
    
    def reset_state(self):
        # Reset variables between epochs
        pass

Key patterns:

  • MeanMetricWrapper: Wraps functional metrics (e.g., accuracy) with automatic mean aggregation
  • Direction inference: Metrics auto-detect if they should be maximized ("up") or minimized ("down")
  • Sample weighting: Supports weighted metrics for imbalanced datasets
  • Nested metrics: Metrics can contain other metrics for complex aggregations

Common metrics: Accuracy, Precision, Recall, AUC, MeanSquaredError, CategoricalAccuracy

Callbacks

Callbacks hook into training at multiple stages via the Callback base class:

class Callback:
    def on_train_begin(self, logs=None): pass
    def on_epoch_begin(self, epoch, logs=None): pass
    def on_train_batch_end(self, batch, logs=None): pass
    def on_epoch_end(self, epoch, logs=None): pass
    def on_train_end(self, logs=None): pass

Key implementations:

  • ModelCheckpoint: Saves model/weights when monitored metric improves
  • EarlyStopping: Stops training when metric plateaus (with patience and min_delta)
  • TensorBoard: Logs metrics, histograms, and profiling data for visualization
  • MonitorCallback: Base class for callbacks that track metric improvements

Callback lifecycle: CallbackList manages multiple callbacks, dispatching events synchronously or asynchronously via thread pools.

Integration Example

model.compile(
    optimizer='adam',
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy()]
)

model.fit(
    x_train, y_train,
    epochs=50,
    validation_data=(x_val, y_val),
    callbacks=[
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=5),
        keras.callbacks.ModelCheckpoint('best_model.keras', monitor='val_accuracy', mode='max')
    ]
)

During training, losses are computed per batch, metrics accumulate state, and callbacks respond to lifecycle events—enabling flexible, production-ready training workflows.

Data Adapters & Model Saving

Relevant Files
  • keras/src/trainers/data_adapters/data_adapter.py
  • keras/src/trainers/data_adapters/array_data_adapter.py
  • keras/src/trainers/data_adapters/tf_dataset_adapter.py
  • keras/src/trainers/data_adapters/torch_data_loader_adapter.py
  • keras/src/trainers/data_adapters/py_dataset_adapter.py
  • keras/src/saving/saving_api.py
  • keras/src/saving/saving_lib.py
  • keras/src/saving/serialization_lib.py
  • keras/src/export/saved_model.py
  • keras/src/export/onnx.py

Data Adapters: Unified Input Interface

Keras uses a DataAdapter pattern to normalize diverse input formats into a unified interface. The base DataAdapter class defines methods to convert data into iterators compatible with different backends (NumPy, TensorFlow, JAX, PyTorch).

Supported input types:

  • NumPy arrays, TensorFlow tensors, JAX arrays, PyTorch tensors
  • tf.data.Dataset and tf.distribute.DistributedDataset
  • torch.utils.data.DataLoader
  • grain.DataLoader, grain.MapDataset, grain.IterDataset
  • Python generators
  • keras.utils.PyDataset (custom parallel dataset class)

Adapter selection logic (in get_data_adapter())

The framework checks input types in order:

  1. If already a DataAdapter, return as-is
  2. If array-like (NumPy, tensors) → ArrayDataAdapter
  3. If tf.data.DatasetTFDatasetAdapter
  4. If torch.utils.data.DataLoaderTorchDataLoaderAdapter
  5. If grain dataset → GrainDatasetAdapter
  6. If Python generator → GeneratorDataAdapter
  7. If PyDatasetPyDatasetAdapter

Key Adapter Features

Each adapter implements backend-specific iterators:

  • get_numpy_iterator() - yields NumPy arrays
  • get_tf_dataset() - returns tf.data.Dataset
  • get_jax_iterator() - yields JAX-compatible arrays
  • get_torch_dataloader() - returns PyTorch DataLoader
  • builtin_prefetch - indicates if adapter handles prefetching
  • num_batches, batch_size, has_partial_batch - metadata

Model Saving & Serialization

Keras provides a unified saving system via keras.saving.save_model() that creates .keras files (ZIP archives).

Archive structure:

  • config.json - model architecture (serialized configuration)
  • model.weights.h5 or model.weights.npz - model weights
  • metadata.json - version and backend info
  • assets/ - optional custom assets

Serialization process (serialization_lib.py):

  • serialize_keras_object() converts layers, optimizers, and configs to JSON
  • Safe mode prevents deserialization of arbitrary Python code (lambdas)
  • ObjectSharingScope handles shared object references
  • Custom objects registered via @keras_export are automatically discoverable

Model Export Formats

Beyond native .keras format, Keras supports specialized export formats:

  • SavedModel (ExportArchive) - TensorFlow serving format with configurable endpoints
  • ONNX - cross-platform inference via ONNX Runtime
  • OpenVINO - Intel optimization format
  • LiteRT - mobile/edge deployment

Export methods accept input_signature to define input shapes and dtypes, enabling static graph compilation and format-specific optimizations.

Loading diagram...

Advanced Features & Extensions

Relevant Files
  • keras/src/distribution/distribution_lib.py
  • keras/src/quantizers/gptq.py
  • keras/src/quantizers/quantization_config.py
  • keras/src/distillation/distiller.py
  • keras/src/dtype_policies/dtype_policy.py
  • keras/src/export/litert.py
  • keras/src/export/saved_model.py
  • keras/src/applications/resnet.py
  • keras/src/applications/efficientnet.py

Distributed Training & Model Parallelism

Keras provides unified distribution APIs for scaling models across multiple devices and hosts. The Distribution class enables both data parallelism and model parallelism strategies.

Data Parallelism replicates model variables across all devices while sharding input data. Model Parallelism shards variables across devices using a DeviceMesh and LayoutMap, allowing you to split large models that don't fit on a single device.

from keras.distribution import DeviceMesh, LayoutMap, ModelParallel

# Create a mesh with 2 devices for data parallelism and 4 for model parallelism
device_mesh = DeviceMesh(
    shape=(2, 4),
    axis_names=('batch', 'model'),
    devices=list_devices()
)

# Define variable sharding patterns
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)

# Use the distribution strategy
distribution = ModelParallel(device_mesh, layout_map)
with distribution.scope():
    model = keras.Sequential([...])

Quantization & Precision Control

Keras supports multiple quantization modes for model compression: int8, int4, float8, and GPTQ (Gradient-based Post-Training Quantization).

DTypePolicy controls both computation and variable dtypes globally or per-layer. Mixed precision policies like "mixed_float16" compute in float16 while storing variables in float32 for stability.

# Set global mixed precision policy
keras.config.set_dtype_policy("mixed_float16")

# Or use quantization for compression
from keras.quantizers import Int8QuantizationConfig, quantize

config = Int8QuantizationConfig()
quantized_model = quantize(model, config=config)

GPTQ performs post-training quantization with error correction, ideal for large language models. It uses inverse Hessian information to minimize accuracy loss during quantization.

Knowledge Distillation

The Distiller class enables knowledge transfer from a large teacher model to a smaller student model. This improves student performance beyond supervised training alone.

from keras.distillation import Distiller, LogitsDistillation

distiller = Distiller(
    teacher=teacher_model,
    student=student_model,
    distillation_losses=LogitsDistillation(temperature=3.0),
    student_loss_weight=0.5
)

distiller.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
distiller.fit(x_train, y_train, epochs=10)

trained_student = distiller.student

Multiple distillation losses can be combined (logits, features) with individual weights for fine-grained control.

Model Export & Deployment

Keras models can be exported to multiple formats for deployment across different platforms:

  • SavedModel (ExportArchive) - TensorFlow serving format with configurable endpoints
  • ONNX - cross-platform inference via ONNX Runtime
  • OpenVINO - Intel hardware optimization
  • LiteRT - mobile and edge device deployment
# Export to different formats
model.export("path/to/model", format="tf_saved_model")
model.export("path/to/model.onnx", format="onnx")
model.export("path/to/model.xml", format="openvino")
model.export("path/to/model.tflite", format="litert")

The input_signature parameter defines input shapes and dtypes, enabling static graph compilation and format-specific optimizations.

Pre-trained Applications

Keras includes production-ready pre-trained models for computer vision tasks. ResNet and EfficientNet families provide various sizes with ImageNet weights.

from keras.applications import ResNet50, EfficientNetB0

# Load with pre-trained weights
model = ResNet50(weights='imagenet', include_top=True)

# Or without weights for transfer learning
model = EfficientNetB0(weights=None, input_shape=(224, 224, 3))

These models support flexible input shapes, custom pooling strategies, and custom classification heads for fine-tuning on new tasks.