Skip to content

alberta_framework

alberta_framework

Alberta Framework: A JAX-based research framework for continual AI.

The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity — every component updates at every time step, with no special training phases or batch processing.

Roadmap
Step Focus Status
1 Meta-learned step-sizes (IDBD, Autostep) Complete
2 Feature generation and testing Planned
3 GVF predictions, Horde architecture Planned
4 Actor-critic with eligibility traces Planned
5-6 Off-policy learning, average reward Planned
7-12 Hierarchical, multi-agent, world models Future

Examples:

import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop

# Non-stationary stream where target weights drift over time
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)

# Learner with IDBD meta-learned step-sizes
learner = LinearLearner(optimizer=IDBD())

# JIT-compiled training via jax.lax.scan
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
References
  • The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
  • Adapting Bias by Gradient Descent (Sutton, 1992)
  • Tuning-free Step-size Adaptation (Mahmood et al., 2012)

LinearLearner(optimizer=None)

Linear function approximator with pluggable optimizer.

Computes predictions as: y = w @ x + b

The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Attributes: optimizer: The optimizer to use for weight updates

Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01)

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyOptimizer | None = None):
    """Initialize the linear learner.

    Args:
        optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
    """
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)

init(feature_dim)

Initialize learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> LearnerState:
    """Initialize learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return LearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
    )

predict(state, observation)

Compute prediction for an observation.

Args: state: Current learner state observation: Input feature vector

Returns: Scalar prediction y = w @ x + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Scalar prediction `y = w @ x + b`
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, target)

Update learner given observation and target.

Performs one step of the learning algorithm: 1. Compute prediction 2. Compute error 3. Get weight updates from optimizer 4. Apply updates to weights and bias

Args: state: Current learner state observation: Input feature vector target: Desired output

Returns: UpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: LearnerState,
    observation: Observation,
    target: Target,
) -> UpdateResult:
    """Update learner given observation and target.

    Performs one step of the learning algorithm:
    1. Compute prediction
    2. Compute error
    3. Get weight updates from optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        UpdateResult with new state, prediction, error, and metrics
    """
    # Make prediction
    prediction = self.predict(state, observation)

    # Compute error (target - prediction)
    error = jnp.squeeze(target) - jnp.squeeze(prediction)

    # Get update from optimizer
    # Note: type ignore needed because we can't statically prove optimizer_state
    # matches the optimizer's expected state type (though they will at runtime)
    opt_update = self._optimizer.update(
        state.optimizer_state,
        error,
        observation,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = LearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
    )

    # Pack metrics as array for scan compatibility
    # Format: [squared_error, error, mean_step_size (if adaptive)]
    squared_error = error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    metrics = jnp.array([squared_error, error, mean_step_size], dtype=jnp.float32)

    return UpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

NormalizedLearnerState

State for a learner with online feature normalization.

Attributes: learner_state: Underlying learner state (weights, bias, optimizer) normalizer_state: Online normalizer state (mean, var estimates)

NormalizedLinearLearner(optimizer=None, normalizer=None)

Linear learner with online feature normalization.

Wraps a LinearLearner with online feature normalization, following the Alberta Plan's approach to handling varying feature scales.

Normalization is applied to features before prediction and learning: x_normalized = (x - mean) / (std + epsilon)

The normalizer statistics update at every time step, maintaining temporal uniformity.

Attributes: learner: Underlying linear learner normalizer: Online feature normalizer

Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01) normalizer: Feature normalizer. Defaults to OnlineNormalizer()

Source code in src/alberta_framework/core/learners.py
def __init__(
    self,
    optimizer: AnyOptimizer | None = None,
    normalizer: OnlineNormalizer | None = None,
):
    """Initialize the normalized linear learner.

    Args:
        optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
        normalizer: Feature normalizer. Defaults to OnlineNormalizer()
    """
    self._learner = LinearLearner(optimizer=optimizer or LMS(step_size=0.01))
    self._normalizer = normalizer or OnlineNormalizer()

init(feature_dim)

Initialize normalized learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial state with zero weights and unit variance estimates

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> NormalizedLearnerState:
    """Initialize normalized learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial state with zero weights and unit variance estimates
    """
    return NormalizedLearnerState(
        learner_state=self._learner.init(feature_dim),
        normalizer_state=self._normalizer.init(feature_dim),
    )

predict(state, observation)

Compute prediction for an observation.

Normalizes the observation using current statistics before prediction.

Args: state: Current normalized learner state observation: Raw (unnormalized) input feature vector

Returns: Scalar prediction y = w @ normalize(x) + b

Source code in src/alberta_framework/core/learners.py
def predict(
    self,
    state: NormalizedLearnerState,
    observation: Observation,
) -> Prediction:
    """Compute prediction for an observation.

    Normalizes the observation using current statistics before prediction.

    Args:
        state: Current normalized learner state
        observation: Raw (unnormalized) input feature vector

    Returns:
        Scalar prediction y = w @ normalize(x) + b
    """
    normalized_obs = self._normalizer.normalize_only(state.normalizer_state, observation)
    return self._learner.predict(state.learner_state, normalized_obs)

update(state, observation, target)

Update learner given observation and target.

Performs one step of the learning algorithm: 1. Normalize observation (and update normalizer statistics) 2. Compute prediction using normalized features 3. Compute error 4. Get weight updates from optimizer 5. Apply updates

Args: state: Current normalized learner state observation: Raw (unnormalized) input feature vector target: Desired output

Returns: NormalizedUpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: NormalizedLearnerState,
    observation: Observation,
    target: Target,
) -> NormalizedUpdateResult:
    """Update learner given observation and target.

    Performs one step of the learning algorithm:
    1. Normalize observation (and update normalizer statistics)
    2. Compute prediction using normalized features
    3. Compute error
    4. Get weight updates from optimizer
    5. Apply updates

    Args:
        state: Current normalized learner state
        observation: Raw (unnormalized) input feature vector
        target: Desired output

    Returns:
        NormalizedUpdateResult with new state, prediction, error, and metrics
    """
    # Normalize observation and update normalizer state
    normalized_obs, new_normalizer_state = self._normalizer.normalize(
        state.normalizer_state, observation
    )

    # Delegate to underlying learner
    result = self._learner.update(
        state.learner_state,
        normalized_obs,
        target,
    )

    # Build combined state
    new_state = NormalizedLearnerState(
        learner_state=result.state,
        normalizer_state=new_normalizer_state,
    )

    # Add normalizer metrics to the metrics array
    normalizer_mean_var = jnp.mean(new_normalizer_state.var)
    metrics = jnp.concatenate([result.metrics, jnp.array([normalizer_mean_var])])

    return NormalizedUpdateResult(
        state=new_state,
        prediction=result.prediction,
        error=result.error,
        metrics=metrics,
    )

TDLinearLearner(optimizer=None)

Linear function approximator for TD learning.

Computes value predictions as: V(s) = w @ φ(s) + b

The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: optimizer: The TD optimizer to use for weight updates

Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyTDOptimizer | None = None):
    """Initialize the TD linear learner.

    Args:
        optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
    """
    self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()

init(feature_dim)

Initialize TD learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial TD learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> TDLearnerState:
    """Initialize TD learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial TD learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return TDLearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
    )

predict(state, observation)

Compute value prediction for an observation.

Args: state: Current TD learner state observation: Input feature vector φ(s)

Returns: Scalar value prediction V(s) = w @ φ(s) + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
    """Compute value prediction for an observation.

    Args:
        state: Current TD learner state
        observation: Input feature vector φ(s)

    Returns:
        Scalar value prediction `V(s) = w @ φ(s) + b`
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, reward, next_observation, gamma)

Update learner given a TD transition.

Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error δ = R + γV(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias

Args: state: Current TD learner state observation: Current observation φ(s) reward: Reward R received next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal states)

Returns: TDUpdateResult with new state, prediction, TD error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: TDLearnerState,
    observation: Observation,
    reward: Array,
    next_observation: Observation,
    gamma: Array,
) -> TDUpdateResult:
    """Update learner given a TD transition.

    Performs one step of TD learning:
    1. Compute V(s) and V(s')
    2. Compute TD error δ = R + γV(s') - V(s)
    3. Get weight updates from TD optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current TD learner state
        observation: Current observation φ(s)
        reward: Reward R received
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal states)

    Returns:
        TDUpdateResult with new state, prediction, TD error, and metrics
    """
    # Compute predictions
    prediction = self.predict(state, observation)
    next_prediction = self.predict(state, next_observation)

    # Compute TD error: δ = R + γV(s') - V(s)
    gamma_scalar = jnp.squeeze(gamma)
    td_error = (
        jnp.squeeze(reward)
        + gamma_scalar * jnp.squeeze(next_prediction)
        - jnp.squeeze(prediction)
    )

    # Get update from TD optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        td_error,
        observation,
        next_observation,
        gamma,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = TDLearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
    )

    # Pack metrics as array for scan compatibility
    # Format: [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]
    squared_td_error = td_error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
    metrics = jnp.array(
        [squared_td_error, td_error, mean_step_size, mean_elig_trace],
        dtype=jnp.float32,
    )

    return TDUpdateResult(
        state=new_state,
        prediction=prediction,
        td_error=jnp.atleast_1d(td_error),
        metrics=metrics,
    )

TDUpdateResult

Result of a TD learner update step.

Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error δ = R + γV(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]

UpdateResult

Result of a learner update step.

Attributes: state: Updated learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics [squared_error, error, ...]

NormalizerState

State for online feature normalization.

Uses Welford's online algorithm for numerically stable estimation of running mean and variance.

Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature sample_count: Number of samples seen decay: Exponential decay factor for estimates (1.0 = no decay, pure online)

OnlineNormalizer(epsilon=1e-08, decay=0.99)

Online feature normalizer for continual learning.

Normalizes features using running estimates of mean and standard deviation: x_normalized = (x - mean) / (std + epsilon)

The normalizer updates its estimates at every time step, following temporal uniformity. Uses exponential moving average for non-stationary environments.

Attributes: epsilon: Small constant for numerical stability decay: Exponential decay for running estimates (0.99 = slower adaptation)

Args: epsilon: Small constant added to std for numerical stability decay: Exponential decay factor for running estimates. Lower values adapt faster to changes. 1.0 means pure online average (no decay).

Source code in src/alberta_framework/core/normalizers.py
def __init__(
    self,
    epsilon: float = 1e-8,
    decay: float = 0.99,
):
    """Initialize the online normalizer.

    Args:
        epsilon: Small constant added to std for numerical stability
        decay: Exponential decay factor for running estimates.
               Lower values adapt faster to changes.
               1.0 means pure online average (no decay).
    """
    self._epsilon = epsilon
    self._decay = decay

init(feature_dim)

Initialize normalizer state.

Args: feature_dim: Dimension of feature vectors

Returns: Initial normalizer state with zero mean and unit variance

Source code in src/alberta_framework/core/normalizers.py
def init(self, feature_dim: int) -> NormalizerState:
    """Initialize normalizer state.

    Args:
        feature_dim: Dimension of feature vectors

    Returns:
        Initial normalizer state with zero mean and unit variance
    """
    return NormalizerState(
        mean=jnp.zeros(feature_dim, dtype=jnp.float32),
        var=jnp.ones(feature_dim, dtype=jnp.float32),
        sample_count=jnp.array(0.0, dtype=jnp.float32),
        decay=jnp.array(self._decay, dtype=jnp.float32),
    )

normalize(state, observation)

Normalize observation and update running statistics.

This method both normalizes the current observation AND updates the running statistics, maintaining temporal uniformity.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Tuple of (normalized_observation, new_state)

Source code in src/alberta_framework/core/normalizers.py
def normalize(
    self,
    state: NormalizerState,
    observation: Array,
) -> tuple[Array, NormalizerState]:
    """Normalize observation and update running statistics.

    This method both normalizes the current observation AND updates
    the running statistics, maintaining temporal uniformity.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Tuple of (normalized_observation, new_state)
    """
    # Update count
    new_count = state.sample_count + 1.0

    # Compute effective decay (ramp up from 0 to target decay)
    # This prevents instability in early steps
    effective_decay = jnp.minimum(state.decay, 1.0 - 1.0 / (new_count + 1.0))

    # Update mean using exponential moving average
    delta = observation - state.mean
    new_mean = state.mean + (1.0 - effective_decay) * delta

    # Update variance using exponential moving average of squared deviations
    # This is a simplified Welford's algorithm adapted for EMA
    delta2 = observation - new_mean
    new_var = effective_decay * state.var + (1.0 - effective_decay) * delta * delta2

    # Ensure variance is positive
    new_var = jnp.maximum(new_var, self._epsilon)

    # Normalize using updated statistics
    std = jnp.sqrt(new_var)
    normalized = (observation - new_mean) / (std + self._epsilon)

    new_state = NormalizerState(
        mean=new_mean,
        var=new_var,
        sample_count=new_count,
        decay=state.decay,
    )

    return normalized, new_state

normalize_only(state, observation)

Normalize observation without updating statistics.

Useful for inference or when you want to normalize multiple observations with the same statistics.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Normalized observation

Source code in src/alberta_framework/core/normalizers.py
def normalize_only(
    self,
    state: NormalizerState,
    observation: Array,
) -> Array:
    """Normalize observation without updating statistics.

    Useful for inference or when you want to normalize multiple
    observations with the same statistics.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Normalized observation
    """
    std = jnp.sqrt(state.var)
    return (observation - state.mean) / (std + self._epsilon)

update_only(state, observation)

Update statistics without returning normalized observation.

Args: state: Current normalizer state observation: Raw feature vector

Returns: Updated normalizer state

Source code in src/alberta_framework/core/normalizers.py
def update_only(
    self,
    state: NormalizerState,
    observation: Array,
) -> NormalizerState:
    """Update statistics without returning normalized observation.

    Args:
        state: Current normalizer state
        observation: Raw feature vector

    Returns:
        Updated normalizer state
    """
    _, new_state = self.normalize(state, observation)
    return new_state

IDBD(initial_step_size=0.01, meta_step_size=0.01)

Bases: Optimizer[IDBDState]

Incremental Delta-Bar-Delta optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.

This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
):
    """Initialize IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate beta for adapting step-sizes
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size

init(feature_dim)

Initialize IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: IDBD state with per-weight step-sizes and traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> IDBDState:
    """Initialize IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        IDBD state with per-weight step-sizes and traces
    """
    return IDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, error, observation)

