是的,BertForSequenceClassification模型对CLS向量进行分类。下面是一个使用PyTorch的示例代码:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# 加载预训练的Bert模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# 输入文本
text = "This is an example sentence."
# 对文本进行tokenize和编码
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
# 获取输入的tensor
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# 使用BertForSequenceClassification进行分类
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
# 获取预测结果
predictions = torch.argmax(logits, dim=1).item()
print(predictions)
在上述代码中,BertForSequenceClassification模型接收文本输入后,会返回一个包含logits的输出。通过在logits上使用argmax函数,可以获取到最终的分类结果。