Text Classification Finetuning

HappyTextClassification contains three methods for training

  • train(): fine-tune the model to become better at a certain task
  • eval(): determine how well the model performs on a labeled dataset
  • test(): run the model on an unlabeled dataset to produce predictions

train()

inputs:

  1. input_filepath (string): a path file to a csv file as described in table 2.0
  2. args (TCTrainArgs): a dataclass. It has the possible values show in table 2.1

Table 2.0

  1. text (string): text to be classified
  2. label (int): the corresponding label. Must be greater than or equal to 0
  3. eval_filepath (string): By default, an evaluating dataset will be generated from the supplied training data. But, you may provide a filepath to a CSV file as described for input_filepath to use standalone evaluating data.
text label
Wow what a great place to eat 1
Horrible food 0
Terrible service 0
I’m coming here again 1

Table 2.1

Information about the learning parameters can be found here Information about saving/loading preprocessed data can be found here

Parameter Default
learning_rate 5e-5
num_train_epochs 1
batch_size 1
weight_decay 0
save_path ””
load_path ””
fp16 False
eval_ratio 0.1
save_steps 0.0
eval_steps 0.1
logging_steps 0.1
output_dir “happy_transformer”

Output: None

Example 2.2:

from happytransformer import HappyTextClassification, TCTrainArgs
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
                                   model_name="distilbert-base-uncased-finetuned-sst-2-english",
                                   num_labels=2)  # Don't forget to set num_labels! 
args = TCTrainArgs(num_train_epochs=1)
happy_tc.train("../../data/tc/train-eval.csv", args=args)

eval()

Input:

  1. input_filepath (string): a path file to a csv file as described in table 2.1

output:

An object with the field “loss”

Example 2.3:

from happytransformer import HappyTextClassification, TCEvalArgs
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
                                   model_name="distilbert-base-uncased-finetuned-sst-2-english",
                                   num_labels=2)  # Don't forget to set num_labels!
result = happy_tc.eval("../../data/tc/train-eval.csv")
print(type(result))  # <class 'happytransformer.happy_trainer.EvalResult'>
print(result)  # EvalResult(eval_loss=0.007262040860950947)
print(result.loss)  # 0.007262040860950947

test()

Input:

  1. input_filepath (string): a path file to a csv file as described in table 2.2

Output: A list of named tuples with keys: “label” and “score”

The list is in order by ascending csv index.

Table 2.2

  1. text (string): text that will be classified
Text
Wow what a great place to eat
Horrible food
Terrible service
I’m coming here again

Example 2.4:

from happytransformer import HappyTextClassification
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
                                   model_name="distilbert-base-uncased-finetuned-sst-2-english",
                                   num_labels=2)  # Don't forget to set num_labels!
result = happy_tc.test("../../data/tc/test.csv")
print(type(result))  # <class 'list'>
print(result)  # [TextClassificationResult(label='POSITIVE', score=0.9998401999473572), TextClassificationResult(label='LABEL_0', score=0.9772131443023682)...
print(type(result[0]))  # <class 'happytransformer.happy_text_classification.TextClassificationResult'>
print(result[0])  # TextClassificationResult(label='POSITIVE', score=0.9998401999473572)
print(result[0].label)  # POSITIVE


Example 2.5:

from happytransformer import HappyTextClassification
# --------------------------------------#
happy_tc = HappyTextClassification(model_type="DISTILBERT",
                                   model_name="distilbert-base-uncased-finetuned-sst-2-english",
                                   num_labels=2)  # Don't forget to set num_labels!
before_loss = happy_tc.eval("../../data/tc/train-eval.csv").loss
happy_tc.train("../../data/tc/train-eval.csv")
after_loss = happy_tc.eval("../../data/tc/train-eval.csv").loss
print("Before loss: ", before_loss)  # 0.007262040860950947
print("After loss: ", after_loss)  # 0.000162081079906784
# Since after_loss < before_loss, the model learned!
# Note: typically you evaluate with a separate dataset
# but for simplicity we used the same one