Compute IDBD weight update with adaptive step-sizes.

The IDBD algorithm:

  1. Compute step-sizes: alpha_i = exp(log_alpha_i)
  2. Update weights: w_i += alpha_i * error * x_i
  3. Update log step-sizes: log_alpha_i += beta * error * x_i * h_i
  4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i

The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.

Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: IDBDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute IDBD weight update with adaptive step-sizes.

    The IDBD algorithm:

    1. Compute step-sizes: `alpha_i = exp(log_alpha_i)`
    2. Update weights: `w_i += alpha_i * error * x_i`
    3. Update log step-sizes: `log_alpha_i += beta * error * x_i * h_i`
    4. Update traces: `h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i`

    The trace h_i tracks the correlation between current and past gradients.
    When gradients consistently point the same direction, h_i grows,
    leading to larger step-sizes.

    Args:
        state: Current IDBD state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    beta = state.meta_step_size

    # Current step-sizes (exponentiate log values)
    alphas = jnp.exp(state.log_step_sizes)

    # Weight updates: alpha_i * error * x_i
    weight_delta = alphas * error_scalar * observation

    # Meta-update: adapt step-sizes based on gradient correlation
    # log_alpha_i += beta * error * x_i * h_i
    gradient_correlation = error_scalar * observation * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # Update traces: h_i = h_i * decay + alpha_i * error * x_i
    # decay = max(0, 1 - alpha_i * x_i^2)
    decay = jnp.maximum(0.0, 1.0 - alphas * observation**2)
    new_traces = state.traces * decay + alphas * error_scalar * observation

    # Bias updates (similar logic but scalar)
    bias_alpha = state.bias_step_size
    bias_delta = bias_alpha * error_scalar

    # Update bias step-size
    bias_gradient_correlation = error_scalar * state.bias_trace
    new_bias_step_size = bias_alpha * jnp.exp(beta * bias_gradient_correlation)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)

    # Update bias trace
    bias_decay = jnp.maximum(0.0, 1.0 - bias_alpha)
    new_bias_trace = state.bias_trace * bias_decay + bias_alpha * error_scalar

    new_state = IDBDState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(alphas),
            "min_step_size": jnp.min(alphas),
            "max_step_size": jnp.max(alphas),
        },
    )

LMS(step_size=0.01)

Bases: Optimizer[LMSState]

Least Mean Square optimizer with fixed step-size.

The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t

This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.

Attributes: step_size: Fixed learning rate alpha

Args: step_size: Fixed learning rate

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, step_size: float = 0.01):
    """Initialize LMS optimizer.

    Args:
        step_size: Fixed learning rate
    """
    self._step_size = step_size

init(feature_dim)

Initialize LMS state.

Args: feature_dim: Dimension of weight vector (unused for LMS)

Returns: LMS state containing the step-size

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> LMSState:
    """Initialize LMS state.

    Args:
        feature_dim: Dimension of weight vector (unused for LMS)

    Returns:
        LMS state containing the step-size
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

update(state, error, observation)

Compute LMS weight update.

Update rule: delta_w = alpha * error * x

Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight and bias deltas

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: LMSState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute LMS weight update.

    Update rule: `delta_w = alpha * error * x`

    Args:
        state: Current LMS state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight and bias deltas
    """
    alpha = state.step_size
    error_scalar = jnp.squeeze(error)

    # Weight update: alpha * error * x
    weight_delta = alpha * error_scalar * observation

    # Bias update: alpha * error
    bias_delta = alpha * error_scalar

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=state,  # LMS state doesn't change
        metrics={"step_size": alpha},
    )

TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)

Bases: TDOptimizer[TDIDBDState]

TD-IDBD optimizer for temporal-difference learning.

Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Two variants are supported: - Semi-gradient (default): Uses only φ(s) in meta-update, more stable - Ordinary gradient: Uses both φ(s) and φ(s'), more accurate but sensitive

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

The semi-gradient TD-IDBD algorithm (Algorithm 3 in paper): 1. Compute TD error: δ = R + γ*w^T*φ(s') - w^T*φ(s) 2. Update meta-weights: β_i += θ*δ*φ_i(s)*h_i 3. Compute step-sizes: α_i = exp(β_i) 4. Update eligibility traces: z_i = γ*λ*z_i + φ_i(s) 5. Update weights: w_i += α_i*δ*z_i 6. Update h traces: h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    use_semi_gradient: bool = True,
):
    """Initialize TD-IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        use_semi_gradient: If True, use semi-gradient variant (recommended)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._use_semi_gradient = use_semi_gradient

init(feature_dim)

Initialize TD-IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> TDIDBDState:
    """Initialize TD-IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        TD-IDBD state with per-weight step-sizes, traces, and h traces
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute TD-IDBD weight update with adaptive step-sizes.

Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.

Args: state: Current TD-IDBD state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: TDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute TD-IDBD weight update with adaptive step-sizes.

    Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
    from Kearney et al. 2019.

    Args:
        state: Current TD-IDBD state
        td_error: TD error δ = R + γV(s') - V(s)
        observation: Current observation φ(s)
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    gamma_scalar = jnp.squeeze(gamma)

    if self._use_semi_gradient:
        # Semi-gradient TD-IDBD (Algorithm 3)
        # β_i += θ*δ*φ_i(s)*h_i
        gradient_correlation = delta * observation * state.h_traces
        new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
    else:
        # Ordinary gradient TD-IDBD (Algorithm 4)
        # β_i -= θ*δ*[γ*φ_i(s') - φ_i(s)]*h_i
        # Note: negative sign because gradient direction is reversed
        feature_diff = gamma_scalar * next_observation - observation
        gradient_correlation = delta * feature_diff * state.h_traces
        new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # Get updated step-sizes for weight update
    new_alphas = jnp.exp(new_log_step_sizes)

    # Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation

    # Compute weight delta: α_i*δ*z_i
    weight_delta = new_alphas * delta * new_eligibility_traces

    if self._use_semi_gradient:
        # Semi-gradient h update (Algorithm 3, line 9)
        # h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i
        h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
    else:
        # Ordinary gradient h update (Algorithm 4, line 9)
        # h_i = h_i*[1 + α_i*z_i*(γ*φ_i(s') - φ_i(s))]^+ + α_i*δ*z_i
        feature_diff = gamma_scalar * next_observation - observation
        h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates (similar logic but scalar)
    if self._use_semi_gradient:
        # Semi-gradient bias meta-update
        bias_gradient_correlation = delta * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
    else:
        # Ordinary gradient bias meta-update
        # For bias, φ(s) = 1, so feature_diff = γ - 1
        bias_feature_diff = gamma_scalar - 1.0
        bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    # Update bias eligibility trace
    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0

    # Bias weight delta
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    if self._use_semi_gradient:
        # Semi-gradient bias h update
        bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )
    else:
        # Ordinary gradient bias h update
        bias_feature_diff = gamma_scalar - 1.0
        bias_h_decay = jnp.maximum(
            0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
        )
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )

    new_state = TDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        meta_step_size=theta,
        trace_decay=lam,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
        },
    )

Autostep(initial_step_size=0.01, meta_step_size=0.01, normalizer_decay=0.99)

Bases: Optimizer[AutostepState]

Autostep optimizer with tuning-free step-size adaptation.

Autostep normalizes gradients to prevent large updates and adapts per-weight step-sizes based on gradient correlation. The key innovation is automatic normalization that makes the algorithm robust to different feature scales.

The algorithm maintains: - Per-weight step-sizes that adapt based on gradient correlation - Running max of absolute gradients for normalization - Traces for detecting consistent gradient directions

Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012). "Tuning-free step-size adaptation"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate mu for adapting step-sizes normalizer_decay: Decay factor tau for gradient normalizers

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate for adapting step-sizes normalizer_decay: Decay factor for gradient normalizers (higher = slower decay)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    normalizer_decay: float = 0.99,
):
    """Initialize Autostep optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate for adapting step-sizes
        normalizer_decay: Decay factor for gradient normalizers (higher = slower decay)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._normalizer_decay = normalizer_decay

init(feature_dim)

Initialize Autostep state.

Args: feature_dim: Dimension of weight vector

Returns: Autostep state with per-weight step-sizes, traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutostepState:
    """Initialize Autostep state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Autostep state with per-weight step-sizes, traces, and normalizers
    """
    return AutostepState(
        step_sizes=jnp.full(feature_dim, self._initial_step_size, dtype=jnp.float32),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

update(state, error, observation)

Compute Autostep weight update with normalized gradients.

The Autostep algorithm:

  1. Compute gradient: g_i = error * x_i
  2. Normalize gradient: g_i' = g_i / max(|g_i|, v_i)
  3. Update weights: w_i += alpha_i * g_i'
  4. Update step-sizes: alpha_i *= exp(mu * g_i' * h_i)
  5. Update traces: h_i = h_i * (1 - alpha_i) + alpha_i * g_i'
  6. Update normalizers: v_i = max(|g_i|, v_i * tau)

Args: state: Current Autostep state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutostepState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute Autostep weight update with normalized gradients.

    The Autostep algorithm:

    1. Compute gradient: `g_i = error * x_i`
    2. Normalize gradient: `g_i' = g_i / max(|g_i|, v_i)`
    3. Update weights: `w_i += alpha_i * g_i'`
    4. Update step-sizes: `alpha_i *= exp(mu * g_i' * h_i)`
    5. Update traces: `h_i = h_i * (1 - alpha_i) + alpha_i * g_i'`
    6. Update normalizers: `v_i = max(|g_i|, v_i * tau)`

    Args:
        state: Current Autostep state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    mu = state.meta_step_size
    tau = state.normalizer_decay

    # Compute raw gradient
    gradient = error_scalar * observation

    # Normalize gradient using running max
    abs_gradient = jnp.abs(gradient)
    normalizer = jnp.maximum(abs_gradient, state.normalizers)
    normalized_gradient = gradient / (normalizer + 1e-8)

    # Compute weight delta using normalized gradient
    weight_delta = state.step_sizes * normalized_gradient

    # Update step-sizes based on gradient correlation
    gradient_correlation = normalized_gradient * state.traces
    new_step_sizes = state.step_sizes * jnp.exp(mu * gradient_correlation)

    # Clip step-sizes to prevent instability
    new_step_sizes = jnp.clip(new_step_sizes, 1e-8, 1.0)

    # Update traces with decay based on step-size
    trace_decay = 1.0 - state.step_sizes
    new_traces = state.traces * trace_decay + state.step_sizes * normalized_gradient

    # Update normalizers with decay
    new_normalizers = jnp.maximum(abs_gradient, state.normalizers * tau)

    # Bias updates (similar logic)
    bias_gradient = error_scalar
    abs_bias_gradient = jnp.abs(bias_gradient)
    bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer)
    normalized_bias_gradient = bias_gradient / (bias_normalizer + 1e-8)

    bias_delta = state.bias_step_size * normalized_bias_gradient

    bias_correlation = normalized_bias_gradient * state.bias_trace
    new_bias_step_size = state.bias_step_size * jnp.exp(mu * bias_correlation)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-8, 1.0)

    bias_trace_decay = 1.0 - state.bias_step_size
    new_bias_trace = (
        state.bias_trace * bias_trace_decay + state.bias_step_size * normalized_bias_gradient
    )

    new_bias_normalizer = jnp.maximum(abs_bias_gradient, state.bias_normalizer * tau)

    new_state = AutostepState(
        step_sizes=new_step_sizes,
        traces=new_traces,
        normalizers=new_normalizers,
        meta_step_size=mu,
        normalizer_decay=tau,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(state.step_sizes),
            "min_step_size": jnp.min(state.step_sizes),
            "max_step_size": jnp.max(state.step_sizes),
            "mean_normalizer": jnp.mean(state.normalizers),
        },
    )

AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Bases: TDOptimizer[AutoTDIDBDState]

AutoStep-style normalized TD-IDBD optimizer.

Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta. Includes: 1. Normalization of the meta-weight update by a running trace of recent updates 2. Effective step-size normalization to prevent overshooting

Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(λ)"

The AutoTDIDBD algorithm: 1. Compute TD error: δ = R + γ*w^T*φ(s') - w^T*φ(s) 2. Update normalizers: η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|, η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i)) 3. Normalized meta-update: β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i 4. Effective step-size normalization: M = max(-exp(β)*[γφ(s')-φ(s)]^T*z, 1) then β_i -= log(M) 5. Update weights and traces as in TIDBD

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
):
    """Initialize AutoTDIDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._normalizer_decay = normalizer_decay

init(feature_dim)

Initialize AutoTDIDBD state.

Args: feature_dim: Dimension of weight vector

Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutoTDIDBDState:
    """Initialize AutoTDIDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

Implements Algorithm 6 from Kearney et al. 2019.

Args: state: Current AutoTDIDBD state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutoTDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

    Implements Algorithm 6 from Kearney et al. 2019.

    Args:
        state: Current AutoTDIDBD state
        td_error: TD error δ = R + γV(s') - V(s)
        observation: Current observation φ(s)
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    tau = state.normalizer_decay
    gamma_scalar = jnp.squeeze(gamma)

    # Feature difference: γ*φ(s') - φ(s)
    feature_diff = gamma_scalar * next_observation - observation

    # Current step-sizes
    alphas = jnp.exp(state.log_step_sizes)

    # Update normalizers (Algorithm 6, lines 5-7)
    # η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
    #           η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))
    abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
    normalizer_decay_term = (
        (1.0 / tau)
        * alphas
        * feature_diff
        * state.eligibility_traces
        * (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
    )
    new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
    # Ensure normalizers don't go to zero
    new_normalizers = jnp.maximum(new_normalizers, 1e-8)

    # Normalized meta-update (Algorithm 6, line 9)
    # β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i
    normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
    new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient

    # Effective step-size normalization (Algorithm 6, lines 10-11)
    # M = max(-exp(β_i)*[γφ_i(s')-φ_i(s)]^T*z_i, 1)
    # β_i -= log(M)
    effective_step_size = -jnp.sum(
        jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
    )
    normalization_factor = jnp.maximum(effective_step_size, 1.0)
    new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # Get updated step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation

    # Compute weight delta: α_i*δ*z_i
    weight_delta = new_alphas * delta * new_eligibility_traces

    # Update h traces (ordinary gradient variant, Algorithm 6 line 15)
    # h_i = h_i*[1 + α_i*[γφ_i(s')-φ_i(s)]*z_i]^+ + α_i*δ*z_i
    h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
    new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    bias_alpha = jnp.exp(state.bias_log_step_size)
    bias_feature_diff = gamma_scalar - 1.0  # For bias, φ(s) = 1

    # Bias normalizer update
    abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
    bias_normalizer_decay_term = (
        (1.0 / tau)
        * bias_alpha
        * bias_feature_diff
        * state.bias_eligibility_trace
        * (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
    )
    new_bias_normalizer = jnp.maximum(
        abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
    )
    new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)

    # Normalized bias meta-update
    normalized_bias_gradient = (
        delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
    )
    new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient

    # Effective step-size normalization for bias
    bias_effective_step_size = (
        -jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
    )
    bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
    new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    # Update bias eligibility trace
    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0

    # Bias weight delta
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    # Bias h trace update
    bias_h_decay = jnp.maximum(
        0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
    )
    new_bias_h_trace = (
        state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
    )

    new_state = AutoTDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        normalizers=new_normalizers,
        meta_step_size=theta,
        trace_decay=lam,
        normalizer_decay=tau,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

Optimizer

Bases: ABC

Base class for optimizers.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, error, observation) abstractmethod

Compute weight updates given prediction error.

Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute weight updates given prediction error.

    Args:
        state: Current optimizer state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with deltas and new state
    """
    ...

TDOptimizer

Bases: ABC

Base class for TD optimizers.

TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, td_error, observation, next_observation, gamma) abstractmethod

Compute weight updates given TD error.

Args: state: Current optimizer state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)

Returns: TDOptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute weight updates given TD error.

    Args:
        state: Current optimizer state
        td_error: TD error δ = R + γV(s') - V(s)
        observation: Current observation φ(s)
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal)

    Returns:
        TDOptimizerUpdate with deltas and new state
    """
    ...

TDOptimizerUpdate

Result of a TD optimizer update step.

Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging

AutostepState

State for the Autostep optimizer.

Autostep is a tuning-free step-size adaptation algorithm that normalizes gradients to prevent large updates and adapts step-sizes based on gradient correlation.

Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation"

Attributes: step_sizes: Per-weight step-sizes (alpha_i) traces: Per-weight traces for gradient correlation (h_i) normalizers: Running max absolute gradient per weight (v_i) meta_step_size: Meta learning rate mu for adapting step-sizes normalizer_decay: Decay factor for the normalizer (tau) bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term bias_normalizer: Normalizer for the bias gradient

AutoTDIDBDState

State for the AutoTDIDBD optimizer.

AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.

Reference: Kearney et al. 2019, Algorithm 6

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation

BatchedLearningResult

Result from batched learning loop across multiple seeds.

Used with run_learning_loop_batched for vmap-based GPU parallelization.

Attributes: states: Batched learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, 3) where columns are [squared_error, error, mean_step_size] step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled

BatchedNormalizedResult

Result from batched normalized learning loop across multiple seeds.

Used with run_normalized_learning_loop_batched for vmap-based GPU parallelization.

Attributes: states: Batched normalized learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, 4) where columns are [squared_error, error, mean_step_size, normalizer_mean_var] step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled

IDBDState

State for the IDBD (Incremental Delta-Bar-Delta) optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.

Reference: Sutton 1992, "Adapting Bias by Gradient Descent"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term

LearnerState

State for a linear learner.

Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer

LMSState

State for the LMS (Least Mean Square) optimizer.

LMS uses a fixed step-size, so state only tracks the step-size parameter.

Attributes: step_size: Fixed learning rate alpha

NormalizerHistory

History of per-feature normalizer state recorded during training.

Used for analyzing how the OnlineNormalizer adapts to distribution shifts (reactive lag diagnostic).

Attributes: means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim) variances: Per-feature variance estimates at each recording, shape (num_recordings, feature_dim) recording_indices: Step indices where recordings were made, shape (num_recordings,)

NormalizerTrackingConfig

Configuration for recording per-feature normalizer state during training.

Attributes: interval: Record normalizer state every N steps

StepSizeHistory

History of per-weight step-sizes recorded during training.

Attributes: step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights) bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None recording_indices: Step indices where recordings were made, shape (num_recordings,) normalizers: Autostep's per-weight normalizers (v_i) at each recording, shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.

StepSizeTrackingConfig

Configuration for recording per-weight step-sizes during training.

Attributes: interval: Record step-sizes every N steps include_bias: Whether to also record the bias step-size

TDIDBDState

State for the TD-IDBD (Temporal-Difference IDBD) optimizer.

TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term

TDLearnerState

State for a TD linear learner.

Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer

TDTimeStep

Single experience from a TD stream.

Represents a transition (s, r, s', gamma) for temporal-difference learning.

Attributes: observation: Feature vector φ(s) reward: Reward R received next_observation: Feature vector φ(s') gamma: Discount factor γ_t (0 at terminal states)

TimeStep

Single experience from an experience stream.

Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)

ScanStream

Bases: Protocol[StateT]

Protocol for JAX scan-compatible experience streams.

Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.

The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.

Type Parameters: StateT: The state type maintained by this stream

Examples:

stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key for initialization

Returns: Initial stream state

Source code in src/alberta_framework/streams/base.py
def init(self, key: Array) -> StateT:
    """Initialize stream state.

    Args:
        key: JAX random key for initialization

    Returns:
        Initial stream state
    """
    ...

step(state, idx)

Generate one time step. Must be JIT-compatible.

This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.

Args: state: Current stream state idx: Current step index (can be ignored for most streams)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/base.py
def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
    """Generate one time step. Must be JIT-compatible.

    This is a pure function that takes the current state and step index,
    and returns a TimeStep along with the updated state. The step index
    can be used for time-dependent behavior but is often ignored.

    Args:
        state: Current stream state
        idx: Current step index (can be ignored for most streams)

    Returns:
        Tuple of (timestep, new_state)
    """
    ...

AbruptChangeState

State for AbruptChangeStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken

AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)

Non-stationary stream with sudden target weight changes.

Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.

Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    change_interval: int = 1000,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the abrupt change stream.

    Args:
        feature_dim: Dimension of feature vectors
        change_interval: Steps between abrupt weight changes
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._change_interval = change_interval
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> AbruptChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return AbruptChangeState(
        key=key,
        true_weights=weights,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_weights, key_x, key_noise = jr.split(state.key, 4)

    # Determine if we should change weights
    should_change = state.step_count % self._change_interval == 0

    # Generate new weights (always generated but only used if should_change)
    new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)

    # Use jnp.where to conditionally update weights (JIT-compatible)
    new_weights = jnp.where(should_change, new_random_weights, state.true_weights)

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = AbruptChangeState(
        key=key,
        true_weights=new_weights,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

CyclicState

State for CyclicStream.

Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken

CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)

Non-stationary stream that cycles between known weight configurations.

Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.

Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    cycle_length: int = 500,
    num_configurations: int = 4,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the cyclic target stream.

    Args:
        feature_dim: Dimension of feature vectors
        cycle_length: Steps spent in each configuration
        num_configurations: Number of configurations to cycle through
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._cycle_length = cycle_length
    self._num_configurations = num_configurations
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with pre-generated configurations

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> CyclicState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with pre-generated configurations
    """
    key, key_configs = jr.split(key)
    configurations = jr.normal(
        key_configs,
        (self._num_configurations, self._feature_dim),
        dtype=jnp.float32,
    )
    return CyclicState(
        key=key,
        configurations=configurations,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Get current configuration index
    config_idx = (state.step_count // self._cycle_length) % self._num_configurations
    true_weights = state.configurations[config_idx]

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = CyclicState(
        key=key,
        configurations=state.configurations,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

DynamicScaleShiftState

State for DynamicScaleShiftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken

DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)

Non-stationary stream with abruptly changing feature scales.

Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.

The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).

Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    scale_change_interval: int = 2000,
    weight_change_interval: int = 1000,
    min_scale: float = 0.01,
    max_scale: float = 100.0,
    noise_std: float = 0.1,
):
    """Initialize the dynamic scale shift stream.

    Args:
        feature_dim: Dimension of feature vectors
        scale_change_interval: Steps between abrupt scale changes
        weight_change_interval: Steps between abrupt weight changes
        min_scale: Minimum scale factor (log-uniform sampling)
        max_scale: Maximum scale factor (log-uniform sampling)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._scale_change_interval = scale_change_interval
    self._weight_change_interval = weight_change_interval
    self._min_scale = min_scale
    self._max_scale = max_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> DynamicScaleShiftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and scales
    """
    key, k_weights, k_scales = jr.split(key, 3)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial scales: log-uniform between min and max
    log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    scales = jnp.exp(log_scales).astype(jnp.float32)
    return DynamicScaleShiftState(
        key=key,
        true_weights=weights,
        current_scales=scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: DynamicScaleShiftState, idx: Array
) -> tuple[TimeStep, DynamicScaleShiftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)

    # Check if scales should change
    should_change_scales = state.step_count % self._scale_change_interval == 0
    new_log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
    new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)

    # Check if weights should change
    should_change_weights = state.step_count % self._weight_change_interval == 0
    new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    x = raw_x * new_scales

    # Target from true weights using RAW features (for consistent difficulty)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = DynamicScaleShiftState(
        key=key,
        true_weights=new_weights,
        current_scales=new_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

PeriodicChangeState

State for PeriodicChangeStream.

Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken

PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights oscillate sinusoidally.

Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase) where each weight has a random phase offset for diversity.

This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.

Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    period: int = 1000,
    amplitude: float = 1.0,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the periodic change stream.

    Args:
        feature_dim: Dimension of feature vectors
        period: Steps for one complete oscillation cycle
        amplitude: Magnitude of weight oscillations around base
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._period = period
    self._amplitude = amplitude
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random base weights and phases

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> PeriodicChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random base weights and phases
    """
    key, key_weights, key_phases = jr.split(key, 3)
    base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
    # Random phases in [0, 2π) for each weight
    phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
    return PeriodicChangeState(
        key=key,
        base_weights=base_weights,
        phases=phases,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: PeriodicChangeState, idx: Array) -> tuple[TimeStep, PeriodicChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Compute oscillating weights: w(t) = base + amplitude * sin(2π * t / period + phase)
    t = state.step_count.astype(jnp.float32)
    oscillation = self._amplitude * jnp.sin(2.0 * jnp.pi * t / self._period + state.phases)
    true_weights = state.base_weights + oscillation

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = PeriodicChangeState(
        key=key,
        base_weights=state.base_weights,
        phases=state.phases,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

RandomWalkState

State for RandomWalkStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights

RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights drift via random walk.

The true target function is linear: y* = w_true @ x + noise where w_true evolves via random walk at each time step.

This tests the learner's ability to continuously track a moving target.

Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    drift_rate: float = 0.001,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the random walk target stream.

    Args:
        feature_dim: Dimension of the feature/observation vectors
        drift_rate: Std dev of weight changes per step (controls non-stationarity)
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._drift_rate = drift_rate
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> RandomWalkState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return RandomWalkState(key=key, true_weights=weights)

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_drift, k_x, k_noise = jr.split(state.key, 4)

    # Drift weights
    drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
    new_weights = state.true_weights + self._drift_rate * drift

    # Generate observation and target
    x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = RandomWalkState(key=key, true_weights=new_weights)

    return timestep, new_state

ScaleDriftState

State for ScaleDriftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken

ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)

Non-stationary stream where feature scales drift via random walk.

Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.

The target is computed from unscaled features to maintain consistent difficulty across scale changes.

Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    weight_drift_rate: float = 0.001,
    scale_drift_rate: float = 0.01,
    min_log_scale: float = -4.0,  # exp(-4) ~ 0.018
    max_log_scale: float = 4.0,  # exp(4) ~ 54.6
    noise_std: float = 0.1,
):
    """Initialize the scale drift stream.

    Args:
        feature_dim: Dimension of feature vectors
        weight_drift_rate: Std dev of weight drift per step
        scale_drift_rate: Std dev of log-scale drift per step
        min_log_scale: Minimum log-scale (clips drift)
        max_log_scale: Maximum log-scale (clips drift)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._weight_drift_rate = weight_drift_rate
    self._scale_drift_rate = scale_drift_rate
    self._min_log_scale = min_log_scale
    self._max_log_scale = max_log_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and unit scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaleDriftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and unit scales
    """
    key, k_weights = jr.split(key)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial log-scales at 0 (scale = 1)
    log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
    return ScaleDriftState(
        key=key,
        true_weights=weights,
        log_scales=log_scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaleDriftState, idx: Array) -> tuple[TimeStep, ScaleDriftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)

    # Drift target weights
    weight_drift = self._weight_drift_rate * jr.normal(
        k_w_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_weights = state.true_weights + weight_drift

    # Drift log-scales (bounded random walk)
    scale_drift = self._scale_drift_rate * jr.normal(
        k_s_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_log_scales = state.log_scales + scale_drift
    new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    scales = jnp.exp(new_log_scales)
    x = raw_x * scales

    # Target from true weights using RAW features
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = ScaleDriftState(
        key=key,
        true_weights=new_weights,
        log_scales=new_log_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

ScaledStreamState

State for ScaledStreamWrapper.

Attributes: inner_state: State of the wrapped stream

ScaledStreamWrapper(inner_stream, feature_scales)

Wrapper that applies per-feature scaling to any stream's observations.

This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.

Examples:

stream = ScaledStreamWrapper(
    AbruptChangeStream(feature_dim=10, change_interval=1000),
    feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
                              100.0, 1000.0, 0.001, 0.01, 0.1])
)

Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)

Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.

Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim

Source code in src/alberta_framework/streams/synthetic.py
def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
    """Initialize the scaled stream wrapper.

    Args:
        inner_stream: Stream to wrap (must implement ScanStream protocol)
        feature_scales: Array of scale factors, one per feature. Must have
            shape (feature_dim,) matching the inner stream's feature_dim.

    Raises:
        ValueError: If feature_scales length doesn't match inner stream's feature_dim
    """
    self._inner_stream: ScanStream[Any] = inner_stream
    self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)

    if self._feature_scales.shape[0] != inner_stream.feature_dim:
        raise ValueError(
            f"feature_scales length ({self._feature_scales.shape[0]}) "
            f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
        )

feature_dim property

Return the dimension of observation vectors.

inner_stream property

Return the wrapped stream.

feature_scales property

Return the per-feature scale factors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state wrapping the inner stream's state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaledStreamState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state wrapping the inner stream's state
    """
    inner_state = self._inner_stream.init(key)
    return ScaledStreamState(inner_state=inner_state)

step(state, idx)

Generate one time step with scaled observations.

Args: state: Current stream state idx: Current step index

Returns: Tuple of (timestep with scaled observation, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
    """Generate one time step with scaled observations.

    Args:
        state: Current stream state
        idx: Current step index

    Returns:
        Tuple of (timestep with scaled observation, new_state)
    """
    timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)

    # Scale the observation
    scaled_observation = timestep.observation * self._feature_scales

    scaled_timestep = TimeStep(
        observation=scaled_observation,
        target=timestep.target,
    )

    new_state = ScaledStreamState(inner_state=new_inner_state)
    return scaled_timestep, new_state

SuttonExperiment1State

State for SuttonExperiment1Stream.

Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken

SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)

Non-stationary stream replicating Experiment 1 from Sutton 1992.

