2024-05-31
High Performance LLM Serving
How to optimize LLM inference and create cost-effective, robust and custom solution
These are my notes on the GPU optimization workshop organized by Chip Huyen.
Sharan Chetlur talks about inference optimization techniques and how to achieve high performance while serving LLMs.
-
Majority of the apps are real-time(online systems) and they require acceptable latencies, should be highly accurate(helpful) and should be cost effective as deploying these large models can be expensive
-
These models will keep getting bigger and better. see chinchilla scaling laws.
-
So to optimize for these challenges we need a high-performant, robust and customizable solution
-
Quantization came up again here as its becoming necessary to do it(its an all-round win w/ just a bit of effort) as long as you maintain accuracy
-
So you can train llms in the BF16 dtype but inference can be done in lower precision dtypes like int8, int4 or even lower. this is post-training quantization and usually just needs calibration.
-
Post-training quantization can make compute faster, communications between GPUs can happen with high throughput
-
There are other optimization techniques as well like quantization aware training (retrain model on small dataset), sparsity
-
At inference-time an LLM request has 2 phases
-
Prefill - processes the prompt, generate 1st token and initialize the KV cache(this cache stores intermediate activations). this phase is compute-heavy.
-
Generate - generate next token using last generated token and the KV cache and update the cache. this phase is memory bound.
-
Compute in the attention block depends on the kind of implementation (can be compute or memory bound)
-
These are compute and memory bounded phases and serving solution should have custom CUDA kernels for high perf (e.g. flash attention). remember we need to keep our CUDA cores busy.
-
Static batching - traditionally, requests are accumulated over some time window, and executed as a batch until completion. this can work if we have to do fixed amount of work but in the space of llms outputs differs massively in length(e.g. "the tallest ferry wheel" vs "explain the multiverse theory")
-
In-flight batching - in llm inference theres multiple forward passes per request which is kinda unknown and unbounded(we don't know how many tokens need to be generated). so in-flight batching treats an llm iteration as a single request.
- so imagine there are 3 requests. req-1 and req-2 is in some state of generation. req-3 is a new request.
- we check for active requests that have met the end-of-sequence token (or some end condition) and then evict that request (e.g. req-1)
- replace it with the new one (req-3)
- run the next iteration
-
Sometimes in a single iteration there are requests that are in generation phase whereas others are in prefill phase. these tokens in different phases can then be concatenated for higher throughput.
-
Paged KV cache - traditionally these caches were contiguous in memory and this leads to wastage of memory since its based on max sequence length. paged KV cache however are partitioned into blocks and these do not need to be contiguous. this way memory waste only happens in the last block.
-
KV cache reuse - now this representation allows memory sharing. for example the system prompt which might be common for 2 requests can now be cached and reused for the 2nd request. see pagedAttention by vllm
-
Speculative decoding - makes educated guesses about future tokens while generating the current token, all within a single forward pass. see hitchhikers guide to speculative decoding by pytorch