Question Answering Finetuning
HappyQuestionAnswering contains three methods for training
- train(): fine-tune a question answering model to become better at a certain task
- eval(): determine how well the model performs on a labeled dataset
- test(): run the model on an unlabeled dataset to produce predictions
train()
inputs:
- input_filepath (string): a path file to a csv file as described in table 3.0
- args (QATrainArgs): a dataclass of type QATrainArgs which contains the fields shown in table 3.1
- eval_filepath (string): By default, an evaluating dataset will be generated from the supplied training data. But, you may provide a filepath to a CSV file as described for input_filepath to use standalone evaluating data.
Table 3.0
- context (string): background information for answer the question
- question (string): the question that will be asked
- answer_text(string): the answer in string format
- answer_start(int): the char index of the start of the answer
context | question | answer_text | answer_start |
---|---|---|---|
October 31st is the date | what is the date? | October 31st | 0 |
The date is November 23rd | what is the date? | November 23rd | 12 |
Information about the learning parameters can be found here
Information about saving/loading preprocessed data can be found here
Table 3.1
Parameter | Default |
---|---|
learning_rate | 5e-5 |
num_train_epochs | 1 |
batch_size | 1 |
weight_decay | 0 |
save_path | ”” |
load_path | ”” |
fp16 | False |
eval_ratio | 0.1 |
save_steps | 0.0 |
eval_steps | 0.1 |
logging_steps | 0.1 |
output_dir | “happy_transformer” |
Output: None
Example 3.3:
from happytransformer import HappyQuestionAnswering, QATrainArgs
# --------------------------------------#
happy_qa = HappyQuestionAnswering()
args = QATrainArgs(num_train_epochs=1)
happy_qa.train("../../data/qa/train-eval.csv", args=args)
eval()
Input:
- input_filepath (string): a path file to a csv file as described in Table 3.0
- args (QAEvalArgs): a dataclass of type QAEvalArgs which contains the fields shown in Table 3.3
Information about saving/loading preprocessed data can be found here
Table 3.2
Parameter | Default |
---|---|
save_path | ”” |
load_path | ”” |
Output: A dataclass with the variable “loss”
Example 3.4:
from happytransformer import HappyQuestionAnswering, QAEvalArgs
# --------------------------------------#
happy_qa = HappyQuestionAnswering()
args = QAEvalArgs() # The default settings as an example
result = happy_qa.eval("../../data/qa/train-eval.csv")
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'>
print(result) # EvalResult(eval_loss=0.11738169193267822)
print(result.loss) # 0.1173816919326782
test()
Input:
- input_filepath (string): a path file to a csv file as described in table 3.2
- args (QATestArgs): A dataclass that contains the values shown in table x above
Table 3.3
- context (string): background information for answer the question
- question (string): the question that will be asked
context | question |
---|---|
October 31st is the date | what is the date? |
The date is November 23rd | what is the date? |
Output: A list of named tuples with keys: “answer”, “score”, “start” and “end”
Example 3.5:
from happytransformer import HappyQuestionAnswering, QATestArgs
# --------------------------------------#
happy_qa = HappyQuestionAnswering()
result = happy_qa.test("../../data/qa/test.csv")
print(type(result))
print(result) # [QuestionAnsweringResult(answer='October 31st', score=0.9939756989479065, start=0, end=12), QuestionAnsweringResult(answer='November 23rd', score=0.967872679233551, start=12, end=25)]
print(result[0]) # QuestionAnsweringResult(answer='October 31st', score=0.9939756989479065, start=0, end=12)
print(result[0].answer) # October 31st