Current machine learning models that are deployed for vision and in natural language processing(NLP) tasks have more than a billion parameters. This allows for better results as the model generalizes over a large wide range of parameters. But there is a catch, as the capacity increases, the computation complexity increases.
The ability to increase the number of parameters while keeping the same computational budget allows the overall system to strike a better trade-off between prediction accuracy and computation efficiency both at training and test time.
The memory is very large by design and therefore significantly increases the capacity of the architecture, by up to a billion parameters with negligible computational overhead. To address these challenges, a paper was introduced which proposes a structured memory which can be easily integrated into a neural network.
This new layer is purported to tackle problems in areas where existing architectures underfit in the presence of a vast amount of available data, or in case of slow work in practice.
Overview Of Key-Value Memory Layer
In a key-value memory layer, usually, the input ‘x’ is processed through a query network that produces a query vector ‘q’, which is compared to all the keys. The output is the sparse weighted sum over the memories associated with the selected keys.
The authors define the keys as the concatenation of two sub-keys. These large number of keys can be thought of as memory slots. Despite a large number of memory slots, finding the exact closest keys to the input is very efficient, typically requiring O(√|K|) vector comparisons, where |K| is the total number of memory slots.
Only a handful of memory slots are updated for each input at training time since sparsity of key selection and parameter updates make both training and inference very efficient.
To validate their claims, the authors experimented with the widely popular BERT(Bidirectional Encoder Representations from Transformers) and Generative pre-training transformer(GPT-2). The attempt here is to integrate the memory within these transformer architectures.
Augmenting Large Scale language Models
BERT and GPT-2 were selected because of their success in proving that increasing the capacity of large models directly translates to large improvements in language modelling, which in turn translates to better performance in both language understanding tasks and text generation.
The transformer network is the current workhorse of Natural Language Processing (NLP) and is built by stacking blocks composed of self-attention layers followed by fully connected layers (dubbed FFN).
The components of the memory layer bear similarities to the query, key and value networks used in these self-attention layers with two notable differences:
- the keys and values do not correspond to input tokens but are free embedding vectors, and
- the number of values (memory size) is very large.
This work borrows some ideas from product quantization (PQ), which is an approximate search technique that maps database vectors into compact codes. And, also exploit the idea to represent a large set of key vectors by a drastically smaller number of vectors, that is updated by regular back-propagation.
The training set used for the experiment was composed of 28 billion words (140 GB of data) extracted from about 40 million English news articles indexed by Common Crawl corpora. The validation and test sets are both composed of 5000 news articles removed from the training set.
The authors found that it is beneficial to set Adam learning rate at 10 ^(-3). Models were implemented with PyTorch and trained on 32 Volta GPUs.
This work is an attempt to:
- Propose a new layer(key-value memory) that allows to drastically improve the capacity of a neural network with negligible computational overhead.
- Provide results that show important gains on large-scale language modelling, reaching with 12 layers the performance of a 24-layer BERT-large model with half the running time.
- Demonstrate why adding memory to the model is more beneficial than increasing the number of layers.
Know more about the work here.