news 2026/1/17 16:04:20

Ray 分布式训练的多智能体路径规划强化学习踩坑记录

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Ray 分布式训练的多智能体路径规划强化学习踩坑记录

Ray 分布式训练的多智能体路径规划强化学习项目

本文基于本仓库代码(train.py / worker.py / environment.py / model.py等),介绍如何用Ray 分布式 Actor–Learner训练一个带可学习通信模块的去中心化多智能体路径规划(MAPF)策略,并总结工程实现中的关键点与常见问题。

关键词(建议保留):MAPF多智能体强化学习DQNRayActor-Learner分布式训练通信注意力PyTorchdtype/AMP

1. 背景:去中心化 MAPF 与分布式强化学习

**MAPF(Multi-Agent Path Finding)**的典型目标是:在带障碍的网格地图中,多个智能体从各自起点出发到达各自目标点,要求尽量少碰撞/冲突、尽快完成。

本仓库采用去中心化执行(每个 agent 根据局部观测决策),训练阶段使用分布式 off-policy 强化学习(DQN 风格),通过多 Actor 并行采样 + 单 Learner 更新参数的方式提升数据吞吐。

2. 工程总览:核心文件与职责

  • train.py
    • 训练入口:启动 Ray,创建并启动GlobalBufferLearner、多个Actor
    • 启动时打印torch.cuda.is_available()与 GPU 信息,并优先选择默认设备(可用则 GPU)。
  • worker.py
    • @ray.remote远程组件:
      • GlobalBuffer:全局优先级经验回放池(Prioritized Replay)+ 后台 batch 预取。
      • Learner:执行训练更新、维护 target network、对外提供最新权重。
      • Actor:与环境交互采样,产生 episode 经验并写入回放池。
  • environment.py
    • 网格环境实现:地图生成、观测构造、冲突检测、奖励计算、可视化辅助。
  • model.py
    • 网络结构:CNN 编码器 + GRU(时序记忆)+ 通信模块(多头注意力)+ Dueling Q 头。
  • buffer.py
    • SumTree:优先级采样结构。
    • LocalBuffer:单个 episode 的暂存与 TD-error 计算。
  • configs.py
    • 环境参数、训练参数、通信参数、课程学习参数、测试参数统一配置。

3. 算法与训练架构:Ray Actor–Learner(DQN 风格)

整体流程可以理解为一个“高吞吐数据生成 + 稳定参数更新”的流水线:

  1. Actors 并行采样
  • 每个Actor持有一个Environment与一份Network(推理用)。
  • 循环执行:
    • env.reset()获取初始观测
    • model.step(obs, pos)选择动作(epsilon-greedy)
    • env.step(actions)与环境交互
    • 将 transition 写入LocalBuffer
  • episode 结束时LocalBuffer.finish()打包整段轨迹并GlobalBuffer.add.remote(data)发送到全局回放池。
  1. GlobalBuffer 统一存储与优先级采样
  • GlobalBuffer用大数组存放多 episode 数据(obs/action/reward/hidden/mask 等)。
  • 维护SumTree实现Prioritized Experience Replay
  • 后台线程prepare_data()会提前准备训练 batch,减小 Learner 等待。
  1. Learner 单点训练与参数广播
  • Learner在初始化时选择设备:torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • 持有modeltar_model(target network),周期性同步。
  • GlobalBuffer拉取 batch,计算 TD loss,反向传播更新。
  • Actor定期调用learner.get_weights()拉取最新参数并更新本地推理网络。

这种结构的优势在于:

  • 多 Actor 并行采样提高数据吞吐
  • Learner 单点更新便于控制优化器与 target network 同步
  • 回放池解耦采样与训练,提升稳定性

4. 环境设计:网格世界、冲突规则与奖励

environment.py中:

  • 动作空间(5 维):停留 / 上 / 下 / 左 / 右
  • 地图生成:按障碍密度随机生成 0/1 网格,并确保至少存在可用连通区域用于采样起点/终点。
  • 冲突处理
    • 越界/撞墙:回退并给 collision 惩罚
    • 交换位置冲突(swap):双方回退并惩罚
    • 其他同格冲突处理(文件后半段)
  • 奖励函数configs.reward_fn控制,例如:
    • move:小负值
    • collision:更大负值
    • finish:正奖励

环境还构造了启发式相关特征(如到目标的距离梯度),用于增强观测信息。

5. 模型设计:CNN + GRU + 通信注意力 + Dueling Q

model.pyNetwork主要由四部分组成:

  1. 局部观测编码(CNN)
  • 将局部栅格观测编码为 latent 向量。
  • 使用残差块与CPCA(通道/空间注意力模块)提升表征能力。
  1. 时序记忆(GRUCell)
  • 在 step 推理时维护 hidden state,使 agent 具备一定记忆能力。
  • 在训练 forward 时按序列展开,并取指定 step 的 hidden 用于 Q 估计。
  1. 通信模块(CommBlock + Multi-Head Attention)
  • 根据 agent 之间相对距离与视野构造通信 mask。
  • 对通信邻居做多头注意力聚合并用 GRUCell 更新隐藏表征。
  1. Dueling Q 头
  • V(s)+A(s,a)组合得到Q(s,a),提升稳定性。

