Agent 轻量化与蒸馏量化完整实现
import time
import json
import hashlib
import secrets
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import numpy as np
from collections import deque, defaultdict
import statistics
import threading
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import copy
class DistillationType(Enum):
"""蒸馏类型"""
LOGITS = "logits" # Logits 蒸馏
FEATURE = "feature" # 特征蒸馏
SELF = "self" # 自蒸馏
ONLINE = "online" # 在线蒸馏
class QuantizationMethod(Enum):
"""量化方法"""
PTQ = "ptq" # 后训练量化
QAT = "qat" # 量化感知训练
AWQ = "awq" # AWQ 量化
GPTQ = "gptq" # GPTQ 量化
class PruningMethod(Enum):
"""剪枝方法"""
MAGNITUDE = "magnitude" # 幅度剪枝
STRUCTURED = "structured" # 结构化剪枝
MOVEMENT = "movement" # 运动剪枝
@dataclass
class DistillationConfig:
"""蒸馏配置"""
teacher_model: Any
student_model: Any
distillation_type: DistillationType
temperature: float
alpha: float # 蒸馏损失权重
beta: float # 硬标签损失权重
feature_layers: List[str]
distill_epochs: int
@dataclass
class QuantizationConfig:
"""量化配置"""
method: QuantizationMethod
weight_bits: int # 4, 8
activation_bits: int # 8, 16
per_channel: bool
symmetric: bool
calibration_samples: int
quantize_embedding: bool
@dataclass
class PruningConfig:
"""剪枝配置"""
method: PruningMethod
sparsity: float # 0.0-1.0
structured: bool
prune_layers: List[str]
fine_tune_epochs: int
@dataclass
class CompressionMetrics:
"""压缩指标"""
original_size: float # MB
compressed_size: float # MB
compression_ratio: float
original_params: int
pruned_params: int
sparsity: float
original_accuracy: float
compressed_accuracy: float
accuracy_drop: float
inference_speedup: float
memory_reduction: float
class KnowledgeDistiller:
"""
知识蒸馏器
支持:
1. Logits 蒸馏
2. 特征蒸馏
3. 自蒸馏
4. 在线蒸馏
"""
def __init__(self, config: DistillationConfig):
self.config = config
self.teacher = config.teacher_model
self.student = config.student_model
self.temperature = config.temperature
self.alpha = config.alpha
self.beta = config.beta
# 冻结教师模型
for param in self.teacher.parameters():
param.requires_grad = False
self.teacher.eval()
def distill_logits(self, student_logits, teacher_logits, hard_labels):
"""Logits 蒸馏"""
# 软化 logits
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
# 蒸馏损失 (KL 散度)
distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
# 硬标签损失
hard_loss = F.cross_entropy(student_logits, hard_labels)
# 总损失
total_loss = self.alpha * distill_loss + self.beta * hard_loss
return total_loss, distill_loss, hard_loss
def distill_features(self, student_features, teacher_features):
"""特征蒸馏"""
# 特征对齐损失 (MSE)
if len(student_features) != len(teacher_features):
raise ValueError("Feature layers mismatch")
feature_loss = 0.0
for sf, tf in zip(student_features, teacher_features):
feature_loss += F.mse_loss(sf, tf)
feature_loss /= len(student_features)
return feature_loss
def train_epoch(self, dataloader, optimizer, device):
"""训练一个 epoch"""
self.student.train()
total_loss = 0.0
distill_losses = []
hard_losses = []
for batch_idx, (data, labels) in enumerate(dataloader):
data, labels = data.to(device), labels.to(device)
optimizer.zero_grad()
# 教师模型前向传播
with torch.no_grad():
teacher_logits = self.teacher(data)
teacher_features = self._extract_features(self.teacher, data)
# 学生模型前向传播
student_logits = self.student(data)
student_features = self._extract_features(self.student, data)
# 计算蒸馏损失
if self.config.distillation_type == DistillationType.LOGITS:
total_loss, distill_loss, hard_loss = self.distill_logits(
student_logits, teacher_logits, labels
)
elif self.config.distillation_type == DistillationType.FEATURE:
feature_loss = self.distill_features(student_features, teacher_features)
hard_loss = F.cross_entropy(student_logits, labels)
total_loss = self.alpha * feature_loss + self.beta * hard_loss
distill_loss = feature_loss
else:
total_loss = F.cross_entropy(student_logits, labels)
distill_loss = torch.tensor(0.0)
hard_loss = total_loss
total_loss.backward()
optimizer.step()
total_loss += total_loss.item()
distill_losses.append(distill_loss.item())
hard_losses.append(hard_loss.item())
return {
'avg_loss': total_loss / len(dataloader),
'avg_distill_loss': statistics.mean(distill_losses),
'avg_hard_loss': statistics.mean(hard_losses)
}
def _extract_features(self, model, data):
"""提取中间层特征"""
features = []
hook_handles = []
def create_hook(layer_name):
def hook(module, input, output):
features.append(output)
return hook
# 注册 hook
for layer_name in self.config.feature_layers:
layer = self._get_layer(model, layer_name)
if layer:
handle = layer.register_forward_hook(create_hook(layer_name))
hook_handles.append(handle)
# 前向传播
with torch.no_grad():
model(data)
# 移除 hook
for handle in hook_handles:
handle.remove()
return features
def _get_layer(self, model, layer_name):
"""获取模型层"""
layers = layer_name.split('.')
module = model
for layer in layers:
if hasattr(module, layer):
module = getattr(module, layer)
else:
return None
return module
def distill(self, train_loader, val_loader, optimizer, scheduler, device, epochs):
"""完整蒸馏过程"""
print(f"开始知识蒸馏,共 {epochs} 个 epoch...")
best_accuracy = 0.0
history = []
for epoch in range(epochs):
# 训练
train_metrics = self.train_epoch(train_loader, optimizer, device)
# 验证
val_accuracy = self.evaluate(val_loader, device)
metrics = {
'epoch': epoch + 1,
**train_metrics,
'val_accuracy': val_accuracy
}
history.append(metrics)
print(f"Epoch {epoch+1}/{epochs} - "
f"Loss: {metrics['avg_loss']:.4f}, "
f"Distill: {metrics['avg_distill_loss']:.4f}, "
f"Hard: {metrics['avg_hard_loss']:.4f}, "
f"Val Acc: {val_accuracy:.4f}")
# 保存最佳模型
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
torch.save(self.student.state_dict(), 'best_student.pth')
# 学习率调度
if scheduler:
scheduler.step()
print(f"蒸馏完成,最佳验证准确率:{best_accuracy:.4f}")
return history
def evaluate(self, dataloader, device):
"""评估模型"""
self.student.eval()
correct = 0
total = 0
with torch.no_grad():
for data, labels in dataloader:
data, labels = data.to(device), labels.to(device)
outputs = self.student(data)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
class ModelQuantizer:
"""
模型量化器
支持:
1. PTQ 后训练量化
2. QAT 量化感知训练
3. INT8/INT4 量化
"""
def __init__(self, config: QuantizationConfig):
self.config = config
self.quantized_model = None
self.calibration_data = None
self.scale_factors = {}
self.zero_points = {}
def prepare_quantization(self, model):
"""准备量化"""
if self.config.method == QuantizationMethod.QAT:
# QAT: 插入伪量化节点
self.quantized_model = self._insert_fake_quant(model)
else:
# PTQ: 直接量化
self.quantized_model = copy.deepcopy(model)
return self.quantized_model
def _insert_fake_quant(self, model):
"""插入伪量化节点 (QAT)"""
class FakeQuantize(nn.Module):
def __init__(self, bits, symmetric):
super().__init__()
self.bits = bits
self.symmetric = symmetric
self.scale = nn.Parameter(torch.tensor(1.0))
self.zero_point = nn.Parameter(torch.tensor(0.0))
def forward(self, x):
if self.symmetric:
qmin = -(2 ** (self.bits - 1))
qmax = 2 ** (self.bits - 1) - 1
else:
qmin = 0
qmax = 2 ** self.bits - 1
x_scaled = x / self.scale
x_quant = torch.round(x_scaled)
x_clamped = torch.clamp(x_quant, qmin, qmax)
x_dequant = x_clamped * self.scale
return x_dequant
# 遍历模型,插入伪量化节点
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 在权重后插入伪量化
module.weight_fake_quant = FakeQuantize(
self.config.weight_bits,
self.config.symmetric
)
return model
def calibrate(self, model, calibration_loader, device):
"""校准 (PTQ)"""
print(f"开始校准,使用 {len(calibration_loader)} 个批次...")
model.eval()
activation_stats = defaultdict(list)
def create_hook(name):
def hook(module, input, output):
activation_stats[name].append(output.detach().cpu())
return hook
# 注册 hook 收集激活统计
hooks = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
handle = module.register_forward_hook(create_hook(name))
hooks.append(handle)
# 校准前向传播
with torch.no_grad():
for data, _ in calibration_loader:
data = data.to(device)
model(data)
# 移除 hook
for handle in hooks:
handle.remove()
# 计算量化参数
for name, activations in activation_stats.items():
all_activations = torch.cat(activations, dim=0)
if self.config.symmetric:
max_abs = torch.max(torch.abs(all_activations))
scale = max_abs / (2 ** (self.config.activation_bits - 1) - 1)
zero_point = torch.tensor(0.0)
else:
min_val = torch.min(all_activations)
max_val = torch.max(all_activations)
scale = (max_val - min_val) / (2 ** self.config.activation_bits - 1)
zero_point = torch.round(-min_val / scale)
self.scale_factors[name] = scale
self.zero_points[name] = zero_point
print(f"校准完成,收集了 {len(activation_stats)} 层的统计信息")
return self.scale_factors, self.zero_points
def quantize_weights(self, model):
"""量化权重"""
print(f"开始权重量化,方法:{self.config.method.value}")
quantized_state_dict = {}
for name, param in model.named_parameters():
if 'weight' in name and len(param.shape) >= 2:
weight = param.detach().cpu()
if self.config.per_channel:
# 逐通道量化
if len(weight.shape) == 2:
scale = torch.max(torch.abs(weight), dim=1, keepdim=True)[0]
else:
scale = torch.max(torch.abs(weight), dim=tuple(range(1, len(weight.shape))), keepdim=True)[0]
scale = scale / (2 ** (self.config.weight_bits - 1) - 1)
weight_quant = torch.round(weight / scale)
weight_quant = torch.clamp(weight_quant,
-(2 ** (self.config.weight_bits - 1)),
2 ** (self.config.weight_bits - 1) - 1)
weight_dequant = weight_quant * scale
else:
# 逐层量化
max_abs = torch.max(torch.abs(weight))
scale = max_abs / (2 ** (self.config.weight_bits - 1) - 1)
weight_quant = torch.round(weight / scale)
weight_quant = torch.clamp(weight_quant,
-(2 ** (self.config.weight_bits - 1)),
2 ** (self.config.weight_bits - 1) - 1)
weight_dequant = weight_quant * scale
quantized_state_dict[name] = weight_dequant
else:
quantized_state_dict[name] = param.detach().cpu()
# 加载量化后的权重
model.load_state_dict(quantized_state_dict)
print(f"权重量化完成,精度:{self.config.weight_bits}bit")
return model
def get_compression_metrics(self, original_model, quantized_model):
"""获取压缩指标"""
original_size = self._calculate_model_size(original_model)
compressed_size = self._calculate_model_size(quantized_model)
original_params = sum(p.numel() for p in original_model.parameters())
pruned_params = sum(p.numel() for p in quantized_model.parameters())
compression_ratio = original_size / compressed_size if compressed_size > 0 else 0
sparsity = 1.0 - (pruned_params / original_params) if original_params > 0 else 0
memory_reduction = (original_size - compressed_size) / original_size * 100
return {
'original_size_mb': original_size,
'compressed_size_mb': compressed_size,
'compression_ratio': compression_ratio,
'original_params': original_params,
'quantized_params': pruned_params,
'sparsity': sparsity,
'memory_reduction_percent': memory_reduction
}
def _calculate_model_size(self, model):
"""计算模型大小 (MB)"""
total_size = 0
for param in model.parameters():
total_size += param.numel() * param.element_size()
return total_size / (1024 ** 2)
# 使用示例
if __name__ == "__main__":
print("=== Agent 模型轻量化与蒸馏量化 ===\n")
print("=== 创建蒸馏配置 ===")
# 创建简单模型示例
class TeacherModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
class StudentModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
teacher = TeacherModel()
student = StudentModel()
# 蒸馏配置
distill_config = DistillationConfig(
teacher_model=teacher,
student_model=student,
distillation_type=DistillationType.LOGITS,
temperature=4.0,
alpha=0.7,
beta=0.3,
feature_layers=['fc2'],
distill_epochs=10
)
print(f"教师模型参数量:{sum(p.numel() for p in teacher.parameters()):,}")
print(f"学生模型参数量:{sum(p.numel() for p in student.parameters()):,}")
print(f"蒸馏类型:{distill_config.distillation_type.value}")
print(f"温度:{distill_config.temperature}")
print(f"蒸馏损失权重:alpha={distill_config.alpha}, beta={distill_config.beta}")
print()
print("=== 创建量化配置 ===")
# 量化配置
quant_config = QuantizationConfig(
method=QuantizationMethod.PTQ,
weight_bits=8,
activation_bits=8,
per_channel=True,
symmetric=True,
calibration_samples=100,
quantize_embedding=True
)
quantizer = ModelQuantizer(quant_config)
print(f"量化方法:{quant_config.method.value}")
print(f"权重量化精度:{quant_config.weight_bits}bit")
print(f"激活量化精度:{quant_config.activation_bits}bit")
print(f"逐通道量化:{quant_config.per_channel}")
print(f"对称量化:{quant_config.symmetric}")
print()
print("=== 测试权重量化 ===")
# 测试量化
original_model = TeacherModel()
quantizer.prepare_quantization(original_model)
# 模拟校准数据
dummy_data = torch.randn(32, 784)
with torch.no_grad():
_ = original_model(dummy_data)
# 量化权重
quantized_model = quantizer.quantize_weights(original_model)
# 获取压缩指标
metrics = quantizer.get_compression_metrics(TeacherModel(), quantized_model)
print(f"\n压缩指标:")
print(f" 原始大小:{metrics['original_size_mb']:.2f} MB")
print(f" 量化后大小:{metrics['compressed_size_mb']:.2f} MB")
print(f" 压缩比:{metrics['compression_ratio']:.2f}x")
print(f" 内存减少:{metrics['memory_reduction_percent']:.1f}%")
print(f" 原始参数:{metrics['original_params']:,}")
print(f" 量化参数:{metrics['quantized_params']:,}")
print(f" 稀疏度:{metrics['sparsity']:.2%}")
print(f"\n关键观察:")
print("1. 模型轻量化:架构优化、参数剪枝、低秩分解")
print("2. 知识蒸馏:Logits 蒸馏、特征蒸馏、自蒸馏")
print("3. 模型量化:INT8/INT4、PTQ/QAT、AWQ/GPTQ")
print("4. 模型压缩:权重压缩、端侧部署、推理优化")
print("5. 精简高效:轻量化 + 蒸馏 + 量化 + 压缩 = 可信赖")
print("\n精简高效的使命:让 AI 模型更小、更快、更强")