alberta_framework
alberta_framework
¶
Alberta Framework: A JAX-based research framework for continual AI.
The Alberta Framework provides foundational components for continual reinforcement learning research. Built on JAX for hardware acceleration, the framework emphasizes temporal uniformity — every component updates at every time step, with no special training phases or batch processing.
Roadmap
| Step | Focus | Status |
|---|---|---|
| 1 | Meta-learned step-sizes (IDBD, Autostep) | Complete |
| 2 | Feature generation and testing | Planned |
| 3 | GVF predictions, Horde architecture | Planned |
| 4 | Actor-critic with eligibility traces | Planned |
| 5-6 | Off-policy learning, average reward | Planned |
| 7-12 | Hierarchical, multi-agent, world models | Future |
Examples:
import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream, run_learning_loop
# Non-stationary stream where target weights drift over time
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
# Learner with IDBD meta-learned step-sizes
learner = LinearLearner(optimizer=IDBD())
# JIT-compiled training via jax.lax.scan
state, metrics = run_learning_loop(learner, stream, num_steps=10000, key=jr.key(42))
References
- The Alberta Plan for AI Research (Sutton et al., 2022): https://arxiv.org/abs/2208.11173
- Adapting Bias by Gradient Descent (Sutton, 1992)
- Tuning-free Step-size Adaptation (Mahmood et al., 2012)
LinearLearner(optimizer=None)
¶
Linear function approximator with pluggable optimizer.
Computes predictions as: y = w @ x + b
The learner maintains weights and bias, delegating the adaptation of learning rates to the optimizer (e.g., LMS or IDBD).
This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.
Attributes: optimizer: The optimizer to use for weight updates
Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01)
Source code in src/alberta_framework/core/learners.py
init(feature_dim)
¶
Initialize learner state.
Args: feature_dim: Dimension of the input feature vector
Returns: Initial learner state with zero weights and bias
Source code in src/alberta_framework/core/learners.py
predict(state, observation)
¶
Compute prediction for an observation.
Args: state: Current learner state observation: Input feature vector
Returns:
Scalar prediction y = w @ x + b
Source code in src/alberta_framework/core/learners.py
update(state, observation, target)
¶
Update learner given observation and target.
Performs one step of the learning algorithm: 1. Compute prediction 2. Compute error 3. Get weight updates from optimizer 4. Apply updates to weights and bias
Args: state: Current learner state observation: Input feature vector target: Desired output
Returns: UpdateResult with new state, prediction, error, and metrics
Source code in src/alberta_framework/core/learners.py
NormalizedLearnerState
¶
State for a learner with online feature normalization.
Attributes: learner_state: Underlying learner state (weights, bias, optimizer) normalizer_state: Online normalizer state (mean, var estimates)
NormalizedLinearLearner(optimizer=None, normalizer=None)
¶
Linear learner with online feature normalization.
Wraps a LinearLearner with online feature normalization, following the Alberta Plan's approach to handling varying feature scales.
Normalization is applied to features before prediction and learning: x_normalized = (x - mean) / (std + epsilon)
The normalizer statistics update at every time step, maintaining temporal uniformity.
Attributes: learner: Underlying linear learner normalizer: Online feature normalizer
Args: optimizer: Optimizer for weight updates. Defaults to LMS(0.01) normalizer: Feature normalizer. Defaults to OnlineNormalizer()
Source code in src/alberta_framework/core/learners.py
init(feature_dim)
¶
Initialize normalized learner state.
Args: feature_dim: Dimension of the input feature vector
Returns: Initial state with zero weights and unit variance estimates
Source code in src/alberta_framework/core/learners.py
predict(state, observation)
¶
Compute prediction for an observation.
Normalizes the observation using current statistics before prediction.
Args: state: Current normalized learner state observation: Raw (unnormalized) input feature vector
Returns: Scalar prediction y = w @ normalize(x) + b
Source code in src/alberta_framework/core/learners.py
update(state, observation, target)
¶
Update learner given observation and target.
Performs one step of the learning algorithm: 1. Normalize observation (and update normalizer statistics) 2. Compute prediction using normalized features 3. Compute error 4. Get weight updates from optimizer 5. Apply updates
Args: state: Current normalized learner state observation: Raw (unnormalized) input feature vector target: Desired output
Returns: NormalizedUpdateResult with new state, prediction, error, and metrics
Source code in src/alberta_framework/core/learners.py
TDLinearLearner(optimizer=None)
¶
Linear function approximator for TD learning.
Computes value predictions as: V(s) = w @ φ(s) + b
The learner maintains weights, bias, and eligibility traces, delegating the adaptation of learning rates to the TD optimizer (e.g., TDIDBD).
This follows the Alberta Plan philosophy of temporal uniformity: every component updates at every time step.
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"
Attributes: optimizer: The TD optimizer to use for weight updates
Args: optimizer: TD optimizer for weight updates. Defaults to TDIDBD()
Source code in src/alberta_framework/core/learners.py
init(feature_dim)
¶
Initialize TD learner state.
Args: feature_dim: Dimension of the input feature vector
Returns: Initial TD learner state with zero weights and bias
Source code in src/alberta_framework/core/learners.py
predict(state, observation)
¶
Compute value prediction for an observation.
Args: state: Current TD learner state observation: Input feature vector φ(s)
Returns:
Scalar value prediction V(s) = w @ φ(s) + b
Source code in src/alberta_framework/core/learners.py
update(state, observation, reward, next_observation, gamma)
¶
Update learner given a TD transition.
Performs one step of TD learning: 1. Compute V(s) and V(s') 2. Compute TD error δ = R + γV(s') - V(s) 3. Get weight updates from TD optimizer 4. Apply updates to weights and bias
Args: state: Current TD learner state observation: Current observation φ(s) reward: Reward R received next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal states)
Returns: TDUpdateResult with new state, prediction, TD error, and metrics
Source code in src/alberta_framework/core/learners.py
1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 | |
TDUpdateResult
¶
Result of a TD learner update step.
Attributes: state: Updated TD learner state prediction: Value prediction V(s) before update td_error: TD error δ = R + γV(s') - V(s) metrics: Array of metrics [squared_td_error, td_error, mean_step_size, ...]
UpdateResult
¶
Result of a learner update step.
Attributes: state: Updated learner state prediction: Prediction made before update error: Prediction error metrics: Array of metrics [squared_error, error, ...]
NormalizerState
¶
State for online feature normalization.
Uses Welford's online algorithm for numerically stable estimation of running mean and variance.
Attributes: mean: Running mean estimate per feature var: Running variance estimate per feature sample_count: Number of samples seen decay: Exponential decay factor for estimates (1.0 = no decay, pure online)
OnlineNormalizer(epsilon=1e-08, decay=0.99)
¶
Online feature normalizer for continual learning.
Normalizes features using running estimates of mean and standard deviation:
x_normalized = (x - mean) / (std + epsilon)
The normalizer updates its estimates at every time step, following temporal uniformity. Uses exponential moving average for non-stationary environments.
Attributes: epsilon: Small constant for numerical stability decay: Exponential decay for running estimates (0.99 = slower adaptation)
Args: epsilon: Small constant added to std for numerical stability decay: Exponential decay factor for running estimates. Lower values adapt faster to changes. 1.0 means pure online average (no decay).
Source code in src/alberta_framework/core/normalizers.py
init(feature_dim)
¶
Initialize normalizer state.
Args: feature_dim: Dimension of feature vectors
Returns: Initial normalizer state with zero mean and unit variance
Source code in src/alberta_framework/core/normalizers.py
normalize(state, observation)
¶
Normalize observation and update running statistics.
This method both normalizes the current observation AND updates the running statistics, maintaining temporal uniformity.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Tuple of (normalized_observation, new_state)
Source code in src/alberta_framework/core/normalizers.py
normalize_only(state, observation)
¶
Normalize observation without updating statistics.
Useful for inference or when you want to normalize multiple observations with the same statistics.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Normalized observation
Source code in src/alberta_framework/core/normalizers.py
update_only(state, observation)
¶
Update statistics without returning normalized observation.
Args: state: Current normalizer state observation: Raw feature vector
Returns: Updated normalizer state
Source code in src/alberta_framework/core/normalizers.py
IDBD(initial_step_size=0.01, meta_step_size=0.01)
¶
Incremental Delta-Bar-Delta optimizer.
IDBD maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation. When successive gradients agree in sign, the step-size for that weight increases. When they disagree, it decreases.
This implements Sutton's 1992 algorithm for adapting step-sizes online without requiring manual tuning.
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate beta for adapting step-sizes
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate beta for adapting step-sizes
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize IDBD state.
Args: feature_dim: Dimension of weight vector
Returns: IDBD state with per-weight step-sizes and traces
Source code in src/alberta_framework/core/optimizers.py
update(state, error, observation)
¶
Compute IDBD weight update with adaptive step-sizes.
The IDBD algorithm:
- Compute step-sizes:
alpha_i = exp(log_alpha_i) - Update weights:
w_i += alpha_i * error * x_i - Update log step-sizes:
log_alpha_i += beta * error * x_i * h_i - Update traces:
h_i = h_i * max(0, 1 - alpha_i * x_i^2) + alpha_i * error * x_i
The trace h_i tracks the correlation between current and past gradients. When gradients consistently point the same direction, h_i grows, leading to larger step-sizes.
Args: state: Current IDBD state error: Prediction error (scalar) observation: Feature vector
Returns: OptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 | |
LMS(step_size=0.01)
¶
Least Mean Square optimizer with fixed step-size.
The simplest gradient-based optimizer: w_{t+1} = w_t + alpha * delta * x_t
This serves as a baseline. The challenge is that the optimal step-size depends on the problem and changes as the task becomes non-stationary.
Attributes: step_size: Fixed learning rate alpha
Args: step_size: Fixed learning rate
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize LMS state.
Args: feature_dim: Dimension of weight vector (unused for LMS)
Returns: LMS state containing the step-size
Source code in src/alberta_framework/core/optimizers.py
update(state, error, observation)
¶
Compute LMS weight update.
Update rule: delta_w = alpha * error * x
Args: state: Current LMS state error: Prediction error (scalar) observation: Feature vector
Returns: OptimizerUpdate with weight and bias deltas
Source code in src/alberta_framework/core/optimizers.py
TDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, use_semi_gradient=True)
¶
Bases: TDOptimizer[TDIDBDState]
TD-IDBD optimizer for temporal-difference learning.
Extends IDBD to TD learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.
Two variants are supported: - Semi-gradient (default): Uses only φ(s) in meta-update, more stable - Ordinary gradient: Uses both φ(s) and φ(s'), more accurate but sensitive
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"
The semi-gradient TD-IDBD algorithm (Algorithm 3 in paper):
1. Compute TD error: δ = R + γ*w^T*φ(s') - w^T*φ(s)
2. Update meta-weights: β_i += θ*δ*φ_i(s)*h_i
3. Compute step-sizes: α_i = exp(β_i)
4. Update eligibility traces: z_i = γ*λ*z_i + φ_i(s)
5. Update weights: w_i += α_i*δ*z_i
6. Update h traces: h_i = h_i*[1 - α_i*φ_i(s)*z_i]^+ + α_i*δ*z_i
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda use_semi_gradient: If True, use semi-gradient variant (default)
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) use_semi_gradient: If True, use semi-gradient variant (recommended)
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize TD-IDBD state.
Args: feature_dim: Dimension of weight vector
Returns: TD-IDBD state with per-weight step-sizes, traces, and h traces
Source code in src/alberta_framework/core/optimizers.py
update(state, td_error, observation, next_observation, gamma)
¶
Compute TD-IDBD weight update with adaptive step-sizes.
Implements Algorithm 3 (semi-gradient) or Algorithm 4 (ordinary gradient) from Kearney et al. 2019.
Args: state: Current TD-IDBD state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)
Returns: TDOptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 | |
Autostep(initial_step_size=0.01, meta_step_size=0.01, normalizer_decay=0.99)
¶
Bases: Optimizer[AutostepState]
Autostep optimizer with tuning-free step-size adaptation.
Autostep normalizes gradients to prevent large updates and adapts per-weight step-sizes based on gradient correlation. The key innovation is automatic normalization that makes the algorithm robust to different feature scales.
The algorithm maintains: - Per-weight step-sizes that adapt based on gradient correlation - Running max of absolute gradients for normalization - Traces for detecting consistent gradient directions
Reference: Mahmood, A.R., Sutton, R.S., Degris, T., & Pilarski, P.M. (2012). "Tuning-free step-size adaptation"
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate mu for adapting step-sizes normalizer_decay: Decay factor tau for gradient normalizers
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate for adapting step-sizes normalizer_decay: Decay factor for gradient normalizers (higher = slower decay)
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize Autostep state.
Args: feature_dim: Dimension of weight vector
Returns: Autostep state with per-weight step-sizes, traces, and normalizers
Source code in src/alberta_framework/core/optimizers.py
update(state, error, observation)
¶
Compute Autostep weight update with normalized gradients.
The Autostep algorithm:
- Compute gradient:
g_i = error * x_i - Normalize gradient:
g_i' = g_i / max(|g_i|, v_i) - Update weights:
w_i += alpha_i * g_i' - Update step-sizes:
alpha_i *= exp(mu * g_i' * h_i) - Update traces:
h_i = h_i * (1 - alpha_i) + alpha_i * g_i' - Update normalizers:
v_i = max(|g_i|, v_i * tau)
Args: state: Current Autostep state error: Prediction error (scalar) observation: Feature vector
Returns: OptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 | |
AutoTDIDBD(initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)
¶
Bases: TDOptimizer[AutoTDIDBDState]
AutoStep-style normalized TD-IDBD optimizer.
Adds AutoStep-style normalization to TDIDBD for improved stability and reduced sensitivity to the meta step-size theta. Includes: 1. Normalization of the meta-weight update by a running trace of recent updates 2. Effective step-size normalization to prevent overshooting
Reference: Kearney et al. 2019, Algorithm 6 "AutoStep Style Normalized TIDBD(λ)"
The AutoTDIDBD algorithm:
1. Compute TD error: δ = R + γ*w^T*φ(s') - w^T*φ(s)
2. Update normalizers: η_i = max(|δ*[γφ_i(s')-φ_i(s)]*h_i|,
η_i - (1/τ)*α_i*[γφ_i(s')-φ_i(s)]*z_i*(|δ*φ_i(s)*h_i| - η_i))
3. Normalized meta-update: β_i -= θ*(1/η_i)*δ*[γφ_i(s')-φ_i(s)]*h_i
4. Effective step-size normalization: M = max(-exp(β)*[γφ(s')-φ(s)]^T*z, 1)
then β_i -= log(M)
5. Update weights and traces as in TIDBD
Attributes: initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay lambda normalizer_decay: Decay parameter tau for normalizers
Args: initial_step_size: Initial value for per-weight step-sizes meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)
Source code in src/alberta_framework/core/optimizers.py
init(feature_dim)
¶
Initialize AutoTDIDBD state.
Args: feature_dim: Dimension of weight vector
Returns: AutoTDIDBD state with per-weight step-sizes, traces, h traces, and normalizers
Source code in src/alberta_framework/core/optimizers.py
update(state, td_error, observation, next_observation, gamma)
¶
Compute AutoTDIDBD weight update with normalized adaptive step-sizes.
Implements Algorithm 6 from Kearney et al. 2019.
Args: state: Current AutoTDIDBD state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)
Returns: TDOptimizerUpdate with weight deltas and updated state
Source code in src/alberta_framework/core/optimizers.py
770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 | |
Optimizer
¶
Bases: ABC
Base class for optimizers.
init(feature_dim)
abstractmethod
¶
Initialize optimizer state.
Args: feature_dim: Dimension of weight vector
Returns: Initial optimizer state
update(state, error, observation)
abstractmethod
¶
Compute weight updates given prediction error.
Args: state: Current optimizer state error: Prediction error (target - prediction) observation: Current observation/feature vector
Returns: OptimizerUpdate with deltas and new state
Source code in src/alberta_framework/core/optimizers.py
TDOptimizer
¶
Bases: ABC
Base class for TD optimizers.
TD optimizers handle temporal-difference learning with eligibility traces. They take TD error and both current and next observations as input.
init(feature_dim)
abstractmethod
¶
Initialize optimizer state.
Args: feature_dim: Dimension of weight vector
Returns: Initial optimizer state
update(state, td_error, observation, next_observation, gamma)
abstractmethod
¶
Compute weight updates given TD error.
Args: state: Current optimizer state td_error: TD error δ = R + γV(s') - V(s) observation: Current observation φ(s) next_observation: Next observation φ(s') gamma: Discount factor γ (0 at terminal)
Returns: TDOptimizerUpdate with deltas and new state
Source code in src/alberta_framework/core/optimizers.py
TDOptimizerUpdate
¶
Result of a TD optimizer update step.
Attributes: weight_delta: Change to apply to weights bias_delta: Change to apply to bias new_state: Updated optimizer state metrics: Dictionary of metrics for logging
AutostepState
¶
State for the Autostep optimizer.
Autostep is a tuning-free step-size adaptation algorithm that normalizes gradients to prevent large updates and adapts step-sizes based on gradient correlation.
Reference: Mahmood et al. 2012, "Tuning-free step-size adaptation"
Attributes: step_sizes: Per-weight step-sizes (alpha_i) traces: Per-weight traces for gradient correlation (h_i) normalizers: Running max absolute gradient per weight (v_i) meta_step_size: Meta learning rate mu for adapting step-sizes normalizer_decay: Decay factor for the normalizer (tau) bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term bias_normalizer: Normalizer for the bias gradient
AutoTDIDBDState
¶
State for the AutoTDIDBD optimizer.
AutoTDIDBD adds AutoStep-style normalization to TDIDBD for improved stability. Includes normalizers for the meta-weight updates and effective step-size normalization to prevent overshooting.
Reference: Kearney et al. 2019, Algorithm 6
Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i h_traces: Per-weight h traces for gradient correlation normalizers: Running max of absolute gradient correlations (eta_i) meta_step_size: Meta learning rate theta trace_decay: Eligibility trace decay parameter lambda normalizer_decay: Decay parameter tau for normalizers bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term bias_normalizer: Normalizer for the bias gradient correlation
BatchedLearningResult
¶
Result from batched learning loop across multiple seeds.
Used with run_learning_loop_batched for vmap-based GPU parallelization.
Attributes: states: Batched learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, 3) where columns are [squared_error, error, mean_step_size] step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled
BatchedNormalizedResult
¶
Result from batched normalized learning loop across multiple seeds.
Used with run_normalized_learning_loop_batched for vmap-based GPU parallelization.
Attributes: states: Batched normalized learner states - each array has shape (num_seeds, ...) metrics: Metrics array with shape (num_seeds, num_steps, 4) where columns are [squared_error, error, mean_step_size, normalizer_mean_var] step_size_history: Optional step-size history with batched shapes, or None if tracking was disabled normalizer_history: Optional normalizer history with batched shapes, or None if tracking was disabled
IDBDState
¶
State for the IDBD (Incremental Delta-Bar-Delta) optimizer.
IDBD maintains per-weight adaptive step-sizes that are meta-learned based on the correlation of successive gradients.
Reference: Sutton 1992, "Adapting Bias by Gradient Descent"
Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) traces: Per-weight traces h_i for gradient correlation meta_step_size: Meta learning rate beta for adapting step-sizes bias_step_size: Step-size for the bias term bias_trace: Trace for the bias term
LearnerState
¶
State for a linear learner.
Attributes: weights: Weight vector for linear prediction bias: Bias term optimizer_state: State maintained by the optimizer
LMSState
¶
State for the LMS (Least Mean Square) optimizer.
LMS uses a fixed step-size, so state only tracks the step-size parameter.
Attributes: step_size: Fixed learning rate alpha
NormalizerHistory
¶
History of per-feature normalizer state recorded during training.
Used for analyzing how the OnlineNormalizer adapts to distribution shifts (reactive lag diagnostic).
Attributes: means: Per-feature mean estimates at each recording, shape (num_recordings, feature_dim) variances: Per-feature variance estimates at each recording, shape (num_recordings, feature_dim) recording_indices: Step indices where recordings were made, shape (num_recordings,)
NormalizerTrackingConfig
¶
Configuration for recording per-feature normalizer state during training.
Attributes: interval: Record normalizer state every N steps
StepSizeHistory
¶
History of per-weight step-sizes recorded during training.
Attributes: step_sizes: Per-weight step-sizes at each recording, shape (num_recordings, num_weights) bias_step_sizes: Bias step-sizes at each recording, shape (num_recordings,) or None recording_indices: Step indices where recordings were made, shape (num_recordings,) normalizers: Autostep's per-weight normalizers (v_i) at each recording, shape (num_recordings, num_weights) or None. Only populated for Autostep optimizer.
StepSizeTrackingConfig
¶
Configuration for recording per-weight step-sizes during training.
Attributes: interval: Record step-sizes every N steps include_bias: Whether to also record the bias step-size
TDIDBDState
¶
State for the TD-IDBD (Temporal-Difference IDBD) optimizer.
TD-IDBD extends IDBD to temporal-difference learning with eligibility traces. Maintains per-weight adaptive step-sizes that are meta-learned based on gradient correlation in the TD setting.
Reference: Kearney et al. 2019, "Learning Feature Relevance Through Step Size Adaptation in Temporal-Difference Learning"
Attributes: log_step_sizes: Log of per-weight step-sizes (log alpha_i) eligibility_traces: Eligibility traces z_i for temporal credit assignment h_traces: Per-weight h traces for gradient correlation meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda bias_log_step_size: Log step-size for the bias term bias_eligibility_trace: Eligibility trace for the bias bias_h_trace: h trace for the bias term
TDLearnerState
¶
State for a TD linear learner.
Attributes: weights: Weight vector for linear value function approximation bias: Bias term optimizer_state: State maintained by the TD optimizer
TDTimeStep
¶
Single experience from a TD stream.
Represents a transition (s, r, s', gamma) for temporal-difference learning.
Attributes: observation: Feature vector φ(s) reward: Reward R received next_observation: Feature vector φ(s') gamma: Discount factor γ_t (0 at terminal states)
TimeStep
¶
Single experience from an experience stream.
Attributes: observation: Feature vector x_t target: Desired output y*_t (for supervised learning)
ScanStream
¶
Bases: Protocol[StateT]
Protocol for JAX scan-compatible experience streams.
Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.
The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.
Type Parameters: StateT: The state type maintained by this stream
Examples:
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key for initialization
Returns: Initial stream state
step(state, idx)
¶
Generate one time step. Must be JIT-compatible.
This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.
Args: state: Current stream state idx: Current step index (can be ignored for most streams)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/base.py
AbruptChangeState
¶
State for AbruptChangeStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken
AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream with sudden target weight changes.
Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.
Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
CyclicState
¶
State for CyclicStream.
Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken
CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream that cycles between known weight configurations.
Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.
Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with pre-generated configurations
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
DynamicScaleShiftState
¶
State for DynamicScaleShiftStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken
DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)
¶
Non-stationary stream with abruptly changing feature scales.
Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.
The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).
Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise
Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights and scales
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
PeriodicChangeState
¶
State for PeriodicChangeStream.
Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken
PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream where target weights oscillate sinusoidally.
Target weights follow: w(t) = base + amplitude * sin(2π * t / period + phase) where each weight has a random phase offset for diversity.
This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.
Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random base weights and phases
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
RandomWalkState
¶
State for RandomWalkStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights
RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream where target weights drift via random walk.
The true target function is linear: y* = w_true @ x + noise
where w_true evolves via random walk at each time step.
This tests the learner's ability to continuously track a moving target.
Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
ScaleDriftState
¶
State for ScaleDriftStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken
ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)
¶
Non-stationary stream where feature scales drift via random walk.
Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.
The target is computed from unscaled features to maintain consistent difficulty across scale changes.
Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise
Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights and unit scales
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
ScaledStreamState
¶
State for ScaledStreamWrapper.
Attributes: inner_state: State of the wrapped stream
ScaledStreamWrapper(inner_stream, feature_scales)
¶
Wrapper that applies per-feature scaling to any stream's observations.
This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.
Examples:
stream = ScaledStreamWrapper(
AbruptChangeStream(feature_dim=10, change_interval=1000),
feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
100.0, 1000.0, 0.001, 0.01, 0.1])
)
Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)
Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.
Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
inner_stream
property
¶
Return the wrapped stream.
feature_scales
property
¶
Return the per-feature scale factors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state wrapping the inner stream's state
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step with scaled observations.
Args: state: Current stream state idx: Current step index
Returns: Tuple of (timestep with scaled observation, new_state)
Source code in src/alberta_framework/streams/synthetic.py
SuttonExperiment1State
¶
State for SuttonExperiment1Stream.
Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken
SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)
¶
Non-stationary stream replicating Experiment 1 from Sutton 1992.
This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"
Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)
Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with all +1 signs
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
Timer(name='Operation', verbose=True, print_fn=None)
¶
Context manager for timing code execution.
Measures wall-clock time for a block of code and optionally prints the duration when the block completes.
Attributes: name: Description of what is being timed duration: Elapsed time in seconds (available after context exits) start_time: Timestamp when timing started end_time: Timestamp when timing ended
Examples:
with Timer("Training loop"):
for i in range(1000):
pass
# Output: Training loop completed in 0.01s
# Silent timing (no print):
with Timer("Silent", verbose=False) as t:
time.sleep(0.1)
print(f"Elapsed: {t.duration:.2f}s")
# Output: Elapsed: 0.10s
# Custom print function:
with Timer("Custom", print_fn=lambda msg: print(f">> {msg}")):
pass
# Output: >> Custom completed in 0.00s
Args: name: Description of the operation being timed verbose: Whether to print the duration when done print_fn: Custom print function (defaults to built-in print)
Source code in src/alberta_framework/utils/timing.py
elapsed()
¶
Get elapsed time since timer started (can be called during execution).
Returns: Elapsed time in seconds
GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)
¶
Experience stream from a Gymnasium environment using Python loop.
This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.
For batch learning, use collect_trajectory() followed by learn_from_trajectory().
Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes
Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy
Source code in src/alberta_framework/streams/gymnasium.py
feature_dim
property
¶
Return the dimension of feature vectors.
target_dim
property
¶
Return the dimension of target vectors.
episode_count
property
¶
Return the number of completed episodes.
step_count
property
¶
Return the total number of steps taken.
mode
property
¶
Return the prediction mode.
set_value_estimator(estimator)
¶
Set the value estimator for proper TD learning in VALUE mode.
PredictionMode
¶
Bases: Enum
Mode for what the stream predicts.
REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)
TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)
¶
Experience stream for proper TD learning with value function bootstrap.
This stream integrates with a learner to use its predictions for bootstrapping in TD targets.
Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)
for step, timestep in enumerate(stream):
result = learner.update(state, timestep.observation, timestep.target)
state = result.state
stream.update_value_function(lambda x: learner.predict(state, x))
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed
Source code in src/alberta_framework/streams/gymnasium.py
metrics_to_dicts(metrics, normalized=False)
¶
Convert metrics array to list of dicts for backward compatibility.
Args: metrics: Array of shape (num_steps, 3) or (num_steps, 4) normalized: If True, expects 4 columns including normalizer_mean_var
Returns: List of metric dictionaries
Source code in src/alberta_framework/core/learners.py
run_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None)
¶
Run the learning loop using jax.lax.scan.
This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.
Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns a 3-tuple including StepSizeHistory.
Returns: If step_size_tracking is None: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size] If step_size_tracking is provided: Tuple of (final_state, metrics_array, step_size_history)
Raises: ValueError: If step_size_tracking.interval is less than 1 or greater than num_steps
Source code in src/alberta_framework/core/learners.py
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 | |
run_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None)
¶
Run learning loop across multiple seeds in parallel using jax.vmap.
This function provides GPU parallelization for multi-seed experiments, typically achieving 2-5x speedup over sequential execution.
Args: learner: The learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...)
Returns: BatchedLearningResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, 3) - step_size_history: Batched history or None if tracking disabled
Examples:
import jax.random as jr
from alberta_framework import LinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_learning_loop_batched
stream = RandomWalkStream(feature_dim=10)
learner = LinearLearner(optimizer=IDBD())
# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_learning_loop_batched(learner, stream, num_steps=10000, keys=keys)
# result.metrics has shape (30, 10000, 3)
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
Source code in src/alberta_framework/core/learners.py
853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 | |
run_normalized_learning_loop(learner, stream, num_steps, key, learner_state=None, step_size_tracking=None, normalizer_tracking=None)
¶
Run the learning loop with normalization using jax.lax.scan.
Args: learner: The normalized learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream) step_size_tracking: Optional config for recording per-weight step-sizes. When provided, returns StepSizeHistory including Autostep normalizers if applicable. normalizer_tracking: Optional config for recording per-feature normalizer state. When provided, returns NormalizerHistory with means and variances over time.
Returns: If no tracking: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var] If step_size_tracking only: Tuple of (final_state, metrics_array, step_size_history) If normalizer_tracking only: Tuple of (final_state, metrics_array, normalizer_history) If both: Tuple of (final_state, metrics_array, step_size_history, normalizer_history)
Raises: ValueError: If tracking interval is invalid
Source code in src/alberta_framework/core/learners.py
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 | |
run_normalized_learning_loop_batched(learner, stream, num_steps, keys, learner_state=None, step_size_tracking=None, normalizer_tracking=None)
¶
Run normalized learning loop across multiple seeds in parallel using jax.vmap.
This function provides GPU parallelization for multi-seed experiments with normalized learners, typically achieving 2-5x speedup over sequential execution.
Args: learner: The normalized learner to train stream: Experience stream providing (observation, target) pairs num_steps: Number of learning steps to run per seed keys: JAX random keys with shape (num_seeds,) or (num_seeds, 2) learner_state: Initial state (if None, will be initialized from stream). The same initial state is used for all seeds. step_size_tracking: Optional config for recording per-weight step-sizes. When provided, history arrays have shape (num_seeds, num_recordings, ...) normalizer_tracking: Optional config for recording normalizer state. When provided, history arrays have shape (num_seeds, num_recordings, ...)
Returns: BatchedNormalizedResult containing: - states: Batched final states with shape (num_seeds, ...) for each array - metrics: Array of shape (num_seeds, num_steps, 4) - step_size_history: Batched history or None if tracking disabled - normalizer_history: Batched history or None if tracking disabled
Examples:
import jax.random as jr
from alberta_framework import NormalizedLinearLearner, IDBD, RandomWalkStream
from alberta_framework import run_normalized_learning_loop_batched
stream = RandomWalkStream(feature_dim=10)
learner = NormalizedLinearLearner(optimizer=IDBD())
# Run 30 seeds in parallel
keys = jr.split(jr.key(42), 30)
result = run_normalized_learning_loop_batched(
learner, stream, num_steps=10000, keys=keys
)
# result.metrics has shape (30, 10000, 4)
mean_error = result.metrics[:, :, 0].mean(axis=0) # Average over seeds
Source code in src/alberta_framework/core/learners.py
934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 | |
run_td_learning_loop(learner, stream, num_steps, key, learner_state=None)
¶
Run the TD learning loop using jax.lax.scan.
This is a JIT-compiled learning loop that uses scan for efficiency. It returns metrics as a fixed-size array rather than a list of dicts.
Args: learner: The TD learner to train stream: TD experience stream providing (s, r, s', γ) tuples num_steps: Number of learning steps to run key: JAX random key for stream initialization learner_state: Initial state (if None, will be initialized from stream)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_td_error, td_error, mean_step_size, mean_eligibility_trace]
Source code in src/alberta_framework/core/learners.py
create_normalizer_state(feature_dim, decay=0.99)
¶
Create initial normalizer state.
Convenience function for creating normalizer state without instantiating the OnlineNormalizer class.
Args: feature_dim: Dimension of feature vectors decay: Exponential decay factor
Returns: Initial normalizer state
Source code in src/alberta_framework/core/normalizers.py
create_autotdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0, normalizer_decay=10000.0)
¶
Create initial AutoTDIDBD optimizer state.
Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0)) normalizer_decay: Decay parameter tau for normalizers (default: 10000)
Returns: Initial AutoTDIDBD state
Source code in src/alberta_framework/core/types.py
create_tdidbd_state(feature_dim, initial_step_size=0.01, meta_step_size=0.01, trace_decay=0.0)
¶
Create initial TD-IDBD optimizer state.
Args: feature_dim: Dimension of the feature vector initial_step_size: Initial per-weight step-size meta_step_size: Meta learning rate theta for adapting step-sizes trace_decay: Eligibility trace decay parameter lambda (0 = TD(0))
Returns: Initial TD-IDBD state
Source code in src/alberta_framework/core/types.py
make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)
¶
Create a per-feature scale array spanning a range.
Utility function to generate scale factors for ScaledStreamWrapper.
Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.
Returns: Array of shape (feature_dim,) with scale factors
Examples:
scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
compare_learners(results, metric='squared_error')
¶
Compare multiple learners on a given metric.
Args: results: Dictionary mapping learner name to metrics history metric: Metric to compare
Returns: Dictionary with summary statistics for each learner
Source code in src/alberta_framework/utils/metrics.py
compute_cumulative_error(metrics_history, error_key='squared_error')
¶
Compute cumulative error over time.
Args: metrics_history: List of metric dictionaries from learning loop error_key: Key to extract error values
Returns: Array of cumulative errors at each time step
Source code in src/alberta_framework/utils/metrics.py
compute_running_mean(values, window_size=100)
¶
Compute running mean of values.
Args: values: Array of values window_size: Size of the moving average window
Returns: Array of running mean values (same length as input, padded at start)
Source code in src/alberta_framework/utils/metrics.py
compute_tracking_error(metrics_history, window_size=100)
¶
Compute tracking error (running mean of squared error).
This is the key metric for evaluating continual learners: how well can the learner track the non-stationary target?
Args: metrics_history: List of metric dictionaries from learning loop window_size: Size of the moving average window
Returns: Array of tracking errors at each time step
Source code in src/alberta_framework/utils/metrics.py
extract_metric(metrics_history, key)
¶
Extract a single metric from the history.
Args: metrics_history: List of metric dictionaries key: Key to extract
Returns: Array of values for that metric
Source code in src/alberta_framework/utils/metrics.py
format_duration(seconds)
¶
Format a duration in seconds as a human-readable string.
Args: seconds: Duration in seconds
Returns: Formatted string like "1.23s", "2m 30.5s", or "1h 5m 30s"
Examples:
format_duration(0.5) # Returns: '0.50s'
format_duration(90.5) # Returns: '1m 30.50s'
format_duration(3665) # Returns: '1h 1m 5.00s'
Source code in src/alberta_framework/utils/timing.py
collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)
¶
Collect a trajectory from a Gymnasium environment.
This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy
Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory using jax.lax.scan.
This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.
Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory with normalization using jax.lax.scan.
Args: learner: The normalized learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
Source code in src/alberta_framework/streams/gymnasium.py
make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)
¶
Wrap a policy with epsilon-greedy exploration.
Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed
Returns: Epsilon-greedy policy
Source code in src/alberta_framework/streams/gymnasium.py
make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)
¶
Factory function to create a GymnasiumStream from an environment ID.
Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()
Returns: GymnasiumStream wrapping the environment
Source code in src/alberta_framework/streams/gymnasium.py
make_random_policy(env, seed=0)
¶
Create a random action policy for an environment.
Args: env: Gymnasium environment seed: Random seed
Returns: A callable that takes an observation and returns a random action