-
KV CacheResearch/NLP_reference 2024. 9. 29. 11:19
https://medium.com/@joaolages/kv-caching-explained-276520203249
Transformers KV Caching Explained
- How caching Key and Value states makes transformers faster
Caching the Key (K) and Value (V) states of generative transformers has been around for a while, but maybe you need to understand what it is exactly, and the great inference speedups that it provides.
The Key and Value states are used for calculating the scaled dot-product attention, as is seen in the image below.
KV caching occurs during multiple token generation steps and only happens in the decoder (i.e., in decoder-only models like GPT, or in the decoder part of encoder-decoder models like T5). Models like BERT are not generative and therefore do not have KV caching.
The decoder works in an auto-regressive fashion, as depicted in this GPT-2 text generation example.
This auto-regressive behavior repeats some operations, and we can better understand this by zooming in on the masked scaled dot-product attention computation that is calculated in the decoder.
Since the decoder is causal (i.e., the attention of a token only depends on its preceding tokens), at each generation step we are recalculating the same previous token attention, when we actually just want to calculate the attention for the new token.
This is where KV comes into play. By caching the previous Keys and Values, we can focus on only calculating the attention for the new token.
Why is this optimization important? As seen in the picture above, the matrices obtained with KV caching are way smaller, which leads to faster matrix multiplications. The only downside is that it needs more GPU VRAM (or CPU RAM if GPU is not being used) to cache the Key and Value states.
Let’s use transformers 🤗 to compare the generation speed of GPT-2 with and without KV caching.
import numpy as np import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2").to(device) for use_cache in (True, False): times = [] for _ in range(10): # measuring 10 generations start = time.time() model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000) times.append(time.time() - start) print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")
On a Google Colab notebook, using a Tesla T4 GPU, these were the reported average and standard deviation times, for generating 1000 new tokens:
with KV caching: 11.885 +- 0.272 seconds
without KV caching: 56.197 +- 1.855 secondsThe difference in inference speed was huge while the GPU VRAM usage was neglectable, as reported here, so make sure to use KV caching in your transformer model!
Thankfully, this is the default behavior in transformers 🤗.
Long Context로 인한 Large KV Cache의 문제점과 해결 방안: Part 1- KV cache의 메모리 요구량
Auto-regressive 모델이란 이전 단계의 출력들을 이용하여 다음 단계의 출력을 예측하는 모델이다. GPT는 auto-regressive 모델로 이전에 생성된 토큰를 기반으로 다음 토큰을 생성한다. GPT는 이전 토큰 생성 시 발생된 중간값인 activations(e.g. KV cache)를 캐싱하여 이전 토큰 값을 재계산하기 위한 GPU의 FLOPS를 절감하는 대신, KV cache을 위한 추가적인 메모리 공간이 필요하다. LLM의 context window size가 증가할수록 KV cache의 크기 또한 선형적으로 증가하므로 context window size는 메모리 용량에 제한을 받는다. 본 포스트는 LLM이 long context를 지원할 경우 KV cache 메모리 요구량이 급격하게 증가하면서 발생하는 메모리 용량 증가 문제에 대해서 소개한다.
LLM의 Context Window Size의 증가
필자가 예전 포스트(LLM의 Context Window Size가 크다고 좋은 것일까?)에서 설명했던 것과 같이 LLM의 context window size는 LLM의 활용성을 높이는데 중요한 역할을 담당한다. 이로 인해 gpt-3.5-turbo(ChatGPT)가 최초 출시한 2022년 말엔 context window size가 4K tokens을 지원하였지만 2023년 11월에 발표된 gpt-4-turbo는 128K tokens를 지원하기 시작하였다. 즉, GPT의 context window size는 무려 1년만에 32배 증가하였다. 그러나 LLM가 long context를 지원할수록 GPU의 FLOPS와 메모리 사용량을 급격하게 증가시켜 production 수준에서 리소스 부족 문제를 발생시킨다.
KV cache란?
KV Cache란 토큰 생성 시 계산되는 Key/Value 텐서를 GPU 메모리에 저장한 후 재사용하는 것으로 이전 토큰의 Key/Value 텐서를 재계산되는 것을 막아 연산량을 줄이는 방법이다.
KV Caching은 compute & memory trade-off의 대표적인 예로 컴퓨팅 양을 줄이는 대신 생성된 Key 텐서와 Value 텐서를 버리지 않고 저장해야 하기 때문에 메모리 사용량이 증가한다. KV Caching에 필요한 메모리 양은 context window size와 batch size에 의해 결정된다.
KV Cache의 메모리 요구량 계산
KV cache(MHA)의 메모리 요구량은 다음과 같이 계산식으로 계산할 수 있다.
LLaMA2의 모델 specification을 이용하여 MHA(Multi-Head Attention)의 KV Cache의 메모리 요구량을 계산하면 다음과 같이 batch size와 sequence length(=context window length)의 식으로 구성됨을 확인할 수 있다.
(Note: LLAMA2–70B은 원래 GQA(Grouped-Query Attention)를 사용하였다. 본 포스트에서는 KV cache가 요구하는 메모리 양이 매우 큼을 확인하기 위해 GQA로 최적화하기 전인 MHA를 사용한 경우를 고려하였다.)
아래 그림은 LLaMA2-70B가 MHA를 사용하였을 때 sequence length와 batch size별 KV cache의 메모리 요구량의 변화를 나타낸 것이다. LLaMA2는 기본적으로 sequence length=4K를 지원하므로 batch 당 KV cache의 메모리 요구량은 1.25GB인 것을 알 수 있다. 만일 LLaMA2의 sequence length를 128K로 증가시킨다면 KV cache를 위해 batch당 무려 40GB가 필요하다!! 특히 long context 조건에서 KV cache의 메모리 요구량은 batch size에 비례하여 급격하게 증가함을 알 수 있다. 예를들어 sequence length=128K & batch size=32인 경우 KV cache는 1TB이 넘는 것을 알 수 있다.
LLM의 sequence length는 지속적으로 커지고 있어 LLM serving 시 long context를 처리하는 것은 매우 중요한 이슈가 되고 있다. 아래 그림은 single node(A100–80GB x8. 총 GPU 메모리 용량: 640GB )를 기준으로 LLaMA2–70B 모델 serving에 필요한 메모리 용량(weight + KV cache)을 나타낸 것이다. Sequence length가 4K일 경우, single node 수준에서 batch size=256을 처리할 수 있는 반면(weight + KV cache:460GB < 총 GPU 메모리 용량: 640GB), sequence length가 128K일 경우, singe node 수준에서 겨우 batch size=8을 처리할 수 밖에 없다. 이와 같은 결과를 통해 다음과 같은 사실을 알 수 있다.
- Long Context와 Large batch size인 조건에서 KV cache가 weight보다 훨씬 더 많은 메모리를 소비한다.
- KV cache의 메모리 소비량이 매우 커질 경우, 추론 시 GPU 메모리 용량이 bottleneck으로 작용할 수 있음을 의미한다.
LLM serving 시 sequence length와 batch size 결정하기
LLM의 weight는 고정된 값인 반면 KV cache는 sequence length와 batch size에 따라 변화한다. GPU 메모리는 대부분 weight와 KV cache로 채워지며 GPU 메모리 용량에 따라 지원하는 sequence length와 batch size가 결정된다. 따라서 토큰 당 KV cache의 용량을 알 수 있다면 GPU 메모리 용량에 따른 지원 가능한 sequence length과 batch size를 계산할 수 있다. LLaMA2–13B(MHA)을 single GPU(A100-80GB)에서 서비스한다면 지원 가능한 sequence length와 batch size는 다음과 같다.
(1) 최대 sequence length
- batch size=1일 경우, 0.82MB*1*seq_len=54GB이므로 최대 sequence length = (approx.) 65854이다.
(2) 최대 batch size
- sequence length=4K일 경우, 0.82MB*batch_size*4096=54GB이므로 batch size = (approx.) 16이다.
만일 A100 GPU(80GB)가 아닌 A100 GPU(40GB)를 사용한다면 메모리 용량의 제약으로 최대 sequence length와 최대 batch size는 모두 1/4 가량 줄어드는 것을 확인할 수 있다.
결론
LLM이 long context를 지원할 경우, KV cache가 급격하게 커지면서 GPU 메모리 용량에 따라 추론 시 LLM의 최대 sequence length와 최대 batch size가 결정되는 것을 확인하였다. Long context를 처리하거나 생성을 해야 할 때 GPU 메모리 부족 문제는 batch 처리를 어렵게 만들어 하드웨어 효율성을 낮추는 문제를 초래한다. 이러한 관점에서 제한된 GPU 메모리 용량를 효율적으로 사용하기 위해 다음과 같은 여러가지 방법을 사용할 수 있다.
- 모델 weight의 memory footprint를 줄이는 방법 (e.g. quantization)
- KV cache의 memory footprint를 줄이는 방법(e.g. GQA, MQA)
- Model Parallelism을 사용하여 모델을 여러 GPU로 분할 처리하는 방법 (e.g. tenosr parallelism 등)
다음 포스트에서는 LLM inference 최적화를 위해 KV cache의 memory footprint를 줄이는 방법에 대해서 알아보도록 할 예정이다.
레퍼런스
[1] Mastering LLM Techniques: Inference Optimization
[2] EfficientML.ai Lecture 12 — Transformer and LLM Part-1
[3] EfficientML.ai Lecture 13 — Transformer and LLM Part-2
'Research > NLP_reference' 카테고리의 다른 글
[Attention Rollout] Explainability for Vision Transformers (0) 2024.11.22 Gemma 2 (0) 2024.09.08 Llama 3.1 (0) 2024.09.08 Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA (0) 2024.09.08 How to Successfully Run a LLM Fine-Tuning Project (0) 2024.09.07