sarsa
sarsa
¶
SARSA agent: on-policy control via Horde (Sutton & Barto Ch. 10).
Wraps HordeLearner with epsilon-greedy action selection and SARSA
target computation. Each action maps to a control demon (head) in the
Horde. The SARSA target r + gamma * Q(s', a') is computed externally
and passed as the cumulant to the Horde, so control demons use gamma=0
internally (single-step prediction of the externally-computed target).
This avoids modifying the Horde's TD target logic: the real discount
lives in SARSAConfig.gamma, while each control demon sees its
cumulant as a supervised target.
Optionally, prediction demons can coexist with control demons in the same Horde — they learn alongside the Q-heads without interference.
Reference: Sutton & Barto 2018, Section 10.1 (Episodic Semi-gradient SARSA)
SARSAConfig
¶
Configuration for SARSA agent.
Attributes: n_actions: Number of discrete actions gamma: Discount factor for SARSA targets (default: 0.99) epsilon_start: Initial exploration rate (default: 0.1) epsilon_end: Final exploration rate (default: 0.01) epsilon_decay_steps: Steps over which epsilon decays linearly. 0 = no decay (constant epsilon_start).
to_config()
¶
Serialize to dict.
Source code in src/alberta_framework/core/sarsa.py
SARSAState
¶
State for the SARSA agent.
Attributes: learner_state: Underlying Horde/MultiHeadMLPLearner state last_action: Action taken at previous step (a_t) last_observation: Observation at previous step (s_t) epsilon: Current exploration rate rng_key: JAX random key for action selection step_count: Number of SARSA update steps taken
SARSAUpdateResult
¶
Result of a single SARSA update step.
Attributes: state: Updated SARSA state (includes new action a_{t+1}) action: Next action a_{t+1} selected for the new state q_values: Q-values for all actions at s_{t+1} td_error: TD error for the taken action reward: Reward received
SARSAEpisodeResult(state, total_reward, num_steps, rewards, q_values, td_errors)
dataclass
¶
Result from running one episode of SARSA.
Not a chex dataclass — used in Python loops with native Python types.
Attributes: state: Final SARSA state total_reward: Sum of rewards in the episode num_steps: Number of steps taken rewards: Per-step rewards q_values: Per-step Q-values td_errors: Per-step TD errors
SARSAContinuingResult(state, total_reward, rewards, q_values, td_errors)
dataclass
¶
Result from running SARSA in continuing mode.
Not a chex dataclass — used in Python loops with native Python types.
Attributes: state: Final SARSA state total_reward: Sum of rewards over all steps rewards: Per-step rewards q_values: Per-step Q-values td_errors: Per-step TD errors
SARSAArrayResult
¶
Result from scan-based SARSA on pre-collected arrays.
Attributes:
state: Final SARSA state
q_values: Per-step Q-values, shape (num_steps, n_actions)
td_errors: Per-step TD errors, shape (num_steps,)
actions: Per-step actions taken, shape (num_steps,)
SARSAAgent(sarsa_config, hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None, prediction_demons=None, lamda=0.0)
¶
On-policy SARSA control agent via Horde architecture.
Wraps HordeLearner with epsilon-greedy action selection and
SARSA target computation. Each action maps to a control demon (head)
in the Horde. The SARSA target r + gamma * Q(s', a') is computed
externally and passed as the cumulant, so control demons use gamma=0
internally.
Optionally, additional prediction demons can coexist with the control demons — they learn alongside the Q-heads.
Single-Step (Daemon) Usage
Both select_action() and update() work with single unbatched
observations (1D arrays). JIT-compiled automatically.
Attributes: sarsa_config: SARSA configuration horde: The underlying HordeLearner n_actions: Number of discrete actions
Args: sarsa_config: SARSA configuration (n_actions, gamma, epsilon) hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128) optimizer: Optimizer for weight updates. Defaults to LMS(step_size). step_size: Base learning rate (used only when optimizer is None) bounder: Optional update bounder (e.g. ObGDBounding) normalizer: Optional feature normalizer sparsity: Fraction of weights zeroed out per neuron (default: 0.9) leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01) use_layer_norm: Whether to apply parameterless layer normalization head_optimizer: Optional separate optimizer for heads prediction_demons: Optional additional prediction demons to learn alongside Q-heads. These are appended after the control demons in the Horde. lamda: Trace decay for control demon heads (default: 0.0)
Source code in src/alberta_framework/core/sarsa.py
sarsa_config
property
¶
The SARSA configuration.
horde
property
¶
The underlying HordeLearner.
n_actions
property
¶
Number of discrete actions.
to_config()
¶
Serialize agent configuration to dict.
Source code in src/alberta_framework/core/sarsa.py
from_config(config)
classmethod
¶
Reconstruct from config dict.
Source code in src/alberta_framework/core/sarsa.py
init(feature_dim, key)
¶
Initialize SARSA agent state.
Args: feature_dim: Dimension of the input feature vector key: JAX random key
Returns: Initial SARSAState with zeroed last_action/observation
Source code in src/alberta_framework/core/sarsa.py
select_action(state, observation)
¶
Select action via epsilon-greedy over Q-values.
JIT-compiled. Uses Gumbel trick for uniform tie-breaking among
equal Q-values (avoids left-side bias from jnp.argmax).
Args: state: Current SARSA state (uses rng_key and epsilon) observation: Input feature vector
Returns: Tuple of (action, new_rng_key)
Source code in src/alberta_framework/core/sarsa.py
update(state, reward, observation, terminated, next_action, prediction_cumulants=None)
¶
Perform one SARSA update step.
Computes the SARSA target r + gamma * Q(s', a') and updates
the Horde. Only the previously-taken action's head receives the
target; all other Q-heads get NaN (no update).
Args:
state: Current SARSA state
reward: Reward r received after taking last_action in last_obs
observation: New observation s' (state we transitioned to)
terminated: Whether s' is terminal (scalar bool/float)
next_action: Action a' selected for s' (pre-computed)
prediction_cumulants: Optional cumulants for prediction demons,
shape (n_prediction_demons,). NaN for inactive demons.
Returns: SARSAUpdateResult with updated state, Q-values, TD error
Source code in src/alberta_framework/core/sarsa.py
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 | |
run_sarsa_episode(agent, state, env, max_steps=10000)
¶
Run one episode of SARSA on a Gymnasium environment.
Python loop (env interaction not JIT-able). Follows the SARSA pattern: select a' before updating, so the update uses the on-policy next action.
Args: agent: SARSA agent state: Initial SARSA state env: Gymnasium environment max_steps: Maximum steps per episode
Returns: SARSAEpisodeResult with episode metrics
Source code in src/alberta_framework/core/sarsa.py
run_sarsa_continuing(agent, state, env, num_steps)
¶
Run SARSA in continuing mode for a fixed number of steps.
At episode boundaries, the environment auto-resets. gamma is set to 0
at pseudo-boundaries (terminal/truncated) to prevent bootstrapping
across resets, matching the ContinuingWrapper pattern.
Args: agent: SARSA agent state: Initial SARSA state env: Gymnasium environment num_steps: Number of steps to run
Returns: SARSAContinuingResult with step-level metrics
Source code in src/alberta_framework/core/sarsa.py
run_sarsa_from_arrays(agent, state, observations, rewards, terminated, next_observations)
¶
Run SARSA on pre-collected arrays via jax.lax.scan.
JIT-compiled for maximum throughput. Actions are selected on-policy within the scan. This is the primary loop for security-gym data where observations are pre-collected.
Args:
agent: SARSA agent
state: Initial SARSA state (must have valid last_action, last_observation)
observations: Current observations, shape (num_steps, feature_dim)
rewards: Rewards, shape (num_steps,)
terminated: Termination flags, shape (num_steps,)
next_observations: Next observations, shape (num_steps, feature_dim)
Returns: SARSAArrayResult with per-step Q-values, TD errors, and actions