This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)

Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    num_relevant: int = 5,
    num_irrelevant: int = 15,
    change_interval: int = 20,
):
    """Initialize the Sutton Experiment 1 stream.

    Args:
        num_relevant: Number of relevant inputs with ±1 weights
        num_irrelevant: Number of irrelevant inputs with 0 weights
        change_interval: Number of steps between sign flips
    """
    self._num_relevant = num_relevant
    self._num_irrelevant = num_irrelevant
    self._change_interval = change_interval

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with all +1 signs

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> SuttonExperiment1State:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with all +1 signs
    """
    signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
    return SuttonExperiment1State(
        key=key,
        signs=signs,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: SuttonExperiment1State, idx: Array
) -> tuple[TimeStep, SuttonExperiment1State]:
    """Generate one time step.

    At each step:
    1. If at a change interval (and not step 0), flip one random sign
    2. Generate random inputs from N(0, 1)
    3. Compute target as sum of relevant inputs weighted by signs

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_which = jr.split(state.key, 3)

    # Determine if we should flip a sign (not at step 0)
    should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)

    # Select which sign to flip
    idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)

    # Create flip mask
    flip_mask = jnp.where(
        jnp.arange(self._num_relevant) == idx_to_flip,
        jnp.array(-1.0, dtype=jnp.float32),
        jnp.array(1.0, dtype=jnp.float32),
    )

    # Apply flip mask conditionally
    new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)

    # Generate observation from N(0, 1)
    x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)

    # Compute target: sum of first num_relevant inputs weighted by signs
    target = jnp.dot(new_signs, x[: self._num_relevant])

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = SuttonExperiment1State(
        key=key,
        signs=new_signs,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

Timer(name='Operation', verbose=True, print_fn=None)

Context manager for timing code execution.

Measures wall-clock time for a block of code and optionally prints the duration when the block completes.

Attributes: name: Description of what is being timed duration: Elapsed time in seconds (available after context exits) start_time: Timestamp when timing started end_time: Timestamp when timing ended

Examples:

with Timer("Training loop"):
    for i in range(1000):
        pass
# Output: Training loop completed in 0.01s

# Silent timing (no print):
with Timer("Silent", verbose=False) as t:
    time.sleep(0.1)
print(f"Elapsed: {t.duration:.2f}s")
# Output: Elapsed: 0.10s

# Custom print function:
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
    pass
# Output: >> Custom completed in 0.00s

Args: name: Description of the operation being timed verbose: Whether to print the duration when done print_fn: Custom print function (defaults to built-in print)

Source code in src/alberta_framework/utils/timing.py
def __init__(
    self,
    name: str = "Operation",
    verbose: bool = True,
    print_fn: Callable[[str], None] | None = None,
):
    """Initialize the timer.

    Args:
        name: Description of the operation being timed
        verbose: Whether to print the duration when done
        print_fn: Custom print function (defaults to built-in print)
    """
    self.name = name
    self.verbose = verbose
    self.print_fn = print_fn or print
    self.start_time: float = 0.0
    self.end_time: float = 0.0
    self.duration: float = 0.0

elapsed()

Get elapsed time since timer started (can be called during execution).

Returns: Elapsed time in seconds

Source code in src/alberta_framework/utils/timing.py
def elapsed(self) -> float:
    """Get elapsed time since timer started (can be called during execution).

    Returns:
        Elapsed time in seconds
    """
    return time.perf_counter() - self.start_time

GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)

Experience stream from a Gymnasium environment using Python loop.

This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.

For batch learning, use collect_trajectory() followed by learn_from_trajectory().

Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes

Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
):
    """Initialize the Gymnasium stream.

    Args:
        env: Gymnasium environment instance
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action).
            If False, features = obs only
        seed: Random seed for environment resets and random policy
    """
    self._env = env
    self._mode = mode
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    if mode == PredictionMode.NEXT_STATE:
        self._target_dim = self._obs_dim
    else:
        self._target_dim = 1

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_estimator: Callable[[Array], float] | None = None

feature_dim property

Return the dimension of feature vectors.

target_dim property

Return the dimension of target vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

mode property

Return the prediction mode.

set_value_estimator(estimator)

Set the value estimator for proper TD learning in VALUE mode.

Source code in src/alberta_framework/streams/gymnasium.py
def set_value_estimator(self, estimator: Callable[[Array], float]) -> None:
    """Set the value estimator for proper TD learning in VALUE mode."""
    self._value_estimator = estimator

PredictionMode

Bases: Enum

Mode for what the stream predicts.

REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)

TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)

Experience stream for proper TD learning with value function bootstrap.

This stream integrates with a learner to use its predictions for bootstrapping in TD targets.

Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)

for step, timestep in enumerate(stream):
    result = learner.update(state, timestep.observation, timestep.target)
    state = result.state
    stream.update_value_function(lambda x: learner.predict(state, x))

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = False,
    seed: int = 0,
):
    """Initialize the TD stream.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor
        include_action_in_features: If True, learn Q(s,a). If False, learn V(s)
        seed: Random seed
    """
    self._env = env
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_fn: Callable[[Array], float] = lambda x: 0.0

feature_dim property

Return the dimension of feature vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

update_value_function(value_fn)

Update the value function used for TD bootstrapping.

Source code in src/alberta_framework/streams/gymnasium.py
def update_value_function(self, value_fn: Callable[[Array], float]) -> None:
    """Update the value function used for TD bootstrapping."""
    self._value_fn = value_fn

metrics_to_dicts(metrics, normalized=False)

Convert metrics array to list of dicts for backward compatibility.

Args: metrics: Array of shape (num_steps, 3) or (num_steps, 4) normalized: If True, expects 4 columns including normalizer_mean_var

Returns: List of metric dictionaries

Source code in src/alberta_framework/core/learners.py
def metrics_to_dicts(metrics: Array, normalized: bool = False) -> list[dict[str, float]]:
    """Convert metrics array to list of dicts for backward compatibility.

    Args:
        metrics: Array of shape (num_steps, 3) or (num_steps, 4)
        normalized: If True, expects 4 columns including normalizer_mean_var

    Returns:
        List of metric dictionaries
    """
    result = []
    for row in metrics:
        d = {
            "squared_error": float(row[0]),
            "error": float(row[1]),
            "mean_step_size": float(row[2]),
        }
        if normalized and len(row) > 3:
            d["normalizer_mean_var"] = float(row[3])
        result.append(d)
    return result

run_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None)

Run the learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.

Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns a 3-tuple including StepSizeHistory.

Returns: If step_size_tracking is None: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size] If step_size_tracking is provided: Tuple of (final_state, metrics_array, step_size_history)

Raises: ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps

Source code in src/alberta_framework/core/learners.py
def run_learning_loop[StreamStateT](
    learner: LinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: LearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
) -> tuple[LearnerState, Array] | tuple[LearnerState, Array, StepSizeHistory]:
    """Run the learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.
    It returns metrics as a fixed-size array rather than a list of dicts.

    Args:
        learner: The learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, returns a 3-tuple including StepSizeHistory.

    Returns:
        If step_size_tracking is None:
            Tuple of (final_state, metrics_array) where metrics_array has shape
            (num_steps, 3) with columns [squared_error, error, mean_step_size]
        If step_size_tracking is provided:
            Tuple of (final_state, metrics_array, step_size_history)

    Raises:
        ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps
    """
    # Validate tracking config
    if step_size_tracking is not None:
        if step_size_tracking.interval < 1:
            raise ValueError(
                f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
            )
        if step_size_tracking.interval > num_steps:
            raise ValueError(
                f"step_size_tracking.interval ({step_size_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    feature_dim = stream.feature_dim

    if step_size_tracking is None:
        # Original behavior without tracking
        def step_fn(
            carry: tuple[LearnerState, StreamStateT], idx: Array
        ) -> tuple[tuple[LearnerState, StreamStateT], Array]:
            l_state, s_state = carry
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)
            return (result.state, new_s_state), result.metrics

        (final_learner, _), metrics = jax.lax.scan(
            step_fn, (learner_state, stream_state), jnp.arange(num_steps)
        )

        return final_learner, metrics

    else:
        # Step-size tracking enabled
        interval = step_size_tracking.interval
        include_bias = step_size_tracking.include_bias
        num_recordings = num_steps // interval

        # Pre-allocate history arrays
        step_size_history = jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
        bias_history = jnp.zeros(num_recordings, dtype=jnp.float32) if include_bias else None
        recording_indices = jnp.zeros(num_recordings, dtype=jnp.int32)

        # Check if we need to track Autostep normalizers
        # We detect this at trace time by checking the initial optimizer state
        track_normalizers = hasattr(learner_state.optimizer_state, "normalizers")
        normalizer_history = (
            jnp.zeros((num_recordings, feature_dim), dtype=jnp.float32)
            if track_normalizers
            else None
        )

        def step_fn_with_tracking(
            carry: tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
            idx: Array,
        ) -> tuple[
            tuple[LearnerState, StreamStateT, Array, Array | None, Array, Array | None],
            Array,
        ]:
            l_state, s_state, ss_history, b_history, rec_indices, norm_history = carry

            # Perform learning step
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)

            # Check if we should record at this step (idx % interval == 0)
            should_record = (idx % interval) == 0
            recording_idx = idx // interval

            # Extract current step-sizes
            # Use hasattr checks at trace time (this works because the type is fixed)
            opt_state = result.state.optimizer_state
            if hasattr(opt_state, "log_step_sizes"):
                # IDBD stores log step-sizes
                weight_ss = jnp.exp(opt_state.log_step_sizes)
                bias_ss = opt_state.bias_step_size
            elif hasattr(opt_state, "step_sizes"):
                # Autostep stores step-sizes directly
                weight_ss = opt_state.step_sizes
                bias_ss = opt_state.bias_step_size
            else:
                # LMS has a single fixed step-size
                weight_ss = jnp.full(feature_dim, opt_state.step_size)
                bias_ss = opt_state.step_size

            # Conditionally update history arrays
            new_ss_history = jax.lax.cond(
                should_record,
                lambda _: ss_history.at[recording_idx].set(weight_ss),
                lambda _: ss_history,
                None,
            )

            new_b_history = b_history
            if b_history is not None:
                new_b_history = jax.lax.cond(
                    should_record,
                    lambda _: b_history.at[recording_idx].set(bias_ss),
                    lambda _: b_history,
                    None,
                )

            new_rec_indices = jax.lax.cond(
                should_record,
                lambda _: rec_indices.at[recording_idx].set(idx),
                lambda _: rec_indices,
                None,
            )

            # Track Autostep normalizers (v_i) if applicable
            new_norm_history = norm_history
            if norm_history is not None and hasattr(opt_state, "normalizers"):
                new_norm_history = jax.lax.cond(
                    should_record,
                    lambda _: norm_history.at[recording_idx].set(opt_state.normalizers),
                    lambda _: norm_history,
                    None,
                )

            return (
                result.state,
                new_s_state,
                new_ss_history,
                new_b_history,
                new_rec_indices,
                new_norm_history,
            ), result.metrics

        initial_carry = (
            learner_state,
            stream_state,
            step_size_history,
            bias_history,
            recording_indices,
            normalizer_history,
        )

        (
            (
                final_learner,
                _,
                final_ss_history,
                final_b_history,
                final_rec_indices,
                final_norm_history,
            ),
            metrics,
        ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))

        history = StepSizeHistory(
            step_sizes=final_ss_history,
            bias_step_sizes=final_b_history,
            recording_indices=final_rec_indices,
            normalizers=final_norm_history,
        )

        return final_learner, metrics, history

run_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None)

Run learning loop across multiple seeds in parallel using jax.vmap.

This function provides GPU parallelization for multi-seed experiments, typically achieving 2-5x speedup over sequential execution.

Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...)

Returns: BatchedLearningResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, 3) - step_size_history: Batched history or None if tracking disabled

Examples:

import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_learning_loop_batched

