Skip to content

core

core

Core components for the Alberta Framework.

BatchedHordeResult

Result from batched Horde learning loop.

Attributes: states: Batched multi-head MLP learner states per_demon_metrics: Per-demon metrics, shape (n_seeds, num_steps, n_demons, 3) td_errors: TD errors, shape (n_seeds, num_steps, n_demons)

HordeLearner(horde_spec, hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None)

Horde: GVF demons sharing a trunk (Sutton et al. 2011).

Wraps MultiHeadMLPLearner. Adds: - Per-demon gamma/lambda from HordeSpec - TD target computation for temporal demons (gamma > 0) - GVF metadata

The trunk uses gamma=0, lamda=0 (no temporal trace decay on shared features). Each head uses its own gamma * lambda product for trace decay, set via per_head_gamma_lamda on the inner learner.

For all-gamma=0 Hordes (e.g. rlsecd's 5 prediction heads), this produces identical results to MultiHeadMLPLearner since the TD target reduces to just the cumulant.

Single-Step (Daemon) Usage

Both predict() and update() work with single unbatched observations (1D arrays). JIT-compiled automatically.

Attributes: horde_spec: The HordeSpec defining all demons n_demons: Number of demons (heads)

Args: horde_spec: Specification of all GVF demons hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). step_size: Base learning rate (used only when optimizer is None) bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer sparsity: Fraction of weights zeroed out per neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization head_optimizer: Optional separate optimizer for heads

Source code in src/alberta_framework/core/horde.py
def __init__(
    self,
    horde_spec: HordeSpec,
    hidden_sizes: tuple[int, ...] = (128, 128),
    optimizer: AnyOptimizer | None = None,
    step_size: float = 1.0,
    bounder: Bounder | None = None,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
    sparsity: float = 0.9,
    leaky_relu_slope: float = 0.01,
    use_layer_norm: bool = True,
    head_optimizer: AnyOptimizer | None = None,
):
    """Initialize the Horde learner.

    Args:
        horde_spec: Specification of all GVF demons
        hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
        optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
        step_size: Base learning rate (used only when optimizer is None)
        bounder: Optional update bounder (e.g. ObGDBounding)
        normalizer: Optional feature normalizer
        sparsity: Fraction of weights zeroed out per neuron (default: 0.9)
        leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
        use_layer_norm: Whether to apply parameterless layer normalization
        head_optimizer: Optional separate optimizer for heads
    """
    self._horde_spec = horde_spec
    self._hidden_sizes = hidden_sizes
    self._step_size = step_size
    self._sparsity = sparsity
    self._leaky_relu_slope = leaky_relu_slope
    self._use_layer_norm = use_layer_norm

    # Compute per-head gamma*lambda products
    per_head_gl = tuple(
        float(d.gamma * d.lamda) for d in horde_spec.demons
    )

    self._learner = MultiHeadMLPLearner(
        n_heads=len(horde_spec.demons),
        hidden_sizes=hidden_sizes,
        optimizer=optimizer,
        step_size=step_size,
        bounder=bounder,
        gamma=0.0,  # trunk: no trace decay
        lamda=0.0,
        normalizer=normalizer,
        sparsity=sparsity,
        leaky_relu_slope=leaky_relu_slope,
        use_layer_norm=use_layer_norm,
        head_optimizer=head_optimizer,
        per_head_gamma_lamda=per_head_gl,
    )

horde_spec property

The HordeSpec defining all demons.

n_demons property

Number of demons (heads).

learner property

The underlying MultiHeadMLPLearner.

to_config()

Serialize learner configuration to dict.

Returns: Dict with horde_spec and all MultiHeadMLPLearner constructor args.

Source code in src/alberta_framework/core/horde.py
def to_config(self) -> dict[str, Any]:
    """Serialize learner configuration to dict.

    Returns:
        Dict with horde_spec and all MultiHeadMLPLearner constructor args.
    """
    learner_config = self._learner.to_config()
    # Remove fields managed by HordeLearner
    learner_config.pop("type", None)
    learner_config.pop("n_heads", None)
    learner_config.pop("gamma", None)
    learner_config.pop("lamda", None)
    learner_config.pop("per_head_gamma_lamda", None)

    return {
        "type": "HordeLearner",
        "horde_spec": self._horde_spec.to_config(),
        **learner_config,
    }

from_config(config) classmethod

Reconstruct from config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed HordeLearner

Source code in src/alberta_framework/core/horde.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "HordeLearner":
    """Reconstruct from config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed HordeLearner
    """
    from alberta_framework.core.normalizers import normalizer_from_config
    from alberta_framework.core.optimizers import (
        bounder_from_config,
        optimizer_from_config,
    )

    config = dict(config)
    config.pop("type", None)

    horde_spec = HordeSpec.from_config(config.pop("horde_spec"))
    optimizer = optimizer_from_config(config.pop("optimizer"))
    bounder_cfg = config.pop("bounder", None)
    bounder = bounder_from_config(bounder_cfg) if bounder_cfg is not None else None
    normalizer_cfg = config.pop("normalizer", None)
    normalizer = (
        normalizer_from_config(normalizer_cfg) if normalizer_cfg is not None else None
    )
    head_opt_cfg = config.pop("head_optimizer", None)
    head_optimizer = (
        optimizer_from_config(head_opt_cfg) if head_opt_cfg is not None else None
    )

    return cls(
        horde_spec=horde_spec,
        hidden_sizes=tuple(config.pop("hidden_sizes")),
        optimizer=optimizer,
        bounder=bounder,
        normalizer=normalizer,
        head_optimizer=head_optimizer,
        **config,
    )

init(feature_dim, key)

Initialize Horde learner state.

Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization

Returns: Initial MultiHeadMLPState

Source code in src/alberta_framework/core/horde.py
def init(self, feature_dim: int, key: Array) -> MultiHeadMLPState:
    """Initialize Horde learner state.

    Args:
        feature_dim: Dimension of the input feature vector
        key: JAX random key for weight initialization

    Returns:
        Initial MultiHeadMLPState
    """
    return self._learner.init(feature_dim, key)

predict(state, observation)

Compute predictions from all demons.

Args: state: Current learner state observation: Input feature vector

Returns: Array of shape (n_demons,) with one prediction per demon

Source code in src/alberta_framework/core/horde.py
@functools.partial(jax.jit, static_argnums=(0,))
def predict(self, state: MultiHeadMLPState, observation: Array) -> Array:
    """Compute predictions from all demons.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Array of shape ``(n_demons,)`` with one prediction per demon
    """
    return self._learner.predict(state, observation)  # type: ignore[no-any-return]

update(state, observation, cumulants, next_observation)

Update Horde given observation, cumulants, and next observation.

Computes TD targets r + gamma * V(s') for each demon, then delegates to MultiHeadMLPLearner.update(). For gamma=0 demons, the target equals the cumulant.

Args: state: Current state observation: Input feature vector, shape (feature_dim,) cumulants: Per-demon pseudo-rewards, shape (n_demons,). NaN = inactive demon. next_observation: Next feature vector, shape (feature_dim,). Used for V(s') bootstrapping. For all-gamma=0 Hordes, this is required but doesn't affect results.

Returns: HordeUpdateResult with updated state, predictions, TD errors, TD targets, and per-demon metrics

Source code in src/alberta_framework/core/horde.py
@functools.partial(jax.jit, static_argnums=(0,))
def update(
    self,
    state: MultiHeadMLPState,
    observation: Array,
    cumulants: Array,
    next_observation: Array,
) -> HordeUpdateResult:
    """Update Horde given observation, cumulants, and next observation.

    Computes TD targets ``r + gamma * V(s')`` for each demon, then
    delegates to ``MultiHeadMLPLearner.update()``. For gamma=0 demons,
    the target equals the cumulant.

    Args:
        state: Current state
        observation: Input feature vector, shape ``(feature_dim,)``
        cumulants: Per-demon pseudo-rewards, shape ``(n_demons,)``.
            NaN = inactive demon.
        next_observation: Next feature vector, shape ``(feature_dim,)``.
            Used for V(s') bootstrapping. For all-gamma=0 Hordes,
            this is required but doesn't affect results.

    Returns:
        HordeUpdateResult with updated state, predictions, TD errors,
        TD targets, and per-demon metrics
    """
    # 1. Compute V(s') for bootstrapping
    next_preds = self._learner.predict(state, next_observation)

    # 2. TD targets: r + gamma * V(s')
    # For gamma=0 demons: target = cumulant (single-step prediction)
    # NaN cumulants stay NaN (inactive demons)
    gammas = self._horde_spec.gammas
    targets = cumulants + gammas * next_preds

    # 3. Delegate to MultiHeadMLPLearner
    result = self._learner.update(state, observation, targets)

    return HordeUpdateResult(  # type: ignore[call-arg]
        state=result.state,
        predictions=result.predictions,
        td_errors=result.errors,
        td_targets=targets,
        per_demon_metrics=result.per_head_metrics,
        trunk_bounding_metric=result.trunk_bounding_metric,
    )

HordeLearningResult

Result from a Horde scan-based learning loop.

Attributes: state: Final multi-head MLP learner state per_demon_metrics: Per-demon metrics over time, shape (num_steps, n_demons, 3) td_errors: TD errors over time, shape (num_steps, n_demons)

HordeUpdateResult

Result of a single Horde update step.

Attributes: state: Updated multi-head MLP learner state predictions: Predictions from all demons, shape (n_demons,) td_errors: TD errors (target - prediction), shape (n_demons,). NaN for inactive demons. td_targets: Computed TD targets r + gamma * V(s'), shape (n_demons,). NaN for inactive demons. per_demon_metrics: Per-demon metrics, shape (n_demons, 3). Columns: [squared_error, raw_error, mean_step_size]. NaN for inactive demons. trunk_bounding_metric: Scalar trunk bounding metric

LinearLearner(optimizer=None, normalizer=None)

Linear function approximator with pluggable optimizer and optional normalizer.

Computes predictions as: y = w @ x + b

The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Attributes: optimizer: The optimizer to use for weight updates normalizer: Optional online feature normalizer

Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01) normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)

Source code in src/alberta_framework/core/learners.py
def __init__(
    self,
    optimizer: AnyOptimizer | None = None,
    normalizer: (
        Normalizer[EMANormalizerState] | Normalizer[WelfordNormalizerState] | None
    ) = None,
):
    """Initialize the linear learner.

    Args:
        optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
        normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)
    """
    self._optimizer: AnyOptimizer = optimizer or LMS(step_size=0.01)
    self._normalizer = normalizer

normalizer property

The feature normalizer, or None if normalization is disabled.

init(feature_dim)

Initialize learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> LearnerState:
    """Initialize learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    normalizer_state = None
    if self._normalizer is not None:
        normalizer_state = self._normalizer.init(feature_dim)

    return LearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
        normalizer_state=normalizer_state,
        step_count=jnp.array(0, dtype=jnp.int32),
        birth_timestamp=time.time(),
        uptime_s=0.0,
    )

predict(state, observation)

Compute prediction for an observation.

Args: state: Current learner state observation: Input feature vector

Returns: Scalar prediction y = w @ x + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: LearnerState, observation: Observation) -> Prediction:
    """Compute prediction for an observation.

    Args:
        state: Current learner state
        observation: Input feature vector

    Returns:
        Scalar prediction ``y = w @ x + b``
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, target)

Update learner given observation and target.

Performs one step of the learning algorithm: 1. Optionally normalize observation 2. Compute prediction 3. Compute error 4. Get weight updates from optimizer 5. Apply updates to weights and bias

Args: state: Current learner state observation: Input feature vector target: Desired output

Returns: UpdateResult with new state, prediction, error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: LearnerState,
    observation: Observation,
    target: Target,
) -> UpdateResult:
    """Update learner given observation and target.

    Performs one step of the learning algorithm:
    1. Optionally normalize observation
    2. Compute prediction
    3. Compute error
    4. Get weight updates from optimizer
    5. Apply updates to weights and bias

    Args:
        state: Current learner state
        observation: Input feature vector
        target: Desired output

    Returns:
        UpdateResult with new state, prediction, error, and metrics
    """
    # Handle normalization
    new_normalizer_state = state.normalizer_state
    obs = observation
    if self._normalizer is not None and state.normalizer_state is not None:
        obs, new_normalizer_state = self._normalizer.normalize(
            state.normalizer_state, observation
        )

    # Make prediction
    prediction = self.predict(
        LearnerState(
            weights=state.weights,
            bias=state.bias,
            optimizer_state=state.optimizer_state,
            normalizer_state=new_normalizer_state,
            step_count=state.step_count,
            birth_timestamp=state.birth_timestamp,
            uptime_s=state.uptime_s,
        ),
        obs,
    )

    # Compute error (target - prediction)
    error = jnp.squeeze(target) - jnp.squeeze(prediction)

    # Get update from optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        error,
        obs,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = LearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
        normalizer_state=new_normalizer_state,
        step_count=state.step_count + 1,
        birth_timestamp=state.birth_timestamp,
        uptime_s=state.uptime_s,
    )

    # Pack metrics as array for scan compatibility
    squared_error = error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)

    if self._normalizer is not None and new_normalizer_state is not None:
        normalizer_mean_var = jnp.mean(new_normalizer_state.var)
        metrics = jnp.array(
            [squared_error, error, mean_step_size, normalizer_mean_var],
            dtype=jnp.float32,
        )
    else:
        metrics = jnp.array(
            [squared_error, error, mean_step_size], dtype=jnp.float32
        )

    return UpdateResult(
        state=new_state,
        prediction=prediction,
        error=jnp.atleast_1d(error),
        metrics=metrics,
    )

TDLinearLearner(optimizer=None)

Linear function approximator for TD learning.

Computes value predictions as: V(s) = w @ phi(s) + b

The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).

This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: optimizer: The TD optimizer to use for weight updates

Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()

Source code in src/alberta_framework/core/learners.py
def __init__(self, optimizer: AnyTDOptimizer | None = None):
    """Initialize the TD linear learner.

    Args:
        optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
    """
    self._optimizer: AnyTDOptimizer = optimizer or TDIDBD()

init(feature_dim)

Initialize TD learner state.

Args: feature_dim: Dimension of the input feature vector

Returns: Initial TD learner state with zero weights and bias

Source code in src/alberta_framework/core/learners.py
def init(self, feature_dim: int) -> TDLearnerState:
    """Initialize TD learner state.

    Args:
        feature_dim: Dimension of the input feature vector

    Returns:
        Initial TD learner state with zero weights and bias
    """
    optimizer_state = self._optimizer.init(feature_dim)

    return TDLearnerState(
        weights=jnp.zeros(feature_dim, dtype=jnp.float32),
        bias=jnp.array(0.0, dtype=jnp.float32),
        optimizer_state=optimizer_state,
        step_count=jnp.array(0, dtype=jnp.int32),
        birth_timestamp=time.time(),
        uptime_s=0.0,
    )

predict(state, observation)

Compute value prediction for an observation.

Args: state: Current TD learner state observation: Input feature vector phi(s)

Returns: Scalar value prediction V(s) = w @ phi(s) + b

Source code in src/alberta_framework/core/learners.py
def predict(self, state: TDLearnerState, observation: Observation) -> Prediction:
    """Compute value prediction for an observation.

    Args:
        state: Current TD learner state
        observation: Input feature vector phi(s)

    Returns:
        Scalar value prediction ``V(s) = w @ phi(s) + b``
    """
    return jnp.atleast_1d(jnp.dot(state.weights, observation) + state.bias)

update(state, observation, reward, next_observation, gamma)

Update learner given a TD transition.

Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error delta = R + gamma*V(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias

Args: state: Current TD learner state observation: Current observation phi(s) reward: Reward R received next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal states)

Returns: TDUpdateResult with new state, prediction, TD error, and metrics

Source code in src/alberta_framework/core/learners.py
def update(
    self,
    state: TDLearnerState,
    observation: Observation,
    reward: Array,
    next_observation: Observation,
    gamma: Array,
) -> TDUpdateResult:
    """Update learner given a TD transition.

    Performs one step of TD learning:
    1. Compute V(s) and V(s')
    2. Compute TD error delta = R + gamma*V(s') - V(s)
    3. Get weight updates from TD optimizer
    4. Apply updates to weights and bias

    Args:
        state: Current TD learner state
        observation: Current observation phi(s)
        reward: Reward R received
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal states)

    Returns:
        TDUpdateResult with new state, prediction, TD error, and metrics
    """
    # Compute predictions
    prediction = self.predict(state, observation)
    next_prediction = self.predict(state, next_observation)

    # Compute TD error: delta = R + gamma*V(s') - V(s)
    gamma_scalar = jnp.squeeze(gamma)
    td_error = (
        jnp.squeeze(reward)
        + gamma_scalar * jnp.squeeze(next_prediction)
        - jnp.squeeze(prediction)
    )

    # Get update from TD optimizer
    opt_update = self._optimizer.update(
        state.optimizer_state,
        td_error,
        observation,
        next_observation,
        gamma,
    )

    # Apply updates
    new_weights = state.weights + opt_update.weight_delta
    new_bias = state.bias + opt_update.bias_delta

    new_state = TDLearnerState(
        weights=new_weights,
        bias=new_bias,
        optimizer_state=opt_update.new_state,
        step_count=state.step_count + 1,
        birth_timestamp=state.birth_timestamp,
        uptime_s=state.uptime_s,
    )

    # Pack metrics as array for scan compatibility
    squared_td_error = td_error**2
    mean_step_size = opt_update.metrics.get("mean_step_size", 0.0)
    mean_elig_trace = opt_update.metrics.get("mean_eligibility_trace", 0.0)
    metrics = jnp.array(
        [squared_td_error, td_error, mean_step_size, mean_elig_trace],
        dtype=jnp.float32,
    )

    return TDUpdateResult(
        state=new_state,
        prediction=prediction,
        td_error=jnp.atleast_1d(td_error),
        metrics=metrics,
    )

TDUpdateResult

Result of a TD learner update step.

Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error delta = R + gamma*V(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]

IDBD(initial_step_size=0.01, meta_step_size=0.01, h_decay_mode='prediction_grads')

Bases: Optimizer[IDBDState]

Incremental Delta-Bar-Delta optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.

This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes h_decay_mode: Mode for computing the h-decay term in MLP path. "prediction_grads": h_decay = z^2 (squared prediction gradients). This is the principled generalization — for linear models, z = x so z^2 = x^2, recovering Sutton 1992. "loss_grads": h_decay = (error * z)^2 (Fisher approximation of the Hessian diagonal). Only affects the MLP path (update_from_gradient); the linear update() method always uses x^2.

Raises: ValueError: If h_decay_mode is not one of the valid modes

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    h_decay_mode: str = "prediction_grads",
):
    """Initialize IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate beta for adapting step-sizes
        h_decay_mode: Mode for computing the h-decay term in MLP path.
            ``"prediction_grads"``: h_decay = z^2 (squared prediction
            gradients). This is the principled generalization — for
            linear models, z = x so z^2 = x^2, recovering Sutton 1992.
            ``"loss_grads"``: h_decay = (error * z)^2 (Fisher
            approximation of the Hessian diagonal).
            Only affects the MLP path (``update_from_gradient``);
            the linear ``update()`` method always uses x^2.

    Raises:
        ValueError: If ``h_decay_mode`` is not one of the valid modes
    """
    if h_decay_mode not in ("prediction_grads", "loss_grads"):
        raise ValueError(
            f"Invalid h_decay_mode: {h_decay_mode!r}. "
            "Must be 'prediction_grads' or 'loss_grads'."
        )
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._h_decay_mode = h_decay_mode

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    config: dict[str, Any] = {
        "type": "IDBD",
        "initial_step_size": self._initial_step_size,
        "meta_step_size": self._meta_step_size,
    }
    if self._h_decay_mode != "prediction_grads":
        config["h_decay_mode"] = self._h_decay_mode
    return config

init(feature_dim)

Initialize IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: IDBD state with per-weight step-sizes and traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> IDBDState:
    """Initialize IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        IDBD state with per-weight step-sizes and traces
    """
    return IDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        bias_step_size=jnp.array(self._initial_step_size, dtype=jnp.float32),
        bias_trace=jnp.array(0.0, dtype=jnp.float32),
    )

init_for_shape(shape)

Initialize IDBD state for arbitrary-shape parameters.

Args: shape: Shape of the parameter array

