Skip to content

alberta_framework

alberta_framework

Alberta Framework: A JAX-based research framework for continual AI.

The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity — every component updates at every time step, with no special training phases or batch processing.

Roadmap
Step Focus Status
1 Meta-learned step-sizes (IDBD, Autostep) Complete
2 Nonlinear function approximation (MLP, ObGD) In Progress
3 GVF predictions, Horde architecture Planned
4 Actor-critic with eligibility traces Planned
5-6 Off-policy learning, average reward Planned
7-12 Hierarchical, multi-agent, world models Future

Examples:

import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop

# Non-stationary stream where target weights drift over time
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)

# Learner with IDBD meta-learned step-sizes
learner = LinearLearner(optimizer=IDBD())

# JIT-compiled training via jax.lax.scan
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
References
  • The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
  • Adapting Bias by Gradient Descent (Sutton, 1992)
  • Tuning-free Step-size Adaptation (Mahmood et al., 2012)
  • Streaming Deep Reinforcement Learning Finally Works (Elsayed et al., 2024)

LinearLearner(optimizer=None, normalizer=None)

Linear function approximator with pluggable optimizer and optional normalizer.

Computes predictions as: y = w @ x + b

The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Attributes: optimizer: The optimizer to use for weight updates normalizer: Optional online feature normalizer

Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01) normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)

Source code in src/alberta_framework/core/learners.py
def __init__(
    self,
    optimizer: AnyOptimizer | None = None,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
):
    """Initialize the linear learner.

    Args:
        optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
        normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)
    """
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
    self._normalizer = normalizer

normalizer property

The feature normalizer, or None if normalization is disabled.

init(feature_dim)

Initialize learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> LearnerState:
    """Initialize learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    normalizer_state = None
    if self._normalizer is not None:
        normalizer_state = self._normalizer.init(feature_dim)

    return LearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
        normalizer_state=normalizer_state,
    )

predict(state, observation)

Compute prediction for an observation.

Args: state: Current learner state observation: Input feature vector

Returns: Scalar prediction y = w @ x + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Scalar prediction ``y = w @ x + b``
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, target)

Update learner given observation and target.

Performs one step of the learning algorithm: 1. Optionally normalize observation 2. Compute prediction 3. Compute error 4. Get weight updates from optimizer 5. Apply updates to weights and bias

Args: state: Current learner state observation: Input feature vector target: Desired output

Returns: UpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: LearnerState,
    observation: Observation,
    target: Target,
) -> UpdateResult:
    """Update learner given observation and target.

    Performs one step of the learning algorithm:
    1. Optionally normalize observation
    2. Compute prediction
    3. Compute error
    4. Get weight updates from optimizer
    5. Apply updates to weights and bias

    Args:
        state: Current learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        UpdateResult with new state, prediction, error, and metrics
    """
    # Handle normalization
    new_normalizer_state = state.normalizer_state
    obs = observation
    if self._normalizer is not None and state.normalizer_state is not None:
        obs, new_normalizer_state = self._normalizer.normalize(
            state.normalizer_state, observation
        )

    # Make prediction
    prediction = self.predict(
        LearnerState(
            weights=state.weights,
            bias=state.bias,
            optimizer_state=state.optimizer_state,
            normalizer_state=new_normalizer_state,
        ),
        obs,
    )

    # Compute error (target - prediction)
    error = jnp.squeeze(target) - jnp.squeeze(prediction)

    # Get update from optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        error,
        obs,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = LearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
        normalizer_state=new_normalizer_state,
    )

    # Pack metrics as array for scan compatibility
    squared_error = error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)

    if self._normalizer is not None and new_normalizer_state is not None:
        normalizer_mean_var = jnp.mean(new_normalizer_state.var)
        metrics = jnp.array(
            [squared_error, error, mean_step_size, normalizer_mean_var],
            dtype=jnp.float32,
        )
    else:
        metrics = jnp.array(
            [squared_error, error, mean_step_size], dtype=jnp.float32
        )

    return UpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

MLPLearner(hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, gamma=0.0, lamda=0.0, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True)

Multi-layer perceptron with composable optimizer, bounder, and normalizer.

Architecture: Input -> [Dense(H) -> LayerNorm -> LeakyReLU] x N -> Dense(1)

When use_layer_norm=False, the architecture simplifies to: Input -> [Dense(H) -> LeakyReLU] x N -> Dense(1)

Uses parameterless layer normalization and sparse initialization following Elsayed et al. 2024. Accepts a pluggable optimizer (LMS, Autostep), an optional bounder (ObGDBounding), and an optional feature normalizer (EMANormalizer, WelfordNormalizer).

The update flow: 1. If normalizer: normalize observation, update normalizer state 2. Forward pass + jax.grad to get per-layer prediction gradients 3. Update eligibility traces: z = gamma * lamda * z + grad 4. Per-layer optimizer step: step, new_opt = optimizer.update_from_gradient(state, z) 5. If bounder: bound all steps globally 6. Apply: param += scale * error * step

Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"

Attributes: hidden_sizes: Tuple of hidden layer sizes optimizer: Optimizer for per-weight step-size adaptation bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer use_layer_norm: Whether to apply parameterless layer normalization gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter sparsity: Fraction of weights zeroed out per output neuron leaky_relu_slope: Negative slope for LeakyReLU activation

Args: hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). Must support init_for_shape and update_from_gradient. step_size: Base learning rate (used only when optimizer is None, default: 1.0) bounder: Optional update bounder (e.g. ObGDBounding for ObGD-style bounding). When None, no bounding is applied. gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised) normalizer: Optional feature normalizer. When provided, features are normalized before prediction and learning. sparsity: Fraction of weights zeroed out per output neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization between hidden layers (default: True). Set to False for ablation studies.

Source code in src/alberta_framework/core/learners.py
def __init__(
    self,
    hidden_sizes: tuple[int, ...] = (128, 128),
    optimizer: AnyOptimizer | None = None,
    step_size: float = 1.0,
    bounder: Bounder | None = None,
    gamma: float = 0.0,
    lamda: float = 0.0,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
    sparsity: float = 0.9,
    leaky_relu_slope: float = 0.01,
    use_layer_norm: bool = True,
):
    """Initialize MLP learner.

    Args:
        hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
        optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
            Must support ``init_for_shape`` and ``update_from_gradient``.
        step_size: Base learning rate (used only when optimizer is None,
            default: 1.0)
        bounder: Optional update bounder (e.g. ObGDBounding for ObGD-style
            bounding). When None, no bounding is applied.
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
        normalizer: Optional feature normalizer. When provided, features are
            normalized before prediction and learning.
        sparsity: Fraction of weights zeroed out per output neuron (default: 0.9)
        leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
        use_layer_norm: Whether to apply parameterless layer normalization
            between hidden layers (default: True). Set to False for ablation
            studies.
    """
    self._hidden_sizes = hidden_sizes
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=step_size)
    self._bounder = bounder
    self._gamma = gamma
    self._lamda = lamda
    self._normalizer = normalizer
    self._sparsity = sparsity
    self._leaky_relu_slope = leaky_relu_slope
    self._use_layer_norm = use_layer_norm

normalizer property

The feature normalizer, or None if normalization is disabled.

init(feature_dim, key)

Initialize MLP learner state with sparse weights.

Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization

Returns: Initial MLP learner state with sparse weights and zero biases

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int, key: Array) -> MLPLearnerState:
    """Initialize MLP learner state with sparse weights.

    Args:
        feature_dim: Dimension of the input feature vector
        key: JAX random key for weight initialization

    Returns:
        Initial MLP learner state with sparse weights and zero biases
    """
    # Build layer sizes: [feature_dim, hidden1, hidden2, ..., 1]
    layer_sizes = [feature_dim, *self._hidden_sizes, 1]

    weights_list = []
    biases_list = []
    traces_list = []
    opt_states_list = []

    for i in range(len(layer_sizes) - 1):
        fan_out = layer_sizes[i + 1]
        fan_in = layer_sizes[i]
        key, subkey = jax.random.split(key)
        w = sparse_init(subkey, (fan_out, fan_in), sparsity=self._sparsity)
        b = jnp.zeros(fan_out, dtype=jnp.float32)
        weights_list.append(w)
        biases_list.append(b)
        # Traces for weights and biases (interleaved: w0, b0, w1, b1, ...)
        traces_list.append(jnp.zeros_like(w))
        traces_list.append(jnp.zeros_like(b))
        # Optimizer states for weights and biases
        opt_states_list.append(self._optimizer.init_for_shape(w.shape))
        opt_states_list.append(self._optimizer.init_for_shape(b.shape))

    params = MLPParams(
        weights=tuple(weights_list),
        biases=tuple(biases_list),
    )

    normalizer_state = None
    if self._normalizer is not None:
        normalizer_state = self._normalizer.init(feature_dim)

    return MLPLearnerState(
        params=params,
        optimizer_states=tuple(opt_states_list),
        traces=tuple(traces_list),
        normalizer_state=normalizer_state,
    )

predict(state, observation)

Compute prediction for an observation.

Args: state: Current MLP learner state observation: Input feature vector

Returns: Scalar prediction

Source code in src/alberta_framework/core/learners.py
def predict(self, state: MLPLearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    Args:
        state: Current MLP learner state
        observation: Input feature vector

    Returns:
        Scalar prediction
    """
    y = self._forward(
        state.params.weights,
        state.params.biases,
        observation,
        self._leaky_relu_slope,
        self._use_layer_norm,
    )
    return jnp.atleast_1d(y)

update(state, observation, target)

Update MLP given observation and target.

Performs one step of the learning algorithm: 1. Optionally normalize observation 2. Compute prediction and error 3. Compute gradients via jax.grad on the forward pass 4. Update eligibility traces 5. Per-layer optimizer step from traces 6. Optionally bound steps 7. Apply bounded weight updates

Args: state: Current MLP learner state observation: Input feature vector target: Desired output

Returns: MLPUpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: MLPLearnerState,
    observation: Observation,
    target: Target,
) -> MLPUpdateResult:
    """Update MLP given observation and target.

    Performs one step of the learning algorithm:
    1. Optionally normalize observation
    2. Compute prediction and error
    3. Compute gradients via jax.grad on the forward pass
    4. Update eligibility traces
    5. Per-layer optimizer step from traces
    6. Optionally bound steps
    7. Apply bounded weight updates

    Args:
        state: Current MLP learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        MLPUpdateResult with new state, prediction, error, and metrics
    """
    target_scalar = jnp.squeeze(target)

    # Handle normalization
    obs = observation
    new_normalizer_state = state.normalizer_state
    if self._normalizer is not None and state.normalizer_state is not None:
        obs, new_normalizer_state = self._normalizer.normalize(
            state.normalizer_state, observation
        )

    # Forward pass for prediction
    prediction_val = self._forward(
        state.params.weights,
        state.params.biases,
        obs,
        self._leaky_relu_slope,
        self._use_layer_norm,
    )
    prediction = jnp.atleast_1d(prediction_val)
    error = target_scalar - prediction_val

    # Compute gradients w.r.t. prediction
    slope = self._leaky_relu_slope
    ln = self._use_layer_norm

    def pred_fn(weights: tuple[Array, ...], biases: tuple[Array, ...]) -> Array:
        return self._forward(weights, biases, obs, slope, ln)

    weight_grads, bias_grads = jax.grad(pred_fn, argnums=(0, 1))(
        state.params.weights, state.params.biases
    )

    # Update eligibility traces: z = gamma * lamda * z + grad
    gamma_lamda = jnp.array(self._gamma * self._lamda, dtype=jnp.float32)
    n_layers = len(state.params.weights)

    new_traces = []
    for i in range(n_layers):
        # Weight trace (index 2*i)
        new_wt = gamma_lamda * state.traces[2 * i] + weight_grads[i]
        new_traces.append(new_wt)
        # Bias trace (index 2*i + 1)
        new_bt = gamma_lamda * state.traces[2 * i + 1] + bias_grads[i]
        new_traces.append(new_bt)

    # Per-parameter optimizer step from traces
    all_steps = []
    new_opt_states = []
    for j in range(len(new_traces)):
        step, new_opt = self._optimizer.update_from_gradient(
            state.optimizer_states[j], new_traces[j], error=error
        )
        all_steps.append(step)
        new_opt_states.append(new_opt)

    # Bounding (optional)
    bounding_metric = jnp.array(1.0, dtype=jnp.float32)
    if self._bounder is not None:
        all_params = []
        for i in range(n_layers):
            all_params.append(state.params.weights[i])
            all_params.append(state.params.biases[i])
        bounded_steps, bounding_metric = self._bounder.bound(
            tuple(all_steps), error, tuple(all_params)
        )
        all_steps = list(bounded_steps)

    # Apply updates: param += error * step
    new_weights = []
    new_biases = []
    for i in range(n_layers):
        new_w = state.params.weights[i] + error * all_steps[2 * i]
        new_weights.append(new_w)
        new_b = state.params.biases[i] + error * all_steps[2 * i + 1]
        new_biases.append(new_b)

    new_params = MLPParams(
        weights=tuple(new_weights), biases=tuple(new_biases)
    )
    new_state = MLPLearnerState(
        params=new_params,
        optimizer_states=tuple(new_opt_states),
        traces=tuple(new_traces),
        normalizer_state=new_normalizer_state,
    )

    squared_error = error**2

    if self._normalizer is not None and new_normalizer_state is not None:
        normalizer_mean_var = jnp.mean(new_normalizer_state.var)
        metrics = jnp.array(
            [squared_error, error, bounding_metric, normalizer_mean_var],
            dtype=jnp.float32,
        )
    else:
        metrics = jnp.array(
            [squared_error, error, bounding_metric], dtype=jnp.float32
        )

    return MLPUpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

MLPUpdateResult

Result of an MLP learner update step.

Attributes: state: Updated MLP learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics -- shape (3,) without normalizer, (4,) with normalizer

TDLinearLearner(optimizer=None)

Linear function approximator for TD learning.

Computes value predictions as: V(s) = w @ phi(s) + b

The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: optimizer: The TD optimizer to use for weight updates

Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyTDOptimizer | None = None):
    """Initialize the TD linear learner.

    Args:
        optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
    """
    self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()

init(feature_dim)

Initialize TD learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial TD learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> TDLearnerState:
    """Initialize TD learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial TD learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return TDLearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
    )

predict(state, observation)

Compute value prediction for an observation.

Args: state: Current TD learner state observation: Input feature vector phi(s)

Returns: Scalar value prediction V(s) = w @ phi(s) + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
    """Compute value prediction for an observation.

    Args:
        state: Current TD learner state
        observation: Input feature vector phi(s)

    Returns:
        Scalar value prediction ``V(s) = w @ phi(s) + b``
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, reward, next_observation, gamma)

Update learner given a TD transition.

Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error delta = R + gamma*V(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias

Args: state: Current TD learner state observation: Current observation phi(s) reward: Reward R received next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal states)

