Word Prediction Finetuning
HappyWordPrediction contains two methods for training
- train(): fine-tune the model to understand a body of text better
- eval(): determine how well the model performs
train()
inputs:
- input_filepath (string): a path file to a text file that contains nothing but text to train the model with
- args (WPTrainArgs): a dataclass with the same fields types as shown in table 4.1.
- eval_filepath (string): By default, an evaluating dataset will be generated from the supplied training data. But, you may provide a filepath to a text of CSV file as described for input_filepath to use standalone evaluating data.
Table 4.0
| text | |——————————–| | This is a training case. | | This is another training case |
Table 4.1
Parameter | Default |
---|---|
learning_rate | 5e-5 |
num_train_epochs | 1 |
batch_size | 1 |
weight_decay | 0 |
save_path | ”” |
load_path | ”” |
mlm_probability | 0.15 |
line-by-line | False |
fp16 | False |
eval_ratio | 0.1 |
save_steps | 0.0 |
eval_steps | 0.1 |
logging_steps | 0.1 |
output_dir | “happy_transformer” |
max_length | None |
Information about the learning parameters can be found here
Information about saving/loading preprocessed data can be found here
mlm_probability: The probability of masking a token.
line-by-line: If False, training data is concatenated and then divided into sections that are the length of the model’s input size, other than the last input which may be shorter. If True, each input contains the text from a single line within the training data. The text may be truncated if the line is too long (eg BERT’s max input size is 512 tokens).
Example 4.4:
from happytransformer import HappyWordPrediction, WPTrainArgs
# --------------------------------------#
happy_wp = HappyWordPrediction()
args = WPTrainArgs(num_train_epochs=1)
happy_wp.train("../../data/wp/train-eval.txt", args=args)
eval()
Input:
- input_filepath (string): a path file to text file with just text to evaluate
- args (WPEvalArgs): a dataclass with the fields shown in Table 4.2
Table 4.2
Parameter | Default |
---|---|
save_path | ”” |
load_path | ”” |
line-by-line | False |
max_length | None |
See the explanations under Table 4.0 for more information
Output: An object with the field “loss”
Example 4.5
from happytransformer import HappyWordPrediction, WPEvalArgs
# --------------------------------------#
happy_wp = HappyWordPrediction()
args = WPEvalArgs(preprocessing_processes=2)
result = happy_wp.eval("../../data/wp/train-eval.txt", args=args)
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'>
print(result) # EvalResult(eval_loss=0.459536075592041)
print(result.loss) # 0.459536075592041