Skip to content

Daemon / Single-Step Usage

This guide covers using alberta-framework learners in daemon-style deployments where observations arrive one at a time (e.g. rlsecd).

Single-Step API

Both MLPLearner and MultiHeadMLPLearner accept single unbatched 1D observations. This is the intended usage for online, per-event processing:

import jax.numpy as jnp
import jax.random as jr
from alberta_framework import MultiHeadMLPLearner, LMS, ObGDBounding, EMANormalizer

learner = MultiHeadMLPLearner(
    n_heads=5,
    hidden_sizes=(64, 64),
    optimizer=LMS(step_size=0.01),
    bounder=ObGDBounding(),
    normalizer=EMANormalizer(),
)
state = learner.init(feature_dim=12, key=jr.key(0))

# Single observation in, predictions out
observation = jnp.ones(12)
predictions = learner.predict(state, observation)  # shape (5,)

# Single-step update with NaN masking for inactive heads
targets = jnp.array([1.0, 0.5, jnp.nan, 0.3, jnp.nan])  # heads 2,4 inactive
result = learner.update(state, observation, targets)
state = result.state  # carry forward

JIT Compilation

predict() and update() are JIT-compiled automatically on both MLPLearner and MultiHeadMLPLearner. The first call triggers JAX's tracing; subsequent calls reuse the cached compilation.

For low-latency startup (avoiding a slow first real event), run a warmup call during initialization:

# Warmup at daemon startup
dummy_obs = jnp.zeros(feature_dim)
dummy_targets = jnp.full(n_heads, jnp.nan)
learner.predict(state, dummy_obs).block_until_ready()
learner.update(state, dummy_obs, dummy_targets)
# First real event is now fast (~0.3ms vs ~20ms without JIT)

Scan loops are unaffected

The jax.lax.scan-based learning loops (e.g. run_multi_head_learning_loop) already compile the outer scan. Nested JIT is a no-op in JAX, so the built-in JIT on predict/update adds zero overhead in scan contexts.

Checkpoints

Save and restore learner state across daemon restarts:

from alberta_framework import save_checkpoint, load_checkpoint

# Save state + daemon metadata
save_checkpoint(state, "agent.ckpt", metadata={
    "total_updates": 100_000,
    "daemon_version": "1.0",
})

# Load (template provides PyTree structure)
template = learner.init(feature_dim=12, key=jr.key(0))
loaded_state, meta = load_checkpoint(template, "agent.ckpt")
print(meta["total_updates"])  # 100000

Config Serialization

Round-trip the learner configuration (architecture, optimizer, bounder, normalizer) as a JSON-serializable dict:

# Save config
config = learner.to_config()
import json
with open("learner_config.json", "w") as f:
    json.dump(config, f)

# Reconstruct learner from config
with open("learner_config.json") as f:
    config = json.load(f)
learner = MultiHeadMLPLearner.from_config(config)
state = learner.init(feature_dim=12, key=jr.key(0))

This pairs with checkpoints: save the config alongside the state so a daemon can fully reconstruct itself on restart without hardcoding architecture parameters.

Feature Diagnostics

For periodic reporting (e.g. every 60s), extract per-feature relevance from the learner state at zero cost:

from alberta_framework import compute_feature_relevance, relevance_to_dict

relevance = compute_feature_relevance(state)
report = relevance_to_dict(
    relevance,
    feature_names=["src_ip", "dst_port", "payload_len", ...],
    head_names=["is_malicious", "attack_type", "stage", "severity", "value"],
)
# report is a JSON-serializable dict ready for logging/storage

For deeper analysis, compute input sensitivity via Jacobian (one forward pass per head, ~100-500us):

from alberta_framework import compute_feature_sensitivity

jacobian = compute_feature_sensitivity(learner, state, observation)
# shape: (n_heads, feature_dim) — sensitivity of each head to each input

Complete Daemon Pattern

Putting it all together:

import jax.numpy as jnp
import jax.random as jr
from alberta_framework import (
    MultiHeadMLPLearner, LMS, ObGDBounding, EMANormalizer,
    save_checkpoint, load_checkpoint,
    compute_feature_relevance, relevance_to_dict,
)

FEATURE_DIM = 12
N_HEADS = 5
CHECKPOINT_PATH = "agent.ckpt"

# 1. Create or restore learner
learner = MultiHeadMLPLearner(
    n_heads=N_HEADS,
    hidden_sizes=(64, 64),
    optimizer=LMS(step_size=0.01),
    bounder=ObGDBounding(),
    normalizer=EMANormalizer(),
)
state = learner.init(feature_dim=FEATURE_DIM, key=jr.key(42))

# Optional: restore from checkpoint
# template = learner.init(feature_dim=FEATURE_DIM, key=jr.key(0))
# state, meta = load_checkpoint(template, CHECKPOINT_PATH)

# 2. Warmup JIT
dummy_obs = jnp.zeros(FEATURE_DIM)
dummy_targets = jnp.full(N_HEADS, jnp.nan)
learner.predict(state, dummy_obs).block_until_ready()
learner.update(state, dummy_obs, dummy_targets)

# 3. Event loop
for event in event_source:
    obs = jnp.array(event.features, dtype=jnp.float32)

    # Predict
    predictions = learner.predict(state, obs)

    # Update (if labels available)
    if event.has_labels:
        targets = jnp.array(event.targets, dtype=jnp.float32)
        result = learner.update(state, obs, targets)
        state = result.state

    # Periodic checkpoint
    if event.step % 10_000 == 0:
        save_checkpoint(state, CHECKPOINT_PATH)

    # Periodic diagnostics
    if event.step % 1_000 == 0:
        rel = compute_feature_relevance(state)
        report = relevance_to_dict(rel)
        log_diagnostics(report)