情绪分析

创建日期:2024-06-21
更新日期:2025-02-01

使用流水线API

from transformers import (
    pipeline,
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
)

path = "E:\\temp\\models\\models--distilbert--distilbert-base-uncased-finetuned-sst-2-english\\snapshots\\714eb0fa89d2f80546fda750413ed43d93601a13"

tokenizer = DistilBertTokenizer.from_pretrained(path)
model = DistilBertForSequenceClassification.from_pretrained(path)

classifier = pipeline(
    "sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt"
)
result = classifier(
    "We are very happy to introduce pipeline to the transformers repository."
)

print(result)

使用PyTorch

import torch
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
)

path = "E:\\temp\\models\\models--distilbert--distilbert-base-uncased-finetuned-sst-2-english\\snapshots\\714eb0fa89d2f80546fda750413ed43d93601a13"

tokenizer = DistilBertTokenizer.from_pretrained(path)
model = DistilBertForSequenceClassification.from_pretrained(path)

inputs = tokenizer(
    "We are very happy to introduce pipeline to the transformers repository.",
    return_tensors="pt",
)

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()
result = model.config.id2label[predicted_class_id]
print(result)