Returns: IDBDParamState with arrays matching the given shape

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> IDBDParamState:
    """Initialize IDBD state for arbitrary-shape parameters.

    Args:
        shape: Shape of the parameter array

    Returns:
        IDBDParamState with arrays matching the given shape
    """
    return IDBDParamState(
        log_step_sizes=jnp.full(
            shape, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        traces=jnp.zeros(shape, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
    )

update_from_gradient(state, gradient, error=None)

Compute IDBD update from pre-computed gradient (MLP path).

Implements Meyer's adaptation of IDBD to nonlinear models. The key insight: replace x^2 in the h-decay term with (dy/dw)^2 (squared prediction gradients), which generalizes IDBD to arbitrary architectures.

This follows Meyer's implementation, which differs from the linear IDBD (Sutton 1992) in two ways to better handle deep networks:

  1. The meta-update uses z * h (prediction gradient times trace) without the current error, rather than error * z * h.
  2. The h-trace accumulates loss gradients (-error * z) rather than error-scaled prediction gradients (error * z).

These changes address problems with IDBD in deep networks where the step-size being factored into both h and beta updates causes compounding effects.

Reference: Meyer, https://github.com/ejmejm/phd_research

Operation order (meta-update first, then new alpha for trace):

  1. Compute h_decay based on mode: z^2 or (error * z)^2
  2. Meta-update with OLD traces: log_alpha += beta * z * h
  3. Clip log step-sizes to [-10.0, 2.0]
  4. New step-sizes: alpha = exp(log_alpha)
  5. Step: alpha * z (error applied externally by caller)
  6. Trace update: h = h * max(0, 1 - alpha * h_decay) + alpha * g where g = -error * z (loss gradient direction)

When error is None (trunk path in multi-head), the gradient is already in loss gradient direction (accumulated cotangents), so the trace uses alpha * z directly.

Args: state: Current IDBD param state gradient: Pre-computed prediction gradient / eligibility trace (same shape as state arrays) error: Optional prediction error scalar. When provided, used for h_decay (loss_grads mode) and h-trace sign.

Returns: (step, new_state) where step has the same shape as gradient

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self,
    state: IDBDParamState,
    gradient: Array,
    error: Array | None = None,
) -> tuple[Array, IDBDParamState]:
    """Compute IDBD update from pre-computed gradient (MLP path).

    Implements Meyer's adaptation of IDBD to nonlinear models. The key
    insight: replace ``x^2`` in the h-decay term with ``(dy/dw)^2``
    (squared prediction gradients), which generalizes IDBD to arbitrary
    architectures.

    This follows Meyer's implementation, which differs from the linear
    IDBD (Sutton 1992) in two ways to better handle deep networks:

    1. The meta-update uses ``z * h`` (prediction gradient times trace)
       without the current error, rather than ``error * z * h``.
    2. The h-trace accumulates loss gradients (``-error * z``) rather
       than error-scaled prediction gradients (``error * z``).

    These changes address problems with IDBD in deep networks where
    the step-size being factored into both h and beta updates causes
    compounding effects.

    Reference: Meyer, https://github.com/ejmejm/phd_research

    Operation order (meta-update first, then new alpha for trace):

    1. Compute h_decay based on mode: ``z^2`` or ``(error * z)^2``
    2. Meta-update with OLD traces: ``log_alpha += beta * z * h``
    3. Clip log step-sizes to ``[-10.0, 2.0]``
    4. New step-sizes: ``alpha = exp(log_alpha)``
    5. Step: ``alpha * z`` (error applied externally by caller)
    6. Trace update: ``h = h * max(0, 1 - alpha * h_decay) + alpha * g``
       where ``g = -error * z`` (loss gradient direction)

    When ``error`` is None (trunk path in multi-head), the gradient
    is already in loss gradient direction (accumulated cotangents),
    so the trace uses ``alpha * z`` directly.

    Args:
        state: Current IDBD param state
        gradient: Pre-computed prediction gradient / eligibility trace
            (same shape as state arrays)
        error: Optional prediction error scalar. When provided,
            used for h_decay (loss_grads mode) and h-trace sign.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient
    """
    beta = state.meta_step_size
    z = gradient

    # 1. Compute h_decay based on mode
    if self._h_decay_mode == "loss_grads" and error is not None:
        h_decay = (jnp.squeeze(error) * z) ** 2
    else:
        # prediction_grads mode, or loss_grads without error
        h_decay = z**2

    # 2. Meta-update with OLD traces (Meyer: prediction_grads * h, no error)
    meta_gradient = z * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * meta_gradient

    # 3. Clip log step-sizes
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # 4. New step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # 5. Step: alpha * z (error applied externally)
    step = new_alphas * z

    # 6. Trace update: h = h * decay + alpha * loss_grads
    # Meyer uses loss_grads = -error * z when error is available;
    # when error is None (trunk path), z is already loss gradient direction.
    decay = jnp.maximum(0.0, 1.0 - new_alphas * h_decay)
    if error is not None:
        new_traces = state.traces * decay - new_alphas * jnp.squeeze(error) * z
    else:
        new_traces = state.traces * decay + new_alphas * z

    new_state = IDBDParamState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
    )

    return step, new_state

update(state, error, observation)

Compute IDBD weight update with adaptive step-sizes.

Following Sutton 1992, Figure 2, the operation ordering is:

  1. Meta-update: log_alpha_i += beta * error * x_i * h_i (using OLD traces)
  2. Compute NEW step-sizes: alpha_i = exp(log_alpha_i)
  3. Update weights: w_i += alpha_i * error * x_i (using NEW alpha)
  4. Update traces: h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i (using NEW alpha)