Returns: TDUpdateResult with new state, prediction, TD error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: TDLearnerState,
    observation: Observation,
    reward: Array,
    next_observation: Observation,
    gamma: Array,
) -> TDUpdateResult:
    """Update learner given a TD transition.

    Performs one step of TD learning:
    1. Compute V(s) and V(s')
    2. Compute TD error delta = R + gamma*V(s') - V(s)
    3. Get weight updates from TD optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current TD learner state
        observation: Current observation phi(s)
        reward: Reward R received
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal states)

    Returns:
        TDUpdateResult with new state, prediction, TD error, and metrics
    """
    # Compute predictions
    prediction = self.predict(state, observation)
    next_prediction = self.predict(state, next_observation)

    # Compute TD error: delta = R + gamma*V(s') - V(s)
    gamma_scalar = jnp.squeeze(gamma)
    td_error = (
        jnp.squeeze(reward)
        + gamma_scalar * jnp.squeeze(next_prediction)
        - jnp.squeeze(prediction)
    )

    # Get update from TD optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        td_error,
        observation,
        next_observation,
        gamma,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = TDLearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
    )

    # Pack metrics as array for scan compatibility
    squared_td_error = td_error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
    metrics = jnp.array(
        [squared_td_error, td_error, mean_step_size, mean_elig_trace],
        dtype=jnp.float32,
    )

    return TDUpdateResult(
        state=new_state,
        prediction=prediction,
        td_error=jnp.atleast_1d(td_error),
        metrics=metrics,
    )

TDUpdateResult

Result of a TD learner update step.

Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error delta = R + gamma*V(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]

UpdateResult

Result of a learner update step.

Attributes: state: Updated learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics -- shape (3,) without normalizer, (4,) with normalizer

EMANormalizer(epsilon=1e-08, decay=0.99)

Bases: Normalizer[EMANormalizerState]

Online feature normalizer using exponential moving average.

Estimates mean and variance via EMA, suitable for non-stationary environments where recent observations should be weighted more heavily.

The effective decay ramps up from 0 to the target decay over early steps to prevent instability.

Attributes: epsilon: Small constant for numerical stability decay: Exponential decay for running estimates (0.99 = slower adaptation)

Args: epsilon: Small constant added to std for numerical stability decay: Exponential decay factor for running estimates. Lower values adapt faster to changes. 1.0 means pure online average (no decay).

Source code in src/alberta_framework/core/normalizers.py
def __init__(
    self,
    epsilon: float = 1e-8,
    decay: float = 0.99,
):
    """Initialize the EMA normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
        decay: Exponential decay factor for running estimates.
               Lower values adapt faster to changes.
               1.0 means pure online average (no decay).
    """
    super().__init__(epsilon=epsilon)
    self._decay = decay

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: StateT,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: StateT,
    observation: Array,
) -> StateT:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

init(feature_dim)

Initialize EMA normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
def init(self, feature_dim: int) -> EMANormalizerState:
    """Initialize EMA normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    return EMANormalizerState(
        mean=jnp.zeros(feature_dim, dtype=jnp.float32),
        var=jnp.ones(feature_dim, dtype=jnp.float32),
        sample_count=jnp.array(0.0, dtype=jnp.float32),
        decay=jnp.array(self._decay, dtype=jnp.float32),
    )

normalize(state, observation)

Normalize observation and update EMA running statistics.

Args: state: Current EMA normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
def normalize(
    self,
    state: EMANormalizerState,
    observation: Array,
) -> tuple[Array, EMANormalizerState]:
    """Normalize observation and update EMA running statistics.

    Args:
        state: Current EMA normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    # Update count
    new_count = state.sample_count + 1.0

    # Compute effective decay (ramp up from 0 to target decay)
    # This prevents instability in early steps
    effective_decay = jnp.minimum(state.decay, 1.0 - 1.0 / (new_count + 1.0))

    # Update mean using exponential moving average
    delta = observation - state.mean
    new_mean = state.mean + (1.0 - effective_decay) * delta

    # Update variance using exponential moving average of squared deviations
    delta2 = observation - new_mean
    new_var = effective_decay * state.var + (1.0 - effective_decay) * delta * delta2

    # Ensure variance is positive
    new_var = jnp.maximum(new_var, self._epsilon)

    # Normalize using updated statistics
    std = jnp.sqrt(new_var)
    normalized = (observation - new_mean) / (std + self._epsilon)

    new_state = EMANormalizerState(
        mean=new_mean,
        var=new_var,
        sample_count=new_count,
        decay=state.decay,
    )

    return normalized, new_state

EMANormalizerState

State for EMA-based online feature normalization.

Uses exponential moving average to estimate running mean and variance, suitable for non-stationary distributions.

Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature sample_count: Number of samples seen decay: Exponential decay factor for estimates (1.0 = no decay, pure online)

Normalizer(epsilon=1e-08)

Bases: ABC

Abstract base class for online feature normalizers.

Normalizes features using running estimates of mean and standard deviation: x_normalized = (x - mean) / (std + epsilon)

The normalizer updates its estimates at every time step, following temporal uniformity.

Subclasses must implement init and normalize. The normalize_only and update_only methods have default implementations.

Attributes: epsilon: Small constant for numerical stability

Args: epsilon: Small constant added to std for numerical stability

Source code in src/alberta_framework/core/normalizers.py
def __init__(self, epsilon: float = 1e-8):
    """Initialize the normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
    """
    self._epsilon = epsilon

init(feature_dim) abstractmethod

Initialize normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    ...

normalize(state, observation) abstractmethod

Normalize observation and update running statistics.

This method both normalizes the current observation AND updates the running statistics, maintaining temporal uniformity.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
@abstractmethod
def normalize(
    self,
    state: StateT,
    observation: Array,
) -> tuple[Array, StateT]:
    """Normalize observation and update running statistics.

    This method both normalizes the current observation AND updates
    the running statistics, maintaining temporal uniformity.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    ...

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: StateT,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: StateT,
    observation: Array,
) -> StateT:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

WelfordNormalizer(epsilon=1e-08)

Bases: Normalizer[WelfordNormalizerState]

Online feature normalizer using Welford's algorithm.

Computes cumulative sample mean and variance with Bessel's correction, suitable for stationary distributions. Numerically stable for large sample counts.

Reference: Welford 1962, "Note on a Method for Calculating Corrected Sums of Squares and Products"

Attributes: epsilon: Small constant for numerical stability

Args: epsilon: Small constant added to std for numerical stability

Source code in src/alberta_framework/core/normalizers.py
def __init__(self, epsilon: float = 1e-8):
    """Initialize the normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
    """
    self._epsilon = epsilon

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: StateT,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: StateT,
    observation: Array,
) -> StateT:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

init(feature_dim)

Initialize Welford normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
def init(self, feature_dim: int) -> WelfordNormalizerState:
    """Initialize Welford normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    return WelfordNormalizerState(
        mean=jnp.zeros(feature_dim, dtype=jnp.float32),
        var=jnp.ones(feature_dim, dtype=jnp.float32),
        sample_count=jnp.array(0.0, dtype=jnp.float32),
        p=jnp.zeros(feature_dim, dtype=jnp.float32),
    )

normalize(state, observation)

Normalize observation and update Welford running statistics.

Uses Welford's online algorithm: 1. Increment count 2. Update mean incrementally 3. Update sum of squared deviations (p / M2) 4. Compute variance with Bessel's correction when count >= 2

Args: state: Current Welford normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
def normalize(
    self,
    state: WelfordNormalizerState,
    observation: Array,
) -> tuple[Array, WelfordNormalizerState]:
    """Normalize observation and update Welford running statistics.

    Uses Welford's online algorithm:
    1. Increment count
    2. Update mean incrementally
    3. Update sum of squared deviations (p / M2)
    4. Compute variance with Bessel's correction when count >= 2

    Args:
        state: Current Welford normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    new_count = state.sample_count + 1.0

    # Welford's incremental mean update
    delta = observation - state.mean
    new_mean = state.mean + delta / new_count

    # Update sum of squared deviations: p += (x - old_mean) * (x - new_mean)
    delta2 = observation - new_mean
    new_p = state.p + delta * delta2

    # Bessel-corrected variance; use unit variance when count < 2
    new_var = jnp.where(
        new_count >= 2.0,
        new_p / (new_count - 1.0),
        jnp.ones_like(new_p),
    )

    # Normalize using updated statistics
    std = jnp.sqrt(new_var)
    normalized = (observation - new_mean) / (std + self._epsilon)

    new_state = WelfordNormalizerState(
        mean=new_mean,
        var=new_var,
        sample_count=new_count,
        p=new_p,
    )

    return normalized, new_state

WelfordNormalizerState

State for Welford's online normalization algorithm.

Uses Welford's algorithm for numerically stable estimation of cumulative sample mean and variance with Bessel's correction.

Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature (Bessel-corrected) sample_count: Number of samples seen p: Sum of squared deviations from the current mean (M2 accumulator)

IDBD(initial_step_size=0.01, meta_step_size=0.01)

Bases: Optimizer[IDBDState]

Incremental Delta-Bar-Delta optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.

This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
):
    """Initialize IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate beta for adapting step-sizes
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size

init_for_shape(shape)

Initialize optimizer state for parameters of arbitrary shape.

Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.

The return type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: shape: Shape of the parameter array

Returns: Initial optimizer state with arrays matching the given shape

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> Any:
    """Initialize optimizer state for parameters of arbitrary shape.

    Used by MLP learners where parameters are matrices/vectors of
    varying shapes. Not all optimizers support this.

    The return type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        shape: Shape of the parameter array

    Returns:
        Initial optimizer state with arrays matching the given shape

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support init_for_shape. "
        "Only LMS and Autostep currently implement this."
    )

update_from_gradient(state, gradient, error=None)

Compute step delta from pre-computed gradient.

The returned delta does NOT include the error -- the caller is responsible for multiplying error * delta before applying.

The state type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.

Returns: (step, new_state) where step has the same shape as gradient

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: Any, gradient: Array, error: Array | None = None
) -> tuple[Array, Any]:
    """Compute step delta from pre-computed gradient.

    The returned delta does NOT include the error -- the caller is
    responsible for multiplying ``error * delta`` before applying.

    The state type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        state: Current optimizer state
        gradient: Pre-computed gradient (e.g. eligibility trace)
        error: Optional prediction error scalar. Optimizers with
            meta-learning (e.g. Autostep) use this for meta-gradient
            computation. LMS ignores it.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support update_from_gradient. "
        "Only LMS and Autostep currently implement this."
    )

init(feature_dim)

Initialize IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: IDBD state with per-weight step-sizes and traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> IDBDState:
    """Initialize IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        IDBD state with per-weight step-sizes and traces
    """
    return IDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, error, observation)

Compute IDBD weight update with adaptive step-sizes.

Following Sutton 1992, Figure 2, the operation ordering is:

  1. Meta-update: log_alpha_i += beta * error * x_i * h_i (using OLD traces)
  2. Compute NEW step-sizes: alpha_i = exp(log_alpha_i)
  3. Update weights: w_i += alpha_i * error * x_i (using NEW alpha)
  4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i (using NEW alpha)

The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.

Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: IDBDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute IDBD weight update with adaptive step-sizes.

    Following Sutton 1992, Figure 2, the operation ordering is:

    1. Meta-update: ``log_alpha_i += beta * error * x_i * h_i`` (using OLD traces)
    2. Compute NEW step-sizes: ``alpha_i = exp(log_alpha_i)``
    3. Update weights: ``w_i += alpha_i * error * x_i`` (using NEW alpha)
    4. Update traces: ``h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i``
       (using NEW alpha)

    The trace h_i tracks the correlation between current and past gradients.
    When gradients consistently point the same direction, h_i grows,
    leading to larger step-sizes.

    Args:
        state: Current IDBD state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    beta = state.meta_step_size

    # 1. Meta-update: adapt step-sizes using OLD traces
    gradient_correlation = error_scalar * observation * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # 2. Compute NEW step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # 3. Weight updates using NEW alpha: alpha_i * error * x_i
    weight_delta = new_alphas * error_scalar * observation

    # 4. Update traces using NEW alpha: h_i = h_i * decay + alpha_i * error * x_i
    # decay = max(0, 1 - alpha_i * x_i^2)
    decay = jnp.maximum(0.0, 1.0 - new_alphas * observation**2)
    new_traces = state.traces * decay + new_alphas * error_scalar * observation

    # Bias updates (same ordering: meta-update first, then new alpha)
    bias_gradient_correlation = error_scalar * state.bias_trace
    new_bias_step_size = state.bias_step_size * jnp.exp(beta * bias_gradient_correlation)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)

    bias_delta = new_bias_step_size * error_scalar

    bias_decay = jnp.maximum(0.0, 1.0 - new_bias_step_size)
    new_bias_trace = state.bias_trace * bias_decay + new_bias_step_size * error_scalar

    new_state = IDBDState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
        },
    )

LMS(step_size=0.01)

Bases: Optimizer[LMSState]

Least Mean Square optimizer with fixed step-size.

The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t

This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.

Attributes: step_size: Fixed learning rate alpha

Args: step_size: Fixed learning rate

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, step_size: float = 0.01):
    """Initialize LMS optimizer.

    Args:
        step_size: Fixed learning rate
    """
    self._step_size = step_size

init(feature_dim)

Initialize LMS state.

Args: feature_dim: Dimension of weight vector (unused for LMS)

Returns: LMS state containing the step-size

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> LMSState:
    """Initialize LMS state.

    Args:
        feature_dim: Dimension of weight vector (unused for LMS)

    Returns:
        LMS state containing the step-size
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

init_for_shape(shape)

Initialize LMS state for arbitrary-shape parameters.

LMS state is shape-independent (single scalar), so this returns the same state regardless of shape.

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> LMSState:
    """Initialize LMS state for arbitrary-shape parameters.

    LMS state is shape-independent (single scalar), so this returns
    the same state regardless of shape.
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

update_from_gradient(state, gradient, error=None)

Compute step from gradient: step = alpha * gradient.

Args: state: Current LMS state gradient: Pre-computed gradient (any shape) error: Unused by LMS (accepted for interface compatibility)

Returns: (step, state) -- state is unchanged for LMS

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: LMSState, gradient: Array, error: Array | None = None
) -> tuple[Array, LMSState]:
    """Compute step from gradient: ``step = alpha * gradient``.

    Args:
        state: Current LMS state
        gradient: Pre-computed gradient (any shape)
        error: Unused by LMS (accepted for interface compatibility)

    Returns:
        ``(step, state)`` -- state is unchanged for LMS
    """
    del error  # LMS doesn't meta-learn
    return state.step_size * gradient, state

update(state, error, observation)

Compute LMS weight update.

Update rule: delta_w = alpha * error * x

Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight and bias deltas

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: LMSState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute LMS weight update.

    Update rule: ``delta_w = alpha * error * x``

    Args:
        state: Current LMS state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight and bias deltas
    """
    alpha = state.step_size
    error_scalar = jnp.squeeze(error)

    # Weight update: alpha * error * x
    weight_delta = alpha * error_scalar * observation

    # Bias update: alpha * error
    bias_delta = alpha * error_scalar

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=state,  # LMS state doesn't change
        metrics={"step_size": alpha},
    )

TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)

Bases: TDOptimizer[TDIDBDState]

TD-IDBD optimizer for temporal-difference learning.

Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Two variants are supported: - Semi-gradient (default): Uses only phi(s) in meta-update, more stable - Ordinary gradient: Uses both phi(s) and phi(s'), more accurate but sensitive

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    use_semi_gradient: bool = True,
):
    """Initialize TD-IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        use_semi_gradient: If True, use semi-gradient variant (recommended)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._use_semi_gradient = use_semi_gradient

init(feature_dim)

Initialize TD-IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> TDIDBDState:
    """Initialize TD-IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        TD-IDBD state with per-weight step-sizes, traces, and h traces
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute TD-IDBD weight update with adaptive step-sizes.

Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.

Args: state: Current TD-IDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: TDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute TD-IDBD weight update with adaptive step-sizes.

    Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
    from Kearney et al. 2019.

    Args:
        state: Current TD-IDBD state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    gamma_scalar = jnp.squeeze(gamma)

    if self._use_semi_gradient:
        gradient_correlation = delta * observation * state.h_traces
        new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
    else:
        feature_diff = gamma_scalar * next_observation - observation
        gradient_correlation = delta * feature_diff * state.h_traces
        new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation

    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
    new_alphas = jnp.exp(new_log_step_sizes)

    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
    weight_delta = new_alphas * delta * new_eligibility_traces

    if self._use_semi_gradient:
        h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
    else:
        feature_diff = gamma_scalar * next_observation - observation
        h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    if self._use_semi_gradient:
        bias_gradient_correlation = delta * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
    else:
        bias_feature_diff = gamma_scalar - 1.0
        bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    if self._use_semi_gradient:
        bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )
    else:
        bias_feature_diff = gamma_scalar - 1.0
        bias_h_decay = jnp.maximum(
            0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
        )
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )

    new_state = TDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        meta_step_size=theta,
        trace_decay=lam,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
        },
    )

AGCBounding(clip_factor=0.01, eps=0.001)

Bases: Bounder

Adaptive Gradient Clipping (Brock et al. 2021).

Clips per-output-unit based on the ratio of gradient norm to weight norm. Units where ||grad|| / max(||weight||, eps) > clip_factor get scaled down to respect the constraint.

Unlike ObGDBounding which applies a single global scale factor, AGC applies fine-grained, per-unit clipping that adapts to each layer's weight magnitude.

The metric returned is the fraction of units that were clipped (0.0 = no clipping, 1.0 = all units clipped).

Reference: Brock, A., De, S., Smith, S.L., & Simonyan, K. (2021). "High-Performance Large-Scale Image Recognition Without Normalization" (arXiv: 2102.06171)

Attributes: clip_factor: Maximum allowed gradient-to-weight ratio (lambda). Default 0.01. eps: Floor for weight norm to avoid division by zero. Default 1e-3.

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, clip_factor: float = 0.01, eps: float = 1e-3):
    self._clip_factor = clip_factor
    self._eps = eps

bound(steps, error, params)

Bound proposed steps using per-unit adaptive gradient clipping.

For each parameter/step pair, computes unit-wise norms and clips units where |error| * ||step|| > clip_factor * max(||param||, eps).

Args: steps: Per-parameter step arrays from the optimizer error: Prediction error scalar params: Current parameter values (used for weight norms)

Returns: (clipped_steps, frac_clipped) where frac_clipped is the fraction of units that were clipped

Source code in src/alberta_framework/core/optimizers.py
def bound(
    self,
    steps: tuple[Array, ...],
    error: Array,
    params: tuple[Array, ...],
) -> tuple[tuple[Array, ...], Array]:
    """Bound proposed steps using per-unit adaptive gradient clipping.

    For each parameter/step pair, computes unit-wise norms and clips
    units where ``|error| * ||step|| > clip_factor * max(||param||, eps)``.

    Args:
        steps: Per-parameter step arrays from the optimizer
        error: Prediction error scalar
        params: Current parameter values (used for weight norms)

    Returns:
        ``(clipped_steps, frac_clipped)`` where frac_clipped is the
        fraction of units that were clipped
    """
    error_abs = jnp.abs(jnp.squeeze(error))
    clipped = []
    total_units = 0
    clipped_units = jnp.array(0.0)

    for step, param in zip(steps, params):
        p_norm = _unitwise_norm(param)
        s_norm = _unitwise_norm(step)
        g_norm = error_abs * s_norm
        max_norm = jnp.maximum(p_norm, self._eps) * self._clip_factor
        scale = max_norm / jnp.maximum(g_norm, 1e-6)
        needs_clip = g_norm > max_norm
        clipped_step = jnp.where(needs_clip, step * scale, step)
        clipped.append(clipped_step)

        total_units += needs_clip.size
        clipped_units = clipped_units + jnp.sum(needs_clip.astype(jnp.float32))

    frac_clipped = clipped_units / jnp.maximum(total_units, 1)
    return tuple(clipped), frac_clipped

Autostep(initial_step_size=0.01, meta_step_size=0.01, tau=10000.0)

Bases: Optimizer[AutostepState]

Autostep optimizer with tuning-free step-size adaptation.

Implements the exact algorithm from Mahmood et al. 2012, Table 1.

The algorithm maintains per-weight step-sizes that adapt based on meta-gradient correlation. The key innovations are: - Self-regulated normalizers (v_i) that track meta-gradient magnitude |delta * x_i * h_i| for stable meta-updates - Overshoot prevention via effective step-size normalization M = max(sum(alpha_i * x_i^2), 1)

Per-sample update (Table 1):

  1. v_i = max(|delta*x_i*h_i|, v_i + (1/tau)*alpha_i*x_i^2*(|delta*x_i*h_i| - v_i))
  2. alpha_i *= exp(mu * delta*x_i*h_i / v_i) where v_i > 0
  3. M = max(sum(alpha_i * x_i^2), 1); alpha_i /= M
  4. w_i += alpha_i * delta * x_i (weight update with NEW alpha)
  5. h_i = h_i * (1 - alpha_i * x_i^2) + alpha_i * delta * x_i (trace update)

Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012). "Tuning-free step-size adaptation"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate mu for adapting step-sizes tau: Time constant for normalizer adaptation (default: 10000)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate for adapting step-sizes tau: Time constant for normalizer adaptation (default: 10000). Higher values mean slower normalizer decay.

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    tau: float = 10000.0,
):
    """Initialize Autostep optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate for adapting step-sizes
        tau: Time constant for normalizer adaptation (default: 10000).
            Higher values mean slower normalizer decay.
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._tau = tau

init(feature_dim)

Initialize Autostep state.

Normalizers (v_i) and traces (h_i) are initialized to 0 per the paper.

Args: feature_dim: Dimension of weight vector

Returns: Autostep state with per-weight step-sizes, traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutostepState:
    """Initialize Autostep state.

    Normalizers (v_i) and traces (h_i) are initialized to 0 per the paper.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Autostep state with per-weight step-sizes, traces, and normalizers
    """
    return AutostepState(
        step_sizes=jnp.full(feature_dim, self._initial_step_size, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        tau=jnp.array(self._tau, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(0.0, dtype=jnp.float32),
    )

init_for_shape(shape)

Initialize Autostep state for arbitrary-shape parameters.

Args: shape: Shape of the parameter array

Returns: AutostepParamState with arrays matching the given shape

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> AutostepParamState:
    """Initialize Autostep state for arbitrary-shape parameters.

    Args:
        shape: Shape of the parameter array

    Returns:
        AutostepParamState with arrays matching the given shape
    """
    return AutostepParamState(
        step_sizes=jnp.full(shape, self._initial_step_size, dtype=jnp.float32),
        traces=jnp.zeros(shape, dtype=jnp.float32),
        normalizers=jnp.zeros(shape, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        tau=jnp.array(self._tau, dtype=jnp.float32),
    )

update_from_gradient(state, gradient, error=None)

Compute Autostep update from pre-computed gradient (MLP path).

Implements the Table 1 algorithm generalized for arbitrary-shape parameters, where gradient plays the role of the eligibility trace z (prediction gradient).

When error is provided, the full paper algorithm is used: meta-gradient is error * z * h. When error is None, falls back to error-free approximation (z * h).

The returned step does NOT include the error -- the caller applies param += error * step after optional bounding.

Args: state: Current Autostep param state gradient: Pre-computed gradient / eligibility trace (same shape as state arrays) error: Optional prediction error scalar. When provided, enables the full paper algorithm with error-scaled meta-gradients.

Returns: (step, new_state) where step has the same shape as gradient

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self,
    state: AutostepParamState,
    gradient: Array,
    error: Array | None = None,
) -> tuple[Array, AutostepParamState]:
    """Compute Autostep update from pre-computed gradient (MLP path).

    Implements the Table 1 algorithm generalized for arbitrary-shape
    parameters, where ``gradient`` plays the role of the eligibility
    trace ``z`` (prediction gradient).

    When ``error`` is provided, the full paper algorithm is used:
    meta-gradient is ``error * z * h``. When ``error`` is None,
    falls back to error-free approximation (``z * h``).

    The returned step does NOT include the error -- the caller applies
    ``param += error * step`` after optional bounding.

    Args:
        state: Current Autostep param state
        gradient: Pre-computed gradient / eligibility trace (same shape as state arrays)
        error: Optional prediction error scalar. When provided, enables
            the full paper algorithm with error-scaled meta-gradients.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient
    """
    mu = state.meta_step_size
    tau = state.tau

    z = gradient  # eligibility trace
    z_sq = z**2

    # Compute meta-gradient: δ*z*h (or z*h if error is None)
    if error is not None:
        error_scalar = jnp.squeeze(error)
        meta_gradient = error_scalar * z * state.traces
    else:
        meta_gradient = z * state.traces

    abs_meta_gradient = jnp.abs(meta_gradient)

    # Eq. 4: v_i = max(|meta_grad|, v_i + (1/τ)*α_i*z_i²*(|meta_grad| - v_i))
    v_update = state.normalizers + (1.0 / tau) * state.step_sizes * z_sq * (
        abs_meta_gradient - state.normalizers
    )
    new_normalizers = jnp.maximum(abs_meta_gradient, v_update)

    # Eq. 5: α_i *= exp(μ * meta_grad / v_i) where v_i > 0
    safe_v = jnp.maximum(new_normalizers, 1e-38)
    new_step_sizes = jnp.where(
        new_normalizers > 0,
        state.step_sizes * jnp.exp(mu * meta_gradient / safe_v),
        state.step_sizes,
    )

    # Eq. 6-7: M = max(Σ α_i*z_i², 1); α_i /= M
    effective_step = jnp.sum(new_step_sizes * z_sq)
    m_factor = jnp.maximum(effective_step, 1.0)
    new_step_sizes = new_step_sizes / m_factor

    # Clip step-sizes for numerical safety
    new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)

    # Compute step: α_i * z_i (error applied externally)
    step = new_step_sizes * z

    # Trace update: h_i = h_i*(1 - α_i*z_i²) + α_i*δ*z_i
    trace_decay = 1.0 - new_step_sizes * z_sq
    if error is not None:
        new_traces = state.traces * trace_decay + new_step_sizes * error_scalar * z
    else:
        new_traces = state.traces * trace_decay + new_step_sizes * z

    new_state = AutostepParamState(
        step_sizes=new_step_sizes,
        traces=new_traces,
        normalizers=new_normalizers,
        meta_step_size=mu,
        tau=tau,
    )

    return step, new_state

update(state, error, observation)

Compute Autostep weight update following Mahmood et al. 2012, Table 1.

The algorithm per sample:

  1. Eq. 4: v_i = max(|δ*x_i*h_i|, v_i + (1/τ)*α_i*x_i²*(|δ*x_i*h_i| - v_i))
  2. Eq. 5: α_i *= exp(μ * δ*x_i*h_i / v_i) where v_i > 0
  3. Eq. 6-7: M = max(Σ α_i*x_i² + α_bias, 1); α_i /= M, α_bias /= M
  4. Weight update: w_i += α_i * δ * x_i (with NEW alpha)
  5. Trace update: h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i

Args: state: Current Autostep state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutostepState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute Autostep weight update following Mahmood et al. 2012, Table 1.

    The algorithm per sample:

    1. Eq. 4: ``v_i = max(|δ*x_i*h_i|, v_i + (1/τ)*α_i*x_i²*(|δ*x_i*h_i| - v_i))``
    2. Eq. 5: ``α_i *= exp(μ * δ*x_i*h_i / v_i)`` where ``v_i > 0``
    3. Eq. 6-7: ``M = max(Σ α_i*x_i² + α_bias, 1)``; ``α_i /= M``, ``α_bias /= M``
    4. Weight update: ``w_i += α_i * δ * x_i`` (with NEW alpha)
    5. Trace update: ``h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i``

    Args:
        state: Current Autostep state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    mu = state.meta_step_size
    tau = state.tau

    x = observation
    x_sq = x**2

    # --- Weights ---
    # Meta-gradient: δ*x_i*h_i
    meta_gradient = error_scalar * x * state.traces
    abs_meta_gradient = jnp.abs(meta_gradient)

    # Eq. 4: v_i update (self-regulated EMA)
    v_update = state.normalizers + (1.0 / tau) * state.step_sizes * x_sq * (
        abs_meta_gradient - state.normalizers
    )
    new_normalizers = jnp.maximum(abs_meta_gradient, v_update)

    # Eq. 5: α_i *= exp(μ * meta_grad / v_i) where v_i > 0
    safe_v = jnp.maximum(new_normalizers, 1e-38)
    new_step_sizes = jnp.where(
        new_normalizers > 0,
        state.step_sizes * jnp.exp(mu * meta_gradient / safe_v),
        state.step_sizes,
    )

    # --- Bias ---
    # Meta-gradient for bias (implicit x=1): δ*h_bias
    bias_meta_gradient = error_scalar * state.bias_trace
    abs_bias_meta_gradient = jnp.abs(bias_meta_gradient)

    # Eq. 4 for bias
    bias_v_update = state.bias_normalizer + (1.0 / tau) * state.bias_step_size * (
        abs_bias_meta_gradient - state.bias_normalizer
    )
    new_bias_normalizer = jnp.maximum(abs_bias_meta_gradient, bias_v_update)

    # Eq. 5 for bias
    safe_bias_v = jnp.maximum(new_bias_normalizer, 1e-38)
    new_bias_step_size = jnp.where(
        new_bias_normalizer > 0,
        state.bias_step_size * jnp.exp(mu * bias_meta_gradient / safe_bias_v),
        state.bias_step_size,
    )

    # Eq. 6-7: Overshoot prevention (joint over weights + bias)
    # M = max(Σ α_i*x_i² + α_bias*1², 1)
    effective_step = jnp.sum(new_step_sizes * x_sq) + new_bias_step_size
    m_factor = jnp.maximum(effective_step, 1.0)
    new_step_sizes = new_step_sizes / m_factor
    new_bias_step_size = new_bias_step_size / m_factor

    # Clip step-sizes for numerical safety
    new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-8, 1.0)

    # Weight update with NEW alpha: α_i * δ * x_i
    weight_delta = new_step_sizes * error_scalar * x

    # Bias update: α_bias * δ
    bias_delta = new_bias_step_size * error_scalar

    # Trace update: h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i
    trace_decay = 1.0 - new_step_sizes * x_sq
    new_traces = state.traces * trace_decay + new_step_sizes * error_scalar * x

    # Bias trace: h_bias = h_bias*(1 - α_bias) + α_bias*δ
    bias_trace_decay = 1.0 - new_bias_step_size
    new_bias_trace = state.bias_trace * bias_trace_decay + new_bias_step_size * error_scalar

    new_state = AutostepState(
        step_sizes=new_step_sizes,
        traces=new_traces,
        normalizers=new_normalizers,
        meta_step_size=mu,
        tau=tau,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_step_sizes),
            "min_step_size": jnp.min(new_step_sizes),
            "max_step_size": jnp.max(new_step_sizes),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Bases: TDOptimizer[AutoTDIDBDState]

