Skip to content

alberta_framework

alberta_framework

Alberta Framework: A JAX-based research framework for continual AI.

The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity — every component updates at every time step, with no special training phases or batch processing.

Roadmap
Step Focus Status
1 Meta-learned step-sizes (IDBD, Autostep) Complete
2 Nonlinear function approximation (MLP, ObGD) In Progress
3 GVF predictions, Horde architecture Planned
4 Actor-critic with eligibility traces Planned
5-6 Off-policy learning, average reward Planned
7-12 Hierarchical, multi-agent, world models Future

Examples:

import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop

# Non-stationary stream where target weights drift over time
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)

# Learner with IDBD meta-learned step-sizes
learner = LinearLearner(optimizer=IDBD())

# JIT-compiled training via jax.lax.scan
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
References
  • The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
  • Adapting Bias by Gradient Descent (Sutton, 1992)
  • Tuning-free Step-size Adaptation (Mahmood et al., 2012)
  • Streaming Deep Reinforcement Learning Finally Works (Elsayed et al., 2024)

FeatureRelevance

Per-feature and per-head relevance metrics extracted from learner state.

All fields are derived from existing MultiHeadMLPState arrays. No forward pass is required.

Attributes: weight_relevance: Path-norm relevance from input features to each head. Shape (n_heads, feature_dim). step_size_activity: Mean absolute step-size on input layer per feature. Shape (feature_dim,). trace_activity: Mean absolute trunk trace magnitude on input layer per feature. Shape (feature_dim,). normalizer_mean: Per-feature normalizer mean estimate, or None if no normalizer. Shape (feature_dim,). normalizer_std: Per-feature normalizer std estimate, or None if no normalizer. Shape (feature_dim,). head_reliance: L1 norm of each head's weight vector over the last hidden layer. Shape (n_heads, hidden_dim_last). head_mean_step_size: Mean step-size per head, or None if optimizer has no per-weight step-sizes. Shape (n_heads,).

BatchedHordeResult

Result from batched Horde learning loop.

Attributes: states: Batched multi-head MLP learner states per_demon_metrics: Per-demon metrics, shape (n_seeds, num_steps, n_demons, 3) td_errors: TD errors, shape (n_seeds, num_steps, n_demons)

HordeLearner(horde_spec, hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None)

Horde: GVF demons sharing a trunk (Sutton et al. 2011).

Wraps MultiHeadMLPLearner. Adds: - Per-demon gamma/lambda from HordeSpec - TD target computation for temporal demons (gamma > 0) - GVF metadata

The trunk uses gamma=0, lamda=0 (no temporal trace decay on shared features). Each head uses its own gamma * lambda product for trace decay, set via per_head_gamma_lamda on the inner learner.

For all-gamma=0 Hordes (e.g. rlsecd's 5 prediction heads), this produces identical results to MultiHeadMLPLearner since the TD target reduces to just the cumulant.

Single-Step (Daemon) Usage

Both predict() and update() work with single unbatched observations (1D arrays). JIT-compiled automatically.

Attributes: horde_spec: The HordeSpec defining all demons n_demons: Number of demons (heads)

Args: horde_spec: Specification of all GVF demons hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). step_size: Base learning rate (used only when optimizer is None) bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer sparsity: Fraction of weights zeroed out per neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization head_optimizer: Optional separate optimizer for heads

Source code in src/alberta_framework/core/horde.py
def __init__(
    self,
    horde_spec: HordeSpec,
    hidden_sizes: tuple[int, ...] = (128, 128),
    optimizer: AnyOptimizer | None = None,
    step_size: float = 1.0,
    bounder: Bounder | None = None,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
    sparsity: float = 0.9,
    leaky_relu_slope: float = 0.01,
    use_layer_norm: bool = True,
    head_optimizer: AnyOptimizer | None = None,
):
    """Initialize the Horde learner.

    Args:
        horde_spec: Specification of all GVF demons
        hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
        optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
        step_size: Base learning rate (used only when optimizer is None)
        bounder: Optional update bounder (e.g. ObGDBounding)
        normalizer: Optional feature normalizer
        sparsity: Fraction of weights zeroed out per neuron (default: 0.9)
        leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
        use_layer_norm: Whether to apply parameterless layer normalization
        head_optimizer: Optional separate optimizer for heads
    """
    self._horde_spec = horde_spec
    self._hidden_sizes = hidden_sizes
    self._step_size = step_size
    self._sparsity = sparsity
    self._leaky_relu_slope = leaky_relu_slope
    self._use_layer_norm = use_layer_norm

    # Compute per-head gamma*lambda products
    per_head_gl = tuple(
        float(d.gamma * d.lamda) for d in horde_spec.demons
    )

    self._learner = MultiHeadMLPLearner(
        n_heads=len(horde_spec.demons),
        hidden_sizes=hidden_sizes,
        optimizer=optimizer,
        step_size=step_size,
        bounder=bounder,
        gamma=0.0,  # trunk: no trace decay
        lamda=0.0,
        normalizer=normalizer,
        sparsity=sparsity,
        leaky_relu_slope=leaky_relu_slope,
        use_layer_norm=use_layer_norm,
        head_optimizer=head_optimizer,
        per_head_gamma_lamda=per_head_gl,
    )

horde_spec property

The HordeSpec defining all demons.

n_demons property

Number of demons (heads).

learner property

The underlying MultiHeadMLPLearner.

to_config()

Serialize learner configuration to dict.

Returns: Dict with horde_spec and all MultiHeadMLPLearner constructor args.

Source code in src/alberta_framework/core/horde.py
def to_config(self) -> dict[str, Any]:
    """Serialize learner configuration to dict.

    Returns:
        Dict with horde_spec and all MultiHeadMLPLearner constructor args.
    """
    learner_config = self._learner.to_config()
    # Remove fields managed by HordeLearner
    learner_config.pop("type", None)
    learner_config.pop("n_heads", None)
    learner_config.pop("gamma", None)
    learner_config.pop("lamda", None)
    learner_config.pop("per_head_gamma_lamda", None)

    return {
        "type": "HordeLearner",
        "horde_spec": self._horde_spec.to_config(),
        **learner_config,
    }

from_config(config) classmethod

Reconstruct from config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed HordeLearner

Source code in src/alberta_framework/core/horde.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "HordeLearner":
    """Reconstruct from config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed HordeLearner
    """
    from alberta_framework.core.normalizers import normalizer_from_config
    from alberta_framework.core.optimizers import (
        bounder_from_config,
        optimizer_from_config,
    )

    config = dict(config)
    config.pop("type", None)

    horde_spec = HordeSpec.from_config(config.pop("horde_spec"))
    optimizer = optimizer_from_config(config.pop("optimizer"))
    bounder_cfg = config.pop("bounder", None)
    bounder = bounder_from_config(bounder_cfg) if bounder_cfg is not None else None
    normalizer_cfg = config.pop("normalizer", None)
    normalizer = (
        normalizer_from_config(normalizer_cfg) if normalizer_cfg is not None else None
    )
    head_opt_cfg = config.pop("head_optimizer", None)
    head_optimizer = (
        optimizer_from_config(head_opt_cfg) if head_opt_cfg is not None else None
    )

    return cls(
        horde_spec=horde_spec,
        hidden_sizes=tuple(config.pop("hidden_sizes")),
        optimizer=optimizer,
        bounder=bounder,
        normalizer=normalizer,
        head_optimizer=head_optimizer,
        **config,
    )

init(feature_dim, key)

Initialize Horde learner state.

Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization

Returns: Initial MultiHeadMLPState

Source code in src/alberta_framework/core/horde.py
def init(self, feature_dim: int, key: Array) -> MultiHeadMLPState:
    """Initialize Horde learner state.

    Args:
        feature_dim: Dimension of the input feature vector
        key: JAX random key for weight initialization

    Returns:
        Initial MultiHeadMLPState
    """
    return self._learner.init(feature_dim, key)

predict(state, observation)

Compute predictions from all demons.

Args: state: Current learner state observation: Input feature vector

Returns: Array of shape (n_demons,) with one prediction per demon

Source code in src/alberta_framework/core/horde.py
@functools.partial(jax.jit, static_argnums=(0,))
def predict(self, state: MultiHeadMLPState, observation: Array) -> Array:
    """Compute predictions from all demons.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Array of shape ``(n_demons,)`` with one prediction per demon
    """
    return self._learner.predict(state, observation)  # type: ignore[no-any-return]

update(state, observation, cumulants, next_observation)

Update Horde given observation, cumulants, and next observation.

Computes TD targets r + gamma * V(s') for each demon, then delegates to MultiHeadMLPLearner.update(). For gamma=0 demons, the target equals the cumulant.

Args: state: Current state observation: Input feature vector, shape (feature_dim,) cumulants: Per-demon pseudo-rewards, shape (n_demons,). NaN = inactive demon. next_observation: Next feature vector, shape (feature_dim,). Used for V(s') bootstrapping. For all-gamma=0 Hordes, this is required but doesn't affect results.

Returns: HordeUpdateResult with updated state, predictions, TD errors, TD targets, and per-demon metrics

Source code in src/alberta_framework/core/horde.py
@functools.partial(jax.jit, static_argnums=(0,))
def update(
    self,
    state: MultiHeadMLPState,
    observation: Array,
    cumulants: Array,
    next_observation: Array,
) -> HordeUpdateResult:
    """Update Horde given observation, cumulants, and next observation.

    Computes TD targets ``r + gamma * V(s')`` for each demon, then
    delegates to ``MultiHeadMLPLearner.update()``. For gamma=0 demons,
    the target equals the cumulant.

    Args:
        state: Current state
        observation: Input feature vector, shape ``(feature_dim,)``
        cumulants: Per-demon pseudo-rewards, shape ``(n_demons,)``.
            NaN = inactive demon.
        next_observation: Next feature vector, shape ``(feature_dim,)``.
            Used for V(s') bootstrapping. For all-gamma=0 Hordes,
            this is required but doesn't affect results.

    Returns:
        HordeUpdateResult with updated state, predictions, TD errors,
        TD targets, and per-demon metrics
    """
    # 1. Compute V(s') for bootstrapping
    next_preds = self._learner.predict(state, next_observation)

    # 2. TD targets: r + gamma * V(s')
    # For gamma=0 demons: target = cumulant (single-step prediction)
    # NaN cumulants stay NaN (inactive demons)
    gammas = self._horde_spec.gammas
    targets = cumulants + gammas * next_preds

    # 3. Delegate to MultiHeadMLPLearner
    result = self._learner.update(state, observation, targets)

    return HordeUpdateResult(  # type: ignore[call-arg]
        state=result.state,
        predictions=result.predictions,
        td_errors=result.errors,
        td_targets=targets,
        per_demon_metrics=result.per_head_metrics,
        trunk_bounding_metric=result.trunk_bounding_metric,
    )

HordeLearningResult

Result from a Horde scan-based learning loop.

Attributes: state: Final multi-head MLP learner state per_demon_metrics: Per-demon metrics over time, shape (num_steps, n_demons, 3) td_errors: TD errors over time, shape (num_steps, n_demons)

HordeUpdateResult

Result of a single Horde update step.

Attributes: state: Updated multi-head MLP learner state predictions: Predictions from all demons, shape (n_demons,) td_errors: TD errors (target - prediction), shape (n_demons,). NaN for inactive demons. td_targets: Computed TD targets r + gamma * V(s'), shape (n_demons,). NaN for inactive demons. per_demon_metrics: Per-demon metrics, shape (n_demons, 3). Columns: [squared_error, raw_error, mean_step_size]. NaN for inactive demons. trunk_bounding_metric: Scalar trunk bounding metric

LinearLearner(optimizer=None, normalizer=None)

Linear function approximator with pluggable optimizer and optional normalizer.

Computes predictions as: y = w @ x + b

The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Attributes: optimizer: The optimizer to use for weight updates normalizer: Optional online feature normalizer

Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01) normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)

Source code in src/alberta_framework/core/learners.py
def __init__(
    self,
    optimizer: AnyOptimizer | None = None,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
):
    """Initialize the linear learner.

    Args:
        optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
        normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)
    """
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
    self._normalizer = normalizer

normalizer property

The feature normalizer, or None if normalization is disabled.

init(feature_dim)

Initialize learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> LearnerState:
    """Initialize learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    normalizer_state = None
    if self._normalizer is not None:
        normalizer_state = self._normalizer.init(feature_dim)

    return LearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
        normalizer_state=normalizer_state,
        step_count=jnp.array(0, dtype=jnp.int32),
        birth_timestamp=time.time(),
        uptime_s=0.0,
    )

predict(state, observation)

Compute prediction for an observation.

Args: state: Current learner state observation: Input feature vector

Returns: Scalar prediction y = w @ x + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Scalar prediction ``y = w @ x + b``
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, target)

Update learner given observation and target.

Performs one step of the learning algorithm: 1. Optionally normalize observation 2. Compute prediction 3. Compute error 4. Get weight updates from optimizer 5. Apply updates to weights and bias

Args: state: Current learner state observation: Input feature vector target: Desired output

Returns: UpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: LearnerState,
    observation: Observation,
    target: Target,
) -> UpdateResult:
    """Update learner given observation and target.

    Performs one step of the learning algorithm:
    1. Optionally normalize observation
    2. Compute prediction
    3. Compute error
    4. Get weight updates from optimizer
    5. Apply updates to weights and bias

    Args:
        state: Current learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        UpdateResult with new state, prediction, error, and metrics
    """
    # Handle normalization
    new_normalizer_state = state.normalizer_state
    obs = observation
    if self._normalizer is not None and state.normalizer_state is not None:
        obs, new_normalizer_state = self._normalizer.normalize(
            state.normalizer_state, observation
        )

    # Make prediction
    prediction = self.predict(
        LearnerState(
            weights=state.weights,
            bias=state.bias,
            optimizer_state=state.optimizer_state,
            normalizer_state=new_normalizer_state,
            step_count=state.step_count,
            birth_timestamp=state.birth_timestamp,
            uptime_s=state.uptime_s,
        ),
        obs,
    )

    # Compute error (target - prediction)
    error = jnp.squeeze(target) - jnp.squeeze(prediction)

    # Get update from optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        error,
        obs,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = LearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
        normalizer_state=new_normalizer_state,
        step_count=state.step_count + 1,
        birth_timestamp=state.birth_timestamp,
        uptime_s=state.uptime_s,
    )

    # Pack metrics as array for scan compatibility
    squared_error = error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)

    if self._normalizer is not None and new_normalizer_state is not None:
        normalizer_mean_var = jnp.mean(new_normalizer_state.var)
        metrics = jnp.array(
            [squared_error, error, mean_step_size, normalizer_mean_var],
            dtype=jnp.float32,
        )
    else:
        metrics = jnp.array(
            [squared_error, error, mean_step_size], dtype=jnp.float32
        )

    return UpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

MLPLearner(hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, gamma=0.0, lamda=0.0, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None)

Multi-layer perceptron with composable optimizer, bounder, and normalizer.

Architecture: Input -> [Dense(H) -> LayerNorm -> LeakyReLU] x N -> Dense(1)

When use_layer_norm=False, the architecture simplifies to: Input -> [Dense(H) -> LeakyReLU] x N -> Dense(1)

Uses parameterless layer normalization and sparse initialization following Elsayed et al. 2024. Accepts a pluggable optimizer (LMS, Autostep), an optional bounder (ObGDBounding), and an optional feature normalizer (EMANormalizer, WelfordNormalizer).

The update flow: 1. If normalizer: normalize observation, update normalizer state 2. Forward pass + jax.grad to get per-layer prediction gradients 3. Update eligibility traces: z = gamma * lamda * z + grad 4. Per-layer optimizer step: step, new_opt = optimizer.update_from_gradient(state, z) 5. If bounder: bound all steps globally 6. Apply: param += scale * error * step

Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"

Attributes: hidden_sizes: Tuple of hidden layer sizes optimizer: Optimizer for per-weight step-size adaptation bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer use_layer_norm: Whether to apply parameterless layer normalization gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter sparsity: Fraction of weights zeroed out per output neuron leaky_relu_slope: Negative slope for LeakyReLU activation

Single-Step (Daemon) Usage

Both predict() and update() work with single unbatched observations (1D arrays of shape (feature_dim,)). This is the intended usage for daemon-style deployments.

For low-latency daemon use, pre-compile predict and update at startup by running a dummy warmup call:

dummy_obs = jnp.zeros(feature_dim)
dummy_target = jnp.zeros(1)
_ = learner.predict(state, dummy_obs)
result = learner.update(state, dummy_obs, dummy_target)

Args: hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). Must support init_for_shape and update_from_gradient. step_size: Base learning rate (used only when optimizer is None, default: 1.0) bounder: Optional update bounder (e.g. ObGDBounding for ObGD-style bounding). When None, no bounding is applied. gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised) normalizer: Optional feature normalizer. When provided, features are normalized before prediction and learning. sparsity: Fraction of weights zeroed out per output neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization between hidden layers (default: True). Set to False for ablation studies. head_optimizer: Optional separate optimizer for the output (head) layer. When None (default), all layers use optimizer. When set, hidden layers use optimizer while the output layer uses head_optimizer. This enables hybrid configurations like stable LMS for the trunk with adaptive Autostep for the head.

Source code in src/alberta_framework/core/learners.py
def __init__(
    self,
    hidden_sizes: tuple[int, ...] = (128, 128),
    optimizer: AnyOptimizer | None = None,
    step_size: float = 1.0,
    bounder: Bounder | None = None,
    gamma: float = 0.0,
    lamda: float = 0.0,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
    sparsity: float = 0.9,
    leaky_relu_slope: float = 0.01,
    use_layer_norm: bool = True,
    head_optimizer: AnyOptimizer | None = None,
):
    """Initialize MLP learner.

    Args:
        hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
        optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
            Must support ``init_for_shape`` and ``update_from_gradient``.
        step_size: Base learning rate (used only when optimizer is None,
            default: 1.0)
        bounder: Optional update bounder (e.g. ObGDBounding for ObGD-style
            bounding). When None, no bounding is applied.
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
        normalizer: Optional feature normalizer. When provided, features are
            normalized before prediction and learning.
        sparsity: Fraction of weights zeroed out per output neuron (default: 0.9)
        leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
        use_layer_norm: Whether to apply parameterless layer normalization
            between hidden layers (default: True). Set to False for ablation
            studies.
        head_optimizer: Optional separate optimizer for the output (head) layer.
            When None (default), all layers use ``optimizer``. When set, hidden
            layers use ``optimizer`` while the output layer uses
            ``head_optimizer``. This enables hybrid configurations like
            stable LMS for the trunk with adaptive Autostep for the head.
    """
    self._hidden_sizes = hidden_sizes
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=step_size)
    self._head_optimizer: AnyOptimizer | None = head_optimizer
    self._bounder = bounder
    self._gamma = gamma
    self._lamda = lamda
    self._normalizer = normalizer
    self._sparsity = sparsity
    self._leaky_relu_slope = leaky_relu_slope
    self._use_layer_norm = use_layer_norm

normalizer property

The feature normalizer, or None if normalization is disabled.

to_config()

Serialize learner configuration to dict.

Returns: Dict with all constructor arguments needed to recreate the learner via from_config().

Source code in src/alberta_framework/core/learners.py
def to_config(self) -> dict[str, Any]:
    """Serialize learner configuration to dict.

    Returns:
        Dict with all constructor arguments needed to recreate
        the learner via ``from_config()``.
    """
    config: dict[str, Any] = {
        "type": "MLPLearner",
        "hidden_sizes": list(self._hidden_sizes),
        "optimizer": self._optimizer.to_config(),
        "bounder": self._bounder.to_config() if self._bounder is not None else None,
        "normalizer": self._normalizer.to_config() if self._normalizer is not None else None,
        "head_optimizer": (
            self._head_optimizer.to_config()
            if self._head_optimizer is not None
            else None
        ),
        "sparsity": self._sparsity,
        "leaky_relu_slope": self._leaky_relu_slope,
        "use_layer_norm": self._use_layer_norm,
        "gamma": self._gamma,
        "lamda": self._lamda,
    }
    return config

from_config(config) classmethod

Reconstruct learner from a config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed MLPLearner instance

Source code in src/alberta_framework/core/learners.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "MLPLearner":
    """Reconstruct learner from a config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed MLPLearner instance
    """
    from alberta_framework.core.normalizers import normalizer_from_config
    from alberta_framework.core.optimizers import (
        bounder_from_config,
        optimizer_from_config,
    )

    config = dict(config)
    config.pop("type", None)

    optimizer = optimizer_from_config(config.pop("optimizer"))
    bounder_cfg = config.pop("bounder", None)
    bounder = bounder_from_config(bounder_cfg) if bounder_cfg is not None else None
    normalizer_cfg = config.pop("normalizer", None)
    normalizer = normalizer_from_config(normalizer_cfg) if normalizer_cfg is not None else None
    head_opt_cfg = config.pop("head_optimizer", None)
    head_optimizer = optimizer_from_config(head_opt_cfg) if head_opt_cfg is not None else None

    return cls(
        hidden_sizes=tuple(config.pop("hidden_sizes")),
        optimizer=optimizer,
        bounder=bounder,
        normalizer=normalizer,
        head_optimizer=head_optimizer,
        **config,
    )

init(feature_dim, key)

Initialize MLP learner state with sparse weights.

Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization

Returns: Initial MLP learner state with sparse weights and zero biases

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int, key: Array) -> MLPLearnerState:
    """Initialize MLP learner state with sparse weights.

    Args:
        feature_dim: Dimension of the input feature vector
        key: JAX random key for weight initialization

    Returns:
        Initial MLP learner state with sparse weights and zero biases
    """
    # Build layer sizes: [feature_dim, hidden1, hidden2, ..., 1]
    layer_sizes = [feature_dim, *self._hidden_sizes, 1]

    weights_list = []
    biases_list = []
    traces_list = []
    opt_states_list = []

    n_total_layers = len(layer_sizes) - 1
    for i in range(n_total_layers):
        fan_out = layer_sizes[i + 1]
        fan_in = layer_sizes[i]
        key, subkey = jax.random.split(key)
        w = sparse_init(subkey, (fan_out, fan_in), sparsity=self._sparsity)
        b = jnp.zeros(fan_out, dtype=jnp.float32)
        weights_list.append(w)
        biases_list.append(b)
        # Traces for weights and biases (interleaved: w0, b0, w1, b1, ...)
        traces_list.append(jnp.zeros_like(w))
        traces_list.append(jnp.zeros_like(b))
        # Optimizer states: use head_optimizer for the output layer if set
        is_output = i == n_total_layers - 1
        opt = (
            self._head_optimizer
            if (self._head_optimizer is not None and is_output)
            else self._optimizer
        )
        opt_states_list.append(opt.init_for_shape(w.shape))
        opt_states_list.append(opt.init_for_shape(b.shape))

    params = MLPParams(
        weights=tuple(weights_list),
        biases=tuple(biases_list),
    )

    normalizer_state = None
    if self._normalizer is not None:
        normalizer_state = self._normalizer.init(feature_dim)

    return MLPLearnerState(
        params=params,
        optimizer_states=tuple(opt_states_list),
        traces=tuple(traces_list),
        normalizer_state=normalizer_state,
        step_count=jnp.array(0, dtype=jnp.int32),
        birth_timestamp=time.time(),
        uptime_s=0.0,
    )

predict(state, observation)

Compute prediction for an observation.

JIT-compiled automatically. First call triggers tracing; subsequent calls with the same learner instance use the cached compilation.

Args: state: Current MLP learner state observation: Input feature vector

Returns: Scalar prediction

Source code in src/alberta_framework/core/learners.py
@functools.partial(jax.jit, static_argnums=(0,))
def predict(self, state: MLPLearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    JIT-compiled automatically. First call triggers tracing; subsequent
    calls with the same learner instance use the cached compilation.

    Args:
        state: Current MLP learner state
        observation: Input feature vector

    Returns:
        Scalar prediction
    """
    y = self._forward(
        state.params.weights,
        state.params.biases,
        observation,
        self._leaky_relu_slope,
        self._use_layer_norm,
    )
    return jnp.atleast_1d(y)

update(state, observation, target)

Update MLP given observation and target.

JIT-compiled automatically. Performs one step of the learning algorithm:

  1. Optionally normalize observation
  2. Compute prediction and error
  3. Compute gradients via jax.grad on the forward pass
  4. Update eligibility traces
  5. Per-layer optimizer step from traces
  6. Optionally bound steps
  7. Apply bounded weight updates

Args: state: Current MLP learner state observation: Input feature vector target: Desired output

Returns: MLPUpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
@functools.partial(jax.jit, static_argnums=(0,))
def update(
    self,
    state: MLPLearnerState,
    observation: Observation,
    target: Target,
) -> MLPUpdateResult:
    """Update MLP given observation and target.

    JIT-compiled automatically. Performs one step of the learning
    algorithm:

    1. Optionally normalize observation
    2. Compute prediction and error
    3. Compute gradients via jax.grad on the forward pass
    4. Update eligibility traces
    5. Per-layer optimizer step from traces
    6. Optionally bound steps
    7. Apply bounded weight updates

    Args:
        state: Current MLP learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        MLPUpdateResult with new state, prediction, error, and metrics
    """
    target_scalar = jnp.squeeze(target)

    # Handle normalization
    obs = observation
    new_normalizer_state = state.normalizer_state
    if self._normalizer is not None and state.normalizer_state is not None:
        obs, new_normalizer_state = self._normalizer.normalize(
            state.normalizer_state, observation
        )

    # Forward pass for prediction
    prediction_val = self._forward(
        state.params.weights,
        state.params.biases,
        obs,
        self._leaky_relu_slope,
        self._use_layer_norm,
    )
    prediction = jnp.atleast_1d(prediction_val)
    error = target_scalar - prediction_val

    # Compute gradients w.r.t. prediction
    slope = self._leaky_relu_slope
    ln = self._use_layer_norm

    def pred_fn(weights: tuple[Array, ...], biases: tuple[Array, ...]) -> Array:
        return self._forward(weights, biases, obs, slope, ln)

    weight_grads, bias_grads = jax.grad(pred_fn, argnums=(0, 1))(
        state.params.weights, state.params.biases
    )

    # Update eligibility traces: z = gamma * lamda * z + grad
    gamma_lamda = jnp.array(self._gamma * self._lamda, dtype=jnp.float32)
    n_layers = len(state.params.weights)

    new_traces = []
    for i in range(n_layers):
        # Weight trace (index 2*i)
        new_wt = gamma_lamda * state.traces[2 * i] + weight_grads[i]
        new_traces.append(new_wt)
        # Bias trace (index 2*i + 1)
        new_bt = gamma_lamda * state.traces[2 * i + 1] + bias_grads[i]
        new_traces.append(new_bt)

    # Per-parameter optimizer step from traces
    # Output layer uses head_optimizer if set (last 2 entries: weight + bias)
    n_trace_entries = len(new_traces)
    all_steps = []
    new_opt_states = []
    for j in range(n_trace_entries):
        is_output = self._head_optimizer is not None and j >= n_trace_entries - 2
        opt = self._head_optimizer if is_output else self._optimizer
        step, new_opt = opt.update_from_gradient(
            state.optimizer_states[j], new_traces[j], error=error
        )
        all_steps.append(step)
        new_opt_states.append(new_opt)

    # Bounding (optional)
    bounding_metric = jnp.array(1.0, dtype=jnp.float32)
    if self._bounder is not None:
        all_params = []
        for i in range(n_layers):
            all_params.append(state.params.weights[i])
            all_params.append(state.params.biases[i])
        bounded_steps, bounding_metric = self._bounder.bound(
            tuple(all_steps), error, tuple(all_params)
        )
        all_steps = list(bounded_steps)

    # Apply updates: param += error * step
    new_weights = []
    new_biases = []
    for i in range(n_layers):
        new_w = state.params.weights[i] + error * all_steps[2 * i]
        new_weights.append(new_w)
        new_b = state.params.biases[i] + error * all_steps[2 * i + 1]
        new_biases.append(new_b)

    new_params = MLPParams(
        weights=tuple(new_weights), biases=tuple(new_biases)
    )
    new_state = MLPLearnerState(
        params=new_params,
        optimizer_states=tuple(new_opt_states),
        traces=tuple(new_traces),
        normalizer_state=new_normalizer_state,
        step_count=state.step_count + 1,
        birth_timestamp=state.birth_timestamp,
        uptime_s=state.uptime_s,
    )

    squared_error = error**2

    if self._normalizer is not None and new_normalizer_state is not None:
        normalizer_mean_var = jnp.mean(new_normalizer_state.var)
        metrics = jnp.array(
            [squared_error, error, bounding_metric, normalizer_mean_var],
            dtype=jnp.float32,
        )
    else:
        metrics = jnp.array(
            [squared_error, error, bounding_metric], dtype=jnp.float32
        )

    return MLPUpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

MLPUpdateResult

Result of an MLP learner update step.

Attributes: state: Updated MLP learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics -- shape (3,) without normalizer, (4,) with normalizer

TDLinearLearner(optimizer=None)

Linear function approximator for TD learning.

Computes value predictions as: V(s) = w @ phi(s) + b

The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: optimizer: The TD optimizer to use for weight updates

Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyTDOptimizer | None = None):
    """Initialize the TD linear learner.

    Args:
        optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
    """
    self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()

init(feature_dim)

Initialize TD learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial TD learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> TDLearnerState:
    """Initialize TD learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial TD learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return TDLearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
        step_count=jnp.array(0, dtype=jnp.int32),
        birth_timestamp=time.time(),
        uptime_s=0.0,
    )

predict(state, observation)

Compute value prediction for an observation.

Args: state: Current TD learner state observation: Input feature vector phi(s)

Returns: Scalar value prediction V(s) = w @ phi(s) + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
    """Compute value prediction for an observation.

    Args:
        state: Current TD learner state
        observation: Input feature vector phi(s)

    Returns:
        Scalar value prediction ``V(s) = w @ phi(s) + b``
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, reward, next_observation, gamma)

Update learner given a TD transition.

Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error delta = R + gamma*V(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias

Args: state: Current TD learner state observation: Current observation phi(s) reward: Reward R received next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal states)

Returns: TDUpdateResult with new state, prediction, TD error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: TDLearnerState,
    observation: Observation,
    reward: Array,
    next_observation: Observation,
    gamma: Array,
) -> TDUpdateResult:
    """Update learner given a TD transition.

    Performs one step of TD learning:
    1. Compute V(s) and V(s')
    2. Compute TD error delta = R + gamma*V(s') - V(s)
    3. Get weight updates from TD optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current TD learner state
        observation: Current observation phi(s)
        reward: Reward R received
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal states)

    Returns:
        TDUpdateResult with new state, prediction, TD error, and metrics
    """
    # Compute predictions
    prediction = self.predict(state, observation)
    next_prediction = self.predict(state, next_observation)

    # Compute TD error: delta = R + gamma*V(s') - V(s)
    gamma_scalar = jnp.squeeze(gamma)
    td_error = (
        jnp.squeeze(reward)
        + gamma_scalar * jnp.squeeze(next_prediction)
        - jnp.squeeze(prediction)
    )

    # Get update from TD optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        td_error,
        observation,
        next_observation,
        gamma,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = TDLearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
        step_count=state.step_count + 1,
        birth_timestamp=state.birth_timestamp,
        uptime_s=state.uptime_s,
    )

    # Pack metrics as array for scan compatibility
    squared_td_error = td_error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
    metrics = jnp.array(
        [squared_td_error, td_error, mean_step_size, mean_elig_trace],
        dtype=jnp.float32,
    )

    return TDUpdateResult(
        state=new_state,
        prediction=prediction,
        td_error=jnp.atleast_1d(td_error),
        metrics=metrics,
    )

TDUpdateResult

Result of a TD learner update step.

Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error delta = R + gamma*V(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]

UpdateResult

Result of a learner update step.

Attributes: state: Updated learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics -- shape (3,) without normalizer, (4,) with normalizer

BatchedMultiHeadResult

Result from batched multi-head learning loop.

Attributes: states: Batched multi-head MLP learner states per_head_metrics: Per-head metrics, shape (n_seeds, num_steps, n_heads, 3)

MultiHeadLearningResult

Result from multi-head learning loop.

Attributes: state: Final multi-head MLP learner state per_head_metrics: Per-head metrics over time, shape (num_steps, n_heads, 3)

MultiHeadMLPLearner(n_heads, hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, gamma=0.0, lamda=0.0, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None, per_head_gamma_lamda=None)

Multi-head MLP with shared trunk and independent prediction heads.

Architecture: Input -> [Dense(H) -> LayerNorm -> LeakyReLU] x N -> {Head_i: Dense(1)} x n_heads

All hidden layers are shared (the trunk). Each head is an independent linear projection from the last hidden representation to a scalar.

The update method uses VJP with accumulated cotangents so that only one backward pass through the trunk is needed regardless of the number of active heads.

Trunk trace constraint: When hidden_sizes is non-empty (MLP mode), trunk gamma * lamda must be 0. The VJP backward pass folds per-head errors into the trunk cotangent before trace accumulation, so traces accumulate error-weighted gradients. For gamma * lamda = 0 this is correct (traces reset each step). For gamma * lamda > 0 it would produce biased trace updates that violate forward-view equivalence (Sutton & Barto Ch. 12). Use HordeLearner for per-head trace decay — it sets trunk gamma=0, lamda=0 and applies per-head gamma * lambda only to the head layers. For linear baselines (hidden_sizes=()), there is no trunk, so any gamma * lamda is fine.

Attributes: n_heads: Number of prediction heads hidden_sizes: Tuple of hidden layer sizes. Pass () for a multi-head linear model (heads project directly from input features). optimizer: Optimizer for per-weight step-size adaptation bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer use_layer_norm: Whether to apply parameterless layer normalization gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter sparsity: Fraction of weights zeroed out per output neuron leaky_relu_slope: Negative slope for LeakyReLU activation

Single-Step (Daemon) Usage

Both predict() and update() work with single unbatched observations (1D arrays). This is the intended usage for daemon-style deployments where one observation arrives at a time.

Both methods are JIT-compiled automatically. The first call triggers JAX's tracing; subsequent calls use the cached compilation. For low-latency startup, run a warmup call so the first real event is fast:

# At daemon startup, after learner.init():
dummy_obs = jnp.zeros(feature_dim)
dummy_targets = jnp.full(n_heads, jnp.nan)
learner.predict(state, dummy_obs).block_until_ready()     # Warmup trace
learner.update(state, dummy_obs, dummy_targets)            # Warmup trace
# First real event will now be fast

NaN target masking works per-step: pass jnp.nan for any head that should not update. Inactive heads preserve their params, traces, and optimizer states.

Args: n_heads: Number of prediction heads hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). Must support init_for_shape and update_from_gradient. step_size: Base learning rate (used only when optimizer is None) bounder: Optional update bounder (e.g. ObGDBounding) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0) normalizer: Optional feature normalizer sparsity: Fraction of weights zeroed out per neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization (default: True) head_optimizer: Optional separate optimizer for the output heads. When None (default), all layers use optimizer. When set, trunk (hidden) layers use optimizer while each head uses head_optimizer. This enables hybrid configurations like stable LMS for the trunk with adaptive Autostep for the heads. per_head_gamma_lamda: Optional per-head trace decay factors. When set, each head uses its own gamma * lambda product for trace decay instead of the global gamma * lamda. Length must equal n_heads. Used by HordeLearner to assign per-demon discount/trace parameters.

Source code in src/alberta_framework/core/multi_head_learner.py
def __init__(
    self,
    n_heads: int,
    hidden_sizes: tuple[int, ...] = (128, 128),
    optimizer: AnyOptimizer | None = None,
    step_size: float = 1.0,
    bounder: Bounder | None = None,
    gamma: float = 0.0,
    lamda: float = 0.0,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
    sparsity: float = 0.9,
    leaky_relu_slope: float = 0.01,
    use_layer_norm: bool = True,
    head_optimizer: AnyOptimizer | None = None,
    per_head_gamma_lamda: tuple[float, ...] | None = None,
):
    """Initialize the multi-head MLP learner.

    Args:
        n_heads: Number of prediction heads
        hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
        optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
            Must support ``init_for_shape`` and ``update_from_gradient``.
        step_size: Base learning rate (used only when optimizer is None)
        bounder: Optional update bounder (e.g. ObGDBounding)
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0)
        normalizer: Optional feature normalizer
        sparsity: Fraction of weights zeroed out per neuron (default: 0.9)
        leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
        use_layer_norm: Whether to apply parameterless layer normalization
            (default: True)
        head_optimizer: Optional separate optimizer for the output heads.
            When None (default), all layers use ``optimizer``. When set,
            trunk (hidden) layers use ``optimizer`` while each head uses
            ``head_optimizer``. This enables hybrid configurations like
            stable LMS for the trunk with adaptive Autostep for the heads.
        per_head_gamma_lamda: Optional per-head trace decay factors.
            When set, each head uses its own ``gamma * lambda`` product
            for trace decay instead of the global ``gamma * lamda``.
            Length must equal ``n_heads``. Used by ``HordeLearner``
            to assign per-demon discount/trace parameters.
    """
    self._n_heads = n_heads
    self._hidden_sizes = hidden_sizes
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=step_size)
    self._head_optimizer: AnyOptimizer | None = head_optimizer
    self._bounder = bounder
    self._gamma = gamma
    self._lamda = lamda
    self._normalizer = normalizer
    self._sparsity = sparsity
    self._leaky_relu_slope = leaky_relu_slope
    self._use_layer_norm = use_layer_norm
    self._per_head_gl: tuple[float, ...] | None = per_head_gamma_lamda

    # Validate trunk trace constraint: gamma*lamda > 0 is only safe
    # when there is no trunk (linear baseline). With a trunk, the VJP
    # cotangent folds error into gradients before trace accumulation,
    # producing biased traces when gamma*lamda > 0.
    if gamma * lamda > 0 and len(hidden_sizes) > 0:
        msg = (
            f"Trunk gamma*lamda must be 0 when hidden_sizes is non-empty "
            f"(got gamma={gamma}, lamda={lamda}, hidden_sizes={hidden_sizes}). "
            f"The VJP backward pass bakes error into trunk gradients before "
            f"trace accumulation, which is only correct when traces reset "
            f"each step (gamma*lamda=0). Use HordeLearner for per-head "
            f"trace decay with a shared trunk."
        )
        raise ValueError(msg)

n_heads property

Number of prediction heads.

normalizer property

The feature normalizer, or None if normalization is disabled.

to_config()

Serialize learner configuration to dict.

Returns: Dict with all constructor arguments needed to recreate the learner via from_config().

Source code in src/alberta_framework/core/multi_head_learner.py
def to_config(self) -> dict[str, Any]:
    """Serialize learner configuration to dict.

    Returns:
        Dict with all constructor arguments needed to recreate
        the learner via ``from_config()``.
    """
    config: dict[str, Any] = {
        "type": "MultiHeadMLPLearner",
        "n_heads": self._n_heads,
        "hidden_sizes": list(self._hidden_sizes),
        "optimizer": self._optimizer.to_config(),
        "bounder": self._bounder.to_config() if self._bounder is not None else None,
        "normalizer": (
            self._normalizer.to_config() if self._normalizer is not None else None
        ),
        "head_optimizer": (
            self._head_optimizer.to_config()
            if self._head_optimizer is not None
            else None
        ),
        "sparsity": self._sparsity,
        "leaky_relu_slope": self._leaky_relu_slope,
        "use_layer_norm": self._use_layer_norm,
        "gamma": self._gamma,
        "lamda": self._lamda,
        "per_head_gamma_lamda": (
            list(self._per_head_gl) if self._per_head_gl is not None else None
        ),
    }
    return config

from_config(config) classmethod

Reconstruct learner from a config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed MultiHeadMLPLearner instance

Source code in src/alberta_framework/core/multi_head_learner.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "MultiHeadMLPLearner":
    """Reconstruct learner from a config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed MultiHeadMLPLearner instance
    """
    from alberta_framework.core.normalizers import normalizer_from_config
    from alberta_framework.core.optimizers import (
        bounder_from_config,
        optimizer_from_config,
    )

    config = dict(config)
    config.pop("type", None)

    optimizer = optimizer_from_config(config.pop("optimizer"))
    bounder_cfg = config.pop("bounder", None)
    bounder = bounder_from_config(bounder_cfg) if bounder_cfg is not None else None
    normalizer_cfg = config.pop("normalizer", None)
    normalizer = (
        normalizer_from_config(normalizer_cfg) if normalizer_cfg is not None else None
    )
    head_opt_cfg = config.pop("head_optimizer", None)
    head_optimizer = (
        optimizer_from_config(head_opt_cfg) if head_opt_cfg is not None else None
    )

    per_head_gl = config.pop("per_head_gamma_lamda", None)
    if per_head_gl is not None:
        per_head_gl = tuple(per_head_gl)

    return cls(
        n_heads=config.pop("n_heads"),
        hidden_sizes=tuple(config.pop("hidden_sizes")),
        optimizer=optimizer,
        bounder=bounder,
        normalizer=normalizer,
        head_optimizer=head_optimizer,
        per_head_gamma_lamda=per_head_gl,
        **config,
    )

init(feature_dim, key)

Initialize multi-head MLP learner state with sparse weights.

Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization

Returns: Initial state with sparse trunk weights, zero biases, and per-head output layers

Source code in src/alberta_framework/core/multi_head_learner.py
def init(self, feature_dim: int, key: Array) -> MultiHeadMLPState:
    """Initialize multi-head MLP learner state with sparse weights.

    Args:
        feature_dim: Dimension of the input feature vector
        key: JAX random key for weight initialization

    Returns:
        Initial state with sparse trunk weights, zero biases, and
        per-head output layers
    """
    # Trunk: [feature_dim, *hidden_sizes] — all hidden layers
    trunk_layer_sizes = [feature_dim, *self._hidden_sizes]

    trunk_weights: list[Array] = []
    trunk_biases: list[Array] = []
    trunk_traces: list[Array] = []
    trunk_opt_states: list[LMSState | AutostepParamState] = []

    for i in range(len(trunk_layer_sizes) - 1):
        fan_out = trunk_layer_sizes[i + 1]
        fan_in = trunk_layer_sizes[i]
        key, subkey = jax.random.split(key)
        w = sparse_init(subkey, (fan_out, fan_in), sparsity=self._sparsity)
        b = jnp.zeros(fan_out, dtype=jnp.float32)
        trunk_weights.append(w)
        trunk_biases.append(b)
        # Interleaved traces and optimizer states: w0, b0, w1, b1, ...
        trunk_traces.append(jnp.zeros_like(w))
        trunk_traces.append(jnp.zeros_like(b))
        trunk_opt_states.append(self._optimizer.init_for_shape(w.shape))
        trunk_opt_states.append(self._optimizer.init_for_shape(b.shape))

    trunk_params = MLPParams(
        weights=tuple(trunk_weights),
        biases=tuple(trunk_biases),
    )

    # Heads: n_heads output layers, each (1, h_last)
    # h_last = last hidden dim, or feature_dim when no trunk layers
    h_last = self._hidden_sizes[-1] if self._hidden_sizes else feature_dim
    head_weights: list[Array] = []
    head_biases: list[Array] = []
    head_traces_list: list[tuple[Array, Array]] = []
    head_opt_states_list: list[tuple[Any, ...]] = []

    head_opt = self._head_optimizer if self._head_optimizer is not None else self._optimizer
    for _ in range(self._n_heads):
        key, subkey = jax.random.split(key)
        w = sparse_init(subkey, (1, h_last), sparsity=self._sparsity)
        b = jnp.zeros(1, dtype=jnp.float32)
        head_weights.append(w)
        head_biases.append(b)
        head_traces_list.append((jnp.zeros_like(w), jnp.zeros_like(b)))
        head_opt_states_list.append((
            head_opt.init_for_shape(w.shape),
            head_opt.init_for_shape(b.shape),
        ))

    head_params = MLPParams(
        weights=tuple(head_weights),
        biases=tuple(head_biases),
    )

    normalizer_state = None
    if self._normalizer is not None:
        normalizer_state = self._normalizer.init(feature_dim)

    return MultiHeadMLPState(
        trunk_params=trunk_params,
        head_params=head_params,
        trunk_optimizer_states=tuple(trunk_opt_states),
        head_optimizer_states=tuple(head_opt_states_list),
        trunk_traces=tuple(trunk_traces),
        head_traces=tuple(head_traces_list),
        normalizer_state=normalizer_state,
        step_count=jnp.array(0, dtype=jnp.int32),
        birth_timestamp=time.time(),
        uptime_s=0.0,
    )

predict(state, observation)

Compute predictions from all heads.

JIT-compiled automatically. First call triggers tracing; subsequent calls with the same learner instance use the cached compilation.

Args: state: Current multi-head MLP learner state observation: Input feature vector

Returns: Array of shape (n_heads,) with one prediction per head

Source code in src/alberta_framework/core/multi_head_learner.py
@functools.partial(jax.jit, static_argnums=(0,))
def predict(self, state: MultiHeadMLPState, observation: Array) -> Array:
    """Compute predictions from all heads.

    JIT-compiled automatically. First call triggers tracing; subsequent
    calls with the same learner instance use the cached compilation.

    Args:
        state: Current multi-head MLP learner state
        observation: Input feature vector

    Returns:
        Array of shape ``(n_heads,)`` with one prediction per head
    """
    obs = observation
    if self._normalizer is not None and state.normalizer_state is not None:
        obs = self._normalizer.normalize_only(state.normalizer_state, observation)

    hidden = self._trunk_forward(
        state.trunk_params.weights,
        state.trunk_params.biases,
        obs,
        self._leaky_relu_slope,
        self._use_layer_norm,
    )

    predictions = []
    for i in range(self._n_heads):
        pred = self._head_forward(
            state.head_params.weights[i],
            state.head_params.biases[i],
            hidden,
        )
        predictions.append(pred)

    return jnp.array(predictions)

update(state, observation, targets)

Update multi-head MLP given observation and per-head targets.

JIT-compiled automatically. Uses VJP with accumulated cotangents for a single backward pass through the trunk. Error from each active head is folded into the trunk gradient before trace accumulation.

Args: state: Current state observation: Input feature vector targets: Per-head targets, shape (n_heads,). NaN = inactive head.

Returns: MultiHeadMLPUpdateResult with updated state, predictions, errors, and per-head metrics

Source code in src/alberta_framework/core/multi_head_learner.py
@functools.partial(jax.jit, static_argnums=(0,))
def update(
    self,
    state: MultiHeadMLPState,
    observation: Array,
    targets: Array,
) -> MultiHeadMLPUpdateResult:
    """Update multi-head MLP given observation and per-head targets.

    JIT-compiled automatically. Uses VJP with accumulated cotangents
    for a single backward pass through the trunk. Error from each
    active head is folded into the trunk gradient before trace
    accumulation.

    Args:
        state: Current state
        observation: Input feature vector
        targets: Per-head targets, shape ``(n_heads,)``.
            NaN = inactive head.

    Returns:
        MultiHeadMLPUpdateResult with updated state, predictions,
        errors, and per-head metrics
    """
    n_heads = self._n_heads
    gamma_lamda = jnp.array(self._gamma * self._lamda, dtype=jnp.float32)

    # 1. Handle NaN targets
    active_mask = ~jnp.isnan(targets)  # (n_heads,)
    safe_targets = jnp.where(active_mask, targets, 0.0)

    # 2. Normalize observation if needed
    obs = observation
    new_normalizer_state = state.normalizer_state
    if self._normalizer is not None and state.normalizer_state is not None:
        obs, new_normalizer_state = self._normalizer.normalize(
            state.normalizer_state, observation
        )

    # 3. Forward trunk via VJP
    slope = self._leaky_relu_slope
    ln = self._use_layer_norm

    def trunk_fn(
        weights: tuple[Array, ...], biases: tuple[Array, ...]
    ) -> Array:
        return self._trunk_forward(weights, biases, obs, slope, ln)

    hidden, trunk_vjp_fn = jax.vjp(
        trunk_fn,
        state.trunk_params.weights,
        state.trunk_params.biases,
    )

    # 4. Per-head forward + compute errors + accumulate cotangent
    h_last = hidden.shape[0]
    cotangent = jnp.zeros(h_last, dtype=jnp.float32)
    predictions_list: list[Array] = []
    errors_list: list[Array] = []

    for i in range(n_heads):
        pred_i = self._head_forward(
            state.head_params.weights[i],
            state.head_params.biases[i],
            hidden,
        )
        error_i = safe_targets[i] - pred_i
        masked_error_i = jnp.where(active_mask[i], error_i, 0.0)

        predictions_list.append(pred_i)
        errors_list.append(jnp.where(active_mask[i], error_i, jnp.nan))

        # Accumulate cotangent: error_i * d(pred_i)/d(hidden)
        # d(pred_i)/d(hidden) = head_w_i squeezed to (H_last,)
        # NOTE: Error is folded into the cotangent here, so trunk VJP
        # gradients are error-weighted. This is safe because trunk
        # gamma*lamda=0 (validated in __init__), so traces reset each
        # step and the error-gradient coupling doesn't accumulate.
        cotangent = cotangent + masked_error_i * jnp.squeeze(
            state.head_params.weights[i]
        )

    predictions_arr = jnp.array(predictions_list)
    errors_arr = jnp.array(errors_list)

    # 5. One backward pass through trunk
    trunk_weight_grads, trunk_bias_grads = trunk_vjp_fn(cotangent)
    # These grads are already error-weighted

    # 6. Update trunk traces and optimizer
    n_trunk_layers = len(state.trunk_params.weights)
    new_trunk_traces: list[Array] = []
    trunk_steps: list[Array] = []
    new_trunk_opt_states: list[LMSState | AutostepParamState] = []

    for i in range(n_trunk_layers):
        # Weight trace (index 2*i)
        new_wt = gamma_lamda * state.trunk_traces[2 * i] + trunk_weight_grads[i]
        new_trunk_traces.append(new_wt)
        w_step, new_w_opt = self._optimizer.update_from_gradient(
            state.trunk_optimizer_states[2 * i], new_wt, error=None
        )
        trunk_steps.append(w_step)
        new_trunk_opt_states.append(new_w_opt)

        # Bias trace (index 2*i + 1)
        new_bt = gamma_lamda * state.trunk_traces[2 * i + 1] + trunk_bias_grads[i]
        new_trunk_traces.append(new_bt)
        b_step, new_b_opt = self._optimizer.update_from_gradient(
            state.trunk_optimizer_states[2 * i + 1], new_bt, error=None
        )
        trunk_steps.append(b_step)
        new_trunk_opt_states.append(new_b_opt)

    # Trunk bounding (pseudo_error=1.0 since error is in gradient)
    trunk_bounding_metric = jnp.array(1.0, dtype=jnp.float32)
    if self._bounder is not None:
        trunk_params_flat: list[Array] = []
        for i in range(n_trunk_layers):
            trunk_params_flat.append(state.trunk_params.weights[i])
            trunk_params_flat.append(state.trunk_params.biases[i])
        bounded_trunk_steps, trunk_bounding_metric = self._bounder.bound(
            tuple(trunk_steps), jnp.array(1.0), tuple(trunk_params_flat)
        )
        trunk_steps = list(bounded_trunk_steps)

    # Apply trunk updates (no error multiply -- error already in gradient)
    new_trunk_weights: list[Array] = []
    new_trunk_biases: list[Array] = []
    for i in range(n_trunk_layers):
        new_trunk_weights.append(
            state.trunk_params.weights[i] + trunk_steps[2 * i]
        )
        new_trunk_biases.append(
            state.trunk_params.biases[i] + trunk_steps[2 * i + 1]
        )

    new_trunk_params = MLPParams(
        weights=tuple(new_trunk_weights),
        biases=tuple(new_trunk_biases),
    )

    # 7. Per-head updates
    new_head_weights: list[Array] = []
    new_head_biases: list[Array] = []
    new_head_traces_list: list[tuple[Array, Array]] = []
    new_head_opt_states_list: list[tuple[Any, ...]] = []
    per_head_metrics_list: list[Array] = []

    for i in range(n_heads):
        head_w = state.head_params.weights[i]
        head_b = state.head_params.biases[i]
        old_w_trace, old_b_trace = state.head_traces[i]
        old_w_opt, old_b_opt = state.head_optimizer_states[i]

        # Head prediction gradient: d(pred_i)/d(head_w) = hidden
        w_grad = hidden.reshape(1, -1)  # (1, H_last)
        b_grad = jnp.ones(1, dtype=jnp.float32)

        # Update traces (per-head decay if configured)
        head_gl = (
            jnp.array(self._per_head_gl[i], dtype=jnp.float32)
            if self._per_head_gl is not None
            else gamma_lamda
        )
        new_w_trace = head_gl * old_w_trace + w_grad
        new_b_trace = head_gl * old_b_trace + b_grad

        # Error for this head (masked to 0 for inactive)
        error_i = jnp.where(
            active_mask[i], safe_targets[i] - predictions_list[i], 0.0
        )

        # Optimizer step (with error for meta-learning)
        head_opt = self._head_optimizer if self._head_optimizer is not None else self._optimizer
        w_step, new_w_opt = head_opt.update_from_gradient(
            old_w_opt, new_w_trace, error=error_i
        )
        b_step, new_b_opt = head_opt.update_from_gradient(
            old_b_opt, new_b_trace, error=error_i
        )

        # Head bounding
        if self._bounder is not None:
            bounded_head_steps, _ = self._bounder.bound(
                (w_step, b_step), error_i, (head_w, head_b)
            )
            w_step, b_step = bounded_head_steps

        # Apply: param += error_i * step
        new_w = head_w + error_i * w_step
        new_b = head_b + error_i * b_step

        # Mask: for inactive heads, keep old state
        new_w = jnp.where(active_mask[i], new_w, head_w)
        new_b = jnp.where(active_mask[i], new_b, head_b)
        new_w_trace = jnp.where(active_mask[i], new_w_trace, old_w_trace)
        new_b_trace = jnp.where(active_mask[i], new_b_trace, old_b_trace)

        # Mask optimizer states back to old for inactive heads
        new_w_opt = jax.tree.map(
            lambda new, old: jnp.where(active_mask[i], new, old),
            new_w_opt,
            old_w_opt,
        )
        new_b_opt = jax.tree.map(
            lambda new, old: jnp.where(active_mask[i], new, old),
            new_b_opt,
            old_b_opt,
        )

        new_head_weights.append(new_w)
        new_head_biases.append(new_b)
        new_head_traces_list.append((new_w_trace, new_b_trace))
        new_head_opt_states_list.append((new_w_opt, new_b_opt))

        # Per-head metrics
        se_i = jnp.where(active_mask[i], error_i**2, jnp.nan)
        raw_error_i = jnp.where(active_mask[i], error_i, jnp.nan)
        mean_ss_i = _extract_mean_step_size(new_w_opt)
        mean_ss_i = jnp.where(active_mask[i], mean_ss_i, jnp.nan)
        per_head_metrics_list.append(
            jnp.array([se_i, raw_error_i, mean_ss_i])
        )

    new_head_params = MLPParams(
        weights=tuple(new_head_weights),
        biases=tuple(new_head_biases),
    )

    new_state = MultiHeadMLPState(
        trunk_params=new_trunk_params,
        head_params=new_head_params,
        trunk_optimizer_states=tuple(new_trunk_opt_states),
        head_optimizer_states=tuple(new_head_opt_states_list),
        trunk_traces=tuple(new_trunk_traces),
        head_traces=tuple(new_head_traces_list),
        normalizer_state=new_normalizer_state,
        step_count=state.step_count + 1,
        birth_timestamp=state.birth_timestamp,
        uptime_s=state.uptime_s,
    )

    per_head_metrics = jnp.stack(per_head_metrics_list)  # (n_heads, 3)

    return MultiHeadMLPUpdateResult(
        state=new_state,
        predictions=predictions_arr,
        errors=errors_arr,
        per_head_metrics=per_head_metrics,
        trunk_bounding_metric=trunk_bounding_metric,
    )

