alberta_framework
alberta_framework
¶
Alberta Framework: A JAX-based research framework for continual AI.
The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity — every component updates at every time step, with no special training phases or batch processing.
Roadmap
| Step | Focus | Status |
|---|---|---|
| 1 | Meta-learned step-sizes (IDBD, Autostep) | Complete |
| 2 | Nonlinear function approximation (MLP, ObGD) | In Progress |
| 3 | GVF predictions, Horde architecture | Planned |
| 4 | Actor-critic with eligibility traces | Planned |
| 5-6 | Off-policy learning, average reward | Planned |
| 7-12 | Hierarchical, multi-agent, world models | Future |
Examples:
import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
# Non-stationary stream where target weights drift over time
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
# Learner with IDBD meta-learned step-sizes
learner = LinearLearner(optimizer=IDBD())
# JIT-compiled training via jax.lax.scan
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
References
- The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
- Adapting Bias by Gradient Descent (Sutton, 1992)
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
- Streaming Deep Reinforcement Learning Finally Works (Elsayed et al., 2024)
LinearLearner(optimizer=None, normalizer=None)
¶
Linear function approximator with pluggable optimizer and optional normalizer.
Computes predictions as: y = w @ x + b
The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).
This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.
Attributes: optimizer: The optimizer to use for weight updates normalizer: Optional online feature normalizer
Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01) normalizer: Optional feature normalizer (e.g. EMANormalizer, WelfordNormalizer)
Source code in src/alberta_framework/core/learners.py
normalizer
property
¶
The feature normalizer, or None if normalization is disabled.
init(feature_dim)
¶
Initialize learner state.
Args: feature_dim: Dimension of the input feature vector
Returns: Initial learner state with zero weights and bias
Source code in src/alberta_framework/core/learners.py
predict(state, observation)
¶
Compute prediction for an observation.
Args: state: Current learner state observation: Input feature vector
Returns:
Scalar prediction y = w @ x + b
Source code in src/alberta_framework/core/learners.py
update(state, observation, target)
¶
Update learner given observation and target.
Performs one step of the learning algorithm: 1. Optionally normalize observation 2. Compute prediction 3. Compute error 4. Get weight updates from optimizer 5. Apply updates to weights and bias
Args: state: Current learner state observation: Input feature vector target: Desired output
Returns: UpdateResult with new state, prediction, error, and metrics
Source code in src/alberta_framework/core/learners.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 | |
MLPLearner(hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, gamma=0.0, lamda=0.0, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True)
¶
Multi-layer perceptron with composable optimizer, bounder, and normalizer.
Architecture: Input -> [Dense(H) -> LayerNorm -> LeakyReLU] x N -> Dense(1)
When use_layer_norm=False, the architecture simplifies to:
Input -> [Dense(H) -> LeakyReLU] x N -> Dense(1)
Uses parameterless layer normalization and sparse initialization following Elsayed et al. 2024. Accepts a pluggable optimizer (LMS, Autostep), an optional bounder (ObGDBounding), and an optional feature normalizer (EMANormalizer, WelfordNormalizer).
The update flow:
1. If normalizer: normalize observation, update normalizer state
2. Forward pass + jax.grad to get per-layer prediction gradients
3. Update eligibility traces: z = gamma * lamda * z + grad
4. Per-layer optimizer step: step, new_opt = optimizer.update_from_gradient(state, z)
5. If bounder: bound all steps globally
6. Apply: param += scale * error * step
Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"
Attributes: hidden_sizes: Tuple of hidden layer sizes optimizer: Optimizer for per-weight step-size adaptation bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer use_layer_norm: Whether to apply parameterless layer normalization gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter sparsity: Fraction of weights zeroed out per output neuron leaky_relu_slope: Negative slope for LeakyReLU activation
Args:
hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
Must support init_for_shape and update_from_gradient.
step_size: Base learning rate (used only when optimizer is None,
default: 1.0)
bounder: Optional update bounder (e.g. ObGDBounding for ObGD-style
bounding). When None, no bounding is applied.
gamma: Discount factor for trace decay (default: 0.0 for supervised)
lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
normalizer: Optional feature normalizer. When provided, features are
normalized before prediction and learning.
sparsity: Fraction of weights zeroed out per output neuron (default: 0.9)
leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
use_layer_norm: Whether to apply parameterless layer normalization
between hidden layers (default: True). Set to False for ablation
studies.
Source code in src/alberta_framework/core/learners.py
normalizer
property
¶
The feature normalizer, or None if normalization is disabled.
init(feature_dim, key)
¶
Initialize MLP learner state with sparse weights.
Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization
Returns: Initial MLP learner state with sparse weights and zero biases
Source code in src/alberta_framework/core/learners.py
predict(state, observation)
¶
Compute prediction for an observation.
Args: state: Current MLP learner state observation: Input feature vector
Returns: Scalar prediction
Source code in src/alberta_framework/core/learners.py
update(state, observation, target)
¶
Update MLP given observation and target.
Performs one step of the learning algorithm: 1. Optionally normalize observation 2. Compute prediction and error 3. Compute gradients via jax.grad on the forward pass 4. Update eligibility traces 5. Per-layer optimizer step from traces 6. Optionally bound steps 7. Apply bounded weight updates
Args: state: Current MLP learner state observation: Input feature vector target: Desired output
Returns: MLPUpdateResult with new state, prediction, error, and metrics
Source code in src/alberta_framework/core/learners.py
940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 | |
MLPUpdateResult
¶
Result of an MLP learner update step.
Attributes: state: Updated MLP learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics -- shape (3,) without normalizer, (4,) with normalizer
TDLinearLearner(optimizer=None)
¶
Linear function approximator for TD learning.
Computes value predictions as: V(s) = w @ phi(s) + b
The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).
This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"
Attributes: optimizer: The TD optimizer to use for weight updates
Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
Source code in src/alberta_framework/core/learners.py
init(feature_dim)
¶
Initialize TD learner state.
Args: feature_dim: Dimension of the input feature vector
Returns: Initial TD learner state with zero weights and bias
Source code in src/alberta_framework/core/learners.py
predict(state, observation)
¶
Compute value prediction for an observation.
Args: state: Current TD learner state observation: Input feature vector phi(s)
Returns:
Scalar value prediction V(s) = w @ phi(s) + b
Source code in src/alberta_framework/core/learners.py
update(state, observation, reward, next_observation, gamma)
¶
Update learner given a TD transition.
Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error delta = R + gamma*V(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias
Args: state: Current TD learner state observation: Current observation phi(s) reward: Reward R received next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal states)
Returns: TDUpdateResult with new state, prediction, TD error, and metrics
Source code in src/alberta_framework/core/learners.py
1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 | |
TDUpdateResult
¶
Result of a TD learner update step.
Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error delta = R + gamma*V(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]
UpdateResult
¶
Result of a learner update step.
Attributes: state: Updated learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics -- shape (3,) without normalizer, (4,) with normalizer
EMANormalizer(epsilon=1e-08, decay=0.99)
¶
Bases: Normalizer[EMANormalizerState]
Online feature normalizer using exponential moving average.
Estimates mean and variance via EMA, suitable for non-stationary environments where recent observations should be weighted more heavily.
The effective decay ramps up from 0 to the target decay over early steps to prevent instability.
Attributes: epsilon: Small constant for numerical stability decay: Exponential decay for running estimates (0.99 = slower adaptation)
Args: epsilon: Small constant added to std for numerical stability decay: Exponential decay factor for running estimates. Lower values adapt faster to changes. 1.0 means pure online average (no decay).
Source code in src/alberta_framework/core/normalizers.py
normalize_only(state, observation)
¶
Normalize observation without updating statistics.
Useful for inference or when you want to normalize multiple observations with the same statistics.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Normalized observation
Source code in src/alberta_framework/core/normalizers.py
update_only(state, observation)
¶
Update statistics without returning normalized observation.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Updated normalizer state
Source code in src/alberta_framework/core/normalizers.py
init(feature_dim)
¶
Initialize EMA normalizer state.
Args: feature_dim: Dimension of feature vectors
Returns: Initial normalizer state with zero mean and unit variance
Source code in src/alberta_framework/core/normalizers.py
normalize(state, observation)
¶
Normalize observation and update EMA running statistics.
Args: state: Current EMA normalizer state observation: Raw feature vector
Returns: Tuple of (normalized_observation, new_state)
Source code in src/alberta_framework/core/normalizers.py
EMANormalizerState
¶
State for EMA-based online feature normalization.
Uses exponential moving average to estimate running mean and variance, suitable for non-stationary distributions.
Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature sample_count: Number of samples seen decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
Normalizer(epsilon=1e-08)
¶
Bases: ABC
Abstract base class for online feature normalizers.
Normalizes features using running estimates of mean and standard deviation:
x_normalized = (x - mean) / (std + epsilon)
The normalizer updates its estimates at every time step, following temporal uniformity.
Subclasses must implement init and normalize. The normalize_only
and update_only methods have default implementations.
Attributes: epsilon: Small constant for numerical stability
Args: epsilon: Small constant added to std for numerical stability
Source code in src/alberta_framework/core/normalizers.py
init(feature_dim)
abstractmethod
¶
Initialize normalizer state.
Args: feature_dim: Dimension of feature vectors
Returns: Initial normalizer state with zero mean and unit variance
Source code in src/alberta_framework/core/normalizers.py
normalize(state, observation)
abstractmethod
¶
Normalize observation and update running statistics.
This method both normalizes the current observation AND updates the running statistics, maintaining temporal uniformity.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Tuple of (normalized_observation, new_state)
Source code in src/alberta_framework/core/normalizers.py
normalize_only(state, observation)
¶
Normalize observation without updating statistics.
Useful for inference or when you want to normalize multiple observations with the same statistics.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Normalized observation
Source code in src/alberta_framework/core/normalizers.py
update_only(state, observation)
¶
Update statistics without returning normalized observation.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Updated normalizer state
Source code in src/alberta_framework/core/normalizers.py
WelfordNormalizer(epsilon=1e-08)
¶
Bases: Normalizer[WelfordNormalizerState]
Online feature normalizer using Welford's algorithm.
Computes cumulative sample mean and variance with Bessel's correction, suitable for stationary distributions. Numerically stable for large sample counts.
Reference: Welford 1962, "Note on a Method for Calculating Corrected Sums of Squares and Products"
Attributes: epsilon: Small constant for numerical stability
Args: epsilon: Small constant added to std for numerical stability
Source code in src/alberta_framework/core/normalizers.py
normalize_only(state, observation)
¶
Normalize observation without updating statistics.
Useful for inference or when you want to normalize multiple observations with the same statistics.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Normalized observation
Source code in src/alberta_framework/core/normalizers.py
update_only(state, observation)
¶
Update statistics without returning normalized observation.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Updated normalizer state
Source code in src/alberta_framework/core/normalizers.py
init(feature_dim)
¶
Initialize Welford normalizer state.
Args: feature_dim: Dimension of feature vectors
Returns: Initial normalizer state with zero mean and unit variance
Source code in src/alberta_framework/core/normalizers.py
normalize(state, observation)
¶
Normalize observation and update Welford running statistics.
Uses Welford's online algorithm: 1. Increment count 2. Update mean incrementally 3. Update sum of squared deviations (p / M2) 4. Compute variance with Bessel's correction when count >= 2
Args: state: Current Welford normalizer state observation: Raw feature vector
Returns: Tuple of (normalized_observation, new_state)
Source code in src/alberta_framework/core/normalizers.py
WelfordNormalizerState
¶
State for Welford's online normalization algorithm.
Uses Welford's algorithm for numerically stable estimation of cumulative sample mean and variance with Bessel's correction.
Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature (Bessel-corrected) sample_count: Number of samples seen p: Sum of squared deviations from the current mean (M2 accumulator)
IDBD(initial_step_size=0.01, meta_step_size=0.01)
¶
Incremental Delta-Bar-Delta optimizer.
IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.
This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes
Source code in src/alberta_framework/core/optimizers.py
init_for_shape(shape)
¶
Initialize optimizer state for parameters of arbitrary shape.
Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.
The return type varies by subclass (e.g. LMSState for LMS,
AutostepParamState for Autostep) so the base signature uses
Any.
Args: shape: Shape of the parameter array
Returns: Initial optimizer state with arrays matching the given shape
Raises: NotImplementedError: If the optimizer does not support this
Source code in src/alberta_framework/core/optimizers.py
update_from_gradient(state, gradient, error=None)
¶
Compute step delta from pre-computed gradient.
The returned delta does NOT include the error -- the caller is
responsible for multiplying error * delta before applying.
The state type varies by subclass (e.g. LMSState for LMS,
AutostepParamState for Autostep) so the base signature uses
Any.
Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.
Returns:
(step, new_state) where step has the same shape as gradient
Raises: NotImplementedError: If the optimizer does not support this
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize IDBD state.
Args: feature_dim: Dimension of weight vector
Returns: IDBD state with per-weight step-sizes and traces
Source code in src/alberta_framework/core/optimizers.py
update(state, error, observation)
¶
Compute IDBD weight update with adaptive step-sizes.
Following Sutton 1992, Figure 2, the operation ordering is:
- Meta-update:
log_alpha_i += beta * error * x_i * h_i(using OLD traces) - Compute NEW step-sizes:
alpha_i = exp(log_alpha_i) - Update weights:
w_i += alpha_i * error * x_i(using NEW alpha) - Update traces:
h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i(using NEW alpha)
The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.
Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector
Returns: OptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 | |
LMS(step_size=0.01)
¶
Least Mean Square optimizer with fixed step-size.
The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.
Attributes: step_size: Fixed learning rate alpha
Args: step_size: Fixed learning rate
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize LMS state.
Args: feature_dim: Dimension of weight vector (unused for LMS)
Returns: LMS state containing the step-size
Source code in src/alberta_framework/core/optimizers.py
init_for_shape(shape)
¶
Initialize LMS state for arbitrary-shape parameters.
LMS state is shape-independent (single scalar), so this returns the same state regardless of shape.
Source code in src/alberta_framework/core/optimizers.py
update_from_gradient(state, gradient, error=None)
¶
Compute step from gradient: step = alpha * gradient.
Args: state: Current LMS state gradient: Pre-computed gradient (any shape) error: Unused by LMS (accepted for interface compatibility)
Returns:
(step, state) -- state is unchanged for LMS
Source code in src/alberta_framework/core/optimizers.py
update(state, error, observation)
¶
Compute LMS weight update.
Update rule: delta_w = alpha * error * x
Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector
Returns: OptimizerUpdate with weight and bias deltas
Source code in src/alberta_framework/core/optimizers.py
TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)
¶
Bases: TDOptimizer[TDIDBDState]
TD-IDBD optimizer for temporal-difference learning.
Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.
Two variants are supported: - Semi-gradient (default): Uses only phi(s) in meta-update, more stable - Ordinary gradient: Uses both phi(s) and phi(s'), more accurate but sensitive
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize TD-IDBD state.
Args: feature_dim: Dimension of weight vector
Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces
Source code in src/alberta_framework/core/optimizers.py
update(state, td_error, observation, next_observation, gamma)
¶
Compute TD-IDBD weight update with adaptive step-sizes.
Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.
Args: state: Current TD-IDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)
Returns: TDOptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 | |
AGCBounding(clip_factor=0.01, eps=0.001)
¶
Bases: Bounder
Adaptive Gradient Clipping (Brock et al. 2021).
Clips per-output-unit based on the ratio of gradient norm to weight norm.
Units where ||grad|| / max(||weight||, eps) > clip_factor get scaled
down to respect the constraint.
Unlike ObGDBounding which applies a single global scale factor, AGC applies fine-grained, per-unit clipping that adapts to each layer's weight magnitude.
The metric returned is the fraction of units that were clipped (0.0 = no clipping, 1.0 = all units clipped).
Reference: Brock, A., De, S., Smith, S.L., & Simonyan, K. (2021). "High-Performance Large-Scale Image Recognition Without Normalization" (arXiv: 2102.06171)
Attributes: clip_factor: Maximum allowed gradient-to-weight ratio (lambda). Default 0.01. eps: Floor for weight norm to avoid division by zero. Default 1e-3.
Source code in src/alberta_framework/core/optimizers.py
bound(steps, error, params)
¶
Bound proposed steps using per-unit adaptive gradient clipping.
For each parameter/step pair, computes unit-wise norms and clips
units where |error| * ||step|| > clip_factor * max(||param||, eps).
Args: steps: Per-parameter step arrays from the optimizer error: Prediction error scalar params: Current parameter values (used for weight norms)
Returns:
(clipped_steps, frac_clipped) where frac_clipped is the
fraction of units that were clipped
Source code in src/alberta_framework/core/optimizers.py
Autostep(initial_step_size=0.01, meta_step_size=0.01, tau=10000.0)
¶
Bases: Optimizer[AutostepState]
Autostep optimizer with tuning-free step-size adaptation.
Implements the exact algorithm from Mahmood et al. 2012, Table 1.
The algorithm maintains per-weight step-sizes that adapt based on
meta-gradient correlation. The key innovations are:
- Self-regulated normalizers (v_i) that track meta-gradient magnitude
|delta * x_i * h_i| for stable meta-updates
- Overshoot prevention via effective step-size normalization
M = max(sum(alpha_i * x_i^2), 1)
Per-sample update (Table 1):
v_i = max(|delta*x_i*h_i|, v_i + (1/tau)*alpha_i*x_i^2*(|delta*x_i*h_i| - v_i))alpha_i *= exp(mu * delta*x_i*h_i / v_i)wherev_i > 0M = max(sum(alpha_i * x_i^2), 1);alpha_i /= Mw_i += alpha_i * delta * x_i(weight update with NEW alpha)h_i = h_i * (1 - alpha_i * x_i^2) + alpha_i * delta * x_i(trace update)
Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012). "Tuning-free step-size adaptation"
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate mu for adapting step-sizes tau: Time constant for normalizer adaptation (default: 10000)
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate for adapting step-sizes tau: Time constant for normalizer adaptation (default: 10000). Higher values mean slower normalizer decay.
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize Autostep state.
Normalizers (v_i) and traces (h_i) are initialized to 0 per the paper.
Args: feature_dim: Dimension of weight vector
Returns: Autostep state with per-weight step-sizes, traces, and normalizers
Source code in src/alberta_framework/core/optimizers.py
init_for_shape(shape)
¶
Initialize Autostep state for arbitrary-shape parameters.
Args: shape: Shape of the parameter array
Returns: AutostepParamState with arrays matching the given shape
Source code in src/alberta_framework/core/optimizers.py
update_from_gradient(state, gradient, error=None)
¶
Compute Autostep update from pre-computed gradient (MLP path).
Implements the Table 1 algorithm generalized for arbitrary-shape
parameters, where gradient plays the role of the eligibility
trace z (prediction gradient).
When error is provided, the full paper algorithm is used:
meta-gradient is error * z * h. When error is None,
falls back to error-free approximation (z * h).
The returned step does NOT include the error -- the caller applies
param += error * step after optional bounding.
Args: state: Current Autostep param state gradient: Pre-computed gradient / eligibility trace (same shape as state arrays) error: Optional prediction error scalar. When provided, enables the full paper algorithm with error-scaled meta-gradients.
Returns:
(step, new_state) where step has the same shape as gradient
Source code in src/alberta_framework/core/optimizers.py
613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 | |
update(state, error, observation)
¶
Compute Autostep weight update following Mahmood et al. 2012, Table 1.
The algorithm per sample:
- Eq. 4:
v_i = max(|δ*x_i*h_i|, v_i + (1/τ)*α_i*x_i²*(|δ*x_i*h_i| - v_i)) - Eq. 5:
α_i *= exp(μ * δ*x_i*h_i / v_i)wherev_i > 0 - Eq. 6-7:
M = max(Σ α_i*x_i² + α_bias, 1);α_i /= M,α_bias /= M - Weight update:
w_i += α_i * δ * x_i(with NEW alpha) - Trace update:
h_i = h_i*(1 - α_i*x_i²) + α_i*δ*x_i
Args: state: Current Autostep state error: Prediction error (scalar) observation: Feature vector
Returns: OptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 | |
AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)
¶
Bases: TDOptimizer[AutoTDIDBDState]
AutoStep-style normalized TD-IDBD optimizer.
Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta.
Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(lambda)"
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize AutoTDIDBD state.
Args: feature_dim: Dimension of weight vector
Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
Source code in src/alberta_framework/core/optimizers.py
update(state, td_error, observation, next_observation, gamma)
¶
Compute AutoTDIDBD weight update with normalized adaptive step-sizes.
Implements Algorithm 6 from Kearney et al. 2019.
Args: state: Current AutoTDIDBD state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)
Returns: TDOptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 | |
Bounder
¶
Bases: ABC
Base class for update bounding strategies.
A bounder takes the proposed per-parameter step arrays from an optimizer and optionally scales them down to prevent overshooting.
bound(steps, error, params)
abstractmethod
¶
Bound proposed update steps.
Args: steps: Per-parameter step arrays from the optimizer error: Prediction error scalar params: Current parameter values (needed by some bounders like AGC)
Returns:
(bounded_steps, metric) where metric is a scalar for reporting
(e.g., scale factor for ObGD, mean clip ratio for AGC)
Source code in src/alberta_framework/core/optimizers.py
ObGD(step_size=1.0, kappa=2.0, gamma=0.0, lamda=0.0)
¶
Observation-bounded Gradient Descent optimizer.
ObGD prevents overshooting by dynamically bounding the effective step-size based on the magnitude of the prediction error and eligibility traces. When the combined update magnitude would be too large, the step-size is scaled down to prevent the prediction from overshooting the target.
This is the deep-network generalization of Autostep's overshooting prevention, designed for streaming reinforcement learning.
For supervised learning (gamma=0, lamda=0), traces equal the current observation each step, making ObGD equivalent to LMS with dynamic step-size bounding.
The ObGD algorithm:
- Update traces:
z = gamma * lamda * z + observation - Compute bound:
M = alpha * kappa * max(|error|, 1) * (||z_w||_1 + |z_b|) - Effective step:
alpha_eff = min(alpha, alpha / M)(i.e.alpha / max(M, 1)) - Weight delta:
delta_w = alpha_eff * error * z_w - Bias delta:
delta_b = alpha_eff * error * z_b
Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"
Attributes: step_size: Base learning rate alpha kappa: Bounding sensitivity parameter (higher = more conservative) gamma: Discount factor for trace decay (0 for supervised learning) lamda: Eligibility trace decay parameter (0 for supervised learning)
Args: step_size: Base learning rate (default: 1.0) kappa: Bounding sensitivity parameter (default: 2.0) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
Source code in src/alberta_framework/core/optimizers.py
init_for_shape(shape)
¶
Initialize optimizer state for parameters of arbitrary shape.
Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.
The return type varies by subclass (e.g. LMSState for LMS,
AutostepParamState for Autostep) so the base signature uses
Any.
Args: shape: Shape of the parameter array
Returns: Initial optimizer state with arrays matching the given shape
Raises: NotImplementedError: If the optimizer does not support this
Source code in src/alberta_framework/core/optimizers.py
update_from_gradient(state, gradient, error=None)
¶
Compute step delta from pre-computed gradient.
The returned delta does NOT include the error -- the caller is
responsible for multiplying error * delta before applying.
The state type varies by subclass (e.g. LMSState for LMS,
AutostepParamState for Autostep) so the base signature uses
Any.
Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.
Returns:
(step, new_state) where step has the same shape as gradient
Raises: NotImplementedError: If the optimizer does not support this
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize ObGD state.
Args: feature_dim: Dimension of weight vector
Returns: ObGD state with eligibility traces
Source code in src/alberta_framework/core/optimizers.py
update(state, error, observation)
¶
Compute ObGD weight update with overshooting prevention.
The bounding mechanism scales down the step-size when the combined effect of error magnitude, trace norm, and step-size would cause the prediction to overshoot the target.
Args: state: Current ObGD state error: Prediction error (target - prediction) observation: Current observation/feature vector
Returns: OptimizerUpdate with bounded weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
ObGDBounding(kappa=2.0)
¶
Bases: Bounder
ObGD-style global update bounding (Elsayed et al. 2024).
Computes a global bounding factor from the L1 norm of all proposed steps and the error magnitude, then uniformly scales all steps down if the combined update would be too large.
For LMS with a single scalar step-size alpha:
total_step = alpha * z_sum, giving
M = alpha * kappa * max(|error|, 1) * z_sum -- identical to
the original Elsayed et al. 2024 formula.
Attributes: kappa: Bounding sensitivity parameter (higher = more conservative)
Source code in src/alberta_framework/core/optimizers.py
bound(steps, error, params)
¶
Bound proposed steps using ObGD formula.
Args: steps: Per-parameter step arrays error: Prediction error scalar params: Current parameter values (unused by ObGD)
Returns:
(bounded_steps, scale) where scale is the bounding factor
Source code in src/alberta_framework/core/optimizers.py
Optimizer
¶
Bases: ABC
Base class for optimizers.
init(feature_dim)
abstractmethod
¶
Initialize optimizer state.
Args: feature_dim: Dimension of weight vector
Returns: Initial optimizer state
update(state, error, observation)
abstractmethod
¶
Compute weight updates given prediction error.
Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector
Returns: OptimizerUpdate with deltas and new state
Source code in src/alberta_framework/core/optimizers.py
init_for_shape(shape)
¶
Initialize optimizer state for parameters of arbitrary shape.
Used by MLP learners where parameters are matrices/vectors of varying shapes. Not all optimizers support this.
The return type varies by subclass (e.g. LMSState for LMS,
AutostepParamState for Autostep) so the base signature uses
Any.
Args: shape: Shape of the parameter array
Returns: Initial optimizer state with arrays matching the given shape
Raises: NotImplementedError: If the optimizer does not support this
Source code in src/alberta_framework/core/optimizers.py
update_from_gradient(state, gradient, error=None)
¶
Compute step delta from pre-computed gradient.
The returned delta does NOT include the error -- the caller is
responsible for multiplying error * delta before applying.
The state type varies by subclass (e.g. LMSState for LMS,
AutostepParamState for Autostep) so the base signature uses
Any.
Args: state: Current optimizer state gradient: Pre-computed gradient (e.g. eligibility trace) error: Optional prediction error scalar. Optimizers with meta-learning (e.g. Autostep) use this for meta-gradient computation. LMS ignores it.
Returns:
(step, new_state) where step has the same shape as gradient
Raises: NotImplementedError: If the optimizer does not support this
Source code in src/alberta_framework/core/optimizers.py
TDOptimizer
¶
Bases: ABC
Base class for TD optimizers.
TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.
init(feature_dim)
abstractmethod
¶
Initialize optimizer state.
Args: feature_dim: Dimension of weight vector
Returns: Initial optimizer state
update(state, td_error, observation, next_observation, gamma)
abstractmethod
¶
Compute weight updates given TD error.
Args: state: Current optimizer state td_error: TD error delta = R + gamma*V(s') - V(s) observation: Current observation phi(s) next_observation: Next observation phi(s') gamma: Discount factor gamma (0 at terminal)
Returns: TDOptimizerUpdate with deltas and new state
Source code in src/alberta_framework/core/optimizers.py
TDOptimizerUpdate
¶
Result of a TD optimizer update step.
Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging
AutostepParamState
¶
Per-parameter Autostep state for use with arbitrary-shape parameters.
Used by Autostep.init_for_shape / Autostep.update_from_gradient
for MLP (or other multi-parameter) learners. Unlike AutostepState,
this type has no bias-specific fields -- each parameter (weight matrix,
bias vector) gets its own AutostepParamState.
Attributes: step_sizes: Per-element step-sizes, same shape as the parameter traces: Per-element traces for gradient correlation normalizers: Running normalizer of meta-gradient magnitude |deltazh| meta_step_size: Meta learning rate mu tau: Time constant for normalizer adaptation
AutostepState
¶
State for the Autostep optimizer.
Autostep is a tuning-free step-size adaptation algorithm that adapts per-weight step-sizes based on meta-gradient correlation, with self-regulated normalizers to stabilize the meta-update.
Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation", Table 1
Attributes: step_sizes: Per-weight step-sizes (alpha_i) traces: Per-weight traces for gradient correlation (h_i) normalizers: Running normalizer of meta-gradient magnitude |deltaxh| (v_i) meta_step_size: Meta learning rate mu for adapting step-sizes tau: Time constant for normalizer adaptation (higher = slower decay) bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term bias_normalizer: Normalizer for the bias meta-gradient
AutoTDIDBDState
¶
State for the AutoTDIDBD optimizer.
AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.
Reference: Kearney et al. 2019, Algorithm 6
Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation
BatchedLearningResult
¶
Result from batched learning loop across multiple seeds.
Used with run_learning_loop_batched for vmap-based GPU parallelization.
Attributes: states: Batched learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, num_cols) where num_cols is 3 (no normalizer) or 4 (with normalizer) step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled
BatchedMLPResult
¶
Result from batched MLP learning loop across multiple seeds.
Used with run_mlp_learning_loop_batched for vmap-based GPU parallelization.
Attributes: states: Batched MLP learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, num_cols) where num_cols is 3 (no normalizer) or 4 (with normalizer) normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled
IDBDState
¶
State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.
Reference: Sutton 1992, "Adapting Bias by Gradient Descent"
Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term
LearnerState
¶
State for a linear learner.
Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer normalizer_state: Optional state for online feature normalization
LMSState
¶
State for the LMS (Least Mean Square) optimizer.
LMS uses a fixed step-size, so state only tracks the step-size parameter.
Attributes: step_size: Fixed learning rate alpha
MLPLearnerState
¶
State for an MLP learner.
Attributes: params: MLP parameters (weights and biases for each layer) optimizer_states: Tuple of per-parameter optimizer states (weights + biases) traces: Tuple of per-parameter eligibility traces normalizer_state: Optional state for online feature normalization
MLPParams
¶
Parameters for a multi-layer perceptron.
Uses tuples of arrays (not lists) for proper JAX PyTree handling.
Attributes: weights: Tuple of weight matrices, one per layer biases: Tuple of bias vectors, one per layer
NormalizerHistory
¶
History of per-feature normalizer state recorded during training.
Used for analyzing how the normalizer (EMA or Welford) adapts to distribution shifts (reactive lag diagnostic).
Attributes: means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim) variances: Per-feature variance estimates at each recording, shape (num_recordings, feature_dim) recording_indices: Step indices where recordings were made, shape (num_recordings,)
NormalizerTrackingConfig
¶
Configuration for recording per-feature normalizer state during training.
Attributes: interval: Record normalizer state every N steps
ObGDState
¶
State for the ObGD (Observation-bounded Gradient Descent) optimizer.
ObGD prevents overshooting by dynamically bounding the effective step-size based on the magnitude of the TD error and eligibility traces. When the combined update magnitude would be too large, the step-size is scaled down.
For supervised learning (gamma=0, lamda=0), traces equal the current observation each step, making ObGD equivalent to LMS with dynamic step-size bounding.
Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"
Attributes: step_size: Base learning rate alpha kappa: Bounding sensitivity parameter (higher = more conservative) traces: Per-weight eligibility traces z_i bias_trace: Eligibility trace for the bias term gamma: Discount factor for trace decay lamda: Eligibility trace decay parameter lambda
StepSizeHistory
¶
History of per-weight step-sizes recorded during training.
Attributes: step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights) bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None recording_indices: Step indices where recordings were made, shape (num_recordings,) normalizers: Autostep's per-weight normalizers (v_i) at each recording, shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
StepSizeTrackingConfig
¶
Configuration for recording per-weight step-sizes during training.
Attributes: interval: Record step-sizes every N steps include_bias: Whether to also record the bias step-size
TDIDBDState
¶
State for the TD-IDBD (Temporal-Difference IDBD) optimizer.
TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"
Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term
TDLearnerState
¶
State for a TD linear learner.
Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer
TDTimeStep
¶
Single experience from a TD stream.
Represents a transition (s, r, s', gamma) for temporal-difference learning.
Attributes: observation: Feature vector phi(s) reward: Reward R received next_observation: Feature vector phi(s') gamma: Discount factor gamma_t (0 at terminal states)
TimeStep
¶
Single experience from an experience stream.
Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)
ScanStream
¶
Bases: Protocol[StateT]
Protocol for JAX scan-compatible experience streams.
Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.
The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.
Type Parameters: StateT: The state type maintained by this stream
Examples:
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key for initialization
Returns: Initial stream state
step(state, idx)
¶
Generate one time step. Must be JIT-compatible.
This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.
Args: state: Current stream state idx: Current step index (can be ignored for most streams)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/base.py
AbruptChangeState
¶
State for AbruptChangeStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken
AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream with sudden target weight changes.
Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.
Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
CyclicState
¶
State for CyclicStream.
Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken
CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream that cycles between known weight configurations.
Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.
Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with pre-generated configurations
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
DynamicScaleShiftState
¶
State for DynamicScaleShiftStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken
DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)
¶
Non-stationary stream with abruptly changing feature scales.
Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.
The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).
Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise
Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights and scales
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
PeriodicChangeState
¶
State for PeriodicChangeStream.
Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken
PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream where target weights oscillate sinusoidally.
Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase) where each weight has a random phase offset for diversity.
This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.
Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random base weights and phases
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
RandomWalkState
¶
State for RandomWalkStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights
RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream where target weights drift via random walk.
The true target function is linear: y* = w_true @ x + noise
where w_true evolves via random walk at each time step.
This tests the learner's ability to continuously track a moving target.
Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
ScaleDriftState
¶
State for ScaleDriftStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken
ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)
¶
Non-stationary stream where feature scales drift via random walk.
Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.
The target is computed from unscaled features to maintain consistent difficulty across scale changes.
Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise
Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights and unit scales
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
ScaledStreamState
¶
State for ScaledStreamWrapper.
Attributes: inner_state: State of the wrapped stream
ScaledStreamWrapper(inner_stream, feature_scales)
¶
Wrapper that applies per-feature scaling to any stream's observations.
This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.
Examples:
stream = ScaledStreamWrapper(
AbruptChangeStream(feature_dim=10, change_interval=1000),
feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
100.0, 1000.0, 0.001, 0.01, 0.1])
)
Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)
Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.
Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
inner_stream
property
¶
Return the wrapped stream.
feature_scales
property
¶
Return the per-feature scale factors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state wrapping the inner stream's state
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step with scaled observations.
Args: state: Current stream state idx: Current step index
Returns: Tuple of (timestep with scaled observation, new_state)
Source code in src/alberta_framework/streams/synthetic.py
SuttonExperiment1State
¶
State for SuttonExperiment1Stream.
Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken
SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)
¶
Non-stationary stream replicating Experiment 1 from Sutton 1992.
This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"
Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)
Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with all +1 signs
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
Timer(name='Operation', verbose=True, print_fn=None)
¶
Context manager for timing code execution.
Measures wall-clock time for a block of code and optionally prints the duration when the block completes.
Attributes: name: Description of what is being timed duration: Elapsed time in seconds (available after context exits) start_time: Timestamp when timing started end_time: Timestamp when timing ended
Examples:
with Timer("Training loop"):
for i in range(1000):
pass
# Output: Training loop completed in 0.01s
# Silent timing (no print):
with Timer("Silent", verbose=False) as t:
time.sleep(0.1)
print(f"Elapsed: {t.duration:.2f}s")
# Output: Elapsed: 0.10s
# Custom print function:
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
pass
# Output: >> Custom completed in 0.00s
Args: name: Description of the operation being timed verbose: Whether to print the duration when done print_fn: Custom print function (defaults to built-in print)
Source code in src/alberta_framework/utils/timing.py
elapsed()
¶
Get elapsed time since timer started (can be called during execution).
Returns: Elapsed time in seconds
GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)
¶
Experience stream from a Gymnasium environment using Python loop.
This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.
For batch learning, use collect_trajectory() followed by learn_from_trajectory().
Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes
Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy
Source code in src/alberta_framework/streams/gymnasium.py
feature_dim
property
¶
Return the dimension of feature vectors.
target_dim
property
¶
Return the dimension of target vectors.
episode_count
property
¶
Return the number of completed episodes.
step_count
property
¶
Return the total number of steps taken.
mode
property
¶
Return the prediction mode.
set_value_estimator(estimator)
¶
Set the value estimator for proper TD learning in VALUE mode.
PredictionMode
¶
Bases: Enum
Mode for what the stream predicts.
REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)
TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)
¶
Experience stream for proper TD learning with value function bootstrap.
This stream integrates with a learner to use its predictions for bootstrapping in TD targets.
Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)
for step, timestep in enumerate(stream):
result = learner.update(state, timestep.observation, timestep.target)
state = result.state
stream.update_value_function(lambda x: learner.predict(state, x))
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed
Source code in src/alberta_framework/streams/gymnasium.py
sparse_init(key, shape, sparsity=0.9, init_type='uniform')
¶
Create a sparsely initialized weight matrix.
Applies LeCun-scale initialization and then zeros out a fraction of weights per output neuron. This creates sparser gradient flows that improve stability in streaming learning settings.
Reference: Elsayed et al. 2024, sparse_init.py
Args: key: JAX random key shape: Weight matrix shape (fan_out, fan_in) sparsity: Fraction of input connections to zero out per output neuron (default: 0.9 means 90% sparse) init_type: Initialization distribution, "uniform" or "normal" (default: "uniform" for LeCun uniform)
Returns: Weight matrix of given shape with specified sparsity
Examples:
import jax.random as jr
from alberta_framework.core.initializers import sparse_init
key = jr.key(42)
weights = sparse_init(key, (128, 10), sparsity=0.9)
# weights has shape (128, 10), ~90% zeros per row
Source code in src/alberta_framework/core/initializers.py
metrics_to_dicts(metrics, normalized=False)
¶
Convert metrics array to list of dicts for backward compatibility.
Args: metrics: Array of shape (num_steps, 3) or (num_steps, 4) normalized: If True, expects 4 columns including normalizer_mean_var
Returns: List of metric dictionaries
Source code in src/alberta_framework/core/learners.py
run_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None, normalizer_tracking=None)
¶
Run the learning loop using jax.lax.scan.
This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.
Supports both plain and normalized learners. When the learner has a normalizer, metrics have 4 columns; otherwise 3 columns.
Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns StepSizeHistory. normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory with means and variances over time.
Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) or (num_steps, 4) depending on normalizer If step_size_tracking only: Tuple of (final_state, metrics_array, step_size_history) If normalizer_tracking only: Tuple of (final_state, metrics_array, normalizer_history) If both: Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
Raises: ValueError: If tracking interval is invalid
Source code in src/alberta_framework/core/learners.py
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 | |
run_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None, normalizer_tracking=None)
¶
Run learning loop across multiple seeds in parallel using jax.vmap.
This function provides GPU parallelization for multi-seed experiments, typically achieving 2-5x speedup over sequential execution.
Supports both plain and normalized learners.
Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...) normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)
Returns: BatchedLearningResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, num_cols) - step_size_history: Batched history or None if tracking disabled - normalizer_history: Batched history or None if tracking disabled
Examples:
import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_learning_loop_batched
stream = RandomWalkStream(feature_dim=10)
learner = LinearLearner(optimizer=IDBD())
# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
Source code in src/alberta_framework/core/learners.py
605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 | |
run_mlp_learning_loop(learner, stream, num_steps, key, learner_state=None, normalizer_tracking=None)
¶
Run the MLP learning loop using jax.lax.scan.
This is a JIT-compiled learning loop that uses scan for efficiency.
Args: learner: The MLP learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream and weight initialization learner_state: Initial state (if None, will be initialized from stream) normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory.
Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) or (num_steps, 4) If normalizer_tracking: Tuple of (final_state, metrics_array, normalizer_history)
Raises: ValueError: If normalizer_tracking.interval is invalid
Source code in src/alberta_framework/core/learners.py
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 | |
run_mlp_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, normalizer_tracking=None)
¶
Run MLP learning loop across multiple seeds in parallel using jax.vmap.
This function provides GPU parallelization for multi-seed MLP experiments, typically achieving 2-5x speedup over sequential execution.
Args: learner: The MLP learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)
Returns: BatchedMLPResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, num_cols) - normalizer_history: Batched history or None if tracking disabled
Examples:
import jax.random as jr
from alberta_framework import MLPLearner, RandomWalkStream
from alberta_framework import run_mlp_learning_loop_batched
stream = RandomWalkStream(feature_dim=10)
learner = MLPLearner(hidden_sizes=(128, 128))
# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_mlp_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
Source code in src/alberta_framework/core/learners.py
1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 | |
run_td_learning_loop(learner, stream, num_steps, key, learner_state=None)
¶
Run the TD learning loop using jax.lax.scan.
This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.
Args: learner: The TD learner to train stream: TD experience stream providing (s, r, s', gamma) tuples num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]
Source code in src/alberta_framework/core/learners.py
create_autotdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)
¶
Create initial AutoTDIDBD optimizer state.
Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)
Returns: Initial AutoTDIDBD state
Source code in src/alberta_framework/core/types.py
create_obgd_state(feature_dim, step_size=1.0, kappa=2.0, gamma=0.0, lamda=0.0)
¶
Create initial ObGD optimizer state.
Args: feature_dim: Dimension of the feature vector step_size: Base learning rate (default: 1.0) kappa: Bounding sensitivity parameter (default: 2.0) gamma: Discount factor for trace decay (default: 0.0 for supervised) lamda: Eligibility trace decay parameter (default: 0.0 for supervised)
Returns: Initial ObGD state
Source code in src/alberta_framework/core/types.py
create_tdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0)
¶
Create initial TD-IDBD optimizer state.
Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
Returns: Initial TD-IDBD state
Source code in src/alberta_framework/core/types.py
make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)
¶
Create a per-feature scale array spanning a range.
Utility function to generate scale factors for ScaledStreamWrapper.
Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.
Returns: Array of shape (feature_dim,) with scale factors
Examples:
scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
compare_learners(results, metric='squared_error')
¶
Compare multiple learners on a given metric.
Args: results: Dictionary mapping learner name to metrics history metric: Metric to compare
Returns: Dictionary with summary statistics for each learner
Source code in src/alberta_framework/utils/metrics.py
compute_cumulative_error(metrics_history, error_key='squared_error')
¶
Compute cumulative error over time.
Args: metrics_history: List of metric dictionaries from learning loop error_key: Key to extract error values
Returns: Array of cumulative errors at each time step
Source code in src/alberta_framework/utils/metrics.py
compute_running_mean(values, window_size=100)
¶
Compute running mean of values.
Args: values: Array of values window_size: Size of the moving average window
Returns: Array of running mean values (same length as input, padded at start)
Source code in src/alberta_framework/utils/metrics.py
compute_tracking_error(metrics_history, window_size=100)
¶
Compute tracking error (running mean of squared error).
This is the key metric for evaluating continual learners: how well can the learner track the non-stationary target?
Args: metrics_history: List of metric dictionaries from learning loop window_size: Size of the moving average window
Returns: Array of tracking errors at each time step
Source code in src/alberta_framework/utils/metrics.py
extract_metric(metrics_history, key)
¶
Extract a single metric from the history.
Args: metrics_history: List of metric dictionaries key: Key to extract
Returns: Array of values for that metric
Source code in src/alberta_framework/utils/metrics.py
format_duration(seconds)
¶
Format a duration in seconds as a human-readable string.
Args: seconds: Duration in seconds
Returns: Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"
Examples:
format_duration(0.5) # Returns: '0.50s'
format_duration(90.5) # Returns: '1m 30.50s'
format_duration(3665) # Returns: '1h 1m 5.00s'
Source code in src/alberta_framework/utils/timing.py
collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)
¶
Collect a trajectory from a Gymnasium environment.
This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy
Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory using jax.lax.scan.
This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.
Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory with normalization using jax.lax.scan.
This is equivalent to learn_from_trajectory for a learner constructed
with a normalizer (e.g. LinearLearner(optimizer=..., normalizer=EMANormalizer())).
Retained for backward compatibility.
Args: learner: The learner to train (should have a normalizer configured) observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
Source code in src/alberta_framework/streams/gymnasium.py
make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)
¶
Wrap a policy with epsilon-greedy exploration.
Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed
Returns: Epsilon-greedy policy
Source code in src/alberta_framework/streams/gymnasium.py
make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)
¶
Factory function to create a GymnasiumStream from an environment ID.
Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()
Returns: GymnasiumStream wrapping the environment
Source code in src/alberta_framework/streams/gymnasium.py
make_random_policy(env, seed=0)
¶
Create a random action policy for an environment.
Args: env: Gymnasium environment seed: Random seed
Returns: A callable that takes an observation and returns a random action