BERT 联合意图识别与槽位填充
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
class JointIntentSlotBERT(nn.Module):
"""
联合意图识别与槽位填充的 BERT 模型
共享 BERT 编码器,两个任务头:
- 意图分类头:使用 [CLS] token
- 槽位填充头:使用每个 token 的输出
"""
def __init__(self, bert_name, intent_labels, slot_labels):
super().__init__()
# 共享 BERT 编码器
self.bert = BertModel.from_pretrained(bert_name)
hidden_size = self.bert.config.hidden_size
# 意图分类头
self.intent_classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, intent_labels)
)
# 槽位填充头
self.slot_classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, slot_labels)
)
# 损失函数
self.intent_loss_fn = nn.CrossEntropyLoss()
self.slot_loss_fn = nn.CrossEntropyLoss(ignore_index=-1) # 忽略 padding
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
intent_labels=None, slot_labels=None):
"""
前向传播
Args:
input_ids: [B, L] 输入 token IDs
attention_mask: [B, L] 注意力掩码
token_type_ids: [B, L] token 类型 IDs
intent_labels: [B] 意图标签(训练时提供)
slot_labels: [B, L] 槽位标签(训练时提供)
Returns:
intent_logits: [B, num_intents] 意图预测
slot_logits: [B, L, num_slots] 槽位预测
loss: 联合损失(训练时返回)
"""
# BERT 编码
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
sequence_output = outputs.last_hidden_state # [B, L, D]
cls_output = outputs.pooler_output # [B, D]
# 意图分类(使用 [CLS] token)
intent_logits = self.intent_classifier(cls_output) # [B, num_intents]
# 槽位填充(使用每个 token 的输出)
slot_logits = self.slot_classifier(sequence_output) # [B, L, num_slots]
# 计算损失
loss = None
if intent_labels is not None and slot_labels is not None:
# 意图损失
intent_loss = self.intent_loss_fn(intent_logits, intent_labels)
# 槽位损失(展平处理)
batch_size, seq_len, num_slots = slot_logits.shape
slot_logits_flat = slot_logits.view(-1, num_slots)
slot_labels_flat = slot_labels.view(-1)
slot_loss = self.slot_loss_fn(slot_logits_flat, slot_labels_flat)
# 联合损失(可调整权重)
loss = 0.5 * intent_loss + 0.5 * slot_loss
return intent_logits, slot_logits, loss
return intent_logits, slot_logits, None
# 使用示例
def joint_bert_example():
"""联合 BERT 示例"""
# 初始化模型
intent_labels = 20 # 20 种意图
slot_labels = 50 # 50 种槽位标签(包括 BIO 标记)
model = JointIntentSlotBERT(
bert_name='bert-base-uncased',
intent_labels=intent_labels,
slot_labels=slot_labels
)
# 准备数据
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
utterance = "book a flight to Beijing tomorrow"
intent_label = 5 # book_flight
slot_labels = [0, 0, 0, 0, 0, 8, 9, 0, 12, 13] # BIO 标注
# 分词
inputs = tokenizer(
utterance,
return_tensors='pt',
padding=True,
truncation=True,
max_length=128
)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# 训练
model.train()
intent_logits, slot_logits, loss = model(
input_ids=input_ids,
attention_mask=attention_mask,
intent_labels=torch.tensor([intent_label]),
slot_labels=torch.tensor([slot_labels])
)
print(f"联合损失:{loss.item():.4f}")
# 推理
model.eval()
with torch.no_grad():
intent_logits, slot_logits, _ = model(
input_ids=input_ids,
attention_mask=attention_mask
)
# 意图预测
intent_pred = intent_logits.argmax(dim=-1).item()
print(f"预测意图:{intent_pred}")
# 槽位预测
slot_preds = slot_logits.argmax(dim=-1)[0].tolist()
print(f"预测槽位:{slot_preds}")
# 解码槽位
id2slot = {i: s for i, s in enumerate(['O', 'B-from', 'I-from', 'B-to', 'I-to', ...])}
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
for token, slot_id in zip(tokens, slot_preds):
if slot_id != 0: # 非 O 标签
print(f" {token}: {id2slot[slot_id]}")