Skip to content

core

core

Core components for the Alberta Framework.

LinearLearner(optimizer=None)

Linear function approximator with pluggable optimizer.

Computes predictions as: y = w @ x + b

The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Attributes: optimizer: The optimizer to use for weight updates

Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01)

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyOptimizer | None = None):
    """Initialize the linear learner.

    Args:
        optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
    """
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)

init(feature_dim)

Initialize learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> LearnerState:
    """Initialize learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return LearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
    )

predict(state, observation)

Compute prediction for an observation.

Args: state: Current learner state observation: Input feature vector

Returns: Scalar prediction y = w @ x + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Scalar prediction `y = w @ x + b`
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, target)

Update learner given observation and target.

Performs one step of the learning algorithm: 1. Compute prediction 2. Compute error 3. Get weight updates from optimizer 4. Apply updates to weights and bias

Args: state: Current learner state observation: Input feature vector target: Desired output

Returns: UpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: LearnerState,
    observation: Observation,
    target: Target,
) -> UpdateResult:
    """Update learner given observation and target.

    Performs one step of the learning algorithm:
    1. Compute prediction
    2. Compute error
    3. Get weight updates from optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        UpdateResult with new state, prediction, error, and metrics
    """
    # Make prediction
    prediction = self.predict(state, observation)

    # Compute error (target - prediction)
    error = jnp.squeeze(target) - jnp.squeeze(prediction)

    # Get update from optimizer
    # Note: type ignore needed because we can't statically prove optimizer_state
    # matches the optimizer's expected state type (though they will at runtime)
    opt_update = self._optimizer.update(
        state.optimizer_state,
        error,
        observation,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = LearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
    )

    # Pack metrics as array for scan compatibility
    # Format: [squared_error, error, mean_step_size (if adaptive)]
    squared_error = error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    metrics = jnp.array([squared_error, error, mean_step_size], dtype=jnp.float32)

    return UpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

TDLinearLearner(optimizer=None)

Linear function approximator for TD learning.

Computes value predictions as: V(s) = w @ φ(s) + b

The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: optimizer: The TD optimizer to use for weight updates

Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyTDOptimizer | None = None):
    """Initialize the TD linear learner.

    Args:
        optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
    """
    self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()

init(feature_dim)

Initialize TD learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial TD learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> TDLearnerState:
    """Initialize TD learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial TD learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return TDLearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
    )

predict(state, observation)

Compute value prediction for an observation.

Args: state: Current TD learner state observation: Input feature vector φ(s)

Returns: Scalar value prediction V(s) = w @ φ(s) + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
    """Compute value prediction for an observation.

    Args:
        state: Current TD learner state
        observation: Input feature vector φ(s)

    Returns:
        Scalar value prediction `V(s) = w @ φ(s) + b`
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, reward, next_observation, gamma)

Update learner given a TD transition.

Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error δ = R + γV(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias

Args: state: Current TD learner state observation: Current observation φ(s) reward: Reward R received next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal states)

Returns: TDUpdateResult with new state, prediction, TD error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: TDLearnerState,
    observation: Observation,
    reward: Array,
    next_observation: Observation,
    gamma: Array,
) -> TDUpdateResult:
    """Update learner given a TD transition.

    Performs one step of TD learning:
    1. Compute V(s) and V(s')
    2. Compute TD error δ = R + γV(s') - V(s)
    3. Get weight updates from TD optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current TD learner state
        observation: Current observation φ(s)
        reward: Reward R received
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal states)

    Returns:
        TDUpdateResult with new state, prediction, TD error, and metrics
    """
    # Compute predictions
    prediction = self.predict(state, observation)
    next_prediction = self.predict(state, next_observation)

    # Compute TD error: δ = R + γV(s') - V(s)
    gamma_scalar = jnp.squeeze(gamma)
    td_error = (
        jnp.squeeze(reward)
        + gamma_scalar * jnp.squeeze(next_prediction)
        - jnp.squeeze(prediction)
    )

    # Get update from TD optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        td_error,
        observation,
        next_observation,
        gamma,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = TDLearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
    )

    # Pack metrics as array for scan compatibility
    # Format: [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]
    squared_td_error = td_error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
    metrics = jnp.array(
        [squared_td_error, td_error, mean_step_size, mean_elig_trace],
        dtype=jnp.float32,
    )

    return TDUpdateResult(
        state=new_state,
        prediction=prediction,
        td_error=jnp.atleast_1d(td_error),
        metrics=metrics,
    )

TDUpdateResult

Result of a TD learner update step.

Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error δ = R + γV(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]

IDBD(initial_step_size=0.01, meta_step_size=0.01)

Bases: Optimizer[IDBDState]

Incremental Delta-Bar-Delta optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.

This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
):
    """Initialize IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate beta for adapting step-sizes
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size

init(feature_dim)

Initialize IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: IDBD state with per-weight step-sizes and traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> IDBDState:
    """Initialize IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        IDBD state with per-weight step-sizes and traces
    """
    return IDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, error, observation)

