I love the idea behind tf.einsum
-- a single unifying framework for a collection of common matrix operations that amount to some combination of multiplication and additive reduction over tensor axes. If you're unfamiliar with einsum notation, give Tim Rocktäschel's wonderful einsum overview a read to understand why you should care.
If you'd prefer the tl:dr version, below is a sample einsum call:
# Batched matrix multiplication using einsum
tf.einsum(`ijk,ikl->ijl`, A, B)
The arguments left of the arrow are the functions inputs. Arguments to the function (tensors) are separated by commas and composed of characters. Each character indicates a tensor axis.
The computation indicated by the above expression is shown in the for-loop expansion below:
# First we construct an output of the provided shape
output = np.zeros(i, j, l)
# Then, we use nested for loops over
# the indices preseent on the right
# In our case, this is (i, j, l)
for _i in range(i):
for _j in range(j):
for _l in range(l):
# For the remaining axes, we use an
# intermediate variable to compute
# a sum reduction over that axis
total = 0
for _k in range(k):
# In the inner-most loop,
# we multiply the indices indicated by the
# left-hand side of the equation
total += A[i,j,k] * B[i,k,l]
# And finally, set the value at the
# location indicated by the right
output[i,j,l] = total
See Olexa Bilaniuk's explanation of the internal workings of einsum for a deeper dive.
It's a relatively simple framework but it's a surprisingly flexible one, and we can use it to compute everything from the trace of a matrix to dot-product attention in a single expression.
In practice I dislike that tf.einsum
necessitates referencing each axis with a single letter variable name. I find myself wanting a more verbose syntax that takes the place of comments used to keep track of dimensions in your source.
To facilitate this, I've put together a tiny wrapper that implements a "Named Einsum" syntax. In total it's only about 30 functional lines of code -- feel free to mix and re-use for your own machine learning escapades. Note that although the gist uses np.einsum
, as we're only performing a translation of the input string you can use the same code in PyTorch or Tensorflow by simply swapping np.einsum
for torch.einsum
or tf.einsum
.
Rather than requiring single letter axis names, this syntax delimits axes in a tensor via an explicit .
between axis names. Let's see the new syntax in action on a sample problem -- an implementation of multi-head attention. For purposes of illustration, let's assume we'll be managing requisite tensor reshapes external to our multi_head_attention
function.
Let's start with vanilla einsum
:
import tensorflow as tf
def multi_head_attention(keys, queries, values):
weights = tf.nn.softmax(
tf.einsum('ijlm,iklm->ijkl', keys, queries),
axis=1
)
per_head_context = tf.einsum('ijlm,ijkl->ijlm', values, weights)
return per_head_context
Were I to run across this function in the wild, my first question would be "what are the shapes of our inputs?". Even if you were familiar with multi-head attention, this may be a confusing read unless the author was kind enough to give us additional detail about input shapes in the form of a comment. And even then you'd have to do some mental translation between the single letter variable names and the tensor axes they correspond with. Trying to choose meaningful single letter variable names doesn't go far enough toward making this code legible -- an h
passed to tf.einsum
might refer to height
, n_heads
, head_size
, or the hidden
dimension.
Compare this to our alternate syntax:
import tensorflow as tf
def multi_head_attention(keys, queries, values):
per_token_weights = tf.nn.softmax(
named_einsum(
'batch.source.heads.hidden,'
'batch.target.heads.hidden'
'->batch.source.target.heads',
keys,
queries
),
axis=1
)
per_head_context = named_einsum(
'batch.target.heads.hidden,'
'batch.source.target.heads'
'->batch.source.heads.hidden',
values,
per_token_weights
)
return per_head_context
Sure it's more verbose, but that's the point! We're trading keystrokes for legibility -- well-written code has more in common with traditional writing than most like to admit. Here we can explicitly call out the functions of our axes used in the einsum operation, and hopefully end up with a result that's less likely to cause a new reader's eyes to glaze over.
In spite of my gripes with reading einsum code that's awash with seas of i's, j's, k's, and l's, I think use of einsum is largely be a step in the right direction.
I would love to see extensions to the einsum syntax to also handle tensor flattening and reshaping:
# Proposed syntax
named_einsum(
'batch.source.heads.head_size->batch.source.(heads.head_size)',
tensor
)
Or tensor concatenation:
named_einsum('batch.a,batch.b->batch.(a,b)')
Or perhaps replace tf.expand_dims()
by mixing in concrete integers with axis names:
named_einsum('row.col->row.col.1', tensor)
More broadly, it feels like the family of tensor operations common in deep learning are an excellent fit for a domain specific language. The typical paradigm of referring to axes by their index makes deep learning code hard to follow and dissect. Deep learning code has a user experience problem -- and I'd love to see more proposals for how we can make the experience of reading deep learning code better through the use of creative syntax.