Julia Flux Embeddings

Flux Embedding Vectors

  • Flux provides Embedding Vectors
    • these layers accept an index, and return a vector
      • or several indices, and several vectors
    • the possible embedding vectors are learned parameters

Flux.Embedding - Type `Embedding(in => out; init=randn32)`

  • a lookup table that stores embeddings of dimension `out` for a vocab of size `in`
    • as a trainable matrix
  • this layer is often used to store word embeddings and retrieve them using indices
    • input to the layer can be a vocabulary index `1:in`
      • an array of indices, or the corresponding onehotencoding
  • for indices `x`, the result is of size `(out, size(x)…)`
    • allowing for several batch dimensions, for one-hot `ohx`
      • the result is of size `(out, size(ohx)[2:end]…)`
using Flux
emb = Embedding(26 => 4, init=Flux.indentity_init(gain=22))
emb([3, 1, 20, 14, 4, 15, 7])  # vocabulary indices, in 1:26
ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
emb(rand(1:26, (10, 1, 12))) |> size  # three batch dimensions

Transformers.jl Embed

  • we define an abstract type for embeddings
    • and a concrete implementation of the abstract type
      • the Embedding Layer, `size`, is the hidden size and `vocabsize` is number of vocabulary

just a wrapper for the embedding matrix

abstract type AbstractEmbedding end

struct Embed{F, E <: AbstractArray} <: AbstractEmbedding
    scale::F
    embeddings::E
end

@functor Embed (embeddings,)
Embed(embeddings::AbstractArray; scale = nothing) = Embed(scale, embeddings)
Embed(hidden_size::Int, vocab_size::Int; scale = nothing) = Embed(scale, randn(Float32, hidden_size, vocab_size))

(embed::Embed{Nothing})(x) = NNlib.gather(embed.embeddings, x)
function (embed::Embed)(x)
    y = NNlib.gather(embed.embeddings, x)
    return y .* convert(eltype(y), embed.scale)
end
  • flux models are deeply nested structures and Functors.jl provides tools needed to explore those objects
    • applies funcs to the params they contain e.g. moving them to the gpu and re-build them
  • Many of the functions used by Flux.jl are provided by NNLib
    • such as softmax, sigmoid, batched multiplication, convolutions, and pooling
    • Also for use with auto-differentiation, NNLib defines gradients using ChainRules.jl, which will be seen by Zygote.jl

Transformers.jl also contains EmbedDecoder, a layer that shares weight with an embedding layer and returns the logit

  • ChainRulesCore.rrule extends the primal computation method
    • can be checked against ChainRulesTestUtils

rrule dispatches on `typeof` of the function we are writing the `rrule` for

  • as well as the types of its arguments
  • it returns the primal result y and the pullback function

Transformers defines a rrule for EmbedDecoder

Functors.jl basic usage and implementation

  • it is allowed to look into the fields of the instances of any struct
    • and modify them
      • achieved through Functors.fmap
        • a structure and type preserving map
          • transforms every leaf node by applying function `f`
            • and otherwise traverses `x` recursively using Functors.functor
              • `functor(x)` returns a tuple containing a NamedTuple of the children of `x` and a reconstruction function
        • to opt out of this behaviour
          • mark a custom type as non-traversable `@leaf`

to include only certain fields of a struct, one can pass a tuple of fieldnames to @functor

struct Baz
    x
    y
end

@functor Baz(x,)
model = Baz(1,2)
# Baz(1, 2)
fmap(float, model)
# Baz(1.0, 2)

NNLib gather

Flux's Embedding layer uses NNlib.gather as its backend

NNlib.gather(src, idx) -> dst

NNlib.gather([1,20,300,4000], [2,4,2])
#3-element Vector{Int64}:
#   20
# 4000
#   20
NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1])
#2×5 Matrix{Int64}:
# 1  3  1  3  1
# 4  6  4  6  4
  • reverse operation of `scatter`
    • gathers data from source `src` and writes its in a destination `dst` according to the index array `idx`
      • for each `k` in `CartesianIndices(idx)` assign values to `dst` according to dst[:, … , k] .= src[:, …, idx[k]…]
      • notice that if `idx` is a vector containing integers and `src` is a matrix
        • previous expression simplifies to dst[:, k] .= src[:, idx[k]]
        • and `k` will run over `1:length(idx)`