AutoStep-style normalized TD-IDBD optimizer.

Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta.

Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(lambda)"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
):
    """Initialize AutoTDIDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._normalizer_decay = normalizer_decay

init(feature_dim)

Initialize AutoTDIDBD state.

Args: feature_dim: Dimension of weight vector

Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutoTDIDBDState:
    """Initialize AutoTDIDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

Implements Algorithm 6 from Kearney et al. 2019.

Args: state: Current AutoTDIDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutoTDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

    Implements Algorithm 6 from Kearney et al. 2019.

    Args:
        state: Current AutoTDIDBD state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    tau = state.normalizer_decay
    gamma_scalar = jnp.squeeze(gamma)

    feature_diff = gamma_scalar * next_observation - observation
    alphas = jnp.exp(state.log_step_sizes)

    # Update normalizers
    abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
    normalizer_decay_term = (
        (1.0 / tau)
        * alphas
        * feature_diff
        * state.eligibility_traces
        * (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
    )
    new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
    new_normalizers = jnp.maximum(new_normalizers, 1e-8)

    # Normalized meta-update
    normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
    new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient

    # Effective step-size normalization
    effective_step_size = -jnp.sum(
        jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
    )
    normalization_factor = jnp.maximum(effective_step_size, 1.0)
    new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)

    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
    new_alphas = jnp.exp(new_log_step_sizes)

    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
    weight_delta = new_alphas * delta * new_eligibility_traces

    # Update h traces
    h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
    new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    bias_alpha = jnp.exp(state.bias_log_step_size)
    bias_feature_diff = gamma_scalar - 1.0

    abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
    bias_normalizer_decay_term = (
        (1.0 / tau)
        * bias_alpha
        * bias_feature_diff
        * state.bias_eligibility_trace
        * (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
    )
    new_bias_normalizer = jnp.maximum(
        abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
    )
    new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)

    normalized_bias_gradient = (
        delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
    )
    new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient

    bias_effective_step_size = (
        -jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
    )
    bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
    new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    bias_h_decay = jnp.maximum(
        0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
    )
    new_bias_h_trace = (
        state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
    )

    new_state = AutoTDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        normalizers=new_normalizers,
        meta_step_size=theta,
        trace_decay=lam,
        normalizer_decay=tau,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

Bounder

Bases: ABC

Base class for update bounding strategies.

A bounder takes the proposed per-parameter step arrays from an optimizer and optionally scales them down to prevent overshooting.

bound(steps, error, params) abstractmethod

Bound proposed update steps.

Args: steps: Per-parameter step arrays from the optimizer error: Prediction error scalar params: Current parameter values (needed by some bounders like AGC)

Returns: (bounded_steps, metric) where metric is a scalar for reporting (e.g., scale factor for ObGD, mean clip ratio for AGC)

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def bound(
    self,
    steps: tuple[Array, ...],
    error: Array,
    params: tuple[Array, ...],
) -> tuple[tuple[Array, ...], Array]:
    """Bound proposed update steps.

    Args:
        steps: Per-parameter step arrays from the optimizer
        error: Prediction error scalar
        params: Current parameter values (needed by some bounders like AGC)

    Returns:
        ``(bounded_steps, metric)`` where metric is a scalar for reporting
        (e.g., scale factor for ObGD, mean clip ratio for AGC)
    """
    ...

ObGD(step_size=1.0, kappa=2.0, gamma=0.0, lamda=0.0)

Bases: Optimizer[ObGDState]

Observation-bounded Gradient Descent optimizer.

ObGD prevents overshooting by dynamically bounding the effective step-size based on the magnitude of the prediction error and eligibility traces. When the combined update magnitude would be too large, the step-size is scaled down to prevent the prediction from overshooting the target.

This is the deep-network generalization of Autostep's overshooting prevention, designed for streaming reinforcement learning.

For supervised learning (gamma=0, lamda=0), traces equal the current observation each step, making ObGD equivalent to LMS with dynamic step-size bounding.

The ObGD algorithm:

  1. Update traces: z = gamma * lamda * z + observation
  2. Compute bound: M = alpha * kappa * max(|error|, 1) * (||z_w||_1 + |z_b|)
  3. Effective step: alpha_eff = min(alpha, alpha / M) (i.e. alpha / max(M, 1))
  4. Weight delta: delta_w = alpha_eff * error * z_w
  5. Bias delta: delta_b = alpha_eff * error * z_b

Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"

Attributes: step_size: Base learning rate alpha kappa: Bounding sensitivity parameter (higher = more conservative) gamma: Discount factor for trace decay (0 for supervised learning) lamda: Eligibility trace decay parameter (0 for supervised learning)

Args: step_size: Base learning rate (default: 1.0) kappa: Bounding sensitivity parameter (default: 2.0) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    step_size: float = 1.0,
    kappa: float = 2.0,
    gamma: float = 0.0,
    lamda: float = 0.0,
):
    """Initialize ObGD optimizer.

    Args:
        step_size: Base learning rate (default: 1.0)
        kappa: Bounding sensitivity parameter (default: 2.0)
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
    """
    self._step_size = step_size
    self._kappa = kappa
    self._gamma = gamma
    self._lamda = lamda

init_for_shape(shape)

Initialize optimizer state for parameters of arbitrary shape.

Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.

The return type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: shape: Shape of the parameter array

Returns: Initial optimizer state with arrays matching the given shape

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> Any:
    """Initialize optimizer state for parameters of arbitrary shape.

    Used by MLP learners where parameters are matrices/vectors of
    varying shapes. Not all optimizers support this.

    The return type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        shape: Shape of the parameter array

    Returns:
        Initial optimizer state with arrays matching the given shape

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support init_for_shape. "
        "Only LMS and Autostep currently implement this."
    )

update_from_gradient(state, gradient, error=None)

Compute step delta from pre-computed gradient.

The returned delta does NOT include the error -- the caller is responsible for multiplying error * delta before applying.

The state type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.

Returns: (step, new_state) where step has the same shape as gradient

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: Any, gradient: Array, error: Array | None = None
) -> tuple[Array, Any]:
    """Compute step delta from pre-computed gradient.

    The returned delta does NOT include the error -- the caller is
    responsible for multiplying ``error * delta`` before applying.

    The state type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        state: Current optimizer state
        gradient: Pre-computed gradient (e.g. eligibility trace)
        error: Optional prediction error scalar. Optimizers with
            meta-learning (e.g. Autostep) use this for meta-gradient
            computation. LMS ignores it.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support update_from_gradient. "
        "Only LMS and Autostep currently implement this."
    )

init(feature_dim)

Initialize ObGD state.

Args: feature_dim: Dimension of weight vector

Returns: ObGD state with eligibility traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> ObGDState:
    """Initialize ObGD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        ObGD state with eligibility traces
    """
    return ObGDState(
        step_size=jnp.array(self._step_size, dtype=jnp.float32),
        kappa=jnp.array(self._kappa, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        gamma=jnp.array(self._gamma, dtype=jnp.float32),
        lamda=jnp.array(self._lamda, dtype=jnp.float32),
    )

update(state, error, observation)

Compute ObGD weight update with overshooting prevention.

The bounding mechanism scales down the step-size when the combined effect of error magnitude, trace norm, and step-size would cause the prediction to overshoot the target.

Args: state: Current ObGD state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with bounded weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: ObGDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute ObGD weight update with overshooting prevention.

    The bounding mechanism scales down the step-size when the combined
    effect of error magnitude, trace norm, and step-size would cause
    the prediction to overshoot the target.

    Args:
        state: Current ObGD state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with bounded weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    alpha = state.step_size
    kappa = state.kappa

    # Update eligibility traces: z = gamma * lamda * z + observation
    new_traces = state.gamma * state.lamda * state.traces + observation
    new_bias_trace = state.gamma * state.lamda * state.bias_trace + 1.0

    # Compute z_sum (L1 norm of all traces)
    z_sum = jnp.sum(jnp.abs(new_traces)) + jnp.abs(new_bias_trace)

    # Compute bounding factor: M = alpha * kappa * max(|error|, 1) * z_sum
    delta_bar = jnp.maximum(jnp.abs(error_scalar), 1.0)
    dot_product = delta_bar * z_sum * alpha * kappa

    # Effective step-size: alpha / max(M, 1)
    alpha_eff = alpha / jnp.maximum(dot_product, 1.0)

    # Weight and bias deltas
    weight_delta = alpha_eff * error_scalar * new_traces
    bias_delta = alpha_eff * error_scalar * new_bias_trace

    new_state = ObGDState(
        step_size=alpha,
        kappa=kappa,
        traces=new_traces,
        bias_trace=new_bias_trace,
        gamma=state.gamma,
        lamda=state.lamda,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "step_size": alpha,
            "effective_step_size": alpha_eff,
            "bounding_factor": dot_product,
        },
    )

ObGDBounding(kappa=2.0)

Bases: Bounder

ObGD-style global update bounding (Elsayed et al. 2024).

Computes a global bounding factor from the L1 norm of all proposed steps and the error magnitude, then uniformly scales all steps down if the combined update would be too large.

For LMS with a single scalar step-size alpha: total_step = alpha * z_sum, giving M = alpha * kappa * max(|error|, 1) * z_sum -- identical to the original Elsayed et al. 2024 formula.

Attributes: kappa: Bounding sensitivity parameter (higher = more conservative)

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, kappa: float = 2.0):
    self._kappa = kappa

bound(steps, error, params)

Bound proposed steps using ObGD formula.

Args: steps: Per-parameter step arrays error: Prediction error scalar params: Current parameter values (unused by ObGD)

Returns: (bounded_steps, scale) where scale is the bounding factor

Source code in src/alberta_framework/core/optimizers.py
def bound(
    self,
    steps: tuple[Array, ...],
    error: Array,
    params: tuple[Array, ...],
) -> tuple[tuple[Array, ...], Array]:
    """Bound proposed steps using ObGD formula.

    Args:
        steps: Per-parameter step arrays
        error: Prediction error scalar
        params: Current parameter values (unused by ObGD)

    Returns:
        ``(bounded_steps, scale)`` where scale is the bounding factor
    """
    del params  # ObGD bounds based on step/error magnitude only
    error_scalar = jnp.squeeze(error)
    total_step = jnp.array(0.0)
    for s in steps:
        total_step = total_step + jnp.sum(jnp.abs(s))
    delta_bar = jnp.maximum(jnp.abs(error_scalar), 1.0)
    bound_magnitude = self._kappa * delta_bar * total_step
    scale = 1.0 / jnp.maximum(bound_magnitude, 1.0)
    bounded = tuple(scale * s for s in steps)
    return bounded, scale

Optimizer

Bases: ABC

Base class for optimizers.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, error, observation) abstractmethod

Compute weight updates given prediction error.

Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute weight updates given prediction error.

    Args:
        state: Current optimizer state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with deltas and new state
    """
    ...

init_for_shape(shape)

Initialize optimizer state for parameters of arbitrary shape.

Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.

The return type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: shape: Shape of the parameter array

Returns: Initial optimizer state with arrays matching the given shape

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> Any:
    """Initialize optimizer state for parameters of arbitrary shape.

    Used by MLP learners where parameters are matrices/vectors of
    varying shapes. Not all optimizers support this.

    The return type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        shape: Shape of the parameter array

    Returns:
        Initial optimizer state with arrays matching the given shape

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support init_for_shape. "
        "Only LMS and Autostep currently implement this."
    )

update_from_gradient(state, gradient, error=None)

Compute step delta from pre-computed gradient.

The returned delta does NOT include the error -- the caller is responsible for multiplying error * delta before applying.

The state type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.

Returns: (step, new_state) where step has the same shape as gradient

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: Any, gradient: Array, error: Array | None = None
) -> tuple[Array, Any]:
    """Compute step delta from pre-computed gradient.

    The returned delta does NOT include the error -- the caller is
    responsible for multiplying ``error * delta`` before applying.

    The state type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        state: Current optimizer state
        gradient: Pre-computed gradient (e.g. eligibility trace)
        error: Optional prediction error scalar. Optimizers with
            meta-learning (e.g. Autostep) use this for meta-gradient
            computation. LMS ignores it.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support update_from_gradient. "
        "Only LMS and Autostep currently implement this."
    )

TDOptimizer

Bases: ABC

Base class for TD optimizers.

TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, td_error, observation, next_observation, gamma) abstractmethod

Compute weight updates given TD error.

Args: state: Current optimizer state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute weight updates given TD error.

    Args:
        state: Current optimizer state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with deltas and new state
    """
    ...

TDOptimizerUpdate

Result of a TD optimizer update step.

Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging

AutostepParamState

Per-parameter Autostep state for use with arbitrary-shape parameters.

Used by Autostep.init_for_shape / Autostep.update_from_gradient for MLP (or other multi-parameter) learners. Unlike AutostepState, this type has no bias-specific fields -- each parameter (weight matrix, bias vector) gets its own AutostepParamState.

Attributes: step_sizes: Per-element step-sizes, same shape as the parameter traces: Per-element traces for gradient correlation normalizers: Running normalizer of meta-gradient magnitude |deltazh| meta_step_size: Meta learning rate mu tau: Time constant for normalizer adaptation

AutostepState

State for the Autostep optimizer.

Autostep is a tuning-free step-size adaptation algorithm that adapts per-weight step-sizes based on meta-gradient correlation, with self-regulated normalizers to stabilize the meta-update.

Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation", Table 1

