Overview
-
1 bit LLM variant, namely BitNet b1.58
-
every single parameter (or weight) of the LLM is ternary \(\{-1, 0, 1\}\)
-
it matches the full-precision ( i.e. FP16, or BF16 ) Transformer LLM
-
with the same model size and training tokens
-
in terms of both perplexity and end task performance
- while being more cost-effective in terms of latency, memory, throughput, and energy consumption
-
in terms of both perplexity and end task performance
-
with the same model size and training tokens
-
it matches the full-precision ( i.e. FP16, or BF16 ) Transformer LLM
-
every single parameter (or weight) of the LLM is ternary \(\{-1, 0, 1\}\)
-
the 1.58 bit LLM defines a new scaling law and recipe
- for training new generations of LLMs that are both high-performance and cost-effective
-
enabling a new computing paradigm
- and opens the door for designing specific hardware optimized for 1-bit LLMs
Perplexity
Perplexity of a probability distribution
\({\displaystyle \mathrm {PP} (p)=\prod _{x}p(x)^{-p(x)}=b^{-\sum _{x}p(x)\log _{b}p(x)},}\)
Perplexity of a probability model
\({\displaystyle b^{-{\frac {1}{N}}\sum _{i=1}^{N}\log _{b}q(x_{i})}=\left(\prod _{i}q(x_{i})\right)^{-1/N},}\)
Quantization
In-Depth
-
the architecture uses INT8 addition calculations when performing matrix multiplication
in contrast to LLaMa LLM's FP16 addition and multiplication operations
BitNet replaces the traditional Linear layers in Multi-head attention and Feed-forward networks with specialized layers called BitLinear that uses ternary precision (or even binary)
main obstacle to training in ternary precision is that the weight values are discretized and thus non-differentiable
BitLinear solves this with Straigh Through Estimator
-
STE allows gradients to flow through the non-differentiable discretization operation
- by approximating its gradient as 1
e.g. instead of stopping the gradient at the discretization step
-
the STE lets the gradient pass through as if the rounding never occured
- enabling weight updates using standard gradient-based optimization techniques
Training
BitNet trains in full precision
-
but quantize the weights into ternary values as we go
- using symmetric per tensor quantization
-
first we compute the average of the absolute values of the weight matrix
- and use this as a scale
-
then devide the weights by the scale, discretize the values, constrain them between -1 and 1
-
finally rescale them to continue in full precision
\(scale_w = \frac{1}{\frac{1}{nm}\sum_{ij}|W_{ij}|}\) \(W_q = clamp_{[-1,1]}(round(W * scale))\)
\(W_{dequantized} = W_q * scale_w\)
-
activations are then quantized to a specified bit-width (e.g. 8-bit)
-
using absmax per token quantization
practice this involves scaling the activations into the range [-128,127] for an 8-bit width
\(scale_x = \frac{127}{|X|_{max,dim=-1}}\) \(X_q = clamp_{[-128,127]}(round(X * scale))\)
\(X_{dequantized} = X_q * scale_x\)
-
We apply Layer Normalization(LN) before quantizing the activations to mainta the variance of the output
\(LN(x) = \frac{x=E(x)}{\sqrt{Var(x)+\varepsilon}}\)
-
where \(\varepsilon\) is a small number to prevent overflow
-
the `round()` functions is not differentiable
-
we use `detach()` as a trick to implement a differentiable straight-through estimator
- in the backwards pass
-
we use `detach()` as a trick to implement a differentiable straight-through estimator
Inference
during inference we simply quantize the weight to ternary values without rescaling
-
same approach is used via 8-bit precision
-
then perform a matrix multiplication with an efficient kernel
- followed by dividing by both the weight and activation scales
-
then perform a matrix multiplication with an efficient kernel
improves inference speed with optimized hardware