Actor-Critic 算法完整实现
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Dict, List
from dataclasses import dataclass
import gymnasium as gym
@dataclass
class ActorCriticConfig:
"""配置类"""
state_dim: int
action_dim: int
hidden_dim: int = 256
actor_lr: float = 3e-4
critic_lr: float = 1e-3
gamma: float = 0.99
max_grad_norm: float = 0.5
class ActorNetwork(nn.Module):
"""Actor 网络:输出动作概率分布"""
def __init__(self, config: ActorCriticConfig):
super().__init__()
self.network = nn.Sequential(
nn.Linear(config.state_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.action_dim),
)
def forward(self, state: torch.Tensor) -> torch.Tensor:
"""输出动作 logits"""
return self.network(state)
def get_action(self, state: torch.Tensor,
deterministic: bool = False) -> Tuple[int, torch.Tensor]:
"""
采样动作
Returns:
action: 采样动作
log_prob: 动作的对数概率
"""
logits = self.forward(state)
probs = F.softmax(logits, dim=-1)
if deterministic:
action = torch.argmax(probs, dim=-1)
log_prob = torch.log(probs.gather(1, action.unsqueeze(-1)))
else:
dist = torch.distributions.Categorical(probs)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob
class CriticNetwork(nn.Module):
"""Critic 网络:估计状态值 V(s)"""
def __init__(self, config: ActorCriticConfig):
super().__init__()
self.network = nn.Sequential(
nn.Linear(config.state_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.ReLU(),
nn.Linear(config.hidden_dim, 1)
)
def forward(self, state: torch.Tensor) -> torch.Tensor:
"""输出状态值 V(s)"""
return self.network(state)
class ActorCriticAgent:
"""
Actor-Critic Agent
核心机制:
1. Actor 学习策略 π(a|s)
2. Critic 估计值函数 V(s)
3. 使用 TD 误差 δ = r + γV(s') - V(s) 作为优势估计
4. 策略梯度:∇J ≈ δ * ∇log π(a|s)
"""
def __init__(self, config: ActorCriticConfig):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化网络
self.actor = ActorNetwork(config).to(self.device)
self.critic = CriticNetwork(config).to(self.device)
# 优化器
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=config.actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=config.critic_lr)
self.training_stats = {
'actor_loss': [],
'critic_loss': [],
'rewards': []
}
def select_action(self, state: np.ndarray,
deterministic: bool = False) -> int:
"""选择动作"""
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action, log_prob = self.actor.get_action(state_tensor, deterministic)
return action
def compute_td_target(self, reward: float,
next_state: np.ndarray,
done: bool) -> torch.Tensor:
"""
计算 TD 目标
TD Target: r + γ * V(s') * (1 - done)
"""
with torch.no_grad():
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
next_value = self.critic(next_state_tensor).item()
td_target = reward + self.config.gamma * next_value * (1 - done)
return torch.FloatTensor([td_target]).to(self.device)
def update(self, state: np.ndarray,
action: int,
reward: float,
next_state: np.ndarray,
done: bool) -> Dict[str, float]:
"""
单步更新
1. 计算 TD 误差:δ = r + γV(s') - V(s)
2. Critic 损失:L_critic = δ²
3. Actor 损失:L_actor = -δ * log π(a|s)
"""
# 转换为 Tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_tensor = torch.LongTensor([action]).to(self.device)
# 1. 计算当前状态值
current_value = self.critic(state_tensor)
# 2. 计算 TD 目标
td_target = self.compute_td_target(reward, next_state, done)
# 3. 计算 TD 误差 (Advantage 估计)
td_error = td_target - current_value
advantage = td_error.detach() # 停止梯度,避免影响 Critic
# 4. Critic 更新
critic_loss = F.mse_loss(current_value, td_target)
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), self.config.max_grad_norm)
self.critic_optimizer.step()
# 5. Actor 更新
_, log_prob = self.actor.get_action(state_tensor)
actor_loss = -log_prob * advantage
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.config.max_grad_norm)
self.actor_optimizer.step()
# 记录统计
self.training_stats['actor_loss'].append(actor_loss.item())
self.training_stats['critic_loss'].append(critic_loss.item())
self.training_stats['rewards'].append(reward)
return {
'actor_loss': actor_loss.item(),
'critic_loss': critic_loss.item(),
'td_error': td_error.item()
}
def train(self,
env: gym.Env,
n_episodes: int = 1000,
max_steps: int = 1000,
verbose: bool = True) -> List[float]:
"""
训练循环
Returns:
episode_rewards: 每集的总奖励
"""
episode_rewards = []
for episode in range(n_episodes):
state, _ = env.reset()
episode_reward = 0
for step in range(max_steps):
# 选择动作
action = self.select_action(state)
# 执行动作
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# 更新
stats = self.update(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
if done:
break
episode_rewards.append(episode_reward)
if verbose and (episode + 1) % 50 == 0:
avg_reward = np.mean(episode_rewards[-50:])
avg_actor_loss = np.mean(self.training_stats['actor_loss'][-50:])
avg_critic_loss = np.mean(self.training_stats['critic_loss'][-50:])
print(f"Episode {episode + 1}/{n_episodes} | "
f"Avg Reward: {avg_reward:.2f} | "
f"Actor Loss: {avg_actor_loss:.4f} | "
f"Critic Loss: {avg_critic_loss:.4f}")
return episode_rewards
# 使用示例
if __name__ == "__main__":
# 创建环境
env = gym.make('CartPole-v1')
# 配置
config = ActorCriticConfig(
state_dim=env.observation_space.shape[0],
action_dim=env.action_space.n,
hidden_dim=256,
actor_lr=3e-4,
critic_lr=1e-3,
gamma=0.99
)
# 创建 Agent
agent = ActorCriticAgent(config)
# 训练
print("开始训练 Actor-Critic Agent...")
episode_rewards = agent.train(env, n_episodes=500, max_steps=500)
print(f"\n训练完成!")
print(f"最后 100 集平均奖励:{np.mean(episode_rewards[-100:]):.2f}")
print(f"最高奖励:{max(episode_rewards):.2f}")
# 测试
print("\n测试训练好的策略...")
state, _ = env.reset()
test_reward = 0
for _ in range(500):
action = agent.select_action(state, deterministic=True)
state, reward, terminated, truncated, _ = env.step(action)
test_reward += reward
if terminated or truncated:
break
print(f"测试奖励:{test_reward:.2f}")
print("\n关键观察:")
print("1. Actor-Critic 结合了策略梯度和值函数的优势")
print("2. Critic 提供低方差的 Advantage 估计 (TD 误差)")
print("3. 支持在线学习,无需经验回放")
print("4. 适用于离散和连续动作空间")
print("5. 是现代深度 RL 算法 (A3C, PPO, SAC) 的基础")