Attributes: step_sizes: Per-weight step-sizes (alpha_i) traces: Per-weight traces for gradient correlation (h_i) normalizers: Running normalizer of meta-gradient magnitude |deltaxh| (v_i) meta_step_size: Meta learning rate mu for adapting step-sizes tau: Time constant for normalizer adaptation (higher = slower decay) bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term bias_normalizer: Normalizer for the bias meta-gradient

AutoTDIDBDState

State for the AutoTDIDBD optimizer.

AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.

Reference: Kearney et al. 2019, Algorithm 6

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation

BatchedLearningResult

Result from batched learning loop across multiple seeds.

Used with run_learning_loop_batched for vmap-based GPU parallelization.

Attributes: states: Batched learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, num_cols) where num_cols is 3 (no normalizer) or 4 (with normalizer) step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled

BatchedMLPResult

Result from batched MLP learning loop across multiple seeds.

Used with run_mlp_learning_loop_batched for vmap-based GPU parallelization.

Attributes: states: Batched MLP learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, num_cols) where num_cols is 3 (no normalizer) or 4 (with normalizer) normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled

IDBDState

State for the IDBD (Incremental Delta-Bar-Delta) optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.

Reference: Sutton 1992, "Adapting Bias by Gradient Descent"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term

LearnerState

State for a linear learner.

Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer normalizer_state: Optional state for online feature normalization

LMSState

State for the LMS (Least Mean Square) optimizer.

LMS uses a fixed step-size, so state only tracks the step-size parameter.

Attributes: step_size: Fixed learning rate alpha

MLPLearnerState

State for an MLP learner.

Attributes: params: MLP parameters (weights and biases for each layer) optimizer_states: Tuple of per-parameter optimizer states (weights + biases) traces: Tuple of per-parameter eligibility traces normalizer_state: Optional state for online feature normalization

MLPParams

Parameters for a multi-layer perceptron.

Uses tuples of arrays (not lists) for proper JAX PyTree handling.

Attributes: weights: Tuple of weight matrices, one per layer biases: Tuple of bias vectors, one per layer

NormalizerHistory

History of per-feature normalizer state recorded during training.

Used for analyzing how the normalizer (EMA or Welford) adapts to distribution shifts (reactive lag diagnostic).

Attributes: means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim) variances: Per-feature variance estimates at each recording, shape (num_recordings, feature_dim) recording_indices: Step indices where recordings were made, shape (num_recordings,)

NormalizerTrackingConfig

Configuration for recording per-feature normalizer state during training.

Attributes: interval: Record normalizer state every N steps

ObGDState

State for the ObGD (Observation-bounded Gradient Descent) optimizer.

ObGD prevents overshooting by dynamically bounding the effective step-size based on the magnitude of the TD error and eligibility traces. When the combined update magnitude would be too large, the step-size is scaled down.

For supervised learning (gamma=0, lamda=0), traces equal the current observation each step, making ObGD equivalent to LMS with dynamic step-size bounding.

Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"

Attributes: step_size: Base learning rate alpha kappa: Bounding sensitivity parameter (higher = more conservative) traces: Per-weight eligibility traces z_i bias_trace: Eligibility trace for the bias term gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter lambda

StepSizeHistory

History of per-weight step-sizes recorded during training.

Attributes: step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights) bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None recording_indices: Step indices where recordings were made, shape (num_recordings,) normalizers: Autostep's per-weight normalizers (v_i) at each recording, shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.

StepSizeTrackingConfig

Configuration for recording per-weight step-sizes during training.

Attributes: interval: Record step-sizes every N steps include_bias: Whether to also record the bias step-size

TDIDBDState

State for the TD-IDBD (Temporal-Difference IDBD) optimizer.

TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term

TDLearnerState

State for a TD linear learner.

Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer

TDTimeStep

Single experience from a TD stream.

Represents a transition (s, r, s', gamma) for temporal-difference learning.

Attributes: observation: Feature vector phi(s) reward: Reward R received next_observation: Feature vector phi(s') gamma: Discount factor gamma_t (0 at terminal states)

TimeStep

Single experience from an experience stream.

Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)

ScanStream

Bases: Protocol[StateT]

Protocol for JAX scan-compatible experience streams.

Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.

The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.

Type Parameters: StateT: The state type maintained by this stream

Examples:

stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key for initialization

Returns: Initial stream state

Source code in src/alberta_framework/streams/base.py
def init(self, key: Array) -> StateT:
    """Initialize stream state.

    Args:
        key: JAX random key for initialization

    Returns:
        Initial stream state
    """
    ...

step(state, idx)

Generate one time step. Must be JIT-compatible.

This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.

Args: state: Current stream state idx: Current step index (can be ignored for most streams)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/base.py
def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
    """Generate one time step. Must be JIT-compatible.

    This is a pure function that takes the current state and step index,
    and returns a TimeStep along with the updated state. The step index
    can be used for time-dependent behavior but is often ignored.

    Args:
        state: Current stream state
        idx: Current step index (can be ignored for most streams)

    Returns:
        Tuple of (timestep, new_state)
    """
    ...

AbruptChangeState

State for AbruptChangeStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken

AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)

Non-stationary stream with sudden target weight changes.

Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.

Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    change_interval: int = 1000,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the abrupt change stream.

    Args:
        feature_dim: Dimension of feature vectors
        change_interval: Steps between abrupt weight changes
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._change_interval = change_interval
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> AbruptChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return AbruptChangeState(
        key=key,
        true_weights=weights,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_weights, key_x, key_noise = jr.split(state.key, 4)

    # Determine if we should change weights
    should_change = state.step_count % self._change_interval == 0

    # Generate new weights (always generated but only used if should_change)
    new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)

    # Use jnp.where to conditionally update weights (JIT-compatible)
    new_weights = jnp.where(should_change, new_random_weights, state.true_weights)

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = AbruptChangeState(
        key=key,
        true_weights=new_weights,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

CyclicState

State for CyclicStream.

Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken

CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)

Non-stationary stream that cycles between known weight configurations.

Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.

Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    cycle_length: int = 500,
    num_configurations: int = 4,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the cyclic target stream.

    Args:
        feature_dim: Dimension of feature vectors
        cycle_length: Steps spent in each configuration
        num_configurations: Number of configurations to cycle through
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._cycle_length = cycle_length
    self._num_configurations = num_configurations
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with pre-generated configurations

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> CyclicState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with pre-generated configurations
    """
    key, key_configs = jr.split(key)
    configurations = jr.normal(
        key_configs,
        (self._num_configurations, self._feature_dim),
        dtype=jnp.float32,
    )
    return CyclicState(
        key=key,
        configurations=configurations,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Get current configuration index
    config_idx = (state.step_count // self._cycle_length) % self._num_configurations
    true_weights = state.configurations[config_idx]

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = CyclicState(
        key=key,
        configurations=state.configurations,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

DynamicScaleShiftState

State for DynamicScaleShiftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken

DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)

Non-stationary stream with abruptly changing feature scales.

Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.

The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).

Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    scale_change_interval: int = 2000,
    weight_change_interval: int = 1000,
    min_scale: float = 0.01,
    max_scale: float = 100.0,
    noise_std: float = 0.1,
):
    """Initialize the dynamic scale shift stream.

    Args:
        feature_dim: Dimension of feature vectors
        scale_change_interval: Steps between abrupt scale changes
        weight_change_interval: Steps between abrupt weight changes
        min_scale: Minimum scale factor (log-uniform sampling)
        max_scale: Maximum scale factor (log-uniform sampling)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._scale_change_interval = scale_change_interval
    self._weight_change_interval = weight_change_interval
    self._min_scale = min_scale
    self._max_scale = max_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> DynamicScaleShiftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and scales
    """
    key, k_weights, k_scales = jr.split(key, 3)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial scales: log-uniform between min and max
    log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    scales = jnp.exp(log_scales).astype(jnp.float32)
    return DynamicScaleShiftState(
        key=key,
        true_weights=weights,
        current_scales=scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: DynamicScaleShiftState, idx: Array
) -> tuple[TimeStep, DynamicScaleShiftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)

    # Check if scales should change
    should_change_scales = state.step_count % self._scale_change_interval == 0
    new_log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
    new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)

    # Check if weights should change
    should_change_weights = state.step_count % self._weight_change_interval == 0
    new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    x = raw_x * new_scales

    # Target from true weights using RAW features (for consistent difficulty)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = DynamicScaleShiftState(
        key=key,
        true_weights=new_weights,
        current_scales=new_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

PeriodicChangeState

State for PeriodicChangeStream.

Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken

PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights oscillate sinusoidally.

Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase) where each weight has a random phase offset for diversity.

This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.

Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    period: int = 1000,
    amplitude: float = 1.0,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the periodic change stream.

    Args:
        feature_dim: Dimension of feature vectors
        period: Steps for one complete oscillation cycle
        amplitude: Magnitude of weight oscillations around base
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._period = period
    self._amplitude = amplitude
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random base weights and phases

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> PeriodicChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random base weights and phases
    """
    key, key_weights, key_phases = jr.split(key, 3)
    base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
    # Random phases in [0, 2π) for each weight
    phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
    return PeriodicChangeState(
        key=key,
        base_weights=base_weights,
        phases=phases,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: PeriodicChangeState, idx: Array) -> tuple[TimeStep, PeriodicChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
    t = state.step_count.astype(jnp.float32)
    oscillation = self._amplitude * jnp.sin(2.0 * jnp.pi * t / self._period + state.phases)
    true_weights = state.base_weights + oscillation

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = PeriodicChangeState(
        key=key,
        base_weights=state.base_weights,
        phases=state.phases,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

RandomWalkState

State for RandomWalkStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights

RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights drift via random walk.

The true target function is linear: y* = w_true @ x + noise where w_true evolves via random walk at each time step.

This tests the learner's ability to continuously track a moving target.

Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    drift_rate: float = 0.001,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the random walk target stream.

    Args:
        feature_dim: Dimension of the feature/observation vectors
        drift_rate: Std dev of weight changes per step (controls non-stationarity)
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._drift_rate = drift_rate
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> RandomWalkState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return RandomWalkState(key=key, true_weights=weights)

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_drift, k_x, k_noise = jr.split(state.key, 4)

    # Drift weights
    drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
    new_weights = state.true_weights + self._drift_rate * drift

    # Generate observation and target
    x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = RandomWalkState(key=key, true_weights=new_weights)

    return timestep, new_state

ScaleDriftState

State for ScaleDriftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken

ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)

Non-stationary stream where feature scales drift via random walk.

Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.

The target is computed from unscaled features to maintain consistent difficulty across scale changes.

Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    weight_drift_rate: float = 0.001,
    scale_drift_rate: float = 0.01,
    min_log_scale: float = -4.0,  # exp(-4) ~ 0.018
    max_log_scale: float = 4.0,  # exp(4) ~ 54.6
    noise_std: float = 0.1,
):
    """Initialize the scale drift stream.

    Args:
        feature_dim: Dimension of feature vectors
        weight_drift_rate: Std dev of weight drift per step
        scale_drift_rate: Std dev of log-scale drift per step
        min_log_scale: Minimum log-scale (clips drift)
        max_log_scale: Maximum log-scale (clips drift)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._weight_drift_rate = weight_drift_rate
    self._scale_drift_rate = scale_drift_rate
    self._min_log_scale = min_log_scale
    self._max_log_scale = max_log_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and unit scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaleDriftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and unit scales
    """
    key, k_weights = jr.split(key)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial log-scales at 0 (scale = 1)
    log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
    return ScaleDriftState(
        key=key,
        true_weights=weights,
        log_scales=log_scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaleDriftState, idx: Array) -> tuple[TimeStep, ScaleDriftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)

    # Drift target weights
    weight_drift = self._weight_drift_rate * jr.normal(
        k_w_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_weights = state.true_weights + weight_drift

    # Drift log-scales (bounded random walk)
    scale_drift = self._scale_drift_rate * jr.normal(
        k_s_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_log_scales = state.log_scales + scale_drift
    new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    scales = jnp.exp(new_log_scales)
    x = raw_x * scales

    # Target from true weights using RAW features
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = ScaleDriftState(
        key=key,
        true_weights=new_weights,
        log_scales=new_log_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

ScaledStreamState

State for ScaledStreamWrapper.

Attributes: inner_state: State of the wrapped stream

ScaledStreamWrapper(inner_stream, feature_scales)

Wrapper that applies per-feature scaling to any stream's observations.

This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.

Examples:

stream = ScaledStreamWrapper(
    AbruptChangeStream(feature_dim=10, change_interval=1000),
    feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
                              100.0, 1000.0, 0.001, 0.01, 0.1])
)

Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)

Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.

Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim

Source code in src/alberta_framework/streams/synthetic.py
def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
    """Initialize the scaled stream wrapper.

    Args:
        inner_stream: Stream to wrap (must implement ScanStream protocol)
        feature_scales: Array of scale factors, one per feature. Must have
            shape (feature_dim,) matching the inner stream's feature_dim.

    Raises:
        ValueError: If feature_scales length doesn't match inner stream's feature_dim
    """
    self._inner_stream: ScanStream[Any] = inner_stream
    self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)

    if self._feature_scales.shape[0] != inner_stream.feature_dim:
        raise ValueError(
            f"feature_scales length ({self._feature_scales.shape[0]}) "
            f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
        )

feature_dim property

Return the dimension of observation vectors.

inner_stream property

Return the wrapped stream.

feature_scales property

