Skip to content

Gymnasium Integration

The framework can wrap Gymnasium RL environments as experience streams for prediction learning.

Optional Dependency

This requires the gymnasium extra: pip install alberta-framework[gymnasium]

Overview

Gymnasium environments become prediction problems by predicting:

  • Rewards: Predict immediate reward from (state, action)
  • Next states: Predict next observation from (state, action)
  • Values: Predict cumulative return (TD learning)

Basic Usage

from alberta_framework import LinearLearner, IDBD, run_learning_loop
from alberta_framework.streams.gymnasium import (
    make_gymnasium_stream,
    PredictionMode,
)

# Create a reward prediction stream
stream = make_gymnasium_stream(
    "CartPole-v1",
    mode=PredictionMode.REWARD,
    include_action_in_features=True,
    seed=42,
)

# Train a predictor
learner = LinearLearner(optimizer=IDBD())
state, metrics = run_learning_loop(
    learner=learner,
    stream=stream,
    num_steps=10000,
    key=jr.PRNGKey(0),
)

Prediction Modes

REWARD Mode

Predict the immediate reward:

  • Features: Current state (optionally with action)
  • Target: Reward received
stream = make_gymnasium_stream(
    "CartPole-v1",
    mode=PredictionMode.REWARD,
)

NEXT_STATE Mode

Predict the next observation:

  • Features: Current state and action
  • Target: Next state vector
stream = make_gymnasium_stream(
    "CartPole-v1",
    mode=PredictionMode.NEXT_STATE,
    include_action_in_features=True,  # Required for this mode
)

VALUE Mode

Predict cumulative return (for TD learning):

  • Features: Current state
  • Target: Bootstrapped value estimate
stream = make_gymnasium_stream(
    "CartPole-v1",
    mode=PredictionMode.VALUE,
    gamma=0.99,  # Discount factor
)

TD Learning with TDStream

For proper TD learning with value function bootstrap:

from alberta_framework.streams.gymnasium import TDStream
import gymnasium as gym

env = gym.make("CartPole-v1")
stream = TDStream(
    env=env,
    gamma=0.99,
    seed=42,
)

# The stream automatically computes TD targets:
# target = reward + gamma * V(next_state)

Updating the Value Function

TD learning requires updating the value estimator:

for step, timestep in enumerate(stream):
    # Make prediction
    prediction = learner.predict(state, timestep.observation)

    # Compute TD error
    error = timestep.target - prediction

    # Update learner
    result = learner.update(state, error, timestep.observation)
    state = result.new_state

    # Update stream's value function estimate
    stream.update_value_function(
        lambda obs: learner.predict(state, obs)
    )

Custom Policies

By default, streams use a random policy. Create custom policies:

from alberta_framework.streams.gymnasium import (
    make_random_policy,
    make_epsilon_greedy_policy,
)

# Random policy
random_policy = make_random_policy(env, seed=42)

# Epsilon-greedy wrapping another policy
def my_policy(obs):
    return my_action_selection(obs)

eps_policy = make_epsilon_greedy_policy(
    base_policy=my_policy,
    env=env,
    epsilon=0.1,
    seed=42,
)

# Use with stream
stream = make_gymnasium_stream(
    "CartPole-v1",
    policy=eps_policy,
)

Episode Handling

Streams automatically handle episode boundaries:

  • Reset environment when episode ends
  • Continue generating experience seamlessly
  • Track episode count via stream.episode_count
stream = make_gymnasium_stream("CartPole-v1")

for i, timestep in enumerate(stream):
    if i >= 10000:
        break

print(f"Completed {stream.episode_count} episodes")
print(f"Total steps: {stream.step_count}")

Feature Construction

The stream flattens observations and actions into feature vectors:

# CartPole-v1: 4-dim state
# With action (discrete 2): 4 + 2 = 6-dim features (one-hot action)

stream = make_gymnasium_stream(
    "CartPole-v1",
    include_action_in_features=True,
)
print(f"Feature dimension: {stream.feature_dim}")  # 6

Supported Environments

The framework supports:

  • Box observation spaces: Continuous state vectors
  • Discrete action spaces: One-hot encoded into features

Environments with complex observation spaces (Dict, Tuple) are flattened automatically.