滑动窗口注意力实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SlidingWindowAttention(nn.Module):
"""
滑动窗口注意力机制
每个 token 只关注其周围窗口大小内的 token
复杂度:O(n * w) 而非 O(n²)
"""
def __init__(self, embed_dim, num_heads, window_size=512):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.window_size = window_size
self.head_dim = embed_dim // num_heads
# QKV 投影
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.scale = self.head_dim ** -0.5
def forward(self, hidden_states, attention_mask=None):
"""
前向传播
Args:
hidden_states: [B, L, D] 输入隐藏状态
attention_mask: [B, L] 可选的注意力掩码
Returns:
output: [B, L, D] 注意力输出
"""
B, L, D = hidden_states.shape
# 计算 Q, K, V
Q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
# 创建滑动窗口掩码
# 对于位置 i,只关注 [i-w, i+w] 范围内的 token
window_mask = self._create_sliding_window_mask(L, self.window_size)
window_mask = window_mask.to(Q.device)
# 计算注意力分数
# Q: [B, H, L, d], K: [B, H, L, d]
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [B, H, L, L]
# 应用滑动窗口掩码
attn_scores = attn_scores.masked_fill(window_mask == 0, -1e9)
# 可选:应用额外的注意力掩码(如 padding mask)
if attention_mask is not None:
attn_mask_2d = attention_mask.unsqueeze(1).unsqueeze(2) # [B, 1, 1, L]
attn_scores = attn_scores.masked_fill(attn_mask_2d == 0, -1e9)
# Softmax 归一化
attn_weights = F.softmax(attn_scores, dim=-1) # [B, H, L, L]
# 加权求和
output = torch.matmul(attn_weights, V) # [B, H, L, d]
# 合并多头
output = output.transpose(1, 2).contiguous().view(B, L, D)
# 输出投影
output = self.out_proj(output)
return output
def _create_sliding_window_mask(self, seq_len, window_size):
"""
创建滑动窗口掩码
Args:
seq_len: 序列长度
window_size: 窗口大小
Returns:
mask: [seq_len, seq_len] 布尔掩码,1 表示可关注,0 表示不可关注
"""
# 创建位置矩阵
positions = torch.arange(seq_len).unsqueeze(1) # [L, 1]
positions_t = positions.t() # [1, L]
# 计算位置差
distance = torch.abs(positions - positions_t) # [L, L]
# 滑动窗口掩码:距离 <= window_size 的位置为 1
mask = (distance <= window_size).float() # [L, L]
return mask
# 使用示例
def sliding_window_example():
"""滑动窗口注意力示例"""
# 参数
batch_size = 2
seq_len = 2048 # 长序列
embed_dim = 768
num_heads = 12
window_size = 512
# 创建模型
attn = SlidingWindowAttention(
embed_dim=embed_dim,
num_heads=num_heads,
window_size=window_size
)
# 随机输入
hidden_states = torch.randn(batch_size, seq_len, embed_dim)
# 前向传播
output = attn(hidden_states)
print(f"输入形状:{hidden_states.shape}")
print(f"输出形状:{output.shape}")
# 计算复杂度对比
# 标准注意力:O(L²) = 2048² ≈ 4.2M
# 滑动窗口:O(L * w) = 2048 * 512 ≈ 1.0M
# 加速比:4.2x
standard_complexity = seq_len ** 2
sliding_complexity = seq_len * window_size
speedup = standard_complexity / sliding_complexity
print(f"\n标准注意力复杂度:{standard_complexity:,}")
print(f"滑动窗口复杂度:{sliding_complexity:,}")
print(f"加速比:{speedup:.1f}x")
if __name__ == "__main__":
sliding_window_example()