checkpoints
checkpoints
¶
Checkpoint utilities for saving and loading learner state.
Provides save_checkpoint and load_checkpoint for persisting any
learner state (LearnerState, MLPLearnerState, MultiHeadMLPState,
TDLearnerState) to disk using orbax-checkpoint.
The caller provides a template state (from learner.init()) to
load_checkpoint so the tree structure is known at load time.
For loading just metadata without a template (e.g. to read learner config
before constructing the template), use load_checkpoint_metadata.
Examples:
import jax.random as jr
from alberta_framework import MultiHeadMLPLearner, save_checkpoint, load_checkpoint
learner = MultiHeadMLPLearner(n_heads=5, hidden_sizes=(64, 64))
state = learner.init(feature_dim=20, key=jr.key(42))
# Save (creates a checkpoint directory at the given path)
save_checkpoint(state, "agent.ckpt", metadata={"epoch": 1})
# Load (template provides tree structure)
template = learner.init(feature_dim=20, key=jr.key(0))
loaded_state, meta = load_checkpoint(template, "agent.ckpt")
assert meta["epoch"] == 1
# Load metadata only (no template needed)
meta = load_checkpoint_metadata("agent.ckpt")
assert meta["epoch"] == 1
save_checkpoint(state, path, metadata=None)
¶
Save learner state to disk.
Creates a checkpoint directory at path containing the serialized
state PyTree and optional user metadata as JSON.
Args: state: Any learner state (LearnerState, MLPLearnerState, MultiHeadMLPState, TDLearnerState) path: Path for the checkpoint directory. metadata: Optional user metadata dict to store alongside the checkpoint (e.g. epoch, learner config, etc.)
Source code in src/alberta_framework/core/checkpoints.py
load_checkpoint(state_template, path)
¶
Load checkpoint into a state matching the template's tree structure.
The template state (from learner.init()) provides the PyTree
structure for deserialization.
Args:
state_template: A state of the same type and structure as the
saved state. Typically created via learner.init() with
the same architecture.
path: Path to the checkpoint directory.
Returns:
Tuple of (loaded_state, user_metadata) where user_metadata
is the dict passed to save_checkpoint, or an empty dict if
none was provided.
Raises: FileNotFoundError: If checkpoint directory doesn't exist ValueError: If state structure doesn't match template
Source code in src/alberta_framework/core/checkpoints.py
load_checkpoint_metadata(path)
¶
Load only the user metadata from a checkpoint, without a state template.
This is useful when metadata contains configuration needed to construct the state template (e.g. learner_config in rlsecd).
Args: path: Path to the checkpoint directory.
Returns: The user metadata dict, or an empty dict if none was stored.
Raises: FileNotFoundError: If checkpoint directory doesn't exist
Source code in src/alberta_framework/core/checkpoints.py
checkpoint_exists(path)
¶
Check whether a checkpoint exists at the given path.
Args: path: Path to check for a checkpoint directory.
Returns: True if a checkpoint directory exists at the path.