Compute IDBD weight update with adaptive step-sizes.

The IDBD algorithm:

  1. Compute step-sizes: alpha_i = exp(log_alpha_i)
  2. Update weights: w_i += alpha_i * error * x_i
  3. Update log step-sizes: log_alpha_i += beta * error * x_i * h_i
  4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i

The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.

Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: IDBDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute IDBD weight update with adaptive step-sizes.

    The IDBD algorithm:

    1. Compute step-sizes: `alpha_i = exp(log_alpha_i)`
    2. Update weights: `w_i += alpha_i * error * x_i`
    3. Update log step-sizes: `log_alpha_i += beta * error * x_i * h_i`
    4. Update traces: `h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i`

    The trace h_i tracks the correlation between current and past gradients.
    When gradients consistently point the same direction, h_i grows,
    leading to larger step-sizes.

    Args:
        state: Current IDBD state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    beta = state.meta_step_size

    # Current step-sizes (exponentiate log values)
    alphas = jnp.exp(state.log_step_sizes)

    # Weight updates: alpha_i * error * x_i
    weight_delta = alphas * error_scalar * observation

    # Meta-update: adapt step-sizes based on gradient correlation
    # log_alpha_i += beta * error * x_i * h_i
    gradient_correlation = error_scalar * observation * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # Update traces: h_i = h_i * decay + alpha_i * error * x_i
    # decay = max(0, 1 - alpha_i * x_i^2)
    decay = jnp.maximum(0.0, 1.0 - alphas * observation**2)
    new_traces = state.traces * decay + alphas * error_scalar * observation

    # Bias updates (similar logic but scalar)
    bias_alpha = state.bias_step_size
    bias_delta = bias_alpha * error_scalar

    # Update bias step-size
    bias_gradient_correlation = error_scalar * state.bias_trace
    new_bias_step_size = bias_alpha * jnp.exp(beta * bias_gradient_correlation)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)

    # Update bias trace
    bias_decay = jnp.maximum(0.0, 1.0 - bias_alpha)
    new_bias_trace = state.bias_trace * bias_decay + bias_alpha * error_scalar

    new_state = IDBDState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(alphas),
            "min_step_size": jnp.min(alphas),
            "max_step_size": jnp.max(alphas),
        },
    )

LMS(step_size=0.01)

Bases: Optimizer[LMSState]

Least Mean Square optimizer with fixed step-size.

The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t

This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.

Attributes: step_size: Fixed learning rate alpha

Args: step_size: Fixed learning rate

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, step_size: float = 0.01):
    """Initialize LMS optimizer.

    Args:
        step_size: Fixed learning rate
    """
    self._step_size = step_size

init(feature_dim)

Initialize LMS state.

Args: feature_dim: Dimension of weight vector (unused for LMS)

Returns: LMS state containing the step-size

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> LMSState:
    """Initialize LMS state.

    Args:
        feature_dim: Dimension of weight vector (unused for LMS)

    Returns:
        LMS state containing the step-size
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

update(state, error, observation)

Compute LMS weight update.

Update rule: delta_w = alpha * error * x

Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight and bias deltas

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: LMSState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute LMS weight update.

    Update rule: `delta_w = alpha * error * x`

    Args:
        state: Current LMS state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight and bias deltas
    """
    alpha = state.step_size
    error_scalar = jnp.squeeze(error)

    # Weight update: alpha * error * x
    weight_delta = alpha * error_scalar * observation

    # Bias update: alpha * error
    bias_delta = alpha * error_scalar

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=state,  # LMS state doesn't change
        metrics={"step_size": alpha},
    )

TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)

Bases: TDOptimizer[TDIDBDState]