stream = RandomWalkStream(feature_dim=10)
learner = LinearLearner(optimizer=IDBD())

# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds

Source code in src/alberta_framework/core/learners.py
def run_learning_loop_batched[StreamStateT](
    learner: LinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    keys: Array,
    learner_state: LearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
) -> BatchedLearningResult:
    """Run learning loop across multiple seeds in parallel using jax.vmap.

    This function provides GPU parallelization for multi-seed experiments,
    typically achieving 2-5x speedup over sequential execution.

    Args:
        learner: The learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run per seed
        keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
        learner_state: Initial state (if None, will be initialized from stream).
            The same initial state is used for all seeds.
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)

    Returns:
        BatchedLearningResult containing:
            - states: Batched final states with shape (num_seeds, ...) for each array
            - metrics: Array of shape (num_seeds, num_steps, 3)
            - step_size_history: Batched history or None if tracking disabled

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework import LinearLearner, IDBD, RandomWalkStream
    from alberta_framework import run_learning_loop_batched

    stream = RandomWalkStream(feature_dim=10)
    learner = LinearLearner(optimizer=IDBD())

    # Run 30 seeds in parallel
    keys = jr.split(jr.key(42), 30)
    result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)

    # result.metrics has shape (30, 10000, 3)
    mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds
    ```
    """

    # Define single-seed function that returns consistent structure
    def single_seed_run(key: Array) -> tuple[LearnerState, Array, StepSizeHistory | None]:
        result = run_learning_loop(
            learner, stream, num_steps, key, learner_state, step_size_tracking
        )
        if step_size_tracking is not None:
            state, metrics, history = cast(tuple[LearnerState, Array, StepSizeHistory], result)
            return state, metrics, history
        else:
            state, metrics = cast(tuple[LearnerState, Array], result)
            # Return None for history to maintain consistent output structure
            return state, metrics, None

    # vmap over the keys dimension
    batched_states, batched_metrics, batched_history = jax.vmap(single_seed_run)(keys)

    # Reconstruct batched history if tracking was enabled
    if step_size_tracking is not None and batched_history is not None:
        batched_step_size_history = StepSizeHistory(
            step_sizes=batched_history.step_sizes,
            bias_step_sizes=batched_history.bias_step_sizes,
            recording_indices=batched_history.recording_indices,
            normalizers=batched_history.normalizers,
        )
    else:
        batched_step_size_history = None

    return BatchedLearningResult(
        states=batched_states,
        metrics=batched_metrics,
        step_size_history=batched_step_size_history,
    )

run_normalized_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None, normalizer_tracking=None)

Run the learning loop with normalization using jax.lax.scan.

Args: learner: The normalized learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns StepSizeHistory including Autostep normalizers if applicable. normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory with means and variances over time.

Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var] If step_size_tracking only: Tuple of (final_state, metrics_array, step_size_history) If normalizer_tracking only: Tuple of (final_state, metrics_array, normalizer_history) If both: Tuple of (final_state, metrics_array, step_size_history, normalizer_history)

Raises: ValueError: If tracking interval is invalid

Source code in src/alberta_framework/core/learners.py
def run_normalized_learning_loop[StreamStateT](
    learner: NormalizedLinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: NormalizedLearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> (
    tuple[NormalizedLearnerState, Array]
    | tuple[NormalizedLearnerState, Array, StepSizeHistory]
    | tuple[NormalizedLearnerState, Array, NormalizerHistory]
    | tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory]
):
    """Run the learning loop with normalization using jax.lax.scan.

    Args:
        learner: The normalized learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, returns StepSizeHistory including Autostep normalizers if applicable.
        normalizer_tracking: Optional config for recording per-feature normalizer state.
            When provided, returns NormalizerHistory with means and variances over time.

    Returns:
        If no tracking:
            Tuple of (final_state, metrics_array) where metrics_array has shape
            (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
        If step_size_tracking only:
            Tuple of (final_state, metrics_array, step_size_history)
        If normalizer_tracking only:
            Tuple of (final_state, metrics_array, normalizer_history)
        If both:
            Tuple of (final_state, metrics_array, step_size_history, normalizer_history)

    Raises:
        ValueError: If tracking interval is invalid
    """
    # Validate tracking configs
    if step_size_tracking is not None:
        if step_size_tracking.interval < 1:
            raise ValueError(
                f"step_size_tracking.interval must be >= 1, got {step_size_tracking.interval}"
            )
        if step_size_tracking.interval > num_steps:
            raise ValueError(
                f"step_size_tracking.interval ({step_size_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    if normalizer_tracking is not None:
        if normalizer_tracking.interval < 1:
            raise ValueError(
                f"normalizer_tracking.interval must be >= 1, got {normalizer_tracking.interval}"
            )
        if normalizer_tracking.interval > num_steps:
            raise ValueError(
                f"normalizer_tracking.interval ({normalizer_tracking.interval}) "
                f"must be <= num_steps ({num_steps})"
            )

    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    feature_dim = stream.feature_dim

    # No tracking - simple case
    if step_size_tracking is None and normalizer_tracking is None:

        def step_fn(
            carry: tuple[NormalizedLearnerState, StreamStateT], idx: Array
        ) -> tuple[tuple[NormalizedLearnerState, StreamStateT], Array]:
            l_state, s_state = carry
            timestep, new_s_state = stream.step(s_state, idx)
            result = learner.update(l_state, timestep.observation, timestep.target)
            return (result.state, new_s_state), result.metrics

        (final_learner, _), metrics = jax.lax.scan(
            step_fn, (learner_state, stream_state), jnp.arange(num_steps)
        )

        return final_learner, metrics

    # Tracking enabled - need to set up history arrays
    ss_interval = step_size_tracking.interval if step_size_tracking else num_steps + 1
    norm_interval = normalizer_tracking.interval if normalizer_tracking else num_steps + 1

    ss_num_recordings = num_steps // ss_interval if step_size_tracking else 0
    norm_num_recordings = num_steps // norm_interval if normalizer_tracking else 0

    # Pre-allocate step-size history arrays
    ss_history = (
        jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
        if step_size_tracking
        else None
    )
    ss_bias_history = (
        jnp.zeros(ss_num_recordings, dtype=jnp.float32)
        if step_size_tracking and step_size_tracking.include_bias
        else None
    )
    ss_rec_indices = jnp.zeros(ss_num_recordings, dtype=jnp.int32) if step_size_tracking else None

    # Check if we need to track Autostep normalizers
    track_autostep_normalizers = hasattr(learner_state.learner_state.optimizer_state, "normalizers")
    ss_normalizers = (
        jnp.zeros((ss_num_recordings, feature_dim), dtype=jnp.float32)
        if step_size_tracking and track_autostep_normalizers
        else None
    )

    # Pre-allocate normalizer state history arrays
    norm_means = (
        jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
        if normalizer_tracking
        else None
    )
    norm_vars = (
        jnp.zeros((norm_num_recordings, feature_dim), dtype=jnp.float32)
        if normalizer_tracking
        else None
    )
    norm_rec_indices = (
        jnp.zeros(norm_num_recordings, dtype=jnp.int32) if normalizer_tracking else None
    )

    def step_fn_with_tracking(
        carry: tuple[
            NormalizedLearnerState,
            StreamStateT,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
        ],
        idx: Array,
    ) -> tuple[
        tuple[
            NormalizedLearnerState,
            StreamStateT,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
            Array | None,
        ],
        Array,
    ]:
        (
            l_state,
            s_state,
            ss_hist,
            ss_bias_hist,
            ss_rec,
            ss_norm,
            n_means,
            n_vars,
            n_rec,
        ) = carry

        # Perform learning step
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(l_state, timestep.observation, timestep.target)

        # Step-size tracking
        new_ss_hist = ss_hist
        new_ss_bias_hist = ss_bias_hist
        new_ss_rec = ss_rec
        new_ss_norm = ss_norm

        if ss_hist is not None:
            should_record_ss = (idx % ss_interval) == 0
            recording_idx = idx // ss_interval

            # Extract current step-sizes from the inner learner state
            opt_state = result.state.learner_state.optimizer_state
            if hasattr(opt_state, "log_step_sizes"):
                # IDBD stores log step-sizes
                weight_ss = jnp.exp(opt_state.log_step_sizes)
                bias_ss = opt_state.bias_step_size
            elif hasattr(opt_state, "step_sizes"):
                # Autostep stores step-sizes directly
                weight_ss = opt_state.step_sizes
                bias_ss = opt_state.bias_step_size
            else:
                # LMS has a single fixed step-size
                weight_ss = jnp.full(feature_dim, opt_state.step_size)
                bias_ss = opt_state.step_size

            new_ss_hist = jax.lax.cond(
                should_record_ss,
                lambda _: ss_hist.at[recording_idx].set(weight_ss),
                lambda _: ss_hist,
                None,
            )

            if ss_bias_hist is not None:
                new_ss_bias_hist = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_bias_hist.at[recording_idx].set(bias_ss),
                    lambda _: ss_bias_hist,
                    None,
                )

            if ss_rec is not None:
                new_ss_rec = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_rec.at[recording_idx].set(idx),
                    lambda _: ss_rec,
                    None,
                )

            # Track Autostep normalizers (v_i) if applicable
            if ss_norm is not None and hasattr(opt_state, "normalizers"):
                new_ss_norm = jax.lax.cond(
                    should_record_ss,
                    lambda _: ss_norm.at[recording_idx].set(opt_state.normalizers),
                    lambda _: ss_norm,
                    None,
                )

        # Normalizer state tracking
        new_n_means = n_means
        new_n_vars = n_vars
        new_n_rec = n_rec

        if n_means is not None:
            should_record_norm = (idx % norm_interval) == 0
            norm_recording_idx = idx // norm_interval

            norm_state = result.state.normalizer_state

            new_n_means = jax.lax.cond(
                should_record_norm,
                lambda _: n_means.at[norm_recording_idx].set(norm_state.mean),
                lambda _: n_means,
                None,
            )

            if n_vars is not None:
                new_n_vars = jax.lax.cond(
                    should_record_norm,
                    lambda _: n_vars.at[norm_recording_idx].set(norm_state.var),
                    lambda _: n_vars,
                    None,
                )

            if n_rec is not None:
                new_n_rec = jax.lax.cond(
                    should_record_norm,
                    lambda _: n_rec.at[norm_recording_idx].set(idx),
                    lambda _: n_rec,
                    None,
                )

        return (
            result.state,
            new_s_state,
            new_ss_hist,
            new_ss_bias_hist,
            new_ss_rec,
            new_ss_norm,
            new_n_means,
            new_n_vars,
            new_n_rec,
        ), result.metrics

    initial_carry = (
        learner_state,
        stream_state,
        ss_history,
        ss_bias_history,
        ss_rec_indices,
        ss_normalizers,
        norm_means,
        norm_vars,
        norm_rec_indices,
    )

    (
        (
            final_learner,
            _,
            final_ss_hist,
            final_ss_bias_hist,
            final_ss_rec,
            final_ss_norm,
            final_n_means,
            final_n_vars,
            final_n_rec,
        ),
        metrics,
    ) = jax.lax.scan(step_fn_with_tracking, initial_carry, jnp.arange(num_steps))

    # Build return values based on what was tracked
    ss_history_result = None
    if step_size_tracking is not None and final_ss_hist is not None:
        ss_history_result = StepSizeHistory(
            step_sizes=final_ss_hist,
            bias_step_sizes=final_ss_bias_hist,
            recording_indices=final_ss_rec,
            normalizers=final_ss_norm,
        )

    norm_history_result = None
    if normalizer_tracking is not None and final_n_means is not None:
        norm_history_result = NormalizerHistory(
            means=final_n_means,
            variances=final_n_vars,
            recording_indices=final_n_rec,
        )

    # Return appropriate tuple based on what was tracked
    if ss_history_result is not None and norm_history_result is not None:
        return final_learner, metrics, ss_history_result, norm_history_result
    elif ss_history_result is not None:
        return final_learner, metrics, ss_history_result
    elif norm_history_result is not None:
        return final_learner, metrics, norm_history_result
    else:
        return final_learner, metrics

