Skip to content

streams

streams

Experience streams for continual learning.

ScanStream

Bases: Protocol[StateT]

Protocol for JAX scan-compatible experience streams.

Streams generate temporally-uniform experience for continual learning. Unlike iterator-based streams, ScanStream uses pure functions that can be compiled with JAX's JIT and used with jax.lax.scan.

The stream should be non-stationary to test continual learning capabilities - the underlying target function changes over time.

Type Parameters: StateT: The state type maintained by this stream

Examples:

stream = RandomWalkStream(feature_dim=10, drift_rate=0.001)
key = jax.random.key(42)
state = stream.init(key)
timestep, new_state = stream.step(state, jnp.array(0))

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key for initialization

Returns: Initial stream state

Source code in src/alberta_framework/streams/base.py
def init(self, key: Array) -> StateT:
    """Initialize stream state.

    Args:
        key: JAX random key for initialization

    Returns:
        Initial stream state
    """
    ...

step(state, idx)

Generate one time step. Must be JIT-compatible.

This is a pure function that takes the current state and step index, and returns a TimeStep along with the updated state. The step index can be used for time-dependent behavior but is often ignored.

Args: state: Current stream state idx: Current step index (can be ignored for most streams)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/base.py
def step(self, state: StateT, idx: Array) -> tuple[TimeStep, StateT]:
    """Generate one time step. Must be JIT-compatible.

    This is a pure function that takes the current state and step index,
    and returns a TimeStep along with the updated state. The step index
    can be used for time-dependent behavior but is often ignored.

    Args:
        state: Current stream state
        idx: Current step index (can be ignored for most streams)

    Returns:
        Tuple of (timestep, new_state)
    """
    ...

AbruptChangeState

State for AbruptChangeStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights step_count: Number of steps taken

AbruptChangeStream(feature_dim, change_interval=1000, noise_std=0.1, feature_std=1.0)

Non-stationary stream with sudden target weight changes.

Target weights remain constant for a period, then abruptly change to new random values. Tests the learner's ability to detect and rapidly adapt to distribution shifts.

Attributes: feature_dim: Dimension of observation vectors change_interval: Number of steps between weight changes noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors change_interval: Steps between abrupt weight changes noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    change_interval: int = 1000,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the abrupt change stream.

    Args:
        feature_dim: Dimension of feature vectors
        change_interval: Steps between abrupt weight changes
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._change_interval = change_interval
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> AbruptChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return AbruptChangeState(
        key=key,
        true_weights=weights,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: AbruptChangeState, idx: Array) -> tuple[TimeStep, AbruptChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_weights, key_x, key_noise = jr.split(state.key, 4)

    # Determine if we should change weights
    should_change = state.step_count % self._change_interval == 0

    # Generate new weights (always generated but only used if should_change)
    new_random_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)

    # Use jnp.where to conditionally update weights (JIT-compatible)
    new_weights = jnp.where(should_change, new_random_weights, state.true_weights)

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = AbruptChangeState(
        key=key,
        true_weights=new_weights,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

CyclicState

State for CyclicStream.

Attributes: key: JAX random key for generating randomness configurations: Pre-generated weight configurations step_count: Number of steps taken

CyclicStream(feature_dim, cycle_length=500, num_configurations=4, noise_std=0.1, feature_std=1.0)

Non-stationary stream that cycles between known weight configurations.

Weights cycle through a fixed set of configurations. Tests whether the learner can re-adapt quickly to previously seen targets.

Attributes: feature_dim: Dimension of observation vectors cycle_length: Number of steps per configuration before switching num_configurations: Number of weight configurations to cycle through noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors cycle_length: Steps spent in each configuration num_configurations: Number of configurations to cycle through noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    cycle_length: int = 500,
    num_configurations: int = 4,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the cyclic target stream.

    Args:
        feature_dim: Dimension of feature vectors
        cycle_length: Steps spent in each configuration
        num_configurations: Number of configurations to cycle through
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._cycle_length = cycle_length
    self._num_configurations = num_configurations
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with pre-generated configurations

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> CyclicState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with pre-generated configurations
    """
    key, key_configs = jr.split(key)
    configurations = jr.normal(
        key_configs,
        (self._num_configurations, self._feature_dim),
        dtype=jnp.float32,
    )
    return CyclicState(
        key=key,
        configurations=configurations,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: CyclicState, idx: Array) -> tuple[TimeStep, CyclicState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Get current configuration index
    config_idx = (state.step_count // self._cycle_length) % self._num_configurations
    true_weights = state.configurations[config_idx]

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = CyclicState(
        key=key,
        configurations=state.configurations,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

DynamicScaleShiftState

State for DynamicScaleShiftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights current_scales: Current per-feature scaling factors step_count: Number of steps taken

DynamicScaleShiftStream(feature_dim, scale_change_interval=2000, weight_change_interval=1000, min_scale=0.01, max_scale=100.0, noise_std=0.1)

Non-stationary stream with abruptly changing feature scales.

Both target weights AND feature scales change at specified intervals. This tests whether OnlineNormalizer can track scale shifts faster than Autostep's internal v_i adaptation.

The target is computed from unscaled features to maintain consistent difficulty across scale changes (only the feature representation changes, not the underlying prediction task).

Attributes: feature_dim: Dimension of observation vectors scale_change_interval: Steps between scale changes weight_change_interval: Steps between weight changes min_scale: Minimum scale factor max_scale: Maximum scale factor noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors scale_change_interval: Steps between abrupt scale changes weight_change_interval: Steps between abrupt weight changes min_scale: Minimum scale factor (log-uniform sampling) max_scale: Maximum scale factor (log-uniform sampling) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    scale_change_interval: int = 2000,
    weight_change_interval: int = 1000,
    min_scale: float = 0.01,
    max_scale: float = 100.0,
    noise_std: float = 0.1,
):
    """Initialize the dynamic scale shift stream.

    Args:
        feature_dim: Dimension of feature vectors
        scale_change_interval: Steps between abrupt scale changes
        weight_change_interval: Steps between abrupt weight changes
        min_scale: Minimum scale factor (log-uniform sampling)
        max_scale: Maximum scale factor (log-uniform sampling)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._scale_change_interval = scale_change_interval
    self._weight_change_interval = weight_change_interval
    self._min_scale = min_scale
    self._max_scale = max_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> DynamicScaleShiftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and scales
    """
    key, k_weights, k_scales = jr.split(key, 3)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial scales: log-uniform between min and max
    log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    scales = jnp.exp(log_scales).astype(jnp.float32)
    return DynamicScaleShiftState(
        key=key,
        true_weights=weights,
        current_scales=scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: DynamicScaleShiftState, idx: Array
) -> tuple[TimeStep, DynamicScaleShiftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_weights, k_scales, k_x, k_noise = jr.split(state.key, 5)

    # Check if scales should change
    should_change_scales = state.step_count % self._scale_change_interval == 0
    new_log_scales = jr.uniform(
        k_scales,
        (self._feature_dim,),
        minval=jnp.log(self._min_scale),
        maxval=jnp.log(self._max_scale),
    )
    new_random_scales = jnp.exp(new_log_scales).astype(jnp.float32)
    new_scales = jnp.where(should_change_scales, new_random_scales, state.current_scales)

    # Check if weights should change
    should_change_weights = state.step_count % self._weight_change_interval == 0
    new_random_weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    new_weights = jnp.where(should_change_weights, new_random_weights, state.true_weights)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    x = raw_x * new_scales

    # Target from true weights using RAW features (for consistent difficulty)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = DynamicScaleShiftState(
        key=key,
        true_weights=new_weights,
        current_scales=new_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

PeriodicChangeState

State for PeriodicChangeStream.

Attributes: key: JAX random key for generating randomness base_weights: Base target weights (center of oscillation) phases: Per-weight phase offsets step_count: Number of steps taken

PeriodicChangeStream(feature_dim, period=1000, amplitude=1.0, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights oscillate sinusoidally.

Target weights follow: w(t) = base + amplitude * sin(2Ï€ * t / period + phase) where each weight has a random phase offset for diversity.

This tests the learner's ability to track predictable periodic changes, which is qualitatively different from random drift or abrupt changes.

Attributes: feature_dim: Dimension of observation vectors period: Number of steps for one complete oscillation amplitude: Magnitude of weight oscillation noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of feature vectors period: Steps for one complete oscillation cycle amplitude: Magnitude of weight oscillations around base noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    period: int = 1000,
    amplitude: float = 1.0,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the periodic change stream.

    Args:
        feature_dim: Dimension of feature vectors
        period: Steps for one complete oscillation cycle
        amplitude: Magnitude of weight oscillations around base
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._period = period
    self._amplitude = amplitude
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random base weights and phases

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> PeriodicChangeState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random base weights and phases
    """
    key, key_weights, key_phases = jr.split(key, 3)
    base_weights = jr.normal(key_weights, (self._feature_dim,), dtype=jnp.float32)
    # Random phases in [0, 2Ï€) for each weight
    phases = jr.uniform(key_phases, (self._feature_dim,), minval=0.0, maxval=2.0 * jnp.pi)
    return PeriodicChangeState(
        key=key,
        base_weights=base_weights,
        phases=phases,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: PeriodicChangeState, idx: Array) -> tuple[TimeStep, PeriodicChangeState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_noise = jr.split(state.key, 3)

    # Compute oscillating weights: w(t) = base + amplitude * sin(2Ï€ * t / period + phase)
    t = state.step_count.astype(jnp.float32)
    oscillation = self._amplitude * jnp.sin(2.0 * jnp.pi * t / self._period + state.phases)
    true_weights = state.base_weights + oscillation

    # Generate observation
    x = self._feature_std * jr.normal(key_x, (self._feature_dim,), dtype=jnp.float32)

    # Compute target
    noise = self._noise_std * jr.normal(key_noise, (), dtype=jnp.float32)
    target = jnp.dot(true_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = PeriodicChangeState(
        key=key,
        base_weights=state.base_weights,
        phases=state.phases,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

RandomWalkState

State for RandomWalkStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights

RandomWalkStream(feature_dim, drift_rate=0.001, noise_std=0.1, feature_std=1.0)

Non-stationary stream where target weights drift via random walk.

The true target function is linear: y* = w_true @ x + noise where w_true evolves via random walk at each time step.

This tests the learner's ability to continuously track a moving target.

Attributes: feature_dim: Dimension of observation vectors drift_rate: Standard deviation of weight drift per step noise_std: Standard deviation of observation noise feature_std: Standard deviation of features

Args: feature_dim: Dimension of the feature/observation vectors drift_rate: Std dev of weight changes per step (controls non-stationarity) noise_std: Std dev of target noise feature_std: Std dev of feature values

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    drift_rate: float = 0.001,
    noise_std: float = 0.1,
    feature_std: float = 1.0,
):
    """Initialize the random walk target stream.

    Args:
        feature_dim: Dimension of the feature/observation vectors
        drift_rate: Std dev of weight changes per step (controls non-stationarity)
        noise_std: Std dev of target noise
        feature_std: Std dev of feature values
    """
    self._feature_dim = feature_dim
    self._drift_rate = drift_rate
    self._noise_std = noise_std
    self._feature_std = feature_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> RandomWalkState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights
    """
    key, subkey = jr.split(key)
    weights = jr.normal(subkey, (self._feature_dim,), dtype=jnp.float32)
    return RandomWalkState(key=key, true_weights=weights)

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: RandomWalkState, idx: Array) -> tuple[TimeStep, RandomWalkState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_drift, k_x, k_noise = jr.split(state.key, 4)

    # Drift weights
    drift = jr.normal(k_drift, state.true_weights.shape, dtype=jnp.float32)
    new_weights = state.true_weights + self._drift_rate * drift

    # Generate observation and target
    x = self._feature_std * jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = RandomWalkState(key=key, true_weights=new_weights)

    return timestep, new_state

ScaleDriftState

State for ScaleDriftStream.

Attributes: key: JAX random key for generating randomness true_weights: Current true target weights log_scales: Current log-scale factors (random walk on log-scale) step_count: Number of steps taken

ScaleDriftStream(feature_dim, weight_drift_rate=0.001, scale_drift_rate=0.01, min_log_scale=-4.0, max_log_scale=4.0, noise_std=0.1)

Non-stationary stream where feature scales drift via random walk.

Both target weights and feature scales drift continuously. Weights drift in linear space while scales drift in log-space (bounded random walk). This tests continuous scale tracking where OnlineNormalizer's EMA may adapt differently than Autostep's v_i.

The target is computed from unscaled features to maintain consistent difficulty across scale changes.

Attributes: feature_dim: Dimension of observation vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips random walk) max_log_scale: Maximum log-scale (clips random walk) noise_std: Standard deviation of observation noise

Args: feature_dim: Dimension of feature vectors weight_drift_rate: Std dev of weight drift per step scale_drift_rate: Std dev of log-scale drift per step min_log_scale: Minimum log-scale (clips drift) max_log_scale: Maximum log-scale (clips drift) noise_std: Std dev of target noise

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    feature_dim: int,
    weight_drift_rate: float = 0.001,
    scale_drift_rate: float = 0.01,
    min_log_scale: float = -4.0,  # exp(-4) ~ 0.018
    max_log_scale: float = 4.0,  # exp(4) ~ 54.6
    noise_std: float = 0.1,
):
    """Initialize the scale drift stream.

    Args:
        feature_dim: Dimension of feature vectors
        weight_drift_rate: Std dev of weight drift per step
        scale_drift_rate: Std dev of log-scale drift per step
        min_log_scale: Minimum log-scale (clips drift)
        max_log_scale: Maximum log-scale (clips drift)
        noise_std: Std dev of target noise
    """
    self._feature_dim = feature_dim
    self._weight_drift_rate = weight_drift_rate
    self._scale_drift_rate = scale_drift_rate
    self._min_log_scale = min_log_scale
    self._max_log_scale = max_log_scale
    self._noise_std = noise_std

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with random weights and unit scales

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaleDriftState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with random weights and unit scales
    """
    key, k_weights = jr.split(key)
    weights = jr.normal(k_weights, (self._feature_dim,), dtype=jnp.float32)
    # Initial log-scales at 0 (scale = 1)
    log_scales = jnp.zeros(self._feature_dim, dtype=jnp.float32)
    return ScaleDriftState(
        key=key,
        true_weights=weights,
        log_scales=log_scales,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaleDriftState, idx: Array) -> tuple[TimeStep, ScaleDriftState]:
    """Generate one time step.

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, k_w_drift, k_s_drift, k_x, k_noise = jr.split(state.key, 5)

    # Drift target weights
    weight_drift = self._weight_drift_rate * jr.normal(
        k_w_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_weights = state.true_weights + weight_drift

    # Drift log-scales (bounded random walk)
    scale_drift = self._scale_drift_rate * jr.normal(
        k_s_drift, (self._feature_dim,), dtype=jnp.float32
    )
    new_log_scales = state.log_scales + scale_drift
    new_log_scales = jnp.clip(new_log_scales, self._min_log_scale, self._max_log_scale)

    # Generate raw features (unscaled)
    raw_x = jr.normal(k_x, (self._feature_dim,), dtype=jnp.float32)

    # Apply scaling to observation
    scales = jnp.exp(new_log_scales)
    x = raw_x * scales

    # Target from true weights using RAW features
    noise = self._noise_std * jr.normal(k_noise, (), dtype=jnp.float32)
    target = jnp.dot(new_weights, raw_x) + noise

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = ScaleDriftState(
        key=key,
        true_weights=new_weights,
        log_scales=new_log_scales,
        step_count=state.step_count + 1,
    )
    return timestep, new_state

ScaledStreamState

State for ScaledStreamWrapper.

Attributes: inner_state: State of the wrapped stream

ScaledStreamWrapper(inner_stream, feature_scales)

Wrapper that applies per-feature scaling to any stream's observations.

This wrapper multiplies each feature of the observation by a corresponding scale factor. Useful for testing how learners handle features at different scales, which is important for understanding normalization benefits.

Examples:

stream = ScaledStreamWrapper(
    AbruptChangeStream(feature_dim=10, change_interval=1000),
    feature_scales=jnp.array([0.001, 0.01, 0.1, 1.0, 10.0,
                              100.0, 1000.0, 0.001, 0.01, 0.1])
)

Attributes: inner_stream: The wrapped stream instance feature_scales: Per-feature scale factors (must match feature_dim)

Args: inner_stream: Stream to wrap (must implement ScanStream protocol) feature_scales: Array of scale factors, one per feature. Must have shape (feature_dim,) matching the inner stream's feature_dim.

Raises: ValueError: If feature_scales length doesn't match inner stream's feature_dim

Source code in src/alberta_framework/streams/synthetic.py
def __init__(self, inner_stream: ScanStream[Any], feature_scales: Array):
    """Initialize the scaled stream wrapper.

    Args:
        inner_stream: Stream to wrap (must implement ScanStream protocol)
        feature_scales: Array of scale factors, one per feature. Must have
            shape (feature_dim,) matching the inner stream's feature_dim.

    Raises:
        ValueError: If feature_scales length doesn't match inner stream's feature_dim
    """
    self._inner_stream: ScanStream[Any] = inner_stream
    self._feature_scales = jnp.asarray(feature_scales, dtype=jnp.float32)

    if self._feature_scales.shape[0] != inner_stream.feature_dim:
        raise ValueError(
            f"feature_scales length ({self._feature_scales.shape[0]}) "
            f"must match inner stream's feature_dim ({inner_stream.feature_dim})"
        )

feature_dim property

Return the dimension of observation vectors.

inner_stream property

Return the wrapped stream.

feature_scales property

Return the per-feature scale factors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state wrapping the inner stream's state

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> ScaledStreamState:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state wrapping the inner stream's state
    """
    inner_state = self._inner_stream.init(key)
    return ScaledStreamState(inner_state=inner_state)

step(state, idx)

Generate one time step with scaled observations.

Args: state: Current stream state idx: Current step index

Returns: Tuple of (timestep with scaled observation, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(self, state: ScaledStreamState, idx: Array) -> tuple[TimeStep, ScaledStreamState]:
    """Generate one time step with scaled observations.

    Args:
        state: Current stream state
        idx: Current step index

    Returns:
        Tuple of (timestep with scaled observation, new_state)
    """
    timestep, new_inner_state = self._inner_stream.step(state.inner_state, idx)

    # Scale the observation
    scaled_observation = timestep.observation * self._feature_scales

    scaled_timestep = TimeStep(
        observation=scaled_observation,
        target=timestep.target,
    )

    new_state = ScaledStreamState(inner_state=new_inner_state)
    return scaled_timestep, new_state

SuttonExperiment1State

State for SuttonExperiment1Stream.

Attributes: key: JAX random key for generating randomness signs: Signs (+1/-1) for the relevant inputs step_count: Number of steps taken

SuttonExperiment1Stream(num_relevant=5, num_irrelevant=15, change_interval=20)

Non-stationary stream replicating Experiment 1 from Sutton 1992.

This stream implements the exact task from Sutton's IDBD paper: - 20 real-valued inputs drawn from N(0, 1) - Only first 5 inputs are relevant (weights are ±1) - Last 15 inputs are irrelevant (weights are 0) - Every change_interval steps, one of the 5 relevant signs is flipped

Reference: Sutton, R.S. (1992). "Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta"

Attributes: num_relevant: Number of relevant inputs (default 5) num_irrelevant: Number of irrelevant inputs (default 15) change_interval: Steps between sign changes (default 20)

Args: num_relevant: Number of relevant inputs with ±1 weights num_irrelevant: Number of irrelevant inputs with 0 weights change_interval: Number of steps between sign flips

Source code in src/alberta_framework/streams/synthetic.py
def __init__(
    self,
    num_relevant: int = 5,
    num_irrelevant: int = 15,
    change_interval: int = 20,
):
    """Initialize the Sutton Experiment 1 stream.

    Args:
        num_relevant: Number of relevant inputs with ±1 weights
        num_irrelevant: Number of irrelevant inputs with 0 weights
        change_interval: Number of steps between sign flips
    """
    self._num_relevant = num_relevant
    self._num_irrelevant = num_irrelevant
    self._change_interval = change_interval

feature_dim property

Return the dimension of observation vectors.

init(key)

Initialize stream state.

Args: key: JAX random key

Returns: Initial stream state with all +1 signs

Source code in src/alberta_framework/streams/synthetic.py
def init(self, key: Array) -> SuttonExperiment1State:
    """Initialize stream state.

    Args:
        key: JAX random key

    Returns:
        Initial stream state with all +1 signs
    """
    signs = jnp.ones(self._num_relevant, dtype=jnp.float32)
    return SuttonExperiment1State(
        key=key,
        signs=signs,
        step_count=jnp.array(0, dtype=jnp.int32),
    )

step(state, idx)

Generate one time step.

At each step: 1. If at a change interval (and not step 0), flip one random sign 2. Generate random inputs from N(0, 1) 3. Compute target as sum of relevant inputs weighted by signs

Args: state: Current stream state idx: Current step index (unused)

Returns: Tuple of (timestep, new_state)

Source code in src/alberta_framework/streams/synthetic.py
def step(
    self, state: SuttonExperiment1State, idx: Array
) -> tuple[TimeStep, SuttonExperiment1State]:
    """Generate one time step.

    At each step:
    1. If at a change interval (and not step 0), flip one random sign
    2. Generate random inputs from N(0, 1)
    3. Compute target as sum of relevant inputs weighted by signs

    Args:
        state: Current stream state
        idx: Current step index (unused)

    Returns:
        Tuple of (timestep, new_state)
    """
    del idx  # unused
    key, key_x, key_which = jr.split(state.key, 3)

    # Determine if we should flip a sign (not at step 0)
    should_flip = (state.step_count > 0) & (state.step_count % self._change_interval == 0)

    # Select which sign to flip
    idx_to_flip = jr.randint(key_which, (), 0, self._num_relevant)

    # Create flip mask
    flip_mask = jnp.where(
        jnp.arange(self._num_relevant) == idx_to_flip,
        jnp.array(-1.0, dtype=jnp.float32),
        jnp.array(1.0, dtype=jnp.float32),
    )

    # Apply flip mask conditionally
    new_signs = jnp.where(should_flip, state.signs * flip_mask, state.signs)

    # Generate observation from N(0, 1)
    x = jr.normal(key_x, (self.feature_dim,), dtype=jnp.float32)

    # Compute target: sum of first num_relevant inputs weighted by signs
    target = jnp.dot(new_signs, x[: self._num_relevant])

    timestep = TimeStep(observation=x, target=jnp.atleast_1d(target))
    new_state = SuttonExperiment1State(
        key=key,
        signs=new_signs,
        step_count=state.step_count + 1,
    )

    return timestep, new_state

GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)

Experience stream from a Gymnasium environment using Python loop.

This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.

For batch learning, use collect_trajectory() followed by learn_from_trajectory().

Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes

Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
):
    """Initialize the Gymnasium stream.

    Args:
        env: Gymnasium environment instance
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action).
            If False, features = obs only
        seed: Random seed for environment resets and random policy
    """
    self._env = env
    self._mode = mode
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    if mode == PredictionMode.NEXT_STATE:
        self._target_dim = self._obs_dim
    else:
        self._target_dim = 1

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_estimator: Callable[[Array], float] | None = None

feature_dim property

Return the dimension of feature vectors.

target_dim property

Return the dimension of target vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

mode property

Return the prediction mode.

set_value_estimator(estimator)

Set the value estimator for proper TD learning in VALUE mode.

Source code in src/alberta_framework/streams/gymnasium.py
def set_value_estimator(self, estimator: Callable[[Array], float]) -> None:
    """Set the value estimator for proper TD learning in VALUE mode."""
    self._value_estimator = estimator

PredictionMode

Bases: Enum

Mode for what the stream predicts.

REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)

TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)