MultiHeadMLPState

State for a multi-head MLP learner.

The trunk (shared hidden layers) and heads (per-task output layers) maintain separate parameters, optimizer states, and eligibility traces.

Trunk optimizer states and traces use an interleaved layout (w0, b0, w1, b1, ...) matching the MLPLearner convention. Head optimizer states and traces use a nested layout ((w_opt, b_opt), ...) indexed by head.

Attributes: trunk_params: Shared hidden layer parameters head_params: Per-head output layer parameters. weights[i] / biases[i] = head i. trunk_optimizer_states: Interleaved (w0, b0, w1, b1, ...) optimizer states for trunk layers head_optimizer_states: Per-head ((w_opt, b_opt), ...) trunk_traces: Interleaved (w0, b0, w1, b1, ...) eligibility traces for trunk layers head_traces: Per-head ((w_trace, b_trace), ...) normalizer_state: Optional online feature normalizer state step_count: Scalar step counter

MultiHeadMLPUpdateResult

Result of a multi-head MLP learner update step.

Attributes: state: Updated multi-head MLP learner state predictions: Predictions from all heads, shape (n_heads,) errors: Prediction errors, shape (n_heads,). NaN for inactive heads. per_head_metrics: Per-head metrics, shape (n_heads, 3). Columns: [squared_error, raw_error, mean_step_size]. NaN for inactive heads. trunk_bounding_metric: Scalar trunk bounding metric

EMANormalizer(epsilon=1e-08, decay=0.99)

Bases: Normalizer[EMANormalizerState]

Online feature normalizer using exponential moving average.

Estimates mean and variance via EMA, suitable for non-stationary environments where recent observations should be weighted more heavily.

The effective decay ramps up from 0 to the target decay over early steps to prevent instability.

Attributes: epsilon: Small constant for numerical stability decay: Exponential decay for running estimates (0.99 = slower adaptation)

Args: epsilon: Small constant added to std for numerical stability decay: Exponential decay factor for running estimates. Lower values adapt faster to changes. 1.0 means pure online average (no decay).

Source code in src/alberta_framework/core/normalizers.py
def __init__(
    self,
    epsilon: float = 1e-8,
    decay: float = 0.99,
):
    """Initialize the EMA normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
        decay: Exponential decay factor for running estimates.
               Lower values adapt faster to changes.
               1.0 means pure online average (no decay).
    """
    super().__init__(epsilon=epsilon)
    self._decay = decay

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: StateT,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: StateT,
    observation: Array,
) -> StateT:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/normalizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {"type": "EMANormalizer", "epsilon": self._epsilon, "decay": self._decay}

init(feature_dim)

Initialize EMA normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
def init(self, feature_dim: int) -> EMANormalizerState:
    """Initialize EMA normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    return EMANormalizerState(
        mean=jnp.zeros(feature_dim, dtype=jnp.float32),
        var=jnp.ones(feature_dim, dtype=jnp.float32),
        sample_count=jnp.array(0.0, dtype=jnp.float32),
        decay=jnp.array(self._decay, dtype=jnp.float32),
    )

normalize(state, observation)

Normalize observation and update EMA running statistics.

Args: state: Current EMA normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
def normalize(
    self,
    state: EMANormalizerState,
    observation: Array,
) -> tuple[Array, EMANormalizerState]:
    """Normalize observation and update EMA running statistics.

    Args:
        state: Current EMA normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    # Update count
    new_count = state.sample_count + 1.0

    # Compute effective decay (ramp up from 0 to target decay)
    # This prevents instability in early steps
    effective_decay = jnp.minimum(state.decay, 1.0 - 1.0 / (new_count + 1.0))

    # Update mean using exponential moving average
    delta = observation - state.mean
    new_mean = state.mean + (1.0 - effective_decay) * delta

    # Update variance using exponential moving average of squared deviations
    delta2 = observation - new_mean
    new_var = effective_decay * state.var + (1.0 - effective_decay) * delta * delta2

    # Ensure variance is positive
    new_var = jnp.maximum(new_var, self._epsilon)

    # Normalize using updated statistics
    std = jnp.sqrt(new_var)
    normalized = (observation - new_mean) / (std + self._epsilon)

    new_state = EMANormalizerState(
        mean=new_mean,
        var=new_var,
        sample_count=new_count,
        decay=state.decay,
    )

    return normalized, new_state

EMANormalizerState

State for EMA-based online feature normalization.

Uses exponential moving average to estimate running mean and variance, suitable for non-stationary distributions.

Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature sample_count: Number of samples seen decay: Exponential decay factor for estimates (1.0 = no decay, pure online)

Normalizer(epsilon=1e-08)

Bases: ABC

Abstract base class for online feature normalizers.

Normalizes features using running estimates of mean and standard deviation: x_normalized = (x - mean) / (std + epsilon)

The normalizer updates its estimates at every time step, following temporal uniformity.

Subclasses must implement init and normalize. The normalize_only and update_only methods have default implementations.

Attributes: epsilon: Small constant for numerical stability

Args: epsilon: Small constant added to std for numerical stability

Source code in src/alberta_framework/core/normalizers.py
def __init__(self, epsilon: float = 1e-8):
    """Initialize the normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
    """
    self._epsilon = epsilon

to_config() abstractmethod

Serialize normalizer configuration to dict.

Source code in src/alberta_framework/core/normalizers.py
@abstractmethod
def to_config(self) -> dict[str, Any]:
    """Serialize normalizer configuration to dict."""
    ...

init(feature_dim) abstractmethod

Initialize normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    ...

normalize(state, observation) abstractmethod

Normalize observation and update running statistics.

This method both normalizes the current observation AND updates the running statistics, maintaining temporal uniformity.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
@abstractmethod
def normalize(
    self,
    state: StateT,
    observation: Array,
) -> tuple[Array, StateT]:
    """Normalize observation and update running statistics.

    This method both normalizes the current observation AND updates
    the running statistics, maintaining temporal uniformity.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    ...

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: StateT,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: StateT,
    observation: Array,
) -> StateT:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

WelfordNormalizer(epsilon=1e-08)

Bases: Normalizer[WelfordNormalizerState]

Online feature normalizer using Welford's algorithm.

Computes cumulative sample mean and variance with Bessel's correction, suitable for stationary distributions. Numerically stable for large sample counts.

Reference: Welford 1962, "Note on a Method for Calculating Corrected Sums of Squares and Products"

Attributes: epsilon: Small constant for numerical stability

Args: epsilon: Small constant added to std for numerical stability

Source code in src/alberta_framework/core/normalizers.py
def __init__(self, epsilon: float = 1e-8):
    """Initialize the normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
    """
    self._epsilon = epsilon

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: StateT,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: StateT,
    observation: Array,
) -> StateT:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/normalizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {"type": "WelfordNormalizer", "epsilon": self._epsilon}

init(feature_dim)

Initialize Welford normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
def init(self, feature_dim: int) -> WelfordNormalizerState:
    """Initialize Welford normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    return WelfordNormalizerState(
        mean=jnp.zeros(feature_dim, dtype=jnp.float32),
        var=jnp.ones(feature_dim, dtype=jnp.float32),
        sample_count=jnp.array(0.0, dtype=jnp.float32),
        p=jnp.zeros(feature_dim, dtype=jnp.float32),
    )

normalize(state, observation)

Normalize observation and update Welford running statistics.

Uses Welford's online algorithm: 1. Increment count 2. Update mean incrementally 3. Update sum of squared deviations (p / M2) 4. Compute variance with Bessel's correction when count >= 2

Args: state: Current Welford normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
def normalize(
    self,
    state: WelfordNormalizerState,
    observation: Array,
) -> tuple[Array, WelfordNormalizerState]:
    """Normalize observation and update Welford running statistics.

    Uses Welford's online algorithm:
    1. Increment count
    2. Update mean incrementally
    3. Update sum of squared deviations (p / M2)
    4. Compute variance with Bessel's correction when count >= 2

    Args:
        state: Current Welford normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    new_count = state.sample_count + 1.0

    # Welford's incremental mean update
    delta = observation - state.mean
    new_mean = state.mean + delta / new_count

    # Update sum of squared deviations: p += (x - old_mean) * (x - new_mean)
    delta2 = observation - new_mean
    new_p = state.p + delta * delta2

    # Bessel-corrected variance; use unit variance when count < 2
    new_var = jnp.where(
        new_count >= 2.0,
        new_p / (new_count - 1.0),
        jnp.ones_like(new_p),
    )

    # Normalize using updated statistics
    std = jnp.sqrt(new_var)
    normalized = (observation - new_mean) / (std + self._epsilon)

    new_state = WelfordNormalizerState(
        mean=new_mean,
        var=new_var,
        sample_count=new_count,
        p=new_p,
    )

    return normalized, new_state

WelfordNormalizerState

State for Welford's online normalization algorithm.

Uses Welford's algorithm for numerically stable estimation of cumulative sample mean and variance with Bessel's correction.

Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature (Bessel-corrected) sample_count: Number of samples seen p: Sum of squared deviations from the current mean (M2 accumulator)

IDBD(initial_step_size=0.01, meta_step_size=0.01, h_decay_mode='prediction_grads')

Bases: Optimizer[IDBDState]

Incremental Delta-Bar-Delta optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.

This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes h_decay_mode: Mode for computing the h-decay term in MLP path. "prediction_grads": h_decay = z^2 (squared prediction gradients). This is the principled generalization — for linear models, z = x so z^2 = x^2, recovering Sutton 1992. "loss_grads": h_decay = (error * z)^2 (Fisher approximation of the Hessian diagonal). Only affects the MLP path (update_from_gradient); the linear update() method always uses x^2.

Raises: ValueError: If h_decay_mode is not one of the valid modes

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    h_decay_mode: str = "prediction_grads",
):
    """Initialize IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate beta for adapting step-sizes
        h_decay_mode: Mode for computing the h-decay term in MLP path.
            ``"prediction_grads"``: h_decay = z^2 (squared prediction
            gradients). This is the principled generalization — for
            linear models, z = x so z^2 = x^2, recovering Sutton 1992.
            ``"loss_grads"``: h_decay = (error * z)^2 (Fisher
            approximation of the Hessian diagonal).
            Only affects the MLP path (``update_from_gradient``);
            the linear ``update()`` method always uses x^2.

    Raises:
        ValueError: If ``h_decay_mode`` is not one of the valid modes
    """
    if h_decay_mode not in ("prediction_grads", "loss_grads"):
        raise ValueError(
            f"Invalid h_decay_mode: {h_decay_mode!r}. "
            "Must be 'prediction_grads' or 'loss_grads'."
        )
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._h_decay_mode = h_decay_mode

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    config: dict[str, Any] = {
        "type": "IDBD",
        "initial_step_size": self._initial_step_size,
        "meta_step_size": self._meta_step_size,
    }
    if self._h_decay_mode != "prediction_grads":
        config["h_decay_mode"] = self._h_decay_mode
    return config

init(feature_dim)

Initialize IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: IDBD state with per-weight step-sizes and traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> IDBDState:
    """Initialize IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        IDBD state with per-weight step-sizes and traces
    """
    return IDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
    )

init_for_shape(shape)

Initialize IDBD state for arbitrary-shape parameters.

Args: shape: Shape of the parameter array

Returns: IDBDParamState with arrays matching the given shape

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> IDBDParamState:
    """Initialize IDBD state for arbitrary-shape parameters.

    Args:
        shape: Shape of the parameter array

    Returns:
        IDBDParamState with arrays matching the given shape
    """
    return IDBDParamState(
        log_step_sizes=jnp.full(
            shape, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(shape, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
    )

update_from_gradient(state, gradient, error=None)

Compute IDBD update from pre-computed gradient (MLP path).

Implements Meyer's adaptation of IDBD to nonlinear models. The key insight: replace x^2 in the h-decay term with (dy/dw)^2 (squared prediction gradients), which generalizes IDBD to arbitrary architectures.

This follows Meyer's implementation, which differs from the linear IDBD (Sutton 1992) in two ways to better handle deep networks:

  1. The meta-update uses z * h (prediction gradient times trace) without the current error, rather than error * z * h.
  2. The h-trace accumulates loss gradients (-error * z) rather than error-scaled prediction gradients (error * z).

These changes address problems with IDBD in deep networks where the step-size being factored into both h and beta updates causes compounding effects.

Reference: Meyer, https://github.com/ejmejm/phd_research

Operation order (meta-update first, then new alpha for trace):

  1. Compute h_decay based on mode: z^2 or (error * z)^2
  2. Meta-update with OLD traces: log_alpha += beta * z * h
  3. Clip log step-sizes to [-10.0, 2.0]
  4. New step-sizes: alpha = exp(log_alpha)
  5. Step: alpha * z (error applied externally by caller)
  6. Trace update: h = h * max(0, 1 - alpha * h_decay) + alpha * g where g = -error * z (loss gradient direction)

When error is None (trunk path in multi-head), the gradient is already in loss gradient direction (accumulated cotangents), so the trace uses alpha * z directly.

Args: state: Current IDBD param state gradient: Pre-computed prediction gradient / eligibility trace (same shape as state arrays) error: Optional prediction error scalar. When provided, used for h_decay (loss_grads mode) and h-trace sign.

Returns: (step, new_state) where step has the same shape as gradient

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self,
    state: IDBDParamState,
    gradient: Array,
    error: Array | None = None,
) -> tuple[Array, IDBDParamState]:
    """Compute IDBD update from pre-computed gradient (MLP path).

    Implements Meyer's adaptation of IDBD to nonlinear models. The key
    insight: replace ``x^2`` in the h-decay term with ``(dy/dw)^2``
    (squared prediction gradients), which generalizes IDBD to arbitrary
    architectures.

    This follows Meyer's implementation, which differs from the linear
    IDBD (Sutton 1992) in two ways to better handle deep networks:

    1. The meta-update uses ``z * h`` (prediction gradient times trace)
       without the current error, rather than ``error * z * h``.
    2. The h-trace accumulates loss gradients (``-error * z``) rather
       than error-scaled prediction gradients (``error * z``).

    These changes address problems with IDBD in deep networks where
    the step-size being factored into both h and beta updates causes
    compounding effects.

    Reference: Meyer, https://github.com/ejmejm/phd_research

    Operation order (meta-update first, then new alpha for trace):

    1. Compute h_decay based on mode: ``z^2`` or ``(error * z)^2``
    2. Meta-update with OLD traces: ``log_alpha += beta * z * h``
    3. Clip log step-sizes to ``[-10.0, 2.0]``
    4. New step-sizes: ``alpha = exp(log_alpha)``
    5. Step: ``alpha * z`` (error applied externally by caller)
    6. Trace update: ``h = h * max(0, 1 - alpha * h_decay) + alpha * g``
       where ``g = -error * z`` (loss gradient direction)

    When ``error`` is None (trunk path in multi-head), the gradient
    is already in loss gradient direction (accumulated cotangents),
    so the trace uses ``alpha * z`` directly.

    Args:
        state: Current IDBD param state
        gradient: Pre-computed prediction gradient / eligibility trace
            (same shape as state arrays)
        error: Optional prediction error scalar. When provided,
            used for h_decay (loss_grads mode) and h-trace sign.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient
    """
    beta = state.meta_step_size
    z = gradient

    # 1. Compute h_decay based on mode
    if self._h_decay_mode == "loss_grads" and error is not None:
        h_decay = (jnp.squeeze(error) * z) ** 2
    else:
        # prediction_grads mode, or loss_grads without error
        h_decay = z**2

    # 2. Meta-update with OLD traces (Meyer: prediction_grads * h, no error)
    meta_gradient = z * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * meta_gradient

    # 3. Clip log step-sizes
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # 4. New step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # 5. Step: alpha * z (error applied externally)
    step = new_alphas * z

    # 6. Trace update: h = h * decay + alpha * loss_grads
    # Meyer uses loss_grads = -error * z when error is available;
    # when error is None (trunk path), z is already loss gradient direction.
    decay = jnp.maximum(0.0, 1.0 - new_alphas * h_decay)
    if error is not None:
        new_traces = state.traces * decay - new_alphas * jnp.squeeze(error) * z
    else:
        new_traces = state.traces * decay + new_alphas * z

    new_state = IDBDParamState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
    )

    return step, new_state

update(state, error, observation)

Compute IDBD weight update with adaptive step-sizes.

Following Sutton 1992, Figure 2, the operation ordering is:

  1. Meta-update: log_alpha_i += beta * error * x_i * h_i (using OLD traces)
  2. Compute NEW step-sizes: alpha_i = exp(log_alpha_i)
  3. Update weights: w_i += alpha_i * error * x_i (using NEW alpha)
  4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i (using NEW alpha)

The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.

Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: IDBDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute IDBD weight update with adaptive step-sizes.

    Following Sutton 1992, Figure 2, the operation ordering is:

    1. Meta-update: ``log_alpha_i += beta * error * x_i * h_i`` (using OLD traces)
    2. Compute NEW step-sizes: ``alpha_i = exp(log_alpha_i)``
    3. Update weights: ``w_i += alpha_i * error * x_i`` (using NEW alpha)
    4. Update traces: ``h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i``
       (using NEW alpha)

    The trace h_i tracks the correlation between current and past gradients.
    When gradients consistently point the same direction, h_i grows,
    leading to larger step-sizes.

    Args:
        state: Current IDBD state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    beta = state.meta_step_size

    # 1. Meta-update: adapt step-sizes using OLD traces
    gradient_correlation = error_scalar * observation * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # 2. Compute NEW step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # 3. Weight updates using NEW alpha: alpha_i * error * x_i
    weight_delta = new_alphas * error_scalar * observation

    # 4. Update traces using NEW alpha: h_i = h_i * decay + alpha_i * error * x_i
    # decay = max(0, 1 - alpha_i * x_i^2)
    decay = jnp.maximum(0.0, 1.0 - new_alphas * observation**2)
    new_traces = state.traces * decay + new_alphas * error_scalar * observation

    # Bias updates (same ordering: meta-update first, then new alpha)
    bias_gradient_correlation = error_scalar * state.bias_trace
    new_bias_step_size = state.bias_step_size * jnp.exp(beta * bias_gradient_correlation)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)

    bias_delta = new_bias_step_size * error_scalar

    bias_decay = jnp.maximum(0.0, 1.0 - new_bias_step_size)
    new_bias_trace = state.bias_trace * bias_decay + new_bias_step_size * error_scalar

    new_state = IDBDState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
        },
    )

LMS(step_size=0.01)

Bases: Optimizer[LMSState]

Least Mean Square optimizer with fixed step-size.

The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t

This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.

Attributes: step_size: Fixed learning rate alpha

Args: step_size: Fixed learning rate

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, step_size: float = 0.01):
    """Initialize LMS optimizer.

    Args:
        step_size: Fixed learning rate
    """
    self._step_size = step_size

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {"type": "LMS", "step_size": self._step_size}

init(feature_dim)

Initialize LMS state.

Args: feature_dim: Dimension of weight vector (unused for LMS)

Returns: LMS state containing the step-size

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> LMSState:
    """Initialize LMS state.

    Args:
        feature_dim: Dimension of weight vector (unused for LMS)

    Returns:
        LMS state containing the step-size
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

init_for_shape(shape)

Initialize LMS state for arbitrary-shape parameters.

LMS state is shape-independent (single scalar), so this returns the same state regardless of shape.

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> LMSState:
    """Initialize LMS state for arbitrary-shape parameters.

    LMS state is shape-independent (single scalar), so this returns
    the same state regardless of shape.
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

update_from_gradient(state, gradient, error=None)

Compute step from gradient: step = alpha * gradient.

Args: state: Current LMS state gradient: Pre-computed gradient (any shape) error: Unused by LMS (accepted for interface compatibility)

Returns: (step, state) -- state is unchanged for LMS

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: LMSState, gradient: Array, error: Array | None = None
) -> tuple[Array, LMSState]:
    """Compute step from gradient: ``step = alpha * gradient``.

    Args:
        state: Current LMS state
        gradient: Pre-computed gradient (any shape)
        error: Unused by LMS (accepted for interface compatibility)

    Returns:
        ``(step, state)`` -- state is unchanged for LMS
    """
    del error  # LMS doesn't meta-learn
    return state.step_size * gradient, state

update(state, error, observation)

Compute LMS weight update.

Update rule: delta_w = alpha * error * x

Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight and bias deltas

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: LMSState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute LMS weight update.

    Update rule: ``delta_w = alpha * error * x``

    Args:
        state: Current LMS state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight and bias deltas
    """
    alpha = state.step_size
    error_scalar = jnp.squeeze(error)

    # Weight update: alpha * error * x
    weight_delta = alpha * error_scalar * observation

    # Bias update: alpha * error
    bias_delta = alpha * error_scalar

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=state,  # LMS state doesn't change
        metrics={"step_size": alpha},
    )

TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)

Bases: TDOptimizer[TDIDBDState]

TD-IDBD optimizer for temporal-difference learning.

Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Two variants are supported: - Semi-gradient (default): Uses only phi(s) in meta-update, more stable - Ordinary gradient: Uses both phi(s) and phi(s'), more accurate but sensitive

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    use_semi_gradient: bool = True,
):
    """Initialize TD-IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        use_semi_gradient: If True, use semi-gradient variant (recommended)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._use_semi_gradient = use_semi_gradient

init(feature_dim)

Initialize TD-IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> TDIDBDState:
    """Initialize TD-IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        TD-IDBD state with per-weight step-sizes, traces, and h traces
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute TD-IDBD weight update with adaptive step-sizes.

Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.

Args: state: Current TD-IDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: TDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute TD-IDBD weight update with adaptive step-sizes.

    Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
    from Kearney et al. 2019.

    Args:
        state: Current TD-IDBD state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    gamma_scalar = jnp.squeeze(gamma)

    if self._use_semi_gradient:
        gradient_correlation = delta * observation * state.h_traces
        new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
    else:
        feature_diff = gamma_scalar * next_observation - observation
        gradient_correlation = delta * feature_diff * state.h_traces
        new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation

    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
    new_alphas = jnp.exp(new_log_step_sizes)

    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
    weight_delta = new_alphas * delta * new_eligibility_traces

    if self._use_semi_gradient:
        h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
    else:
        feature_diff = gamma_scalar * next_observation - observation
        h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    if self._use_semi_gradient:
        bias_gradient_correlation = delta * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
    else:
        bias_feature_diff = gamma_scalar - 1.0
        bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    if self._use_semi_gradient:
        bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )
    else:
        bias_feature_diff = gamma_scalar - 1.0
        bias_h_decay = jnp.maximum(
            0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
        )
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )

    new_state = TDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        meta_step_size=theta,
        trace_decay=lam,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
        },
    )

AGCBounding(clip_factor=0.01, eps=0.001)

Bases: Bounder

Adaptive Gradient Clipping (Brock et al. 2021).

Clips per-output-unit based on the ratio of gradient norm to weight norm. Units where ||grad|| / max(||weight||, eps) > clip_factor get scaled down to respect the constraint.

Unlike ObGDBounding which applies a single global scale factor, AGC applies fine-grained, per-unit clipping that adapts to each layer's weight magnitude.

The metric returned is the fraction of units that were clipped (0.0 = no clipping, 1.0 = all units clipped).

Reference: Brock, A., De, S., Smith, S.L., & Simonyan, K. (2021). "High-Performance Large-Scale Image Recognition Without Normalization" (arXiv: 2102.06171)

Attributes: clip_factor: Maximum allowed gradient-to-weight ratio (lambda). Default 0.01. eps: Floor for weight norm to avoid division by zero. Default 1e-3.

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, clip_factor: float = 0.01, eps: float = 1e-3):
    self._clip_factor = clip_factor
    self._eps = eps

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {"type": "AGCBounding", "clip_factor": self._clip_factor, "eps": self._eps}

bound(steps, error, params)

Bound proposed steps using per-unit adaptive gradient clipping.

For each parameter/step pair, computes unit-wise norms and clips units where |error| * ||step|| > clip_factor * max(||param||, eps).

Args: steps: Per-parameter step arrays from the optimizer error: Prediction error scalar params: Current parameter values (used for weight norms)

Returns: (clipped_steps, frac_clipped) where frac_clipped is the fraction of units that were clipped

Source code in src/alberta_framework/core/optimizers.py
def bound(
    self,
    steps: tuple[Array, ...],
    error: Array,
    params: tuple[Array, ...],
) -> tuple[tuple[Array, ...], Array]:
    """Bound proposed steps using per-unit adaptive gradient clipping.

    For each parameter/step pair, computes unit-wise norms and clips
    units where ``|error| * ||step|| > clip_factor * max(||param||, eps)``.

    Args:
        steps: Per-parameter step arrays from the optimizer
        error: Prediction error scalar
        params: Current parameter values (used for weight norms)

    Returns:
        ``(clipped_steps, frac_clipped)`` where frac_clipped is the
        fraction of units that were clipped
    """
    error_abs = jnp.abs(jnp.squeeze(error))
    clipped = []
    total_units = 0
    clipped_units = jnp.array(0.0)

    for step, param in zip(steps, params):
        p_norm = _unitwise_norm(param)
        s_norm = _unitwise_norm(step)
        g_norm = error_abs * s_norm
        max_norm = jnp.maximum(p_norm, self._eps) * self._clip_factor
        scale = max_norm / jnp.maximum(g_norm, 1e-6)
        needs_clip = g_norm > max_norm
        clipped_step = jnp.where(needs_clip, step * scale, step)
        clipped.append(clipped_step)

        total_units += needs_clip.size
        clipped_units = clipped_units + jnp.sum(needs_clip.astype(jnp.float32))

    frac_clipped = clipped_units / jnp.maximum(total_units, 1)
    return tuple(clipped), frac_clipped

Autostep(initial_step_size=0.01, meta_step_size=0.01, tau=10000.0)

Bases: Optimizer[AutostepState]

Autostep optimizer with tuning-free step-size adaptation.

Implements the exact algorithm from Mahmood et al. 2012, Table 1.

The algorithm maintains per-weight step-sizes that adapt based on meta-gradient correlation. The key innovations are: - Self-regulated normalizers (v_i) that track meta-gradient magnitude |delta * x_i * h_i| for stable meta-updates - Overshoot prevention via effective step-size normalization M = max(sum(alpha_i * x_i^2), 1)

Per-sample update (Table 1):

  1. v_i = max(|delta*x_i*h_i|, v_i + (1/tau)*alpha_i*x_i^2*(|delta*x_i*h_i| - v_i))
  2. alpha_i *= exp(mu * delta*x_i*h_i / v_i) where v_i > 0
  3. M = max(sum(alpha_i * x_i^2), 1); alpha_i /= M
  4. w_i += alpha_i * delta * x_i (weight update with NEW alpha)
  5. h_i = h_i * (1 - alpha_i * x_i^2) + alpha_i * delta * x_i (trace update)

Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012). "Tuning-free step-size adaptation"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate mu for adapting step-sizes tau: Time constant for normalizer adaptation (default: 10000)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate for adapting step-sizes tau: Time constant for normalizer adaptation (default: 10000). Higher values mean slower normalizer decay.

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    tau: float = 10000.0,
):
    """Initialize Autostep optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate for adapting step-sizes
        tau: Time constant for normalizer adaptation (default: 10000).
            Higher values mean slower normalizer decay.
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._tau = tau

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {
        "type": "Autostep",
        "initial_step_size": self._initial_step_size,
        "meta_step_size": self._meta_step_size,
        "tau": self._tau,
    }

init(feature_dim)

Initialize Autostep state.

Normalizers (v_i) and traces (h_i) are initialized to 0 per the paper.

Args: feature_dim: Dimension of weight vector

Returns: Autostep state with per-weight step-sizes, traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutostepState:
    """Initialize Autostep state.

    Normalizers (v_i) and traces (h_i) are initialized to 0 per the paper.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Autostep state with per-weight step-sizes, traces, and normalizers
    """
    return AutostepState(
        step_sizes=jnp.full(feature_dim, self._initial_step_size, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        tau=jnp.array(self._tau, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(0.0, dtype=jnp.float32),
    )

init_for_shape(shape)

Initialize Autostep state for arbitrary-shape parameters.

Args: shape: Shape of the parameter array

Returns: AutostepParamState with arrays matching the given shape

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> AutostepParamState:
    """Initialize Autostep state for arbitrary-shape parameters.

    Args:
        shape: Shape of the parameter array

    Returns:
        AutostepParamState with arrays matching the given shape
    """
    return AutostepParamState(
        step_sizes=jnp.full(shape, self._initial_step_size, dtype=jnp.float32),
        traces=jnp.zeros(shape, dtype=jnp.float32),
        normalizers=jnp.zeros(shape, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        tau=jnp.array(self._tau, dtype=jnp.float32),
    )

update_from_gradient(state, gradient, error=None)

Compute Autostep update from pre-computed gradient (MLP path).

Implements the Table 1 algorithm generalized for arbitrary-shape parameters, where gradient plays the role of the eligibility trace z (prediction gradient).

When error is provided, the full paper algorithm is used: meta-gradient is error * z * h. When error is None, falls back to error-free approximation (z * h).

The returned step does NOT include the error -- the caller applies param += error * step after optional bounding.

Args: state: Current Autostep param state gradient: Pre-computed gradient / eligibility trace (same shape as state arrays) error: Optional prediction error scalar. When provided, enables the full paper algorithm with error-scaled meta-gradients.

Returns: (step, new_state) where step has the same shape as gradient

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self,
    state: AutostepParamState,
    gradient: Array,
    error: Array | None = None,
) -> tuple[Array, AutostepParamState]:
    """Compute Autostep update from pre-computed gradient (MLP path).

    Implements the Table 1 algorithm generalized for arbitrary-shape
    parameters, where ``gradient`` plays the role of the eligibility
    trace ``z`` (prediction gradient).

    When ``error`` is provided, the full paper algorithm is used:
    meta-gradient is ``error * z * h``. When ``error`` is None,
    falls back to error-free approximation (``z * h``).

    The returned step does NOT include the error -- the caller applies
    ``param += error * step`` after optional bounding.

    Args:
        state: Current Autostep param state
        gradient: Pre-computed gradient / eligibility trace (same shape as state arrays)
        error: Optional prediction error scalar. When provided, enables
            the full paper algorithm with error-scaled meta-gradients.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient
    """
    mu = state.meta_step_size
    tau = state.tau

    z = gradient  # eligibility trace
    z_sq = z**2

    # Compute meta-gradient: δ*z*h (or z*h if error is None)
    if error is not None:
        error_scalar = jnp.squeeze(error)
        meta_gradient = error_scalar * z * state.traces
    else:
        meta_gradient = z * state.traces

    abs_meta_gradient = jnp.abs(meta_gradient)

    # Eq. 4: v_i = max(|meta_grad|, v_i + (1/τ)*α_i*z_i²*(|meta_grad| - v_i))
    v_update = state.normalizers + (1.0 / tau) * state.step_sizes * z_sq * (
        abs_meta_gradient - state.normalizers
    )
    new_normalizers = jnp.maximum(abs_meta_gradient, v_update)

    # Eq. 5: α_i *= exp(μ * meta_grad / v_i) where v_i > 0
    safe_v = jnp.maximum(new_normalizers, 1e-38)
    new_step_sizes = jnp.where(
        new_normalizers > 0,
        state.step_sizes * jnp.exp(mu * meta_gradient / safe_v),
        state.step_sizes,
    )

    # Eq. 6-7: M = max(Σ α_i*z_i², 1); α_i /= M
    effective_step = jnp.sum(new_step_sizes * z_sq)
    m_factor = jnp.maximum(effective_step, 1.0)
    new_step_sizes = new_step_sizes / m_factor

    # Clip step-sizes for numerical safety
    new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)

    # Compute step: α_i * z_i (error applied externally)
    step = new_step_sizes * z

    # Trace update: h_i = h_i*(1 - α_i*z_i²) + α_i*δ*z_i
    trace_decay = 1.0 - new_step_sizes * z_sq
    if error is not None:
        new_traces = state.traces * trace_decay + new_step_sizes * error_scalar * z
    else:
        new_traces = state.traces * trace_decay + new_step_sizes * z

    new_state = AutostepParamState(
        step_sizes=new_step_sizes,
        traces=new_traces,
        normalizers=new_normalizers,
        meta_step_size=mu,
        tau=tau,
    )

    return step, new_state

update(state, error, observation)

Compute Autostep weight update following Mahmood et al. 2012, Table 1.

The algorithm per sample:

  1. Eq. 4: v_i = max(|δ*x_i*h_i|, v_i + (1/τ)*α_i*x_i²*(|δ*x_i*h_i| - v_i))
  2. Eq. 5: α_i *= exp(μ * δ*x_i*h_i / v_i) where v_i > 0
  3. Eq. 6-7: M = max(Σ α_i*x_i² + α_bias, 1); α_i /= M, α_bias /= M
  4. Weight update: w_i += α_i * δ * x_i (with NEW alpha)
  5. Trace update: h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i

Args: state: Current Autostep state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutostepState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute Autostep weight update following Mahmood et al. 2012, Table 1.

    The algorithm per sample:

    1. Eq. 4: ``v_i = max(|δ*x_i*h_i|, v_i + (1/τ)*α_i*x_i²*(|δ*x_i*h_i| - v_i))``
    2. Eq. 5: ``α_i *= exp(μ * δ*x_i*h_i / v_i)`` where ``v_i > 0``
    3. Eq. 6-7: ``M = max(Σ α_i*x_i² + α_bias, 1)``; ``α_i /= M``, ``α_bias /= M``
    4. Weight update: ``w_i += α_i * δ * x_i`` (with NEW alpha)
    5. Trace update: ``h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i``

    Args:
        state: Current Autostep state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    mu = state.meta_step_size
    tau = state.tau

    x = observation
    x_sq = x**2

    # --- Weights ---
    # Meta-gradient: δ*x_i*h_i
    meta_gradient = error_scalar * x * state.traces
    abs_meta_gradient = jnp.abs(meta_gradient)

    # Eq. 4: v_i update (self-regulated EMA)
    v_update = state.normalizers + (1.0 / tau) * state.step_sizes * x_sq * (
        abs_meta_gradient - state.normalizers
    )
    new_normalizers = jnp.maximum(abs_meta_gradient, v_update)

    # Eq. 5: α_i *= exp(μ * meta_grad / v_i) where v_i > 0
    safe_v = jnp.maximum(new_normalizers, 1e-38)
    new_step_sizes = jnp.where(
        new_normalizers > 0,
        state.step_sizes * jnp.exp(mu * meta_gradient / safe_v),
        state.step_sizes,
    )

    # --- Bias ---
    # Meta-gradient for bias (implicit x=1): δ*h_bias
    bias_meta_gradient = error_scalar * state.bias_trace
    abs_bias_meta_gradient = jnp.abs(bias_meta_gradient)

    # Eq. 4 for bias
    bias_v_update = state.bias_normalizer + (1.0 / tau) * state.bias_step_size * (
        abs_bias_meta_gradient - state.bias_normalizer
    )
    new_bias_normalizer = jnp.maximum(abs_bias_meta_gradient, bias_v_update)

    # Eq. 5 for bias
    safe_bias_v = jnp.maximum(new_bias_normalizer, 1e-38)
    new_bias_step_size = jnp.where(
        new_bias_normalizer > 0,
        state.bias_step_size * jnp.exp(mu * bias_meta_gradient / safe_bias_v),
        state.bias_step_size,
    )

    # Eq. 6-7: Overshoot prevention (joint over weights + bias)
    # M = max(Σ α_i*x_i² + α_bias*1², 1)
    effective_step = jnp.sum(new_step_sizes * x_sq) + new_bias_step_size
    m_factor = jnp.maximum(effective_step, 1.0)
    new_step_sizes = new_step_sizes / m_factor
    new_bias_step_size = new_bias_step_size / m_factor

    # Clip step-sizes for numerical safety
    new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-8, 1.0)

    # Weight update with NEW alpha: α_i * δ * x_i
    weight_delta = new_step_sizes * error_scalar * x

    # Bias update: α_bias * δ
    bias_delta = new_bias_step_size * error_scalar

    # Trace update: h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i
    trace_decay = 1.0 - new_step_sizes * x_sq
    new_traces = state.traces * trace_decay + new_step_sizes * error_scalar * x

    # Bias trace: h_bias = h_bias*(1 - α_bias) + α_bias*δ
    bias_trace_decay = 1.0 - new_bias_step_size
    new_bias_trace = state.bias_trace * bias_trace_decay + new_bias_step_size * error_scalar

    new_state = AutostepState(
        step_sizes=new_step_sizes,
        traces=new_traces,
        normalizers=new_normalizers,
        meta_step_size=mu,
        tau=tau,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_step_sizes),
            "min_step_size": jnp.min(new_step_sizes),
            "max_step_size": jnp.max(new_step_sizes),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Bases: TDOptimizer[AutoTDIDBDState]

AutoStep-style normalized TD-IDBD optimizer.

Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta.

Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(lambda)"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
):
    """Initialize AutoTDIDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._normalizer_decay = normalizer_decay

init(feature_dim)

Initialize AutoTDIDBD state.

Args: feature_dim: Dimension of weight vector

Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutoTDIDBDState:
    """Initialize AutoTDIDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

Implements Algorithm 6 from Kearney et al. 2019.

Args: state: Current AutoTDIDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutoTDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

    Implements Algorithm 6 from Kearney et al. 2019.

    Args:
        state: Current AutoTDIDBD state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    tau = state.normalizer_decay
    gamma_scalar = jnp.squeeze(gamma)

    feature_diff = gamma_scalar * next_observation - observation
    alphas = jnp.exp(state.log_step_sizes)

    # Update normalizers
    abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
    normalizer_decay_term = (
        (1.0 / tau)
        * alphas
        * feature_diff
        * state.eligibility_traces
        * (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
    )
    new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
    new_normalizers = jnp.maximum(new_normalizers, 1e-8)

    # Normalized meta-update
    normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
    new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient

    # Effective step-size normalization
    effective_step_size = -jnp.sum(
        jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
    )
    normalization_factor = jnp.maximum(effective_step_size, 1.0)
    new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)

    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
    new_alphas = jnp.exp(new_log_step_sizes)

    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
    weight_delta = new_alphas * delta * new_eligibility_traces

    # Update h traces
    h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
    new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    bias_alpha = jnp.exp(state.bias_log_step_size)
    bias_feature_diff = gamma_scalar - 1.0

    abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
    bias_normalizer_decay_term = (
        (1.0 / tau)
        * bias_alpha
        * bias_feature_diff
        * state.bias_eligibility_trace
        * (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
    )
    new_bias_normalizer = jnp.maximum(
        abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
    )
    new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)

    normalized_bias_gradient = (
        delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
    )
    new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient

    bias_effective_step_size = (
        -jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
    )
    bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
    new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    bias_h_decay = jnp.maximum(
        0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
    )
    new_bias_h_trace = (
        state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
    )

    new_state = AutoTDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        normalizers=new_normalizers,
        meta_step_size=theta,
        trace_decay=lam,
        normalizer_decay=tau,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

Bounder

Bases: ABC

Base class for update bounding strategies.

A bounder takes the proposed per-parameter step arrays from an optimizer and optionally scales them down to prevent overshooting.

to_config() abstractmethod

Serialize bounding configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def to_config(self) -> dict[str, Any]:
    """Serialize bounding configuration to dict."""
    ...

bound(steps, error, params) abstractmethod

Bound proposed update steps.

Args: steps: Per-parameter step arrays from the optimizer error: Prediction error scalar params: Current parameter values (needed by some bounders like AGC)

Returns: (bounded_steps, metric) where metric is a scalar for reporting (e.g., scale factor for ObGD, mean clip ratio for AGC)

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def bound(
    self,
    steps: tuple[Array, ...],
    error: Array,
    params: tuple[Array, ...],
) -> tuple[tuple[Array, ...], Array]:
    """Bound proposed update steps.

    Args:
        steps: Per-parameter step arrays from the optimizer
        error: Prediction error scalar
        params: Current parameter values (needed by some bounders like AGC)

    Returns:
        ``(bounded_steps, metric)`` where metric is a scalar for reporting
        (e.g., scale factor for ObGD, mean clip ratio for AGC)
    """
    ...

ObGD(step_size=1.0, kappa=2.0, gamma=0.0, lamda=0.0)

Bases: Optimizer[ObGDState]

Observation-bounded Gradient Descent optimizer.

ObGD prevents overshooting by dynamically bounding the effective step-size based on the magnitude of the prediction error and eligibility traces. When the combined update magnitude would be too large, the step-size is scaled down to prevent the prediction from overshooting the target.

This is the deep-network generalization of Autostep's overshooting prevention, designed for streaming reinforcement learning.

For supervised learning (gamma=0, lamda=0), traces equal the current observation each step, making ObGD equivalent to LMS with dynamic step-size bounding.

The ObGD algorithm:

  1. Update traces: z = gamma * lamda * z + observation
  2. Compute bound: M = alpha * kappa * max(|error|, 1) * (||z_w||_1 + |z_b|)
  3. Effective step: alpha_eff = min(alpha, alpha / M) (i.e. alpha / max(M, 1))
  4. Weight delta: delta_w = alpha_eff * error * z_w
  5. Bias delta: delta_b = alpha_eff * error * z_b

Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"

Attributes: step_size: Base learning rate alpha kappa: Bounding sensitivity parameter (higher = more conservative) gamma: Discount factor for trace decay (0 for supervised learning) lamda: Eligibility trace decay parameter (0 for supervised learning)

Args: step_size: Base learning rate (default: 1.0) kappa: Bounding sensitivity parameter (default: 2.0) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    step_size: float = 1.0,
    kappa: float = 2.0,
    gamma: float = 0.0,
    lamda: float = 0.0,
):
    """Initialize ObGD optimizer.

    Args:
        step_size: Base learning rate (default: 1.0)
        kappa: Bounding sensitivity parameter (default: 2.0)
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
    """
    self._step_size = step_size
    self._kappa = kappa
    self._gamma = gamma
    self._lamda = lamda

init_for_shape(shape)

Initialize optimizer state for parameters of arbitrary shape.

Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.

The return type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: shape: Shape of the parameter array

Returns: Initial optimizer state with arrays matching the given shape

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> Any:
    """Initialize optimizer state for parameters of arbitrary shape.

    Used by MLP learners where parameters are matrices/vectors of
    varying shapes. Not all optimizers support this.

    The return type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        shape: Shape of the parameter array

    Returns:
        Initial optimizer state with arrays matching the given shape

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support init_for_shape. "
        "Only LMS, IDBD, and Autostep currently implement this."
    )

update_from_gradient(state, gradient, error=None)

Compute step delta from pre-computed gradient.

The returned delta does NOT include the error -- the caller is responsible for multiplying error * delta before applying.

The state type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.

Returns: (step, new_state) where step has the same shape as gradient

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: Any, gradient: Array, error: Array | None = None
) -> tuple[Array, Any]:
    """Compute step delta from pre-computed gradient.

    The returned delta does NOT include the error -- the caller is
    responsible for multiplying ``error * delta`` before applying.

    The state type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        state: Current optimizer state
        gradient: Pre-computed gradient (e.g. eligibility trace)
        error: Optional prediction error scalar. Optimizers with
            meta-learning (e.g. Autostep) use this for meta-gradient
            computation. LMS ignores it.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support update_from_gradient. "
        "Only LMS, IDBD, and Autostep currently implement this."
    )

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {
        "type": "ObGD",
        "step_size": self._step_size,
        "kappa": self._kappa,
        "gamma": self._gamma,
        "lamda": self._lamda,
    }

init(feature_dim)

Initialize ObGD state.

Args: feature_dim: Dimension of weight vector

Returns: ObGD state with eligibility traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> ObGDState:
    """Initialize ObGD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        ObGD state with eligibility traces
    """
    return ObGDState(
        step_size=jnp.array(self._step_size, dtype=jnp.float32),
        kappa=jnp.array(self._kappa, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        gamma=jnp.array(self._gamma, dtype=jnp.float32),
        lamda=jnp.array(self._lamda, dtype=jnp.float32),
    )

update(state, error, observation)

Compute ObGD weight update with overshooting prevention.

The bounding mechanism scales down the step-size when the combined effect of error magnitude, trace norm, and step-size would cause the prediction to overshoot the target.

Args: state: Current ObGD state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with bounded weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: ObGDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute ObGD weight update with overshooting prevention.

    The bounding mechanism scales down the step-size when the combined
    effect of error magnitude, trace norm, and step-size would cause
    the prediction to overshoot the target.

    Args:
        state: Current ObGD state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with bounded weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    alpha = state.step_size
    kappa = state.kappa

    # Update eligibility traces: z = gamma * lamda * z + observation
    new_traces = state.gamma * state.lamda * state.traces + observation
    new_bias_trace = state.gamma * state.lamda * state.bias_trace + 1.0

    # Compute z_sum (L1 norm of all traces)
    z_sum = jnp.sum(jnp.abs(new_traces)) + jnp.abs(new_bias_trace)

    # Compute bounding factor: M = alpha * kappa * max(|error|, 1) * z_sum
    delta_bar = jnp.maximum(jnp.abs(error_scalar), 1.0)
    dot_product = delta_bar * z_sum * alpha * kappa

    # Effective step-size: alpha / max(M, 1)
    alpha_eff = alpha / jnp.maximum(dot_product, 1.0)

    # Weight and bias deltas
    weight_delta = alpha_eff * error_scalar * new_traces
    bias_delta = alpha_eff * error_scalar * new_bias_trace

    new_state = ObGDState(
        step_size=alpha,
        kappa=kappa,
        traces=new_traces,
        bias_trace=new_bias_trace,
        gamma=state.gamma,
        lamda=state.lamda,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "step_size": alpha,
            "effective_step_size": alpha_eff,
            "bounding_factor": dot_product,
        },
    )

ObGDBounding(kappa=2.0)

Bases: Bounder

ObGD-style global update bounding (Elsayed et al. 2024).

Computes a global bounding factor from the L1 norm of all proposed steps and the error magnitude, then uniformly scales all steps down if the combined update would be too large.

For LMS with a single scalar step-size alpha: total_step = alpha * z_sum, giving M = alpha * kappa * max(|error|, 1) * z_sum -- identical to the original Elsayed et al. 2024 formula.

Attributes: kappa: Bounding sensitivity parameter (higher = more conservative)

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, kappa: float = 2.0):
    self._kappa = kappa

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {"type": "ObGDBounding", "kappa": self._kappa}

bound(steps, error, params)

Bound proposed steps using ObGD formula.

Args: steps: Per-parameter step arrays error: Prediction error scalar params: Current parameter values (unused by ObGD)

Returns: (bounded_steps, scale) where scale is the bounding factor

Source code in src/alberta_framework/core/optimizers.py
def bound(
    self,
    steps: tuple[Array, ...],
    error: Array,
    params: tuple[Array, ...],
) -> tuple[tuple[Array, ...], Array]:
    """Bound proposed steps using ObGD formula.

    Args:
        steps: Per-parameter step arrays
        error: Prediction error scalar
        params: Current parameter values (unused by ObGD)

    Returns:
        ``(bounded_steps, scale)`` where scale is the bounding factor
    """
    del params  # ObGD bounds based on step/error magnitude only
    error_scalar = jnp.squeeze(error)
    total_step = jnp.array(0.0)
    for s in steps:
        total_step = total_step + jnp.sum(jnp.abs(s))
    delta_bar = jnp.maximum(jnp.abs(error_scalar), 1.0)
    bound_magnitude = self._kappa * delta_bar * total_step
    scale = 1.0 / jnp.maximum(bound_magnitude, 1.0)
    bounded = tuple(scale * s for s in steps)
    return bounded, scale

Optimizer

Bases: ABC

Base class for optimizers.

to_config() abstractmethod

Serialize optimizer configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def to_config(self) -> dict[str, Any]:
    """Serialize optimizer configuration to dict."""
    ...

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, error, observation) abstractmethod

Compute weight updates given prediction error.

Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute weight updates given prediction error.

    Args:
        state: Current optimizer state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with deltas and new state
    """
    ...

init_for_shape(shape)

Initialize optimizer state for parameters of arbitrary shape.

Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.

The return type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: shape: Shape of the parameter array

Returns: Initial optimizer state with arrays matching the given shape

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> Any:
    """Initialize optimizer state for parameters of arbitrary shape.

    Used by MLP learners where parameters are matrices/vectors of
    varying shapes. Not all optimizers support this.

    The return type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        shape: Shape of the parameter array

    Returns:
        Initial optimizer state with arrays matching the given shape

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support init_for_shape. "
        "Only LMS, IDBD, and Autostep currently implement this."
    )

update_from_gradient(state, gradient, error=None)

Compute step delta from pre-computed gradient.

The returned delta does NOT include the error -- the caller is responsible for multiplying error * delta before applying.

The state type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.

Returns: (step, new_state) where step has the same shape as gradient

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: Any, gradient: Array, error: Array | None = None
) -> tuple[Array, Any]:
    """Compute step delta from pre-computed gradient.

    The returned delta does NOT include the error -- the caller is
    responsible for multiplying ``error * delta`` before applying.

    The state type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        state: Current optimizer state
        gradient: Pre-computed gradient (e.g. eligibility trace)
        error: Optional prediction error scalar. Optimizers with
            meta-learning (e.g. Autostep) use this for meta-gradient
            computation. LMS ignores it.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support update_from_gradient. "
        "Only LMS, IDBD, and Autostep currently implement this."
    )

TDOptimizer

Bases: ABC

Base class for TD optimizers.

TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, td_error, observation, next_observation, gamma) abstractmethod

Compute weight updates given TD error.

Args: state: Current optimizer state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute weight updates given TD error.

    Args:
        state: Current optimizer state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with deltas and new state
    """
    ...

TDOptimizerUpdate

Result of a TD optimizer update step.

Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging

SARSAAgent(sarsa_config, hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None, prediction_demons=None, lamda=0.0)

On-policy SARSA control agent via Horde architecture.

Wraps HordeLearner with epsilon-greedy action selection and SARSA target computation. Each action maps to a control demon (head) in the Horde. The SARSA target r + gamma * Q(s', a') is computed externally and passed as the cumulant, so control demons use gamma=0 internally.

Optionally, additional prediction demons can coexist with the control demons — they learn alongside the Q-heads.

Single-Step (Daemon) Usage

Both select_action() and update() work with single unbatched observations (1D arrays). JIT-compiled automatically.

Attributes: sarsa_config: SARSA configuration horde: The underlying HordeLearner n_actions: Number of discrete actions

Args: sarsa_config: SARSA configuration (n_actions, gamma, epsilon) hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). step_size: Base learning rate (used only when optimizer is None) bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer sparsity: Fraction of weights zeroed out per neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization head_optimizer: Optional separate optimizer for heads prediction_demons: Optional additional prediction demons to learn alongside Q-heads. These are appended after the control demons in the Horde. lamda: Trace decay for control demon heads (default: 0.0)

Source code in src/alberta_framework/core/sarsa.py
def __init__(
    self,
    sarsa_config: SARSAConfig,
    hidden_sizes: tuple[int, ...] = (128, 128),
    optimizer: AnyOptimizer | None = None,
    step_size: float = 1.0,
    bounder: Bounder | None = None,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
    sparsity: float = 0.9,
    leaky_relu_slope: float = 0.01,
    use_layer_norm: bool = True,
    head_optimizer: AnyOptimizer | None = None,
    prediction_demons: list[GVFSpec] | None = None,
    lamda: float = 0.0,
):
    """Initialize the SARSA agent.

    Args:
        sarsa_config: SARSA configuration (n_actions, gamma, epsilon)
        hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
        optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
        step_size: Base learning rate (used only when optimizer is None)
        bounder: Optional update bounder (e.g. ObGDBounding)
        normalizer: Optional feature normalizer
        sparsity: Fraction of weights zeroed out per neuron (default: 0.9)
        leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
        use_layer_norm: Whether to apply parameterless layer normalization
        head_optimizer: Optional separate optimizer for heads
        prediction_demons: Optional additional prediction demons to
            learn alongside Q-heads. These are appended after the
            control demons in the Horde.
        lamda: Trace decay for control demon heads (default: 0.0)
    """
    self._sarsa_config = sarsa_config
    self._hidden_sizes = hidden_sizes
    self._lamda = lamda

    # Build HordeSpec: control demons first, then prediction demons
    control_demons = _make_control_demons(sarsa_config.n_actions, lamda=lamda)
    all_demons: list[GVFSpec] = list(control_demons)
    if prediction_demons is not None:
        all_demons.extend(prediction_demons)
    self._n_prediction_demons = len(prediction_demons) if prediction_demons else 0

    horde_spec = create_horde_spec(all_demons)

    self._horde = HordeLearner(
        horde_spec=horde_spec,
        hidden_sizes=hidden_sizes,
        optimizer=optimizer,
        step_size=step_size,
        bounder=bounder,
        normalizer=normalizer,
        sparsity=sparsity,
        leaky_relu_slope=leaky_relu_slope,
        use_layer_norm=use_layer_norm,
        head_optimizer=head_optimizer,
    )

sarsa_config property

The SARSA configuration.

horde property

The underlying HordeLearner.

n_actions property

Number of discrete actions.

to_config()

Serialize agent configuration to dict.

Source code in src/alberta_framework/core/sarsa.py
def to_config(self) -> dict[str, Any]:
    """Serialize agent configuration to dict."""
    horde_config = self._horde.to_config()
    # Remove fields managed by SARSAAgent
    horde_config.pop("type", None)
    horde_config.pop("horde_spec", None)

    # Extract prediction demon specs if any
    pred_demons = None
    if self._n_prediction_demons > 0:
        all_demons = self._horde.horde_spec.demons
        pred_demons = [
            d.to_config()
            for d in all_demons[self._sarsa_config.n_actions :]
        ]

    return {
        "type": "SARSAAgent",
        "sarsa_config": self._sarsa_config.to_config(),
        "lamda": self._lamda,
        "prediction_demons": pred_demons,
        **horde_config,
    }

from_config(config) classmethod

Reconstruct from config dict.

Source code in src/alberta_framework/core/sarsa.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "SARSAAgent":
    """Reconstruct from config dict."""
    from alberta_framework.core.normalizers import normalizer_from_config
    from alberta_framework.core.optimizers import (
        bounder_from_config,
        optimizer_from_config,
    )

    config = dict(config)
    config.pop("type", None)

    sarsa_config = SARSAConfig.from_config(config.pop("sarsa_config"))
    optimizer = optimizer_from_config(config.pop("optimizer"))
    bounder_cfg = config.pop("bounder", None)
    bounder = bounder_from_config(bounder_cfg) if bounder_cfg is not None else None
    normalizer_cfg = config.pop("normalizer", None)
    normalizer = (
        normalizer_from_config(normalizer_cfg) if normalizer_cfg is not None else None
    )
    head_opt_cfg = config.pop("head_optimizer", None)
    head_optimizer = (
        optimizer_from_config(head_opt_cfg) if head_opt_cfg is not None else None
    )
    pred_demons_cfg = config.pop("prediction_demons", None)
    prediction_demons = None
    if pred_demons_cfg is not None:
        prediction_demons = [GVFSpec.from_config(d) for d in pred_demons_cfg]

    return cls(
        sarsa_config=sarsa_config,
        hidden_sizes=tuple(config.pop("hidden_sizes")),
        optimizer=optimizer,
        bounder=bounder,
        normalizer=normalizer,
        head_optimizer=head_optimizer,
        prediction_demons=prediction_demons,
        **config,
    )

init(feature_dim, key)

Initialize SARSA agent state.

Args: feature_dim: Dimension of the input feature vector key: JAX random key

Returns: Initial SARSAState with zeroed last_action/observation

Source code in src/alberta_framework/core/sarsa.py
def init(self, feature_dim: int, key: Array) -> SARSAState:
    """Initialize SARSA agent state.

    Args:
        feature_dim: Dimension of the input feature vector
        key: JAX random key

    Returns:
        Initial SARSAState with zeroed last_action/observation
    """
    key, subkey = jr.split(key)
    learner_state = self._horde.init(feature_dim, subkey)

    return SARSAState(  # type: ignore[call-arg]
        learner_state=learner_state,
        last_action=jnp.array(-1, dtype=jnp.int32),
        last_observation=jnp.zeros(feature_dim, dtype=jnp.float32),
        epsilon=jnp.array(self._sarsa_config.epsilon_start, dtype=jnp.float32),
        rng_key=key,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

select_action(state, observation)

Select action via epsilon-greedy over Q-values.

JIT-compiled. Uses Gumbel trick for uniform tie-breaking among equal Q-values (avoids left-side bias from jnp.argmax).

Args: state: Current SARSA state (uses rng_key and epsilon) observation: Input feature vector

Returns: Tuple of (action, new_rng_key)

Source code in src/alberta_framework/core/sarsa.py
@functools.partial(jax.jit, static_argnums=(0,))
def select_action(
    self,
    state: SARSAState,
    observation: Array,
) -> tuple[Int[Array, ""], Array]:
    """Select action via epsilon-greedy over Q-values.

    JIT-compiled. Uses Gumbel trick for uniform tie-breaking among
    equal Q-values (avoids left-side bias from ``jnp.argmax``).

    Args:
        state: Current SARSA state (uses rng_key and epsilon)
        observation: Input feature vector

    Returns:
        Tuple of (action, new_rng_key)
    """
    key, explore_key, noise_key, random_key = jr.split(state.rng_key, 4)

    # Get Q-values (first n_actions heads are control demons)
    all_preds = self._horde.predict(state.learner_state, observation)
    q_values = all_preds[: self._sarsa_config.n_actions]

    # Greedy action with Gumbel tie-breaking
    # Add small noise only to max-valued actions for uniform tie-breaking
    gumbel_noise = jr.gumbel(noise_key, shape=q_values.shape) * 1e-6
    greedy_action = jnp.argmax(q_values + gumbel_noise).astype(jnp.int32)

    # Random action
    random_action = jr.randint(
        random_key, (), 0, self._sarsa_config.n_actions
    ).astype(jnp.int32)

    # Epsilon-greedy selection
    explore = jr.uniform(explore_key) < state.epsilon
    action = jax.lax.select(explore, random_action, greedy_action)

    return action, key

update(state, reward, observation, terminated, next_action, prediction_cumulants=None)

Perform one SARSA update step.

Computes the SARSA target r + gamma * Q(s', a') and updates the Horde. Only the previously-taken action's head receives the target; all other Q-heads get NaN (no update).

Args: state: Current SARSA state reward: Reward r received after taking last_action in last_obs observation: New observation s' (state we transitioned to) terminated: Whether s' is terminal (scalar bool/float) next_action: Action a' selected for s' (pre-computed) prediction_cumulants: Optional cumulants for prediction demons, shape (n_prediction_demons,). NaN for inactive demons.

Returns: SARSAUpdateResult with updated state, Q-values, TD error

Source code in src/alberta_framework/core/sarsa.py
@functools.partial(jax.jit, static_argnums=(0,))
def update(
    self,
    state: SARSAState,
    reward: Array,
    observation: Array,
    terminated: Array,
    next_action: Array,
    prediction_cumulants: Array | None = None,
) -> SARSAUpdateResult:
    """Perform one SARSA update step.

    Computes the SARSA target ``r + gamma * Q(s', a')`` and updates
    the Horde. Only the previously-taken action's head receives the
    target; all other Q-heads get NaN (no update).

    Args:
        state: Current SARSA state
        reward: Reward r received after taking last_action in last_obs
        observation: New observation s' (state we transitioned to)
        terminated: Whether s' is terminal (scalar bool/float)
        next_action: Action a' selected for s' (pre-computed)
        prediction_cumulants: Optional cumulants for prediction demons,
            shape ``(n_prediction_demons,)``. NaN for inactive demons.

    Returns:
        SARSAUpdateResult with updated state, Q-values, TD error
    """
    n_actions = self._sarsa_config.n_actions
    gamma = self._sarsa_config.gamma

    # Q(s', :) for all actions
    all_preds = self._horde.predict(state.learner_state, observation)
    q_next = all_preds[:n_actions]

    # SARSA target: r + gamma * Q(s', a') with terminal handling
    effective_gamma = jnp.where(terminated, 0.0, gamma)
    sarsa_target = reward + effective_gamma * q_next[next_action]

    # Build cumulants: NaN for all except last_action gets sarsa_target
    cumulants = jnp.full(self._horde.n_demons, jnp.nan, dtype=jnp.float32)
    # Only update the head corresponding to the action we took at s_t
    cumulants = cumulants.at[state.last_action].set(sarsa_target)

    # Add prediction demon cumulants if any
    if prediction_cumulants is not None:
        cumulants = cumulants.at[n_actions:].set(prediction_cumulants)

    # Horde update: learns from (s_t, cumulants, s')
    horde_result = self._horde.update(
        state.learner_state,
        state.last_observation,
        cumulants,
        observation,
    )

    # TD error for the taken action
    q_old = all_preds[state.last_action]
    td_error = sarsa_target - q_old

    # Epsilon decay
    cfg = self._sarsa_config
    new_step_count = state.step_count + 1
    new_epsilon = jax.lax.cond(
        cfg.epsilon_decay_steps > 0,
        lambda: jnp.maximum(
            cfg.epsilon_end,
            cfg.epsilon_start
            - (cfg.epsilon_start - cfg.epsilon_end)
            * new_step_count
            / cfg.epsilon_decay_steps,
        ),
        lambda: state.epsilon,
    )

    new_state = SARSAState(  # type: ignore[call-arg]
        learner_state=horde_result.state,
        last_action=next_action,
        last_observation=observation,
        epsilon=new_epsilon,
        rng_key=state.rng_key,
        step_count=new_step_count,
    )

    return SARSAUpdateResult(  # type: ignore[call-arg]
        state=new_state,
        action=next_action,
        q_values=q_next,
        td_error=td_error,
        reward=reward,
    )

SARSAArrayResult

Result from scan-based SARSA on pre-collected arrays.

Attributes: state: Final SARSA state q_values: Per-step Q-values, shape (num_steps, n_actions) td_errors: Per-step TD errors, shape (num_steps,) actions: Per-step actions taken, shape (num_steps,)

SARSAConfig

Configuration for SARSA agent.

Attributes: n_actions: Number of discrete actions gamma: Discount factor for SARSA targets (default: 0.99) epsilon_start: Initial exploration rate (default: 0.1) epsilon_end: Final exploration rate (default: 0.01) epsilon_decay_steps: Steps over which epsilon decays linearly. 0 = no decay (constant epsilon_start).

to_config()

Serialize to dict.

Source code in src/alberta_framework/core/sarsa.py
def to_config(self) -> dict[str, Any]:
    """Serialize to dict."""
    return {
        "n_actions": self.n_actions,
        "gamma": self.gamma,
        "epsilon_start": self.epsilon_start,
        "epsilon_end": self.epsilon_end,
        "epsilon_decay_steps": self.epsilon_decay_steps,
    }

from_config(config) classmethod

Reconstruct from config dict.

Source code in src/alberta_framework/core/sarsa.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "SARSAConfig":
    """Reconstruct from config dict."""
    return cls(**config)

SARSAContinuingResult(state, total_reward, rewards, q_values, td_errors) dataclass

Result from running SARSA in continuing mode.

Not a chex dataclass — used in Python loops with native Python types.

Attributes: state: Final SARSA state total_reward: Sum of rewards over all steps rewards: Per-step rewards q_values: Per-step Q-values td_errors: Per-step TD errors

SARSAEpisodeResult(state, total_reward, num_steps, rewards, q_values, td_errors) dataclass

Result from running one episode of SARSA.

Not a chex dataclass — used in Python loops with native Python types.

Attributes: state: Final SARSA state total_reward: Sum of rewards in the episode num_steps: Number of steps taken rewards: Per-step rewards q_values: Per-step Q-values td_errors: Per-step TD errors

SARSAState

State for the SARSA agent.

Attributes: learner_state: Underlying Horde/MultiHeadMLPLearner state last_action: Action taken at previous step (a_t) last_observation: Observation at previous step (s_t) epsilon: Current exploration rate rng_key: JAX random key for action selection step_count: Number of SARSA update steps taken

SARSAUpdateResult

Result of a single SARSA update step.

Attributes: state: Updated SARSA state (includes new action a_{t+1}) action: Next action a_{t+1} selected for the new state q_values: Q-values for all actions at s_{t+1} td_error: TD error for the taken action reward: Reward received

AutostepParamState

Per-parameter Autostep state for use with arbitrary-shape parameters.

Used by Autostep.init_for_shape / Autostep.update_from_gradient for MLP (or other multi-parameter) learners. Unlike AutostepState, this type has no bias-specific fields -- each parameter (weight matrix, bias vector) gets its own AutostepParamState.

Attributes: step_sizes: Per-element step-sizes, same shape as the parameter traces: Per-element traces for gradient correlation normalizers: Running normalizer of meta-gradient magnitude |deltazh| meta_step_size: Meta learning rate mu tau: Time constant for normalizer adaptation

AutostepState

State for the Autostep optimizer.

Autostep is a tuning-free step-size adaptation algorithm that adapts per-weight step-sizes based on meta-gradient correlation, with self-regulated normalizers to stabilize the meta-update.

Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation", Table 1

Attributes: step_sizes: Per-weight step-sizes (alpha_i) traces: Per-weight traces for gradient correlation (h_i) normalizers: Running normalizer of meta-gradient magnitude |deltaxh| (v_i) meta_step_size: Meta learning rate mu for adapting step-sizes tau: Time constant for normalizer adaptation (higher = slower decay) bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term bias_normalizer: Normalizer for the bias meta-gradient

AutoTDIDBDState

State for the AutoTDIDBD optimizer.

AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.

Reference: Kearney et al. 2019, Algorithm 6

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation

BatchedLearningResult

Result from batched learning loop across multiple seeds.

Used with run_learning_loop_batched for vmap-based GPU parallelization.

Attributes: states: Batched learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, num_cols) where num_cols is 3 (no normalizer) or 4 (with normalizer) step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled

BatchedMLPResult

Result from batched MLP learning loop across multiple seeds.

Used with run_mlp_learning_loop_batched for vmap-based GPU parallelization.

Attributes: states: Batched MLP learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, num_cols) where num_cols is 3 (no normalizer) or 4 (with normalizer) normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled

DemonType

Bases: Enum

Type of GVF demon.

A prediction demon has a fixed policy and learns to predict. A control demon learns a policy (e.g. via SARSA) — Step 4.

GVFSpec

One GVF demon's question functions (Sutton et al. 2011).

Declarative, not callable — JAX pytree-compatible. Cumulant values are computed externally and passed as arrays.

Attributes: name: Human-readable name for this demon demon_type: Whether this is a prediction or control demon gamma: Pseudo-termination discount (0.0 = single-step prediction) lamda: Trace decay parameter (0.0 = no eligibility traces) cumulant_index: Index into targets array, or -1 for external cumulant terminal_reward: Terminal pseudo-reward z (default 0.0)

to_config()

Serialize to dict.

Returns: Dict with all fields needed to recreate the GVFSpec.

Source code in src/alberta_framework/core/types.py
def to_config(self) -> dict[str, Any]:
    """Serialize to dict.

    Returns:
        Dict with all fields needed to recreate the GVFSpec.
    """
    return {
        "name": self.name,
        "demon_type": self.demon_type.value,
        "gamma": self.gamma,
        "lamda": self.lamda,
        "cumulant_index": self.cumulant_index,
        "terminal_reward": self.terminal_reward,
    }

from_config(config) classmethod

Reconstruct from config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed GVFSpec

Source code in src/alberta_framework/core/types.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GVFSpec":
    """Reconstruct from config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed GVFSpec
    """
    config = dict(config)
    config["demon_type"] = DemonType(config["demon_type"])
    return cls(**config)

HordeSpec

Collection of GVF demons, one per head.

Attributes: demons: Tuple of GVFSpec, one per demon/head gammas: Pre-computed gamma array for JIT, shape (n_demons,) lamdas: Pre-computed lambda array for JIT, shape (n_demons,)

to_config()

Serialize to dict.

Returns: Dict with demons list, each serialized via GVFSpec.to_config().

Source code in src/alberta_framework/core/types.py
def to_config(self) -> dict[str, Any]:
    """Serialize to dict.

    Returns:
        Dict with demons list, each serialized via ``GVFSpec.to_config()``.
    """
    return {
        "demons": [d.to_config() for d in self.demons],
    }

from_config(config) classmethod

Reconstruct from config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed HordeSpec via create_horde_spec

Source code in src/alberta_framework/core/types.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "HordeSpec":
    """Reconstruct from config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed HordeSpec via ``create_horde_spec``
    """
    demons = [GVFSpec.from_config(d) for d in config["demons"]]
    return create_horde_spec(demons)

IDBDParamState

Per-parameter IDBD state for use with arbitrary-shape parameters.

Used by IDBD.init_for_shape / IDBD.update_from_gradient for MLP (or other multi-parameter) learners. Unlike IDBDState, this type has no bias-specific fields -- each parameter (weight matrix, bias vector) gets its own IDBDParamState.

Implements Meyer's adaptation of IDBD for nonlinear models: replaces x^2 in the h-decay term with (dy/dw)^2 (squared prediction gradients), which generalizes IDBD to arbitrary architectures.

Reference: Meyer, https://github.com/ejmejm/phd_research

Attributes: log_step_sizes: Log of per-element step-sizes, same shape as the parameter traces: Per-element h traces for gradient correlation meta_step_size: Meta learning rate beta

IDBDState

State for the IDBD (Incremental Delta-Bar-Delta) optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.

Reference: Sutton 1992, "Adapting Bias by Gradient Descent"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term

LearnerState

State for a linear learner.

Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer normalizer_state: Optional state for online feature normalization

LMSState

State for the LMS (Least Mean Square) optimizer.

LMS uses a fixed step-size, so state only tracks the step-size parameter.

Attributes: step_size: Fixed learning rate alpha

MLPLearnerState

State for an MLP learner.

Attributes: params: MLP parameters (weights and biases for each layer) optimizer_states: Tuple of per-parameter optimizer states (weights + biases) traces: Tuple of per-parameter eligibility traces normalizer_state: Optional state for online feature normalization

MLPParams

Parameters for a multi-layer perceptron.

Uses tuples of arrays (not lists) for proper JAX PyTree handling.

Attributes: weights: Tuple of weight matrices, one per layer biases: Tuple of bias vectors, one per layer

NormalizerHistory

History of per-feature normalizer state recorded during training.

Used for analyzing how the normalizer (EMA or Welford) adapts to distribution shifts (reactive lag diagnostic).

Attributes: means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim) variances: Per-feature variance estimates at each recording, shape (num_recordings, feature_dim) recording_indices: Step indices where recordings were made, shape (num_recordings,)

NormalizerTrackingConfig

Configuration for recording per-feature normalizer state during training.

Attributes: interval: Record normalizer state every N steps

ObGDState

State for the ObGD (Observation-bounded Gradient Descent) optimizer.

ObGD prevents overshooting by dynamically bounding the effective step-size based on the magnitude of the TD error and eligibility traces. When the combined update magnitude would be too large, the step-size is scaled down.

For supervised learning (gamma=0, lamda=0), traces equal the current observation each step, making ObGD equivalent to LMS with dynamic step-size bounding.

Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"

Attributes: step_size: Base learning rate alpha kappa: Bounding sensitivity parameter (higher = more conservative) traces: Per-weight eligibility traces z_i bias_trace: Eligibility trace for the bias term gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter lambda

StepSizeHistory

History of per-weight step-sizes recorded during training.

Attributes: step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights) bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None recording_indices: Step indices where recordings were made, shape (num_recordings,) normalizers: Autostep's per-weight normalizers (v_i) at each recording, shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.

StepSizeTrackingConfig

Configuration for recording per-weight step-sizes during training.

Attributes: interval: Record step-sizes every N steps include_bias: Whether to also record the bias step-size

TDIDBDState

State for the TD-IDBD (Temporal-Difference IDBD) optimizer.

TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term

TDLearnerState

State for a TD linear learner.

Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer

TDTimeStep

Single experience from a TD stream.

Represents a transition (s, r, s', gamma) for temporal-difference learning.

Attributes: observation: Feature vector phi(s) reward: Reward R received next_observation: Feature vector phi(s') gamma: Discount factor gamma_t (0 at terminal states)

TimeStep

Single experience from an experience stream.

Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)

ScanStream

Bases: Protocol[StateT]

Protocol for JAX scan-compatible experience streams.

Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.

The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.

Type Parameters: StateT: The state type maintained by this stream

Examples:

stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key for initialization

Returns: Initial stream state

Source code in src/alberta_framework/streams/base.py
def init(self, key: Array) -> StateT:
    """Initialize stream state.

    Args:
        key: JAX random key for initialization

    Returns:
        Initial stream state
    """
    ...

step(state, idx)

Generate one time step. Must be JIT-compatible.

This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.

Args: state: Current stream state idx: Current step index (can be ignored for most streams)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/base.py
def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
    """Generate one time step. Must be JIT-compatible.

    This is a pure function that takes the current state and step index,
    and returns a TimeStep along with the updated state. The step index
    can be used for time-dependent behavior but is often ignored.

    Args:
        state: Current stream state
        idx: Current step index (can be ignored for most streams)

    Returns:
        Tuple of (timestep, new_state)
    """
    ...

AbruptChangeState

State for AbruptChangeStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken

AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)

Non-stationary stream with sudden target weight changes.

Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.

Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    change_interval: int = 1000,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the abrupt change stream.

    Args:
        feature_dim: Dimension of feature vectors
        change_interval: Steps between abrupt weight changes
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._change_interval = change_interval
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> AbruptChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return AbruptChangeState(
        key=key,
        true_weights=weights,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_weights, key_x, key_noise = jr.split(state.key, 4)

    # Determine if we should change weights
    should_change = state.step_count % self._change_interval == 0

    # Generate new weights (always generated but only used if should_change)
    new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)

    # Use jnp.where to conditionally update weights (JIT-compatible)
    new_weights = jnp.where(should_change, new_random_weights, state.true_weights)

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = AbruptChangeState(
        key=key,
        true_weights=new_weights,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

CyclicState

State for CyclicStream.

Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken

CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)

Non-stationary stream that cycles between known weight configurations.

Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.

Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    cycle_length: int = 500,
    num_configurations: int = 4,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the cyclic target stream.

    Args:
        feature_dim: Dimension of feature vectors
        cycle_length: Steps spent in each configuration
        num_configurations: Number of configurations to cycle through
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._cycle_length = cycle_length
    self._num_configurations = num_configurations
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with pre-generated configurations

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> CyclicState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with pre-generated configurations
    """
    key, key_configs = jr.split(key)
    configurations = jr.normal(
        key_configs,
        (self._num_configurations, self._feature_dim),
        dtype=jnp.float32,
    )
    return CyclicState(
        key=key,
        configurations=configurations,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Get current configuration index
    config_idx = (state.step_count // self._cycle_length) % self._num_configurations
    true_weights = state.configurations[config_idx]

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = CyclicState(
        key=key,
        configurations=state.configurations,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

DynamicScaleShiftState

State for DynamicScaleShiftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken

DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)

Non-stationary stream with abruptly changing feature scales.

Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.

The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).

Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    scale_change_interval: int = 2000,
    weight_change_interval: int = 1000,
    min_scale: float = 0.01,
    max_scale: float = 100.0,
    noise_std: float = 0.1,
):
    """Initialize the dynamic scale shift stream.

    Args:
        feature_dim: Dimension of feature vectors
        scale_change_interval: Steps between abrupt scale changes
        weight_change_interval: Steps between abrupt weight changes
        min_scale: Minimum scale factor (log-uniform sampling)
        max_scale: Maximum scale factor (log-uniform sampling)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._scale_change_interval = scale_change_interval
    self._weight_change_interval = weight_change_interval
    self._min_scale = min_scale
    self._max_scale = max_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> DynamicScaleShiftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and scales
    """
    key, k_weights, k_scales = jr.split(key, 3)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial scales: log-uniform between min and max
    log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    scales = jnp.exp(log_scales).astype(jnp.float32)
    return DynamicScaleShiftState(
        key=key,
        true_weights=weights,
        current_scales=scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: DynamicScaleShiftState, idx: Array
) -> tuple[TimeStep, DynamicScaleShiftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)

    # Check if scales should change
    should_change_scales = state.step_count % self._scale_change_interval == 0
    new_log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
    new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)

    # Check if weights should change
    should_change_weights = state.step_count % self._weight_change_interval == 0
    new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    x = raw_x * new_scales

    # Target from true weights using RAW features (for consistent difficulty)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = DynamicScaleShiftState(
        key=key,
        true_weights=new_weights,
        current_scales=new_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

PeriodicChangeState

State for PeriodicChangeStream.

Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken

PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights oscillate sinusoidally.

Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase) where each weight has a random phase offset for diversity.

This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.

Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    period: int = 1000,
    amplitude: float = 1.0,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the periodic change stream.

    Args:
        feature_dim: Dimension of feature vectors
        period: Steps for one complete oscillation cycle
        amplitude: Magnitude of weight oscillations around base
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._period = period
    self._amplitude = amplitude
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random base weights and phases

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> PeriodicChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random base weights and phases
    """
    key, key_weights, key_phases = jr.split(key, 3)
    base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
    # Random phases in [0, 2π) for each weight
    phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
    return PeriodicChangeState(
        key=key,
        base_weights=base_weights,
        phases=phases,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: PeriodicChangeState, idx: Array) -> tuple[TimeStep, PeriodicChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
    t = state.step_count.astype(jnp.float32)
    oscillation = self._amplitude * jnp.sin(2.0 * jnp.pi * t / self._period + state.phases)
    true_weights = state.base_weights + oscillation

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = PeriodicChangeState(
        key=key,
        base_weights=state.base_weights,
        phases=state.phases,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

RandomWalkState

State for RandomWalkStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights

RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights drift via random walk.

The true target function is linear: y* = w_true @ x + noise where w_true evolves via random walk at each time step.

This tests the learner's ability to continuously track a moving target.

Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    drift_rate: float = 0.001,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the random walk target stream.

    Args:
        feature_dim: Dimension of the feature/observation vectors
        drift_rate: Std dev of weight changes per step (controls non-stationarity)
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._drift_rate = drift_rate
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> RandomWalkState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return RandomWalkState(key=key, true_weights=weights)

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_drift, k_x, k_noise = jr.split(state.key, 4)

    # Drift weights
    drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
    new_weights = state.true_weights + self._drift_rate * drift

    # Generate observation and target
    x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = RandomWalkState(key=key, true_weights=new_weights)

    return timestep, new_state

ScaleDriftState

State for ScaleDriftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken

ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)

Non-stationary stream where feature scales drift via random walk.

Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.

The target is computed from unscaled features to maintain consistent difficulty across scale changes.

Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    weight_drift_rate: float = 0.001,
    scale_drift_rate: float = 0.01,
    min_log_scale: float = -4.0,  # exp(-4) ~ 0.018
    max_log_scale: float = 4.0,  # exp(4) ~ 54.6
    noise_std: float = 0.1,
):
    """Initialize the scale drift stream.

    Args:
        feature_dim: Dimension of feature vectors
        weight_drift_rate: Std dev of weight drift per step
        scale_drift_rate: Std dev of log-scale drift per step
        min_log_scale: Minimum log-scale (clips drift)
        max_log_scale: Maximum log-scale (clips drift)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._weight_drift_rate = weight_drift_rate
    self._scale_drift_rate = scale_drift_rate
    self._min_log_scale = min_log_scale
    self._max_log_scale = max_log_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and unit scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaleDriftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and unit scales
    """
    key, k_weights = jr.split(key)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial log-scales at 0 (scale = 1)
    log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
    return ScaleDriftState(
        key=key,
        true_weights=weights,
        log_scales=log_scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaleDriftState, idx: Array) -> tuple[TimeStep, ScaleDriftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)

    # Drift target weights
    weight_drift = self._weight_drift_rate * jr.normal(
        k_w_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_weights = state.true_weights + weight_drift

    # Drift log-scales (bounded random walk)
    scale_drift = self._scale_drift_rate * jr.normal(
        k_s_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_log_scales = state.log_scales + scale_drift
    new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    scales = jnp.exp(new_log_scales)
    x = raw_x * scales

    # Target from true weights using RAW features
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = ScaleDriftState(
        key=key,
        true_weights=new_weights,
        log_scales=new_log_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

ScaledStreamState

State for ScaledStreamWrapper.

Attributes: inner_state: State of the wrapped stream

ScaledStreamWrapper(inner_stream, feature_scales)

Wrapper that applies per-feature scaling to any stream's observations.

This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.

Examples:

stream = ScaledStreamWrapper(
    AbruptChangeStream(feature_dim=10, change_interval=1000),
    feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
                              100.0, 1000.0, 0.001, 0.01, 0.1])
)

Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)

Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.

Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim

Source code in src/alberta_framework/streams/synthetic.py
def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
    """Initialize the scaled stream wrapper.

    Args:
        inner_stream: Stream to wrap (must implement ScanStream protocol)
        feature_scales: Array of scale factors, one per feature. Must have
            shape (feature_dim,) matching the inner stream's feature_dim.

    Raises:
        ValueError: If feature_scales length doesn't match inner stream's feature_dim
    """
    self._inner_stream: ScanStream[Any] = inner_stream
    self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)

    if self._feature_scales.shape[0] != inner_stream.feature_dim:
        raise ValueError(
            f"feature_scales length ({self._feature_scales.shape[0]}) "
            f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
        )

feature_dim property

Return the dimension of observation vectors.

inner_stream property

Return the wrapped stream.

feature_scales property

Return the per-feature scale factors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state wrapping the inner stream's state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaledStreamState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state wrapping the inner stream's state
    """
    inner_state = self._inner_stream.init(key)
    return ScaledStreamState(inner_state=inner_state)

step(state, idx)

Generate one time step with scaled observations.

Args: state: Current stream state idx: Current step index

Returns: Tuple of (timestep with scaled observation, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
    """Generate one time step with scaled observations.

    Args:
        state: Current stream state
        idx: Current step index

    Returns:
        Tuple of (timestep with scaled observation, new_state)
    """
    timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)

    # Scale the observation
    scaled_observation = timestep.observation * self._feature_scales

    scaled_timestep = TimeStep(
        observation=scaled_observation,
        target=timestep.target,
    )

    new_state = ScaledStreamState(inner_state=new_inner_state)
    return scaled_timestep, new_state

SuttonExperiment1State

State for SuttonExperiment1Stream.

Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken

SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)

Non-stationary stream replicating Experiment 1 from Sutton 1992.

This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)

Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    num_relevant: int = 5,
    num_irrelevant: int = 15,
    change_interval: int = 20,
):
    """Initialize the Sutton Experiment 1 stream.

    Args:
        num_relevant: Number of relevant inputs with ±1 weights
        num_irrelevant: Number of irrelevant inputs with 0 weights
        change_interval: Number of steps between sign flips
    """
    self._num_relevant = num_relevant
    self._num_irrelevant = num_irrelevant
    self._change_interval = change_interval

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with all +1 signs

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> SuttonExperiment1State:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with all +1 signs
    """
    signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
    return SuttonExperiment1State(
        key=key,
        signs=signs,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: SuttonExperiment1State, idx: Array
) -> tuple[TimeStep, SuttonExperiment1State]:
    """Generate one time step.

    At each step:
    1. If at a change interval (and not step 0), flip one random sign
    2. Generate random inputs from N(0, 1)
    3. Compute target as sum of relevant inputs weighted by signs

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_which = jr.split(state.key, 3)

    # Determine if we should flip a sign (not at step 0)
    should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)

    # Select which sign to flip
    idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)

    # Create flip mask
    flip_mask = jnp.where(
        jnp.arange(self._num_relevant) == idx_to_flip,
        jnp.array(-1.0, dtype=jnp.float32),
        jnp.array(1.0, dtype=jnp.float32),
    )

    # Apply flip mask conditionally
    new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)

    # Generate observation from N(0, 1)
    x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)

    # Compute target: sum of first num_relevant inputs weighted by signs
    target = jnp.dot(new_signs, x[: self._num_relevant])

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = SuttonExperiment1State(
        key=key,
        signs=new_signs,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

Timer(name='Operation', verbose=True, print_fn=None)

Context manager for timing code execution.

Measures wall-clock time for a block of code and optionally prints the duration when the block completes.

Attributes: name: Description of what is being timed duration: Elapsed time in seconds (available after context exits) start_time: Timestamp when timing started end_time: Timestamp when timing ended

Examples:

with Timer("Training loop"):
    for i in range(1000):
        pass
# Output: Training loop completed in 0.01s

# Silent timing (no print):
with Timer("Silent", verbose=False) as t:
    time.sleep(0.1)
print(f"Elapsed: {t.duration:.2f}s")
# Output: Elapsed: 0.10s

# Custom print function:
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
    pass
# Output: >> Custom completed in 0.00s

Args: name: Description of the operation being timed verbose: Whether to print the duration when done print_fn: Custom print function (defaults to built-in print)

Source code in src/alberta_framework/utils/timing.py
def __init__(
    self,
    name: str = "Operation",
    verbose: bool = True,
    print_fn: Callable[[str], None] | None = None,
):
    """Initialize the timer.

    Args:
        name: Description of the operation being timed
        verbose: Whether to print the duration when done
        print_fn: Custom print function (defaults to built-in print)
    """
    self.name = name
    self.verbose = verbose
    self.print_fn = print_fn or print
    self.start_time: float = 0.0
    self.end_time: float = 0.0
    self.duration: float = 0.0

elapsed()

Get elapsed time since timer started (can be called during execution).

Returns: Elapsed time in seconds

Source code in src/alberta_framework/utils/timing.py
def elapsed(self) -> float:
    """Get elapsed time since timer started (can be called during execution).

    Returns:
        Elapsed time in seconds
    """
    return time.perf_counter() - self.start_time

GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)

Experience stream from a Gymnasium environment using Python loop.

This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.

For batch learning, use collect_trajectory() followed by learn_from_trajectory().

Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes

Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
):
    """Initialize the Gymnasium stream.

    Args:
        env: Gymnasium environment instance
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action).
            If False, features = obs only
        seed: Random seed for environment resets and random policy
    """
    self._env = env
    self._mode = mode
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    if mode == PredictionMode.NEXT_STATE:
        self._target_dim = self._obs_dim
    else:
        self._target_dim = 1

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_estimator: Callable[[Array], float] | None = None

feature_dim property

Return the dimension of feature vectors.

target_dim property

Return the dimension of target vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

mode property

Return the prediction mode.

set_value_estimator(estimator)

Set the value estimator for proper TD learning in VALUE mode.

Source code in src/alberta_framework/streams/gymnasium.py
def set_value_estimator(self, estimator: Callable[[Array], float]) -> None:
    """Set the value estimator for proper TD learning in VALUE mode."""
    self._value_estimator = estimator

PredictionMode

Bases: Enum

Mode for what the stream predicts.

REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)

TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)

Experience stream for proper TD learning with value function bootstrap.

This stream integrates with a learner to use its predictions for bootstrapping in TD targets.

Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)

for step, timestep in enumerate(stream):
    result = learner.update(state, timestep.observation, timestep.target)
    state = result.state
    stream.update_value_function(lambda x: learner.predict(state, x))

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = False,
    seed: int = 0,
):
    """Initialize the TD stream.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor
        include_action_in_features: If True, learn Q(s,a). If False, learn V(s)
        seed: Random seed
    """
    self._env = env
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_fn: Callable[[Array], float] = lambda x: 0.0

feature_dim property

Return the dimension of feature vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

update_value_function(value_fn)

Update the value function used for TD bootstrapping.

Source code in src/alberta_framework/streams/gymnasium.py
def update_value_function(self, value_fn: Callable[[Array], float]) -> None:
    """Update the value function used for TD bootstrapping."""
    self._value_fn = value_fn

checkpoint_exists(path)

Check whether a checkpoint exists at the given path.

Args: path: Path to check for a checkpoint directory.

Returns: True if a checkpoint directory exists at the path.

Source code in src/alberta_framework/core/checkpoints.py
def checkpoint_exists(path: str | Path) -> bool:
    """Check whether a checkpoint exists at the given path.

    Args:
        path: Path to check for a checkpoint directory.

    Returns:
        True if a checkpoint directory exists at the path.
    """
    path = Path(path)
    # Orbax checkpoints are directories containing a state/ subdirectory
    return path.is_dir() and (path / "state").is_dir()

load_checkpoint(state_template, path)

Load checkpoint into a state matching the template's tree structure.

The template state (from learner.init()) provides the PyTree structure for deserialization.

Args: state_template: A state of the same type and structure as the saved state. Typically created via learner.init() with the same architecture. path: Path to the checkpoint directory.

Returns: Tuple of (loaded_state, user_metadata) where user_metadata is the dict passed to save_checkpoint, or an empty dict if none was provided.

Raises: FileNotFoundError: If checkpoint directory doesn't exist ValueError: If state structure doesn't match template

Source code in src/alberta_framework/core/checkpoints.py
def load_checkpoint(
    state_template: Any,
    path: str | Path,
) -> tuple[Any, dict[str, Any]]:
    """Load checkpoint into a state matching the template's tree structure.

    The template state (from ``learner.init()``) provides the PyTree
    structure for deserialization.

    Args:
        state_template: A state of the same type and structure as the
            saved state. Typically created via ``learner.init()`` with
            the same architecture.
        path: Path to the checkpoint directory.

    Returns:
        Tuple of ``(loaded_state, user_metadata)`` where ``user_metadata``
        is the dict passed to ``save_checkpoint``, or an empty dict if
        none was provided.

    Raises:
        FileNotFoundError: If checkpoint directory doesn't exist
        ValueError: If state structure doesn't match template
    """
    path = Path(path)

    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")

    try:
        with ocp.Checkpointer(ocp.CompositeCheckpointHandler()) as ckptr:
            loaded = ckptr.restore(
                str(path),
                args=ocp.args.Composite(
                    state=ocp.args.StandardRestore(state_template),
                    metadata=ocp.args.JsonRestore(),
                ),
            )
    except ValueError as e:
        raise ValueError(
            f"State structure mismatch. "
            f"Ensure the learner architecture matches the saved checkpoint. "
            f"Original error: {e}"
        ) from e

    user_metadata = dict(loaded.metadata)
    user_metadata.pop(_VERSION_KEY, None)
    return loaded.state, user_metadata

load_checkpoint_metadata(path)

Load only the user metadata from a checkpoint, without a state template.

This is useful when metadata contains configuration needed to construct the state template (e.g. learner_config in rlsecd).

Args: path: Path to the checkpoint directory.

Returns: The user metadata dict, or an empty dict if none was stored.

Raises: FileNotFoundError: If checkpoint directory doesn't exist

Source code in src/alberta_framework/core/checkpoints.py
def load_checkpoint_metadata(path: str | Path) -> dict[str, Any]:
    """Load only the user metadata from a checkpoint, without a state template.

    This is useful when metadata contains configuration needed to construct
    the state template (e.g. learner_config in rlsecd).

    Args:
        path: Path to the checkpoint directory.

    Returns:
        The user metadata dict, or an empty dict if none was stored.

    Raises:
        FileNotFoundError: If checkpoint directory doesn't exist
    """
    path = Path(path)

    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")

    with ocp.Checkpointer(ocp.CompositeCheckpointHandler()) as ckptr:
        loaded = ckptr.restore(
            str(path),
            args=ocp.args.Composite(
                metadata=ocp.args.JsonRestore(),
            ),
        )

    user_metadata = dict(loaded.metadata)
    user_metadata.pop(_VERSION_KEY, None)
    return user_metadata

save_checkpoint(state, path, metadata=None)

Save learner state to disk.

Creates a checkpoint directory at path containing the serialized state PyTree and optional user metadata as JSON.

Args: state: Any learner state (LearnerState, MLPLearnerState, MultiHeadMLPState, TDLearnerState) path: Path for the checkpoint directory. metadata: Optional user metadata dict to store alongside the checkpoint (e.g. epoch, learner config, etc.)

Source code in src/alberta_framework/core/checkpoints.py
def save_checkpoint(
    state: Any,
    path: str | Path,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Save learner state to disk.

    Creates a checkpoint directory at ``path`` containing the serialized
    state PyTree and optional user metadata as JSON.

    Args:
        state: Any learner state (LearnerState, MLPLearnerState,
            MultiHeadMLPState, TDLearnerState)
        path: Path for the checkpoint directory.
        metadata: Optional user metadata dict to store alongside
            the checkpoint (e.g. epoch, learner config, etc.)
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    meta_to_save = {_VERSION_KEY: _FORMAT_VERSION}
    if metadata is not None:
        meta_to_save.update(metadata)

    with ocp.Checkpointer(ocp.CompositeCheckpointHandler()) as ckptr:
        ckptr.save(
            str(path),
            args=ocp.args.Composite(
                state=ocp.args.StandardSave(state),
                metadata=ocp.args.JsonSave(meta_to_save),
            ),
        )

compute_feature_relevance(state)

Extract per-feature relevance metrics from multi-head learner state.

All metrics are computed from existing state arrays via small matrix multiplies. Typical cost: ~10-50us after JIT for a (64,64) trunk with 5 heads and 12 features.

Args: state: Current multi-head MLP learner state.

Returns: FeatureRelevance dataclass with all Tier 1 metrics.

Source code in src/alberta_framework/core/diagnostics.py
def compute_feature_relevance(state: MultiHeadMLPState) -> FeatureRelevance:
    """Extract per-feature relevance metrics from multi-head learner state.

    All metrics are computed from existing state arrays via small matrix
    multiplies. Typical cost: ~10-50us after JIT for a (64,64) trunk
    with 5 heads and 12 features.

    Args:
        state: Current multi-head MLP learner state.

    Returns:
        ``FeatureRelevance`` dataclass with all Tier 1 metrics.
    """
    n_heads = len(state.head_params.weights)
    n_trunk_layers = len(state.trunk_params.weights)

    # --- Weight relevance (path-norm) ---
    # Build path through trunk: |W0|, then |W1| @ path, ...
    # trunk_params.weights[0] has shape (H0, feature_dim)
    if n_trunk_layers > 0:
        path = jnp.abs(state.trunk_params.weights[0])  # (H0, feature_dim)
        for i in range(1, n_trunk_layers):
            path = jnp.abs(state.trunk_params.weights[i]) @ path  # (H_i, feature_dim)

        # Per-head: |head_w[h]| @ path -> (1, feature_dim) -> squeeze to (feature_dim,)
        weight_relevance_list = []
        for h in range(n_heads):
            head_w = jnp.abs(state.head_params.weights[h])  # (1, H_last)
            rel = head_w @ path  # (1, feature_dim)
            weight_relevance_list.append(jnp.squeeze(rel, axis=0))
        weight_relevance = jnp.stack(weight_relevance_list)  # (n_heads, feature_dim)
    else:
        # No trunk: heads project directly from input features
        weight_relevance_list = []
        for h in range(n_heads):
            weight_relevance_list.append(jnp.abs(jnp.squeeze(state.head_params.weights[h], axis=0)))
        weight_relevance = jnp.stack(weight_relevance_list)

    # --- Step-size activity on input layer ---
    # Trunk optimizer states are interleaved: (w0, b0, w1, b1, ...)
    # Index 0 = input weights optimizer state
    if n_trunk_layers > 0:
        input_opt_state = state.trunk_optimizer_states[0]
        if hasattr(input_opt_state, "step_sizes"):
            # AutostepParamState: step_sizes has shape (H0, feature_dim)
            step_size_activity = jnp.mean(jnp.abs(input_opt_state.step_sizes), axis=0)
        elif hasattr(input_opt_state, "step_size"):
            # LMSState: scalar step_size, uniform across features
            feature_dim = state.trunk_params.weights[0].shape[1]
            step_size_activity = jnp.full(feature_dim, jnp.abs(input_opt_state.step_size))
        else:
            feature_dim = state.trunk_params.weights[0].shape[1]
            step_size_activity = jnp.zeros(feature_dim)
    else:
        # No trunk layers — use head info
        feature_dim = state.head_params.weights[0].shape[1]
        step_size_activity = jnp.zeros(feature_dim)

    # --- Trace activity on input layer ---
    # trunk_traces interleaved: (w0, b0, w1, b1, ...)
    # Index 0 = input weight traces, shape (H0, feature_dim)
    if n_trunk_layers > 0:
        input_traces = state.trunk_traces[0]  # (H0, feature_dim)
        trace_activity = jnp.mean(jnp.abs(input_traces), axis=0)  # (feature_dim,)
    else:
        feature_dim = state.head_params.weights[0].shape[1]
        trace_activity = jnp.zeros(feature_dim)

    # --- Normalizer state ---
    normalizer_mean = None
    normalizer_std = None
    if state.normalizer_state is not None:
        normalizer_mean = state.normalizer_state.mean
        normalizer_std = jnp.sqrt(state.normalizer_state.var + 1e-8)

    # --- Head reliance ---
    # |head_params.weights[h]| squeezed to (H_last,)
    head_reliance_list = []
    for h in range(n_heads):
        head_reliance_list.append(jnp.abs(jnp.squeeze(state.head_params.weights[h], axis=0)))
    head_reliance = jnp.stack(head_reliance_list)  # (n_heads, H_last)

    # --- Head mean step-size ---
    head_mean_step_size = None
    if n_heads > 0:
        first_head_w_opt = state.head_optimizer_states[0][0]
        if hasattr(first_head_w_opt, "step_sizes"):
            head_ss_list = []
            for h in range(n_heads):
                w_opt = state.head_optimizer_states[h][0]
                head_ss_list.append(jnp.mean(w_opt.step_sizes))
            head_mean_step_size = jnp.array(head_ss_list)

    return FeatureRelevance(
        weight_relevance=weight_relevance,
        step_size_activity=step_size_activity,
        trace_activity=trace_activity,
        normalizer_mean=normalizer_mean,
        normalizer_std=normalizer_std,
        head_reliance=head_reliance,
        head_mean_step_size=head_mean_step_size,
    )

compute_feature_sensitivity(learner, state, observation)

Compute per-head sensitivity to each input feature via Jacobian.

Uses jax.jacrev to compute d(pred_h)/d(obs_f) for all heads and features. This is a Tier 2 metric requiring one forward pass per output (5 for 5 heads). Typical cost: ~100-500us for a (64,64) trunk.

jacrev is used because output dim (n_heads) < input dim (feature_dim), making reverse-mode more efficient.

Args: learner: The multi-head MLP learner instance. state: Current learner state. observation: Input feature vector, shape (feature_dim,).

Returns: Jacobian array of shape (n_heads, feature_dim) where entry [h, f] is the sensitivity of head h's prediction to feature f at this observation.

Source code in src/alberta_framework/core/diagnostics.py
def compute_feature_sensitivity(
    learner: MultiHeadMLPLearner,
    state: MultiHeadMLPState,
    observation: Array,
) -> Array:
    """Compute per-head sensitivity to each input feature via Jacobian.

    Uses ``jax.jacrev`` to compute ``d(pred_h)/d(obs_f)`` for all heads
    and features. This is a Tier 2 metric requiring one forward pass
    per output (5 for 5 heads). Typical cost: ~100-500us for a (64,64)
    trunk.

    ``jacrev`` is used because output dim (n_heads) < input dim
    (feature_dim), making reverse-mode more efficient.

    Args:
        learner: The multi-head MLP learner instance.
        state: Current learner state.
        observation: Input feature vector, shape ``(feature_dim,)``.

    Returns:
        Jacobian array of shape ``(n_heads, feature_dim)`` where entry
        ``[h, f]`` is the sensitivity of head ``h``'s prediction to
        feature ``f`` at this observation.
    """

    def predict_fn(obs: Array) -> Array:
        preds: Array = learner.predict(state, obs)
        return preds

    jacobian: Array = jax.jacrev(predict_fn)(observation)  # (n_heads, feature_dim)
    return jacobian

relevance_to_dict(relevance, feature_names=None, head_names=None)

Convert FeatureRelevance to a JSON-serializable dict.

Produces a structured dict suitable for logging or inspection. Includes normalized_weight_relevance when normalizer state is available, which scales weight relevance by normalizer std to give relevance in raw input units.

Args: relevance: FeatureRelevance from compute_feature_relevance. feature_names: Optional list of feature names. If None, uses "feature_0", "feature_1", etc. head_names: Optional list of head names. If None, uses "head_0", "head_1", etc.

Returns: Nested dict with "trunk" and "per_head" sections.

Source code in src/alberta_framework/core/diagnostics.py
def relevance_to_dict(
    relevance: FeatureRelevance,
    feature_names: list[str] | None = None,
    head_names: list[str] | None = None,
) -> dict[str, Any]:
    """Convert FeatureRelevance to a JSON-serializable dict.

    Produces a structured dict suitable for logging or inspection.
    Includes ``normalized_weight_relevance`` when normalizer state is
    available, which scales weight relevance by normalizer std to give
    relevance in raw input units.

    Args:
        relevance: FeatureRelevance from ``compute_feature_relevance``.
        feature_names: Optional list of feature names. If None, uses
            ``"feature_0"``, ``"feature_1"``, etc.
        head_names: Optional list of head names. If None, uses
            ``"head_0"``, ``"head_1"``, etc.

    Returns:
        Nested dict with ``"trunk"`` and ``"per_head"`` sections.
    """
    n_heads, feature_dim = relevance.weight_relevance.shape
    h_last = relevance.head_reliance.shape[1]

    if feature_names is None:
        feature_names = [f"feature_{i}" for i in range(feature_dim)]
    if head_names is None:
        head_names = [f"head_{i}" for i in range(n_heads)]

    # Trunk-level metrics
    trunk: dict[str, Any] = {
        "step_size_activity": {
            feature_names[f]: float(relevance.step_size_activity[f]) for f in range(feature_dim)
        },
        "trace_activity": {
            feature_names[f]: float(relevance.trace_activity[f]) for f in range(feature_dim)
        },
    }

    if relevance.normalizer_mean is not None:
        trunk["normalizer_mean"] = {
            feature_names[f]: float(relevance.normalizer_mean[f]) for f in range(feature_dim)
        }
    if relevance.normalizer_std is not None:
        trunk["normalizer_std"] = {
            feature_names[f]: float(relevance.normalizer_std[f]) for f in range(feature_dim)
        }

    # Compute normalized weight relevance if normalizer is available
    has_norm_std = relevance.normalizer_std is not None

    # Per-head metrics
    per_head: dict[str, Any] = {}
    for h in range(n_heads):
        head_dict: dict[str, Any] = {
            "weight_relevance": {
                feature_names[f]: float(relevance.weight_relevance[h, f])
                for f in range(feature_dim)
            },
        }
        if has_norm_std and relevance.normalizer_std is not None:
            norm_rel = relevance.weight_relevance[h] * relevance.normalizer_std
            head_dict["normalized_weight_relevance"] = {
                feature_names[f]: float(norm_rel[f]) for f in range(feature_dim)
            }
        head_dict["head_reliance"] = {
            f"neuron_{j}": float(relevance.head_reliance[h, j]) for j in range(h_last)
        }
        if relevance.head_mean_step_size is not None:
            head_dict["mean_step_size"] = float(relevance.head_mean_step_size[h])

        per_head[head_names[h]] = head_dict

    return {"trunk": trunk, "per_head": per_head}

run_horde_learning_loop(horde, state, observations, cumulants, next_observations)

Run Horde learning loop using jax.lax.scan.

Scans over (obs, cumulants, next_obs) triples.

Args: horde: Horde learner state: Initial learner state observations: Input observations, shape (num_steps, feature_dim) cumulants: Per-demon cumulants, shape (num_steps, n_demons). NaN = inactive demon for that step. next_observations: Next observations, shape (num_steps, feature_dim)

Returns: HordeLearningResult with final state, per-demon metrics, and TD errors

Source code in src/alberta_framework/core/horde.py
def run_horde_learning_loop(
    horde: HordeLearner,
    state: MultiHeadMLPState,
    observations: Array,
    cumulants: Array,
    next_observations: Array,
) -> HordeLearningResult:
    """Run Horde learning loop using ``jax.lax.scan``.

    Scans over ``(obs, cumulants, next_obs)`` triples.

    Args:
        horde: Horde learner
        state: Initial learner state
        observations: Input observations, shape ``(num_steps, feature_dim)``
        cumulants: Per-demon cumulants, shape ``(num_steps, n_demons)``.
            NaN = inactive demon for that step.
        next_observations: Next observations, shape ``(num_steps, feature_dim)``

    Returns:
        HordeLearningResult with final state, per-demon metrics, and TD errors
    """

    def step_fn(
        carry: MultiHeadMLPState,
        inputs: tuple[Array, Array, Array],
    ) -> tuple[MultiHeadMLPState, tuple[Array, Array]]:
        l_state = carry
        obs, cums, next_obs = inputs
        result = horde.update(l_state, obs, cums, next_obs)
        return result.state, (result.per_demon_metrics, result.td_errors)

    t0 = time.time()
    final_state, (per_demon_metrics, td_errors) = jax.lax.scan(
        step_fn, state, (observations, cumulants, next_observations)
    )
    elapsed = time.time() - t0
    final_state = final_state.replace(uptime_s=final_state.uptime_s + elapsed)  # type: ignore[attr-defined]

    return HordeLearningResult(  # type: ignore[call-arg]
        state=final_state,
        per_demon_metrics=per_demon_metrics,
        td_errors=td_errors,
    )

run_horde_learning_loop_batched(horde, observations, cumulants, next_observations, keys)

Run Horde learning loop across seeds using jax.vmap.

Each seed produces an independently initialized state. All seeds share the same observations, cumulants, and next observations.

Args: horde: Horde learner observations: Shared observations, shape (num_steps, feature_dim) cumulants: Shared cumulants, shape (num_steps, n_demons) next_observations: Shared next observations, shape (num_steps, feature_dim) keys: JAX random keys, shape (n_seeds,) or (n_seeds, 2)

Returns: BatchedHordeResult with batched states, per-demon metrics, and TD errors

Source code in src/alberta_framework/core/horde.py
def run_horde_learning_loop_batched(
    horde: HordeLearner,
    observations: Array,
    cumulants: Array,
    next_observations: Array,
    keys: Array,
) -> BatchedHordeResult:
    """Run Horde learning loop across seeds using ``jax.vmap``.

    Each seed produces an independently initialized state. All seeds
    share the same observations, cumulants, and next observations.

    Args:
        horde: Horde learner
        observations: Shared observations, shape ``(num_steps, feature_dim)``
        cumulants: Shared cumulants, shape ``(num_steps, n_demons)``
        next_observations: Shared next observations,
            shape ``(num_steps, feature_dim)``
        keys: JAX random keys, shape ``(n_seeds,)`` or ``(n_seeds, 2)``

    Returns:
        BatchedHordeResult with batched states, per-demon metrics, and TD errors
    """
    feature_dim = observations.shape[1]

    def single_run(key: Array) -> tuple[MultiHeadMLPState, Array, Array]:
        init_state = horde.init(feature_dim, key)
        result = run_horde_learning_loop(
            horde, init_state, observations, cumulants, next_observations
        )
        return result.state, result.per_demon_metrics, result.td_errors

    t0 = time.time()
    batched_states, batched_metrics, batched_td_errors = jax.vmap(single_run)(keys)
    elapsed = time.time() - t0
    batched_states = batched_states.replace(  # type: ignore[attr-defined]
        uptime_s=batched_states.uptime_s + elapsed
    )

    return BatchedHordeResult(  # type: ignore[call-arg]
        states=batched_states,
        per_demon_metrics=batched_metrics,
        td_errors=batched_td_errors,
    )

sparse_init(key, shape, sparsity=0.9, init_type='uniform')

Create a sparsely initialized weight matrix.

Applies LeCun-scale initialization and then zeros out a fraction of weights per output neuron. This creates sparser gradient flows that improve stability in streaming learning settings.

Reference: Elsayed et al. 2024, sparse_init.py

Args: key: JAX random key shape: Weight matrix shape (fan_out, fan_in) sparsity: Fraction of input connections to zero out per output neuron (default: 0.9 means 90% sparse) init_type: Initialization distribution, "uniform" or "normal" (default: "uniform" for LeCun uniform)

Returns: Weight matrix of given shape with specified sparsity

Examples:

import jax.random as jr
from alberta_framework.core.initializers import sparse_init

key = jr.key(42)
weights = sparse_init(key, (128, 10), sparsity=0.9)
# weights has shape (128, 10), ~90% zeros per row

Source code in src/alberta_framework/core/initializers.py
def sparse_init(
    key: Array,
    shape: tuple[int, int],
    sparsity: float = 0.9,
    init_type: str = "uniform",
) -> Float[Array, "fan_out fan_in"]:
    """Create a sparsely initialized weight matrix.

    Applies LeCun-scale initialization and then zeros out a fraction of
    weights per output neuron. This creates sparser gradient flows that
    improve stability in streaming learning settings.

    Reference: Elsayed et al. 2024, sparse_init.py

    Args:
        key: JAX random key
        shape: Weight matrix shape (fan_out, fan_in)
        sparsity: Fraction of input connections to zero out per output neuron
            (default: 0.9 means 90% sparse)
        init_type: Initialization distribution, "uniform" or "normal"
            (default: "uniform" for LeCun uniform)

    Returns:
        Weight matrix of given shape with specified sparsity

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework.core.initializers import sparse_init

    key = jr.key(42)
    weights = sparse_init(key, (128, 10), sparsity=0.9)
    # weights has shape (128, 10), ~90% zeros per row
    ```
    """
    fan_out, fan_in = shape
    num_zeros = int(sparsity * fan_in + 0.5)  # round to nearest int

    # Split key for init and sparsity mask
    init_key, mask_key = jr.split(key)

    # LeCun-scale initialization
    scale = 1.0 / fan_in**0.5
    if init_type == "uniform":
        weights = jr.uniform(init_key, shape, dtype=jnp.float32, minval=-scale, maxval=scale)
    elif init_type == "normal":
        weights = jr.normal(init_key, shape, dtype=jnp.float32) * scale
    else:
        raise ValueError(f"init_type must be 'uniform' or 'normal', got '{init_type}'")

    # Create sparsity mask: for each output neuron, zero out num_zeros inputs
    # Use vmap over output neurons with independent random permutations
    row_keys = jr.split(mask_key, fan_out)

    def make_row_mask(row_key: Array) -> Float[Array, " fan_in"]:
        """Create a binary mask for a single output neuron."""
        perm = jr.permutation(row_key, fan_in)
        # mask[i] = 1 if perm[i] >= num_zeros, else 0
        mask = (perm >= num_zeros).astype(jnp.float32)
        return mask

    masks = jax.vmap(make_row_mask)(row_keys)  # (fan_out, fan_in)

    return weights * masks

metrics_to_dicts(metrics, normalized=False)

Convert metrics array to list of dicts for backward compatibility.

Args: metrics: Array of shape (num_steps, 3) or (num_steps, 4) normalized: If True, expects 4 columns including normalizer_mean_var

Returns: List of metric dictionaries

Source code in src/alberta_framework/core/learners.py
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
    """Convert metrics array to list of dicts for backward compatibility.

    Args:
        metrics: Array of shape (num_steps, 3) or (num_steps, 4)
        normalized: If True, expects 4 columns including normalizer_mean_var

    Returns:
        List of metric dictionaries
    """
    result = []
    for row in metrics:
        d = {
            "squared_error": float(row[0]),
            "error": float(row[1]),
            "mean_step_size": float(row[2]),
        }
        if normalized and len(row) > 3:
            d["normalizer_mean_var"] = float(row[3])
        result.append(d)
    return result

run_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None, normalizer_tracking=None)

