torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1067 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ from collections.abc import Sequence
9
+ from copy import deepcopy
10
+ from textwrap import indent
11
+
12
+ import numpy as np
13
+ import torch
14
+ from tensordict import TensorDict
15
+ from torch import nn
16
+ from torchrl.data.utils import DEVICE_TYPING
17
+ from torchrl.modules.models import ConvNet, MLP
18
+ from torchrl.modules.models.utils import _reset_parameters_recursive
19
+
20
+
21
+ class MultiAgentNetBase(nn.Module):
22
+ """A base class for multi-agent networks.
23
+
24
+ .. note:: to initialize the MARL module parameters with the `torch.nn.init`
25
+ module, please refer to :meth:`get_stateful_net` and :meth:`from_stateful_net`
26
+ methods.
27
+
28
+ """
29
+
30
+ _empty_net: nn.Module
31
+
32
+ def __init__(
33
+ self,
34
+ *,
35
+ n_agents: int,
36
+ centralized: bool | None = None,
37
+ share_params: bool | None = None,
38
+ agent_dim: int | None = None,
39
+ vmap_randomness: str = "different",
40
+ use_td_params: bool = True,
41
+ **kwargs,
42
+ ):
43
+ super().__init__()
44
+
45
+ # For backward compatibility
46
+ centralized = kwargs.pop("centralised", centralized)
47
+ if centralized is None:
48
+ raise TypeError("centralized arg must be passed.")
49
+ if share_params is None:
50
+ raise TypeError("share_params arg must be passed.")
51
+ if agent_dim is None:
52
+ raise TypeError("agent_dim arg must be passed.")
53
+
54
+ self.use_td_params = use_td_params
55
+ self.n_agents = n_agents
56
+ self.share_params = share_params
57
+ self.centralized = centralized
58
+ self.agent_dim = agent_dim
59
+ self._vmap_randomness = vmap_randomness
60
+
61
+ agent_networks = [
62
+ self._build_single_net(**kwargs)
63
+ for _ in range(self.n_agents if not self.share_params else 1)
64
+ ]
65
+ initialized = True
66
+ for p in agent_networks[0].parameters():
67
+ if isinstance(p, torch.nn.UninitializedParameter):
68
+ initialized = False
69
+ break
70
+ self.initialized = initialized
71
+ self._make_params(agent_networks)
72
+
73
+ # We make sure all params and buffers are on 'meta' device
74
+ # To do this, we set the device keyword arg to 'meta', we also temporarily change
75
+ # the default device. Finally, we convert all params to 'meta' tensors that are not params.
76
+ kwargs["device"] = "meta"
77
+ with torch.device("meta"):
78
+ try:
79
+ self._empty_net = self._build_single_net(**kwargs)
80
+ except NotImplementedError as err:
81
+ if "Cannot copy out of meta tensor" in str(err):
82
+ raise RuntimeError(
83
+ "The network was built using `factory().to(device), build the network directly "
84
+ "on device using `factory(device=device)` instead."
85
+ )
86
+ # Remove all parameters
87
+ TensorDict.from_module(self._empty_net).data.to("meta").to_module(
88
+ self._empty_net
89
+ )
90
+ if not self.use_td_params:
91
+ self.params.to_module(self._empty_net)
92
+
93
+ @property
94
+ def vmap_randomness(self):
95
+ if self.initialized:
96
+ return self._vmap_randomness
97
+ # The class _BatchedUninitializedParameter and buffer are not batched
98
+ # by vmap so using "different" will raise an exception because vmap can't find
99
+ # the batch dimension. This is ok though since we won't have the same config
100
+ # for every element (as one might expect from "same").
101
+ return "same"
102
+
103
+ def _make_params(self, agent_networks):
104
+ if self.share_params:
105
+ self.params = TensorDict.from_module(
106
+ agent_networks[0], as_module=self.use_td_params
107
+ )
108
+ else:
109
+ self.params = TensorDict.from_modules(
110
+ *agent_networks, as_module=self.use_td_params
111
+ )
112
+
113
+ @abc.abstractmethod
114
+ def _build_single_net(self, *, device, **kwargs):
115
+ ...
116
+
117
+ @abc.abstractmethod
118
+ def _pre_forward_check(self, inputs):
119
+ ...
120
+
121
+ @staticmethod
122
+ def vmap_func_module(module, *args, **kwargs):
123
+ def exec_module(params, *input):
124
+ with params.to_module(module):
125
+ return module(*input)
126
+
127
+ return torch.vmap(exec_module, *args, **kwargs)
128
+
129
+ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
130
+ if len(inputs) > 1:
131
+ inputs = torch.cat([*inputs], -1)
132
+ else:
133
+ inputs = inputs[0]
134
+
135
+ # Convert agent_dim to positive index for consistent output placement.
136
+ # This ensures the agent dimension stays at the same position relative
137
+ # to batch dimensions, even if the network changes the number of dimensions
138
+ # (e.g., ConvNet collapses spatial dims).
139
+ # NOTE: Must compute this BEFORE _pre_forward_check, which may modify input shape
140
+ # (e.g., centralized mode flattens the agent dimension).
141
+ agent_dim_positive = self.agent_dim
142
+ if agent_dim_positive < 0:
143
+ agent_dim_positive = inputs.ndim + agent_dim_positive
144
+
145
+ inputs = self._pre_forward_check(inputs)
146
+
147
+ # If parameters are not shared, each agent has its own network
148
+ if not self.share_params:
149
+ if self.centralized:
150
+ output = self.vmap_func_module(
151
+ self._empty_net,
152
+ (0, None),
153
+ (agent_dim_positive,),
154
+ randomness=self.vmap_randomness,
155
+ )(self.params, inputs)
156
+ else:
157
+ output = self.vmap_func_module(
158
+ self._empty_net,
159
+ (0, agent_dim_positive),
160
+ (agent_dim_positive,),
161
+ randomness=self.vmap_randomness,
162
+ )(self.params, inputs)
163
+
164
+ # If parameters are shared, agents use the same network
165
+ else:
166
+ with self.params.to_module(self._empty_net):
167
+ output = self._empty_net(inputs)
168
+
169
+ if self.centralized:
170
+ # If the parameters are shared, and it is centralized, all agents will have the same output
171
+ # We expand it to maintain the agent dimension, but values will be the same for all agents
172
+ n_agent_outputs = output.shape[-1]
173
+ output = output.view(*output.shape[:-1], n_agent_outputs)
174
+ # Insert agent dimension at the correct position
175
+ output = output.unsqueeze(agent_dim_positive)
176
+ # Build the expanded shape
177
+ expand_shape = list(output.shape)
178
+ expand_shape[agent_dim_positive] = self.n_agents
179
+ output = output.expand(*expand_shape)
180
+
181
+ if output.shape[agent_dim_positive] != (self.n_agents):
182
+ raise ValueError(
183
+ f"Multi-agent network expected output with shape[{agent_dim_positive}]={self.n_agents}"
184
+ f" but got {output.shape}"
185
+ )
186
+
187
+ return output
188
+
189
+ def get_stateful_net(self, copy: bool = True):
190
+ """Returns a stateful version of the network.
191
+
192
+ This can be used to initialize parameters.
193
+
194
+ Such networks will often not be callable out-of-the-box and will require a `vmap` call
195
+ to be executable.
196
+
197
+ Args:
198
+ copy (bool, optional): if ``True``, a deepcopy of the network is made.
199
+ Defaults to ``True``.
200
+
201
+ If the parameters are modified in-place (recommended) there is no need to copy the
202
+ parameters back into the MARL module.
203
+ See :meth:`from_stateful_net` for details on how to re-populate the MARL model with
204
+ parameters that have been re-initialized out-of-place.
205
+
206
+ Examples:
207
+ >>> from torchrl.modules import MultiAgentMLP
208
+ >>> import torch
209
+ >>> n_agents = 6
210
+ >>> n_agent_inputs=3
211
+ >>> n_agent_outputs=2
212
+ >>> batch = 64
213
+ >>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
214
+ >>> mlp = MultiAgentMLP(
215
+ ... n_agent_inputs=n_agent_inputs,
216
+ ... n_agent_outputs=n_agent_outputs,
217
+ ... n_agents=n_agents,
218
+ ... centralized=False,
219
+ ... share_params=False,
220
+ ... depth=2,
221
+ ... )
222
+ >>> snet = mlp.get_stateful_net()
223
+ >>> def init(module):
224
+ ... if hasattr(module, "weight"):
225
+ ... torch.nn.init.kaiming_normal_(module.weight)
226
+ >>> snet.apply(init)
227
+ >>> # If the module has been updated out-of-place (not the case here) we can reset the params
228
+ >>> mlp.from_stateful_net(snet)
229
+
230
+ """
231
+ if copy:
232
+ try:
233
+ net = deepcopy(self._empty_net)
234
+ except RuntimeError as err:
235
+ raise RuntimeError(
236
+ "Failed to deepcopy the module, consider using copy=False."
237
+ ) from err
238
+ else:
239
+ net = self._empty_net
240
+ self.params.to_module(net)
241
+ return net
242
+
243
+ def from_stateful_net(self, stateful_net: nn.Module):
244
+ """Populates the parameters given a stateful version of the network.
245
+
246
+ See :meth:`get_stateful_net` for details on how to gather a stateful version of the network.
247
+
248
+ Args:
249
+ stateful_net (nn.Module): the stateful network from which the params should be
250
+ gathered.
251
+
252
+ """
253
+ params = TensorDict.from_module(stateful_net, as_module=True)
254
+ keyset0 = set(params.keys(True, True))
255
+ keyset1 = set(self.params.keys(True, True))
256
+ if keyset0 != keyset1:
257
+ raise RuntimeError(
258
+ f"The keys of params and provided module differ: "
259
+ f"{keyset1 - keyset0} are in self.params and not in the module, "
260
+ f"{keyset0 - keyset1} are in the module but not in self.params."
261
+ )
262
+ self.params.data.update_(params.data)
263
+
264
+ def __repr__(self):
265
+ empty_net = self._empty_net
266
+ with self.params.to_module(empty_net):
267
+ module_repr = indent(str(empty_net), 4 * " ")
268
+ n_agents = indent(f"n_agents={self.n_agents}", 4 * " ")
269
+ share_params = indent(f"share_params={self.share_params}", 4 * " ")
270
+ centralized = indent(f"centralized={self.centralized}", 4 * " ")
271
+ agent_dim = indent(f"agent_dim={self.agent_dim}", 4 * " ")
272
+ return f"{self.__class__.__name__}(\n{module_repr},\n{n_agents},\n{share_params},\n{centralized},\n{agent_dim})"
273
+
274
+ def reset_parameters(self):
275
+ """Resets the parameters of the model."""
276
+
277
+ def vmap_reset_module(module, *args, **kwargs):
278
+ def reset_module(params):
279
+ with params.to_module(module):
280
+ _reset_parameters_recursive(module)
281
+ return params
282
+
283
+ return torch.vmap(reset_module, *args, **kwargs)
284
+
285
+ if not self.share_params:
286
+ vmap_reset_module(self._empty_net, randomness="different")(self.params)
287
+ else:
288
+ with self.params.to_module(self._empty_net):
289
+ _reset_parameters_recursive(self._empty_net)
290
+
291
+
292
+ class MultiAgentMLP(MultiAgentNetBase):
293
+ """Mult-agent MLP.
294
+
295
+ This is an MLP that can be used in multi-agent contexts.
296
+ For example, as a policy or as a value function.
297
+ See `examples/multiagent` for examples.
298
+
299
+ It expects inputs with shape (*B, n_agents, n_agent_inputs)
300
+ It returns outputs with shape (*B, n_agents, n_agent_outputs)
301
+
302
+ If `share_params` is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies).
303
+ Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).
304
+
305
+ If `centralized` is True, each agent will use the inputs of all agents to compute its output
306
+ (n_agent_inputs * n_agents will be the number of inputs for one agent).
307
+ Otherwise, each agent will only use its data as input.
308
+
309
+ Args:
310
+ n_agent_inputs (int or None): number of inputs for each agent. If ``None``,
311
+ the number of inputs is lazily instantiated during the first call.
312
+ n_agent_outputs (int): number of outputs for each agent.
313
+ n_agents (int): number of agents.
314
+
315
+ Keyword Args:
316
+ centralized (bool): If `centralized` is True, each agent will use the inputs of all agents to compute its output
317
+ (n_agent_inputs * n_agents will be the number of inputs for one agent).
318
+ Otherwise, each agent will only use its data as input.
319
+ share_params (bool): If `share_params` is True, the same MLP will be used to make the forward pass
320
+ for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process
321
+ its input (heterogeneous policies).
322
+ device (str or toech.device, optional): device to create the module on.
323
+ depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the
324
+ desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated,
325
+ the depth information should be contained in the num_cells argument (see below). If num_cells is an
326
+ iterable and depth is indicated, both should match: len(num_cells) must be equal to depth.
327
+ default: 3.
328
+ num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
329
+ an integer is provided, every layer will have the same number of cells. If an iterable is provided,
330
+ the linear layers out_features will match the content of num_cells.
331
+ default: 32.
332
+ activation_class (Type[nn.Module]): activation class to be used.
333
+ default: nn.Tanh.
334
+ use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a
335
+ :class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
336
+ If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
337
+ should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
338
+ ``use_td_params=True`` cannot be used when ``use_td_params=False``.
339
+ **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs.
340
+
341
+ .. note:: to initialize the MARL module parameters with the `torch.nn.init`
342
+ module, please refer to :meth:`get_stateful_net` and :meth:`from_stateful_net`
343
+ methods.
344
+
345
+ Examples:
346
+ >>> from torchrl.modules import MultiAgentMLP
347
+ >>> import torch
348
+ >>> n_agents = 6
349
+ >>> n_agent_inputs=3
350
+ >>> n_agent_outputs=2
351
+ >>> batch = 64
352
+ >>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
353
+ >>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy)
354
+ >>> mlp = MultiAgentMLP(
355
+ ... n_agent_inputs=n_agent_inputs,
356
+ ... n_agent_outputs=n_agent_outputs,
357
+ ... n_agents=n_agents,
358
+ ... centralized=False,
359
+ ... share_params=True,
360
+ ... depth=2,
361
+ ... )
362
+ >>> print(mlp)
363
+ MultiAgentMLP(
364
+ (agent_networks): ModuleList(
365
+ (0): MLP(
366
+ (0): Linear(in_features=3, out_features=32, bias=True)
367
+ (1): Tanh()
368
+ (2): Linear(in_features=32, out_features=32, bias=True)
369
+ (3): Tanh()
370
+ (4): Linear(in_features=32, out_features=2, bias=True)
371
+ )
372
+ )
373
+ )
374
+ >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
375
+ Now let's instantiate a centralized network shared by all agents (e.g. a centalised value function)
376
+ >>> mlp = MultiAgentMLP(
377
+ ... n_agent_inputs=n_agent_inputs,
378
+ ... n_agent_outputs=n_agent_outputs,
379
+ ... n_agents=n_agents,
380
+ ... centralized=True,
381
+ ... share_params=True,
382
+ ... depth=2,
383
+ ... )
384
+ >>> print(mlp)
385
+ MultiAgentMLP(
386
+ (agent_networks): ModuleList(
387
+ (0): MLP(
388
+ (0): Linear(in_features=18, out_features=32, bias=True)
389
+ (1): Tanh()
390
+ (2): Linear(in_features=32, out_features=32, bias=True)
391
+ (3): Tanh()
392
+ (4): Linear(in_features=32, out_features=2, bias=True)
393
+ )
394
+ )
395
+ )
396
+ We can see that the input to the first layer is n_agents * n_agent_inputs,
397
+ this is because in the case the net acts as a centralized mlp (like a single huge agent)
398
+ >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
399
+ Outputs will be identical for all agents.
400
+ Now we can do both examples just shown but with an independent set of parameters for each agent
401
+ Let's show the centralized=False case.
402
+ >>> mlp = MultiAgentMLP(
403
+ ... n_agent_inputs=n_agent_inputs,
404
+ ... n_agent_outputs=n_agent_outputs,
405
+ ... n_agents=n_agents,
406
+ ... centralized=False,
407
+ ... share_params=False,
408
+ ... depth=2,
409
+ ... )
410
+ >>> print(mlp)
411
+ MultiAgentMLP(
412
+ (agent_networks): ModuleList(
413
+ (0-5): 6 x MLP(
414
+ (0): Linear(in_features=3, out_features=32, bias=True)
415
+ (1): Tanh()
416
+ (2): Linear(in_features=32, out_features=32, bias=True)
417
+ (3): Tanh()
418
+ (4): Linear(in_features=32, out_features=2, bias=True)
419
+ )
420
+ )
421
+ )
422
+ We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent!
423
+ >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ n_agent_inputs: int | None,
429
+ n_agent_outputs: int,
430
+ n_agents: int,
431
+ *,
432
+ centralized: bool | None = None,
433
+ share_params: bool | None = None,
434
+ device: DEVICE_TYPING | None = None,
435
+ depth: int | None = None,
436
+ num_cells: Sequence | int | None = None,
437
+ activation_class: type[nn.Module] | None = nn.Tanh,
438
+ use_td_params: bool = True,
439
+ **kwargs,
440
+ ):
441
+ self.n_agents = n_agents
442
+ self.n_agent_inputs = n_agent_inputs
443
+ self.n_agent_outputs = n_agent_outputs
444
+ self.share_params = share_params
445
+ self.centralized = centralized
446
+ self.num_cells = num_cells
447
+ self.activation_class = activation_class
448
+ self.depth = depth
449
+
450
+ super().__init__(
451
+ n_agents=n_agents,
452
+ centralized=centralized,
453
+ share_params=share_params,
454
+ device=device,
455
+ agent_dim=-2,
456
+ use_td_params=use_td_params,
457
+ **kwargs,
458
+ )
459
+
460
+ def _pre_forward_check(self, inputs):
461
+ if inputs.shape[-2] != self.n_agents:
462
+ raise ValueError(
463
+ f"Multi-agent network expected input with shape[-2]={self.n_agents},"
464
+ f" but got {inputs.shape}"
465
+ )
466
+ # If the model is centralized, agents have full observability
467
+ if self.centralized:
468
+ inputs = inputs.flatten(-2, -1)
469
+ return inputs
470
+
471
+ def _build_single_net(self, *, device, **kwargs):
472
+ n_agent_inputs = self.n_agent_inputs
473
+ if self.centralized and n_agent_inputs is not None:
474
+ n_agent_inputs = self.n_agent_inputs * self.n_agents
475
+ return MLP(
476
+ in_features=n_agent_inputs,
477
+ out_features=self.n_agent_outputs,
478
+ depth=self.depth,
479
+ num_cells=self.num_cells,
480
+ activation_class=self.activation_class,
481
+ device=device,
482
+ **kwargs,
483
+ )
484
+
485
+
486
+ class MultiAgentConvNet(MultiAgentNetBase):
487
+ """Multi-agent CNN.
488
+
489
+ In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as "centralized" and "non-centralized", respectively.
490
+
491
+ It expects inputs with shape ``(*B, n_agents, channels, x, y)``.
492
+
493
+ .. note:: to initialize the MARL module parameters with the `torch.nn.init`
494
+ module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net`
495
+ methods.
496
+
497
+ Args:
498
+ n_agents (int): number of agents.
499
+ centralized (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input.
500
+ share_params (bool): If ``True``, the same :class:`~torchrl.modules.ConvNet` will be used to make the forward pass
501
+ for all agents (homogeneous policies). Otherwise, each agent will use a different :class:`~torchrl.modules.ConvNet` to process
502
+ its input (heterogeneous policies).
503
+
504
+ Keyword Args:
505
+ in_features (int, optional): the input feature dimension. If left to ``None``,
506
+ a lazy module is used.
507
+ device (str or torch.device, optional): device to create the module on.
508
+ num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
509
+ an integer is provided, every layer will have the same number of cells. If an iterable is provided,
510
+ the linear layers ``out_features`` will match the content of ``num_cells``.
511
+ kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the convolutional network.
512
+ Defaults to ``5``.
513
+ strides (int or Sequence[int]): Stride(s) of the convolutional network. If iterable, the length must match the
514
+ depth, defined by the num_cells or depth arguments.
515
+ Defaults to ``2``.
516
+ activation_class (Type[nn.Module]): activation class to be used.
517
+ Default to :class:`torch.nn.ELU`.
518
+ use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a
519
+ :class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
520
+ If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
521
+ should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
522
+ ``use_td_params=True`` cannot be used when ``use_td_params=False``.
523
+ **kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet.
524
+
525
+
526
+ Examples:
527
+ >>> import torch
528
+ >>> from torchrl.modules import MultiAgentConvNet
529
+ >>> batch = (3,2)
530
+ >>> n_agents = 7
531
+ >>> channels, x, y = 3, 100, 100
532
+ >>> obs = torch.randn(*batch, n_agents, channels, x, y)
533
+ >>> # Let's consider a centralized network with shared parameters.
534
+ >>> cnn = MultiAgentConvNet(
535
+ ... n_agents,
536
+ ... centralized = True,
537
+ ... share_params = True
538
+ ... )
539
+ >>> print(cnn)
540
+ MultiAgentConvNet(
541
+ (agent_networks): ModuleList(
542
+ (0): ConvNet(
543
+ (0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2))
544
+ (1): ELU(alpha=1.0)
545
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
546
+ (3): ELU(alpha=1.0)
547
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
548
+ (5): ELU(alpha=1.0)
549
+ (6): SquashDims()
550
+ )
551
+ )
552
+ )
553
+ >>> result = cnn(obs)
554
+ >>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'.
555
+ >>> print(result.shape)
556
+ torch.Size([3, 2, 7, 2592])
557
+ >>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function)
558
+ >>> print(all(result[0,0,0] == result[0,0,1]))
559
+ True
560
+
561
+ >>> # Alternatively, a local network with parameter sharing (eg. decentralized weight sharing policy)
562
+ >>> cnn = MultiAgentConvNet(
563
+ ... n_agents,
564
+ ... centralized = False,
565
+ ... share_params = True
566
+ ... )
567
+ >>> print(cnn)
568
+ MultiAgentConvNet(
569
+ (agent_networks): ModuleList(
570
+ (0): ConvNet(
571
+ (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
572
+ (1): ELU(alpha=1.0)
573
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
574
+ (3): ELU(alpha=1.0)
575
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
576
+ (5): ELU(alpha=1.0)
577
+ (6): SquashDims()
578
+ )
579
+ )
580
+ )
581
+ >>> print(result.shape)
582
+ torch.Size([3, 2, 7, 2592])
583
+ >>> # Parameters are shared but not observations, hence each agent has a different output.
584
+ >>> print(all(result[0,0,0] == result[0,0,1]))
585
+ False
586
+
587
+ >>> # Or multiple local networks identical in structure but with differing weights.
588
+ >>> cnn = MultiAgentConvNet(
589
+ ... n_agents,
590
+ ... centralized = False,
591
+ ... share_params = False
592
+ ... )
593
+ >>> print(cnn)
594
+ MultiAgentConvNet(
595
+ (agent_networks): ModuleList(
596
+ (0-6): 7 x ConvNet(
597
+ (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
598
+ (1): ELU(alpha=1.0)
599
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
600
+ (3): ELU(alpha=1.0)
601
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
602
+ (5): ELU(alpha=1.0)
603
+ (6): SquashDims()
604
+ )
605
+ )
606
+ )
607
+ >>> print(result.shape)
608
+ torch.Size([3, 2, 7, 2592])
609
+ >>> print(all(result[0,0,0] == result[0,0,1]))
610
+ False
611
+
612
+ >>> # Or where inputs are shared but not parameters.
613
+ >>> cnn = MultiAgentConvNet(
614
+ ... n_agents,
615
+ ... centralized = True,
616
+ ... share_params = False
617
+ ... )
618
+ >>> print(cnn)
619
+ MultiAgentConvNet(
620
+ (agent_networks): ModuleList(
621
+ (0-6): 7 x ConvNet(
622
+ (0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2))
623
+ (1): ELU(alpha=1.0)
624
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
625
+ (3): ELU(alpha=1.0)
626
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
627
+ (5): ELU(alpha=1.0)
628
+ (6): SquashDims()
629
+ )
630
+ )
631
+ )
632
+ >>> print(result.shape)
633
+ torch.Size([3, 2, 7, 2592])
634
+ >>> print(all(result[0,0,0] == result[0,0,1]))
635
+ False
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ n_agents: int,
641
+ centralized: bool | None = None,
642
+ share_params: bool | None = None,
643
+ *,
644
+ in_features: int | None = None,
645
+ device: DEVICE_TYPING | None = None,
646
+ num_cells: Sequence[int] | None = None,
647
+ kernel_sizes: Sequence[int | Sequence[int]] | int = 5,
648
+ strides: Sequence | int = 2,
649
+ paddings: Sequence | int = 0,
650
+ activation_class: type[nn.Module] = nn.ELU,
651
+ use_td_params: bool = True,
652
+ **kwargs,
653
+ ):
654
+ self.in_features = in_features
655
+ self.num_cells = num_cells
656
+ self.strides = strides
657
+ self.kernel_sizes = kernel_sizes
658
+ self.paddings = paddings
659
+ self.activation_class = activation_class
660
+ super().__init__(
661
+ n_agents=n_agents,
662
+ centralized=centralized,
663
+ share_params=share_params,
664
+ device=device,
665
+ agent_dim=-4,
666
+ use_td_params=use_td_params,
667
+ **kwargs,
668
+ )
669
+
670
+ def _build_single_net(self, *, device, **kwargs):
671
+ in_features = self.in_features
672
+ if self.centralized and in_features is not None:
673
+ in_features = in_features * self.n_agents
674
+ return ConvNet(
675
+ in_features=in_features,
676
+ num_cells=self.num_cells,
677
+ kernel_sizes=self.kernel_sizes,
678
+ strides=self.strides,
679
+ paddings=self.paddings,
680
+ activation_class=self.activation_class,
681
+ device=device,
682
+ **kwargs,
683
+ )
684
+
685
+ def _pre_forward_check(self, inputs):
686
+ if len(inputs.shape) < 4:
687
+ raise ValueError(
688
+ """Multi-agent network expects (*batch_size, agent_index, x, y, channels)"""
689
+ )
690
+ if inputs.shape[-4] != self.n_agents:
691
+ raise ValueError(
692
+ f"""Multi-agent network expects {self.n_agents} but got {inputs.shape[-4]}"""
693
+ )
694
+ if self.centralized:
695
+ # If the model is centralized, agents have full observability
696
+ inputs = torch.flatten(inputs, -4, -3)
697
+ return inputs
698
+
699
+
700
+ class Mixer(nn.Module):
701
+ """A multi-agent value mixer.
702
+
703
+ It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1),
704
+ into a global value with shape (*B, 1).
705
+ Used with the :class:`torchrl.objectives.QMixerLoss`.
706
+ See `examples/multiagent/qmix_vdn.py` for examples.
707
+
708
+ Args:
709
+ n_agents (int): number of agents.
710
+ needs_state (bool): whether the mixer takes a global state as input.
711
+ state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions).
712
+ device (str or torch.Device): torch device for the network.
713
+
714
+ Examples:
715
+ Creating a VDN mixer
716
+ >>> import torch
717
+ >>> from tensordict import TensorDict
718
+ >>> from tensordict.nn import TensorDictModule
719
+ >>> from torchrl.modules.models.multiagent import VDNMixer
720
+ >>> n_agents = 4
721
+ >>> vdn = TensorDictModule(
722
+ ... module=VDNMixer(
723
+ ... n_agents=n_agents,
724
+ ... device="cpu",
725
+ ... ),
726
+ ... in_keys=[("agents","chosen_action_value")],
727
+ ... out_keys=["chosen_action_value"],
728
+ ... )
729
+ >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents])}, [32])
730
+ >>> td
731
+ TensorDict(
732
+ fields={
733
+ agents: TensorDict(
734
+ fields={
735
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
736
+ batch_size=torch.Size([32, 4]),
737
+ device=None,
738
+ is_shared=False)},
739
+ batch_size=torch.Size([32]),
740
+ device=None,
741
+ is_shared=False)
742
+ >>> vdn(td)
743
+ TensorDict(
744
+ fields={
745
+ agents: TensorDict(
746
+ fields={
747
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
748
+ batch_size=torch.Size([32, 4]),
749
+ device=None,
750
+ is_shared=False),
751
+ chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
752
+ batch_size=torch.Size([32]),
753
+ device=None,
754
+ is_shared=False)
755
+ Creating a QMix mixer
756
+ >>> import torch
757
+ >>> from tensordict import TensorDict
758
+ >>> from tensordict.nn import TensorDictModule
759
+ >>> from torchrl.modules.models.multiagent import QMixer
760
+ >>> n_agents = 4
761
+ >>> qmix = TensorDictModule(
762
+ ... module=QMixer(
763
+ ... state_shape=(64, 64, 3),
764
+ ... mixing_embed_dim=32,
765
+ ... n_agents=n_agents,
766
+ ... device="cpu",
767
+ ... ),
768
+ ... in_keys=[("agents", "chosen_action_value"), "state"],
769
+ ... out_keys=["chosen_action_value"],
770
+ ... )
771
+ >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32])
772
+ >>> td
773
+ TensorDict(
774
+ fields={
775
+ agents: TensorDict(
776
+ fields={
777
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
778
+ batch_size=torch.Size([32, 4]),
779
+ device=None,
780
+ is_shared=False),
781
+ state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
782
+ batch_size=torch.Size([32]),
783
+ device=None,
784
+ is_shared=False)
785
+ >>> vdn(td)
786
+ TensorDict(
787
+ fields={
788
+ agents: TensorDict(
789
+ fields={
790
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
791
+ batch_size=torch.Size([32, 4]),
792
+ device=None,
793
+ is_shared=False),
794
+ chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
795
+ state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
796
+ batch_size=torch.Size([32]),
797
+ device=None,
798
+ is_shared=False)
799
+ """
800
+
801
+ def __init__(
802
+ self,
803
+ n_agents: int,
804
+ needs_state: bool,
805
+ state_shape: tuple[int, ...] | torch.Size,
806
+ device: DEVICE_TYPING,
807
+ ):
808
+ super().__init__()
809
+
810
+ self.n_agents = n_agents
811
+ self.device = device
812
+ self.needs_state = needs_state
813
+ self.state_shape = state_shape
814
+
815
+ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
816
+ """Forward pass of the mixer.
817
+
818
+ Args:
819
+ *inputs: The first input should be the value of the chosen action of shape (*B, self.n_agents, 1),
820
+ representing the local q value of each agent.
821
+ The second input (optional, used only in some mixers)
822
+ is the shared state of all agents of shape (*B, *self.state_shape).
823
+
824
+ Returns:
825
+ The global value of the chosen actions obtained after mixing, with shape (*B, 1)
826
+
827
+ """
828
+ if not self.needs_state:
829
+ if len(inputs) > 1:
830
+ raise ValueError(
831
+ "Mixer that doesn't need state was passed more than 1 input"
832
+ )
833
+ chosen_action_value = inputs[0]
834
+ else:
835
+ if len(inputs) != 2:
836
+ raise ValueError("Mixer that needs state was passed more than 2 inputs")
837
+
838
+ chosen_action_value, state = inputs
839
+
840
+ if state.shape[-len(self.state_shape) :] != self.state_shape:
841
+ raise ValueError(
842
+ f"Mixer network expected state with ending shape {self.state_shape},"
843
+ f" but got state shape {state.shape}"
844
+ )
845
+
846
+ if chosen_action_value.shape[-2:] != (self.n_agents, 1):
847
+ raise ValueError(
848
+ f"Mixer network expected chosen_action_value with last 2 dimensions {(self.n_agents,1)},"
849
+ f" but got {chosen_action_value.shape}"
850
+ )
851
+ batch_dims = chosen_action_value.shape[:-2]
852
+
853
+ if not self.needs_state:
854
+ output = self.mix(chosen_action_value, None)
855
+ else:
856
+ output = self.mix(chosen_action_value, state)
857
+
858
+ if output.shape != (*batch_dims, 1):
859
+ raise ValueError(
860
+ f"Mixer network expected output with same shape as input minus the multi-agent dimension,"
861
+ f" but got {output.shape}"
862
+ )
863
+
864
+ return output
865
+
866
+ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor):
867
+ """Forward pass for the mixer.
868
+
869
+ Args:
870
+ chosen_action_value: Tensor of shape [*B, n_agents]
871
+
872
+ Returns:
873
+ chosen_action_value: Tensor of shape [*B]
874
+ """
875
+ raise NotImplementedError
876
+
877
+
878
+ class VDNMixer(Mixer):
879
+ """Value-Decomposition Network mixer.
880
+
881
+ Mixes the local Q values of the agents into a global Q value by summing them together.
882
+ From the paper https://arxiv.org/abs/1706.05296 .
883
+
884
+ It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1),
885
+ into a global value with shape (*B, 1).
886
+ Used with the :class:`torchrl.objectives.QMixerLoss`.
887
+ See `examples/multiagent/qmix_vdn.py` for examples.
888
+
889
+ Args:
890
+ n_agents (int): number of agents.
891
+ device (str or torch.Device): torch device for the network.
892
+
893
+ Examples:
894
+ >>> import torch
895
+ >>> from tensordict import TensorDict
896
+ >>> from tensordict.nn import TensorDictModule
897
+ >>> from torchrl.modules.models.multiagent import VDNMixer
898
+ >>> n_agents = 4
899
+ >>> vdn = TensorDictModule(
900
+ ... module=VDNMixer(
901
+ ... n_agents=n_agents,
902
+ ... device="cpu",
903
+ ... ),
904
+ ... in_keys=[("agents","chosen_action_value")],
905
+ ... out_keys=["chosen_action_value"],
906
+ ... )
907
+ >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents])}, [32])
908
+ >>> td
909
+ TensorDict(
910
+ fields={
911
+ agents: TensorDict(
912
+ fields={
913
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
914
+ batch_size=torch.Size([32, 4]),
915
+ device=None,
916
+ is_shared=False)},
917
+ batch_size=torch.Size([32]),
918
+ device=None,
919
+ is_shared=False)
920
+ >>> vdn(td)
921
+ TensorDict(
922
+ fields={
923
+ agents: TensorDict(
924
+ fields={
925
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
926
+ batch_size=torch.Size([32, 4]),
927
+ device=None,
928
+ is_shared=False),
929
+ chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
930
+ batch_size=torch.Size([32]),
931
+ device=None,
932
+ is_shared=False)
933
+ """
934
+
935
+ def __init__(
936
+ self,
937
+ n_agents: int,
938
+ device: DEVICE_TYPING,
939
+ ):
940
+ super().__init__(
941
+ needs_state=False,
942
+ state_shape=torch.Size([]),
943
+ n_agents=n_agents,
944
+ device=device,
945
+ )
946
+
947
+ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor):
948
+ return chosen_action_value.sum(dim=-2)
949
+
950
+
951
+ class QMixer(Mixer):
952
+ """QMix mixer.
953
+
954
+ Mixes the local Q values of the agents into a global Q value through a monotonic
955
+ hyper-network whose parameters are obtained from a global state.
956
+ From the paper https://arxiv.org/abs/1803.11485 .
957
+
958
+ It transforms the local value of each agent's chosen action of shape (*B, self.n_agents, 1),
959
+ into a global value with shape (*B, 1).
960
+ Used with the :class:`torchrl.objectives.QMixerLoss`.
961
+ See `examples/multiagent/qmix_vdn.py` for examples.
962
+
963
+ Args:
964
+ state_shape (tuple or torch.Size): the shape of the state (excluding eventual leading batch dimensions).
965
+ mixing_embed_dim (int): the size of the mixing embedded dimension.
966
+ n_agents (int): number of agents.
967
+ device (str or torch.Device): torch device for the network.
968
+
969
+ Examples:
970
+ >>> import torch
971
+ >>> from tensordict import TensorDict
972
+ >>> from tensordict.nn import TensorDictModule
973
+ >>> from torchrl.modules.models.multiagent import QMixer
974
+ >>> n_agents = 4
975
+ >>> qmix = TensorDictModule(
976
+ ... module=QMixer(
977
+ ... state_shape=(64, 64, 3),
978
+ ... mixing_embed_dim=32,
979
+ ... n_agents=n_agents,
980
+ ... device="cpu",
981
+ ... ),
982
+ ... in_keys=[("agents", "chosen_action_value"), "state"],
983
+ ... out_keys=["chosen_action_value"],
984
+ ... )
985
+ >>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32])
986
+ >>> td
987
+ TensorDict(
988
+ fields={
989
+ agents: TensorDict(
990
+ fields={
991
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
992
+ batch_size=torch.Size([32, 4]),
993
+ device=None,
994
+ is_shared=False),
995
+ state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
996
+ batch_size=torch.Size([32]),
997
+ device=None,
998
+ is_shared=False)
999
+ >>> vdn(td)
1000
+ TensorDict(
1001
+ fields={
1002
+ agents: TensorDict(
1003
+ fields={
1004
+ chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1005
+ batch_size=torch.Size([32, 4]),
1006
+ device=None,
1007
+ is_shared=False),
1008
+ chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1009
+ state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
1010
+ batch_size=torch.Size([32]),
1011
+ device=None,
1012
+ is_shared=False)
1013
+ """
1014
+
1015
+ def __init__(
1016
+ self,
1017
+ state_shape: tuple[int, ...] | torch.Size,
1018
+ mixing_embed_dim: int,
1019
+ n_agents: int,
1020
+ device: DEVICE_TYPING,
1021
+ ):
1022
+ super().__init__(
1023
+ needs_state=True, state_shape=state_shape, n_agents=n_agents, device=device
1024
+ )
1025
+
1026
+ self.embed_dim = mixing_embed_dim
1027
+ self.state_dim = int(np.prod(state_shape))
1028
+
1029
+ self.hyper_w_1 = nn.Linear(
1030
+ self.state_dim, self.embed_dim * self.n_agents, device=self.device
1031
+ )
1032
+ self.hyper_w_final = nn.Linear(
1033
+ self.state_dim, self.embed_dim, device=self.device
1034
+ )
1035
+
1036
+ # State dependent bias for hidden layer
1037
+ self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim, device=self.device)
1038
+
1039
+ # V(s) instead of a bias for the last layers
1040
+ self.V = nn.Sequential(
1041
+ nn.Linear(self.state_dim, self.embed_dim, device=self.device),
1042
+ nn.ReLU(),
1043
+ nn.Linear(self.embed_dim, 1, device=self.device),
1044
+ )
1045
+
1046
+ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor):
1047
+ bs = chosen_action_value.shape[:-2]
1048
+ state = state.view(-1, self.state_dim)
1049
+ chosen_action_value = chosen_action_value.view(-1, 1, self.n_agents)
1050
+ # First layer
1051
+ w1 = torch.abs(self.hyper_w_1(state))
1052
+ b1 = self.hyper_b_1(state)
1053
+ w1 = w1.view(-1, self.n_agents, self.embed_dim)
1054
+ b1 = b1.view(-1, 1, self.embed_dim)
1055
+ hidden = nn.functional.elu(
1056
+ torch.bmm(chosen_action_value, w1) + b1
1057
+ ) # [-1, 1, self.embed_dim]
1058
+ # Second layer
1059
+ w_final = torch.abs(self.hyper_w_final(state))
1060
+ w_final = w_final.view(-1, self.embed_dim, 1)
1061
+ # State-dependent bias
1062
+ v = self.V(state).view(-1, 1, 1)
1063
+ # Compute final output
1064
+ y = torch.bmm(hidden, w_final) + v # [-1, 1, 1]
1065
+ # Reshape and return
1066
+ q_tot = y.view(*bs, 1)
1067
+ return q_tot