Return the per-feature scale factors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state wrapping the inner stream's state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaledStreamState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state wrapping the inner stream's state
    """
    inner_state = self._inner_stream.init(key)
    return ScaledStreamState(inner_state=inner_state)

step(state, idx)

Generate one time step with scaled observations.

Args: state: Current stream state idx: Current step index

Returns: Tuple of (timestep with scaled observation, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
    """Generate one time step with scaled observations.

    Args:
        state: Current stream state
        idx: Current step index

    Returns:
        Tuple of (timestep with scaled observation, new_state)
    """
    timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)

    # Scale the observation
    scaled_observation = timestep.observation * self._feature_scales

    scaled_timestep = TimeStep(
        observation=scaled_observation,
        target=timestep.target,
    )

    new_state = ScaledStreamState(inner_state=new_inner_state)
    return scaled_timestep, new_state

SuttonExperiment1State

State for SuttonExperiment1Stream.

Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken

SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)

Non-stationary stream replicating Experiment 1 from Sutton 1992.

This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)

Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    num_relevant: int = 5,
    num_irrelevant: int = 15,
    change_interval: int = 20,
):
    """Initialize the Sutton Experiment 1 stream.

    Args:
        num_relevant: Number of relevant inputs with ±1 weights
        num_irrelevant: Number of irrelevant inputs with 0 weights
        change_interval: Number of steps between sign flips
    """
    self._num_relevant = num_relevant
    self._num_irrelevant = num_irrelevant
    self._change_interval = change_interval

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with all +1 signs

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> SuttonExperiment1State:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with all +1 signs
    """
    signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
    return SuttonExperiment1State(
        key=key,
        signs=signs,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: SuttonExperiment1State, idx: Array
) -> tuple[TimeStep, SuttonExperiment1State]:
    """Generate one time step.

    At each step:
    1. If at a change interval (and not step 0), flip one random sign
    2. Generate random inputs from N(0, 1)
    3. Compute target as sum of relevant inputs weighted by signs

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_which = jr.split(state.key, 3)

    # Determine if we should flip a sign (not at step 0)
    should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)

    # Select which sign to flip
    idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)

    # Create flip mask
    flip_mask = jnp.where(
        jnp.arange(self._num_relevant) == idx_to_flip,
        jnp.array(-1.0, dtype=jnp.float32),
        jnp.array(1.0, dtype=jnp.float32),
    )

    # Apply flip mask conditionally
    new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)

    # Generate observation from N(0, 1)
    x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)

    # Compute target: sum of first num_relevant inputs weighted by signs
    target = jnp.dot(new_signs, x[: self._num_relevant])

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = SuttonExperiment1State(
        key=key,
        signs=new_signs,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

Timer(name='Operation', verbose=True, print_fn=None)

Context manager for timing code execution.

Measures wall-clock time for a block of code and optionally prints the duration when the block completes.

Attributes: name: Description of what is being timed duration: Elapsed time in seconds (available after context exits) start_time: Timestamp when timing started end_time: Timestamp when timing ended

Examples:

with Timer("Training loop"):
    for i in range(1000):
        pass
# Output: Training loop completed in 0.01s

# Silent timing (no print):
with Timer("Silent", verbose=False) as t:
    time.sleep(0.1)
print(f"Elapsed: {t.duration:.2f}s")
# Output: Elapsed: 0.10s

# Custom print function:
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
    pass
# Output: >> Custom completed in 0.00s

Args: name: Description of the operation being timed verbose: Whether to print the duration when done print_fn: Custom print function (defaults to built-in print)

Source code in src/alberta_framework/utils/timing.py
def __init__(
    self,
    name: str = "Operation",
    verbose: bool = True,
    print_fn: Callable[[str], None] | None = None,
):
    """Initialize the timer.

    Args:
        name: Description of the operation being timed
        verbose: Whether to print the duration when done
        print_fn: Custom print function (defaults to built-in print)
    """
    self.name = name
    self.verbose = verbose
    self.print_fn = print_fn or print
    self.start_time: float = 0.0
    self.end_time: float = 0.0
    self.duration: float = 0.0

elapsed()

Get elapsed time since timer started (can be called during execution).

Returns: Elapsed time in seconds

Source code in src/alberta_framework/utils/timing.py
def elapsed(self) -> float:
    """Get elapsed time since timer started (can be called during execution).

    Returns:
        Elapsed time in seconds
    """
    return time.perf_counter() - self.start_time

GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)

Experience stream from a Gymnasium environment using Python loop.

This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.

For batch learning, use collect_trajectory() followed by learn_from_trajectory().

Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes

Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
):
    """Initialize the Gymnasium stream.

    Args:
        env: Gymnasium environment instance
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action).
            If False, features = obs only
        seed: Random seed for environment resets and random policy
    """
    self._env = env
    self._mode = mode
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    if mode == PredictionMode.NEXT_STATE:
        self._target_dim = self._obs_dim
    else:
        self._target_dim = 1

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_estimator: Callable[[Array], float] | None = None

feature_dim property

Return the dimension of feature vectors.

target_dim property

Return the dimension of target vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

mode property

Return the prediction mode.

set_value_estimator(estimator)

Set the value estimator for proper TD learning in VALUE mode.

Source code in src/alberta_framework/streams/gymnasium.py
def set_value_estimator(self, estimator: Callable[[Array], float]) -> None:
    """Set the value estimator for proper TD learning in VALUE mode."""
    self._value_estimator = estimator

PredictionMode

Bases: Enum

Mode for what the stream predicts.

REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)

TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)

Experience stream for proper TD learning with value function bootstrap.

This stream integrates with a learner to use its predictions for bootstrapping in TD targets.

Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)

for step, timestep in enumerate(stream):
    result = learner.update(state, timestep.observation, timestep.target)
    state = result.state
    stream.update_value_function(lambda x: learner.predict(state, x))

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = False,
    seed: int = 0,
):
    """Initialize the TD stream.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor
        include_action_in_features: If True, learn Q(s,a). If False, learn V(s)
        seed: Random seed
    """
    self._env = env
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_fn: Callable[[Array], float] = lambda x: 0.0

feature_dim property

Return the dimension of feature vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

update_value_function(value_fn)

Update the value function used for TD bootstrapping.

Source code in src/alberta_framework/streams/gymnasium.py
def update_value_function(self, value_fn: Callable[[Array], float]) -> None:
    """Update the value function used for TD bootstrapping."""
    self._value_fn = value_fn

sparse_init(key, shape, sparsity=0.9, init_type='uniform')

Create a sparsely initialized weight matrix.

Applies LeCun-scale initialization and then zeros out a fraction of weights per output neuron. This creates sparser gradient flows that improve stability in streaming learning settings.

Reference: Elsayed et al. 2024, sparse_init.py

Args: key: JAX random key shape: Weight matrix shape (fan_out, fan_in) sparsity: Fraction of input connections to zero out per output neuron (default: 0.9 means 90% sparse) init_type: Initialization distribution, "uniform" or "normal" (default: "uniform" for LeCun uniform)

Returns: Weight matrix of given shape with specified sparsity

Examples:

import jax.random as jr
from alberta_framework.core.initializers import sparse_init

key = jr.key(42)
weights = sparse_init(key, (128, 10), sparsity=0.9)
# weights has shape (128, 10), ~90% zeros per row

Source code in src/alberta_framework/core/initializers.py
def sparse_init(
    key: Array,
    shape: tuple[int, int],
    sparsity: float = 0.9,
    init_type: str = "uniform",
) -> Float[Array, "fan_out fan_in"]:
    """Create a sparsely initialized weight matrix.

    Applies LeCun-scale initialization and then zeros out a fraction of
    weights per output neuron. This creates sparser gradient flows that
    improve stability in streaming learning settings.

    Reference: Elsayed et al. 2024, sparse_init.py

    Args:
        key: JAX random key
        shape: Weight matrix shape (fan_out, fan_in)
        sparsity: Fraction of input connections to zero out per output neuron
            (default: 0.9 means 90% sparse)
        init_type: Initialization distribution, "uniform" or "normal"
            (default: "uniform" for LeCun uniform)

    Returns:
        Weight matrix of given shape with specified sparsity

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework.core.initializers import sparse_init

    key = jr.key(42)
    weights = sparse_init(key, (128, 10), sparsity=0.9)
    # weights has shape (128, 10), ~90% zeros per row
    ```
    """
    fan_out, fan_in = shape
    num_zeros = int(sparsity * fan_in + 0.5)  # round to nearest int

    # Split key for init and sparsity mask
    init_key, mask_key = jr.split(key)

    # LeCun-scale initialization
    scale = 1.0 / fan_in**0.5
    if init_type == "uniform":
        weights = jr.uniform(init_key, shape, dtype=jnp.float32, minval=-scale, maxval=scale)
    elif init_type == "normal":
        weights = jr.normal(init_key, shape, dtype=jnp.float32) * scale
    else:
        raise ValueError(f"init_type must be 'uniform' or 'normal', got '{init_type}'")

    # Create sparsity mask: for each output neuron, zero out num_zeros inputs
    # Use vmap over output neurons with independent random permutations
    row_keys = jr.split(mask_key, fan_out)

    def make_row_mask(row_key: Array) -> Float[Array, " fan_in"]:
        """Create a binary mask for a single output neuron."""
        perm = jr.permutation(row_key, fan_in)
        # mask[i] = 1 if perm[i] >= num_zeros, else 0
        mask = (perm >= num_zeros).astype(jnp.float32)
        return mask

    masks = jax.vmap(make_row_mask)(row_keys)  # (fan_out, fan_in)

    return weights * masks

metrics_to_dicts(metrics, normalized=False)

Convert metrics array to list of dicts for backward compatibility.

Args: metrics: Array of shape (num_steps, 3) or (num_steps, 4) normalized: If True, expects 4 columns including normalizer_mean_var

Returns: List of metric dictionaries

Source code in src/alberta_framework/core/learners.py
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
    """Convert metrics array to list of dicts for backward compatibility.

    Args:
        metrics: Array of shape (num_steps, 3) or (num_steps, 4)
        normalized: If True, expects 4 columns including normalizer_mean_var

    Returns:
        List of metric dictionaries
    """
    result = []
    for row in metrics:
        d = {
            "squared_error": float(row[0]),
            "error": float(row[1]),
            "mean_step_size": float(row[2]),
        }
        if normalized and len(row) > 3:
            d["normalizer_mean_var"] = float(row[3])
        result.append(d)
    return result

run_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None, normalizer_tracking=None)

Run the learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.

Supports both plain and normalized learners. When the learner has a normalizer, metrics have 4 columns; otherwise 3 columns.

Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns StepSizeHistory. normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory with means and variances over time.

Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) or (num_steps, 4) depending on normalizer If step_size_tracking only: Tuple of (final_state, metrics_array, step_size_history) If normalizer_tracking only: Tuple of (final_state, metrics_array, normalizer_history) If both: Tuple of (final_state, metrics_array, step_size_history, normalizer_history)

Raises: ValueError: If tracking interval is invalid

Source code in src/alberta_framework/core/learners.py
def run_learning_loop[StreamStateT](
    learner: LinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: LearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> (
    tuple[LearnerState, Array]
    | tuple[LearnerState, Array, StepSizeHistory]
    | tuple[LearnerState, Array, NormalizerHistory]
    | tuple[LearnerState, Array, StepSizeHistory, NormalizerHistory]
):
    """Run the learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.
    It returns metrics as a fixed-size array rather than a list of dicts.

    Supports both plain and normalized learners. When the learner has a
    normalizer, metrics have 4 columns; otherwise 3 columns.

    Args:
        learner: The learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, returns StepSizeHistory.
        normalizer_tracking: Optional config for recording per-feature normalizer
            state. When provided, returns NormalizerHistory with means and
            variances over time.

    Returns:
        If no tracking:
            Tuple of (final_state, metrics_array) where metrics_array has shape
            (num_steps, 3) or (num_steps, 4) depending on normalizer
        If step_size_tracking only:
            Tuple of (final_state, metrics_array, step_size_history)
        If normalizer_tracking only:
            Tuple of (final_state, metrics_array, normalizer_history)
        If both:
            Tuple of (final_state, metrics_array, step_size_history, normalizer_history)

    Raises:
        ValueError: If tracking interval is invalid
    """
    # Validate tracking configs
    if step_size_tracking is not None:
        if step_size_tracking.interval < 1:
            raise ValueError(
                f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
            )
        if step_size_tracking.interval > num_steps:
            raise ValueError(
                f"step_size_tracking.interval ({step_size_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    if normalizer_tracking is not None:
        if normalizer_tracking.interval < 1:
            raise ValueError(
                f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
            )
        if normalizer_tracking.interval > num_steps:
            raise ValueError(
                f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    feature_dim = stream.feature_dim

    # No tracking - simple case
    if step_size_tracking is None and normalizer_tracking is None:

        def step_fn(
            carry: tuple[LearnerState, StreamStateT], idx: Array
        ) -> tuple[tuple[LearnerState, StreamStateT], Array]:
            l_state, s_state = carry
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)
            return (result.state, new_s_state), result.metrics

        (final_learner, _), metrics = jax.lax.scan(
            step_fn, (learner_state, stream_state), jnp.arange(num_steps)
        )

        return final_learner, metrics

    # Tracking enabled - need to set up history arrays
    ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
    norm_interval = normalizer_tracking.interval if normalizer_tracking else num_steps + 1

    ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
    norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0

    # Pre-allocate step-size history arrays
    ss_history = (
        jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
        if step_size_tracking
        else None
    )
    ss_bias_history = (
        jnp.zeros(ss_num_recordings, dtype=jnp.float32)
        if step_size_tracking and step_size_tracking.include_bias
        else None
    )
    ss_rec_indices = jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None

    # Check if we need to track Autostep normalizers
    track_autostep_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
    ss_normalizers = (
        jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
        if step_size_tracking and track_autostep_normalizers
        else None
    )

    # Pre-allocate normalizer state history arrays
    norm_means = (
        jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
        if normalizer_tracking
        else None
    )
    norm_vars = (
        jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
        if normalizer_tracking
        else None
    )
    norm_rec_indices = (
        jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
    )

    def step_fn_with_tracking(
        carry: tuple[
            LearnerState,
            StreamStateT,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
        ],
        idx: Array,
    ) -> tuple[
        tuple[
            LearnerState,
            StreamStateT,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
        ],
        Array,
    ]:
        (
            l_state,
            s_state,
            ss_hist,
            ss_bias_hist,
            ss_rec,
            ss_norm,
            n_means,
            n_vars,
            n_rec,
        ) = carry

        # Perform learning step
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(l_state, timestep.observation, timestep.target)

        # Step-size tracking
        new_ss_hist = ss_hist
        new_ss_bias_hist = ss_bias_hist
        new_ss_rec = ss_rec
        new_ss_norm = ss_norm

        if ss_hist is not None:
            should_record_ss = (idx % ss_interval) == 0
            recording_idx = idx // ss_interval

            # Extract current step-sizes
            opt_state = result.state.optimizer_state
            if hasattr(opt_state, "log_step_sizes"):
                # IDBD stores log step-sizes
                weight_ss = jnp.exp(opt_state.log_step_sizes)
                bias_ss = opt_state.bias_step_size
            elif hasattr(opt_state, "step_sizes"):
                # Autostep stores step-sizes directly
                weight_ss = opt_state.step_sizes
                bias_ss = opt_state.bias_step_size
            else:
                # LMS has a single fixed step-size
                weight_ss = jnp.full(feature_dim, opt_state.step_size)
                bias_ss = opt_state.step_size

            new_ss_hist = jax.lax.cond(
                should_record_ss,
                lambda _: ss_hist.at[recording_idx].set(weight_ss),
                lambda _: ss_hist,
                None,
            )

            if ss_bias_hist is not None:
                new_ss_bias_hist = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
                    lambda _: ss_bias_hist,
                    None,
                )

            if ss_rec is not None:
                new_ss_rec = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_rec.at[recording_idx].set(idx),
                    lambda _: ss_rec,
                    None,
                )

            # Track Autostep normalizers (v_i) if applicable
            if ss_norm is not None and hasattr(opt_state, "normalizers"):
                new_ss_norm = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_norm.at[recording_idx].set(opt_state.normalizers),
                    lambda _: ss_norm,
                    None,
                )

        # Normalizer state tracking
        new_n_means = n_means
        new_n_vars = n_vars
        new_n_rec = n_rec

        if n_means is not None:
            should_record_norm = (idx % norm_interval) == 0
            norm_recording_idx = idx // norm_interval

            norm_state = result.state.normalizer_state

            new_n_means = jax.lax.cond(
                should_record_norm,
                lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
                lambda _: n_means,
                None,
            )

            if n_vars is not None:
                new_n_vars = jax.lax.cond(
                    should_record_norm,
                    lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
                    lambda _: n_vars,
                    None,
                )

            if n_rec is not None:
                new_n_rec = jax.lax.cond(
                    should_record_norm,
                    lambda _: n_rec.at[norm_recording_idx].set(idx),
                    lambda _: n_rec,
                    None,
                )

        return (
            result.state,
            new_s_state,
            new_ss_hist,
            new_ss_bias_hist,
            new_ss_rec,
            new_ss_norm,
            new_n_means,
            new_n_vars,
            new_n_rec,
        ), result.metrics

    initial_carry = (
        learner_state,
        stream_state,
        ss_history,
        ss_bias_history,
        ss_rec_indices,
        ss_normalizers,
        norm_means,
        norm_vars,
        norm_rec_indices,
    )

    (
        (
            final_learner,
            _,
            final_ss_hist,
            final_ss_bias_hist,
            final_ss_rec,
            final_ss_norm,
            final_n_means,
            final_n_vars,
            final_n_rec,
        ),
        metrics,
    ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))

    # Build return values based on what was tracked
    ss_history_result = None
    if step_size_tracking is not None and final_ss_hist is not None:
        ss_history_result = StepSizeHistory(
            step_sizes=final_ss_hist,
            bias_step_sizes=final_ss_bias_hist,
            recording_indices=final_ss_rec,
            normalizers=final_ss_norm,
        )

    norm_history_result = None
    if normalizer_tracking is not None and final_n_means is not None:
        norm_history_result = NormalizerHistory(
            means=final_n_means,
            variances=final_n_vars,
            recording_indices=final_n_rec,
        )

    # Return appropriate tuple based on what was tracked
    if ss_history_result is not None and norm_history_result is not None:
        return final_learner, metrics, ss_history_result, norm_history_result
    elif ss_history_result is not None:
        return final_learner, metrics, ss_history_result
    elif norm_history_result is not None:
        return final_learner, metrics, norm_history_result
    else:
        return final_learner, metrics

run_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None, normalizer_tracking=None)

Run learning loop across multiple seeds in parallel using jax.vmap.

This function provides GPU parallelization for multi-seed experiments, typically achieving 2-5x speedup over sequential execution.

Supports both plain and normalized learners.

Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...) normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)

Returns: BatchedLearningResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, num_cols) - step_size_history: Batched history or None if tracking disabled - normalizer_history: Batched history or None if tracking disabled

Examples:

import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_learning_loop_batched

stream = RandomWalkStream(feature_dim=10)
learner = LinearLearner(optimizer=IDBD())

# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds

Source code in src/alberta_framework/core/learners.py
def run_learning_loop_batched[StreamStateT](
    learner: LinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    keys: Array,
    learner_state: LearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> BatchedLearningResult:
    """Run learning loop across multiple seeds in parallel using jax.vmap.

    This function provides GPU parallelization for multi-seed experiments,
    typically achieving 2-5x speedup over sequential execution.

    Supports both plain and normalized learners.

    Args:
        learner: The learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run per seed
        keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
        learner_state: Initial state (if None, will be initialized from stream).
            The same initial state is used for all seeds.
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)
        normalizer_tracking: Optional config for recording normalizer state.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)

    Returns:
        BatchedLearningResult containing:
            - states: Batched final states with shape (num_seeds, ...) for each array
            - metrics: Array of shape (num_seeds, num_steps, num_cols)
            - step_size_history: Batched history or None if tracking disabled
            - normalizer_history: Batched history or None if tracking disabled

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework import LinearLearner, IDBD, RandomWalkStream
    from alberta_framework import run_learning_loop_batched

    stream = RandomWalkStream(feature_dim=10)
    learner = LinearLearner(optimizer=IDBD())

    # Run 30 seeds in parallel
    keys = jr.split(jr.key(42), 30)
    result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

    # result.metrics has shape (30, 10000, 3)
    mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds
    ```
    """

    # Define single-seed function that returns consistent structure
    def single_seed_run(
        key: Array,
    ) -> tuple[LearnerState, Array, StepSizeHistory | None, NormalizerHistory | None]:
        result = run_learning_loop(
            learner, stream, num_steps, key, learner_state,
            step_size_tracking, normalizer_tracking,
        )

        # Unpack based on what tracking was enabled
        if step_size_tracking is not None and normalizer_tracking is not None:
            state, metrics, ss_history, norm_history = cast(
                tuple[LearnerState, Array, StepSizeHistory, NormalizerHistory],
                result,
            )
            return state, metrics, ss_history, norm_history
        elif step_size_tracking is not None:
            state, metrics, ss_history = cast(
                tuple[LearnerState, Array, StepSizeHistory], result
            )
            return state, metrics, ss_history, None
        elif normalizer_tracking is not None:
            state, metrics, norm_history = cast(
                tuple[LearnerState, Array, NormalizerHistory], result
            )
            return state, metrics, None, norm_history
        else:
            state, metrics = cast(tuple[LearnerState, Array], result)
            return state, metrics, None, None

    # vmap over the keys dimension
    batched_states, batched_metrics, batched_ss_history, batched_norm_history = jax.vmap(
        single_seed_run
    )(keys)

    # Reconstruct batched histories if tracking was enabled
    if step_size_tracking is not None and batched_ss_history is not None:
        batched_step_size_history = StepSizeHistory(
            step_sizes=batched_ss_history.step_sizes,
            bias_step_sizes=batched_ss_history.bias_step_sizes,
            recording_indices=batched_ss_history.recording_indices,
            normalizers=batched_ss_history.normalizers,
        )
    else:
        batched_step_size_history = None

    if normalizer_tracking is not None and batched_norm_history is not None:
        batched_normalizer_history = NormalizerHistory(
            means=batched_norm_history.means,
            variances=batched_norm_history.variances,
            recording_indices=batched_norm_history.recording_indices,
        )
    else:
        batched_normalizer_history = None

    return BatchedLearningResult(
        states=batched_states,
        metrics=batched_metrics,
        step_size_history=batched_step_size_history,
        normalizer_history=batched_normalizer_history,
    )

run_mlp_learning_loop(learner, stream, num_steps, key, learner_state=None, normalizer_tracking=None)

Run the MLP learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency.

Args: learner: The MLP learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream and weight initialization learner_state: Initial state (if None, will be initialized from stream) normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory.

Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) or (num_steps, 4) If normalizer_tracking: Tuple of (final_state, metrics_array, normalizer_history)

Raises: ValueError: If normalizer_tracking.interval is invalid

Source code in src/alberta_framework/core/learners.py
def run_mlp_learning_loop[StreamStateT](
    learner: MLPLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: MLPLearnerState | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> (
    tuple[MLPLearnerState, Array]
    | tuple[MLPLearnerState, Array, NormalizerHistory]
):
    """Run the MLP learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.

    Args:
        learner: The MLP learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run
        key: JAX random key for stream and weight initialization
        learner_state: Initial state (if None, will be initialized from stream)
        normalizer_tracking: Optional config for recording per-feature normalizer
            state. When provided, returns NormalizerHistory.

    Returns:
        If no tracking:
            Tuple of (final_state, metrics_array) where metrics_array has shape
            (num_steps, 3) or (num_steps, 4)
        If normalizer_tracking:
            Tuple of (final_state, metrics_array, normalizer_history)

    Raises:
        ValueError: If normalizer_tracking.interval is invalid
    """
    # Validate tracking config
    if normalizer_tracking is not None:
        if normalizer_tracking.interval < 1:
            raise ValueError(
                f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
            )
        if normalizer_tracking.interval > num_steps:
            raise ValueError(
                f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    # Split key for initialization
    stream_key, init_key = jax.random.split(key)

    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim, init_key)
    stream_state = stream.init(stream_key)

    feature_dim = stream.feature_dim

    if normalizer_tracking is None:
        # Simple case without tracking
        def step_fn(
            carry: tuple[MLPLearnerState, StreamStateT], idx: Array
        ) -> tuple[tuple[MLPLearnerState, StreamStateT], Array]:
            l_state, s_state = carry
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)
            return (result.state, new_s_state), result.metrics

        (final_learner, _), metrics = jax.lax.scan(
            step_fn, (learner_state, stream_state), jnp.arange(num_steps)
        )

        return final_learner, metrics

    # Tracking enabled
    norm_interval = normalizer_tracking.interval
    norm_num_recordings = num_steps // norm_interval

    norm_means = jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
    norm_vars = jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
    norm_rec_indices = jnp.zeros(norm_num_recordings, dtype=jnp.int32)

    def step_fn_with_tracking(
        carry: tuple[MLPLearnerState, StreamStateT, Array, Array, Array],
        idx: Array,
    ) -> tuple[
        tuple[MLPLearnerState, StreamStateT, Array, Array, Array],
        Array,
    ]:
        l_state, s_state, n_means, n_vars, n_rec = carry

        # Perform learning step
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(l_state, timestep.observation, timestep.target)

        # Normalizer state tracking
        should_record = (idx % norm_interval) == 0
        recording_idx = idx // norm_interval

        norm_state = result.state.normalizer_state

        new_n_means = jax.lax.cond(
            should_record,
            lambda _: n_means.at[recording_idx].set(norm_state.mean),
            lambda _: n_means,
            None,
        )

        new_n_vars = jax.lax.cond(
            should_record,
            lambda _: n_vars.at[recording_idx].set(norm_state.var),
            lambda _: n_vars,
            None,
        )

        new_n_rec = jax.lax.cond(
            should_record,
            lambda _: n_rec.at[recording_idx].set(idx),
            lambda _: n_rec,
            None,
        )

        return (
            result.state,
            new_s_state,
            new_n_means,
            new_n_vars,
            new_n_rec,
        ), result.metrics

    initial_carry = (
        learner_state,
        stream_state,
        norm_means,
        norm_vars,
        norm_rec_indices,
    )

    (
        (final_learner, _, final_n_means, final_n_vars, final_n_rec),
        metrics,
    ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))

    norm_history = NormalizerHistory(
        means=final_n_means,
        variances=final_n_vars,
        recording_indices=final_n_rec,
    )

    return final_learner, metrics, norm_history

run_mlp_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, normalizer_tracking=None)

Run MLP learning loop across multiple seeds in parallel using jax.vmap.

This function provides GPU parallelization for multi-seed MLP experiments, typically achieving 2-5x speedup over sequential execution.

Args: learner: The MLP learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)

Returns: BatchedMLPResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, num_cols) - normalizer_history: Batched history or None if tracking disabled

Examples:

import jax.random as jr
from alberta_framework import MLPLearner, RandomWalkStream
from alberta_framework import run_mlp_learning_loop_batched

stream = RandomWalkStream(feature_dim=10)
learner = MLPLearner(hidden_sizes=(128, 128))

# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_mlp_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds

Source code in src/alberta_framework/core/learners.py
def run_mlp_learning_loop_batched[StreamStateT](
    learner: MLPLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    keys: Array,
    learner_state: MLPLearnerState | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> BatchedMLPResult:
    """Run MLP learning loop across multiple seeds in parallel using jax.vmap.

    This function provides GPU parallelization for multi-seed MLP experiments,
    typically achieving 2-5x speedup over sequential execution.

    Args:
        learner: The MLP learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run per seed
        keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
        learner_state: Initial state (if None, will be initialized from stream).
            The same initial state is used for all seeds.
        normalizer_tracking: Optional config for recording normalizer state.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)

    Returns:
        BatchedMLPResult containing:
            - states: Batched final states with shape (num_seeds, ...) for each array
            - metrics: Array of shape (num_seeds, num_steps, num_cols)
            - normalizer_history: Batched history or None if tracking disabled

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework import MLPLearner, RandomWalkStream
    from alberta_framework import run_mlp_learning_loop_batched

    stream = RandomWalkStream(feature_dim=10)
    learner = MLPLearner(hidden_sizes=(128, 128))

    # Run 30 seeds in parallel
    keys = jr.split(jr.key(42), 30)
    result = run_mlp_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

    # result.metrics has shape (30, 10000, 3)
    mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds
    ```
    """

    def single_seed_run(
        key: Array,
    ) -> tuple[MLPLearnerState, Array, NormalizerHistory | None]:
        result = run_mlp_learning_loop(
            learner, stream, num_steps, key, learner_state, normalizer_tracking
        )

        if normalizer_tracking is not None:
            state, metrics, norm_history = cast(
                tuple[MLPLearnerState, Array, NormalizerHistory], result
            )
            return state, metrics, norm_history
        else:
            state, metrics = cast(tuple[MLPLearnerState, Array], result)
            return state, metrics, None

    batched_states, batched_metrics, batched_norm_history = jax.vmap(single_seed_run)(keys)

    if normalizer_tracking is not None and batched_norm_history is not None:
        batched_normalizer_history = NormalizerHistory(
            means=batched_norm_history.means,
            variances=batched_norm_history.variances,
            recording_indices=batched_norm_history.recording_indices,
        )
    else:
        batched_normalizer_history = None

    return BatchedMLPResult(
        states=batched_states,
        metrics=batched_metrics,
        normalizer_history=batched_normalizer_history,
    )

run_td_learning_loop(learner, stream, num_steps, key, learner_state=None)

Run the TD learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.

Args: learner: The TD learner to train stream: TD experience stream providing (s, r, s', gamma) tuples num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]

Source code in src/alberta_framework/core/learners.py
def run_td_learning_loop[StreamStateT](
    learner: TDLinearLearner,
    stream: TDStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: TDLearnerState | None = None,
) -> tuple[TDLearnerState, Array]:
    """Run the TD learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.
    It returns metrics as a fixed-size array rather than a list of dicts.

    Args:
        learner: The TD learner to train
        stream: TD experience stream providing (s, r, s', gamma) tuples
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size,
        mean_eligibility_trace]
    """
    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    def step_fn(
        carry: tuple[TDLearnerState, StreamStateT], idx: Array
    ) -> tuple[tuple[TDLearnerState, StreamStateT], Array]:
        l_state, s_state = carry
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(
            l_state,
            timestep.observation,
            timestep.reward,
            timestep.next_observation,
            timestep.gamma,
        )
        return (result.state, new_s_state), result.metrics

    (final_learner, _), metrics = jax.lax.scan(
        step_fn, (learner_state, stream_state), jnp.arange(num_steps)
    )

    return final_learner, metrics

create_autotdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Create initial AutoTDIDBD optimizer state.

Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Returns: Initial AutoTDIDBD state

Source code in src/alberta_framework/core/types.py
def create_autotdidbd_state(
    feature_dim: int,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
) -> AutoTDIDBDState:
    """Create initial AutoTDIDBD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        initial_step_size: Initial per-weight step-size
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)

    Returns:
        Initial AutoTDIDBD state
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

create_obgd_state(feature_dim, step_size=1.0, kappa=2.0, gamma=0.0, lamda=0.0)

Create initial ObGD optimizer state.

Args: feature_dim: Dimension of the feature vector step_size: Base learning rate (default: 1.0) kappa: Bounding sensitivity parameter (default: 2.0) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised)

Returns: Initial ObGD state

Source code in src/alberta_framework/core/types.py
def create_obgd_state(
    feature_dim: int,
    step_size: float = 1.0,
    kappa: float = 2.0,
    gamma: float = 0.0,
    lamda: float = 0.0,
) -> ObGDState:
    """Create initial ObGD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        step_size: Base learning rate (default: 1.0)
        kappa: Bounding sensitivity parameter (default: 2.0)
        gamma: Discount factor for trace decay (default: 0.0 for supervised)
        lamda: Eligibility trace decay parameter (default: 0.0 for supervised)

    Returns:
        Initial ObGD state
    """
    return ObGDState(
        step_size=jnp.array(step_size, dtype=jnp.float32),
        kappa=jnp.array(kappa, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        gamma=jnp.array(gamma, dtype=jnp.float32),
        lamda=jnp.array(lamda, dtype=jnp.float32),
    )

create_tdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0)

Create initial TD-IDBD optimizer state.

Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))

Returns: Initial TD-IDBD state

Source code in src/alberta_framework/core/types.py
def create_tdidbd_state(
    feature_dim: int,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
) -> TDIDBDState:
    """Create initial TD-IDBD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        initial_step_size: Initial per-weight step-size
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))

    Returns:
        Initial TD-IDBD state
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)

Create a per-feature scale array spanning a range.

Utility function to generate scale factors for ScaledStreamWrapper.

Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.

Returns: Array of shape (feature_dim,) with scale factors

Examples:

scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
def make_scale_range(
    feature_dim: int,
    min_scale: float = 0.001,
    max_scale: float = 1000.0,
    log_spaced: bool = True,
) -> Array:
    """Create a per-feature scale array spanning a range.

    Utility function to generate scale factors for ScaledStreamWrapper.

    Args:
        feature_dim: Number of features
        min_scale: Minimum scale factor
        max_scale: Maximum scale factor
        log_spaced: If True, scales are logarithmically spaced (default).
            If False, scales are linearly spaced.

    Returns:
        Array of shape (feature_dim,) with scale factors

    Examples
    --------
    ```python
    scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
    stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
    ```
    """
    if log_spaced:
        return jnp.logspace(
            jnp.log10(min_scale),
            jnp.log10(max_scale),
            feature_dim,
            dtype=jnp.float32,
        )
    else:
        return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)

compare_learners(results, metric='squared_error')

Compare multiple learners on a given metric.

Args: results: Dictionary mapping learner name to metrics history metric: Metric to compare

Returns: Dictionary with summary statistics for each learner

Source code in src/alberta_framework/utils/metrics.py
def compare_learners(
    results: dict[str, list[dict[str, float]]],
    metric: str = "squared_error",
) -> dict[str, dict[str, float]]:
    """Compare multiple learners on a given metric.

    Args:
        results: Dictionary mapping learner name to metrics history
        metric: Metric to compare

    Returns:
        Dictionary with summary statistics for each learner
    """
    summary = {}
    for name, metrics_history in results.items():
        values = extract_metric(metrics_history, metric)
        summary[name] = {
            "mean": float(np.mean(values)),
            "std": float(np.std(values)),
            "cumulative": float(np.sum(values)),
            "final_100_mean": (
                float(np.mean(values[-100:])) if len(values) >= 100 else float(np.mean(values))
            ),
        }
    return summary

compute_cumulative_error(metrics_history, error_key='squared_error')

Compute cumulative error over time.

Args: metrics_history: List of metric dictionaries from learning loop error_key: Key to extract error values

Returns: Array of cumulative errors at each time step

Source code in src/alberta_framework/utils/metrics.py
def compute_cumulative_error(
    metrics_history: list[dict[str, float]],
    error_key: str = "squared_error",
) -> NDArray[np.float64]:
    """Compute cumulative error over time.

    Args:
        metrics_history: List of metric dictionaries from learning loop
        error_key: Key to extract error values

    Returns:
        Array of cumulative errors at each time step
    """
    errors = np.array([m[error_key] for m in metrics_history])
    return np.cumsum(errors)

compute_running_mean(values, window_size=100)

Compute running mean of values.

Args: values: Array of values window_size: Size of the moving average window

Returns: Array of running mean values (same length as input, padded at start)

Source code in src/alberta_framework/utils/metrics.py
def compute_running_mean(
    values: NDArray[np.float64] | list[float],
    window_size: int = 100,
) -> NDArray[np.float64]:
    """Compute running mean of values.

    Args:
        values: Array of values
        window_size: Size of the moving average window

    Returns:
        Array of running mean values (same length as input, padded at start)
    """
    values_arr = np.asarray(values)
    cumsum = np.cumsum(np.insert(values_arr, 0, 0))
    running_mean = (cumsum[window_size:] - cumsum[:-window_size]) / window_size

    # Pad the beginning with the first computed mean
    if len(running_mean) > 0:
        padding = np.full(window_size - 1, running_mean[0])
        return np.concatenate([padding, running_mean])
    return values_arr

compute_tracking_error(metrics_history, window_size=100)

Compute tracking error (running mean of squared error).

This is the key metric for evaluating continual learners: how well can the learner track the non-stationary target?

Args: metrics_history: List of metric dictionaries from learning loop window_size: Size of the moving average window

Returns: Array of tracking errors at each time step

Source code in src/alberta_framework/utils/metrics.py
def compute_tracking_error(
    metrics_history: list[dict[str, float]],
    window_size: int = 100,
) -> NDArray[np.float64]:
    """Compute tracking error (running mean of squared error).

    This is the key metric for evaluating continual learners:
    how well can the learner track the non-stationary target?

    Args:
        metrics_history: List of metric dictionaries from learning loop
        window_size: Size of the moving average window

    Returns:
        Array of tracking errors at each time step
    """
    errors = np.array([m["squared_error"] for m in metrics_history])
    return compute_running_mean(errors, window_size)

extract_metric(metrics_history, key)

Extract a single metric from the history.

Args: metrics_history: List of metric dictionaries key: Key to extract

Returns: Array of values for that metric

Source code in src/alberta_framework/utils/metrics.py
def extract_metric(
    metrics_history: list[dict[str, float]],
    key: str,
) -> NDArray[np.float64]:
    """Extract a single metric from the history.

    Args:
        metrics_history: List of metric dictionaries
        key: Key to extract

    Returns:
        Array of values for that metric
    """
    return np.array([m[key] for m in metrics_history])

format_duration(seconds)

Format a duration in seconds as a human-readable string.

Args: seconds: Duration in seconds

Returns: Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"

Examples:

format_duration(0.5)   # Returns: '0.50s'
format_duration(90.5)  # Returns: '1m 30.50s'
format_duration(3665)  # Returns: '1h 1m 5.00s'
Source code in src/alberta_framework/utils/timing.py
def format_duration(seconds: float) -> str:
    """Format a duration in seconds as a human-readable string.

    Args:
        seconds: Duration in seconds

    Returns:
        Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"

    Examples
    --------
    ```python
    format_duration(0.5)   # Returns: '0.50s'
    format_duration(90.5)  # Returns: '1m 30.50s'
    format_duration(3665)  # Returns: '1h 1m 5.00s'
    ```
    """
    if seconds < 60:
        return f"{seconds:.2f}s"
    elif seconds < 3600:
        minutes = int(seconds // 60)
        secs = seconds % 60
        return f"{minutes}m {secs:.2f}s"
    else:
        hours = int(seconds // 3600)
        remaining = seconds % 3600
        minutes = int(remaining // 60)
        secs = remaining % 60
        return f"{hours}h {minutes}m {secs:.2f}s"

collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)

Collect a trajectory from a Gymnasium environment.

This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy

Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)

Source code in src/alberta_framework/streams/gymnasium.py
def collect_trajectory(
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None,
    num_steps: int,
    mode: PredictionMode = PredictionMode.REWARD,
    include_action_in_features: bool = True,
    seed: int = 0,
) -> tuple[Array, Array]:
    """Collect a trajectory from a Gymnasium environment.

    This uses a Python loop to interact with the environment and collects
    observations and targets into JAX arrays that can be used with scan-based
    learning.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        num_steps: Number of steps to collect
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed for environment resets and random policy

    Returns:
        Tuple of (observations, targets) as JAX arrays with shape
        (num_steps, feature_dim) and (num_steps, target_dim)
    """
    if policy is None:
        policy = make_random_policy(env, seed)

    observations = []
    targets = []

    reset_count = 0
    raw_obs, _ = env.reset(seed=seed + reset_count)
    reset_count += 1
    current_obs = _flatten_observation(raw_obs, env.observation_space)

    for _ in range(num_steps):
        action = policy(current_obs)
        flat_action = _flatten_action(action, env.action_space)

        raw_next_obs, reward, terminated, truncated, _ = env.step(action)
        next_obs = _flatten_observation(raw_next_obs, env.observation_space)

        # Construct features
        if include_action_in_features:
            features = jnp.concatenate([current_obs, flat_action])
        else:
            features = current_obs

        # Construct target based on mode
        if mode == PredictionMode.REWARD:
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
        elif mode == PredictionMode.NEXT_STATE:
            target = next_obs
        else:  # VALUE mode
            # TD target with 0 bootstrap (simple version)
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))

        observations.append(features)
        targets.append(target)

        if terminated or truncated:
            raw_obs, _ = env.reset(seed=seed + reset_count)
            reset_count += 1
            current_obs = _flatten_observation(raw_obs, env.observation_space)
        else:
            current_obs = next_obs

    return jnp.stack(observations), jnp.stack(targets)

learn_from_trajectory(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory using jax.lax.scan.

This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.

Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory(
    learner: LinearLearner,
    observations: Array,
    targets: Array,
    learner_state: LearnerState | None = None,
) -> tuple[LearnerState, Array]:
    """Learn from a pre-collected trajectory using jax.lax.scan.

    This is a JIT-compiled learning function that processes a trajectory
    collected from a Gymnasium environment.

    Args:
        learner: The learner to train
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 3) with columns [squared_error, error, mean_step_size]
    """
    if learner_state is None:
        learner_state = learner.init(observations.shape[1])

    def step_fn(state: LearnerState, inputs: tuple[Array, Array]) -> tuple[LearnerState, Array]:
        obs, target = inputs
        result = learner.update(state, obs, target)
        return result.state, result.metrics

    final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))

    return final_state, metrics

learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory with normalization using jax.lax.scan.

This is equivalent to learn_from_trajectory for a learner constructed with a normalizer (e.g. LinearLearner(optimizer=..., normalizer=EMANormalizer())). Retained for backward compatibility.

Args: learner: The learner to train (should have a normalizer configured) observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory_normalized(
    learner: LinearLearner,
    observations: Array,
    targets: Array,
    learner_state: LearnerState | None = None,
) -> tuple[LearnerState, Array]:
    """Learn from a pre-collected trajectory with normalization using jax.lax.scan.

    This is equivalent to ``learn_from_trajectory`` for a learner constructed
    with a normalizer (e.g. ``LinearLearner(optimizer=..., normalizer=EMANormalizer())``).
    Retained for backward compatibility.

    Args:
        learner: The learner to train (should have a normalizer configured)
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
    """
    return learn_from_trajectory(learner, observations, targets, learner_state)

make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)

Wrap a policy with epsilon-greedy exploration.

Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed

Returns: Epsilon-greedy policy

Source code in src/alberta_framework/streams/gymnasium.py
def make_epsilon_greedy_policy(
    base_policy: Callable[[Array], Any],
    env: gymnasium.Env[Any, Any],
    epsilon: float = 0.1,
    seed: int = 0,
) -> Callable[[Array], Any]:
    """Wrap a policy with epsilon-greedy exploration.

    Args:
        base_policy: The greedy policy to wrap
        env: Gymnasium environment (for random action sampling)
        epsilon: Probability of taking a random action
        seed: Random seed

    Returns:
        Epsilon-greedy policy
    """
    random_policy = make_random_policy(env, seed + 1)
    rng = jr.key(seed)

    def policy(obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if jr.uniform(key) < epsilon:
            return random_policy(obs)
        return base_policy(obs)

    return policy

make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)

Factory function to create a GymnasiumStream from an environment ID.

Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()

Returns: GymnasiumStream wrapping the environment

Source code in src/alberta_framework/streams/gymnasium.py
def make_gymnasium_stream(
    env_id: str,
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
    **env_kwargs: Any,
) -> GymnasiumStream:
    """Factory function to create a GymnasiumStream from an environment ID.

    Args:
        env_id: Gymnasium environment ID (e.g., "CartPole-v1")
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed
        **env_kwargs: Additional arguments passed to gymnasium.make()

    Returns:
        GymnasiumStream wrapping the environment
    """
    import gymnasium

    env = gymnasium.make(env_id, **env_kwargs)
    return GymnasiumStream(
        env=env,
        mode=mode,
        policy=policy,
        gamma=gamma,
        include_action_in_features=include_action_in_features,
        seed=seed,
    )

make_random_policy(env, seed=0)

Create a random action policy for an environment.

Args: env: Gymnasium environment seed: Random seed

Returns: A callable that takes an observation and returns a random action

Source code in src/alberta_framework/streams/gymnasium.py
def make_random_policy(env: gymnasium.Env[Any, Any], seed: int = 0) -> Callable[[Array], Any]:
    """Create a random action policy for an environment.

    Args:
        env: Gymnasium environment
        seed: Random seed

    Returns:
        A callable that takes an observation and returns a random action
    """
    import gymnasium

    rng = jr.key(seed)
    action_space = env.action_space

    def policy(_obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if isinstance(action_space, gymnasium.spaces.Discrete):
            return int(jr.randint(key, (), 0, int(action_space.n)))
        elif isinstance(action_space, gymnasium.spaces.Box):
            # Sample uniformly between low and high
            low = jnp.asarray(action_space.low, dtype=jnp.float32)
            high = jnp.asarray(action_space.high, dtype=jnp.float32)
            return jr.uniform(key, action_space.shape, minval=low, maxval=high)
        elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
            nvec = action_space.nvec
            return [int(jr.randint(jr.fold_in(key, i), (), 0, n)) for i, n in enumerate(nvec)]
        else:
            raise ValueError(f"Unsupported action space: {type(action_space).__name__}")

    return policy