run_normalized_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None, normalizer_tracking=None)

Run normalized learning loop across multiple seeds in parallel using jax.vmap.

This function provides GPU parallelization for multi-seed experiments with normalized learners, typically achieving 2-5x speedup over sequential execution.

Args: learner: The normalized learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...) normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)

Returns: BatchedNormalizedResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, 4) - step_size_history: Batched history or None if tracking disabled - normalizer_history: Batched history or None if tracking disabled

Examples:

import jax.random as jr
from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_normalized_learning_loop_batched

stream = RandomWalkStream(feature_dim=10)
learner = NormalizedLinearLearner(optimizer=IDBD())

# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_normalized_learning_loop_batched(
    learner, stream, num_steps=10000, keys=keys
)

# result.metrics has shape (30, 10000, 4)
mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds

Source code in src/alberta_framework/core/learners.py
def run_normalized_learning_loop_batched[StreamStateT](
    learner: NormalizedLinearLearner,
    stream: ScanStream[StreamStateT],
    num_steps: int,
    keys: Array,
    learner_state: NormalizedLearnerState | None = None,
    step_size_tracking: StepSizeTrackingConfig | None = None,
    normalizer_tracking: NormalizerTrackingConfig | None = None,
) -> BatchedNormalizedResult:
    """Run normalized learning loop across multiple seeds in parallel using jax.vmap.

    This function provides GPU parallelization for multi-seed experiments with
    normalized learners, typically achieving 2-5x speedup over sequential execution.

    Args:
        learner: The normalized learner to train
        stream: Experience stream providing (observation, target) pairs
        num_steps: Number of learning steps to run per seed
        keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2)
        learner_state: Initial state (if None, will be initialized from stream).
            The same initial state is used for all seeds.
        step_size_tracking: Optional config for recording per-weight step-sizes.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)
        normalizer_tracking: Optional config for recording normalizer state.
            When provided, history arrays have shape (num_seeds, num_recordings, ...)

    Returns:
        BatchedNormalizedResult containing:
            - states: Batched final states with shape (num_seeds, ...) for each array
            - metrics: Array of shape (num_seeds, num_steps, 4)
            - step_size_history: Batched history or None if tracking disabled
            - normalizer_history: Batched history or None if tracking disabled

    Examples:
    ```python
    import jax.random as jr
    from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
    from alberta_framework import run_normalized_learning_loop_batched

    stream = RandomWalkStream(feature_dim=10)
    learner = NormalizedLinearLearner(optimizer=IDBD())

    # Run 30 seeds in parallel
    keys = jr.split(jr.key(42), 30)
    result = run_normalized_learning_loop_batched(
        learner, stream, num_steps=10000, keys=keys
    )

    # result.metrics has shape (30, 10000, 4)
    mean_error = result.metrics[:, :, 0].mean(axis=0)  # Average over seeds
    ```
    """

    # Define single-seed function that returns consistent structure
    def single_seed_run(
        key: Array,
    ) -> tuple[NormalizedLearnerState, Array, StepSizeHistory | None, NormalizerHistory | None]:
        result = run_normalized_learning_loop(
            learner, stream, num_steps, key, learner_state, step_size_tracking, normalizer_tracking
        )

        # Unpack based on what tracking was enabled
        if step_size_tracking is not None and normalizer_tracking is not None:
            state, metrics, ss_history, norm_history = cast(
                tuple[NormalizedLearnerState, Array, StepSizeHistory, NormalizerHistory],
                result,
            )
            return state, metrics, ss_history, norm_history
        elif step_size_tracking is not None:
            state, metrics, ss_history = cast(
                tuple[NormalizedLearnerState, Array, StepSizeHistory], result
            )
            return state, metrics, ss_history, None
        elif normalizer_tracking is not None:
            state, metrics, norm_history = cast(
                tuple[NormalizedLearnerState, Array, NormalizerHistory], result
            )
            return state, metrics, None, norm_history
        else:
            state, metrics = cast(tuple[NormalizedLearnerState, Array], result)
            return state, metrics, None, None

    # vmap over the keys dimension
    batched_states, batched_metrics, batched_ss_history, batched_norm_history = jax.vmap(
        single_seed_run
    )(keys)

    # Reconstruct batched histories if tracking was enabled
    if step_size_tracking is not None and batched_ss_history is not None:
        batched_step_size_history = StepSizeHistory(
            step_sizes=batched_ss_history.step_sizes,
            bias_step_sizes=batched_ss_history.bias_step_sizes,
            recording_indices=batched_ss_history.recording_indices,
            normalizers=batched_ss_history.normalizers,
        )
    else:
        batched_step_size_history = None

    if normalizer_tracking is not None and batched_norm_history is not None:
        batched_normalizer_history = NormalizerHistory(
            means=batched_norm_history.means,
            variances=batched_norm_history.variances,
            recording_indices=batched_norm_history.recording_indices,
        )
    else:
        batched_normalizer_history = None

    return BatchedNormalizedResult(
        states=batched_states,
        metrics=batched_metrics,
        step_size_history=batched_step_size_history,
        normalizer_history=batched_normalizer_history,
    )

run_td_learning_loop(learner, stream, num_steps, key, learner_state=None)

Run the TD learning loop using jax.lax.scan.

This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.

Args: learner: The TD learner to train stream: TD experience stream providing (s, r, s', γ) tuples num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]

Source code in src/alberta_framework/core/learners.py
def run_td_learning_loop[StreamStateT](
    learner: TDLinearLearner,
    stream: TDStream[StreamStateT],
    num_steps: int,
    key: Array,
    learner_state: TDLearnerState | None = None,
) -> tuple[TDLearnerState, Array]:
    """Run the TD learning loop using jax.lax.scan.

    This is a JIT-compiled learning loop that uses scan for efficiency.
    It returns metrics as a fixed-size array rather than a list of dicts.

    Args:
        learner: The TD learner to train
        stream: TD experience stream providing (s, r, s', γ) tuples
        num_steps: Number of learning steps to run
        key: JAX random key for stream initialization
        learner_state: Initial state (if None, will be initialized from stream)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size,
        mean_eligibility_trace]
    """
    # Initialize states
    if learner_state is None:
        learner_state = learner.init(stream.feature_dim)
    stream_state = stream.init(key)

    def step_fn(
        carry: tuple[TDLearnerState, StreamStateT], idx: Array
    ) -> tuple[tuple[TDLearnerState, StreamStateT], Array]:
        l_state, s_state = carry
        timestep, new_s_state = stream.step(s_state, idx)
        result = learner.update(
            l_state,
            timestep.observation,
            timestep.reward,
            timestep.next_observation,
            timestep.gamma,
        )
        return (result.state, new_s_state), result.metrics

    (final_learner, _), metrics = jax.lax.scan(
        step_fn, (learner_state, stream_state), jnp.arange(num_steps)
    )

    return final_learner, metrics

create_normalizer_state(feature_dim, decay=0.99)

Create initial normalizer state.

Convenience function for creating normalizer state without instantiating the OnlineNormalizer class.

Args: feature_dim: Dimension of feature vectors decay: Exponential decay factor

Returns: Initial normalizer state

Source code in src/alberta_framework/core/normalizers.py
def create_normalizer_state(
    feature_dim: int,
    decay: float = 0.99,
) -> NormalizerState:
    """Create initial normalizer state.

    Convenience function for creating normalizer state without
    instantiating the OnlineNormalizer class.

    Args:
        feature_dim: Dimension of feature vectors
        decay: Exponential decay factor

    Returns:
        Initial normalizer state
    """
    return NormalizerState(
        mean=jnp.zeros(feature_dim, dtype=jnp.float32),
        var=jnp.ones(feature_dim, dtype=jnp.float32),
        sample_count=jnp.array(0.0, dtype=jnp.float32),
        decay=jnp.array(decay, dtype=jnp.float32),
    )

create_autotdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Create initial AutoTDIDBD optimizer state.

Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Returns: Initial AutoTDIDBD state

Source code in src/alberta_framework/core/types.py
def create_autotdidbd_state(
    feature_dim: int,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
) -> AutoTDIDBDState:
    """Create initial AutoTDIDBD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        initial_step_size: Initial per-weight step-size
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)

    Returns:
        Initial AutoTDIDBD state
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

create_tdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0)

Create initial TD-IDBD optimizer state.

Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))

Returns: Initial TD-IDBD state

Source code in src/alberta_framework/core/types.py
def create_tdidbd_state(
    feature_dim: int,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
) -> TDIDBDState:
    """Create initial TD-IDBD optimizer state.

    Args:
        feature_dim: Dimension of the feature vector
        initial_step_size: Initial per-weight step-size
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))

    Returns:
        Initial TD-IDBD state
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(feature_dim, jnp.log(initial_step_size), dtype=jnp.float32),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)

Create a per-feature scale array spanning a range.

Utility function to generate scale factors for ScaledStreamWrapper.

Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.

Returns: Array of shape (feature_dim,) with scale factors

Examples:

scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
def make_scale_range(
    feature_dim: int,
    min_scale: float = 0.001,
    max_scale: float = 1000.0,
    log_spaced: bool = True,
) -> Array:
    """Create a per-feature scale array spanning a range.

    Utility function to generate scale factors for ScaledStreamWrapper.

    Args:
        feature_dim: Number of features
        min_scale: Minimum scale factor
        max_scale: Maximum scale factor
        log_spaced: If True, scales are logarithmically spaced (default).
            If False, scales are linearly spaced.

    Returns:
        Array of shape (feature_dim,) with scale factors

    Examples
    --------
    ```python
    scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
    stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
    ```
    """
    if log_spaced:
        return jnp.logspace(
            jnp.log10(min_scale),
            jnp.log10(max_scale),
            feature_dim,
            dtype=jnp.float32,
        )
    else:
        return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)

compare_learners(results, metric='squared_error')

Compare multiple learners on a given metric.

Args: results: Dictionary mapping learner name to metrics history metric: Metric to compare

Returns: Dictionary with summary statistics for each learner

Source code in src/alberta_framework/utils/metrics.py
def compare_learners(
    results: dict[str, list[dict[str, float]]],
    metric: str = "squared_error",
) -> dict[str, dict[str, float]]:
    """Compare multiple learners on a given metric.

    Args:
        results: Dictionary mapping learner name to metrics history
        metric: Metric to compare

    Returns:
        Dictionary with summary statistics for each learner
    """
    summary = {}
    for name, metrics_history in results.items():
        values = extract_metric(metrics_history, metric)
        summary[name] = {
            "mean": float(np.mean(values)),
            "std": float(np.std(values)),
            "cumulative": float(np.sum(values)),
            "final_100_mean": (
                float(np.mean(values[-100:])) if len(values) >= 100 else float(np.mean(values))
            ),
        }
    return summary

