Text-to-text Finetuning
HappyTextToText contains two methods for training
- train(): fine-tune the model to convert a standalone text to another standalone piece of text
- eval(): determines how well the model performs
train()
inputs:
- input_filepath (string): a path file to a csv file as shown in table 7.1
- args (TTTrainArgs): a dataclass with the same fields types as shown in Table 7.2.
- eval_filepath (string): By default, an evaluating dataset will be generated from the supplied training data. But, you may provide a filepath to a CSV file as described for input_filepath to use standalone evaluating data.
Table 7.1
Contains two columns with the following header values: input and target
input | target |
---|---|
grammar: I has poor grammars | I have poor grammar |
grammar: I wants too plays | I want to play |
Table 7.2
Parameter | Default |
---|---|
learning_rate | 5e-5 |
num_train_epochs | 1 |
batch_size | 1 |
preprocessing_processes | 1 |
save_path | ”” |
load_path | ”” |
max_input_length | None |
max_output_length | None |
fp16 | False |
eval_ratio | 0.1 |
save_steps | 0.0 |
eval_steps | 0.1 |
logging_steps | 0.1 |
output_dir | “happy_transformer” |
Information about the learning parameters can be found here
preprocessing_processes: Number of processes used for preprocessing. We recommend 1-4. max_input_length: The maximum number of tokens for the input. The rest get truncated. By default the maximum number of tokens the model can handle is used. max_output_length: Ditto, except for the output.
Example 7.3:
from happytransformer import HappyTextToText, TTTrainArgs
# --------------------------------------#
happy_tt = HappyTextToText()
args = TTTrainArgs(num_train_epochs=1)
happy_tt.train("../../data/tt/train-eval-grammar.csv", args=args)
eval()
Input:
- input_filepath (string): a path file to a csv file with the same format as described for the training data in table 7.1
- args (TTEvalArgs): a dataclass with the same fields shown in Table 7.3
Table 7.3
Parameter | Default |
---|---|
preprocessing_processes | 1 |
max_input_length | 1024 |
max_output_length | 1024 |
See Table 7.1 for more information
Output: An object with a single field called “loss”
Example 1.4
from happytransformer import HappyTextToText, TTEvalArgs
# --------------------------------------#
happy_tt = HappyTextToText()
args = TTEvalArgs(preprocessing_processes=1)
result = happy_tt.eval("../../data/tt/train-eval-grammar.csv", args=args)
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'>
print(result) # EvalResult(loss=3.2277376651763916)
print(result.loss) # 3.2277376651763916
Tutorials
Top T5 Models Grammar Correction Fine-tune a Grammar Correction Model