6. 如何运行与复现实验

6.1 训练

在已安装依赖的环境中运行:

python train.py

启动时会打印:

  • torch.cuda.is_available()
  • GPU 数量与名称(如可用)
  • default device selected: cuda|cpu

6.2 配置项

直接修改configs.py

  • 训练规模:num_actors,batch_size,learning_starts,training_times
  • 环境规模:init_env_settings,max_num_agents,max_map_lenght
  • 通信配置:max_comm_agents,num_comm_layers,num_comm_heads

6.3 生成测试集/评测

python test.py

测试集位于./test_set,评测时会从./models读取权重(详见test.py内的test_model)。

7. 工程踩坑:CPU/GPU 与 dtype(FP16/FP32)一致性

分布式训练中最常见的问题之一,是dtype 或 device 不一致导致的运行时报错,典型表现例如:

  • Input type (Half) and bias type (float) should be the same
  • mat1 and mat2 must have the same dtype, but got Half and float

这类问题的根因通常是:

  • 回放池/采样数据是 FP16
  • 模型参数是 FP32
  • AMP/autocast 使部分中间结果变成 FP16
  • CPU 上对 FP16 的算子支持不完整

解决思路(建议择一策略贯彻到底):

  • 策略 A:全链路 FP32(最稳,CPU/GPU 都可)

    • 采样 batch 用 float32
    • 模型参数 float32
    • 禁用 AMP(或仅在 GPU 上谨慎启用)
  • 策略 B:全链路 AMP/GPU(性能更好,但约束更多)

    • Learner 必须在 GPU
    • 输入/中间状态/损失计算路径遵守 AMP 规则
    • 关键张量与参数 dtype 要统一

本仓库已在模型训练前向中做了 dtype 对齐处理,以降低 dtype 混用导致的报错概率。

8. 下一步可以改进什么

  • 增加requirements.txtenvironment.yml,让依赖版本可复现。
  • 将 checkpoint 保存/加载流程与评测流程在 README/博客中进一步标准化。
  • 为训练与评测增加更清晰的日志与可视化(例如 TensorBoard)。

参考与致谢

  • 原始 DHC 项目与示意图来源:
    • https://github.com/ZiyuanMa/DHC
      .yml`,让依赖版本可复现。
  • 将 checkpoint 保存/加载流程与评测流程在 README/博客中进一步标准化。
  • 为训练与评测增加更清晰的日志与可视化(例如 TensorBoard)。

参考与致谢

  • 原始 DHC 项目与示意图来源:
    • https://github.com/ZiyuanMa/DHC
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/17 1:08:14

Font Awesome 品牌图标

Font Awesome 品牌图标(Brands Icons)详解 Font Awesome 的 Brands Icons 是专门用于展示知名品牌、公司、社交媒体、软件和技术平台的图标集合。这些图标都是矢量形式,基于官方商标设计,但 Font Awesome 强调:所有品…

作者头像 李华
网站建设 2026/1/15 2:02:22

SQLite Studio终极指南:一站式数据库管理解决方案

SQLite Studio终极指南:一站式数据库管理解决方案 【免费下载链接】sqlite-studio SQLite database explorer 项目地址: https://gitcode.com/gh_mirrors/sq/sqlite-studio 在数据驱动的时代,高效管理SQLite数据库成为开发者和数据分析师的必备技…

作者头像 李华
网站建设 2026/1/8 10:47:02

SQLPad查询结果缓存终极配置技巧:让你的重复查询速度提升10倍

SQLPad查询结果缓存终极配置技巧:让你的重复查询速度提升10倍 【免费下载链接】sqlpad Web-based SQL editor. Legacy project in maintenance mode. 项目地址: https://gitcode.com/gh_mirrors/sq/sqlpad 还在为SQLPad中重复查询的缓慢响应而烦恼吗&#xf…

作者头像 李华
网站建设 2026/1/15 1:33:12

Windows系统性能加速神器:OpenSpeedy完全使用手册

Windows系统性能加速神器:OpenSpeedy完全使用手册 【免费下载链接】OpenSpeedy 项目地址: https://gitcode.com/gh_mirrors/op/OpenSpeedy 还在为电脑卡顿、程序响应慢而烦恼吗?今天为大家推荐一款真正免费且高效的Windows系统加速工具——OpenS…

作者头像 李华
网站建设 2026/1/9 9:56:27

2026软件测试岗必问的100个面试题【含答案】

一、测试理论 1.什么是软件测试? 答:软件测试是通过执行预定的步骤和使用指定的数据,以确定软件系统在特定条件下是否满足预期的行为。 2.测试驱动开发(TDD)是什么? 答:测试驱动开发是一种开…

作者头像 李华
网站建设 2026/1/17 4:55:50

UABEA完全攻略:解锁Unity游戏资源提取与修改的终极指南

UABEA(Unity Asset Bundle Extractor Avalonia)是一款专为新版本Unity设计的开源资源提取工具,能够深度解析和操作Unity游戏中的各种资源文件。无论你是游戏开发者、资源分析师,还是游戏爱好者,UABEA都能为你打开一扇通…

作者头像 李华