The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.

Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: IDBDState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute IDBD weight update with adaptive step-sizes.

    Following Sutton 1992, Figure 2, the operation ordering is:

    1. Meta-update: ``log_alpha_i += beta * error * x_i * h_i`` (using OLD traces)
    2. Compute NEW step-sizes: ``alpha_i = exp(log_alpha_i)``
    3. Update weights: ``w_i += alpha_i * error * x_i`` (using NEW alpha)
    4. Update traces: ``h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i``
       (using NEW alpha)

    The trace h_i tracks the correlation between current and past gradients.
    When gradients consistently point the same direction, h_i grows,
    leading to larger step-sizes.

    Args:
        state: Current IDBD state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight deltas and updated state
    """
    error_scalar = jnp.squeeze(error)
    beta = state.meta_step_size

    # 1. Meta-update: adapt step-sizes using OLD traces
    gradient_correlation = error_scalar * observation * state.traces
    new_log_step_sizes = state.log_step_sizes + beta * gradient_correlation

    # Clip log step-sizes to prevent numerical issues
    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)

    # 2. Compute NEW step-sizes
    new_alphas = jnp.exp(new_log_step_sizes)

    # 3. Weight updates using NEW alpha: alpha_i * error * x_i
    weight_delta = new_alphas * error_scalar * observation

    # 4. Update traces using NEW alpha: h_i = h_i * decay + alpha_i * error * x_i
    # decay = max(0, 1 - alpha_i * x_i^2)
    decay = jnp.maximum(0.0, 1.0 - new_alphas * observation**2)
    new_traces = state.traces * decay + new_alphas * error_scalar * observation

    # Bias updates (same ordering: meta-update first, then new alpha)
    bias_gradient_correlation = error_scalar * state.bias_trace
    new_bias_step_size = state.bias_step_size * jnp.exp(beta * bias_gradient_correlation)
    new_bias_step_size = jnp.clip(new_bias_step_size, 1e-6, 1.0)

    bias_delta = new_bias_step_size * error_scalar

    bias_decay = jnp.maximum(0.0, 1.0 - new_bias_step_size)
    new_bias_trace = state.bias_trace * bias_decay + new_bias_step_size * error_scalar

    new_state = IDBDState(
        log_step_sizes=new_log_step_sizes,
        traces=new_traces,
        meta_step_size=beta,
        bias_step_size=new_bias_step_size,
        bias_trace=new_bias_trace,
    )

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
        },
    )

LMS(step_size=0.01)

Bases: Optimizer[LMSState]

Least Mean Square optimizer with fixed step-size.

The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t

This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.

Attributes: step_size: Fixed learning rate alpha

Args: step_size: Fixed learning rate

Source code in src/alberta_framework/core/optimizers.py
def __init__(self, step_size: float = 0.01):
    """Initialize LMS optimizer.

    Args:
        step_size: Fixed learning rate
    """
    self._step_size = step_size

to_config()

Serialize configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
def to_config(self) -> dict[str, Any]:
    """Serialize configuration to dict."""
    return {"type": "LMS", "step_size": self._step_size}

init(feature_dim)

Initialize LMS state.

Args: feature_dim: Dimension of weight vector (unused for LMS)

Returns: LMS state containing the step-size

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> LMSState:
    """Initialize LMS state.

    Args:
        feature_dim: Dimension of weight vector (unused for LMS)

    Returns:
        LMS state containing the step-size
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

init_for_shape(shape)

Initialize LMS state for arbitrary-shape parameters.

LMS state is shape-independent (single scalar), so this returns the same state regardless of shape.

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> LMSState:
    """Initialize LMS state for arbitrary-shape parameters.

    LMS state is shape-independent (single scalar), so this returns
    the same state regardless of shape.
    """
    return LMSState(step_size=jnp.array(self._step_size, dtype=jnp.float32))

update_from_gradient(state, gradient, error=None)

Compute step from gradient: step = alpha * gradient.

Args: state: Current LMS state gradient: Pre-computed gradient (any shape) error: Unused by LMS (accepted for interface compatibility)

Returns: (step, state) -- state is unchanged for LMS

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: LMSState, gradient: Array, error: Array | None = None
) -> tuple[Array, LMSState]:
    """Compute step from gradient: ``step = alpha * gradient``.

    Args:
        state: Current LMS state
        gradient: Pre-computed gradient (any shape)
        error: Unused by LMS (accepted for interface compatibility)

    Returns:
        ``(step, state)`` -- state is unchanged for LMS
    """
    del error  # LMS doesn't meta-learn
    return state.step_size * gradient, state

update(state, error, observation)

Compute LMS weight update.

Update rule: delta_w = alpha * error * x

Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector

Returns: OptimizerUpdate with weight and bias deltas

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: LMSState,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute LMS weight update.

    Update rule: ``delta_w = alpha * error * x``

    Args:
        state: Current LMS state
        error: Prediction error (scalar)
        observation: Feature vector

    Returns:
        OptimizerUpdate with weight and bias deltas
    """
    alpha = state.step_size
    error_scalar = jnp.squeeze(error)

    # Weight update: alpha * error * x
    weight_delta = alpha * error_scalar * observation

    # Bias update: alpha * error
    bias_delta = alpha * error_scalar

    return OptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=state,  # LMS state doesn't change
        metrics={"step_size": alpha},
    )

TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)

Bases: TDOptimizer[TDIDBDState]

TD-IDBD optimizer for temporal-difference learning.

Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Two variants are supported: - Semi-gradient (default): Uses only phi(s) in meta-update, more stable - Ordinary gradient: Uses both phi(s) and phi(s'), more accurate but sensitive

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    use_semi_gradient: bool = True,
):
    """Initialize TD-IDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        use_semi_gradient: If True, use semi-gradient variant (recommended)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._use_semi_gradient = use_semi_gradient

init(feature_dim)

Initialize TD-IDBD state.

Args: feature_dim: Dimension of weight vector

Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> TDIDBDState:
    """Initialize TD-IDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        TD-IDBD state with per-weight step-sizes, traces, and h traces
    """
    return TDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute TD-IDBD weight update with adaptive step-sizes.

Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.

Args: state: Current TD-IDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: TDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute TD-IDBD weight update with adaptive step-sizes.

    Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient)
    from Kearney et al. 2019.

    Args:
        state: Current TD-IDBD state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    gamma_scalar = jnp.squeeze(gamma)

    if self._use_semi_gradient:
        gradient_correlation = delta * observation * state.h_traces
        new_log_step_sizes = state.log_step_sizes + theta * gradient_correlation
    else:
        feature_diff = gamma_scalar * next_observation - observation
        gradient_correlation = delta * feature_diff * state.h_traces
        new_log_step_sizes = state.log_step_sizes - theta * gradient_correlation

    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
    new_alphas = jnp.exp(new_log_step_sizes)

    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
    weight_delta = new_alphas * delta * new_eligibility_traces

    if self._use_semi_gradient:
        h_decay = jnp.maximum(0.0, 1.0 - new_alphas * observation * new_eligibility_traces)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces
    else:
        feature_diff = gamma_scalar * next_observation - observation
        h_decay = jnp.maximum(0.0, 1.0 + new_alphas * new_eligibility_traces * feature_diff)
        new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    if self._use_semi_gradient:
        bias_gradient_correlation = delta * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size + theta * bias_gradient_correlation
    else:
        bias_feature_diff = gamma_scalar - 1.0
        bias_gradient_correlation = delta * bias_feature_diff * state.bias_h_trace
        new_bias_log_step_size = state.bias_log_step_size - theta * bias_gradient_correlation

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    if self._use_semi_gradient:
        bias_h_decay = jnp.maximum(0.0, 1.0 - new_bias_alpha * new_bias_eligibility_trace)
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )
    else:
        bias_feature_diff = gamma_scalar - 1.0
        bias_h_decay = jnp.maximum(
            0.0, 1.0 + new_bias_alpha * new_bias_eligibility_trace * bias_feature_diff
        )
        new_bias_h_trace = (
            state.bias_h_trace * bias_h_decay
            + new_bias_alpha * delta * new_bias_eligibility_trace
        )

    new_state = TDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        meta_step_size=theta,
        trace_decay=lam,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
        },
    )

AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)

Bases: TDOptimizer[AutoTDIDBDState]

AutoStep-style normalized TD-IDBD optimizer.

Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta.

Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(lambda)"

Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers

Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)

Source code in src/alberta_framework/core/optimizers.py
def __init__(
    self,
    initial_step_size: float = 0.01,
    meta_step_size: float = 0.01,
    trace_decay: float = 0.0,
    normalizer_decay: float = 10000.0,
):
    """Initialize AutoTDIDBD optimizer.

    Args:
        initial_step_size: Initial value for per-weight step-sizes
        meta_step_size: Meta learning rate theta for adapting step-sizes
        trace_decay: Eligibility trace decay lambda (0 = TD(0))
        normalizer_decay: Decay parameter tau for normalizers (default: 10000)
    """
    self._initial_step_size = initial_step_size
    self._meta_step_size = meta_step_size
    self._trace_decay = trace_decay
    self._normalizer_decay = normalizer_decay

init(feature_dim)

Initialize AutoTDIDBD state.

Args: feature_dim: Dimension of weight vector

Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers

Source code in src/alberta_framework/core/optimizers.py
def init(self, feature_dim: int) -> AutoTDIDBDState:
    """Initialize AutoTDIDBD state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
    """
    return AutoTDIDBDState(
        log_step_sizes=jnp.full(
            feature_dim, jnp.log(self._initial_step_size), dtype=jnp.float32
        ),
        eligibility_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        h_traces=jnp.zeros(feature_dim, dtype=jnp.float32),
        normalizers=jnp.ones(feature_dim, dtype=jnp.float32),
        meta_step_size=jnp.array(self._meta_step_size, dtype=jnp.float32),
        trace_decay=jnp.array(self._trace_decay, dtype=jnp.float32),
        normalizer_decay=jnp.array(self._normalizer_decay, dtype=jnp.float32),
        bias_log_step_size=jnp.array(jnp.log(self._initial_step_size), dtype=jnp.float32),
        bias_eligibility_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_h_trace=jnp.array(0.0, dtype=jnp.float32),
        bias_normalizer=jnp.array(1.0, dtype=jnp.float32),
    )

update(state, td_error, observation, next_observation, gamma)

Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

Implements Algorithm 6 from Kearney et al. 2019.

Args: state: Current AutoTDIDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with weight deltas and updated state

Source code in src/alberta_framework/core/optimizers.py
def update(
    self,
    state: AutoTDIDBDState,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute AutoTDIDBD weight update with normalized adaptive step-sizes.

    Implements Algorithm 6 from Kearney et al. 2019.

    Args:
        state: Current AutoTDIDBD state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with weight deltas and updated state
    """
    delta = jnp.squeeze(td_error)
    theta = state.meta_step_size
    lam = state.trace_decay
    tau = state.normalizer_decay
    gamma_scalar = jnp.squeeze(gamma)

    feature_diff = gamma_scalar * next_observation - observation
    alphas = jnp.exp(state.log_step_sizes)

    # Update normalizers
    abs_weight_update = jnp.abs(delta * feature_diff * state.h_traces)
    normalizer_decay_term = (
        (1.0 / tau)
        * alphas
        * feature_diff
        * state.eligibility_traces
        * (jnp.abs(delta * observation * state.h_traces) - state.normalizers)
    )
    new_normalizers = jnp.maximum(abs_weight_update, state.normalizers - normalizer_decay_term)
    new_normalizers = jnp.maximum(new_normalizers, 1e-8)

    # Normalized meta-update
    normalized_gradient = delta * feature_diff * state.h_traces / new_normalizers
    new_log_step_sizes = state.log_step_sizes - theta * normalized_gradient

    # Effective step-size normalization
    effective_step_size = -jnp.sum(
        jnp.exp(new_log_step_sizes) * feature_diff * state.eligibility_traces
    )
    normalization_factor = jnp.maximum(effective_step_size, 1.0)
    new_log_step_sizes = new_log_step_sizes - jnp.log(normalization_factor)

    new_log_step_sizes = jnp.clip(new_log_step_sizes, -10.0, 2.0)
    new_alphas = jnp.exp(new_log_step_sizes)

    new_eligibility_traces = gamma_scalar * lam * state.eligibility_traces + observation
    weight_delta = new_alphas * delta * new_eligibility_traces

    # Update h traces
    h_decay = jnp.maximum(0.0, 1.0 + new_alphas * feature_diff * new_eligibility_traces)
    new_h_traces = state.h_traces * h_decay + new_alphas * delta * new_eligibility_traces

    # Bias updates
    bias_alpha = jnp.exp(state.bias_log_step_size)
    bias_feature_diff = gamma_scalar - 1.0

    abs_bias_weight_update = jnp.abs(delta * bias_feature_diff * state.bias_h_trace)
    bias_normalizer_decay_term = (
        (1.0 / tau)
        * bias_alpha
        * bias_feature_diff
        * state.bias_eligibility_trace
        * (jnp.abs(delta * state.bias_h_trace) - state.bias_normalizer)
    )
    new_bias_normalizer = jnp.maximum(
        abs_bias_weight_update, state.bias_normalizer - bias_normalizer_decay_term
    )
    new_bias_normalizer = jnp.maximum(new_bias_normalizer, 1e-8)

    normalized_bias_gradient = (
        delta * bias_feature_diff * state.bias_h_trace / new_bias_normalizer
    )
    new_bias_log_step_size = state.bias_log_step_size - theta * normalized_bias_gradient

    bias_effective_step_size = (
        -jnp.exp(new_bias_log_step_size) * bias_feature_diff * state.bias_eligibility_trace
    )
    bias_norm_factor = jnp.maximum(bias_effective_step_size, 1.0)
    new_bias_log_step_size = new_bias_log_step_size - jnp.log(bias_norm_factor)

    new_bias_log_step_size = jnp.clip(new_bias_log_step_size, -10.0, 2.0)
    new_bias_alpha = jnp.exp(new_bias_log_step_size)

    new_bias_eligibility_trace = gamma_scalar * lam * state.bias_eligibility_trace + 1.0
    bias_delta = new_bias_alpha * delta * new_bias_eligibility_trace

    bias_h_decay = jnp.maximum(
        0.0, 1.0 + new_bias_alpha * bias_feature_diff * new_bias_eligibility_trace
    )
    new_bias_h_trace = (
        state.bias_h_trace * bias_h_decay + new_bias_alpha * delta * new_bias_eligibility_trace
    )

    new_state = AutoTDIDBDState(
        log_step_sizes=new_log_step_sizes,
        eligibility_traces=new_eligibility_traces,
        h_traces=new_h_traces,
        normalizers=new_normalizers,
        meta_step_size=theta,
        trace_decay=lam,
        normalizer_decay=tau,
        bias_log_step_size=new_bias_log_step_size,
        bias_eligibility_trace=new_bias_eligibility_trace,
        bias_h_trace=new_bias_h_trace,
        bias_normalizer=new_bias_normalizer,
    )

    return TDOptimizerUpdate(
        weight_delta=weight_delta,
        bias_delta=bias_delta,
        new_state=new_state,
        metrics={
            "mean_step_size": jnp.mean(new_alphas),
            "min_step_size": jnp.min(new_alphas),
            "max_step_size": jnp.max(new_alphas),
            "mean_eligibility_trace": jnp.mean(jnp.abs(new_eligibility_traces)),
            "mean_normalizer": jnp.mean(new_normalizers),
        },
    )

Optimizer

Bases: ABC

Base class for optimizers.

to_config() abstractmethod

Serialize optimizer configuration to dict.

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def to_config(self) -> dict[str, Any]:
    """Serialize optimizer configuration to dict."""
    ...

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, error, observation) abstractmethod

Compute weight updates given prediction error.

Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector

Returns: OptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    error: Array,
    observation: Array,
) -> OptimizerUpdate:
    """Compute weight updates given prediction error.

    Args:
        state: Current optimizer state
        error: Prediction error (target - prediction)
        observation: Current observation/feature vector

    Returns:
        OptimizerUpdate with deltas and new state
    """
    ...

init_for_shape(shape)

Initialize optimizer state for parameters of arbitrary shape.

Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.

The return type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: shape: Shape of the parameter array

Returns: Initial optimizer state with arrays matching the given shape

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def init_for_shape(self, shape: tuple[int, ...]) -> Any:
    """Initialize optimizer state for parameters of arbitrary shape.

    Used by MLP learners where parameters are matrices/vectors of
    varying shapes. Not all optimizers support this.

    The return type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        shape: Shape of the parameter array

    Returns:
        Initial optimizer state with arrays matching the given shape

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support init_for_shape. "
        "Only LMS, IDBD, and Autostep currently implement this."
    )

