Intro

Attention entropy collapse is a failure mode in ML where attention logits converge to a small number of states. In practice, this is often poorly monitored and patched with normalization. In this note, I discuss how to extend the attention algorithm to compute entropy online with the same GPU kernel.


Entropy

Consider data \(x=(x_1, \ldots, x_n)\) from \(p(x)\). Recall entropy \(-\mathbb{E}_{x \sim p}[\log p(x)]\). It is maximized by uniform distributions, and collapses to zero for delta distributions. Entropy is the Jensen gap between cross-entropy \(-\mathbb{E}_{x \sim p}[\log f(x)]\) and KL divergence \(D_{KL}(p || f)\) for any model \(f\).


Model

Let \(f_\theta\) be a parametric model for the \(n\) conditional distributions \(p(x_i | \bar x_i)\) where \(\bar x_i\) is the context \(x=(x_1, \ldots, \emptyset, \ldots, x_n)\). In ML, this is the typical subproblem of modeling \(p(x)\). Consider the simple attention model \begin{align*} q_i &= g_i\left(W_q x_i\right)\,, \quad k_j = g_j\left(W_k x_j\right), \quad v_j = W_v x_j \\[1em] \alpha_{ij} &= \frac{\exp(q_i \cdot k_j)}{\sum_{l=1}^n \exp(q_i \cdot k_l)}\,,\quad f_\theta(x)_i = \sum_{j=1}^n \alpha_{ij} v_j\,. \end{align*} The most commonly used \(g_i\) are compositions of positional encodings \(g_i(y) = e^{2\pi i \gamma} y\) and normalization \(g_i(y) = \gamma * \frac{y - \bar \mu}{\bar \sigma} + \beta\).


Fused Attention

Attention is computed with tiling and online softmax. 1) Rabe, Stats 'Self-Attention does not need \(O(n^2)\) memory' https://arxiv.org/pdf/2112.05682
2) Dao, Fu, et al. 'FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness' https://arxiv.org/pdf/2205.14135
3+) FlashAttention2, FlashAttention3, ...
The algorithm is best explained in Triton jargon (I find this more intuitive than CUDA). Notation is mostly consistent with OpenAI's kernel.
The work is divided into parallel programs each processing a block of queries \(q_i\). Note that from now on \(i\) denotes a block of indices rather than a single index. Each program loops over tiles of keys and values to compute online softmax by updating accumulators \(m_i, l_i, acc_i\) as \begin{align*} s_{ij} &\leftarrow q_i \cdot k_j \\ m_{ij} &\leftarrow \max(m_i, s_{ij}) \\ l_{ij} &\leftarrow l_i * \exp(m_i - m_{ij}) + \sum_{\text{tile}} \exp(s_{ij} - m_i) \\ acc_i &\leftarrow acc_i * \frac{l_i}{l_{ij}} + \sum_{\text{tile}} \exp(s_{ij} - m_i) v_j\\ m_i &\leftarrow \max(m_i, m_{ij}) \\ l_i &\leftarrow l_{ij}\\ \end{align*} This computes the max \(m_i\) of dot products across all tiles and the partition function \(l_i = \sum_{j=1}^n \exp(s_{ij} - m_i)\). The output is \(f_\theta(x)_i = acc_i / l_i\). This algorithm can be extended to compute the entropy online with near zero additional overhead.


Online Entropy

We can rearrange the entropy computation as \begin{align*} H(p_i) &= -\sum_{j=1}^n p_{ij} \log p_{ij} \\ &= -\sum_{j=1}^n p_{ij} (s_{ij} - m_i - \log l_i) \\ &= m_i + \log l_i - \sum_{j=1}^n \exp(s_{ij} - m_i)s_{ij} \end{align*} The first term \(m_i + \log l_i\) is already computed online. The last term can be computed online as well using an additional accumulator for \(r_i = \sum_j \exp(s_{ij} - m_i) s_{ij}\) and update as: \begin{align*} r_i &\leftarrow r_i * \frac{l_i}{l_{ij}} + \sum_{\text{tile}} \exp(s_{ij} - m_i) s_{ij} \\ \end{align*} The final result for the entropy is then \(H(p_i) = m_i + \log l_i - r_i\).

GPU kernel

Any implementation of fused attention can be extended with this idea. This is particularly simple here because it only concerns the forward kernel. Below is a minimal extension of OpenAI's kernel.

