Text Classification Finetuning
HappyTextClassification contains three methods for training
- train(): fine-tune the model to become better at a certain task
- eval(): determine how well the model performs on a labeled dataset
- test(): run the model on an unlabeled dataset to produce predictions
train()
inputs:
- input_filepath (string): a path file to a csv file as described in table 2.0
- args (TCTrainArgs): a dataclass. It has the possible values show in table 2.1
Table 2.0
- text (string): text to be classified
- label (int): the corresponding label. Must be greater than or equal to 0
- eval_filepath (string): By default, an evaluating dataset will be generated from the supplied training data. But, you may provide a filepath to a CSV file as described for input_filepath to use standalone evaluating data.
text | label |
---|---|
Wow what a great place to eat | 1 |
Horrible food | 0 |
Terrible service | 0 |
I’m coming here again | 1 |
Table 2.1
Information about the learning parameters can be found here Information about saving/loading preprocessed data can be found here
Parameter | Default |
---|---|
learning_rate | 5e-5 |
num_train_epochs | 1 |
batch_size | 1 |
weight_decay | 0 |
save_path | ”” |
load_path | ”” |
fp16 | False |
eval_ratio | 0.1 |
save_steps | 0.0 |
eval_steps | 0.1 |
logging_steps | 0.1 |
output_dir | “happy_transformer” |
Output: None
Example 2.2:
from happytransformer import HappyTextClassification, TCTrainArgs
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
model_name="distilbert-base-uncased-finetuned-sst-2-english",
num_labels=2) # Don't forget to set num_labels!
args = TCTrainArgs(num_train_epochs=1)
happy_tc.train("../../data/tc/train-eval.csv", args=args)
eval()
Input:
- input_filepath (string): a path file to a csv file as described in table 2.1
output:
An object with the field “loss”
Example 2.3:
from happytransformer import HappyTextClassification, TCEvalArgs
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
model_name="distilbert-base-uncased-finetuned-sst-2-english",
num_labels=2) # Don't forget to set num_labels!
result = happy_tc.eval("../../data/tc/train-eval.csv")
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'>
print(result) # EvalResult(eval_loss=0.007262040860950947)
print(result.loss) # 0.007262040860950947
test()
Input:
- input_filepath (string): a path file to a csv file as described in table 2.2
Output: A list of named tuples with keys: “label” and “score”
The list is in order by ascending csv index.
Table 2.2
- text (string): text that will be classified
Text |
---|
Wow what a great place to eat |
Horrible food |
Terrible service |
I’m coming here again |
Example 2.4:
from happytransformer import HappyTextClassification
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
model_name="distilbert-base-uncased-finetuned-sst-2-english",
num_labels=2) # Don't forget to set num_labels!
result = happy_tc.test("../../data/tc/test.csv")
print(type(result)) # <class 'list'>
print(result) # [TextClassificationResult(label='POSITIVE', score=0.9998401999473572), TextClassificationResult(label='LABEL_0', score=0.9772131443023682)...
print(type(result[0])) # <class 'happytransformer.happy_text_classification.TextClassificationResult'>
print(result[0]) # TextClassificationResult(label='POSITIVE', score=0.9998401999473572)
print(result[0].label) # POSITIVE
Example 2.5:
from happytransformer import HappyTextClassification
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
model_name="distilbert-base-uncased-finetuned-sst-2-english",
num_labels=2) # Don't forget to set num_labels!
before_loss = happy_tc.eval("../../data/tc/train-eval.csv").loss
happy_tc.train("../../data/tc/train-eval.csv")
after_loss = happy_tc.eval("../../data/tc/train-eval.csv").loss
print("Before loss: ", before_loss) # 0.007262040860950947
print("After loss: ", after_loss) # 0.000162081079906784
# Since after_loss < before_loss, the model learned!
# Note: typically you evaluate with a separate dataset
# but for simplicity we used the same one