Experience stream for proper TD learning with value function bootstrap.

This stream integrates with a learner to use its predictions for bootstrapping in TD targets.

Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)

for step, timestep in enumerate(stream):
    result = learner.update(state, timestep.observation, timestep.target)
    state = result.state
    stream.update_value_function(lambda x: learner.predict(state, x))

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed

Source code in src/alberta_framework/streams/gymnasium.py
def __init__(
    self,
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = False,
    seed: int = 0,
):
    """Initialize the TD stream.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor
        include_action_in_features: If True, learn Q(s,a). If False, learn V(s)
        seed: Random seed
    """
    self._env = env
    self._gamma = gamma
    self._include_action_in_features = include_action_in_features
    self._seed = seed
    self._reset_count = 0

    if policy is None:
        self._policy = make_random_policy(env, seed)
    else:
        self._policy = policy

    self._obs_dim = _flatten_space(env.observation_space)
    self._action_dim = _flatten_space(env.action_space)

    if include_action_in_features:
        self._feature_dim = self._obs_dim + self._action_dim
    else:
        self._feature_dim = self._obs_dim

    self._current_obs: Array | None = None
    self._episode_count = 0
    self._step_count = 0
    self._value_fn: Callable[[Array], float] = lambda x: 0.0

feature_dim property

