跨模态注意力实现
import torch
import torch.nn as nn
import math
class CrossModalAttention(nn.Module):
"""
跨模态注意力层
Query 来自模态 A(如文本)
Key, Value 来自模态 B(如图像)
"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim 必须能被 num_heads 整除"
# Q, K, V 线性变换
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# 输出投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
# Dropout
self.dropout = nn.Dropout(dropout)
# 缩放因子
self.scale = math.sqrt(self.head_dim)
def forward(self, query, key, value, key_padding_mask=None, attn_mask=None):
"""
前向传播
Args:
query: [B, L_q, D] - Query 序列(如文本)
key: [B, L_k, D] - Key 序列(如图像区域)
value: [B, L_k, D] - Value 序列
key_padding_mask: [B, L_k] - Key 填充掩码(True 表示忽略)
attn_mask: [L_q, L_k] - 注意力掩码
Returns:
output: [B, L_q, D] - 注意力输出
attn_weights: [B, num_heads, L_q, L_k] - 注意力权重
"""
B, L_q, D = query.shape
L_k = key.shape[1]
# 线性变换
Q = self.q_proj(query) # [B, L_q, D]
K = self.k_proj(key) # [B, L_k, D]
V = self.v_proj(value) # [B, L_k, D]
# 多头拆分
Q = Q.view(B, L_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L_q, d]
K = K.view(B, L_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L_k, d]
V = V.view(B, L_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, L_k, d]
# 计算注意力分数
attn_scores = Q @ K.transpose(-2, -1) / self.scale # [B, H, L_q, L_k]
# 应用掩码
if attn_mask is not None:
attn_scores = attn_scores.masked_fill(attn_mask == 0, float('-inf'))
if key_padding_mask is not None:
attn_scores = attn_scores.masked_fill(
key_padding_mask.view(B, 1, 1, L_k),
float('-inf')
)
# Softmax 归一化
attn_weights = torch.softmax(attn_scores, dim=-1) # [B, H, L_q, L_k]
attn_weights = self.dropout(attn_weights)
# 加权求和
output = attn_weights @ V # [B, H, L_q, d]
# 合并多头
output = output.transpose(1, 2).contiguous().view(B, L_q, D) # [B, L_q, D]
# 输出投影
output = self.out_proj(output)
return output, attn_weights
class GatedCrossAttention(nn.Module):
"""
门控跨模态注意力(Flamingo 使用)
通过门控机制控制跨模态信息的流动,
使得模型可以动态决定是否使用跨模态信息
"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.cross_attn = CrossModalAttention(embed_dim, num_heads, dropout)
# 门控参数(可学习)
self.gate = nn.Parameter(torch.zeros(1))
# LayerNorm
self.norm = nn.LayerNorm(embed_dim)
def forward(self, query, key, value, **kwargs):
"""
前向传播
Args:
query: [B, L_q, D] - Query(语言模型隐藏状态)
key: [B, L_k, D] - Key(视觉特征)
value: [B, L_k, D] - Value(视觉特征)
Returns:
output: [B, L_q, D] - 门控注意力输出
"""
# 保存残差连接
residual = query
# LayerNorm
query_norm = self.norm(query)
# Cross-Attention
attn_output, attn_weights = self.cross_attn(
query_norm, key, value, **kwargs
)
# 门控融合:gate * cross_attn_output
gated_output = torch.tanh(self.gate) * attn_output
# 残差连接
output = residual + gated_output
return output, attn_weights
# 使用示例
def cross_attention_example():
"""跨模态注意力示例"""
# 初始化模型
embed_dim = 768
num_heads = 12
cross_attn = CrossModalAttention(embed_dim, num_heads)
gated_cross_attn = GatedCrossAttention(embed_dim, num_heads)
# 模拟数据
B = 4
text_len = 50 # 文本序列长度
image_regions = 196 # 图像区域数量(14x14)
text_features = torch.randn(B, text_len, embed_dim) # 文本特征
image_features = torch.randn(B, image_regions, embed_dim) # 图像特征
# 标准 Cross-Attention
output, attn_weights = cross_attn(
query=text_features,
key=image_features,
value=image_features
)
print(f"输入形状:{text_features.shape}")
print(f"输出形状:{output.shape}")
print(f"注意力权重形状:{attn_weights.shape}")
# 可视化注意力(第一个样本、第一个头)
import matplotlib.pyplot as plt
attn_map = attn_weights[0, 0].cpu().numpy() # [L_q, L_k]
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.imshow(attn_map, cmap='viridis', aspect='auto')
plt.xlabel('图像区域')
plt.ylabel('文本词')
plt.title('Cross-Attention 热力图')
plt.colorbar()
# 平均注意力(每个文本词关注的图像区域)
plt.subplot(1, 2, 2)
mean_attn = attn_map.mean(axis=1)
plt.plot(mean_attn)
plt.xlabel('图像区域索引')
plt.ylabel('平均注意力权重')
plt.title('每个文本词的平均注意力分布')
plt.tight_layout()
plt.show()
# 门控 Cross-Attention
gated_output, gated_weights = gated_cross_attn(
query=text_features,
key=image_features,
value=image_features
)
print(f"门控参数值:{gated_cross_attn.gate.item():.4f}")
print(f"门控因子:{torch.tanh(gated_cross_attn.gate).item():.4f}")