MAML 算法完整实现
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Callable
from dataclasses import dataclass
from copy import deepcopy
import warnings
@dataclass
class MAMLConfig:
"""MAML 配置类"""
input_dim: int
output_dim: int
hidden_dim: int = 256
inner_lr: float = 0.01 # 内层学习率(任务特定)
outer_lr: float = 0.001 # 外层学习率(meta)
inner_steps: int = 5 # 内层更新步数
num_tasks: int = 4 # 每批任务数
k_shot: int = 5 # 每类样本数
q_query: int = 15 # 查询集大小
class MAMLNetwork(nn.Module):
"""MAML 使用的神经网络"""
def __init__(self, config: MAMLConfig):
super().__init__()
self.network = nn.Sequential(
nn.Linear(config.input_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.output_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
def functional_forward(self, x: torch.Tensor,
params: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
使用给定参数进行前向传播(用于内层更新)
Args:
x: 输入
params: 网络参数字典
Returns:
输出
"""
# 手动实现前向传播
h = F.linear(x, params['network.0.weight'], params['network.0.bias'])
h = F.relu(h)
h = F.linear(h, params['network.2.weight'], params['network.2.bias'])
h = F.relu(h)
out = F.linear(h, params['network.4.weight'], params['network.4.bias'])
return out
class MAML:
"""
Model-Agnostic Meta-Learning (MAML)
核心算法:
1. 采样一批任务
2. 对每个任务:
- 从支持集计算损失
- 梯度更新得到任务特定参数
- 在查询集上评估
3. 汇总所有任务的查询集损失
4. 更新 meta 参数
关键洞察:学习"对梯度敏感"的初始化
"""
def __init__(self, config: MAMLConfig):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化 meta 参数
self.model = MAMLNetwork(config).to(self.device)
self.meta_optimizer = optim.Adam(self.model.parameters(), lr=config.outer_lr)
self.training_stats = {
'meta_loss': [],
'task_losses': [],
'accuracies': []
}
def _get_model_params(self, model: nn.Module) -> Dict[str, torch.Tensor]:
"""获取模型参数副本"""
return {name: param.clone() for name, param in model.named_parameters()}
def _inner_loop(self, support_x: torch.Tensor,
support_y: torch.Tensor,
params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
内层循环:任务特定更新
Args:
support_x: 支持集输入 (K*N, D)
support_y: 支持集标签 (K*N,)
params: 初始参数
Returns:
updated_params: 更新后的参数
"""
updated_params = params
for step in range(self.config.inner_steps):
# 前向传播
logits = self.model.functional_forward(support_x, updated_params)
# 计算损失
loss = F.cross_entropy(logits, support_y)
# 梯度计算
grads = torch.autograd.grad(loss, updated_params.values(),
create_graph=True if step == 0 else False)
# 参数更新
updated_params = {
name: param - self.config.inner_lr * grad
for (name, param), grad in zip(updated_params.items(), grads)
}
return updated_params
def _outer_loop(self, task_batch: List[Dict]) -> torch.Tensor:
"""
外层循环:meta 更新
Args:
task_batch: 任务批次,每个任务包含 support_x, support_y, query_x, query_y
Returns:
meta_loss: meta 损失
"""
meta_loss = 0
task_losses = []
for task in task_batch:
support_x = task['support_x'].to(self.device)
support_y = task['support_y'].to(self.device)
query_x = task['query_x'].to(self.device)
query_y = task['query_y'].to(self.device)
# 1. 获取当前 meta 参数
base_params = self._get_model_params(self.model)
# 2. 内层更新:得到任务特定参数
adapted_params = self._inner_loop(support_x, support_y, base_params)
# 3. 在查询集上评估
query_logits = self.model.functional_forward(query_x, adapted_params)
query_loss = F.cross_entropy(query_logits, query_y)
meta_loss = meta_loss + query_loss
task_losses.append(query_loss.item())
# 平均 meta 损失
meta_loss = meta_loss / len(task_batch)
return meta_loss, np.mean(task_losses)
def meta_train(self,
task_sampler: Callable,
n_iterations: int = 1000,
verbose: bool = True) -> List[float]:
"""
Meta 训练循环
Args:
task_sampler: 任务采样函数,返回任务批次
n_iterations: 迭代次数
verbose: 是否打印日志
Returns:
meta_losses: 每轮 meta 损失
"""
meta_losses = []
for iteration in range(n_iterations):
# 采样任务批次
task_batch = task_sampler(self.config.num_tasks, self.config.k_shot,
self.config.q_query)
# 外层更新
self.meta_optimizer.zero_grad()
meta_loss, avg_task_loss = self._outer_loop(task_batch)
# 反向传播
meta_loss.backward()
# 梯度裁剪(稳定训练)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.meta_optimizer.step()
# 记录统计
self.training_stats['meta_loss'].append(meta_loss.item())
self.training_stats['task_losses'].append(avg_task_loss)
if verbose and (iteration + 1) % 50 == 0:
avg_meta_loss = np.mean(self.training_stats['meta_loss'][-50:])
print(f"Iteration {iteration + 1}/{n_iterations} | "
f"Meta Loss: {avg_meta_loss:.4f} | "
f"Task Loss: {avg_task_loss:.4f}")
meta_losses.append(meta_loss.item())
return meta_losses
def adapt(self, support_x: torch.Tensor,
support_y: torch.Tensor,
n_steps: int = None) -> nn.Module:
"""
在新任务上适配
Args:
support_x: 支持集输入
support_y: 支持集标签
n_steps: 更新步数(默认使用 config.inner_steps)
Returns:
adapted_model: 适配后的模型
"""
if n_steps is None:
n_steps = self.config.inner_steps
# 复制模型
adapted_model = deepcopy(self.model)
adapted_params = self._get_model_params(adapted_model)
# 内层更新
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
for step in range(n_steps):
logits = adapted_model(support_x)
loss = F.cross_entropy(logits, support_y)
adapted_model.zero_grad()
loss.backward()
# 手动更新参数
with torch.no_grad():
for param in adapted_model.parameters():
param -= self.config.inner_lr * param.grad
return adapted_model
def evaluate(self,
adapted_model: nn.Module,
query_x: torch.Tensor,
query_y: torch.Tensor) -> Dict[str, float]:
"""
评估适配后模型
Returns:
metrics: 准确率等指标
"""
adapted_model.eval()
query_x = query_x.to(self.device)
query_y = query_y.to(self.device)
with torch.no_grad():
logits = adapted_model(query_x)
preds = torch.argmax(logits, dim=1)
accuracy = (preds == query_y).float().mean().item()
return {'accuracy': accuracy}
# 使用示例:Few-Shot Classification
if __name__ == "__main__":
# 配置
config = MAMLConfig(
input_dim=64, # 例如: flattened 8x8 图像
output_dim=5, # 5-way classification
hidden_dim=128,
inner_lr=0.01,
outer_lr=0.001,
inner_steps=5,
num_tasks=4,
k_shot=5, # 5-shot
q_query=15
)
# 创建 MAML
maml = MAML(config)
# 模拟任务采样器(实际应使用真实数据集如 miniImageNet)
def dummy_task_sampler(n_tasks, k_shot, q_query):
tasks = []
for _ in range(n_tasks):
# 模拟 5-way 5-shot 任务
n_ways = config.output_dim
support_x = torch.randn(k_shot * n_ways, config.input_dim)
support_y = torch.arange(n_ways).repeat_interleave(k_shot)
query_x = torch.randn(q_query * n_ways, config.input_dim)
query_y = torch.arange(n_ways).repeat_interleave(q_query)
tasks.append({
'support_x': support_x,
'support_y': support_y,
'query_x': query_x,
'query_y': query_y
})
return tasks
print("开始 MAML meta-training...")
meta_losses = maml.meta_train(dummy_task_sampler, n_iterations=200, verbose=True)
print("\nMeta-training 完成!")
print(f"最终 meta 损失:{meta_losses[-1]:.4f}")
# 测试快速适配
print("\n测试新任务适配...")
# 采样新任务
new_task = dummy_task_sampler(1, config.k_shot, config.q_query)[0]
# 适配
adapted_model = maml.adapt(new_task['support_x'], new_task['support_y'])
# 评估
metrics = maml.evaluate(adapted_model, new_task['query_x'], new_task['query_y'])
print(f"适配后准确率:{metrics['accuracy']:.2%}")
print("\n关键观察:")
print("1. MAML 学习'对梯度敏感'的初始化参数")
print("2. 新任务上只需 1-5 步梯度更新即可达到高性能")
print("3. 模型无关:适用于任何基于梯度的模型")
print("4. 二阶优化:内层更新的梯度需要二阶导数")
print("5. 少样本分类、回归、强化学习均可应用")