Return the dimension of feature vectors.

episode_count property

Return the number of completed episodes.

step_count property

Return the total number of steps taken.

update_value_function(value_fn)

Update the value function used for TD bootstrapping.

Source code in src/alberta_framework/streams/gymnasium.py
def update_value_function(self, value_fn: Callable[[Array], float]) -> None:
    """Update the value function used for TD bootstrapping."""
    self._value_fn = value_fn

make_scale_range(feature_dim, min_scale=0.001, max_scale=1000.0, log_spaced=True)

Create a per-feature scale array spanning a range.

Utility function to generate scale factors for ScaledStreamWrapper.

Args: feature_dim: Number of features min_scale: Minimum scale factor max_scale: Maximum scale factor log_spaced: If True, scales are logarithmically spaced (default). If False, scales are linearly spaced.

Returns: Array of shape (feature_dim,) with scale factors

Examples:

scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
Source code in src/alberta_framework/streams/synthetic.py
def make_scale_range(
    feature_dim: int,
    min_scale: float = 0.001,
    max_scale: float = 1000.0,
    log_spaced: bool = True,
) -> Array:
    """Create a per-feature scale array spanning a range.

    Utility function to generate scale factors for ScaledStreamWrapper.

    Args:
        feature_dim: Number of features
        min_scale: Minimum scale factor
        max_scale: Maximum scale factor
        log_spaced: If True, scales are logarithmically spaced (default).
            If False, scales are linearly spaced.

    Returns:
        Array of shape (feature_dim,) with scale factors

    Examples
    --------
    ```python
    scales = make_scale_range(10, min_scale=0.01, max_scale=100.0)
    stream = ScaledStreamWrapper(RandomWalkStream(10), scales)
    ```
    """
    if log_spaced:
        return jnp.logspace(
            jnp.log10(min_scale),
            jnp.log10(max_scale),
            feature_dim,
            dtype=jnp.float32,
        )
    else:
        return jnp.linspace(min_scale, max_scale, feature_dim, dtype=jnp.float32)

collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)

Collect a trajectory from a Gymnasium environment.

This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.

Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy

Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)

Source code in src/alberta_framework/streams/gymnasium.py
def collect_trajectory(
    env: gymnasium.Env[Any, Any],
    policy: Callable[[Array], Any] | None,
    num_steps: int,
    mode: PredictionMode = PredictionMode.REWARD,
    include_action_in_features: bool = True,
    seed: int = 0,
) -> tuple[Array, Array]:
    """Collect a trajectory from a Gymnasium environment.

    This uses a Python loop to interact with the environment and collects
    observations and targets into JAX arrays that can be used with scan-based
    learning.

    Args:
        env: Gymnasium environment instance
        policy: Action selection function. If None, uses random policy
        num_steps: Number of steps to collect
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed for environment resets and random policy

    Returns:
        Tuple of (observations, targets) as JAX arrays with shape
        (num_steps, feature_dim) and (num_steps, target_dim)
    """
    if policy is None:
        policy = make_random_policy(env, seed)

    observations = []
    targets = []

    reset_count = 0
    raw_obs, _ = env.reset(seed=seed + reset_count)
    reset_count += 1
    current_obs = _flatten_observation(raw_obs, env.observation_space)

    for _ in range(num_steps):
        action = policy(current_obs)
        flat_action = _flatten_action(action, env.action_space)

        raw_next_obs, reward, terminated, truncated, _ = env.step(action)
        next_obs = _flatten_observation(raw_next_obs, env.observation_space)

        # Construct features
        if include_action_in_features:
            features = jnp.concatenate([current_obs, flat_action])
        else:
            features = current_obs

        # Construct target based on mode
        if mode == PredictionMode.REWARD:
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))
        elif mode == PredictionMode.NEXT_STATE:
            target = next_obs
        else:  # VALUE mode
            # TD target with 0 bootstrap (simple version)
            target = jnp.atleast_1d(jnp.array(reward, dtype=jnp.float32))

        observations.append(features)
        targets.append(target)

        if terminated or truncated:
            raw_obs, _ = env.reset(seed=seed + reset_count)
            reset_count += 1
            current_obs = _flatten_observation(raw_obs, env.observation_space)
        else:
            current_obs = next_obs

    return jnp.stack(observations), jnp.stack(targets)

