The release of GPT-3 and the OpenAI API has renewed interest in zero-shot learning behavior and the use of "priming", where language models are encouraged to produce outputs of a particular, structured form by conditioning text generation on cleverly designed text contexts. This begs a related question – how might we go about learning to retrieve relevant documents for language models to condition on when generating text? Can conditioning on retrieved documents improve performance on downstream tasks?
Although GPT-3 successes make it clear that language-modeling at scale produces models that contain a considerable amount of world knowledge hidden away in the models parameters, conditioning on long-form text returned from a query might be a more computationally efficient mechanism for storing factual or highly domain-specific information and a viable alternative to parametric external memory stores as described in works like "Large Memory Layers with Product Keys".
In fact, retrieval methods based on TF-IDF or BM25 queries have been used to improve performance on open-domain question answering tasks for more than a decade, and the subfield of open-domain question answering in general has been a solid test bed for retrieval methods.
The space of question-answering can be divided into two broad categories: "closed-domain" question answering assumes that the relevant context from which the answer can be found is directly provided, while in the "open-domain" formulation locating documents with relevant context is a portion of the task to be solved. In open-domain question answering literature, a "retriever" is used to narrow a potentially massive pool of candidate contexts down to a more manageable set that a "reader" then processes to produce an answer. The first two papers on our list are recent works that approach retrieval through the lens of open-domain question answering.
- Dense Passage Retrieval for Open-Domain Question Answering
- Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering
The third and fourth papers included in this overview of representation learning and retrieval methods takes a different tack, looking instead at retrieval as a component of a general backbone for solving downstream tasks:
Without further ado, let's dive into the first work on our list!
Dense Passage Retrieval for Open-Domain Question Answering
"Dense Passage Retrieval for Open-Domain Question Answering" by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih shows that using dense vector representations in place of sparse TF-IDF or BM25 vectors leads to measurable gains on a passage retrieval task, and then illustrates that better passage retrieval also translates to better downstream accuracy on several open-domain question answering tasks.
Question answering tasks can be approached either in a free-form generation setting (where a model may generate an answer to question in an unconstrained manner) or in a span extraction setting, where the answer to the question is assumed to be a sequence of tokens from a corpus of documents to be queried. The authors use the span extraction formulation for their work.
Model Architecture
The Dense Passage Retrieval (DPR) model uses two distinct BERT encoders (the first for encoding the question and a second for encoding queried passages) and dot product similarity computed between the dense question vector and dense passage vector.
During training, a metric-learning style objective is used to encourage relevant (positive) passages to have high similarity with an encoded passage and irrelevant (negative) passages to yield low similarity. The loss function used is simply the negative log likelihood of a single relevant passage among a pool of negative candidates. Positive passages for other queries in a mini-batch are used as negative passages for each query to share computational cost between examples and enable higher throughput at training time.
At runtime, the passage encoder is fixed and applied to the full pool of candidate passages so the passage encoder incurs no computational cost on a per-query basis. Facebook's approximate nearest neighbor library (FAISS) is then used to index the documents and enable efficient passage retrieval. On a 20-core CPU w/ 512GB of RAM, candidate passages can be returned for nearly 1k queries per second.
For application to question answering there is an additional reader component of the question answering system, but this portion of the model is less relevant in the context of language modeling and general representation learning. Section 6 of the source paper contains details but the reader is largely a fixed BERT-base encoder with only a few thousand additional weights for learning to produce:
- The relevance of each passage (represented by a [CLS] token) from the pool of retrieved candidates
- The likelihood of a token being the start token of the answer
- The likelihood of a token being the end token of the answer
Experimental Results
With as few as 1k training examples, retrieval with the DPR model outperforms the traditional BM25 baseline that uses term frequency heuristics.
The authors note that there are clear qualitative differences between results returned by DPR and BM25, with DPR returning more relevant results for broad, topical queries while BM25 better captures specific rare terms (e.g. "Thoros of Myr"). For these reasons they also test a hybrid BM25 + DPR solution where a weighted combination of the BM25 score and DPR score is used to rank the top 2k candidates generated by BM25.
Improved retrieval metrics also translate to better downstream QA performance, improving over REALM and several graph network based methods on 4/5 question answering tasks considered.
If our end goal is more realistic text generation, question answering is a necessary subtask to be solved (either explicitly or implicitly), so while this paper happens to benchmark on downstream QA tasks the components described in the architecture of the retriever certainly have broader applicability.
Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering
Gautier Izacard and Edouard Grave's work "Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering" from July 2020 builds on the prior work of the Dense Passage Retriever paper but makes the connection to language modeling more explicit.
Unlike "Dense Passage Retrieval for Open-Domain Question Answering", Izacard and Grave's work uses free-form answer text generation rather than structuring question answering as a span extraction task. In some senses this is a more flexible formulation than the labeled span extraction formation, as no span annotations are required and the text of the correct answer is not required to occur in any of the source passages.
The same Dense Passage Retrieval component is used for generating passage candidates in this work, but the model used as reader used is more complex.
For the reader component, the input is first formatted using special tokens to convey relevant information to the model.
Contexts from all inputs are first embedded independently using a finetuned T5 encoder, but the decoder sees a concatenation of all input passages formatted as above. The finetuned T5 decoder then attends over the concatenation of all provided contexts and is trained to output the expected answer using a typical language modeling objective. The authors refer to this formulation as "Fusion-in-Decoder".
100 passages are retrieved and concatenated to 250 tokens each. This means the decoder is potentially attending to 25k tokens worth of context, but the operation still uses a tractable amount of memory because generated answers are short and there is no attention between passages in the encoder.
Experiment Results
The large 770M parameter version of their model posts extremely strong results on NaturalQuestions, TriviaQA, and SQuAD Open.
They also note that the performance of their QA system scales with the number of passages shown to the decoder, which they consider evidence of the decoder aggregating information from many passages.
Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
The next paper on our list, "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks" (by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, and Douwe Kiela), has the distinction of incorporating retrieval to a broader set of tasks than just question answering.
Model Architecture
Like the fusion-in-decoder model, the Retrieval-Augmented Generator (RAG) model is composed of a DPR model for retrieval and a generator that operates on top of the top-k document vectors returned by the retriever. The generator is equivalent to the "reader" we've mentioned in the context of question answering.
For the retriever the author's use pre-trained checkpoints from the DPR paper. For the generator the author's use pre-trained checkpoints from BART – a seq2seq model trained on a text denoising objective, where the model is tasked with properly re-ordering a shuffled set of input sequences and filling in masked spans.
At finetuning time, this two pre-trained models are chained together and trained jointly. Half of the DPR model (the document encoder) is frozen to avoid having to update the cached document vector representations.
The author's experiment with two RAG model variants – one which queries the document corpus with the DPR model once per input, and a second which retrieves a potentially unique document per token.
Experimental Results
RAG post decent benchmarks on tasks ranging from question answering, to sequence labeling, to classification and fact-verification, although it under-performs fusion-in-decoder by a large margin on the question answering datasets benchmarked on by both (somewhat unsurprisingly as RAG was a precursor to fusion-in-decoder).
For open-domain question answering tasks, they note that compared to entirely parametric models that do not include a retriever, RAG tends to produce answers that are more specific and more likely to be factually correct. By conditioning on concrete input documents, specific facts no longer need to be stored in parametric memory of a massive model – they can be effectively copied from a retrieved document that was queried.
In addition to question answering tasks, Lewis et. al test RAG on a Jeopardy Question Generation task, MS-MARCO (an abstractive QA task), and FEVER (a fact verification classification-style task), posting strong results in spite of not using the contexts provided by the MS-Marco task and making minimal task-specific changes.
Finally, the authors note that perhaps one of the most interesting affordances of retrieval-based methods is the ability to swap out the corpus documents are retrieved from. As the world changes, typical language models that store factual information in parametric memory need to be retrained or otherwise-adjusted to reflect the new facts. In a model backed by a retriever, simply swapping out the text corpus that backs the model is sufficient.
Pre-training via Paraphrasing
Mike Lewis, Marjan Ghazvininejad, Gargi Ghosh, Armen Aghajanyan, Sida Wang, and Luke Zettlemoyer's work "Pre-training via Paraphrasing" differs from previously discussed works in that it's framed as a "from-scratch" approach for general representation learning. Their model, MARGE (a Multilingual Autoencoder that Retrieves and Generates), is also unique from DPR-based models in that a singular encoder is used in place of a separate model for the query and document encoders.
Model Architecture
MARGE works by treating a set of source documents and learned similarities to a target document of interest as an auto-encoder bottleneck for a document reconstruction task.
An encoder is responsible for taking a given document, producing a summary document vector and a set of contextualized token representations. The same encoder is used to embed both the target and source (evidence) documents. Cosine similarity between the source documents and target document is computed and is used as an attention bias in the decoder.
In a manner similar to fusion-in-decoder, the top-k most similar documents to a document of interest are provided as inputs to the decoder alongside the contextualized token representations of all source documents.
The authors note that reconstruction might not be reasonable if the set of retrieved documents are not sufficiently relevant at initialization, and that because gradients are not propagated to documents outside of the top-k retrieved documents there was concern that the model would be unable to escape this mode. However, even the random features at initialization were sufficient to encourage documents that contain similar words to the target document to appear in the retrieved set and avoid the chicken-and-egg problem.
Because it's infeasible to embed all target documents after every update to the encoder to determine which documents should be part of the relevant set, the similarity matrix between the source and target documents is cached and only updated infrequently. To save on computation, the similarity matrix is used to constructed batches such that the source documents of several target documents overlap as much as possible.
Experimental Results
Perhaps the most impressive feature of MARGE is the breadth of downstream tasks the model produces strong benchmarks on. Perhaps unsurprisingly, MARGE performs exceptionally well on the MLSum benchmark.
Because documents from any language can be part of the retrieved set for a given target document (and the authors apply some simple heuristics to encourage this behavior), MARGE is also well suited to translation tasks, performing comparably to mBART on English-to-German and Chinese-to-English translation tasks.
The authors additionally test on a multilingual question answering benchmark (MLQA) and a paraphrase detection benchmark which MARGE performs predictably well on given it's paraphrase-like training objective.
Acknowledgements and Conclusion
Thank you to Aran Komatsuzaki for the excellent paper discussions and for highlighting the promise of retrieval-based methods in a recent Reddit thread. I hope coming months will see broader exploration of retrieval-augmented approaches as parameter efficient alternatives for text representation learning.
If you enjoyed this article you may be interested in some of the following related works in the retrieval and language modeling space: