Member-only story
Practical Guide to LLM: KV Cache
Intuition, shapes and a little bit implementation.
This is part of a series of short posts that aims to provide intuitions and implementation hints for LLM basics. I will be posting more as I refresh my knowledge. If you like this style of learning, follow for more :)
Note: as we are both learning, please point out if anything seems inaccurate so that we can improve together.
What is KV?
K (Key) is a matrix representation of context that can be looked up.
V (Value) is a matrix representation of context content to look up.
LLM = use a query (prompt) to match topics of interest in K (Key) and retrieve relevant content from V (Value).
In practice, most of the time, both K and V are just dot products of input embedding and linear weights:
class ExampleAttention(nn.Module):
def __init__(self, embed_dim):
# init some stuff
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, q, k, v):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Do attention calculation here.