2 The transformer architecture
For a vector \(v \in \mathbb {R}^{n}\),
For a matrix \(A \in \mathbb {R}^{n \times n}\) the softmax operator is applied row-wise: \(\operatorname {softmax}(A)_{i,:} = \operatorname {softmax}(A_{i,:})\).
For input dimension \(d_{\mathrm{in}} \in \mathbb {N}\), output dimension \(d_{\mathrm{out}} \in \mathbb {N}\), embedding dimension \(m \in \mathbb {N}\), and matrices \(Q, K \in \mathbb {R}^{d_{\mathrm{in}} \times m}\) and \(V \in \mathbb {R}^{d_{\mathrm{in}} \times d_{\mathrm{out}}}\), an attention is the mapping \(A_{Q,K,V} : \mathbb {R}^{n \times d_{\mathrm{in}}} \to \mathbb {R}^{n \times d_{\mathrm{out}}}\) defined by
We write \(\mathcal A_{d_{\mathrm{in}}, m, d_{\mathrm{out}}}\) for the set of all such attentions.
A multi-layer perceptron (MLP) is represented by some continuous function \(\varphi : \mathbb {R}^{a} \to \mathbb {R}^{b}\) for positive integers \(a,b\) (modelling, following the universal approximation theorem, any function approximable by a neural network). It is applied to a matrix row-wise: for \(X \in \mathbb {R}^{n \times a}\), \(\varphi (X) = (\varphi (X_1), \ldots , \varphi (X_n)) \in \mathbb {R}^{n \times b}\).
A transformer is a mapping \(\mathsf{TF}: \mathbb {R}^{n \times d} \to \mathbb {R}\) specified by an attention unit \(A_{Q,K,V}\) and two MLPs \(\varphi _1 : \mathbb {R}^{n \times d} \to \mathbb {R}^{n \times d_{\mathrm{in}}}\) and \(\varphi _2 : \mathbb {R}^{n \times d_{\mathrm{out}}} \to \mathbb {R}\). On an embedding matrix \(E \in \mathbb {R}^{n \times d}\) it outputs
This models a single-attention-unit transformer (first MLP, then the attention unit, then the second MLP).
A transformer \(\mathsf{TF}\) solves a (decision) problem whose answer on instance \(E \in \mathbb {R}^{n \times d}\) is \(\mathsf{TF}(E)\). For a problem on vectors \(v_1, \ldots , v_n\) with answer a Boolean, \(\mathsf{TF}\) solves it if for every input \(E\) with \(E_{i,:} = v_i\) for all \(i\) one has \(\mathsf{TF}(E) = 1\) exactly when the answer is "yes" and \(\mathsf{TF}(E) = 0\) otherwise.