Attention entropy collapse is a failure mode in ML where attention logits converge to a small number of states. In practice, this is often poorly monitored and patched with normalization. In this note, I discuss how to extend the attention algorithm to compute entropy online with the same GPU kernel.
Consider data \(x=(x_1, \ldots, x_n)\) from \(p(x)\). Recall entropy \(-\mathbb{E}_{x \sim p}[\log p(x)]\). It is maximized by uniform distributions, and collapses to zero for delta distributions. Entropy is the Jensen gap between cross-entropy \(-\mathbb{E}_{x \sim p}[\log f(x)]\) and KL divergence \(D_{KL}(p || f)\) for any model \(f\).
Let \(f_\theta\) be a parametric model for the \(n\) conditional distributions \(p(x_i | \bar x_i)\) where \(\bar x_i\) is the context \(x=(x_1, \ldots, \emptyset, \ldots, x_n)\). In ML, this is the typical subproblem of modeling \(p(x)\). Consider the simple attention model \begin{align*} q_i &= g_i\left(W_q x_i\right)\,, \quad k_j = g_j\left(W_k x_j\right), \quad v_j = W_v x_j \\[1em] \alpha_{ij} &= \frac{\exp(q_i \cdot k_j)}{\sum_{l=1}^n \exp(q_i \cdot k_l)}\,,\quad f_\theta(x)_i = \sum_{j=1}^n \alpha_{ij} v_j\,. \end{align*} The most commonly used \(g_i\) are compositions of positional encodings \(g_i(y) = e^{2\pi i \gamma} y\) and normalization \(g_i(y) = \gamma * \frac{y - \bar \mu}{\bar \sigma} + \beta\).
Attention is computed with tiling and online softmax.
1) Rabe, Stats 'Self-Attention does not need \(O(n^2)\) memory' https://arxiv.org/pdf/2112.05682
2) Dao, Fu, et al. 'FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness' https://arxiv.org/pdf/2205.14135
3+) FlashAttention2, FlashAttention3, ...
The algorithm is best explained in Triton jargon (I find this more intuitive than CUDA). Notation is mostly consistent with code below.
The work is divided into parallel programs each processing a block of queries \(q_i\).
Note that from now on \(i\) denotes a block of indices rather than a single index.
Each program loops over tiles of keys and values to compute online softmax by updating accumulators \(m_i, l_i, acc_i\) as
\begin{align*}
s_{ij} &\leftarrow q_i \cdot k_j \\
m_{ij} &\leftarrow \max(m_i, s_{ij}) \\
l_{ij} &\leftarrow l_i * \exp(m_i - m_{ij}) + \sum_{\text{tile}} \exp(s_{ij} - m_i) \\
acc_i &\leftarrow acc_i * \frac{l_i}{l_{ij}} + \sum_{\text{tile}} \exp(s_{ij} - m_i) v_j\\
m_i &\leftarrow \max(m_i, m_{ij}) \\
l_i &\leftarrow l_{ij}\\
\end{align*}
This computes the max \(m_i\) of dot products across all tiles and the partition function \(l_i = \sum_{j=1}^n \exp(s_{ij} - m_i)\). The output is \(f_\theta(x)_i = acc_i / l_i\).
This algorithm can be extended to compute the entropy online with near zero additional overhead.
@triton.jit
def _attn_fwd(sm_scale, M,
Z, H,
ENTROPY,
desc_q, desc_k, desc_v, desc_o, N_CTX,
COMPUTE_ENTROPY: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
FP8_OUTPUT: tl.constexpr,
STAGE: tl.constexpr,
warp_specialize: tl.constexpr,
IS_HOPPER: tl.constexpr,
):
...
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
qo_offset_y = offset_y + start_m * BLOCK_M
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
if COMPUTE_ENTROPY:
r_i = tl.zeros([BLOCK_M], dtype=tl.float32)
...
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
desc_o.store([qo_offset_y, 0], acc.to(dtype))
if COMPUTE_ENTROPY:
entropy_ptrs = ENTROPY + off_hz * N_CTX + offs_m
tl.store(entropy_ptrs, m_i - r_i)
@triton.jit
def _attn_fwd_inner(
acc, l_i, m_i,
r_i,
q,
desc_k, desc_v,
offset_y, dtype: tl.constexpr, start_m, qk_scale,
COMPUTE_ENTROPY: tl.constexpr,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
offsetk_y = offset_y + lo
if dtype == tl.float8e5:
offsetv_y = offset_y * HEAD_DIM + lo
else:
offsetv_y = offset_y + lo
# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = desc_k.load([offsetk_y, 0]).T
qk = tl.dot(q, k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
qk = qk * qk_scale
p = tl.math.exp2(qk)
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# -- update output accumulator --
if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
BM: tl.constexpr = acc.shape[0]
BN: tl.constexpr = acc.shape[1]
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
else:
acc = acc * alpha[:, None]
# prepare p and v for the dot
if dtype == tl.float8e5:
v = desc_v.load([0, offsetv_y]).T
else:
v = desc_v.load([offsetv_y, 0])
p = p.to(dtype)
# note that this non transposed v for FP8 is only supported on Blackwell
acc = tl.dot(p, v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure
l_i = l_i * alpha + l_ij
m_i = m_ij
if COMPUTE_ENTROPY:
r_i = tl.dot(p, qk, l_i / l_ij * r_i)
offsetk_y += BLOCK_N
offsetv_y += BLOCK_N
return acc, l_i, m_i