compute_cumulative_error(metrics_history, error_key='squared_error')

Compute cumulative error over time.

Args: metrics_history: List of metric dictionaries from learning loop error_key: Key to extract error values

Returns: Array of cumulative errors at each time step

Source code in src/alberta_framework/utils/metrics.py
def compute_cumulative_error(
    metrics_history: list[dict[str, float]],
    error_key: str = "squared_error",
) -> NDArray[np.float64]:
    """Compute cumulative error over time.

    Args:
        metrics_history: List of metric dictionaries from learning loop
        error_key: Key to extract error values

    Returns:
        Array of cumulative errors at each time step
    """
    errors = np.array([m[error_key] for m in metrics_history])
    return np.cumsum(errors)

compute_running_mean(values, window_size=100)

Compute running mean of values.

Args: values: Array of values window_size: Size of the moving average window

Returns: Array of running mean values (same length as input, padded at start)

Source code in src/alberta_framework/utils/metrics.py
def compute_running_mean(
    values: NDArray[np.float64] | list[float],
    window_size: int = 100,
) -> NDArray[np.float64]:
    """Compute running mean of values.

    Args:
        values: Array of values
        window_size: Size of the moving average window

    Returns:
        Array of running mean values (same length as input, padded at start)
    """
    values_arr = np.asarray(values)
    cumsum = np.cumsum(np.insert(values_arr, 0, 0))
    running_mean = (cumsum[window_size:] - cumsum[:-window_size]) / window_size

    # Pad the beginning with the first computed mean
    if len(running_mean) > 0:
        padding = np.full(window_size - 1, running_mean[0])
        return np.concatenate([padding, running_mean])
    return values_arr

compute_tracking_error(metrics_history, window_size=100)

Compute tracking error (running mean of squared error).

This is the key metric for evaluating continual learners: how well can the learner track the non-stationary target?

Args: metrics_history: List of metric dictionaries from learning loop window_size: Size of the moving average window

Returns: Array of tracking errors at each time step

Source code in src/alberta_framework/utils/metrics.py
def compute_tracking_error(
    metrics_history: list[dict[str, float]],
    window_size: int = 100,
) -> NDArray[np.float64]:
    """Compute tracking error (running mean of squared error).

    This is the key metric for evaluating continual learners:
    how well can the learner track the non-stationary target?

    Args:
        metrics_history: List of metric dictionaries from learning loop
        window_size: Size of the moving average window

    Returns:
        Array of tracking errors at each time step
    """
    errors = np.array([m["squared_error"] for m in metrics_history])
    return compute_running_mean(errors, window_size)

extract_metric(metrics_history, key)

Extract a single metric from the history.

Args: metrics_history: List of metric dictionaries key: Key to extract

Returns: Array of values for that metric

Source code in src/alberta_framework/utils/metrics.py
def extract_metric(
    metrics_history: list[dict[str, float]],
    key: str,
) -> NDArray[np.float64]:
    """Extract a single metric from the history.

    Args:
        metrics_history: List of metric dictionaries
        key: Key to extract

    Returns:
        Array of values for that metric
    """
    return np.array([m[key] for m in metrics_history])

format_duration(seconds)

Format a duration in seconds as a human-readable string.

Args: seconds: Duration in seconds

Returns: Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"

Examples:

format_duration(0.5)   # Returns: '0.50s'
format_duration(90.5)  # Returns: '1m 30.50s'
format_duration(3665)  # Returns: '1h 1m 5.00s'
Source code in src/alberta_framework/utils/timing.py
def format_duration(seconds: float) -> str:
    """Format a duration in seconds as a human-readable string.

    Args:
        seconds: Duration in seconds

    Returns:
        Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"

    Examples
    --------
    ```python
    format_duration(0.5)   # Returns: '0.50s'
    format_duration(90.5)  # Returns: '1m 30.50s'
    format_duration(3665)  # Returns: '1h 1m 5.00s'
    ```
    """
    if seconds < 60:
        return f"{seconds:.2f}s"
    elif seconds < 3600:
        minutes = int(seconds // 60)
        secs = seconds % 60
        return f"{minutes}m {secs:.2f}s"
    else:
        hours = int(seconds // 3600)
        remaining = seconds % 3600
        minutes = int(remaining // 60)
        secs = remaining % 60
        return f"{hours}h {minutes}m {secs:.2f}s"

collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)

Collect a trajectory from a Gymnasium environment.

This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy

Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)

Source code in src/alberta_framework/streams/gymnasium.py
def collect_trajectory(
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None,
    num_steps: int,
    mode: PredictionMode = PredictionMode.REWARD,
    include_action_in_features: bool = True,
    seed: int = 0,
) -> tuple[Array, Array]:
    """Collect a trajectory from a Gymnasium environment.

    This uses a Python loop to interact with the environment and collects
    observations and targets into JAX arrays that can be used with scan-based
    learning.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        num_steps: Number of steps to collect
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed for environment resets and random policy

    Returns:
        Tuple of (observations, targets) as JAX arrays with shape
        (num_steps, feature_dim) and (num_steps, target_dim)
    """
    if policy is None:
        policy = make_random_policy(env, seed)

    observations = []
    targets = []

    reset_count = 0
    raw_obs, _ = env.reset(seed=seed + reset_count)
    reset_count += 1
    current_obs = _flatten_observation(raw_obs, env.observation_space)

    for _ in range(num_steps):
        action = policy(current_obs)
        flat_action = _flatten_action(action, env.action_space)

        raw_next_obs, reward, terminated, truncated, _ = env.step(action)
        next_obs = _flatten_observation(raw_next_obs, env.observation_space)

        # Construct features
        if include_action_in_features:
            features = jnp.concatenate([current_obs, flat_action])
        else:
            features = current_obs

        # Construct target based on mode
        if mode == PredictionMode.REWARD:
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
        elif mode == PredictionMode.NEXT_STATE:
            target = next_obs
        else:  # VALUE mode
            # TD target with 0 bootstrap (simple version)
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))

        observations.append(features)
        targets.append(target)

        if terminated or truncated:
            raw_obs, _ = env.reset(seed=seed + reset_count)
            reset_count += 1
            current_obs = _flatten_observation(raw_obs, env.observation_space)
        else:
            current_obs = next_obs

    return jnp.stack(observations), jnp.stack(targets)

learn_from_trajectory(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory using jax.lax.scan.

This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.

Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory(
    learner: LinearLearner,
    observations: Array,
    targets: Array,
    learner_state: LearnerState | None = None,
) -> tuple[LearnerState, Array]:
    """Learn from a pre-collected trajectory using jax.lax.scan.

    This is a JIT-compiled learning function that processes a trajectory
    collected from a Gymnasium environment.

    Args:
        learner: The learner to train
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 3) with columns [squared_error, error, mean_step_size]
    """
    if learner_state is None:
        learner_state = learner.init(observations.shape[1])

    def step_fn(state: LearnerState, inputs: tuple[Array, Array]) -> tuple[LearnerState, Array]:
        obs, target = inputs
        result = learner.update(state, obs, target)
        return result.state, result.metrics

    final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))

    return final_state, metrics

learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory with normalization using jax.lax.scan.

Args: learner: The normalized learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory_normalized(
    learner: NormalizedLinearLearner,
    observations: Array,
    targets: Array,
    learner_state: NormalizedLearnerState | None = None,
) -> tuple[NormalizedLearnerState, Array]:
    """Learn from a pre-collected trajectory with normalization using jax.lax.scan.

    Args:
        learner: The normalized learner to train
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
    """
    if learner_state is None:
        learner_state = learner.init(observations.shape[1])

    def step_fn(
        state: NormalizedLearnerState, inputs: tuple[Array, Array]
    ) -> tuple[NormalizedLearnerState, Array]:
        obs, target = inputs
        result = learner.update(state, obs, target)
        return result.state, result.metrics

    final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))

    return final_state, metrics

make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)

Wrap a policy with epsilon-greedy exploration.

Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed

Returns: Epsilon-greedy policy

Source code in src/alberta_framework/streams/gymnasium.py
def make_epsilon_greedy_policy(
    base_policy: Callable[[Array], Any],
    env: gymnasium.Env[Any, Any],
    epsilon: float = 0.1,
    seed: int = 0,
) -> Callable[[Array], Any]:
    """Wrap a policy with epsilon-greedy exploration.

    Args:
        base_policy: The greedy policy to wrap
        env: Gymnasium environment (for random action sampling)
        epsilon: Probability of taking a random action
        seed: Random seed

    Returns:
        Epsilon-greedy policy
    """
    random_policy = make_random_policy(env, seed + 1)
    rng = jr.key(seed)

    def policy(obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if jr.uniform(key) < epsilon:
            return random_policy(obs)
        return base_policy(obs)

    return policy

make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)

Factory function to create a GymnasiumStream from an environment ID.

Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()

Returns: GymnasiumStream wrapping the environment

Source code in src/alberta_framework/streams/gymnasium.py
def make_gymnasium_stream(
    env_id: str,
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
    **env_kwargs: Any,
) -> GymnasiumStream:
    """Factory function to create a GymnasiumStream from an environment ID.

    Args:
        env_id: Gymnasium environment ID (e.g., "CartPole-v1")
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed
        **env_kwargs: Additional arguments passed to gymnasium.make()

    Returns:
        GymnasiumStream wrapping the environment
    """
    import gymnasium

    env = gymnasium.make(env_id, **env_kwargs)
    return GymnasiumStream(
        env=env,
        mode=mode,
        policy=policy,
        gamma=gamma,
        include_action_in_features=include_action_in_features,
        seed=seed,
    )

make_random_policy(env, seed=0)

Create a random action policy for an environment.

Args: env: Gymnasium environment seed: Random seed

Returns: A callable that takes an observation and returns a random action

Source code in src/alberta_framework/streams/gymnasium.py
def make_random_policy(env: gymnasium.Env[Any, Any], seed: int = 0) -> Callable[[Array], Any]:
    """Create a random action policy for an environment.

    Args:
        env: Gymnasium environment
        seed: Random seed

    Returns:
        A callable that takes an observation and returns a random action
    """
    import gymnasium

    rng = jr.key(seed)
    action_space = env.action_space

    def policy(_obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if isinstance(action_space, gymnasium.spaces.Discrete):
            return int(jr.randint(key, (), 0, int(action_space.n)))
        elif isinstance(action_space, gymnasium.spaces.Box):
            # Sample uniformly between low and high
            low = jnp.asarray(action_space.low, dtype=jnp.float32)
            high = jnp.asarray(action_space.high, dtype=jnp.float32)
            return jr.uniform(key, action_space.shape, minval=low, maxval=high)
        elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
            nvec = action_space.nvec
            return [int(jr.randint(jr.fold_in(key, i), (), 0, n)) for i, n in enumerate(nvec)]
        else:
            raise ValueError(f"Unsupported action space: {type(action_space).__name__}")

    return policy