Word Prediction Basic Usage
predict_mask()
The method predict_masks() contains 3 arguments:
- text (string): a body of text that contains a single masked token
- targets (list of strings): a list of potential answers. All other answers will be ignored
- top_k (int): the number of results that will be returned
Returns: A list of objects with fields “token” and “score”
Note: if targets are provided, then top_k will be ignored and a score for each target will be returned.
Example 4.1:
from happytransformer import HappyWordPrediction
#--------------------------------------#
happy_wp = HappyWordPrediction() # default uses distilbert-base-uncased
result = happy_wp.predict_mask("I think therefore I [MASK]")
print(type(result)) # <class 'list'>
print(result) # [WordPredictionResult(token='am', score=0.10172799974679947)]
print(type(result[0])) # <class 'happytransformer.happy_word_prediction.WordPredictionResult'>
print(result[0]) # [WordPredictionResult(token='am', score=0.10172799974679947)]
print(result[0].token) # am
print(result[0].score) # 0.10172799974679947
Example 4.2:
from happytransformer import HappyWordPrediction
#--------------------------------------#
happy_wp = HappyWordPrediction()
result = happy_wp.predict_mask("To better the world I would invest in [MASK] and education.", top_k=2)
print(result) # [WordPredictionResult(token='health', score=0.1280556619167328), WordPredictionResult(token='science', score=0.07976455241441727)]
print(result[1]) # WordPredictionResult(token='science', score=0.07976455241441727)
print(result[1].token) # science
Example 4.3:
from happytransformer import HappyWordPrediction
#--------------------------------------#
happy_wp = HappyWordPrediction()
targets = ["technology", "healthcare"]
result = happy_wp.predict_mask("To better the world I would invest in [MASK] and education.", targets=targets, top_k=2)
print(result) # [WordPredictionResult(token='healthcare', score=0.07380751520395279), WordPredictionResult(token='technology', score=0.009395276196300983)]
print(result[1]) # WordPredictionResult(token='technology', score=0.009395276196300983)
print(result[1].token) # technology