@triton.jit
def _attn_fwd(
    sm_scale, M, Z, H,
    Ent,
    desc_q, desc_k, desc_v, desc_o, 
    N_CTX,
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    FP8_OUTPUT: tl.constexpr,
    STAGE: tl.constexpr,
    warp_specialize: tl.constexpr,
    IS_HOPPER: tl.constexpr,
    COMPUTE_ENTROPY: tl.constexpr,
):
    ...
    offset_y = off_z * (N_CTX * H) + off_h * N_CTX
    qo_offset_y = offset_y + start_m * BLOCK_M
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    if COMPUTE_ENTROPY:
        r_i = tl.zeros([BLOCK_M], dtype=tl.float32)

    ...
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i, r_i = _attn_fwd_inner(
            acc, l_i, m_i, q,  #
            desc_k, desc_v,  #
            offset_y, dtype, start_m, qk_scale,  #
            BLOCK_M, HEAD_DIM, BLOCK_N,  #
            4 - STAGE, offs_m, offs_n, N_CTX,  #
            warp_specialize, IS_HOPPER,
            COMPUTE_ENTROPY,
        )
    # stage 2: on-band
    if STAGE & 2:
        acc, l_i, m_i, r_i = _attn_fwd_inner(
            acc, l_i, m_i, q,  #
            desc_k, desc_v,  #
            offset_y, dtype, start_m, qk_scale,  #
            BLOCK_M, HEAD_DIM, BLOCK_N,  #
            2, offs_m, offs_n, N_CTX,  #
            warp_specialize, IS_HOPPER,
            COMPUTE_ENTROPY,
        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    desc_o.store([qo_offset_y, 0], acc.to(dtype))
    if COMPUTE_ENTROPY:
        entropy_ptrs = Ent + off_hz * N_CTX + offs_m
        tl.store(entropy_ptrs, m_i - r_i)

    

@triton.jit
def _attn_fwd_inner(
    acc, l_i, m_i,  r_i,
    q, desc_k, desc_v,
    offset_y, dtype: tl.constexpr, start_m, qk_scale,
    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
    N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr,
    COMPUTE_ENTROPY: tl.constexpr,
):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False
    else:
        lo, hi = 0, N_CTX
    offsetk_y = offset_y + lo
    if dtype == tl.float8e5:
        offsetv_y = offset_y * HEAD_DIM + lo
    else:
        offsetv_y = offset_y + lo
    # loop over k, v and update accumulator
    for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = desc_k.load([offsetk_y, 0]).T
        qk = tl.dot(q, k)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)

            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk2 = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            qk = qk2 - m_ij[:, None]

        else:
            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            qk = qk * qk_scale - m_ij[:, None]

            qk2 = qk * qk_scale
            qk = qk2 - m_ij[:, None]
            
        p = tl.math.exp2(qk)
        # -- compute correction factor
        alpha = tl.math.exp2(m_i - m_ij)
        l_ij = tl.sum(p, 1)
        # -- update output accumulator --
        if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
            BM: tl.constexpr = acc.shape[0]
            BN: tl.constexpr = acc.shape[1]
            acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
            acc0 = acc0 * alpha[:, None]
            acc1 = acc1 * alpha[:, None]
            acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
        else:
            acc = acc * alpha[:, None]
        # prepare p and v for the dot
        if dtype == tl.float8e5:
            v = desc_v.load([0, offsetv_y]).T
        else:
            v = desc_v.load([offsetv_y, 0])
        p = p.to(dtype)
        # note that this non transposed v for FP8 is only supported on Blackwell
        acc = tl.dot(p, v, acc)
        # update m_i and l_i
        # place this at the end of the loop to reduce register pressure
        l_i = l_i * alpha + l_ij
        m_i = m_ij
        if COMPUTE_ENTROPY:
            r_i = tl.dot(p, qk2, l_i / l_ij * r_i)

        offsetk_y += BLOCK_N
        offsetv_y += BLOCK_N
    return acc, l_i, m_i,  r_i
    

class _attention(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx, q, k, v,
        causal, sm_scale,
        warp_specialize=True,
        compute_entropy=False,
        
    ):
        # shape constraints
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        # when v is in float8_e5m2 it is transposed.
        HEAD_DIM_V = v.shape[-1]
        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        o = torch.empty_like(q)
        stage = 3 if causal else 1
        extra_kern_args = {}
        # Tuning for AMD target
        if is_hip():
            waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        ent = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
    
        # Use device_descriptor for Hopper + warpspec.
        if supports_host_descriptor() and not (is_hopper() and warp_specialize):
            # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
            y_dim = q.shape[0] * q.shape[1] * q.shape[2]

            dummy_block = [1, 1]
            desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
            if q.dtype == torch.float8_e5m2:
                desc_v = TensorDescriptor(v, shape=[HEAD_DIM_K, y_dim], strides=[q.shape[2], 1],
                                            block_shape=dummy_block)
            else:
                desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1],
                                            block_shape=dummy_block)
            desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
            desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
        else:
            desc_q = q
            desc_v = v
            desc_k = k
            desc_o = o

        def alloc_fn(size: int, align: int, _):
            return torch.empty(size, dtype=torch.int8, device="cuda")

        triton.set_allocator(alloc_fn)

        def grid(META):
            return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)

        ctx.grid = grid
        if is_blackwell() and warp_specialize:
            if HEAD_DIM_K == 128 and q.dtype == torch.float16:
                extra_kern_args["maxnreg"] = 168
            else:
                extra_kern_args["maxnreg"] = 80
        _attn_fwd[grid](
            sm_scale, M,  #
            q.shape[0], q.shape[1],  #
            ent,
            desc_q, desc_k, desc_v, desc_o,  #
            N_CTX=q.shape[2],  #
            COMPUTE_ENTROPY=compute_entropy,
            HEAD_DIM=HEAD_DIM_K,  #
            FP8_OUTPUT=q.dtype == torch.float8_e5m2,  #
            STAGE=stage,  #
            warp_specialize=warp_specialize,  #
            IS_HOPPER=is_hopper(),  #
            **extra_kern_args)

        ctx.save_for_backward(q, k, v, o, M)
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        return o