问答

创建日期:2024-06-21
更新日期:2025-02-01

使用流水线API

from transformers import pipeline, DistilBertTokenizer, DistilBertForQuestionAnswering

path = "E:\\temp\\models\\models--distilbert--distilbert-base-cased-distilled-squad\\snapshots\\50ba811384f02cb99cdabe5cdc02f7ddc4f69e10"

tokenizer = DistilBertTokenizer.from_pretrained(path)
model = DistilBertForQuestionAnswering.from_pretrained(path)

def main():
    question_answerer = pipeline(
        "question-answering", model=model, tokenizer=tokenizer, framework="pt"
    )
    result = question_answerer(
        {
            "question": "What is the name of the repository ?",
            "context": "Pipeline has been included in the huggingface/transformers repository",
        }
    )
    print(result)

if __name__ == "__main__":
    main()

使用PyTorch

from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import torch

path = "E:\\temp\\models\\models--distilbert--distilbert-base-cased-distilled-squad\\snapshots\\50ba811384f02cb99cdabe5cdc02f7ddc4f69e10"

tokenizer = DistilBertTokenizer.from_pretrained(path)
model = DistilBertForQuestionAnswering.from_pretrained(path)

inputs = tokenizer(
    "What is the name of the repository?",
    "Pipeline has been included in the huggingface/transformers repository",
    return_tensors="pt",
)

with torch.no_grad():
    outputs = model(**inputs)

answer_start_index = int(torch.argmax(outputs.start_logits, axis=-1)[0])
answer_end_index = int(torch.argmax(outputs.end_logits, axis=-1)[0])

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
result = tokenizer.decode(predict_answer_tokens)

print(result)