Create a sparsely initialized weight matrix.
Applies LeCun-scale initialization and then zeros out a fraction of
weights per output neuron. This creates sparser gradient flows that
improve stability in streaming learning settings.
Reference: Elsayed et al. 2024, sparse_init.py
Args:
key: JAX random key
shape: Weight matrix shape (fan_out, fan_in)
sparsity: Fraction of input connections to zero out per output neuron
(default: 0.9 means 90% sparse)
init_type: Initialization distribution, "uniform" or "normal"
(default: "uniform" for LeCun uniform)
Returns:
Weight matrix of given shape with specified sparsity
Examples:
import jax.random as jr
from alberta_framework.core.initializers import sparse_init
key = jr.key(42)
weights = sparse_init(key, (128, 10), sparsity=0.9)
# weights has shape (128, 10), ~90% zeros per row
Source code in src/alberta_framework/core/initializers.py
| def sparse_init(
key: Array,
shape: tuple[int, int],
sparsity: float = 0.9,
init_type: str = "uniform",
) -> Float[Array, "fan_out fan_in"]:
"""Create a sparsely initialized weight matrix.
Applies LeCun-scale initialization and then zeros out a fraction of
weights per output neuron. This creates sparser gradient flows that
improve stability in streaming learning settings.
Reference: Elsayed et al. 2024, sparse_init.py
Args:
key: JAX random key
shape: Weight matrix shape (fan_out, fan_in)
sparsity: Fraction of input connections to zero out per output neuron
(default: 0.9 means 90% sparse)
init_type: Initialization distribution, "uniform" or "normal"
(default: "uniform" for LeCun uniform)
Returns:
Weight matrix of given shape with specified sparsity
Examples:
```python
import jax.random as jr
from alberta_framework.core.initializers import sparse_init
key = jr.key(42)
weights = sparse_init(key, (128, 10), sparsity=0.9)
# weights has shape (128, 10), ~90% zeros per row
```
"""
fan_out, fan_in = shape
num_zeros = int(sparsity * fan_in + 0.5) # round to nearest int
# Split key for init and sparsity mask
init_key, mask_key = jr.split(key)
# LeCun-scale initialization
scale = 1.0 / fan_in**0.5
if init_type == "uniform":
weights = jr.uniform(init_key, shape, dtype=jnp.float32, minval=-scale, maxval=scale)
elif init_type == "normal":
weights = jr.normal(init_key, shape, dtype=jnp.float32) * scale
else:
raise ValueError(f"init_type must be 'uniform' or 'normal', got '{init_type}'")
# Create sparsity mask: for each output neuron, zero out num_zeros inputs
# Use vmap over output neurons with independent random permutations
row_keys = jr.split(mask_key, fan_out)
def make_row_mask(row_key: Array) -> Float[Array, " fan_in"]:
"""Create a binary mask for a single output neuron."""
perm = jr.permutation(row_key, fan_in)
# mask[i] = 1 if perm[i] >= num_zeros, else 0
mask = (perm >= num_zeros).astype(jnp.float32)
return mask
masks = jax.vmap(make_row_mask)(row_keys) # (fan_out, fan_in)
return weights * masks
|