使用CLS向量进行分类的步骤如下:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
这里的num_labels
参数是分类任务的类别数,根据实际情况进行设置。
text = "这是一段要分类的文本"
inputs = tokenizer.encode_plus(text, add_special_tokens=True, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
encode_plus
方法将文本编码为BERT模型所需的输入格式。add_special_tokens=True
表示添加特殊的开始和结束标记,truncation=True
表示对文本进行截断,max_length
指定了最大长度,padding='max_length'
表示进行填充。
outputs = model(**inputs)
logits = outputs.logits
**inputs
将编码后的输入传递给模型。outputs.logits
得到模型的输出,即分类的预测得分。
predictions = torch.argmax(logits, dim=1)
torch.argmax
方法返回最大值的索引,即预测的类别。
完整的代码示例如下:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
text = "这是一段要分类的文本"
inputs = tokenizer.encode_plus(text, add_special_tokens=True, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=1)
print(predictions)
这段代码使用了预训练的BERT模型对输入文本进行分类,并打印了预测的结果。