Attention Models: What They Are and Why They Matter
With attention models you can design ML algorithms that learn which parts of the input are essential to solving a problem and which are irrelevant.
The concept of attention in AI and machine learning was inspired by psychology, which describes it as the cognitive processes involved in focusing or “attending to” specialized pieces of information while ignoring the rest.
When reading a book, people tend to concentrate on a few selected words to understand the full meaning of a sentence. Similarly, when presented with an image of a scene, people need only focus on a few objects to comprehend the theme of the image. People tend to ignore the minute details in the picture.
In machine learning, “attention” refers to assigning relevance and importance to specific parts of the input data while ignoring the rest to mimic the cognitive attention process. It lets you design algorithms that can learn which parts of the input are essential to solving a problem and which are irrelevant.
Why Is Attention Important, and What Are the Real-World Use Cases?
The attention mechanism has revolutionized the world of deep learning and helped to solve many challenging real-world problems. Research has shown that adding an attention layer to different types of deep learning neural architectures, such as encoder-decoder networks and recurrent neural networks, improves their performance.
Attention is now also an inherent part of many deep learning networks. These include transformers, memory networks, and graph attention networks.
In practice, the attention model has been shown to successfully solve many real-life problems in the domains of natural language processing and computer vision. Applications include translating languages, classifying documents, paraphrasing and summarizing text, recognizing images, answering questions visually, generating images, synthesizing text and images, and more.
Apart from NLP and computer vision, healthcare applications use attention-based networks to analyze medical data, classify diseases in MRI images, develop chatbots for conversational AI, among other applications.
How Does the Attention Model Work?
The example dataset below illustrates how the attention mechanism works. This provides a feature or predictor x to determine the value of the target variable y. If you use a least squares linear regression model, you can predict the value of the target variable by using a fixed weight w regardless of the input values xtest. An example expression is given below, with b as a constant.
The attention mechanism takes into account the relevance of the input values xtest to the predictors in the dataset. Instead of using a static value w, the weights are generated according to how similar xtest is to a data point in the training instances. The generalized attention model predicts the target y as:
From the above equation, the main ingredients of the generalized attention model are given below. The figure below provides a worked-out example.
- Keys K: These correspond to the values of the predictors in the training dataset. In the figure, keys are the values in column x.
- Values V: These correspond to the target values in the training dataset. Column y of the figure corresponds to values.
- Query Q: This is the value of the test data point xtest.
- Alignment function a: This function determines the similarity of keys with the query.
- Distribution function f: This function normalizes the weights.
Figure 1: An example of the generalized attention model. The alignment function has arbitrarily been chosen as a product of key and query. The distribution function simply normalizes the weights to sum to one. Source: Mehreen Saeed
Alignment Functions
The alignment function computes the similarity between the query Q and keys K. You can choose the function that works best for your use case. For example, you can use dot product, cosine similarity, or scaled dot product as an alignment function.
The main guideline is to return a high value if the query is close to the key, and a low value if they are very different.
There are also other alignment functions such as concat alignment, where keys and queries are both concatenated. You can also use area attention, which does not consider the keys and depends on the query only.
Distribution Function
Use a distribution function to ensure that the attention weights lie between 0 and 1 and are normalized to sum to 1. The logistic sigmoid or softmax function suffices for this purpose. The output values from the distribution function can be seen as probabilistic relevance scores.
Different Types of Attention Models
There are several variations of the generalized attention model. “An Attentive Survey of Attention Models” describes the following taxonomy of attention models:
Sequence-Based
There are different attention models based on the type of input sequences. This mainly differentiates how you define query Q and keys K.
- Distinctive attention: This is used in tasks such as language translation, speech recognition, and sentence paraphrasing when the key and query states correspond to distinct input and output sequences. The distinctive attention mechanism was used by Bahdanau et al. in “Neural Machine Translation by Jointly Learning to Align and Translate.”
- Co-attention: This model learns attention weights by using multiple input sequences at the same time. For example, co-attention was used for visual question answering, where both image sequences and words in questions were used to model attention.
- Self attention: In self attention, the key and query states belong to the same sequence. For example, in a document categorization task, the input is a sequence of words, and the output is the category, which is not a sequence. In this scenario, the attention mechanism learns the relevance of tokens in the input sequence for every token of the same input sequence.
Abstraction Level–Based
You can also define attention models based on the hierarchical levels at which the attention weights are computed.
- Single-level: In single-level attention, the weights are calculated only for the original input sequence.
- Multilevel: For multilevel attention, the attention weights are repeatedly calculated at different levels of abstraction. The lower abstraction level serves to generate the query state for the higher level. A good example is the word-level and document-level abstraction used in document categorization.
Position-Based
Position-based attention models determine the actual positions in the input sequence on which attention weights are calculated.
- Soft/global attention: Global attention (also called “soft attention”) uses all the tokens in the input sequence to compute the attention weights and decides the relevance of each token. For example, in text translation models, all of the words in a sentence are considered to calculate the attention weights in order to decide the relevance of each word.
- Hard attention: Hard attention computes the attention weights from some tokens of the input sequence while leaving out the others. Deciding which tokens to use is also a part of hard attention. So for the same text translation task, some words in the input sentence are left out when computing the relevance of different words.
- Local attention: Local attention is the middle ground between hard and soft attention. It picks a position in the input sequence and places a window around that position to compute an attention model.
Representation-Based
There are different types of representation based attention models:
- Multi-representational: In a multi-representational representation of attention, the inputs are represented using different features; hence, the attention weights are used to assign relevance to all the representations.
- Multidimensional: Similar to multi-representational models, multidimensional attention models can be used to determine the relevance of various dimensions if the input data has a multidimensional representation.
What Are the Challenges and Future Directions of Attention Models?
Attention models implemented in conjunction with deep learning networks have their challenges. Training and deploying these large models is costly in terms of time and computational power. This includes the self-attention model implemented in transformers that involves a large number of computations. Research is now geared toward reducing these costs.
Online attention is another area of future research to implement real-time applications such as machine translation while a person is speaking. Generating correct, partially translated text before a person finishes speaking the entire sentence is a challenge in itself.
Determining the structure and network architecture of a deep learning model that employs attention is another aspect that needs further research. This has given rise to auto-learning attention, where the goal is to use attention itself to determine the optimal structure and design of an attention model.
The attention mechanism has served to motivate future research into alternatives to attention models. Active memory is one such alternative that does not focus on a specialized area of memory as implemented by attention models but operates uniformly in parallel on all of it.
Get Started with Attention
An attention layer integrated within a mathematical model assigns importance weights to different tokens in an input sequence. This determines which parts of the input are more relevant to solving a given problem and which input tokens can be ignored. The concept of attention works in conjunction with other machine learning models, which are typically based on deep learning architectures.
To get started implementing your own attention model, the first step is to understand how it is incorporated into different deep learning networks. The deep learning model you choose, along with the type of attention model you incorporate within it, depends mostly on the nature of the problem you are solving and the application you are developing.