Intro

Attention entropy collapse is a failure mode in ML where attention logits converge to a small number of states. In practice, this is often poorly monitored and patched with normalization. In this note, I discuss how to extend the attention algorithm to compute entropy online with the same GPU kernel.


Entropy

Consider data \(x=(x_1, \ldots, x_n)\) from \(p(x)\). Recall entropy \(-\mathbb{E}_{x \sim p}[\log p(x)]\). It is maximized by uniform distributions, and collapses to zero for delta distributions. Entropy is the Jensen gap between cross-entropy \(-\mathbb{E}_{x \sim p}[\log f(x)]\) and KL divergence \(D_{KL}(p || f)\) for any model \(f\).


Model

Let \(f_\theta\) be a parametric model for the \(n\) conditional distributions \(p(x_i | \bar x_i)\) where \(\bar x_i\) is the context \(x=(x_1, \ldots, \emptyset, \ldots, x_n)\). In ML, this is the typical subproblem of modeling \(p(x)\). Consider the simple attention model \begin{align*} q_i &= g_i\left(W_q x_i\right)\,, \quad k_j = g_j\left(W_k x_j\right), \quad v_j = W_v x_j \\[1em] \alpha_{ij} &= \frac{\exp(q_i \cdot k_j)}{\sum_{l=1}^n \exp(q_i \cdot k_l)}\,,\quad f_\theta(x)_i = \sum_{j=1}^n \alpha_{ij} v_j\,. \end{align*} The most commonly used \(g_i\) are compositions of positional encodings \(g_i(y) = e^{2\pi i \gamma} y\) and normalization \(g_i(y) = \gamma * \frac{y - \bar \mu}{\bar \sigma} + \beta\).


Fused Attention

Attention is computed with tiling and online softmax. 1) Rabe, Stats 'Self-Attention does not need \(O(n^2)\) memory' https://arxiv.org/pdf/2112.05682
2) Dao, Fu, et al. 'FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness' https://arxiv.org/pdf/2205.14135
3+) FlashAttention2, FlashAttention3, ...
The algorithm is best explained in Triton jargon (I find this more intuitive than CUDA). Notation is mostly consistent with code below.
The work is divided into parallel programs each processing a block of queries \(q_i\). Note that from now on \(i\) denotes a block of indices rather than a single index. Each program loops over tiles of keys and values to compute online softmax by updating accumulators \(m_i, l_i, acc_i\) as \begin{align*} s_{ij} &\leftarrow q_i \cdot k_j \\ m_{ij} &\leftarrow \max(m_i, s_{ij}) \\ l_{ij} &\leftarrow l_i * \exp(m_i - m_{ij}) + \sum_{\text{tile}} \exp(s_{ij} - m_i) \\ acc_i &\leftarrow acc_i * \frac{l_i}{l_{ij}} + \sum_{\text{tile}} \exp(s_{ij} - m_i) v_j\\ m_i &\leftarrow \max(m_i, m_{ij}) \\ l_i &\leftarrow l_{ij}\\ \end{align*} This computes the max \(m_i\) of dot products across all tiles and the partition function \(l_i = \sum_{j=1}^n \exp(s_{ij} - m_i)\). The output is \(f_\theta(x)_i = acc_i / l_i\). This algorithm can be extended to compute the entropy online with near zero additional overhead.


Online Entropy

We can rearrange the entropy computation as \begin{align*} H(p_i) &= -\sum_{j=1}^n p_{ij} \log p_{ij} \\ &= -\sum_{j=1}^n p_{ij} (s_{ij} - m_i - \log l_i) \\ &= m_i + \log l_i - \sum_{j=1}^n \exp(s_{ij} - m_i)s_{ij} \end{align*} The first term \(m_i + \log l_i\) is already computed online. The last term can be computed online as well using an additional accumulator for \(r_i = \sum_j \exp(s_{ij} - m_i) s_{ij}\) and update as: \begin{align*} r_i &\leftarrow r_i * \frac{l_i}{l_{ij}} + \sum_{\text{tile}} \exp(s_{ij} - m_i) s_{ij} \\ \end{align*} The final result for the entropy is then \(H(p_i) = m_i + \log l_i - r_i\).

GPU kernel

Any implementation of fused attention can be extended with this idea. This is particularly simple here because it only concerns the forward kernel. Below is an extension of OpenAI's Fused Attention. See my GitHub for full code and usage.

@triton.jit
def _attn_fwd(sm_scale, M,
    Z, H,
    ENTROPY,
    desc_q, desc_k, desc_v, desc_o, N_CTX,
    COMPUTE_ENTROPY: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    FP8_OUTPUT: tl.constexpr,
    STAGE: tl.constexpr,
    warp_specialize: tl.constexpr,
    IS_HOPPER: tl.constexpr,
):
    ...
    offset_y = off_z * (N_CTX * H) + off_h * N_CTX
    qo_offset_y = offset_y + start_m * BLOCK_M
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    if COMPUTE_ENTROPY:
        r_i = tl.zeros([BLOCK_M], dtype=tl.float32)

    ...
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    desc_o.store([qo_offset_y, 0], acc.to(dtype))
    if COMPUTE_ENTROPY:
        entropy_ptrs = ENTROPY + off_hz * N_CTX + offs_m
        tl.store(entropy_ptrs, m_i - r_i)

    

@triton.jit
def _attn_fwd_inner(
    acc, l_i, m_i,
    r_i,
    q,
    desc_k, desc_v,
    offset_y, dtype: tl.constexpr, start_m, qk_scale,
    COMPUTE_ENTROPY: tl.constexpr,
    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
    N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr
):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False
    else:
        lo, hi = 0, N_CTX
    offsetk_y = offset_y + lo
    if dtype == tl.float8e5:
        offsetv_y = offset_y * HEAD_DIM + lo
    else:
        offsetv_y = offset_y + lo
    # loop over k, v and update accumulator
    for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = desc_k.load([offsetk_y, 0]).T
        qk = tl.dot(q, k)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:, None]

        else:
            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            qk = qk * qk_scale - m_ij[:, None]

            qk = qk * qk_scale

        p = tl.math.exp2(qk)
        # -- compute correction factor
        alpha = tl.math.exp2(m_i - m_ij)
        l_ij = tl.sum(p, 1)
        # -- update output accumulator --
        if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
            BM: tl.constexpr = acc.shape[0]
            BN: tl.constexpr = acc.shape[1]
            acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
            acc0 = acc0 * alpha[:, None]
            acc1 = acc1 * alpha[:, None]
            acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
        else:
            acc = acc * alpha[:, None]
        # prepare p and v for the dot
        if dtype == tl.float8e5:
            v = desc_v.load([0, offsetv_y]).T
        else:
            v = desc_v.load([offsetv_y, 0])
        p = p.to(dtype)
        # note that this non transposed v for FP8 is only supported on Blackwell
        acc = tl.dot(p, v, acc)
        # update m_i and l_i
        # place this at the end of the loop to reduce register pressure
        l_i = l_i * alpha + l_ij
        m_i = m_ij
        if COMPUTE_ENTROPY:
            r_i = tl.dot(p, qk, l_i / l_ij * r_i)

        offsetk_y += BLOCK_N
        offsetv_y += BLOCK_N
    return acc, l_i, m_i