TD-IDBD optimizer for temporal-difference learning.

Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Two variants are supported: - Semi-gradient (default): Uses only φ(s) in meta-update, more stable - Ordinary gradient: Uses both φ(s) and φ(s'), more accurate but sensitive

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

The semi-gradient TD-IDBD algorithm (Algorithm 3 in paper): 1. Compute TD error: δ = R + γ*w^T*φ(s') - w^T*φ(s) 2. Update meta-weights: β_i += θ*δ*φ_i(s)*h_i 3. Compute step-sizes: α_i = exp(β_i) 4. Update eligibility traces: z_i = γ*λ*z_i + φ_i(s) 5. Update weights: w_i += α_i*δ*z_i 6. Update h traces: h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    use_semi_gradient: bool = True,
):
    """Initialize TD-IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        use_semi_gradient: If True, use semi-gradient variant (recommended)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._use_semi_gradient = use_semi_gradient

init(feature_dim)

Initialize TD-IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> TDIDBDState:
    """Initialize TD-IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        TD-IDBD state with per-weight step-sizes, traces, and h traces
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute TD-IDBD weight update with adaptive step-sizes.

Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.

Args: state: Current TD-IDBD state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: TDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute TD-IDBD weight update with adaptive step-sizes.

    Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
    from Kearney et al. 2019.

    Args:
        state: Current TD-IDBD state
        td_error: TD error δ = R + γV(s') - V(s)
        observation: Current observation φ(s)
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    gamma_scalar = jnp.squeeze(gamma)

    if self._use_semi_gradient:
        # Semi-gradient TD-IDBD (Algorithm 3)
        # β_i += θ*δ*φ_i(s)*h_i
        gradient_correlation = delta * observation * state.h_traces
        new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
    else:
        # Ordinary gradient TD-IDBD (Algorithm 4)
        # β_i -= θ*δ*[γ*φ_i(s') - φ_i(s)]*h_i
        # Note: negative sign because gradient direction is reversed
        feature_diff = gamma_scalar * next_observation - observation
        gradient_correlation = delta * feature_diff * state.h_traces
        new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # Get updated step-sizes for weight update
    new_alphas = jnp.exp(new_log_step_sizes)

    # Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation

    # Compute weight delta: α_i*δ*z_i
    weight_delta = new_alphas * delta * new_eligibility_traces

    if self._use_semi_gradient:
        # Semi-gradient h update (Algorithm 3, line 9)
        # h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i
        h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
    else:
        # Ordinary gradient h update (Algorithm 4, line 9)
        # h_i = h_i*[1 + α_i*z_i*(γ*φ_i(s') - φ_i(s))]^+ + α_i*δ*z_i
        feature_diff = gamma_scalar * next_observation - observation
        h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates (similar logic but scalar)
    if self._use_semi_gradient:
        # Semi-gradient bias meta-update
        bias_gradient_correlation = delta * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
    else:
        # Ordinary gradient bias meta-update
        # For bias, φ(s) = 1, so feature_diff = γ - 1
        bias_feature_diff = gamma_scalar - 1.0
        bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    # Update bias eligibility trace
    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0

    # Bias weight delta
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    if self._use_semi_gradient:
        # Semi-gradient bias h update
        bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )
    else:
        # Ordinary gradient bias h update
        bias_feature_diff = gamma_scalar - 1.0
        bias_h_decay = jnp.maximum(
            0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
        )
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )

    new_state = TDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        meta_step_size=theta,
        trace_decay=lam,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
        },
    )

AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Bases: TDOptimizer[AutoTDIDBDState]

AutoStep-style normalized TD-IDBD optimizer.

Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta. Includes: 1. Normalization of the meta-weight update by a running trace of recent updates 2. Effective step-size normalization to prevent overshooting

Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(λ)"

The AutoTDIDBD algorithm: 1. Compute TD error: δ = R + γ*w^T*φ(s') - w^T*φ(s) 2. Update normalizers: η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|, η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i)) 3. Normalized meta-update: β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i 4. Effective step-size normalization: M = max(-exp(β)*[γφ(s')-φ(s)]^T*z, 1) then β_i -= log(M) 5. Update weights and traces as in TIDBD

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
):
    """Initialize AutoTDIDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._normalizer_decay = normalizer_decay

init(feature_dim)

Initialize AutoTDIDBD state.

Args: feature_dim: Dimension of weight vector

Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutoTDIDBDState:
    """Initialize AutoTDIDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

Implements Algorithm 6 from Kearney et al. 2019.

Args: state: Current AutoTDIDBD state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutoTDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

    Implements Algorithm 6 from Kearney et al. 2019.

    Args:
        state: Current AutoTDIDBD state
        td_error: TD error δ = R + γV(s') - V(s)
        observation: Current observation φ(s)
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    tau = state.normalizer_decay
    gamma_scalar = jnp.squeeze(gamma)

    # Feature difference: γ*φ(s') - φ(s)
    feature_diff = gamma_scalar * next_observation - observation

    # Current step-sizes
    alphas = jnp.exp(state.log_step_sizes)

    # Update normalizers (Algorithm 6, lines 5-7)
    # η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
    #           η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))
    abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
    normalizer_decay_term = (
        (1.0 / tau)
        * alphas
        * feature_diff
        * state.eligibility_traces
        * (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
    )
    new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
    # Ensure normalizers don't go to zero
    new_normalizers = jnp.maximum(new_normalizers, 1e-8)

    # Normalized meta-update (Algorithm 6, line 9)
    # β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i
    normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
    new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient

    # Effective step-size normalization (Algorithm 6, lines 10-11)
    # M = max(-exp(β_i)*[γφ_i(s')-φ_i(s)]^T*z_i, 1)
    # β_i -= log(M)
    effective_step_size = -jnp.sum(
        jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
    )
    normalization_factor = jnp.maximum(effective_step_size, 1.0)
    new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # Get updated step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation

    # Compute weight delta: α_i*δ*z_i
    weight_delta = new_alphas * delta * new_eligibility_traces

    # Update h traces (ordinary gradient variant, Algorithm 6 line 15)
    # h_i = h_i*[1 + α_i*[γφ_i(s')-φ_i(s)]*z_i]^+ + α_i*δ*z_i
    h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
    new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    bias_alpha = jnp.exp(state.bias_log_step_size)
    bias_feature_diff = gamma_scalar - 1.0  # For bias, φ(s) = 1

    # Bias normalizer update
    abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
    bias_normalizer_decay_term = (
        (1.0 / tau)
        * bias_alpha
        * bias_feature_diff
        * state.bias_eligibility_trace
        * (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
    )
    new_bias_normalizer = jnp.maximum(
        abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
    )
    new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)

    # Normalized bias meta-update
    normalized_bias_gradient = (
        delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
    )
    new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient

    # Effective step-size normalization for bias
    bias_effective_step_size = (
        -jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
    )
    bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
    new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    # Update bias eligibility trace
    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0

    # Bias weight delta
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    # Bias h trace update
    bias_h_decay = jnp.maximum(
        0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
    )
    new_bias_h_trace = (
        state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
    )

    new_state = AutoTDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        normalizers=new_normalizers,
        meta_step_size=theta,
        trace_decay=lam,
        normalizer_decay=tau,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

Optimizer

Bases: ABC

Base class for optimizers.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, error, observation) abstractmethod

Compute weight updates given prediction error.

Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute weight updates given prediction error.

    Args:
        state: Current optimizer state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with deltas and new state
    """
    ...

TDOptimizer

Bases: ABC

Base class for TD optimizers.

TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, td_error, observation, next_observation, gamma) abstractmethod

Compute weight updates given TD error.

Args: state: Current optimizer state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)

Returns: TDOptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute weight updates given TD error.

    Args:
        state: Current optimizer state
        td_error: TD error δ = R + γV(s') - V(s)
        observation: Current observation φ(s)
        next_observation: Next observation φ(s')
        gamma: Discount factor γ (0 at terminal)

    Returns:
        TDOptimizerUpdate with deltas and new state
    """
    ...

TDOptimizerUpdate

Result of a TD optimizer update step.

Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging

AutoTDIDBDState

State for the AutoTDIDBD optimizer.

AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.

Reference: Kearney et al. 2019, Algorithm 6

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation

IDBDState

State for the IDBD (Incremental Delta-Bar-Delta) optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.

Reference: Sutton 1992, "Adapting Bias by Gradient Descent"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term

LearnerState

State for a linear learner.

Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer

LMSState

State for the LMS (Least Mean Square) optimizer.

LMS uses a fixed step-size, so state only tracks the step-size parameter.

Attributes: step_size: Fixed learning rate alpha

TDIDBDState

State for the TD-IDBD (Temporal-Difference IDBD) optimizer.

TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term

TDLearnerState

State for a TD linear learner.

Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer

TDTimeStep

Single experience from a TD stream.

Represents a transition (s, r, s', gamma) for temporal-difference learning.

Attributes: observation: Feature vector φ(s) reward: Reward R received next_observation: Feature vector φ(s') gamma: Discount factor γ_t (0 at terminal states)

TimeStep

Single experience from an experience stream.

Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)