Bit Net

Overview

  • 1 bit LLM variant, namely BitNet b1.58
    • every single parameter (or weight) of the LLM is ternary \(\{-1, 0, 1\}\)
      • it matches the full-precision ( i.e. FP16, or BF16 ) Transformer LLM
        • with the same model size and training tokens
          • in terms of both perplexity and end task performance
            • while being more cost-effective in terms of latency, memory, throughput, and energy consumption
  • the 1.58 bit LLM defines a new scaling law and recipe
    • for training new generations of LLMs that are both high-performance and cost-effective
  • enabling a new computing paradigm
    • and opens the door for designing specific hardware optimized for 1-bit LLMs

Perplexity

Perplexity of a probability distribution

\({\displaystyle \mathrm {PP} (p)=\prod _{x}p(x)^{-p(x)}=b^{-\sum _{x}p(x)\log _{b}p(x)},}\)

Perplexity of a probability model

\({\displaystyle b^{-{\frac {1}{N}}\sum _{i=1}^{N}\log _{b}q(x_{i})}=\left(\prod _{i}q(x_{i})\right)^{-1/N},}\)

Quantization

In-Depth

  • the architecture uses INT8 addition calculations when performing matrix multiplication

    in contrast to LLaMa LLM's FP16 addition and multiplication operations

BitNet replaces the traditional Linear layers in Multi-head attention and Feed-forward networks with specialized layers called BitLinear that uses ternary precision (or even binary)

main obstacle to training in ternary precision is that the weight values are discretized and thus non-differentiable

BitLinear solves this with Straigh Through Estimator

  • STE allows gradients to flow through the non-differentiable discretization operation
    • by approximating its gradient as 1

e.g. instead of stopping the gradient at the discretization step

  • the STE lets the gradient pass through as if the rounding never occured
    • enabling weight updates using standard gradient-based optimization techniques

Training

BitNet trains in full precision

  • but quantize the weights into ternary values as we go
    • using symmetric per tensor quantization
  • first we compute the average of the absolute values of the weight matrix

    • and use this as a scale
  • then devide the weights by the scale, discretize the values, constrain them between -1 and 1

  • finally rescale them to continue in full precision

    \(scale_w = \frac{1}{\frac{1}{nm}\sum_{ij}|W_{ij}|}\) \(W_q = clamp_{[-1,1]}(round(W * scale))\)

    \(W_{dequantized} = W_q * scale_w\)

  • activations are then quantized to a specified bit-width (e.g. 8-bit)

    • using absmax per token quantization

      practice this involves scaling the activations into the range [-128,127] for an 8-bit width

    \(scale_x = \frac{127}{|X|_{max,dim=-1}}\) \(X_q = clamp_{[-128,127]}(round(X * scale))\)

    \(X_{dequantized} = X_q * scale_x\)

We apply Layer Normalization(LN) before quantizing the activations to mainta the variance of the output

\(LN(x) = \frac{x=E(x)}{\sqrt{Var(x)+\varepsilon}}\)

  • where \(\varepsilon\) is a small number to prevent overflow

  • the `round()` functions is not differentiable

    • we use `detach()` as a trick to implement a differentiable straight-through estimator
      • in the backwards pass

Inference

during inference we simply quantize the weight to ternary values without rescaling

  • same approach is used via 8-bit precision
    • then perform a matrix multiplication with an efficient kernel
      • followed by dividing by both the weight and activation scales

improves inference speed with optimized hardware