update_from_gradient(state, gradient, error=None)

Compute step delta from pre-computed gradient.

The returned delta does NOT include the error -- the caller is responsible for multiplying error * delta before applying.

The state type varies by subclass (e.g. LMSState for LMS, AutostepParamState for Autostep) so the base signature uses Any.

Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.

Returns: (step, new_state) where step has the same shape as gradient

Raises: NotImplementedError: If the optimizer does not support this

Source code in src/alberta_framework/core/optimizers.py
def update_from_gradient(
    self, state: Any, gradient: Array, error: Array | None = None
) -> tuple[Array, Any]:
    """Compute step delta from pre-computed gradient.

    The returned delta does NOT include the error -- the caller is
    responsible for multiplying ``error * delta`` before applying.

    The state type varies by subclass (e.g. ``LMSState`` for LMS,
    ``AutostepParamState`` for Autostep) so the base signature uses
    ``Any``.

    Args:
        state: Current optimizer state
        gradient: Pre-computed gradient (e.g. eligibility trace)
        error: Optional prediction error scalar. Optimizers with
            meta-learning (e.g. Autostep) use this for meta-gradient
            computation. LMS ignores it.

    Returns:
        ``(step, new_state)`` where step has the same shape as gradient

    Raises:
        NotImplementedError: If the optimizer does not support this
    """
    raise NotImplementedError(
        f"{type(self).__name__} does not support update_from_gradient. "
        "Only LMS, IDBD, and Autostep currently implement this."
    )

TDOptimizer

Bases: ABC

Base class for TD optimizers.

TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.

init(feature_dim) abstractmethod

Initialize optimizer state.

Args: feature_dim: Dimension of weight vector

Returns: Initial optimizer state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def init(self, feature_dim: int) -> StateT:
    """Initialize optimizer state.

    Args:
        feature_dim: Dimension of weight vector

    Returns:
        Initial optimizer state
    """
    ...

update(state, td_error, observation, next_observation, gamma) abstractmethod

Compute weight updates given TD error.

Args: state: Current optimizer state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)

Returns: TDOptimizerUpdate with deltas and new state

Source code in src/alberta_framework/core/optimizers.py
@abstractmethod
def update(
    self,
    state: StateT,
    td_error: Array,
    observation: Array,
    next_observation: Array,
    gamma: Array,
) -> TDOptimizerUpdate:
    """Compute weight updates given TD error.

    Args:
        state: Current optimizer state
        td_error: TD error delta = R + gamma*V(s') - V(s)
        observation: Current observation phi(s)
        next_observation: Next observation phi(s')
        gamma: Discount factor gamma (0 at terminal)

    Returns:
        TDOptimizerUpdate with deltas and new state
    """
    ...

TDOptimizerUpdate

Result of a TD optimizer update step.

Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging

AutoTDIDBDState

State for the AutoTDIDBD optimizer.

AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.

Reference: Kearney et al. 2019, Algorithm 6

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation

DemonType

Bases: Enum

Type of GVF demon.

A prediction demon has a fixed policy and learns to predict. A control demon learns a policy (e.g. via SARSA) — Step 4.

GVFSpec

One GVF demon's question functions (Sutton et al. 2011).

Declarative, not callable — JAX pytree-compatible. Cumulant values are computed externally and passed as arrays.

Attributes: name: Human-readable name for this demon demon_type: Whether this is a prediction or control demon gamma: Pseudo-termination discount (0.0 = single-step prediction) lamda: Trace decay parameter (0.0 = no eligibility traces) cumulant_index: Index into targets array, or -1 for external cumulant terminal_reward: Terminal pseudo-reward z (default 0.0)

to_config()

Serialize to dict.

Returns: Dict with all fields needed to recreate the GVFSpec.

Source code in src/alberta_framework/core/types.py
def to_config(self) -> dict[str, Any]:
    """Serialize to dict.

    Returns:
        Dict with all fields needed to recreate the GVFSpec.
    """
    return {
        "name": self.name,
        "demon_type": self.demon_type.value,
        "gamma": self.gamma,
        "lamda": self.lamda,
        "cumulant_index": self.cumulant_index,
        "terminal_reward": self.terminal_reward,
    }

from_config(config) classmethod

Reconstruct from config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed GVFSpec

Source code in src/alberta_framework/core/types.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GVFSpec":
    """Reconstruct from config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed GVFSpec
    """
    config = dict(config)
    config["demon_type"] = DemonType(config["demon_type"])
    return cls(**config)

HordeSpec

Collection of GVF demons, one per head.

Attributes: demons: Tuple of GVFSpec, one per demon/head gammas: Pre-computed gamma array for JIT, shape (n_demons,) lamdas: Pre-computed lambda array for JIT, shape (n_demons,)

to_config()

Serialize to dict.

Returns: Dict with demons list, each serialized via GVFSpec.to_config().

Source code in src/alberta_framework/core/types.py
def to_config(self) -> dict[str, Any]:
    """Serialize to dict.

    Returns:
        Dict with demons list, each serialized via ``GVFSpec.to_config()``.
    """
    return {
        "demons": [d.to_config() for d in self.demons],
    }