Run the learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.

Supports both plain and normalized learners. When the learner has a normalizer, metrics have 4 columns; otherwise 3 columns.

Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns StepSizeHistory. normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory with means and variances over time.

Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) or (num_steps, 4) depending on normalizer If step_size_tracking only: Tuple of (final_state, metrics_array, step_size_history) If normalizer_tracking only: Tuple of (final_state, metrics_array, normalizer_history) If both: Tuple of (final_state, metrics_array, step_size_history, normalizer_history)

Raises: ValueError: If tracking interval is invalid

Source code in src/alberta_framework/core/learners.py
def run_learning_loop[StreamStateT](
    learner: LinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: LearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> (
    tuple[LearnerState, Array]
    | tuple[LearnerState, Array, StepSizeHistory]
    | tuple[LearnerState, Array, NormalizerHistory]
    | tuple[LearnerState, Array, StepSizeHistory, NormalizerHistory]
):
    """Run the learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.
    It returns metrics as a fixed-size array rather than a list of dicts.

    Supports both plain and normalized learners. When the learner has a
    normalizer, metrics have 4 columns; otherwise 3 columns.

    Args:
        learner: The learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, returns StepSizeHistory.
        normalizer_tracking: Optional config for recording per-feature normalizer
            state. When provided, returns NormalizerHistory with means and
            variances over time.

    Returns:
        If no tracking:
            Tuple of (final_state, metrics_array) where metrics_array has shape
            (num_steps, 3) or (num_steps, 4) depending on normalizer
        If step_size_tracking only:
            Tuple of (final_state, metrics_array, step_size_history)
        If normalizer_tracking only:
            Tuple of (final_state, metrics_array, normalizer_history)
        If both:
            Tuple of (final_state, metrics_array, step_size_history, normalizer_history)

    Raises:
        ValueError: If tracking interval is invalid
    """
    # Validate tracking configs
    if step_size_tracking is not None:
        if step_size_tracking.interval < 1:
            raise ValueError(
                f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
            )
        if step_size_tracking.interval > num_steps:
            raise ValueError(
                f"step_size_tracking.interval ({step_size_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    if normalizer_tracking is not None:
        if normalizer_tracking.interval < 1:
            raise ValueError(
                f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
            )
        if normalizer_tracking.interval > num_steps:
            raise ValueError(
                f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    feature_dim = stream.feature_dim

    # No tracking - simple case
    if step_size_tracking is None and normalizer_tracking is None:

        def step_fn(
            carry: tuple[LearnerState, StreamStateT], idx: Array
        ) -> tuple[tuple[LearnerState, StreamStateT], Array]:
            l_state, s_state = carry
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)
            return (result.state, new_s_state), result.metrics

        t0 = time.time()
        (final_learner, _), metrics = jax.lax.scan(
            step_fn, (learner_state, stream_state), jnp.arange(num_steps)
        )
        elapsed = time.time() - t0
        final_learner = final_learner.replace(uptime_s=final_learner.uptime_s + elapsed)  # type: ignore[attr-defined]

        return final_learner, metrics

    # Tracking enabled - need to set up history arrays
    ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
    norm_interval = normalizer_tracking.interval if normalizer_tracking else num_steps + 1

    ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
    norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0

    # Pre-allocate step-size history arrays
    ss_history = (
        jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
        if step_size_tracking
        else None
    )
    ss_bias_history = (
        jnp.zeros(ss_num_recordings, dtype=jnp.float32)
        if step_size_tracking and step_size_tracking.include_bias
        else None
    )
    ss_rec_indices = jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None

    # Check if we need to track Autostep normalizers
    track_autostep_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
    ss_normalizers = (
        jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
        if step_size_tracking and track_autostep_normalizers
        else None
    )

    # Pre-allocate normalizer state history arrays
    norm_means = (
        jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
        if normalizer_tracking
        else None
    )
    norm_vars = (
        jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
        if normalizer_tracking
        else None
    )
    norm_rec_indices = (
        jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
    )

    def step_fn_with_tracking(
        carry: tuple[
            LearnerState,
            StreamStateT,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
        ],
        idx: Array,
    ) -> tuple[
        tuple[
            LearnerState,
            StreamStateT,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
        ],
        Array,
    ]:
        (
            l_state,
            s_state,
            ss_hist,
            ss_bias_hist,
            ss_rec,
            ss_norm,
            n_means,
            n_vars,
            n_rec,
        ) = carry

        # Perform learning step
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(l_state, timestep.observation, timestep.target)

        # Step-size tracking
        new_ss_hist = ss_hist
        new_ss_bias_hist = ss_bias_hist
        new_ss_rec = ss_rec
        new_ss_norm = ss_norm

        if ss_hist is not None:
            should_record_ss = (idx % ss_interval) == 0
            recording_idx = idx // ss_interval

            # Extract current step-sizes
            opt_state = result.state.optimizer_state
            if hasattr(opt_state, "log_step_sizes"):
                # IDBD stores log step-sizes
                weight_ss = jnp.exp(opt_state.log_step_sizes)
                bias_ss = opt_state.bias_step_size
            elif hasattr(opt_state, "step_sizes"):
                # Autostep stores step-sizes directly
                weight_ss = opt_state.step_sizes
                bias_ss = opt_state.bias_step_size
            else:
                # LMS has a single fixed step-size
                weight_ss = jnp.full(feature_dim, opt_state.step_size)
                bias_ss = opt_state.step_size

            new_ss_hist = jax.lax.cond(
                should_record_ss,
                lambda _: ss_hist.at[recording_idx].set(weight_ss),
                lambda _: ss_hist,
                None,
            )

            if ss_bias_hist is not None:
                new_ss_bias_hist = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
                    lambda _: ss_bias_hist,
                    None,
                )

            if ss_rec is not None:
                new_ss_rec = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_rec.at[recording_idx].set(idx),
                    lambda _: ss_rec,
                    None,
                )

            # Track Autostep normalizers (v_i) if applicable
            if ss_norm is not None and hasattr(opt_state, "normalizers"):
                new_ss_norm = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_norm.at[recording_idx].set(opt_state.normalizers),
                    lambda _: ss_norm,
                    None,
                )

        # Normalizer state tracking
        new_n_means = n_means
        new_n_vars = n_vars
        new_n_rec = n_rec

        if n_means is not None:
            should_record_norm = (idx % norm_interval) == 0
            norm_recording_idx = idx // norm_interval

            norm_state = result.state.normalizer_state

            new_n_means = jax.lax.cond(
                should_record_norm,
                lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
                lambda _: n_means,
                None,
            )

            if n_vars is not None:
                new_n_vars = jax.lax.cond(
                    should_record_norm,
                    lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
                    lambda _: n_vars,
                    None,
                )

            if n_rec is not None:
                new_n_rec = jax.lax.cond(
                    should_record_norm,
                    lambda _: n_rec.at[norm_recording_idx].set(idx),
                    lambda _: n_rec,
                    None,
                )

        return (
            result.state,
            new_s_state,
            new_ss_hist,
            new_ss_bias_hist,
            new_ss_rec,
            new_ss_norm,
            new_n_means,
            new_n_vars,
            new_n_rec,
        ), result.metrics

    initial_carry = (
        learner_state,
        stream_state,
        ss_history,
        ss_bias_history,
        ss_rec_indices,
        ss_normalizers,
        norm_means,
        norm_vars,
        norm_rec_indices,
    )

    t0 = time.time()
    (
        (
            final_learner,
            _,
            final_ss_hist,
            final_ss_bias_hist,
            final_ss_rec,
            final_ss_norm,
            final_n_means,
            final_n_vars,
            final_n_rec,
        ),
        metrics,
    ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
    elapsed = time.time() - t0
    final_learner = final_learner.replace(uptime_s=final_learner.uptime_s + elapsed)  # type: ignore[attr-defined]

    # Build return values based on what was tracked
    ss_history_result = None
    if step_size_tracking is not None and final_ss_hist is not None:
        ss_history_result = StepSizeHistory(
            step_sizes=final_ss_hist,
            bias_step_sizes=final_ss_bias_hist,
            recording_indices=final_ss_rec,
            normalizers=final_ss_norm,
        )

    norm_history_result = None
    if normalizer_tracking is not None and final_n_means is not None:
        norm_history_result = NormalizerHistory(
            means=final_n_means,
            variances=final_n_vars,
            recording_indices=final_n_rec,
        )

    # Return appropriate tuple based on what was tracked
    if ss_history_result is not None and norm_history_result is not None:
        return final_learner, metrics, ss_history_result, norm_history_result
    elif ss_history_result is not None:
        return final_learner, metrics, ss_history_result
    elif norm_history_result is not None:
        return final_learner, metrics, norm_history_result
    else:
        return final_learner, metrics

run_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None, normalizer_tracking=None)

Run learning loop across multiple seeds in parallel using jax.vmap.

This function provides GPU parallelization for multi-seed experiments, typically achieving 2-5x speedup over sequential execution.

Supports both plain and normalized learners.

Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...) normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)

Returns: BatchedLearningResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, num_cols) - step_size_history: Batched history or None if tracking disabled - normalizer_history: Batched history or None if tracking disabled

Examples:

import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_learning_loop_batched

stream = RandomWalkStream(feature_dim=10)
learner = LinearLearner(optimizer=IDBD())

# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds

Source code in src/alberta_framework/core/learners.py
def run_learning_loop_batched[StreamStateT](
    learner: LinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    keys: Array,
    learner_state: LearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> BatchedLearningResult:
    """Run learning loop across multiple seeds in parallel using jax.vmap.

    This function provides GPU parallelization for multi-seed experiments,
    typically achieving 2-5x speedup over sequential execution.

    Supports both plain and normalized learners.

    Args:
        learner: The learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run per seed
        keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
        learner_state: Initial state (if None, will be initialized from stream).
            The same initial state is used for all seeds.
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)
        normalizer_tracking: Optional config for recording normalizer state.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)

    Returns:
        BatchedLearningResult containing:
            - states: Batched final states with shape (num_seeds, ...) for each array
            - metrics: Array of shape (num_seeds, num_steps, num_cols)
            - step_size_history: Batched history or None if tracking disabled
            - normalizer_history: Batched history or None if tracking disabled

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework import LinearLearner, IDBD, RandomWalkStream
    from alberta_framework import run_learning_loop_batched

    stream = RandomWalkStream(feature_dim=10)
    learner = LinearLearner(optimizer=IDBD())

    # Run 30 seeds in parallel
    keys = jr.split(jr.key(42), 30)
    result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

    # result.metrics has shape (30, 10000, 3)
    mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds
    ```
    """

    # Define single-seed function that returns consistent structure
    def single_seed_run(
        key: Array,
    ) -> tuple[LearnerState, Array, StepSizeHistory | None, NormalizerHistory | None]:
        result = run_learning_loop(
            learner, stream, num_steps, key, learner_state,
            step_size_tracking, normalizer_tracking,
        )

        # Unpack based on what tracking was enabled
        if step_size_tracking is not None and normalizer_tracking is not None:
            state, metrics, ss_history, norm_history = cast(
                tuple[LearnerState, Array, StepSizeHistory, NormalizerHistory],
                result,
            )
            return state, metrics, ss_history, norm_history
        elif step_size_tracking is not None:
            state, metrics, ss_history = cast(
                tuple[LearnerState, Array, StepSizeHistory], result
            )
            return state, metrics, ss_history, None
        elif normalizer_tracking is not None:
            state, metrics, norm_history = cast(
                tuple[LearnerState, Array, NormalizerHistory], result
            )
            return state, metrics, None, norm_history
        else:
            state, metrics = cast(tuple[LearnerState, Array], result)
            return state, metrics, None, None

    # vmap over the keys dimension
    t0 = time.time()
    batched_states, batched_metrics, batched_ss_history, batched_norm_history = jax.vmap(
        single_seed_run
    )(keys)
    elapsed = time.time() - t0
    batched_states = batched_states.replace(  # type: ignore[attr-defined]
        uptime_s=batched_states.uptime_s + elapsed
    )

    # Reconstruct batched histories if tracking was enabled
    if step_size_tracking is not None and batched_ss_history is not None:
        batched_step_size_history = StepSizeHistory(
            step_sizes=batched_ss_history.step_sizes,
            bias_step_sizes=batched_ss_history.bias_step_sizes,
            recording_indices=batched_ss_history.recording_indices,
            normalizers=batched_ss_history.normalizers,
        )
    else:
        batched_step_size_history = None

    if normalizer_tracking is not None and batched_norm_history is not None:
        batched_normalizer_history = NormalizerHistory(
            means=batched_norm_history.means,
            variances=batched_norm_history.variances,
            recording_indices=batched_norm_history.recording_indices,
        )
    else:
        batched_normalizer_history = None

    return BatchedLearningResult(
        states=batched_states,
        metrics=batched_metrics,
        step_size_history=batched_step_size_history,
        normalizer_history=batched_normalizer_history,
    )

run_mlp_learning_loop(learner, stream, num_steps, key, learner_state=None, normalizer_tracking=None)

Run the MLP learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency.

Args: learner: The MLP learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream and weight initialization learner_state: Initial state (if None, will be initialized from stream) normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory.

Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) or (num_steps, 4) If normalizer_tracking: Tuple of (final_state, metrics_array, normalizer_history)

Raises: ValueError: If normalizer_tracking.interval is invalid

Source code in src/alberta_framework/core/learners.py
def run_mlp_learning_loop[StreamStateT](
    learner: MLPLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: MLPLearnerState | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> (
    tuple[MLPLearnerState, Array]
    | tuple[MLPLearnerState, Array, NormalizerHistory]
):
    """Run the MLP learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.

    Args:
        learner: The MLP learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run
        key: JAX random key for stream and weight initialization
        learner_state: Initial state (if None, will be initialized from stream)
        normalizer_tracking: Optional config for recording per-feature normalizer
            state. When provided, returns NormalizerHistory.

    Returns:
        If no tracking:
            Tuple of (final_state, metrics_array) where metrics_array has shape
            (num_steps, 3) or (num_steps, 4)
        If normalizer_tracking:
            Tuple of (final_state, metrics_array, normalizer_history)

    Raises:
        ValueError: If normalizer_tracking.interval is invalid
    """
    # Validate tracking config
    if normalizer_tracking is not None:
        if normalizer_tracking.interval < 1:
            raise ValueError(
                f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
            )
        if normalizer_tracking.interval > num_steps:
            raise ValueError(
                f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    # Split key for initialization
    stream_key, init_key = jax.random.split(key)

    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim, init_key)
    stream_state = stream.init(stream_key)

    feature_dim = stream.feature_dim

    if normalizer_tracking is None:
        # Simple case without tracking
        def step_fn(
            carry: tuple[MLPLearnerState, StreamStateT], idx: Array
        ) -> tuple[tuple[MLPLearnerState, StreamStateT], Array]:
            l_state, s_state = carry
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)
            return (result.state, new_s_state), result.metrics

        t0 = time.time()
        (final_learner, _), metrics = jax.lax.scan(
            step_fn, (learner_state, stream_state), jnp.arange(num_steps)
        )
        elapsed = time.time() - t0
        final_learner = final_learner.replace(uptime_s=final_learner.uptime_s + elapsed)  # type: ignore[attr-defined]

        return final_learner, metrics

    # Tracking enabled
    norm_interval = normalizer_tracking.interval
    norm_num_recordings = num_steps // norm_interval

    norm_means = jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
    norm_vars = jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
    norm_rec_indices = jnp.zeros(norm_num_recordings, dtype=jnp.int32)

    def step_fn_with_tracking(
        carry: tuple[MLPLearnerState, StreamStateT, Array, Array, Array],
        idx: Array,
    ) -> tuple[
        tuple[MLPLearnerState, StreamStateT, Array, Array, Array],
        Array,
    ]:
        l_state, s_state, n_means, n_vars, n_rec = carry

        # Perform learning step
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(l_state, timestep.observation, timestep.target)

        # Normalizer state tracking
        should_record = (idx % norm_interval) == 0
        recording_idx = idx // norm_interval

        norm_state = result.state.normalizer_state

        new_n_means = jax.lax.cond(
            should_record,
            lambda _: n_means.at[recording_idx].set(norm_state.mean),
            lambda _: n_means,
            None,
        )

        new_n_vars = jax.lax.cond(
            should_record,
            lambda _: n_vars.at[recording_idx].set(norm_state.var),
            lambda _: n_vars,
            None,
        )

        new_n_rec = jax.lax.cond(
            should_record,
            lambda _: n_rec.at[recording_idx].set(idx),
            lambda _: n_rec,
            None,
        )

        return (
            result.state,
            new_s_state,
            new_n_means,
            new_n_vars,
            new_n_rec,
        ), result.metrics

    initial_carry = (
        learner_state,
        stream_state,
        norm_means,
        norm_vars,
        norm_rec_indices,
    )

    t0 = time.time()
    (
        (final_learner, _, final_n_means, final_n_vars, final_n_rec),
        metrics,
    ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))
    elapsed = time.time() - t0
    final_learner = final_learner.replace(uptime_s=final_learner.uptime_s + elapsed)  # type: ignore[attr-defined]

    norm_history = NormalizerHistory(
        means=final_n_means,
        variances=final_n_vars,
        recording_indices=final_n_rec,
    )

    return final_learner, metrics, norm_history

run_mlp_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, normalizer_tracking=None)

Run MLP learning loop across multiple seeds in parallel using jax.vmap.

This function provides GPU parallelization for multi-seed MLP experiments, typically achieving 2-5x speedup over sequential execution.

Args: learner: The MLP learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)

Returns: BatchedMLPResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, num_cols) - normalizer_history: Batched history or None if tracking disabled

Examples:

import jax.random as jr
from alberta_framework import MLPLearner, RandomWalkStream
from alberta_framework import run_mlp_learning_loop_batched

stream = RandomWalkStream(feature_dim=10)
learner = MLPLearner(hidden_sizes=(128, 128))

# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_mlp_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds

Source code in src/alberta_framework/core/learners.py
def run_mlp_learning_loop_batched[StreamStateT](
    learner: MLPLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    keys: Array,
    learner_state: MLPLearnerState | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> BatchedMLPResult:
    """Run MLP learning loop across multiple seeds in parallel using jax.vmap.

    This function provides GPU parallelization for multi-seed MLP experiments,
    typically achieving 2-5x speedup over sequential execution.

    Args:
        learner: The MLP learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run per seed
        keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
        learner_state: Initial state (if None, will be initialized from stream).
            The same initial state is used for all seeds.
        normalizer_tracking: Optional config for recording normalizer state.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)

    Returns:
        BatchedMLPResult containing:
            - states: Batched final states with shape (num_seeds, ...) for each array
            - metrics: Array of shape (num_seeds, num_steps, num_cols)
            - normalizer_history: Batched history or None if tracking disabled

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework import MLPLearner, RandomWalkStream
    from alberta_framework import run_mlp_learning_loop_batched

    stream = RandomWalkStream(feature_dim=10)
    learner = MLPLearner(hidden_sizes=(128, 128))

    # Run 30 seeds in parallel
    keys = jr.split(jr.key(42), 30)
    result = run_mlp_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

    # result.metrics has shape (30, 10000, 3)
    mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds
    ```
    """

    def single_seed_run(
        key: Array,
    ) -> tuple[MLPLearnerState, Array, NormalizerHistory | None]:
        result = run_mlp_learning_loop(
            learner, stream, num_steps, key, learner_state, normalizer_tracking
        )

        if normalizer_tracking is not None:
            state, metrics, norm_history = cast(
                tuple[MLPLearnerState, Array, NormalizerHistory], result
            )
            return state, metrics, norm_history
        else:
            state, metrics = cast(tuple[MLPLearnerState, Array], result)
            return state, metrics, None

    t0 = time.time()
    batched_states, batched_metrics, batched_norm_history = jax.vmap(single_seed_run)(keys)
    elapsed = time.time() - t0
    batched_states = batched_states.replace(  # type: ignore[attr-defined]
        uptime_s=batched_states.uptime_s + elapsed
    )

    if normalizer_tracking is not None and batched_norm_history is not None:
        batched_normalizer_history = NormalizerHistory(
            means=batched_norm_history.means,
            variances=batched_norm_history.variances,
            recording_indices=batched_norm_history.recording_indices,
        )
    else:
        batched_normalizer_history = None

    return BatchedMLPResult(
        states=batched_states,
        metrics=batched_metrics,
        normalizer_history=batched_normalizer_history,
    )

run_td_learning_loop(learner, stream, num_steps, key, learner_state=None)

Run the TD learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.

Args: learner: The TD learner to train stream: TD experience stream providing (s, r, s', gamma) tuples num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]

Source code in src/alberta_framework/core/learners.py
def run_td_learning_loop[StreamStateT](
    learner: TDLinearLearner,
    stream: TDStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: TDLearnerState | None = None,
) -> tuple[TDLearnerState, Array]:
    """Run the TD learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.
    It returns metrics as a fixed-size array rather than a list of dicts.

    Args:
        learner: The TD learner to train
        stream: TD experience stream providing (s, r, s', gamma) tuples
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size,
        mean_eligibility_trace]
    """
    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    def step_fn(
        carry: tuple[TDLearnerState, StreamStateT], idx: Array
    ) -> tuple[tuple[TDLearnerState, StreamStateT], Array]:
        l_state, s_state = carry
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(
            l_state,
            timestep.observation,
            timestep.reward,
            timestep.next_observation,
            timestep.gamma,
        )
        return (result.state, new_s_state), result.metrics

    t0 = time.time()
    (final_learner, _), metrics = jax.lax.scan(
        step_fn, (learner_state, stream_state), jnp.arange(num_steps)
    )
    elapsed = time.time() - t0
    final_learner = final_learner.replace(uptime_s=final_learner.uptime_s + elapsed)  # type: ignore[attr-defined]

    return final_learner, metrics

multi_head_metrics_to_dicts(result)

Convert per-head metrics array to list of dicts for online use.

Active heads get a dict with keys 'squared_error', 'error', 'mean_step_size'. Inactive heads get None.

Args: result: Update result from MultiHeadMLPLearner.update

Returns: List of n_heads entries, one per head

Source code in src/alberta_framework/core/multi_head_learner.py
def multi_head_metrics_to_dicts(
    result: MultiHeadMLPUpdateResult,
) -> list[dict[str, float] | None]:
    """Convert per-head metrics array to list of dicts for online use.

    Active heads get a dict with keys ``'squared_error'``, ``'error'``,
    ``'mean_step_size'``. Inactive heads get ``None``.

    Args:
        result: Update result from ``MultiHeadMLPLearner.update``

    Returns:
        List of ``n_heads`` entries, one per head
    """
    output: list[dict[str, float] | None] = []
    for i in range(result.per_head_metrics.shape[0]):
        se = float(result.per_head_metrics[i, 0])
        if math.isnan(se):
            output.append(None)
        else:
            output.append(
                {
                    "squared_error": se,
                    "error": float(result.per_head_metrics[i, 1]),
                    "mean_step_size": float(result.per_head_metrics[i, 2]),
                }
            )
    return output

run_multi_head_learning_loop(learner, state, observations, targets)

Run multi-head learning loop using jax.lax.scan.

Scans over pre-provided observation and target arrays. This is designed for settings where data comes from an external source (e.g. security event logs) rather than from a ScanStream.

Args: learner: Multi-head MLP learner state: Initial learner state observations: Input observations, shape (num_steps, feature_dim) targets: Per-head targets, shape (num_steps, n_heads). NaN = inactive head for that step.

Returns: MultiHeadLearningResult with final state and per-head metrics of shape (num_steps, n_heads, 3)

Source code in src/alberta_framework/core/multi_head_learner.py
def run_multi_head_learning_loop(
    learner: MultiHeadMLPLearner,
    state: MultiHeadMLPState,
    observations: Array,
    targets: Array,
) -> MultiHeadLearningResult:
    """Run multi-head learning loop using ``jax.lax.scan``.

    Scans over pre-provided observation and target arrays. This is
    designed for settings where data comes from an external source
    (e.g. security event logs) rather than from a ``ScanStream``.

    Args:
        learner: Multi-head MLP learner
        state: Initial learner state
        observations: Input observations, shape ``(num_steps, feature_dim)``
        targets: Per-head targets, shape ``(num_steps, n_heads)``.
            NaN = inactive head for that step.

    Returns:
        ``MultiHeadLearningResult`` with final state and per-head metrics
        of shape ``(num_steps, n_heads, 3)``
    """

    def step_fn(
        carry: MultiHeadMLPState, inputs: tuple[Array, Array]
    ) -> tuple[MultiHeadMLPState, Array]:
        l_state = carry
        obs, tgt = inputs
        result = learner.update(l_state, obs, tgt)
        return result.state, result.per_head_metrics

    t0 = time.time()
    final_state, per_head_metrics = jax.lax.scan(
        step_fn, state, (observations, targets)
    )
    elapsed = time.time() - t0
    final_state = final_state.replace(uptime_s=final_state.uptime_s + elapsed)  # type: ignore[attr-defined]

    return MultiHeadLearningResult(
        state=final_state,
        per_head_metrics=per_head_metrics,
    )

run_multi_head_learning_loop_batched(learner, observations, targets, keys)

Run multi-head learning loop across seeds using jax.vmap.

Each seed produces an independently initialized state (different sparse weight masks). All seeds share the same observations and targets.

Args: learner: Multi-head MLP learner observations: Shared observations, shape (num_steps, feature_dim) targets: Shared targets, shape (num_steps, n_heads). NaN = inactive head. keys: JAX random keys, shape (n_seeds,) or (n_seeds, 2)

Returns: BatchedMultiHeadResult with batched states and per-head metrics of shape (n_seeds, num_steps, n_heads, 3)

Source code in src/alberta_framework/core/multi_head_learner.py
def run_multi_head_learning_loop_batched(
    learner: MultiHeadMLPLearner,
    observations: Array,
    targets: Array,
    keys: Array,
) -> BatchedMultiHeadResult:
    """Run multi-head learning loop across seeds using ``jax.vmap``.

    Each seed produces an independently initialized state (different
    sparse weight masks). All seeds share the same observations and
    targets.

    Args:
        learner: Multi-head MLP learner
        observations: Shared observations, shape ``(num_steps, feature_dim)``
        targets: Shared targets, shape ``(num_steps, n_heads)``.
            NaN = inactive head.
        keys: JAX random keys, shape ``(n_seeds,)`` or ``(n_seeds, 2)``

    Returns:
        ``BatchedMultiHeadResult`` with batched states and per-head metrics
        of shape ``(n_seeds, num_steps, n_heads, 3)``
    """
    feature_dim = observations.shape[1]

    def single_run(key: Array) -> tuple[MultiHeadMLPState, Array]:
        init_state = learner.init(feature_dim, key)
        result = run_multi_head_learning_loop(
            learner, init_state, observations, targets
        )
        return result.state, result.per_head_metrics

    t0 = time.time()
    batched_states, batched_metrics = jax.vmap(single_run)(keys)
    elapsed = time.time() - t0
    batched_states = batched_states.replace(  # type: ignore[attr-defined]
        uptime_s=batched_states.uptime_s + elapsed
    )

    return BatchedMultiHeadResult(
        states=batched_states,
        per_head_metrics=batched_metrics,
    )

normalizer_from_config(config)

Reconstruct a normalizer from a config dict.

Args: config: Dict with "type" key and constructor kwargs

Returns: Reconstructed normalizer instance

Raises: ValueError: If the normalizer type is unknown

Source code in src/alberta_framework/core/normalizers.py
def normalizer_from_config(config: dict[str, Any]) -> Normalizer[Any]:
    """Reconstruct a normalizer from a config dict.

    Args:
        config: Dict with ``"type"`` key and constructor kwargs

    Returns:
        Reconstructed normalizer instance

    Raises:
        ValueError: If the normalizer type is unknown
    """
    config = dict(config)
    type_name = config.pop("type")
    cls = _NORMALIZER_REGISTRY.get(type_name)
    if cls is None:
        raise ValueError(f"Unknown normalizer type: {type_name!r}")
    result: Normalizer[Any] = cls(**config)
    return result

bounder_from_config(config)

Reconstruct a bounder from a config dict.

Args: config: Dict with "type" key and constructor kwargs

Returns: Reconstructed bounder instance

Raises: ValueError: If the bounder type is unknown

Source code in src/alberta_framework/core/optimizers.py
def bounder_from_config(config: dict[str, Any]) -> Bounder:
    """Reconstruct a bounder from a config dict.

    Args:
        config: Dict with ``"type"`` key and constructor kwargs

    Returns:
        Reconstructed bounder instance

    Raises:
        ValueError: If the bounder type is unknown
    """
    config = dict(config)
    type_name = config.pop("type")
    cls = _BOUNDER_REGISTRY.get(type_name)
    if cls is None:
        raise ValueError(f"Unknown bounder type: {type_name!r}")
    result: Bounder = cls(**config)
    return result

optimizer_from_config(config)

Reconstruct an optimizer from a config dict.

Args: config: Dict with "type" key and constructor kwargs

Returns: Reconstructed optimizer instance

Raises: ValueError: If the optimizer type is unknown

Source code in src/alberta_framework/core/optimizers.py
def optimizer_from_config(config: dict[str, Any]) -> Optimizer[Any]:
    """Reconstruct an optimizer from a config dict.

    Args:
        config: Dict with ``"type"`` key and constructor kwargs

    Returns:
        Reconstructed optimizer instance

    Raises:
        ValueError: If the optimizer type is unknown
    """
    config = dict(config)
    type_name = config.pop("type")
    cls = _OPTIMIZER_REGISTRY.get(type_name)
    if cls is None:
        raise ValueError(f"Unknown optimizer type: {type_name!r}")
    result: Optimizer[Any] = cls(**config)
    return result

run_sarsa_continuing(agent, state, env, num_steps)

Run SARSA in continuing mode for a fixed number of steps.

At episode boundaries, the environment auto-resets. gamma is set to 0 at pseudo-boundaries (terminal/truncated) to prevent bootstrapping across resets, matching the ContinuingWrapper pattern.

Args: agent: SARSA agent state: Initial SARSA state env: Gymnasium environment num_steps: Number of steps to run

Returns: SARSAContinuingResult with step-level metrics

Source code in src/alberta_framework/core/sarsa.py
def run_sarsa_continuing(
    agent: SARSAAgent,
    state: SARSAState,
    env: Any,
    num_steps: int,
) -> SARSAContinuingResult:
    """Run SARSA in continuing mode for a fixed number of steps.

    At episode boundaries, the environment auto-resets. gamma is set to 0
    at pseudo-boundaries (terminal/truncated) to prevent bootstrapping
    across resets, matching the ``ContinuingWrapper`` pattern.

    Args:
        agent: SARSA agent
        state: Initial SARSA state
        env: Gymnasium environment
        num_steps: Number of steps to run

    Returns:
        SARSAContinuingResult with step-level metrics
    """
    obs, _info = env.reset()
    obs = jnp.asarray(obs, dtype=jnp.float32).flatten()

    # Select initial action
    action, new_key = agent.select_action(state, obs)
    state = state.replace(  # type: ignore[attr-defined]
        last_action=action,
        last_observation=obs,
        rng_key=new_key,
    )

    rewards: list[float] = []
    q_values_list: list[Array] = []
    td_errors: list[float] = []
    total_reward = 0.0

    for _ in range(num_steps):
        next_obs, reward, terminated, truncated, _info = env.step(int(action))
        next_obs = jnp.asarray(next_obs, dtype=jnp.float32).flatten()
        reward_arr = jnp.array(reward, dtype=jnp.float32)

        # Continuing mode: gamma=0 at pseudo-boundaries
        is_boundary = terminated or truncated
        term_arr = jnp.array(is_boundary, dtype=jnp.float32)

        if is_boundary:
            next_obs_reset, _info = env.reset()
            next_obs = jnp.asarray(next_obs_reset, dtype=jnp.float32).flatten()

        # Select next action
        next_action, new_key = agent.select_action(state, next_obs)
        state = state.replace(rng_key=new_key)  # type: ignore[attr-defined]

        # SARSA update
        result = agent.update(
            state, reward_arr, next_obs, term_arr, next_action
        )
        state = result.state

        rewards.append(float(reward))
        q_values_list.append(result.q_values)
        td_errors.append(float(result.td_error))
        total_reward += float(reward)

        action = next_action

    return SARSAContinuingResult(
        state=state,
        total_reward=total_reward,
        rewards=rewards,
        q_values=q_values_list,
        td_errors=td_errors,
    )

run_sarsa_episode(agent, state, env, max_steps=10000)

Run one episode of SARSA on a Gymnasium environment.

Python loop (env interaction not JIT-able). Follows the SARSA pattern: select a' before updating, so the update uses the on-policy next action.

Args: agent: SARSA agent state: Initial SARSA state env: Gymnasium environment max_steps: Maximum steps per episode

Returns: SARSAEpisodeResult with episode metrics

Source code in src/alberta_framework/core/sarsa.py
def run_sarsa_episode(
    agent: SARSAAgent,
    state: SARSAState,
    env: Any,
    max_steps: int = 10000,
) -> SARSAEpisodeResult:
    """Run one episode of SARSA on a Gymnasium environment.

    Python loop (env interaction not JIT-able). Follows the SARSA
    pattern: select a' *before* updating, so the update uses the
    on-policy next action.

    Args:
        agent: SARSA agent
        state: Initial SARSA state
        env: Gymnasium environment
        max_steps: Maximum steps per episode

    Returns:
        SARSAEpisodeResult with episode metrics
    """
    obs, _info = env.reset()
    obs = jnp.asarray(obs, dtype=jnp.float32).flatten()

    # Select initial action
    action, new_key = agent.select_action(state, obs)
    state = state.replace(  # type: ignore[attr-defined]
        last_action=action,
        last_observation=obs,
        rng_key=new_key,
    )

    rewards: list[float] = []
    q_values_list: list[Array] = []
    td_errors: list[float] = []
    total_reward = 0.0

    for _ in range(max_steps):
        # Step environment
        next_obs, reward, terminated, truncated, _info = env.step(int(action))
        next_obs = jnp.asarray(next_obs, dtype=jnp.float32).flatten()
        reward_arr = jnp.array(reward, dtype=jnp.float32)
        term_arr = jnp.array(terminated, dtype=jnp.float32)

        # Select next action a' (on-policy)
        next_action, new_key = agent.select_action(state, next_obs)
        state = state.replace(rng_key=new_key)  # type: ignore[attr-defined]

        # SARSA update
        result = agent.update(
            state, reward_arr, next_obs, term_arr, next_action
        )
        state = result.state

        rewards.append(float(reward))
        q_values_list.append(result.q_values)
        td_errors.append(float(result.td_error))
        total_reward += float(reward)

        action = next_action

        if terminated or truncated:
            break

    return SARSAEpisodeResult(
        state=state,
        total_reward=total_reward,
        num_steps=len(rewards),
        rewards=rewards,
        q_values=q_values_list,
        td_errors=td_errors,
    )

run_sarsa_from_arrays(agent, state, observations, rewards, terminated, next_observations)

Run SARSA on pre-collected arrays via jax.lax.scan.

JIT-compiled for maximum throughput. Actions are selected on-policy within the scan. This is the primary loop for security-gym data where observations are pre-collected.

Args: agent: SARSA agent state: Initial SARSA state (must have valid last_action, last_observation) observations: Current observations, shape (num_steps, feature_dim) rewards: Rewards, shape (num_steps,) terminated: Termination flags, shape (num_steps,) next_observations: Next observations, shape (num_steps, feature_dim)

Returns: SARSAArrayResult with per-step Q-values, TD errors, and actions

Source code in src/alberta_framework/core/sarsa.py
def run_sarsa_from_arrays(
    agent: SARSAAgent,
    state: SARSAState,
    observations: Float[Array, "num_steps feature_dim"],
    rewards: Float[Array, " num_steps"],
    terminated: Float[Array, " num_steps"],
    next_observations: Float[Array, "num_steps feature_dim"],
) -> SARSAArrayResult:
    """Run SARSA on pre-collected arrays via ``jax.lax.scan``.

    JIT-compiled for maximum throughput. Actions are selected on-policy
    within the scan. This is the primary loop for security-gym data
    where observations are pre-collected.

    Args:
        agent: SARSA agent
        state: Initial SARSA state (must have valid last_action, last_observation)
        observations: Current observations, shape ``(num_steps, feature_dim)``
        rewards: Rewards, shape ``(num_steps,)``
        terminated: Termination flags, shape ``(num_steps,)``
        next_observations: Next observations, shape ``(num_steps, feature_dim)``

    Returns:
        SARSAArrayResult with per-step Q-values, TD errors, and actions
    """

    @jax.jit
    def _scan_fn(
        carry: SARSAState,
        inputs: tuple[Array, Array, Array, Array],
    ) -> tuple[SARSAState, tuple[Array, Array, Array]]:
        s = carry
        obs, r, term, next_obs = inputs

        # Select next action for next_obs
        next_action, new_key = agent.select_action(s, next_obs)
        s = s.replace(rng_key=new_key)  # type: ignore[attr-defined]

        # Update using current obs/reward/next_obs
        result = agent.update(s, r, next_obs, term, next_action)

        return result.state, (result.q_values, result.td_error, result.action)

    t0 = time.time()
    final_state, (q_vals, td_errs, actions) = jax.lax.scan(
        _scan_fn, state, (observations, rewards, terminated, next_observations)
    )
    elapsed = time.time() - t0

    # Update uptime on the inner learner state
    final_learner = final_state.learner_state.replace(  # type: ignore[attr-defined]
        uptime_s=final_state.learner_state.uptime_s + elapsed,
    )
    final_state = final_state.replace(learner_state=final_learner)  # type: ignore[attr-defined]

    return SARSAArrayResult(  # type: ignore[call-arg]
        state=final_state,
        q_values=q_vals,
        td_errors=td_errs,
        actions=actions,
    )

agent_age_s(state)

Compute agent age in seconds (wall-clock time since birth).

Args: state: Any learner state with a birth_timestamp attribute

Returns: Seconds elapsed since the agent was initialized

Source code in src/alberta_framework/core/types.py
def agent_age_s(state: object) -> float:
    """Compute agent age in seconds (wall-clock time since birth).

    Args:
        state: Any learner state with a ``birth_timestamp`` attribute

    Returns:
        Seconds elapsed since the agent was initialized
    """
    return time.time() - getattr(state, "birth_timestamp", 0.0)

agent_uptime_s(state)

Return the agent's cumulative active uptime in seconds.

Args: state: Any learner state with an uptime_s attribute

Returns: Cumulative seconds the agent has spent inside learning loops

Source code in src/alberta_framework/core/types.py
def agent_uptime_s(state: object) -> float:
    """Return the agent's cumulative active uptime in seconds.

    Args:
        state: Any learner state with an ``uptime_s`` attribute

    Returns:
        Cumulative seconds the agent has spent inside learning loops
    """
    return float(getattr(state, "uptime_s", 0.0))

create_autotdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Create initial AutoTDIDBD optimizer state.

Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Returns: Initial AutoTDIDBD state

Source code in src/alberta_framework/core/types.py
def create_autotdidbd_state(
    feature_dim: int,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
) -> AutoTDIDBDState:
    """Create initial AutoTDIDBD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        initial_step_size: Initial per-weight step-size
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)

    Returns:
        Initial AutoTDIDBD state
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

create_horde_spec(demons)

Create a HordeSpec from a sequence of GVFSpec demons.

Pre-computes gamma and lambda arrays for efficient JIT usage.

Args: demons: Sequence of GVFSpec, one per demon/head

Returns: HordeSpec with pre-computed arrays

Source code in src/alberta_framework/core/types.py
def create_horde_spec(demons: Sequence[GVFSpec]) -> HordeSpec:
    """Create a HordeSpec from a sequence of GVFSpec demons.

    Pre-computes gamma and lambda arrays for efficient JIT usage.

    Args:
        demons: Sequence of GVFSpec, one per demon/head

    Returns:
        HordeSpec with pre-computed arrays
    """
    demons_tuple = tuple(demons)
    gammas = jnp.array([d.gamma for d in demons_tuple], dtype=jnp.float32)
    lamdas = jnp.array([d.lamda for d in demons_tuple], dtype=jnp.float32)
    return HordeSpec(demons=demons_tuple, gammas=gammas, lamdas=lamdas)

create_obgd_state(feature_dim, step_size=1.0, kappa=2.0, gamma=0.0, lamda=0.0)

Create initial ObGD optimizer state.

Args: feature_dim: Dimension of the feature vector step_size: Base learning rate (default: 1.0) kappa: Bounding sensitivity parameter (default: 2.0) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised)

Returns: Initial ObGD state

Source code in src/alberta_framework/core/types.py
def create_obgd_state(
    feature_dim: int,
    step_size: float = 1.0,
    kappa: float = 2.0,
    gamma: float = 0.0,
    lamda: float = 0.0,
) -> ObGDState:
    """Create initial ObGD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        step_size: Base learning rate (default: 1.0)
        kappa: Bounding sensitivity parameter (default: 2.0)
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0 for supervised)

    Returns:
        Initial ObGD state
    """
    return ObGDState(
        step_size=jnp.array(step_size, dtype=jnp.float32),
        kappa=jnp.array(kappa, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        gamma=jnp.array(gamma, dtype=jnp.float32),
        lamda=jnp.array(lamda, dtype=jnp.float32),
    )

create_tdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0)

Create initial TD-IDBD optimizer state.

Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))

Returns: Initial TD-IDBD state

Source code in src/alberta_framework/core/types.py
def create_tdidbd_state(
    feature_dim: int,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
) -> TDIDBDState:
    """Create initial TD-IDBD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        initial_step_size: Initial per-weight step-size
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))

    Returns:
        Initial TD-IDBD state
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)

Create a per-feature scale array spanning a range.

Utility function to generate scale factors for ScaledStreamWrapper.

Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.

Returns: Array of shape (feature_dim,) with scale factors

Examples:

scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
def make_scale_range(
    feature_dim: int,
    min_scale: float = 0.001,
    max_scale: float = 1000.0,
    log_spaced: bool = True,
) -> Array:
    """Create a per-feature scale array spanning a range.

    Utility function to generate scale factors for ScaledStreamWrapper.

    Args:
        feature_dim: Number of features
        min_scale: Minimum scale factor
        max_scale: Maximum scale factor
        log_spaced: If True, scales are logarithmically spaced (default).
            If False, scales are linearly spaced.

    Returns:
        Array of shape (feature_dim,) with scale factors

    Examples
    --------
    ```python
    scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
    stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
    ```
    """
    if log_spaced:
        return jnp.logspace(
            jnp.log10(min_scale),
            jnp.log10(max_scale),
            feature_dim,
            dtype=jnp.float32,
        )
    else:
        return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)

compare_learners(results, metric='squared_error')

Compare multiple learners on a given metric.

Args: results: Dictionary mapping learner name to metrics history metric: Metric to compare

Returns: Dictionary with summary statistics for each learner

Source code in src/alberta_framework/utils/metrics.py
def compare_learners(
    results: dict[str, list[dict[str, float]]],
    metric: str = "squared_error",
) -> dict[str, dict[str, float]]:
    """Compare multiple learners on a given metric.

    Args:
        results: Dictionary mapping learner name to metrics history
        metric: Metric to compare

    Returns:
        Dictionary with summary statistics for each learner
    """
    summary = {}
    for name, metrics_history in results.items():
        values = extract_metric(metrics_history, metric)
        summary[name] = {
            "mean": float(np.mean(values)),
            "std": float(np.std(values)),
            "cumulative": float(np.sum(values)),
            "final_100_mean": (
                float(np.mean(values[-100:])) if len(values) >= 100 else float(np.mean(values))
            ),
        }
    return summary

compute_cumulative_error(metrics_history, error_key='squared_error')

Compute cumulative error over time.

Args: metrics_history: List of metric dictionaries from learning loop error_key: Key to extract error values

Returns: Array of cumulative errors at each time step

Source code in src/alberta_framework/utils/metrics.py
def compute_cumulative_error(
    metrics_history: list[dict[str, float]],
    error_key: str = "squared_error",
) -> NDArray[np.float64]:
    """Compute cumulative error over time.

    Args:
        metrics_history: List of metric dictionaries from learning loop
        error_key: Key to extract error values

    Returns:
        Array of cumulative errors at each time step
    """
    errors = np.array([m[error_key] for m in metrics_history])
    return np.cumsum(errors)

compute_running_mean(values, window_size=100)

Compute running mean of values.

Args: values: Array of values window_size: Size of the moving average window

Returns: Array of running mean values (same length as input, padded at start)

Source code in src/alberta_framework/utils/metrics.py
def compute_running_mean(
    values: NDArray[np.float64] | list[float],
    window_size: int = 100,
) -> NDArray[np.float64]:
    """Compute running mean of values.

    Args:
        values: Array of values
        window_size: Size of the moving average window

    Returns:
        Array of running mean values (same length as input, padded at start)
    """
    values_arr = np.asarray(values)
    cumsum = np.cumsum(np.insert(values_arr, 0, 0))
    running_mean = (cumsum[window_size:] - cumsum[:-window_size]) / window_size

    # Pad the beginning with the first computed mean
    if len(running_mean) > 0:
        padding = np.full(window_size - 1, running_mean[0])
        return np.concatenate([padding, running_mean])
    return values_arr

compute_tracking_error(metrics_history, window_size=100)

Compute tracking error (running mean of squared error).

This is the key metric for evaluating continual learners: how well can the learner track the non-stationary target?

Args: metrics_history: List of metric dictionaries from learning loop window_size: Size of the moving average window

Returns: Array of tracking errors at each time step

Source code in src/alberta_framework/utils/metrics.py
def compute_tracking_error(
    metrics_history: list[dict[str, float]],
    window_size: int = 100,
) -> NDArray[np.float64]:
    """Compute tracking error (running mean of squared error).

    This is the key metric for evaluating continual learners:
    how well can the learner track the non-stationary target?

    Args:
        metrics_history: List of metric dictionaries from learning loop
        window_size: Size of the moving average window

    Returns:
        Array of tracking errors at each time step
    """
    errors = np.array([m["squared_error"] for m in metrics_history])
    return compute_running_mean(errors, window_size)

extract_metric(metrics_history, key)

Extract a single metric from the history.

Args: metrics_history: List of metric dictionaries key: Key to extract

Returns: Array of values for that metric

Source code in src/alberta_framework/utils/metrics.py
def extract_metric(
    metrics_history: list[dict[str, float]],
    key: str,
) -> NDArray[np.float64]:
    """Extract a single metric from the history.

    Args:
        metrics_history: List of metric dictionaries
        key: Key to extract

    Returns:
        Array of values for that metric
    """
    return np.array([m[key] for m in metrics_history])

format_duration(seconds)

Format a duration in seconds as a human-readable string.

Args: seconds: Duration in seconds

Returns: Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"

Examples:

format_duration(0.5)   # Returns: '0.50s'
format_duration(90.5)  # Returns: '1m 30.50s'
format_duration(3665)  # Returns: '1h 1m 5.00s'
Source code in src/alberta_framework/utils/timing.py
def format_duration(seconds: float) -> str:
    """Format a duration in seconds as a human-readable string.

    Args:
        seconds: Duration in seconds

    Returns:
        Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"

    Examples
    --------
    ```python
    format_duration(0.5)   # Returns: '0.50s'
    format_duration(90.5)  # Returns: '1m 30.50s'
    format_duration(3665)  # Returns: '1h 1m 5.00s'
    ```
    """
    if seconds < 60:
        return f"{seconds:.2f}s"
    elif seconds < 3600:
        minutes = int(seconds // 60)
        secs = seconds % 60
        return f"{minutes}m {secs:.2f}s"
    else:
        hours = int(seconds // 3600)
        remaining = seconds % 3600
        minutes = int(remaining // 60)
        secs = remaining % 60
        return f"{hours}h {minutes}m {secs:.2f}s"

collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)

Collect a trajectory from a Gymnasium environment.

This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy

Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)

Source code in src/alberta_framework/streams/gymnasium.py
def collect_trajectory(
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None,
    num_steps: int,
    mode: PredictionMode = PredictionMode.REWARD,
    include_action_in_features: bool = True,
    seed: int = 0,
) -> tuple[Array, Array]:
    """Collect a trajectory from a Gymnasium environment.

    This uses a Python loop to interact with the environment and collects
    observations and targets into JAX arrays that can be used with scan-based
    learning.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        num_steps: Number of steps to collect
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed for environment resets and random policy

    Returns:
        Tuple of (observations, targets) as JAX arrays with shape
        (num_steps, feature_dim) and (num_steps, target_dim)
    """
    if policy is None:
        policy = make_random_policy(env, seed)

    observations = []
    targets = []

    reset_count = 0
    raw_obs, _ = env.reset(seed=seed + reset_count)
    reset_count += 1
    current_obs = _flatten_observation(raw_obs, env.observation_space)

    for _ in range(num_steps):
        action = policy(current_obs)
        flat_action = _flatten_action(action, env.action_space)

        raw_next_obs, reward, terminated, truncated, _ = env.step(action)
        next_obs = _flatten_observation(raw_next_obs, env.observation_space)

        # Construct features
        if include_action_in_features:
            features = jnp.concatenate([current_obs, flat_action])
        else:
            features = current_obs

        # Construct target based on mode
        if mode == PredictionMode.REWARD:
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
        elif mode == PredictionMode.NEXT_STATE:
            target = next_obs
        else:  # VALUE mode
            # TD target with 0 bootstrap (simple version)
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))

        observations.append(features)
        targets.append(target)

        if terminated or truncated:
            raw_obs, _ = env.reset(seed=seed + reset_count)
            reset_count += 1
            current_obs = _flatten_observation(raw_obs, env.observation_space)
        else:
            current_obs = next_obs

    return jnp.stack(observations), jnp.stack(targets)

learn_from_trajectory(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory using jax.lax.scan.

This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.

Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory(
    learner: LinearLearner,
    observations: Array,
    targets: Array,
    learner_state: LearnerState | None = None,
) -> tuple[LearnerState, Array]:
    """Learn from a pre-collected trajectory using jax.lax.scan.

    This is a JIT-compiled learning function that processes a trajectory
    collected from a Gymnasium environment.

    Args:
        learner: The learner to train
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 3) with columns [squared_error, error, mean_step_size]
    """
    if learner_state is None:
        learner_state = learner.init(observations.shape[1])

    def step_fn(state: LearnerState, inputs: tuple[Array, Array]) -> tuple[LearnerState, Array]:
        obs, target = inputs
        result = learner.update(state, obs, target)
        return result.state, result.metrics

    t0 = time.time()
    final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))
    elapsed = time.time() - t0
    final_state = final_state.replace(uptime_s=final_state.uptime_s + elapsed)  # type: ignore[attr-defined]

    return final_state, metrics

learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory with normalization using jax.lax.scan.

This is equivalent to learn_from_trajectory for a learner constructed with a normalizer (e.g. LinearLearner(optimizer=..., normalizer=EMANormalizer())). Retained for backward compatibility.

Args: learner: The learner to train (should have a normalizer configured) observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory_normalized(
    learner: LinearLearner,
    observations: Array,
    targets: Array,
    learner_state: LearnerState | None = None,
) -> tuple[LearnerState, Array]:
    """Learn from a pre-collected trajectory with normalization using jax.lax.scan.

    This is equivalent to ``learn_from_trajectory`` for a learner constructed
    with a normalizer (e.g. ``LinearLearner(optimizer=..., normalizer=EMANormalizer())``).
    Retained for backward compatibility.

    Args:
        learner: The learner to train (should have a normalizer configured)
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
    """
    return learn_from_trajectory(learner, observations, targets, learner_state)

make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)

Wrap a policy with epsilon-greedy exploration.

Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed

Returns: Epsilon-greedy policy

Source code in src/alberta_framework/streams/gymnasium.py
def make_epsilon_greedy_policy(
    base_policy: Callable[[Array], Any],
    env: gymnasium.Env[Any, Any],
    epsilon: float = 0.1,
    seed: int = 0,
) -> Callable[[Array], Any]:
    """Wrap a policy with epsilon-greedy exploration.

    Args:
        base_policy: The greedy policy to wrap
        env: Gymnasium environment (for random action sampling)
        epsilon: Probability of taking a random action
        seed: Random seed

    Returns:
        Epsilon-greedy policy
    """
    random_policy = make_random_policy(env, seed + 1)
    rng = jr.key(seed)

    def policy(obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if jr.uniform(key) < epsilon:
            return random_policy(obs)
        return base_policy(obs)

    return policy

make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)

Factory function to create a GymnasiumStream from an environment ID.

Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()

Returns: GymnasiumStream wrapping the environment

Source code in src/alberta_framework/streams/gymnasium.py
def make_gymnasium_stream(
    env_id: str,
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
    **env_kwargs: Any,
) -> GymnasiumStream:
    """Factory function to create a GymnasiumStream from an environment ID.

    Args:
        env_id: Gymnasium environment ID (e.g., "CartPole-v1")
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed
        **env_kwargs: Additional arguments passed to gymnasium.make()

    Returns:
        GymnasiumStream wrapping the environment
    """
    import gymnasium

    env = gymnasium.make(env_id, **env_kwargs)
    return GymnasiumStream(
        env=env,
        mode=mode,
        policy=policy,
        gamma=gamma,
        include_action_in_features=include_action_in_features,
        seed=seed,
    )

make_random_policy(env, seed=0)

Create a random action policy for an environment.

Args: env: Gymnasium environment seed: Random seed

Returns: A callable that takes an observation and returns a random action

Source code in src/alberta_framework/streams/gymnasium.py
def make_random_policy(env: gymnasium.Env[Any, Any], seed: int = 0) -> Callable[[Array], Any]:
    """Create a random action policy for an environment.

    Args:
        env: Gymnasium environment
        seed: Random seed

    Returns:
        A callable that takes an observation and returns a random action
    """
    import gymnasium

    rng = jr.key(seed)
    action_space = env.action_space

    def policy(_obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if isinstance(action_space, gymnasium.spaces.Discrete):
            return int(jr.randint(key, (), 0, int(action_space.n)))
        elif isinstance(action_space, gymnasium.spaces.Box):
            # Sample uniformly between low and high
            low = jnp.asarray(action_space.low, dtype=jnp.float32)
            high = jnp.asarray(action_space.high, dtype=jnp.float32)
            return jr.uniform(key, action_space.shape, minval=low, maxval=high)
        elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
            nvec = action_space.nvec
            return [int(jr.randint(jr.fold_in(key, i), (), 0, n)) for i, n in enumerate(nvec)]
        else:
            raise ValueError(f"Unsupported action space: {type(action_space).__name__}")

    return policy