Transformers: What They Are and Why They Matter
Transformers leverage the concept of self-attention to develop simpler models.
A transformer is a machine learning model based solely on the concept of attention. It is a game changer in the world of deep learning architectures because it eliminates the need for recurrent connections and convolutions.
The main architecture of the transformer consists of an encoder stack followed by a decoder stack built from multihead attention layers. Here’s a deep dive into the structure of this model that will help you fully understand its internal components. The model explained below is based on the original transformer described in the paper “Attention Is All You Need,” by Ashish Vaswani et al.
Attention Based on Scale-Dot Product
In my AI Exchange article “Attention Models: What They Are and Why They Matter,” I discussed how attention weights are computed from the query Q, keys K, and values V. The attention model based on the scaled-dot product on which the transformer is based computes its output Osdp as:
Osdp = softmax(QKT/sqrt(dk))V
Here, the dot product is scaled by the square root of the number of dimensions dk of Q and K to keep the magnitude of the dot product small. The softmax function ensures that the weights computed from the dot product lie between 0 and 1 and sum to 1. This function normalizes the weights to a probability distribution that assigns higher values to more relevant keys and smaller weights to less important keys. An example of this model with three keys and three corresponding values is shown in Figure 1.
Figure 1: Scaled-dot product attention; the idea is to keep the magnitude of the dot product small. Source: Mehreen Saeed
Multihead Attention
In multihead attention, the scaled-dot product is computed multiple times using the linear projection of the query, keys, and values. The query, keys, and values are first transformed linearly using different projection vectors. The computed projections become the input to calculate multiple scaled dot product attention outputs.
The output of each scaled-dot product attention is called the “head.” Figure 2 shows an example of a two-head attention model. The operations involved in computing each head are independent and so can be done in parallel, giving you the advantage of speed. The network learns the projection vectors W during the training phase.
Figure 2: Multihead attention; the output of each scaled-dot product attention is called the “head.” Source: Mehreen Saeed
The Transformer Architecture
The original version of the transformer was a composition of an encoder stack and a decoder stack.
The Encoder Stack
The encoder stack consists of six identical layers, where the output from one layer becomes the input to the next. Each encoder layer has two sublayers, including a multihead attention layer and a fully connected feed-forward neural network. The encoder applies a residual connection around each sublayer, followed by a normalization function. The residual connection adds the input of the sublayer to the output of the same layer, as shown in Figure 3.
The encoder stack implements the concept of self-attention. Here, the query, keys, and values belong to the same sequence. One encoder layer uses the output of the previous layer as an input to attend to all the positions in a sequence of tokens.
Figure 3 shows a single encoder layer.
Figure 3: The encoder stack used in a transformer; the output from one layer becomes the input to the next. Source: Mehreen Saeed
The Decoder Stack
Like the encoder stack, the decoder stack is built from six identical layers; the output from one layer becomes the input to the next. However, each decoder layer is built from three sublayers, with residual connections around them followed by a normalization layer. The first two sublayers are multihead attention layers, and the third one is a fully connected feed-forward neural network.
The first multihead attention sublayer takes input from the previous decoder layer and implements self-attention with a masking operation. The purpose of the mask is to attend to only the preceding tokens in the input sequence instead of all the tokens of the sequence. Hence, in case of word sentences, the decoder’s first sublayer attends to each word’s position, ignoring all the words that come after the input symbol and processing all the words that come before it. This is equivalent to predicting a word by using only the previous words in the sentence.
For the second multihead attention sublayer, the keys and values come from the encoder output and the query is the output from the previous decoder layer. Hence, in this decoder layer, every position of the decoder output can attend to all positions of the input sequence.
The third decoder sublayer is the fully connected feed-forward layer, similar to the one used in the encoder. Figure 4 shows the entire decoder stack.
Figure 4: The decoder stack used in the transformer; each decoder layer is built from three sublayers, with residual connections around them, followed by a normalization layer. Source: Mehreen Saeed
The Transformer with Encoder and Decoder Stacks
Figure 5 shows the overall architecture of the transformer with the encoder and decoder stacks. The initial input is the representation of an input sentence along with positional information. The output from the encoder is the input to the decoder stack. The final output produced by the transformer is the prediction for the next word in the input sequence.
Figure 5: The transformer architecture; the final output produced by the transformer is the prediction for the next word in the input sequence. Source: Mehreen Saeed
What Is the Importance of a Transformer, and What Are the Real-Life Use Cases?
The concept of attention was a breakthrough for deep learning networks. The multihead attention mechanism implemented in this model uses parallel processing that speeds up computations. As a whole, the transformer model implements attention and eliminates the need for complex convolutions and recurrences. This is a step away from the previously used recurrent neural networks and long- and short-term memory networks.
The transformer model has shown success in numerous natural language processing applications, including machine translation, sentence paraphrasing, document summarizing, language analysis for detecting hate speech, automatic content moderation, and more.
The potential of transformers in solving AI problems lies beyond the NLP domain. Transformers and variants have been proved to solve problems, including time series analysis and prediction, image recognition, video understanding, and biological sequence analysis.
Transformers have also demonstrated their worth in synthesis applications. A few examples include generating novel molecules, creating images from text prompts, and creating high-quality 3D meshes.
What Are the Challenges and Future Direction of Transformer Technology?
The main challenge in training transformers to learn from data is that they are costly in terms of time and memory, especially for long input sequences. To address these shortcomings, researchers have proposed other variants of transformers. Some of these variants are compressive transformers that use compressed representations to map information, reformers to improve transformer efficiency, and extended transformer construction based on sparse attention mechanisms.
In the future, we are likely to see an increasing number of practical applications and use cases of transformers in the consumer market. Researchers are focused on reducing transformers' time and memory requirements so that the technology can be trained on machines with smaller memory, including edge devices.
Getting Started with Transformers
Transformers represent a breakthrough in the field of deep learning networks. Unlike the typical networks based on convolutions or recurrences, transformers leverage the concept of self-attention to develop simpler models.
Implementing a transformer from scratch for an AI application is not hard. PyTorch and TensorFlow are popular open-source libraries that provide a framework for your implementation. Once you have formulated the problem and gathered the data, go ahead and try out this model yourself.