from_config(config) classmethod

Reconstruct from config dict.

Args: config: Dict as produced by to_config()

Returns: Reconstructed HordeSpec via create_horde_spec

Source code in src/alberta_framework/core/types.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "HordeSpec":
    """Reconstruct from config dict.

    Args:
        config: Dict as produced by ``to_config()``

    Returns:
        Reconstructed HordeSpec via ``create_horde_spec``
    """
    demons = [GVFSpec.from_config(d) for d in config["demons"]]
    return create_horde_spec(demons)

IDBDState

State for the IDBD (Incremental Delta-Bar-Delta) optimizer.

IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.

Reference: Sutton 1992, "Adapting Bias by Gradient Descent"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term

LearnerState

State for a linear learner.

Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer normalizer_state: Optional state for online feature normalization

LMSState

State for the LMS (Least Mean Square) optimizer.

LMS uses a fixed step-size, so state only tracks the step-size parameter.

Attributes: step_size: Fixed learning rate alpha

TDIDBDState

State for the TD-IDBD (Temporal-Difference IDBD) optimizer.

TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.

Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"

Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term

TDLearnerState

State for a TD linear learner.

Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer

TDTimeStep

Single experience from a TD stream.

Represents a transition (s, r, s', gamma) for temporal-difference learning.

Attributes: observation: Feature vector phi(s) reward: Reward R received next_observation: Feature vector phi(s') gamma: Discount factor gamma_t (0 at terminal states)

TimeStep

Single experience from an experience stream.

Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)

run_horde_learning_loop(horde, state, observations, cumulants, next_observations)

Run Horde learning loop using jax.lax.scan.

Scans over (obs, cumulants, next_obs) triples.

Args: horde: Horde learner state: Initial learner state observations: Input observations, shape (num_steps, feature_dim) cumulants: Per-demon cumulants, shape (num_steps, n_demons). NaN = inactive demon for that step. next_observations: Next observations, shape (num_steps, feature_dim)

Returns: HordeLearningResult with final state, per-demon metrics, and TD errors

Source code in src/alberta_framework/core/horde.py
def run_horde_learning_loop(
    horde: HordeLearner,
    state: MultiHeadMLPState,
    observations: Array,
    cumulants: Array,
    next_observations: Array,
) -> HordeLearningResult:
    """Run Horde learning loop using ``jax.lax.scan``.

    Scans over ``(obs, cumulants, next_obs)`` triples.

    Args:
        horde: Horde learner
        state: Initial learner state
        observations: Input observations, shape ``(num_steps, feature_dim)``
        cumulants: Per-demon cumulants, shape ``(num_steps, n_demons)``.
            NaN = inactive demon for that step.
        next_observations: Next observations, shape ``(num_steps, feature_dim)``

    Returns:
        HordeLearningResult with final state, per-demon metrics, and TD errors
    """

    def step_fn(
        carry: MultiHeadMLPState,
        inputs: tuple[Array, Array, Array],
    ) -> tuple[MultiHeadMLPState, tuple[Array, Array]]:
        l_state = carry
        obs, cums, next_obs = inputs
        result = horde.update(l_state, obs, cums, next_obs)
        return result.state, (result.per_demon_metrics, result.td_errors)

    t0 = time.time()
    final_state, (per_demon_metrics, td_errors) = jax.lax.scan(
        step_fn, state, (observations, cumulants, next_observations)
    )
    elapsed = time.time() - t0
    final_state = final_state.replace(uptime_s=final_state.uptime_s + elapsed)  # type: ignore[attr-defined]

    return HordeLearningResult(  # type: ignore[call-arg]
        state=final_state,
        per_demon_metrics=per_demon_metrics,
        td_errors=td_errors,
    )

run_horde_learning_loop_batched(horde, observations, cumulants, next_observations, keys)

Run Horde learning loop across seeds using jax.vmap.

Each seed produces an independently initialized state. All seeds share the same observations, cumulants, and next observations.

Args: horde: Horde learner observations: Shared observations, shape (num_steps, feature_dim) cumulants: Shared cumulants, shape (num_steps, n_demons) next_observations: Shared next observations, shape (num_steps, feature_dim) keys: JAX random keys, shape (n_seeds,) or (n_seeds, 2)

Returns: BatchedHordeResult with batched states, per-demon metrics, and TD errors

Source code in src/alberta_framework/core/horde.py
def run_horde_learning_loop_batched(
    horde: HordeLearner,
    observations: Array,
    cumulants: Array,
    next_observations: Array,
    keys: Array,
) -> BatchedHordeResult:
    """Run Horde learning loop across seeds using ``jax.vmap``.

    Each seed produces an independently initialized state. All seeds
    share the same observations, cumulants, and next observations.

    Args:
        horde: Horde learner
        observations: Shared observations, shape ``(num_steps, feature_dim)``
        cumulants: Shared cumulants, shape ``(num_steps, n_demons)``
        next_observations: Shared next observations,
            shape ``(num_steps, feature_dim)``
        keys: JAX random keys, shape ``(n_seeds,)`` or ``(n_seeds, 2)``

    Returns:
        BatchedHordeResult with batched states, per-demon metrics, and TD errors
    """
    feature_dim = observations.shape[1]

    def single_run(key: Array) -> tuple[MultiHeadMLPState, Array, Array]:
        init_state = horde.init(feature_dim, key)
        result = run_horde_learning_loop(
            horde, init_state, observations, cumulants, next_observations
        )
        return result.state, result.per_demon_metrics, result.td_errors

    t0 = time.time()
    batched_states, batched_metrics, batched_td_errors = jax.vmap(single_run)(keys)
    elapsed = time.time() - t0
    batched_states = batched_states.replace(  # type: ignore[attr-defined]
        uptime_s=batched_states.uptime_s + elapsed
    )

    return BatchedHordeResult(  # type: ignore[call-arg]
        states=batched_states,
        per_demon_metrics=batched_metrics,
        td_errors=batched_td_errors,
    )

create_horde_spec(demons)

Create a HordeSpec from a sequence of GVFSpec demons.

Pre-computes gamma and lambda arrays for efficient JIT usage.

Args: demons: Sequence of GVFSpec, one per demon/head

Returns: HordeSpec with pre-computed arrays

Source code in src/alberta_framework/core/types.py
def create_horde_spec(demons: Sequence[GVFSpec]) -> HordeSpec:
    """Create a HordeSpec from a sequence of GVFSpec demons.

    Pre-computes gamma and lambda arrays for efficient JIT usage.

    Args:
        demons: Sequence of GVFSpec, one per demon/head

    Returns:
        HordeSpec with pre-computed arrays
    """
    demons_tuple = tuple(demons)
    gammas = jnp.array([d.gamma for d in demons_tuple], dtype=jnp.float32)
    lamdas = jnp.array([d.lamda for d in demons_tuple], dtype=jnp.float32)
    return HordeSpec(demons=demons_tuple, gammas=gammas, lamdas=lamdas)