Flux Embedding Vectors
-
Flux provides Embedding Vectors
-
these layers accept an index, and return a vector
- or several indices, and several vectors
- the possible embedding vectors are learned parameters
-
these layers accept an index, and return a vector
Flux.Embedding - Type `Embedding(in => out; init=randn32)`
-
a lookup table that stores embeddings of dimension `out` for a vocab of size `in`
- as a trainable matrix
-
this layer is often used to store word embeddings and retrieve them using indices
-
input to the layer can be a vocabulary index `1:in`
- an array of indices, or the corresponding onehotencoding
-
input to the layer can be a vocabulary index `1:in`
-
for indices `x`, the result is of size `(out, size(x)…)`
-
allowing for several batch dimensions, for one-hot `ohx`
- the result is of size `(out, size(ohx)[2:end]…)`
-
allowing for several batch dimensions, for one-hot `ohx`
using Flux
emb = Embedding(26 => 4, init=Flux.indentity_init(gain=22))
emb([3, 1, 20, 14, 4, 15, 7]) # vocabulary indices, in 1:26
ans == emb(Flux.onehotbatch("cat&dog", 'a':'z', 'n'))
emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
Transformers.jl Embed
-
we define an abstract type for embeddings
-
and a concrete implementation of the abstract type
- the Embedding Layer, `size`, is the hidden size and `vocabsize` is number of vocabulary
-
and a concrete implementation of the abstract type
just a wrapper for the embedding matrix
abstract type AbstractEmbedding end
struct Embed{F, E <: AbstractArray} <: AbstractEmbedding
scale::F
embeddings::E
end
@functor Embed (embeddings,)
Embed(embeddings::AbstractArray; scale = nothing) = Embed(scale, embeddings)
Embed(hidden_size::Int, vocab_size::Int; scale = nothing) = Embed(scale, randn(Float32, hidden_size, vocab_size))
(embed::Embed{Nothing})(x) = NNlib.gather(embed.embeddings, x)
function (embed::Embed)(x)
y = NNlib.gather(embed.embeddings, x)
return y .* convert(eltype(y), embed.scale)
end
-
flux models are deeply nested structures and Functors.jl provides tools needed to explore those objects
- applies funcs to the params they contain e.g. moving them to the gpu and re-build them
-
Many of the functions used by Flux.jl are provided by NNLib
- such as softmax, sigmoid, batched multiplication, convolutions, and pooling
- Also for use with auto-differentiation, NNLib defines gradients using ChainRules.jl, which will be seen by Zygote.jl
Transformers.jl also contains EmbedDecoder, a layer that shares weight with an embedding layer and returns the logit
-
ChainRulesCore.rrule extends the primal computation method
- can be checked against ChainRulesTestUtils
rrule dispatches on `typeof` of the function we are writing the `rrule` for
- as well as the types of its arguments
- it returns the primal result y and the pullback function
Transformers defines a rrule for EmbedDecoder
Functors.jl basic usage and implementation
-
it is allowed to look into the fields of the instances of any struct
-
and modify them
-
achieved through Functors.fmap
-
a structure and type preserving map
-
transforms every leaf node by applying function `f`
-
and otherwise traverses `x` recursively using Functors.functor
- `functor(x)` returns a tuple containing a NamedTuple of the children of `x` and a reconstruction function
-
and otherwise traverses `x` recursively using Functors.functor
-
transforms every leaf node by applying function `f`
-
to opt out of this behaviour
- mark a custom type as non-traversable `@leaf`
-
a structure and type preserving map
-
achieved through Functors.fmap
-
and modify them
to include only certain fields of a struct, one can pass a tuple of fieldnames to @functor
struct Baz
x
y
end
@functor Baz(x,)
model = Baz(1,2)
# Baz(1, 2)
fmap(float, model)
# Baz(1.0, 2)
NNLib gather
Flux's Embedding layer uses NNlib.gather as its backend
NNlib.gather(src, idx) -> dst
NNlib.gather([1,20,300,4000], [2,4,2])
#3-element Vector{Int64}:
# 20
# 4000
# 20
NNlib.gather([1 2 3; 4 5 6], [1,3,1,3,1])
#2×5 Matrix{Int64}:
# 1 3 1 3 1
# 4 6 4 6 4
-
reverse operation of `scatter`
-
gathers data from source `src` and writes its in a destination `dst` according to the index array `idx`
- for each `k` in `CartesianIndices(idx)` assign values to `dst` according to dst[:, … , k] .= src[:, …, idx[k]…]
-
notice that if `idx` is a vector containing integers and `src` is a matrix
- previous expression simplifies to dst[:, k] .= src[:, idx[k]]
- and `k` will run over `1:length(idx)`
-
gathers data from source `src` and writes its in a destination `dst` according to the index array `idx`