multi_head_learner
multi_head_learner
¶
Multi-head MLP learner for multi-task continual learning.
Implements a shared-trunk, multi-head MLP architecture where hidden layers are shared across prediction heads. Each head can be independently active or inactive at each time step (NaN targets = inactive).
Architecture: Input -> [Dense(H) -> LayerNorm -> LeakyReLU] x N -> {Head_i: Dense(1)} x n_heads
When use_layer_norm=False:
Input -> [Dense(H) -> LeakyReLU] x N -> {Head_i: Dense(1)} x n_heads
The update uses VJP with accumulated cotangents to perform a single backward pass through the trunk regardless of the number of heads.
Reference: Elsayed et al. 2024, "Streaming Deep Reinforcement Learning Finally Works"
MultiHeadMLPState
¶
State for a multi-head MLP learner.
The trunk (shared hidden layers) and heads (per-task output layers) maintain separate parameters, optimizer states, and eligibility traces.
Trunk optimizer states and traces use an interleaved layout
(w0, b0, w1, b1, ...) matching the MLPLearner convention.
Head optimizer states and traces use a nested layout
((w_opt, b_opt), ...) indexed by head.
Attributes:
trunk_params: Shared hidden layer parameters
head_params: Per-head output layer parameters.
weights[i] / biases[i] = head i.
trunk_optimizer_states: Interleaved (w0, b0, w1, b1, ...)
optimizer states for trunk layers
head_optimizer_states: Per-head ((w_opt, b_opt), ...)
trunk_traces: Interleaved (w0, b0, w1, b1, ...)
eligibility traces for trunk layers
head_traces: Per-head ((w_trace, b_trace), ...)
normalizer_state: Optional online feature normalizer state
step_count: Scalar step counter
MultiHeadMLPUpdateResult
¶
Result of a multi-head MLP learner update step.
Attributes:
state: Updated multi-head MLP learner state
predictions: Predictions from all heads, shape (n_heads,)
errors: Prediction errors, shape (n_heads,). NaN for inactive heads.
per_head_metrics: Per-head metrics, shape (n_heads, 3).
Columns: [squared_error, raw_error, mean_step_size].
NaN for inactive heads.
trunk_bounding_metric: Scalar trunk bounding metric
MultiHeadLearningResult
¶
Result from multi-head learning loop.
Attributes:
state: Final multi-head MLP learner state
per_head_metrics: Per-head metrics over time,
shape (num_steps, n_heads, 3)
BatchedMultiHeadResult
¶
Result from batched multi-head learning loop.
Attributes:
states: Batched multi-head MLP learner states
per_head_metrics: Per-head metrics,
shape (n_seeds, num_steps, n_heads, 3)
MultiHeadMLPLearner(n_heads, hidden_sizes=(128, 128), optimizer=None, step_size=1.0, bounder=None, gamma=0.0, lamda=0.0, normalizer=None, sparsity=0.9, leaky_relu_slope=0.01, use_layer_norm=True, head_optimizer=None, per_head_gamma_lamda=None)
¶
Multi-head MLP with shared trunk and independent prediction heads.
Architecture:
Input -> [Dense(H) -> LayerNorm -> LeakyReLU] x N -> {Head_i: Dense(1)} x n_heads
All hidden layers are shared (the trunk). Each head is an independent linear projection from the last hidden representation to a scalar.
The update method uses VJP with accumulated cotangents so that
only one backward pass through the trunk is needed regardless of the
number of active heads.
Trunk trace constraint: When hidden_sizes is non-empty (MLP mode),
trunk gamma * lamda must be 0. The VJP backward pass folds per-head
errors into the trunk cotangent before trace accumulation, so traces
accumulate error-weighted gradients. For gamma * lamda = 0 this is
correct (traces reset each step). For gamma * lamda > 0 it would
produce biased trace updates that violate forward-view equivalence
(Sutton & Barto Ch. 12). Use HordeLearner for per-head trace decay
— it sets trunk gamma=0, lamda=0 and applies per-head
gamma * lambda only to the head layers. For linear baselines
(hidden_sizes=()), there is no trunk, so any gamma * lamda is fine.
Attributes:
n_heads: Number of prediction heads
hidden_sizes: Tuple of hidden layer sizes. Pass () for a multi-head
linear model (heads project directly from input features).
optimizer: Optimizer for per-weight step-size adaptation
bounder: Optional update bounder (e.g. ObGDBounding)
normalizer: Optional feature normalizer
use_layer_norm: Whether to apply parameterless layer normalization
gamma: Discount factor for trace decay
lamda: Eligibility trace decay parameter
sparsity: Fraction of weights zeroed out per output neuron
leaky_relu_slope: Negative slope for LeakyReLU activation
Single-Step (Daemon) Usage
Both predict() and update() work with single unbatched
observations (1D arrays). This is the intended usage for daemon-style
deployments where one observation arrives at a time.
Both methods are JIT-compiled automatically. The first call triggers JAX's tracing; subsequent calls use the cached compilation. For low-latency startup, run a warmup call so the first real event is fast:
# At daemon startup, after learner.init():
dummy_obs = jnp.zeros(feature_dim)
dummy_targets = jnp.full(n_heads, jnp.nan)
learner.predict(state, dummy_obs).block_until_ready() # Warmup trace
learner.update(state, dummy_obs, dummy_targets) # Warmup trace
# First real event will now be fast
NaN target masking works per-step: pass jnp.nan for any head
that should not update. Inactive heads preserve their params,
traces, and optimizer states.
Args:
n_heads: Number of prediction heads
hidden_sizes: Tuple of hidden layer sizes (default: two layers of 128)
optimizer: Optimizer for weight updates. Defaults to LMS(step_size).
Must support init_for_shape and update_from_gradient.
step_size: Base learning rate (used only when optimizer is None)
bounder: Optional update bounder (e.g. ObGDBounding)
gamma: Discount factor for trace decay (default: 0.0 for supervised)
lamda: Eligibility trace decay parameter (default: 0.0)
normalizer: Optional feature normalizer
sparsity: Fraction of weights zeroed out per neuron (default: 0.9)
leaky_relu_slope: Negative slope for LeakyReLU (default: 0.01)
use_layer_norm: Whether to apply parameterless layer normalization
(default: True)
head_optimizer: Optional separate optimizer for the output heads.
When None (default), all layers use optimizer. When set,
trunk (hidden) layers use optimizer while each head uses
head_optimizer. This enables hybrid configurations like
stable LMS for the trunk with adaptive Autostep for the heads.
per_head_gamma_lamda: Optional per-head trace decay factors.
When set, each head uses its own gamma * lambda product
for trace decay instead of the global gamma * lamda.
Length must equal n_heads. Used by HordeLearner
to assign per-demon discount/trace parameters.
Source code in src/alberta_framework/core/multi_head_learner.py
n_heads
property
¶
Number of prediction heads.
normalizer
property
¶
The feature normalizer, or None if normalization is disabled.
to_config()
¶
Serialize learner configuration to dict.
Returns:
Dict with all constructor arguments needed to recreate
the learner via from_config().
Source code in src/alberta_framework/core/multi_head_learner.py
from_config(config)
classmethod
¶
Reconstruct learner from a config dict.
Args:
config: Dict as produced by to_config()
Returns: Reconstructed MultiHeadMLPLearner instance
Source code in src/alberta_framework/core/multi_head_learner.py
init(feature_dim, key)
¶
Initialize multi-head MLP learner state with sparse weights.
Args: feature_dim: Dimension of the input feature vector key: JAX random key for weight initialization
Returns: Initial state with sparse trunk weights, zero biases, and per-head output layers
Source code in src/alberta_framework/core/multi_head_learner.py
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 | |
predict(state, observation)
¶
Compute predictions from all heads.
JIT-compiled automatically. First call triggers tracing; subsequent calls with the same learner instance use the cached compilation.
Args: state: Current multi-head MLP learner state observation: Input feature vector
Returns:
Array of shape (n_heads,) with one prediction per head
Source code in src/alberta_framework/core/multi_head_learner.py
update(state, observation, targets)
¶
Update multi-head MLP given observation and per-head targets.
JIT-compiled automatically. Uses VJP with accumulated cotangents for a single backward pass through the trunk. Error from each active head is folded into the trunk gradient before trace accumulation.
Args:
state: Current state
observation: Input feature vector
targets: Per-head targets, shape (n_heads,).
NaN = inactive head.
Returns: MultiHeadMLPUpdateResult with updated state, predictions, errors, and per-head metrics
Source code in src/alberta_framework/core/multi_head_learner.py
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 | |
multi_head_metrics_to_dicts(result)
¶
Convert per-head metrics array to list of dicts for online use.
Active heads get a dict with keys 'squared_error', 'error',
'mean_step_size'. Inactive heads get None.
Args:
result: Update result from MultiHeadMLPLearner.update
Returns:
List of n_heads entries, one per head
Source code in src/alberta_framework/core/multi_head_learner.py
run_multi_head_learning_loop(learner, state, observations, targets)
¶
Run multi-head learning loop using jax.lax.scan.
Scans over pre-provided observation and target arrays. This is
designed for settings where data comes from an external source
(e.g. security event logs) rather than from a ScanStream.
Args:
learner: Multi-head MLP learner
state: Initial learner state
observations: Input observations, shape (num_steps, feature_dim)
targets: Per-head targets, shape (num_steps, n_heads).
NaN = inactive head for that step.
Returns:
MultiHeadLearningResult with final state and per-head metrics
of shape (num_steps, n_heads, 3)
Source code in src/alberta_framework/core/multi_head_learner.py
run_multi_head_learning_loop_batched(learner, observations, targets, keys)
¶
Run multi-head learning loop across seeds using jax.vmap.
Each seed produces an independently initialized state (different sparse weight masks). All seeds share the same observations and targets.
Args:
learner: Multi-head MLP learner
observations: Shared observations, shape (num_steps, feature_dim)
targets: Shared targets, shape (num_steps, n_heads).
NaN = inactive head.
keys: JAX random keys, shape (n_seeds,) or (n_seeds, 2)
Returns:
BatchedMultiHeadResult with batched states and per-head metrics
of shape (n_seeds, num_steps, n_heads, 3)