Fusion-in-Decoder Training (FiD)#
FiD finetunes an encoder-decoder using a combination of 1) retrieved documents from a frozen retriever, and 2) ground-truth <chat_history, question, passage, answer> pairs (for one turn QA, you can use chat_history=''). Under the hood, this training script:
use the frozen retriever to retrieve
kdocuments for eachquestionin the training data.augments the supporting passages to be
passage_aug = passage + retrieved_passages.for each passage
pinpassage_aug, concatenate with the chat history and question to form the inputinput_i = chat_history + question + p.encode each
input_iusing the encoder in parallelconcatenate the hidden states of the encoder and feed them into the encoder-decoder
trains the encoder-decoder model using standard cross-entropy loss on the ground-truth answer.
Visually:
Architecture of the Fusion-in-Decoder method. (Izacard and Grave, 2020)#
Running FiD Trainer#
At a high level, SwR training requires:
a training, evaluation, and test dataset of <question, passage, answer> pairs
an encoder-decoder model (e.g.
lmsys/fastchat-t5-3b-v1.0) to be trainedan embedding model (e.g.
intfloat/e5-base-v2) used for training AND automatic E2E evaluation during training
python scripts/train/qa_llm/train_w_gt_fid.py \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
# other training hyperparameters omitted
--model_name_or_path lmsys/fastchat-t5-3b-v1.0 \
--embedding_model intfloat/e5-base-v2 \
--embedding_max_num_to_retrieve 3 \
--output_dir model_checkpoints/my_SwR_qa_model \
--train_file <example/train_w_qa.jsonl> \
--eval_file <example/eval_w_qa.jsonl> \
--test_file <example/test_w_qa.jsonl> \
--full_dataset_file_path <example/documents.pkl> \
--full_dataset_index_path <example/index>
for a full list of arguments, you can run python scripts/train/qa_llm/train_w_gt_fid.py -h. In this example:
--per_device_train_batch_size,--model_name_or_path, and other training arguments are from the HuggingFace TrainingArguments class. Since we implement our trainers from Huggingface’s Trainer class, it is compatible with most of the arguments there.--embedding_modelis used to perform retrieval during training and evaluation.--embedding_max_num_to_retrievedictates the size ofpassage_augduring training. In practice, we boundlen(passage_aug) = embedding_max_num_to_retrieve + 1.--output_diris the directory where the trained model, training history, and evaluation results will be saved--train_file,--eval_file, and--test_fileare the paths to the training, evaluation, and test datasets. See RQA Data Format for more details on the format of these files.--full_dataset_file_pathand--full_dataset_index_pathare the paths to the documents and their indices. This is used byeval_embedding_modelto perform retrieval during evaluation. See RQA Data Format for more details on the format of these files.
Note
For complete examples (e.g., obtaining files like <example/train_w_qa.jsonl> or other training hyperparameters), you can use Databricks and Faire as references.
References
Gautier Izacard and Edouard Grave. 2020. Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering.