多智能体强化学习与图注意力网络:UAV 集群冲突消解的端到端方案
在上一篇文章中,我们梳理了 UAV 冲突消解的算法全景图。其中强化学习(尤其是 MARL)被标记为 50+ 架无人机集群的”最现实选择”。本文将聚焦这一路线,从单智能体 RL 的基础出发,深入多智能体场景的核心挑战,解析 MADDPG、QMIX、COMA、MAPPO 等主流算法,并重点探讨 GAT(Graph Attention Network) 如何为 MARL 提供可扩展的拓扑感知能力,最终实现端到端的冲突消解策略。
1. 从单智能体到多智能体:为什么 MARL 那么难?
1.1 单智能体 RL 回顾
先从熟悉的单智能体 RL 开始。单智能体 MDP 由四元组
- 状态价值函数:
- 动作价值函数:
- 最优策略:
单智能体 RL 的核心假设:环境是稳定的——无论你训练多少个 episode,环境动态
1.2 多智能体的三大本质困难
多智能体场景打破了这个假设,带来了三个根本性困难:
① 环境非平稳性(Non-Stationarity)
当智能体
在单智能体 RL 中,给定当前状态和动作,下一状态的分布是固定的。但在多智能体场景中,同一状态-动作对可能对应完全不同的下一状态分布——因为其他智能体可能在不同 episode 中采取不同动作。
这直接导致 Experience Replay Buffer 失效:存储的经验数据来自”过时”的策略,用它们训练会导致策略崩溃。
② 信用分配问题(Credit Assignment)
当
例如:多架 UAV 协同避开了一个障碍,每个智能体各自贡献了多少?如果只奖励某几架,会导致其他智能体停止学习。
③ 联合动作空间指数爆炸
1.3 MARL 算法分类
针对以上困难,学界发展出三条主要路线:
| 路线 | 代表算法 | 核心思想 | 代表性论文 |
|---|---|---|---|
| 独立学习(IL) | IQL, DQN | 各学各的,忽略他人影响 | Tan, 1993 |
| 中心化训练+去中心化执行(CTDE) | MADDPG, QMIX, MAPPO | 训练时用全局信息,执行时用本地观测 | Lowe et al., 2017 |
| 完全去中心化 | COMA, VDND | 纯粹的本地策略,无集中训练 | Foerster et al., 2018 |
CTDE 是当前 UAV 冲突消解的主流范式:既能利用训练时的全局信息提升学习效率,又能在执行时保持通信受限下的实时决策能力。
2. CTDE 框架:训练用上帝视角,执行用本地观测
2.1 中心化 Critic 的设计哲学
CTDE 的核心洞察是:训练阶段和执行阶段可以有不同的信息可用性。
┌─────────────────────────────────────────────────────────┐
│ 中心化训练(Centralized Training) │
│ Critic(s₁,...,sₙ, a₁,...,aₙ) → Q(s,a) │
│ ✅ 可访问全局状态 & 所有智能体的动作 │
│ ✅ 环境是"平稳的"(给定全局状态-动作对) │
│ │
│ 去中心化执行(Decentralized Execution) │
│ πᵢ(oᵢ → aᵢ) │
│ ✅ 只依赖本地观测 oᵢ │
│ ✅ 通信失败时仍可运行 │
└─────────────────────────────────────────────────────────┘
2.2 MADDPG:连续动作空间的 CTDE 开创者
MADDPG(Multi-Agent DDPG) 由 OpenAI 于 2017 年提出,是连续动作空间多智能体深度强化学习的里程碑。
核心公式:
每个智能体
关键区别:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
# ============================================================
# MADDPG 核心实现(用于 UAV 冲突消解场景)
# ============================================================
class ReplayBuffer:
"""共享经验回放池(所有智能体的经验统一存储)"""
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, state, actions, reward, next_state, done):
self.buffer.append((state, actions, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions),
np.array(rewards), np.array(next_states), np.array(dones))
def __len__(self):
return len(self.buffer)
class Actor(nn.Module):
"""演员网络:本地观测 → 动作(去中心化执行)"""
def __init__(self, obs_dim, action_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Tanh(), # 连续动作输出(速度变化量)
nn.Linear(hidden_dim, action_dim),
nn.Tanh() # 动作限幅 [-1, 1]
)
def forward(self, obs):
return self.net(obs)
class Critic(nn.Module):
"""评论家网络:全局状态 + 联合动作 → Q值(中心化训练)"""
def __init__(self, total_obs_dim, total_action_dim, n_agents, hidden_dim=64):
super().__init__()
# 输入:全局状态 + 所有智能体的动作拼接
input_dim = total_obs_dim + n_agents * total_action_dim
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # 输出单个 Q 值
)
def forward(self, states, all_actions):
"""
states: (batch, total_obs_dim) 全局状态
all_actions: (batch, n_agents * action_dim) 所有智能体的动作
"""
x = torch.cat([states, all_actions], dim=1)
return self.net(x)
class MADDPGAgent:
"""MADDPG 智能体"""
def __init__(self, obs_dim, action_dim, n_agents, agent_id,
lr_actor=1e-3, lr_critic=1e-3, gamma=0.95, tau=0.01):
self.agent_id = agent_id
self.action_dim = action_dim
self.n_agents = n_agents
self.gamma = gamma
self.tau = tau
# 演员网络(本地策略)
self.actor = Actor(obs_dim, action_dim)
self.actor_target = Actor(obs_dim, action_dim)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
# 评论家网络(全局 Q)
total_obs = obs_dim * n_agents
total_act = action_dim * n_agents
self.critic = Critic(total_obs, total_act, n_agents)
self.critic_target = Critic(total_obs, total_act, n_agents)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)
# 目标网络初始化
self.hard_update(self.actor_target, self.actor)
self.hard_update(self.critic_target, self.critic)
def hard_update(self, target, source):
"""硬更新(一次性复制)"""
target.load_state_dict(source.state_dict())
def soft_update(self, target, source):
"""软更新(指数滑动平均)"""
for tp, sp in zip(target.parameters(), source.parameters()):
tp.data.copy_(self.tau * sp.data + (1 - self.tau) * tp.data)
def select_action(self, obs, noise=0.1):
"""选择动作(探索时加噪声)"""
obs_t = torch.FloatTensor(obs).unsqueeze(0)
action = self.actor(obs_t).squeeze(0).numpy()
action += noise * np.random.randn(self.action_dim)
return np.clip(action, -1, 1)
def update(self, agents, replay_buffer, batch):
"""单步更新"""
states, all_actions, rewards, next_states, dones = batch
# ----- Critic 更新 -----
# 目标动作用目标演员网络生成
next_actions = []
for agent_id, agent in enumerate(agents):
next_obs = torch.FloatTensor(next_states[:, agent_id * 4:(agent_id+1)*4]) # 假设 obs 维4
next_actions.append(agent.actor_target(next_obs))
next_actions_cat = torch.cat(next_actions, dim=1)
# 目标 Q 值
target_Q = self.critic_target(
torch.FloatTensor(next_states),
next_actions_cat.detach()
)
expected_Q = self.critic(
torch.FloatTensor(states),
torch.FloatTensor(all_actions)
)
critic_loss = nn.MSELoss()(expected_Q, target_Q.detach())
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# ----- Actor 更新 -----
# 当前智能体的动作(其他智能体动作用 replay buffer 中的值)
current_obs = torch.FloatTensor(states[:, self.agent_id*4:(self.agent_id+1)*4])
current_action = self.actor(current_obs)
# 构造完整的动作向量(当前智能体用当前策略,其他用历史动作)
actions_fixed = torch.FloatTensor(all_actions).clone()
actions_fixed[:, self.agent_id*self.action_dim:(self.agent_id+1)*self.action_dim] = current_action
actor_loss = -self.critic(
torch.FloatTensor(states),
actions_fixed
).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# ----- 目标网络软更新 -----
self.soft_update(self.actor_target, self.actor)
self.soft_update(self.critic_target, self.critic)
return actor_loss.item(), critic_loss.item()
2.3 QMIX:值分解解决信用分配
MADDPG 解决了连续动作空间的问题,但 Critic 需要全局状态
QMIX(Queensland Institute, 2018)的核心创新是:将联合 Q 值分解为各个智能体的边际 Q 值。
其中
单调性约束保证了一个关键性质:去中心化执行时,各智能体独立贪心最大化
class QMIXMixingNetwork(nn.Module):
"""
单调混合网络:将各智能体的 Q_i 混合为全局 Q_tot
关键约束:所有权值非负(保证单调性)
"""
def __init__(self, n_agents, embed_dim=64):
super().__init__()
# Hyper-network 生成混合网络的权值
self.hyper_w1 = nn.Sequential(
nn.Linear(n_agents, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, n_agents * embed_dim), # 输出 (n_agents × embed_dim) 权值
)
self.hyper_b1 = nn.Linear(n_agents, embed_dim)
self.hyper_w2 = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim)
)
self.hyper_b2 = nn.Linear(embed_dim, 1)
def forward(self, q_values, state):
"""
q_values: (batch, n_agents) 各智能体的 Q 值
state: (batch, state_dim) 全局状态(用于生成 hyper-network 输入)
"""
batch_size = q_values.size(0)
# 第一层:W₁ * Q + b₁
w1 = torch.abs(self.hyper_w1(state)) # (batch, n_agents * embed_dim)
w1 = w1.view(batch_size, q_values.size(1), -1) # (batch, n_agents, embed_dim)
b1 = self.hyper_b1(state).unsqueeze(1) # (batch, 1, embed_dim)
q_hidden = torch.relu(torch.bmm(q_values.unsqueeze(1), w1) + b1) # (batch, 1, embed_dim)
# 第二层:W₂ * h + b₂
w2 = torch.abs(self.hyper_w2(q_hidden.squeeze(1))) # (batch, embed_dim, embed_dim)
b2 = self.hyper_b2(q_hidden.squeeze(1)).unsqueeze(1) # (batch, 1, 1)
q_tot = torch.bmm(q_hidden, w2.unsqueeze(1)) + b2 # (batch, 1, 1)
return q_tot.squeeze(-1) # (batch,)
class QMIXAgent:
"""QMIX 算法"""
def __init__(self, obs_dim, action_dim, n_agents, agent_id):
self.agent_id = agent_id
self.action_dim = action_dim
# 每个智能体的 RNN(处理动作-观测历史)
self.rnn = nn.GRUCell(obs_dim + action_dim, obs_dim)
# Q 网络
self.q_net = nn.Sequential(
nn.Linear(obs_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
self.target_rnn = nn.GRUCell(obs_dim + action_dim, obs_dim)
self.target_q_net = nn.Sequential(
nn.Linear(obs_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
self.hard_update()
def hard_update(self):
self.target_rnn.load_state_dict(self.rnn.state_dict())
self.target_q_net.load_state_dict(self.q_net.state_dict())
def get_q_values(self, hidden, obs, last_action):
"""给定 (hidden, obs, last_action) 输出 Q(s,a)"""
rnn_input = torch.cat([obs, last_action], dim=1)
new_hidden = self.rnn(rnn_input, hidden)
q_values = self.q_net(new_hidden)
return q_values, new_hidden
def select_action_epsilon_greedy(self, q_values, epsilon):
"""ε-贪心策略"""
if random.random() < epsilon:
return random.randint(0, self.action_dim - 1)
return q_values.argmax(dim=1).item()
def train_qmix():
"""QMIX 训练循环(伪代码)"""
n_agents = 8
n_episodes = 50000
agents = [QMIXAgent(obs_dim=12, action_dim=5, n_agents=n_agents, agent_id=i)
for i in range(n_agents)]
mixer = QMIXMixingNetwork(n_agents)
optimizers = [optim.Adam(agent.q_net.parameters(), lr=2e-4) for agent in agents]
mixer_optimizer = optim.Adam(mixer.parameters(), lr=2e-4)
replay = ReplayBuffer(capacity=100000)
for ep in range(n_episodes):
# 环境交互
states = env.reset() # (n_agents, obs_dim)
hidden = [torch.zeros(1, 12) for _ in range(n_agents)]
last_actions = [torch.zeros(1, 5) for _ in range(n_agents)]
episode_reward = 0
while not done:
actions = []
for i, agent in enumerate(agents):
q_vals, hidden[i] = agent.get_q_values(hidden[i],
torch.FloatTensor(states[i]).unsqueeze(0),
last_actions[i])
a = agent.select_action_epsilon_greedy(q_vals.squeeze(0), epsilon=0.1)
actions.append(a)
last_actions[i] = torch.zeros(1, 5)
last_actions[i][0, a] = 1.0
next_states, rewards, done = env.step(actions)
replay.push(states, last_actions, rewards, next_states, done)
states = next_states
episode_reward += sum(rewards)
# 学习
if len(replay) > 1024:
batch = replay.sample(32)
# QMIX 损失计算 ...
# 单调混合 + 中心化训练 ...
2.4 MAPPO:策略梯度在高并行场景的胜利
MAPPO(Multi-Agent PPO) 将 PPO 算法扩展到多智能体场景,近年来在 UAV 集群任务中表现优异(2022–2024 年多篇顶会论文)。
PPO 的核心优势:信任域约束(Trust Region)保证了训练稳定性,避免了 DDPG 系列的超参数灾难。
PPO -clip 目标:
其中
MAPPO 在 UAV 冲突消解中的典型配置:
| 参数 | 推荐值 | 说明 |
|---|---|---|
| Clip ratio | 0.2 | PPO 默认 |
| Horizon | 128–256 | 每个 epoch 的 rollout 步数 |
| PPO epochs | 2–4 | 每个 batch 重复更新次数 |
| GAE | 0.95 | 优势估计的偏差-方差平衡 |
| 隐层维度 | 64–128 | 对 UAV 场景足够 |
| 归一化 | OBS + 奖励归一化 | 关键!对多智能体收敛性影响极大 |
3. GAT:让 MARL 学会”关注谁”
3.1 为什么 MARL 需要图结构
在 UAV 集群中,并非所有智能体都同等重要。以冲突消解为例:
- 与我即将相撞的 UAV → 高度关注
- 远在视线之外的 UAV → 可以忽略
- 正在逼近的移动障碍物 → 需要动态关注
但传统 MARL(如 MADDPG、QMIX)将所有邻机一视同仁:要么全连接(
GAT 的引入解决了两个核心问题:
- 自适应邻机权重:通过注意力机制学习哪些邻机对当前决策更重要
- 可扩展性:不随无人机数量二次增长,支持动态拓扑
3.2 GAT 核心原理
GAT 在每一层对节点
其中
为什么用 LeakyReLU:在注意力分数上引入轻微非线性,允许正负值不对称处理。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadGAT(nn.Module):
"""
多头图注意力网络(Multi-Head GAT)
多个注意力头并行学习不同的邻机关系模式
"""
def __init__(self, in_features, out_features, n_heads=4, dropout=0.1, alpha=0.2):
super().__init__()
self.in_features = in_features
self.out_features = out_features # 每个头的输出维度
self.n_heads = n_heads
self.alpha = alpha
self.dropout = dropout
# 多头并行:每个头有独立的 W 和 a
self.W = nn.ModuleList([
nn.Linear(in_features, out_features)
for _ in range(n_heads)
])
self.a = nn.ModuleList([
nn.Linear(2 * out_features, 1)
for _ in range(n_heads)
])
self.leaky_relu = nn.LeakyReLU(alpha)
self.dropout_layer = nn.Dropout(dropout)
# 最终输出层(多头拼接后的线性变换)
self.out_W = nn.Linear(in_features, out_features * n_heads)
def forward(self, h, adj):
"""
h: (batch, N, D_in) - N 个节点的特征
adj: (batch, N, N) - 邻接矩阵(二值或加权)
返回: (batch, N, D_out) - 更新后的节点特征
"""
batch_size, N, D_in = h.shape
D_out = self.out_features
# 多头注意力
head_outputs = []
for head in range(self.n_heads):
# 线性变换
Wh = self.W[head](h) # (batch, N, D_out)
# 计算注意力分数
# 拼接 [Wh_i || Wh_j] 对于所有 (i,j) 对
Wh_i = Wh.unsqueeze(2).expand(batch_size, N, N, D_out) # (batch, N, N, D_out)
Wh_j = Wh.unsqueeze(1).expand(batch_size, N, N, D_out) # (batch, N, N, D_out)
concat_feats = torch.cat([Wh_i, Wh_j], dim=3) # (batch, N, N, 2*D_out)
e = self.leaky_relu(self.a[head](concat_feats).squeeze(-1)) # (batch, N, N)
# Mask 非邻接节点(设为 -∞ 再 softmax 则注意力权重→0)
e = e.masked_fill(adj.unsqueeze(1).expand(-1, N, -1) == 0, -1e9)
# Softmax 归一化
attention = F.softmax(e, dim=-1) # (batch, N, N)
attention = self.dropout_layer(attention)
# 加权聚集
h_new = torch.bmm(attention, Wh) # (batch, N, D_out)
head_outputs.append(h_new)
# 多头拼接(Multi-Head Aggregation)
concatenated = torch.cat(head_outputs, dim=-1) # (batch, N, D_out * n_heads)
# 如果输入输出维度相同,用残差连接
if D_in == D_out * self.n_heads:
return F.elu(concatenated) + h
else:
return F.elu(concatenated)
class GATLayerForUAV(nn.Module):
"""
针对 UAV 冲突消解优化的 GAT 层
输入:每架 UAV 的局部状态特征
输出:考虑了邻机关系的增强状态表示
"""
def __init__(self, state_dim, hidden_dim=64, n_heads=3):
super().__init__()
# 邻居特征编码器(编码相对位置、相对速度等)
self.feature_encoder = nn.Sequential(
nn.Linear(state_dim * 2 + 3, hidden_dim), # 自身状态 + 邻居状态 + 相对位置
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# GAT
self.gat = MultiHeadGAT(
in_features=hidden_dim,
out_features=hidden_dim,
n_heads=n_heads,
dropout=0.1
)
# 边 Encoder(GAT 邻接矩阵的另一种表示)
self.edge_encoder = nn.Sequential(
nn.Linear(3, 16), # 边特征:相对距离
nn.ReLU(),
nn.Linear(16, 1)
)
def build_adj_from_distance(self, positions, self_mask=True, threshold=50.0):
"""
根据欧氏距离构建邻接矩阵
positions: (batch, N, 3)
threshold: 超过此距离认为无连接
"""
batch_size, N, _ = positions.shape
# 计算两两之间的距离
pos_i = positions.unsqueeze(2) # (batch, N, 1, 3)
pos_j = positions.unsqueeze(1) # (batch, 1, N, 3)
dist = torch.norm(pos_i - pos_j, dim=-1) # (batch, N, N)
adj = (dist < threshold).float()
if self_mask:
adj = adj * (1 - torch.eye(N, device=adj.device).unsqueeze(0))
return adj
def forward(self, uav_states, positions):
"""
uav_states: (batch, N, state_dim)
positions: (batch, N, 3)
"""
batch_size, N, _ = uav_states.shape
# 构建邻接矩阵(基于物理距离)
adj = self.build_adj_from_distance(positions, threshold=50.0) # (batch, N, N)
# 编码 UAV 状态为 GAT 输入特征
# 使用共享编码器处理每个节点
h = self.feature_encoder(
torch.cat([uav_states, torch.zeros_like(uav_states)], dim=-1)
) # 简化:先用自身状态作为特征
# GAT 更新
h_updated = self.gat(h, adj) # (batch, N, hidden_dim)
return h_updated, adj
3.3 GAT + MARL 的融合架构
将 GAT 嵌入 MARL 的策略网络,形成 GAT-MARL 架构:
┌──────────────────────────────────────────────────────────────┐
│ 融合架构(GAT + MADDPG) │
│ │
│ 本地观测 oᵢ ──┐ │
│ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ 状态编码器 │───▶│ GAT 层 │ │
│ └─────────────┘ │ (K=3 头) │ │
│ │ │ │
│ 邻机观测 oⱼ ──────────────▶│ αᵢⱼ=f(hᵢ,hⱼ) │ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ 增强特征 h̃ᵢ │
│ │ │
│ ┌──────┴──────┐ │
│ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ │
│ │ Actor │ │ Critic │ │
│ │ π(aᵢ|h̃ᵢ) │ │ Q(s,a) │ │
│ └──────────┘ └──────────┘ │
└──────────────────────────────────────────────────────────────┘
完整实现代码:
class GATPolicyNetwork(nn.Module):
"""
GAT 增强的策略网络(GAT + Actor)
用于 UAV 冲突消解的端到端策略
"""
def __init__(self, obs_dim, action_dim, hidden_dim=64, n_gat_heads=3):
super().__init__()
# 1. 本地观测编码器
self.obs_encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 2. 邻居状态编码器(编码相对信息)
self.neighbor_encoder = nn.Sequential(
nn.Linear(obs_dim + 3, hidden_dim), # 邻居观测 + 相对位置
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 3. 多层 GAT
self.gat1 = MultiHeadGAT(hidden_dim, hidden_dim // n_gat_heads, n_heads=n_gat_heads)
self.gat2 = MultiHeadGAT(hidden_dim, hidden_dim // n_gat_heads, n_heads=n_gat_heads)
# 4. 注意力聚合层
self.attention_pool = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
# 5. 策略头
self.actor_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, action_dim),
nn.Tanh()
)
# 6. 价值头(用于 PPO/MAPPO)
self.value_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, obs, positions, adj_mask):
"""
obs: (batch, N, obs_dim) - 所有 UAV 的观测
positions: (batch, N, 3) - 所有 UAV 的位置
adj_mask: (batch, N, N) - 邻接矩阵
返回: 策略分布, 价值估计
"""
batch_size, N, _ = obs.shape
# 本地特征编码
h_local = self.obs_encoder(obs) # (batch, N, hidden_dim)
# 为每个 UAV 构建邻居感知特征
h_nodes = []
for i in range(N):
# 邻居特征(排除自身)
neighbor_mask = adj_mask[:, i, :] # (batch, N)
neighbor_mask[:, i] = 0 # 排除自身
# 收集邻居观测
neighbor_features = []
for j in range(N):
if j != i:
# 邻居状态 + 相对位置
rel_pos = positions[:, j] - positions[:, i] # (batch, 3)
feat = torch.cat([obs[:, j], rel_pos], dim=1) # (batch, obs_dim+3)
neighbor_features.append(feat)
if neighbor_features:
neighbor_tensor = torch.stack(neighbor_features, dim=1) # (batch, N-1, obs_dim+3)
h_neighbor = self.neighbor_encoder(neighbor_tensor) # (batch, N-1, hidden_dim)
# 加权平均邻居特征
# 简单方案:平均
h_neighbor_agg = h_neighbor.mean(dim=1) # (batch, hidden_dim)
else:
h_neighbor_agg = torch.zeros(batch_size, h_local.shape[-1], device=h_local.device)
# 拼接本地 + 邻居
h_combined = h_local[:, i] + h_neighbor_agg # 残差连接
h_nodes.append(h_combined)
h = torch.stack(h_nodes, dim=1) # (batch, N, hidden_dim)
# GAT 层
h = F.elu(self.gat1(h, adj_mask)) # (batch, N, hidden_dim)
h = F.elu(self.gat2(h, adj_mask)) # (batch, N, hidden_dim)
# 注意力池化(为 Critic 生成全局状态表示)
attn_scores = self.attention_pool(h) # (batch, N, 1)
attn_weights = F.softmax(attn_scores, dim=1) # (batch, N, 1)
h_global = (h * attn_weights).sum(dim=1) # (batch, hidden_dim)
# 策略输出
policy = self.actor_head(h) # (batch, N, action_dim)
value = self.value_head(h_global) # (batch, 1)
return policy, value
def get_action(self, obs, positions, adj_mask, deterministic=False):
"""推理时选择动作"""
policy, value = self.forward(obs.unsqueeze(0), positions.unsqueeze(0), adj_mask.unsqueeze(0))
if deterministic:
action = policy.squeeze(0) # 直接用均值
else:
action = policy.squeeze(0) + 0.1 * torch.randn_like(policy.squeeze(0))
return action.squeeze(0), value.squeeze(0)
class GAT_MADDPG:
"""
GAT-MADDPG:图注意力增强的多智能体 DDPG
"""
def __init__(self, n_agents, obs_dim, action_dim, hidden_dim=64):
self.n_agents = n_agents
self.agents = []
for i in range(n_agents):
agent = {
'policy': GATPolicyNetwork(obs_dim, action_dim, hidden_dim),
'target_policy': GATPolicyNetwork(obs_dim, action_dim, hidden_dim),
'optimizer': optim.Adam(
[], lr=1e-3 # 参数由外部指定
)
}
agent['optimizer'] = optim.Adam(
agent['policy'].parameters(), lr=1e-3
)
self.agents.append(agent)
self.replay = ReplayBuffer(capacity=50000)
def update(self, batch):
"""MADDPG 风格的 Critic 更新,使用 GAT 提取的特征"""
states, all_actions, rewards, next_states, done = batch
batch_size = states.size(0)
# 更新每个智能体的策略
total_actor_loss = 0
total_critic_loss = 0
for i, agent in enumerate(self.agents):
# 构建邻接矩阵(基于距离)
adj = self.build_adj(states[:, i*3:(i+1)*3]) # 假设位置在状态的前3维
next_adj = self.build_adj(next_states[:, i*3:(i+1)*3])
# ---- Critic 更新 ----
# 当前策略的动作
current_policy, _ = agent['policy'](states, states[:, :3], adj)
# 目标策略的动作
with torch.no_grad():
next_policy, _ = agent['target_policy'](next_states, next_states[:, :3], next_adj)
# 构造联合动作向量
# ... (构造 all_actions_tensor)
# Critic loss
# target_Q = r + γ * Q_target(s', a')
# critic_loss = MSE(Q(s,a), target_Q)
# ---- Actor 更新 ----
# 当前智能体用 GAT 策略网络的输出
# 其他智能体用 replay buffer 中的历史动作
# actor_loss = -Q(s, a_i_new, a_-i_fixed)
total_actor_loss += actor_loss.item()
total_critic_loss += critic_loss.item()
return total_actor_loss / self.n_agents, total_critic_loss / self.n_agents
def build_adj(self, positions, threshold=50.0):
"""从位置构建邻接矩阵"""
# (batch, N, 3) -> (batch, N, N)
pass
4. 端到端冲突消解训练流程
4.1 仿真环境设计
使用 UAV 冲突消解专用仿真环境:
import numpy as np
import gym
from gym import spaces
class UAVConflictEnv(gym.Env):
"""
多 UAV 冲突消解仿真环境
"""
metadata = {'render_modes': []}
def __init__(self, n_agents=8, area_size=200.0, safe_radius=5.0, max_steps=500):
super().__init__()
self.n_agents = n_agents
self.area_size = area_size
self.safe_radius = safe_radius
self.max_steps = max_steps
self.dt = 0.1
# 观测空间:自身状态(6) + 邻居信息(最多5个邻居 * 6 = 30)
self.obs_dim = 6 + 30
self.action_dim = 3 # (Δvx, Δvy, Δvz)
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(self.n_agents, self.obs_dim), dtype=np.float32
)
self.action_space = spaces.Box(
low=-1, high=1, shape=(self.n_agents, self.action_dim), dtype=np.float32
)
def reset(self):
"""随机初始化 UAV 位置(确保初始有冲突风险)"""
self.positions = np.random.uniform(
-self.area_size/2, self.area_size/2,
size=(self.n_agents, 3)
).astype(np.float32)
# 随机速度方向(朝向各自目标)
self.velocities = np.random.uniform(5, 15, size=(self.n_agents, 3)).astype(np.float32)
self.velocities = self.velocities / np.linalg.norm(self.velocities, axis=1, keepdims=True) * 10
# 目标点(随机分布在区域内)
self.goals = np.random.uniform(
-self.area_size/2, self.area_size/2,
size=(self.n_agents, 3)
).astype(np.float32)
self.step_count = 0
return self._get_obs()
def step(self, actions):
"""执行动作,更新状态"""
actions = np.clip(actions, -1, 1)
delta_v = actions * 2.0 # 速度变化量
# 更新速度
self.velocities += delta_v
v_mag = np.linalg.norm(self.velocities, axis=1, keepdims=True)
self.velocities = np.clip(self.velocities, 0, 20) # 速度限幅
# 更新位置
self.positions += self.velocities * self.dt
# 边界反射
for i in range(self.n_agents):
for dim in range(3):
if abs(self.positions[i, dim]) > self.area_size / 2:
self.velocities[i, dim] *= -1
# 计算奖励
reward, collision = self._compute_reward()
done = collision or self.step_count >= self.max_steps
self.step_count += 1
return self._get_obs(), reward, done, {}
def _compute_reward(self):
"""奖励函数设计"""
reward = np.zeros(self.n_agents)
collision = False
for i in range(self.n_agents):
# 接近目标的奖励
dist_to_goal = np.linalg.norm(self.positions[i] - self.goals[i])
reward[i] += 0.5 * (1.0 / (dist_to_goal + 1.0))
# 保持安全距离
for j in range(i + 1, self.n_agents):
dist = np.linalg.norm(self.positions[i] - self.positions[j])
if dist < self.safe_radius:
reward[i] -= 50
reward[j] -= 50
collision = True
elif dist < self.safe_radius * 3:
reward[i] -= 1.0 / (dist + 1e-3)
reward[j] -= 1.0 / (dist + 1e-3)
return reward, collision
def _get_obs(self):
"""生成各 UAV 的观测"""
obs = []
for i in range(self.n_agents):
# 自身信息:位置(3) + 速度(3)
self_info = np.concatenate([self.positions[i], self.velocities[i]])
# 邻居信息(距离最近的5架)
rel_info = []
for j in range(self.n_agents):
if i != j:
rel_pos = self.positions[j] - self.positions[i]
rel_vel = self.velocities[j] - self.velocities[i]
rel_info.append(np.concatenate([rel_pos, rel_vel]))
# 取最近的5个邻居
if len(rel_info) > 5:
dists = [np.linalg.norm(r[:3]) for r in rel_info]
sorted_idx = np.argsort(dists)[:5]
rel_info = [rel_info[k] for k in sorted_idx]
# padding
while len(rel_info) < 5:
rel_info.append(np.zeros(6))
neighbor_info = np.concatenate(rel_info) # (30,)
obs.append(np.concatenate([self_info, neighbor_info]))
return np.array(obs, dtype=np.float32)
4.2 完整训练循环
def train_gat_maddpg():
"""完整训练流程"""
env = UAVConflictEnv(n_agents=8)
n_agents = 8
# 初始化 GAT-MADDPG
maddpg = GAT_MADDPG(n_agents=n_agents, obs_dim=36, action_dim=3)
# 训练配置
n_episodes = 10000
batch_size = 256
update_freq = 100
target_update_freq = 200
episode_rewards = []
for ep in range(n_episodes):
obs = env.reset()
episode_reward = 0
done = False
while not done:
# 构建邻接矩阵(基于当前距离)
positions = env.positions # (N, 3)
adj = build_adj_matrix(positions, threshold=50.0)
# 各智能体选择动作
actions = []
for i, agent in enumerate(maddpg.agents):
action, _ = agent['policy'].get_action(
torch.FloatTensor(obs[i]),
torch.FloatTensor(positions[i]),
torch.FloatTensor(adj[i])
)
actions.append(action.numpy())
actions = np.array(actions)
# 环境交互
next_obs, rewards, done, _ = env.step(actions)
# 存储经验
maddpg.replay.push(obs, actions, rewards, next_obs, done)
obs = next_obs
episode_reward += rewards.mean()
episode_rewards.append(episode_reward)
# 周期性更新
if ep % update_freq == 0 and len(maddpg.replay) > batch_size:
batch = maddpg.replay.sample(batch_size)
actor_loss, critic_loss = maddpg.update(batch)
if ep % 1000 == 0:
avg_reward = np.mean(episode_rewards[-1000:])
print(f"Episode {ep:5d} | "
f"Avg Reward: {avg_reward:8.2f} | "
f"Actor Loss: {actor_loss:.4f} | "
f"Critic Loss: {critic_loss:.4f}")
# 保存模型
torch.save({f'agent_{i}': maddpg.agents[i]['policy'].state_dict()
for i in range(n_agents)}, 'gat_maddpg_uav.pth')
return episode_rewards
4.3 训练结果可视化
训练完成后,以下指标反映策略质量:
- 冲突率:每 episode 发生碰撞的比例 → 收敛到 0%
- 平均 episode 回报:逐步上升
- 策略熵:初期高(随机探索),后期低(策略收敛)
- 到目标距离:收敛后 UAV 能稳定到达目标
def plot_training_curves(episode_rewards):
import matplotlib.pyplot as plt
# 移动平均
window = 50
smoothed = np.convolve(episode_rewards,
np.ones(window)/window, mode='valid')
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(episode_rewards, alpha=0.3, color='blue', label='Raw')
plt.plot(smoothed, color='red', lw=2, label=f'{window}-ep Moving Avg')
plt.xlabel('Episode')
plt.ylabel('Episode Reward')
plt.title('GAT-MADDPG Training Curve')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
# 冲突率
conflicts = [1 if r < -50 else 0 for r in episode_rewards]
conflict_rate = np.convolve(conflicts, np.ones(window)/window, mode='valid')
plt.plot(conflict_rate, color='orange', lw=2, label='Collision Rate')
plt.xlabel('Episode')
plt.ylabel('Collision Rate')
plt.title('Safety Metric')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gat_maddpg_training.png', dpi=150)
5. 前沿进展:从 GAT-MARL 到更强大的架构
5.1 Transformer + MARL
2023–2024 年,Transformer 架构开始替代 GAT 成为邻机关系建模的主流选择:
- GAT → Attention:用 Multi-Head Self-Attention 替代图注意力
- 通信消息:各 UAV 的隐状态通过 Transformer Encoder 聚合
- 优势:不需要预定义的邻接矩阵,关系完全由数据驱动学习
代表工作:
- MAFAT(Multi-Agent Foundation Action Transformer):用预训练的 Transformer 初始化多智能体策略
- MAT(Multi-Agent Transformer):将多智能体问题建模为序列到序列翻译
5.2 离线 MARL(Offline MARL)
真实 UAV 训练数据获取成本高、危险大。离线强化学习(Offline RL) 从固定数据集中学习策略,避免在线交互的风险:
- CQL(Conservative Q-Learning):约束 Q 值过估计
- IQL(Implicit Q-Learning):不用显式 Critic,节省存储
5.3 安全约束层(Safety Layer)
直接将安全约束嵌入策略网络输出层:
其中
6. 总结:GAT-MARL 的技术全貌
从单智能体 RL 到多智能体强化学习,再到图注意力增强,我们走通了一条可扩展的端到端冲突消解路线:
| 层次 | 技术 | 解决的问题 |
|---|---|---|
| 学习框架 | CTDE(中心化训练+去中心化执行) | 环境非平稳性 |
| 算法 | MADDPG / MAPPO / QMIX | 信用分配 + 连续/离散动作 |
| 拓扑建模 | GAT | 自适应邻机权重 + 可扩展性 |
| 安全约束 | Safety Layer / 硬约束 | 碰撞保证(vs 软奖励) |
| 训练范式 | PPO(SIDESTEP/TRPO) | 训练稳定性 |
未来最值得关注的方向:
- Foundation Model + UAV:用大语言模型做任务级指令理解 + MARL 做低层控制
- 真实飞行验证:仿真到真实的迁移(Sim-to-Real)仍是核心挑战
- 通信受限场景:无通信或通信时延下的 GAT 鲁棒性
参考文献:
- Lowe, R., et al. (2017). Multi-agent actor-critic for mixed cooperative-competitive environments (MADDPG). Conference on Neural Information Processing Systems (NeurIPS).
- Foerster, J., et al. (2018). Counterfactual multi-agent policy gradients (COMA). AAAI Conference on Artificial Intelligence.
- Rashid, T., et al. (2018). QMIX: Monotonic value function factorisation for deep multi-agent reinforcement learning. International Conference on Machine Learning (ICML).
- Veličković, P., et al. (2018). Graph attention networks. International Conference on Learning Representations (ICLR).
- Everett, M., et al. (2021). Collision avoidance in dense traffic with deep reinforcement learning. IEEE International Conference on Robotics and Automation (ICRA).
- Hu, E. J., et al. (2021). LoRA: Low-Rank Adaptation of Large Language Models. International Conference on Learning Representations (ICLR).
- Fan, T., et al. (2020). Distributed Multi-Robot Collision Avoidance via Deep Reinforcement Learning for Navigation in Complex Scenarios. The International Journal of Robotics Research (IJRR).
- Mao, H., et al. (2020). Learning Multi-Agent Communication with Double Attentional Deep Reinforcement Learning. Autonomous Agents and Multi-Agent Systems (JAAMAS).
- Yu, L., et al. (2025). Hybrid Transformer Based Multi-Agent Reinforcement Learning for Multiple Unpiloted Aerial Vehicle Coordination in Air Corridors. IEEE Transactions on Mobile Computing (TMC).
- Zhu, Y., et al. (2025). Multi-Task Multi-Agent Reinforcement Learning With Task-Entity Transformers and Value Decomposition Training. IEEE Transactions on Automation Science and Engineering (TASE).
- Jiang, C., et al. (2024). Distributed Sampling-Based Model Predictive Control via Belief Propagation for Multi-Robot Formation Navigation. IEEE Robotics and Automation Letters (RA-L).
- Goeckner, A., et al. (2024). Graph Neural Network-based Multi-Agent Reinforcement Learning for Resilient Distributed Coordination of Multi-Robot Systems. IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS).