视听融合语音识别实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2Model, ViTModel
class AudioVisualSpeechRecognizer(nn.Module):
"""
视听融合语音识别模型
支持早期融合(特征级)和晚期融合(决策级)
"""
def __init__(self, audio_model="facebook/wav2vec2-base",
visual_model="google/vit-base-patch16-224",
fusion_type="late", # "early" or "late"
num_classes=40): # 音素数量
"""
初始化
Args:
audio_model: 音频预训练模型
visual_model: 视觉预训练模型
fusion_type: 融合类型
num_classes: 输出类别数
"""
super().__init__()
self.fusion_type = fusion_type
# 音频编码器
self.audio_encoder = Wav2Vec2Model.from_pretrained(audio_model)
self.audio_proj = nn.Linear(768, 512) # 投影到融合空间
# 视觉编码器
self.visual_encoder = ViTModel.from_pretrained(visual_model)
self.visual_proj = nn.Linear(768, 512) # 投影到融合空间
# 融合模块
if fusion_type == "early":
# 早期融合:拼接特征后通过融合网络
self.fusion_network = nn.Sequential(
nn.Linear(1024, 768), # 512+512
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3)
)
self.classifier = nn.Linear(512, num_classes)
else:
# 晚期融合:独立分类后融合决策
self.audio_classifier = nn.Linear(512, num_classes)
self.visual_classifier = nn.Linear(512, num_classes)
# 注意力融合权重
self.attention_weights = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Linear(256, 2),
nn.Softmax(dim=-1)
)
def extract_audio_features(self, audio_input):
"""
提取音频特征
Args:
audio_input: 音频波形 [batch, time]
Returns:
features: 音频特征 [batch, seq_len, 512]
"""
outputs = self.audio_encoder(audio_input)
hidden_states = outputs.last_hidden_state # [batch, seq_len, 768]
features = self.audio_proj(hidden_states) # [batch, seq_len, 512]
return features
def extract_visual_features(self, visual_input):
"""
提取视觉特征(唇部图像序列)
Args:
visual_input: 视频帧序列 [batch, num_frames, 3, 224, 224]
Returns:
features: 视觉特征 [batch, seq_len, 512]
"""
batch, num_frames, C, H, W = visual_input.shape
# 重塑为 [batch*num_frames, C, H, W]
frames = visual_input.view(-1, C, H, W)
# 提取每帧特征
outputs = self.visual_encoder(frames)
frame_features = outputs.last_hidden_state[:, 0, :] # [batch*num_frames, 768]
# 投影并重塑回 [batch, num_frames, 512]
features = self.visual_proj(frame_features)
features = features.view(batch, num_frames, -1)
return features
def forward(self, audio_input, visual_input, labels=None):
"""
前向传播
Args:
audio_input: 音频输入
visual_input: 视觉输入
labels: 标签(用于训练)
Returns:
output: 识别结果
"""
# 提取特征
audio_features = self.extract_audio_features(audio_input)
visual_features = self.extract_visual_features(visual_input)
# 确保时间维度一致
min_len = min(audio_features.size(1), visual_features.size(1))
audio_features = audio_features[:, :min_len, :]
visual_features = visual_features[:, :min_len, :]
if self.fusion_type == "early":
# 早期融合:特征拼接
concatenated = torch.cat([audio_features, visual_features], dim=-1)
fused_features = self.fusion_network(concatenated)
logits = self.classifier(fused_features)
output = {
"logits": logits,
"fusion_type": "early",
"fused_features": fused_features
}
else:
# 晚期融合:独立分类
audio_logits = self.audio_classifier(audio_features)
visual_logits = self.visual_classifier(visual_features)
# 注意力融合
combined = torch.cat([audio_features, visual_features], dim=-1)
weights = self.attention_weights(combined.mean(dim=1)) # [batch, 2]
# 加权融合
fused_logits = (
weights[:, 0:1].unsqueeze(1) * audio_logits +
weights[:, 1:2].unsqueeze(1) * visual_logits
)
output = {
"logits": fused_logits,
"audio_logits": audio_logits,
"visual_logits": visual_logits,
"fusion_weights": weights,
"fusion_type": "late"
}
# 计算损失
if labels is not None:
loss = F.cross_entropy(
output["logits"].view(-1, output["logits"].size(-1)),
labels.view(-1),
ignore_index=-100
)
output["loss"] = loss
return output
# 使用示例
if __name__ == "__main__":
# 初始化模型
model = AudioVisualSpeechRecognizer(fusion_type="late")
model.eval()
# 模拟输入
batch_size = 4
audio_length = 16000 * 3 # 3 秒音频 (16kHz)
num_frames = 30 # 30 帧视频
audio_input = torch.randn(batch_size, audio_length)
visual_input = torch.randn(batch_size, num_frames, 3, 224, 224)
print("视听融合语音识别示例:")
print("="*70 + "\n")
# 推理
with torch.no_grad():
output = model(audio_input, visual_input)
print(f"融合类型:{output['fusion_type']}")
print(f"融合权重:{output['fusion_weights']}")
print(f" - 音频权重:{output['fusion_weights'][0, 0].item():.3f}")
print(f" - 视觉权重:{output['fusion_weights'][0, 1].item():.3f}")
print(f"输出 logits 形状:{output['logits'].shape}")
print("\n" + "="*70)
print("\n关键观察:")
print("1. 模型自动学习音频和视觉的融合权重")
print("2. 嘈杂环境下视觉权重会自动增加")
print("3. 清晰环境下音频权重占主导")
print("4. 多模态融合提升鲁棒性和准确性")