streams
streams
¶
Experience streams for continual learning.
ScanStream
¶
Bases: Protocol[StateT]
Protocol for JAX scan-compatible experience streams.
Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.
The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.
Type Parameters: StateT: The state type maintained by this stream
Examples:
stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key for initialization
Returns: Initial stream state
step(state, idx)
¶
Generate one time step. Must be JIT-compatible.
This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.
Args: state: Current stream state idx: Current step index (can be ignored for most streams)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/base.py
AbruptChangeState
¶
State for AbruptChangeStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken
AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream with sudden target weight changes.
Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.
Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
CyclicState
¶
State for CyclicStream.
Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken
CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream that cycles between known weight configurations.
Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.
Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with pre-generated configurations
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
DynamicScaleShiftState
¶
State for DynamicScaleShiftStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken
DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)
¶
Non-stationary stream with abruptly changing feature scales.
Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.
The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).
Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise
Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights and scales
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
PeriodicChangeState
¶
State for PeriodicChangeStream.
Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken
PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream where target weights oscillate sinusoidally.
Target weights follow: w(t) = base + amplitude * sin(2Ï€ * t / period + phase) where each weight has a random phase offset for diversity.
This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.
Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random base weights and phases
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
RandomWalkState
¶
State for RandomWalkStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights
RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)
¶
Non-stationary stream where target weights drift via random walk.
The true target function is linear: y* = w_true @ x + noise
where w_true evolves via random walk at each time step.
This tests the learner's ability to continuously track a moving target.
Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features
Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
ScaleDriftState
¶
State for ScaleDriftStream.
Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken
ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)
¶
Non-stationary stream where feature scales drift via random walk.
Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.
The target is computed from unscaled features to maintain consistent difficulty across scale changes.
Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise
Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with random weights and unit scales
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
ScaledStreamState
¶
State for ScaledStreamWrapper.
Attributes: inner_state: State of the wrapped stream
ScaledStreamWrapper(inner_stream, feature_scales)
¶
Wrapper that applies per-feature scaling to any stream's observations.
This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.
Examples:
stream = ScaledStreamWrapper(
AbruptChangeStream(feature_dim=10, change_interval=1000),
feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
100.0, 1000.0, 0.001, 0.01, 0.1])
)
Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)
Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.
Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
inner_stream
property
¶
Return the wrapped stream.
feature_scales
property
¶
Return the per-feature scale factors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state wrapping the inner stream's state
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step with scaled observations.
Args: state: Current stream state idx: Current step index
Returns: Tuple of (timestep with scaled observation, new_state)
Source code in src/alberta_framework/streams/synthetic.py
SuttonExperiment1State
¶
State for SuttonExperiment1Stream.
Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken
SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)
¶
Non-stationary stream replicating Experiment 1 from Sutton 1992.
This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped
Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"
Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)
Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips
Source code in src/alberta_framework/streams/synthetic.py
feature_dim
property
¶
Return the dimension of observation vectors.
init(key)
¶
Initialize stream state.
Args: key: JAX random key
Returns: Initial stream state with all +1 signs
Source code in src/alberta_framework/streams/synthetic.py
step(state, idx)
¶
Generate one time step.
At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs
Args: state: Current stream state idx: Current step index (unused)
Returns: Tuple of (timestep, new_state)
Source code in src/alberta_framework/streams/synthetic.py
GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)
¶
Experience stream from a Gymnasium environment using Python loop.
This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.
For batch learning, use collect_trajectory() followed by learn_from_trajectory().
Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes
Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy
Source code in src/alberta_framework/streams/gymnasium.py
feature_dim
property
¶
Return the dimension of feature vectors.
target_dim
property
¶
Return the dimension of target vectors.
episode_count
property
¶
Return the number of completed episodes.
step_count
property
¶
Return the total number of steps taken.
mode
property
¶
Return the prediction mode.
set_value_estimator(estimator)
¶
Set the value estimator for proper TD learning in VALUE mode.
PredictionMode
¶
Bases: Enum
Mode for what the stream predicts.
REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)
TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)
¶
Experience stream for proper TD learning with value function bootstrap.
This stream integrates with a learner to use its predictions for bootstrapping in TD targets.
Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)
for step, timestep in enumerate(stream):
result = learner.update(state, timestep.observation, timestep.target)
state = result.state
stream.update_value_function(lambda x: learner.predict(state, x))
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed
Source code in src/alberta_framework/streams/gymnasium.py
make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)
¶
Create a per-feature scale array spanning a range.
Utility function to generate scale factors for ScaledStreamWrapper.
Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.
Returns: Array of shape (feature_dim,) with scale factors
Examples:
scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)
¶
Collect a trajectory from a Gymnasium environment.
This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy
Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory using jax.lax.scan.
This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.
Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory with normalization using jax.lax.scan.
Args: learner: The normalized learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
Source code in src/alberta_framework/streams/gymnasium.py
make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)
¶
Wrap a policy with epsilon-greedy exploration.
Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed
Returns: Epsilon-greedy policy
Source code in src/alberta_framework/streams/gymnasium.py
make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)
¶
Factory function to create a GymnasiumStream from an environment ID.
Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()
Returns: GymnasiumStream wrapping the environment
Source code in src/alberta_framework/streams/gymnasium.py
make_random_policy(env, seed=0)
¶
Create a random action policy for an environment.
Args: env: Gymnasium environment seed: Random seed
Returns: A callable that takes an observation and returns a random action