使用流水线API
from transformers import pipeline, DistilBertTokenizer, DistilBertForQuestionAnswering
path = "E:\\temp\\models\\models--distilbert--distilbert-base-cased-distilled-squad\\snapshots\\50ba811384f02cb99cdabe5cdc02f7ddc4f69e10"
tokenizer = DistilBertTokenizer.from_pretrained(path)
model = DistilBertForQuestionAnswering.from_pretrained(path)
def main():
question_answerer = pipeline(
"question-answering", model=model, tokenizer=tokenizer, framework="pt"
)
result = question_answerer(
{
"question": "What is the name of the repository ?",
"context": "Pipeline has been included in the huggingface/transformers repository",
}
)
print(result)
if __name__ == "__main__":
main()
使用PyTorch
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import torch
path = "E:\\temp\\models\\models--distilbert--distilbert-base-cased-distilled-squad\\snapshots\\50ba811384f02cb99cdabe5cdc02f7ddc4f69e10"
tokenizer = DistilBertTokenizer.from_pretrained(path)
model = DistilBertForQuestionAnswering.from_pretrained(path)
inputs = tokenizer(
"What is the name of the repository?",
"Pipeline has been included in the huggingface/transformers repository",
return_tensors="pt",
)
with torch.no_grad():
outputs = model(**inputs)
answer_start_index = int(torch.argmax(outputs.start_logits, axis=-1)[0])
answer_end_index = int(torch.argmax(outputs.end_logits, axis=-1)[0])
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
result = tokenizer.decode(predict_answer_tokens)
print(result)