マルチエージェント強化学習とグラフ アテンション ネットワーク: UAV クラスターの競合解決のためのエンドツーエンドのソリューション
前の記事 では、UAV の競合解決のアルゴリズムのパノラマを整理しました。その中でも、強化学習 (特に MARL) は、50 機以上のドローンの群れにとって「最も現実的なオプション」とされています。この記事では、シングル エージェント RL の基礎から始めて、マルチエージェント シナリオの中核的な課題に入り、MADDPG、QMIX、COMA、MAPPO などの主流アルゴリズムを分析し、GAT (グラフ アテンション ネットワーク) が MARL にスケーラブルなトポロジ認識機能を提供し、最終的にエンドツーエンドの競合解決戦略を達成する方法に焦点を当てて、このルートに焦点を当てます。
1. シングルエージェントからマルチエージェントへ: MARL はなぜそれほど難しいのでしょうか?
1.1 単一エージェントの RL レビュー
おなじみのシングルエージェント RL から始めましょう。シングルエージェント MDP は、 の 4 つの要素で記述されます。
シングルエージェント RL の中核となる前提: 環境は安定している - トレーニングするエピソードの数に関係なく、環境 のダイナミクスは常に変化しません。
1.2 マルチエージェントの 3 つの本質的な問題点
マルチエージェントのシナリオはこの前提を打ち破り、次の 3 つの根本的な問題を引き起こします。
① 環境の非定常性(Non-Stationarity)
エージェント がポリシー を学習しているとき、他のエージェントのポリシー も変更されます。これはつまり:$$
\mathcal{P}_i(s’\mid s, a_1,\dots,a_n) \neq \mathcal{P}_i(s’\mid s, a_1,\dots,a_n, a_1’,\dots,a_n’)
r_t = f(\mathbf{s}_t, \mathbf{a}t, \mathbf{s}{t+1})
\nabla_{\theta_i} J(\theta_i) = \mathbb{E}{\mathbf{s} \sim \mathcal{D}}\left[
\nabla{\theta_i} \log \pi_i(a_i \mid o_i) \cdot
Q_i^\pi(\mathbf{s}, a_1, \dots, a_n) \Big|_{a_i = \pi_i(o_i)}
\右]
がエージェント の行動観察軌跡である場合、 は次を満たす 単調混合ネットワークです。
単調性制約は重要な特性を保証します。 分散実行中、各エージェントの の独立した貪欲な最大化は、 のグローバルな最大化と同等です。
class QMIXMixingNetwork(nn.Module):
"""
单调混合网络:将各智能体的 Q_i 混合为全局 Q_tot
关键约束:所有权值非负(保证单调性)
"""
def __init__(self, n_agents, embed_dim=64):
super().__init__()
# Hyper-network 生成混合网络的权值
self.hyper_w1 = nn.Sequential(
nn.Linear(n_agents, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, n_agents * embed_dim), # 输出 (n_agents × embed_dim) 权值
)
self.hyper_b1 = nn.Linear(n_agents, embed_dim)
self.hyper_w2 = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim)
)
self.hyper_b2 = nn.Linear(embed_dim, 1)
def forward(self, q_values, state):
"""
q_values: (batch, n_agents) 各智能体的 Q 值
state: (batch, state_dim) 全局状态(用于生成 hyper-network 输入)
"""
batch_size = q_values.size(0)
# 第一层:W₁ * Q + b₁
w1 = torch.abs(self.hyper_w1(state)) # (batch, n_agents * embed_dim)
w1 = w1.view(batch_size, q_values.size(1), -1) # (batch, n_agents, embed_dim)
b1 = self.hyper_b1(state).unsqueeze(1) # (batch, 1, embed_dim)
q_hidden = torch.relu(torch.bmm(q_values.unsqueeze(1), w1) + b1) # (batch, 1, embed_dim)
# 第二层:W₂ * h + b₂
w2 = torch.abs(self.hyper_w2(q_hidden.squeeze(1))) # (batch, embed_dim, embed_dim)
b2 = self.hyper_b2(q_hidden.squeeze(1)).unsqueeze(1) # (batch, 1, 1)
q_tot = torch.bmm(q_hidden, w2.unsqueeze(1)) + b2 # (batch, 1, 1)
return q_tot.squeeze(-1) # (batch,)
class QMIXAgent:
"""QMIX 算法"""
def __init__(self, obs_dim, action_dim, n_agents, agent_id):
self.agent_id = agent_id
self.action_dim = action_dim
# 每个智能体的 RNN(处理动作-观测历史)
self.rnn = nn.GRUCell(obs_dim + action_dim, obs_dim)
# Q 网络
self.q_net = nn.Sequential(
nn.Linear(obs_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
self.target_rnn = nn.GRUCell(obs_dim + action_dim, obs_dim)
self.target_q_net = nn.Sequential(
nn.Linear(obs_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
self.hard_update()
def hard_update(self):
self.target_rnn.load_state_dict(self.rnn.state_dict())
self.target_q_net.load_state_dict(self.q_net.state_dict())
def get_q_values(self, hidden, obs, last_action):
"""给定 (hidden, obs, last_action) 输出 Q(s,a)"""
rnn_input = torch.cat([obs, last_action], dim=1)
new_hidden = self.rnn(rnn_input, hidden)
q_values = self.q_net(new_hidden)
return q_values, new_hidden
def select_action_epsilon_greedy(self, q_values, epsilon):
"""ε-贪心策略"""
if random.random() < epsilon:
return random.randint(0, self.action_dim - 1)
return q_values.argmax(dim=1).item()
def train_qmix():
"""QMIX 训练循环(伪代码)"""
n_agents = 8
n_episodes = 50000
agents = [QMIXAgent(obs_dim=12, action_dim=5, n_agents=n_agents, agent_id=i)
for i in range(n_agents)]
mixer = QMIXMixingNetwork(n_agents)
optimizers = [optim.Adam(agent.q_net.parameters(), lr=2e-4) for agent in agents]
mixer_optimizer = optim.Adam(mixer.parameters(), lr=2e-4)
replay = ReplayBuffer(capacity=100000)
for ep in range(n_episodes):
# 环境交互
states = env.reset() # (n_agents, obs_dim)
hidden = [torch.zeros(1, 12) for _ in range(n_agents)]
last_actions = [torch.zeros(1, 5) for _ in range(n_agents)]
episode_reward = 0
while not done:
actions = []
for i, agent in enumerate(agents):
q_vals, hidden[i] = agent.get_q_values(hidden[i],
torch.FloatTensor(states[i]).unsqueeze(0),
last_actions[i])
a = agent.select_action_epsilon_greedy(q_vals.squeeze(0), epsilon=0.1)
actions.append(a)
last_actions[i] = torch.zeros(1, 5)
last_actions[i][0, a] = 1.0
next_states, rewards, done = env.step(actions)
replay.push(states, last_actions, rewards, next_states, done)
states = next_states
episode_reward += sum(rewards)
# 学习
if len(replay) > 1024:
batch = replay.sample(32)
# QMIX 损失计算 ...
# 单调混合 + 中心化训练 ...
2.4 MAPPO: 高度に並行したシナリオにおける政策勾配の勝利
MAPPO (マルチエージェント PPO) は、PPO アルゴリズムをマルチエージェント シナリオに拡張し、近年の UAV クラスター タスクで良好なパフォーマンスを示しています (2022 年から 2024 年までの複数の主要なカンファレンス論文)。
PPO の主な利点: 信頼領域の制約により、トレーニングの安定性が確保され、DDPG シリーズのハイパーパラメーターによる災害が回避されます。
PPO -クリップ ターゲット:
は確率比、 は GAE (Generalized Advantage Estimation) です。UAV 競合解決における MAPPO の一般的な構成:
| パラメータ | 推奨値 | 説明 |
|---|
| クリップ率 | 0.2 | PPO のデフォルト |
| ホライゾン | 128–256 | エポックごとのロールアウト ステップの数 |
| PPO エポック | 2–4 | バッチごとに繰り返される更新の数 |
| GAE | 0.95 | 優勢推定のバイアス分散バランス |
| 隠れ層の寸法 | 64–128 | UAV シナリオには十分 |
| 正規化 | OBS + 報酬の正規化 | 鍵!マルチエージェントのコンバージェンスに大きな影響 |
3. GAT: MARL に「誰に従うべきか」を学ばせる
3.1 なぜ MARL にはグラフ構造が必要なのでしょうか?
UAV クラスターでは、すべてのエージェントが同じように重要であるわけではありません。競合の解決を例に挙げます。
- UAV が私に衝突しようとしている → 重大な懸念
- UAV は視界の外にある → 無視しても問題ありません
- 動く障害物に近づく → 動的な注意が必要
ただし、従来の MARL (MADDPG、QMIX など) は、完全に接続されたトポロジ ( 通信) または固定トポロジ (リング、最近傍など) のいずれかで、すべての近隣ノードを平等に扱います。
GAT の導入により、次の 2 つの主要な問題が解決されます。
- 適応近隣重み: アテンション メカニズムを通じて、現在の決定にとってどの近隣がより重要であるかを学習します。
- 拡張性: ドローンの数に応じて増加せず、動的なトポロジーをサポートします
3.2 GAT の基本原則
GAT は、各層のノード の特徴 に対して 近隣集約 を実行し、重みはアテンション メカニズムによって動的に計算されます。$$
\alpha_{ij} = \frac{\exp\left(\text{LeakyReLU}\left(\mathbf{a}^\top[\mathbf{W}\mathbf{h}_i \Vert \mathbf{W}\mathbf{h}j]\right)\right)}
{\sum{k \in \mathcal{N}_i} \exp\left(\text{LeakyReLU}\left(\mathbf{a}^\top[\mathbf{W}\mathbf{h}_i \Vert \mathbf{W}\mathbf{h}_k]\right)\right)}
\mathbf{h}i’ = \sigma\left(\sum{j \in \mathcal{N}i} \alpha{ij} \mathbf{W}\mathbf{h}_j\right)
\pi_{安全な}(s) = \text{Proj}{\mathcal{A}{安全な}(s)} \pi(s)