Introduction
REALM is a paper mentioned in the T5 paper titled: How Much Knowledge Can You Pack Into The Parameters of a Language Model?
TLDR: This paper retrieves documents that have the information present while solving Question-Answer type problems.
NOTE: This post is more like my running notes while reading the paper than a comprehensive blog. I will update this blog once I learn a little more about the transformer architecture.
Introduced a latent knowledge retriever, which can attend and retrieve documents over large corpus and can be trained in unsupervised manner using masked language modelling technique and backprop through retreiver which considers lots of docs.
Training process for REALM Key point: Train retriever using a performance-based signal from unsupervised text.
Retrieval based LM => Moar computational resources => Moar money
- Solution: Computation performed for each doc is cached and can be used again. Best doc selected using Maximum Inner Product Search(MIPS). Read the paper here.
REALM retriever can be used on downstream tasks via transfer learning.
REALM is SOTA on NQ-Open, WQ and CuratedTrec.
Approach
Retreive-then-predict generative process
Training: Masked-LM. Fine-tuning: Open QA task
Computing chance of the document given a question decomposed into two steps:
Function to be computed:
Given
,retrive documents from corpus . Modelled as:Condition of both
and to generate output i.eOverall likelihood
is generated by treating as latent variable and marginalizing over all documents
Architecture
Neural Knowledge Retriever which models the distribution:
Knowledge Augmented Encoder which models the distribution
Neural Knowledge Retriever
Dense inner product model.
and are embedding functions is called relevance score. It is inner product of vector embeddings.Relevant Distribution is softmax over all relevance scores
Embedding implement using BERT-style transformers. Join using <SEP>, prefix using <CLS> and append <SEP> as the end token.
Pass above into transformer, which gives over vector for each token. Perform linear projection to reduce dimensionality of vector
Knowledge-Augmented Encoder
Given input
and relevant doc , this definesJoin
and into single sequence and feed into transformerHere, training is different for pre-training vs fine-tuning
For pre-training, predict [MASK] token. Use same Masked LM(MLM) loss as in Transformer(Devlin et al.)
For Open-QA, we need to produce string
.Assumption:
occurs as sequence of tokens in some document in the corpus.
Training
Compute gradients in
and and optimize using SGD.Challenge: Computing
Approx by summing over top
documents with highest prob underQuestion: How to find top
docs? Answer: Use MIPSNeed to precompute
for all docs. Problem? It changes with each step of SGD.Solution: Async refresh
every 500 stepsUse MIPS to select top
docs. For these docs, recompute using new .
Implementing async MIPS refreshes
Two jobs running in parallel:
Primary trainer: Perform gradient updates on parameters
Secondary index builder: Embeds and indexes the docs
Async MIPS implementation Async refresh used only for pre-training
For fine tuning, build index once from pre-trained
and use it.
What does retriever learn?
Retriever promotes docs that improve accuracy
This can be analyzed by analyzing gradient wrt the parameters
Injecting inductive biases into pre-trianing
Salient span masking: Some questions require only local context. Select named entities and dates and mask one of them. Performs better.
Null document: Add null document to top
documents to allow answers even when no context is requiredProhibiting trivial retrievals: If knowledge corpus
is the same as pre-training corpus , it can predict by looking at in . Exclude trivial candidateInitialization: Warm up
and using Inverse Cloze Task(ICT) i.e model trained to retrieve the doc where the sentence came from.
Experiments
- REALM outperforms all approaches by a big margin.
Future Work
Structured knowledge where we learn entities which are informative
Multi lingual setting. Retreiving knowledge in high resource language to better represent text in low resource language
Multi model setting. Retrieve images or videos that can provide knowledge not present in text
Resources
There are a number of other resources you can use to learn more about this paper such as:
Comments
Overall, I enjoyed reading this paper. However, there are two key points that concern me: