MUSTAFAR reading note

-claude link -claude link chat

KV Cache and Pruning Strategies - Study Notes

What is KV Cache?

Purpose

KV cache is a memory optimization technique used in transformer models during text generation to avoid redundant computations.

How it Works

  • Problem: Without caching, generating each new token requires recomputing Key (K) and Value (V) matrices for all previous tokens
  • Solution: Store K and V representations of previous tokens, only compute K and V for the new token

Example Process

Generating "The cat sat on"

Step 1: Generate "cat"
- Input: "The"
- Compute K₁, V₁ for "The"
- Cache: K=[K₁], V=[V₁]

Step 2: Generate "sat" 
- Input: "The cat"
- Compute K₂, V₂ for "cat" 
- Cache: K=[K₁, K₂], V=[V₁, V₂]
- Reuse K₁, V₁ (no recomputation!)

Step 3: Generate "on"
- Input: "The cat sat"
- Compute K₃, V₃ for "sat"
- Cache: K=[K₁, K₂, K₃], V=[V₁, V₂, V₃]

KV Cache Structure

Matrix Dimensions

  • Format: [tokens × channels]
  • Tokens: Sequence positions (words/subwords in the input)
  • Channels: Feature dimensions (hidden size of the model, e.g., 768, 1024, 4096)
  • Growth: Cache grows as sequence lengthens: [1×channels] → [2×channels] → [3×channels]…

Key Properties

  • Both K and V caches have identical dimensions
  • Channels size is determined by model architecture
  • Each element represents the intersection of a token and a channel

Pruning Strategies

Core Concepts

  • Pruning Direction: Which axis to remove elements from
  • Output-Awareness: Using scoring metrics to estimate element importance
  • Local Dense Window: Keep recent 32 tokens untouched during decoding

1. Per-Channel Pruning

Definition: For each channel (column), selectively remove some token entries

How it works:

  • Look at each channel across all tokens
  • Apply different sparsity patterns to different channels
  • Remove elements within each channel vector

Example:

Original:
        Ch1  Ch2  Ch3  Ch4  Ch5  Ch6
Token1: [a,   b,   c,   d,   e,   f]
Token2: [g,   h,   i,   j,   k,   l]
Token3: [m,   n,   o,   p,   q,   r]
Token4: [s,   t,   u,   v,   w,   x]

After Per-Channel Pruning:
        Ch1  Ch2  Ch3  Ch4  Ch5  Ch6
Token1: [a,   -,   c,   d,   -,   f]
Token2: [g,   h,   -,   -,   k,   l]
Token3: [-,   n,   o,   p,   q,   -]
Token4: [s,   -,   u,   v,   -,   x]

2. Per-Token Pruning

Definition: For each token (row), selectively remove some channel entries

How it works:

  • Look at each token across all channels
  • Apply different sparsity patterns to different tokens
  • Remove elements within each token’s representation

Example:

Original:
        Ch1  Ch2  Ch3  Ch4  Ch5  Ch6
Token1: [a,   b,   c,   d,   e,   f]
Token2: [g,   h,   i,   j,   k,   l]
Token3: [m,   n,   o,   p,   q,   r]
Token4: [s,   t,   u,   v,   w,   x]

After Per-Token Pruning:
        Ch1  Ch2  Ch3  Ch4  Ch5  Ch6
Token1: [a,   b,   -,   -,   e,   f]  ← 66% kept
Token2: [g,   -,   i,   j,   -,   l]  ← 66% kept
Token3: [m,   n,   -,   p,   q,   r]  ← 83% kept
Token4: [-,   t,   u,   -,   w,   x]  ← 66% kept

Key Differences Between Pruning Strategies

AspectPer-Channel PruningPer-Token Pruning
DirectionVertical (across tokens)Horizontal (across channels)
UnitChannel vectorToken vector
Sparsity PatternDifferent for each channelDifferent for each token
What’s RemovedToken entries within channelsChannel entries within tokens

Important Notes

  • Both strategies create unstructured sparsity (irregular patterns)
  • Each channel captures different features/aspects of the representation
  • Each token has its own unique representation across channels
  • The choice between strategies depends on the specific use case and model characteristics
  • Recent tokens (last 32) are typically preserved for accuracy