learn_from_trajectory(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory using jax.lax.scan.

This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.

Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory(
    learner: LinearLearner,
    observations: Array,
    targets: Array,
    learner_state: LearnerState | None = None,
) -> tuple[LearnerState, Array]:
    """Learn from a pre-collected trajectory using jax.lax.scan.

    This is a JIT-compiled learning function that processes a trajectory
    collected from a Gymnasium environment.

    Args:
        learner: The learner to train
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 3) with columns [squared_error, error, mean_step_size]
    """
    if learner_state is None:
        learner_state = learner.init(observations.shape[1])

    def step_fn(state: LearnerState, inputs: tuple[Array, Array]) -> tuple[LearnerState, Array]:
        obs, target = inputs
        result = learner.update(state, obs, target)
        return result.state, result.metrics

    final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))

    return final_state, metrics

learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)

Learn from a pre-collected trajectory with normalization using jax.lax.scan.

Args: learner: The normalized learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)

Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]

Source code in src/alberta_framework/streams/gymnasium.py
def learn_from_trajectory_normalized(
    learner: NormalizedLinearLearner,
    observations: Array,
    targets: Array,
    learner_state: NormalizedLearnerState | None = None,
) -> tuple[NormalizedLearnerState, Array]:
    """Learn from a pre-collected trajectory with normalization using jax.lax.scan.

    Args:
        learner: The normalized learner to train
        observations: Array of observations with shape (num_steps, feature_dim)
        targets: Array of targets with shape (num_steps, target_dim)
        learner_state: Initial state (if None, will be initialized)

    Returns:
        Tuple of (final_state, metrics_array) where metrics_array has shape
        (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
    """
    if learner_state is None:
        learner_state = learner.init(observations.shape[1])

    def step_fn(
        state: NormalizedLearnerState, inputs: tuple[Array, Array]
    ) -> tuple[NormalizedLearnerState, Array]:
        obs, target = inputs
        result = learner.update(state, obs, target)
        return result.state, result.metrics

    final_state, metrics = jax.lax.scan(step_fn, learner_state, (observations, targets))

    return final_state, metrics

make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)

Wrap a policy with epsilon-greedy exploration.

Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed

Returns: Epsilon-greedy policy

Source code in src/alberta_framework/streams/gymnasium.py
def make_epsilon_greedy_policy(
    base_policy: Callable[[Array], Any],
    env: gymnasium.Env[Any, Any],
    epsilon: float = 0.1,
    seed: int = 0,
) -> Callable[[Array], Any]:
    """Wrap a policy with epsilon-greedy exploration.

    Args:
        base_policy: The greedy policy to wrap
        env: Gymnasium environment (for random action sampling)
        epsilon: Probability of taking a random action
        seed: Random seed

    Returns:
        Epsilon-greedy policy
    """
    random_policy = make_random_policy(env, seed + 1)
    rng = jr.key(seed)

    def policy(obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if jr.uniform(key) < epsilon:
            return random_policy(obs)
        return base_policy(obs)

    return policy

make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)

Factory function to create a GymnasiumStream from an environment ID.

Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()

Returns: GymnasiumStream wrapping the environment

Source code in src/alberta_framework/streams/gymnasium.py
def make_gymnasium_stream(
    env_id: str,
    mode: PredictionMode = PredictionMode.REWARD,
    policy: Callable[[Array], Any] | None = None,
    gamma: float = 0.99,
    include_action_in_features: bool = True,
    seed: int = 0,
    **env_kwargs: Any,
) -> GymnasiumStream:
    """Factory function to create a GymnasiumStream from an environment ID.

    Args:
        env_id: Gymnasium environment ID (e.g., "CartPole-v1")
        mode: What to predict (REWARD, NEXT_STATE, VALUE)
        policy: Action selection function. If None, uses random policy
        gamma: Discount factor for VALUE mode
        include_action_in_features: If True, features = concat(obs, action)
        seed: Random seed
        **env_kwargs: Additional arguments passed to gymnasium.make()

    Returns:
        GymnasiumStream wrapping the environment
    """
    import gymnasium

    env = gymnasium.make(env_id, **env_kwargs)
    return GymnasiumStream(
        env=env,
        mode=mode,
        policy=policy,
        gamma=gamma,
        include_action_in_features=include_action_in_features,
        seed=seed,
    )

make_random_policy(env, seed=0)

Create a random action policy for an environment.

Args: env: Gymnasium environment seed: Random seed

Returns: A callable that takes an observation and returns a random action

Source code in src/alberta_framework/streams/gymnasium.py
def make_random_policy(env: gymnasium.Env[Any, Any], seed: int = 0) -> Callable[[Array], Any]:
    """Create a random action policy for an environment.

    Args:
        env: Gymnasium environment
        seed: Random seed

    Returns:
        A callable that takes an observation and returns a random action
    """
    import gymnasium

    rng = jr.key(seed)
    action_space = env.action_space

    def policy(_obs: Array) -> Any:
        nonlocal rng
        rng, key = jr.split(rng)

        if isinstance(action_space, gymnasium.spaces.Discrete):
            return int(jr.randint(key, (), 0, int(action_space.n)))
        elif isinstance(action_space, gymnasium.spaces.Box):
            # Sample uniformly between low and high
            low = jnp.asarray(action_space.low, dtype=jnp.float32)
            high = jnp.asarray(action_space.high, dtype=jnp.float32)
            return jr.uniform(key, action_space.shape, minval=low, maxval=high)
        elif isinstance(action_space, gymnasium.spaces.MultiDiscrete):
            nvec = action_space.nvec
            return [int(jr.randint(jr.fold_in(key, i), (), 0, n)) for i, n in enumerate(nvec)]
        else:
            raise ValueError(f"Unsupported action space: {type(action_space).__name__}")

    return policy