gymnasium
gymnasium
¶
Gymnasium environment wrappers as experience streams.
This module wraps Gymnasium environments to provide temporally-uniform experience streams compatible with the Alberta Framework's learners.
Gymnasium environments cannot be JIT-compiled, so this module provides: 1. Trajectory collection: Collect data using Python loop, then learn with scan 2. Online learning: Python loop for cases requiring real-time env interaction
Supports multiple prediction modes: - REWARD: Predict immediate reward from (state, action) - NEXT_STATE: Predict next state from (state, action) - VALUE: Predict cumulative return via TD learning
PredictionMode
¶
Bases: Enum
Mode for what the stream predicts.
REWARD: Predict immediate reward from (state, action) NEXT_STATE: Predict next state from (state, action) VALUE: Predict cumulative return (TD learning with bootstrap)
GymnasiumStream(env, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0)
¶
Experience stream from a Gymnasium environment using Python loop.
This class maintains iterator-based access for online learning scenarios where you need to interact with the environment in real-time.
For batch learning, use collect_trajectory() followed by learn_from_trajectory().
Attributes: mode: Prediction mode (REWARD, NEXT_STATE, VALUE) gamma: Discount factor for VALUE mode include_action_in_features: Whether to include action in features episode_count: Number of completed episodes
Args: env: Gymnasium environment instance mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action). If False, features = obs only seed: Random seed for environment resets and random policy
Source code in src/alberta_framework/streams/gymnasium.py
feature_dim
property
¶
Return the dimension of feature vectors.
target_dim
property
¶
Return the dimension of target vectors.
episode_count
property
¶
Return the number of completed episodes.
step_count
property
¶
Return the total number of steps taken.
mode
property
¶
Return the prediction mode.
set_value_estimator(estimator)
¶
Set the value estimator for proper TD learning in VALUE mode.
TDStream(env, policy=None, gamma=0.99, include_action_in_features=False, seed=0)
¶
Experience stream for proper TD learning with value function bootstrap.
This stream integrates with a learner to use its predictions for bootstrapping in TD targets.
Usage: stream = TDStream(env) learner = LinearLearner(optimizer=IDBD()) state = learner.init(stream.feature_dim)
for step, timestep in enumerate(stream):
result = learner.update(state, timestep.observation, timestep.target)
state = result.state
stream.update_value_function(lambda x: learner.predict(state, x))
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy gamma: Discount factor include_action_in_features: If True, learn Q(s,a). If False, learn V(s) seed: Random seed
Source code in src/alberta_framework/streams/gymnasium.py
make_random_policy(env, seed=0)
¶
Create a random action policy for an environment.
Args: env: Gymnasium environment seed: Random seed
Returns: A callable that takes an observation and returns a random action
Source code in src/alberta_framework/streams/gymnasium.py
make_epsilon_greedy_policy(base_policy, env, epsilon=0.1, seed=0)
¶
Wrap a policy with epsilon-greedy exploration.
Args: base_policy: The greedy policy to wrap env: Gymnasium environment (for random action sampling) epsilon: Probability of taking a random action seed: Random seed
Returns: Epsilon-greedy policy
Source code in src/alberta_framework/streams/gymnasium.py
collect_trajectory(env, policy, num_steps, mode=PredictionMode.REWARD, include_action_in_features=True, seed=0)
¶
Collect a trajectory from a Gymnasium environment.
This uses a Python loop to interact with the environment and collects observations and targets into JAX arrays that can be used with scan-based learning.
Args: env: Gymnasium environment instance policy: Action selection function. If None, uses random policy num_steps: Number of steps to collect mode: What to predict (REWARD, NEXT_STATE, VALUE) include_action_in_features: If True, features = concat(obs, action) seed: Random seed for environment resets and random policy
Returns: Tuple of (observations, targets) as JAX arrays with shape (num_steps, feature_dim) and (num_steps, target_dim)
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory using jax.lax.scan.
This is a JIT-compiled learning function that processes a trajectory collected from a Gymnasium environment.
Args: learner: The learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 3) with columns [squared_error, error, mean_step_size]
Source code in src/alberta_framework/streams/gymnasium.py
learn_from_trajectory_normalized(learner, observations, targets, learner_state=None)
¶
Learn from a pre-collected trajectory with normalization using jax.lax.scan.
Args: learner: The normalized learner to train observations: Array of observations with shape (num_steps, feature_dim) targets: Array of targets with shape (num_steps, target_dim) learner_state: Initial state (if None, will be initialized)
Returns: Tuple of (final_state, metrics_array) where metrics_array has shape (num_steps, 4) with columns [squared_error, error, mean_step_size, normalizer_mean_var]
Source code in src/alberta_framework/streams/gymnasium.py
make_gymnasium_stream(env_id, mode=PredictionMode.REWARD, policy=None, gamma=0.99, include_action_in_features=True, seed=0, **env_kwargs)
¶
Factory function to create a GymnasiumStream from an environment ID.
Args: env_id: Gymnasium environment ID (e.g., "CartPole-v1") mode: What to predict (REWARD, NEXT_STATE, VALUE) policy: Action selection function. If None, uses random policy gamma: Discount factor for VALUE mode include_action_in_features: If True, features = concat(obs, action) seed: Random seed **env_kwargs: Additional arguments passed to gymnasium.make()
Returns: GymnasiumStream wrapping the environment