Training a Retriever

Training a Retriever#

Given a user query, a retriever selects k most relevant passages from a collection of documents. LocalRQA implements trainers for encoders that distill from a down-stream LM and trainers that perform contrastive learning using a dataset of <q,p> pairs (and optionally hard negative examples):

At a high level, we provide ready-to-use training scripts for each algorithm above. These scripts allow you to specify the training data, model, and other hyperparameters in a single command line. For instance, with CTL training:

python scripts/train/retriever/train_ctl_retriever.py \
--full_dataset_file_path <example/documents.pkl> \
--train_file <example/train_w_q.jsonl> \
--eval_file <example/eval_w_q.jsonl> \
--model_name_or_path intfloat/e5-base-v2 \
--pooling_type mean \
--learning_rate 1e-4 \
--per_device_train_batch_size 256 \
--per_device_eval_batch_size 128 \
--hard_neg_ratio 0.05 \
--metric_for_best_model eval_retr/document_recall/recall4 \
--output_dir model_checkpoints/my_CTL_ret_model

this will finetune intfloat/e5-base-v2 using the training data from <example/train_w_q.jsonl>, and then save the model at model_checkpoints/my_CTL_ret_model.

Note

During training, our scripts will also perform automatic retriever evaluations on the validation set, i.e., <example/eval_w_q.jsonl>. The evaluation results will be printed to the console and saved to the output directory.

For more details on retriever evaluation, please refer to Retriever Evaluation.

For more details on each training algorithm/script, please refer to their respective sections.