torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1042 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import copy
8
+ import importlib
9
+ import warnings
10
+
11
+ import numpy as np
12
+ import packaging
13
+ import torch
14
+ from tensordict import TensorDictBase
15
+
16
+ from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded
17
+ from torchrl.envs.common import _EnvWrapper
18
+ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend
19
+ from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
20
+
21
+ _has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None
22
+
23
+
24
+ def _get_envs():
25
+ if not _has_pettingzoo:
26
+ raise ImportError("PettingZoo is not installed in your virtual environment.")
27
+ try:
28
+ from pettingzoo.utils.all_modules import all_environments
29
+ except ModuleNotFoundError as err:
30
+ warnings.warn(
31
+ f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
32
+ )
33
+ all_environments = _load_available_envs()
34
+
35
+ return list(all_environments.keys())
36
+
37
+
38
+ def _load_available_envs() -> dict:
39
+ all_environments = {}
40
+ try:
41
+ from pettingzoo.mpe.all_modules import mpe_environments
42
+
43
+ all_environments.update(mpe_environments)
44
+ except ModuleNotFoundError as err:
45
+ warnings.warn(f"MPE environments failed to load with error message {err}.")
46
+ try:
47
+ from pettingzoo.sisl.all_modules import sisl_environments
48
+
49
+ all_environments.update(sisl_environments)
50
+ except ModuleNotFoundError as err:
51
+ warnings.warn(f"SISL environments failed to load with error message {err}.")
52
+ try:
53
+ from pettingzoo.classic.all_modules import classic_environments
54
+
55
+ all_environments.update(classic_environments)
56
+ except ModuleNotFoundError as err:
57
+ warnings.warn(f"Classic environments failed to load with error message {err}.")
58
+ try:
59
+ from pettingzoo.atari.all_modules import atari_environments
60
+
61
+ all_environments.update(atari_environments)
62
+ except ModuleNotFoundError as err:
63
+ warnings.warn(f"Atari environments failed to load with error message {err}.")
64
+ try:
65
+ from pettingzoo.butterfly.all_modules import butterfly_environments
66
+
67
+ all_environments.update(butterfly_environments)
68
+ except ModuleNotFoundError as err:
69
+ warnings.warn(
70
+ f"Butterfly environments failed to load with error message {err}."
71
+ )
72
+ return all_environments
73
+
74
+
75
+ def _extract_nested_with_index(data: np.ndarray | dict[str, np.ndarray], index: int):
76
+ if isinstance(data, np.ndarray):
77
+ return data[index]
78
+ elif isinstance(data, dict):
79
+ return {
80
+ key: _extract_nested_with_index(value, index) for key, value in data.items()
81
+ }
82
+ else:
83
+ raise NotImplementedError(f"Invalid type of data {data}")
84
+
85
+
86
+ class PettingZooWrapper(_EnvWrapper):
87
+ """PettingZoo environment wrapper.
88
+
89
+ To install petting zoo follow the guide `here <https://github.com/Farama-Foundation/PettingZoo#installation>__`.
90
+
91
+ This class is a general torchrl wrapper for all PettingZoo environments.
92
+ It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``.
93
+
94
+ Let's see how more in details:
95
+
96
+ In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step.
97
+ If the number of agents during the task varies, please set ``use_mask=True``.
98
+ ``"mask"`` will be provided
99
+ as an output in each group and should be used to mask out dead agents.
100
+ The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
101
+
102
+ In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
103
+ For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
104
+ ``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
105
+ The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
106
+
107
+ If there are any unavailable actions for an agent,
108
+ the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
109
+ for each group to reflect the latest available actions. This should be passed to a masked distribution during
110
+ training.
111
+
112
+ As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment.
113
+ You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same
114
+ neural network. You can split agents in different groups where they are heterogenous or should be processed by
115
+ different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time.
116
+
117
+ By default, agents in pettingzoo will be grouped by name.
118
+ For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like:
119
+
120
+ >>> print(env.rand_action(env.reset()))
121
+ TensorDict(
122
+ fields={
123
+ agent: TensorDict(
124
+ fields={
125
+ action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False),
126
+ action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False),
127
+ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
128
+ observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
129
+ terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
130
+ truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
131
+ batch_size=torch.Size([3]))},
132
+ adversary: TensorDict(
133
+ fields={
134
+ action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False),
135
+ action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
136
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
137
+ observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
138
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
139
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
140
+ batch_size=torch.Size([1]))},
141
+ batch_size=torch.Size([]))
142
+ >>> print(env.group_map)
143
+ {"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]}
144
+
145
+ Otherwise, a group map can be specified or selected from some premade options.
146
+ See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
147
+ For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
148
+ have its own tensordict (similar to the pettingzoo parallel API).
149
+
150
+ Grouping is useful for leveraging vectorization among agents whose data goes through the same
151
+ neural network.
152
+
153
+ Args:
154
+ env (``pettingzoo.utils.env.ParallelEnv`` or ``pettingzoo.utils.env.AECEnv``): the pettingzoo environment to wrap.
155
+ return_state (bool, optional): whether to return the global state from pettingzoo
156
+ (not available in all environments). Defaults to ``False``.
157
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for
158
+ input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified
159
+ or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
160
+ use_mask (bool, optional): whether the environment should output a ``"mask"``. This is compulsory in
161
+ wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used
162
+ for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``.
163
+ categorical_actions (bool, optional): if the environments actions are discrete, whether to transform
164
+ them to categorical or one-hot.
165
+ seed (int, optional): the seed. Defaults to ``None``.
166
+ done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
167
+ using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
168
+ parallel environments and ``all()`` for AEC ones.
169
+
170
+ Examples:
171
+ >>> # Parallel env
172
+ >>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper
173
+ >>> from pettingzoo.butterfly import pistonball_v6
174
+ >>> kwargs = {"n_pistons": 21, "continuous": True}
175
+ >>> env = PettingZooWrapper(
176
+ ... env=pistonball_v6.parallel_env(**kwargs),
177
+ ... return_state=True,
178
+ ... group_map=None, # Use default for parallel (all pistons grouped together)
179
+ ... )
180
+ >>> print(env.group_map)
181
+ ... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']}
182
+ >>> env.rollout(10)
183
+ >>> # AEC env
184
+ >>> from pettingzoo.classic import tictactoe_v3
185
+ >>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper
186
+ >>> from torchrl.envs.utils import MarlGroupMapType
187
+ >>> env = PettingZooWrapper(
188
+ ... env=tictactoe_v3.env(),
189
+ ... use_mask=True, # Must use it since one player plays at a time
190
+ ... group_map=None # # Use default for AEC (one group per player)
191
+ ... )
192
+ >>> print(env.group_map)
193
+ ... {'player_1': ['player_1'], 'player_2': ['player_2']}
194
+ >>> env.rollout(10)
195
+ """
196
+
197
+ git_url = "https://github.com/Farama-Foundation/PettingZoo"
198
+ libname = "pettingzoo"
199
+
200
+ @_classproperty
201
+ def available_envs(cls):
202
+ if not _has_pettingzoo:
203
+ return []
204
+ return list(_get_envs())
205
+
206
+ def __init__(
207
+ self,
208
+ env: (
209
+ pettingzoo.utils.env.ParallelEnv # noqa: F821
210
+ | pettingzoo.utils.env.AECEnv # noqa: F821
211
+ ) = None,
212
+ return_state: bool = False,
213
+ group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
214
+ use_mask: bool = False,
215
+ categorical_actions: bool = True,
216
+ seed: int | None = None,
217
+ done_on_any: bool | None = None,
218
+ **kwargs,
219
+ ):
220
+ if env is not None:
221
+ kwargs["env"] = env
222
+
223
+ self.group_map = group_map
224
+ self.return_state = return_state
225
+ self.seed = seed
226
+ self.use_mask = use_mask
227
+ self.categorical_actions = categorical_actions
228
+ self.done_on_any = done_on_any
229
+
230
+ super().__init__(**kwargs, allow_done_after_reset=True)
231
+
232
+ def _get_default_group_map(self, agent_names: list[str]):
233
+ # This function performs the default grouping in pettingzoo
234
+ if not self.parallel:
235
+ # In AEC envs we will have one group per agent by default
236
+ group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names)
237
+ else:
238
+ # In parallel envs, by default
239
+ # Agents with names "str_int" will be grouped in group name "str"
240
+ group_map = {}
241
+ for agent_name in agent_names:
242
+ # See if the agent follows the convention "name_int"
243
+ follows_convention = True
244
+ agent_name_split = agent_name.split("_")
245
+ if len(agent_name_split) == 1:
246
+ follows_convention = False
247
+ try:
248
+ int(agent_name_split[-1])
249
+ except ValueError:
250
+ follows_convention = False
251
+
252
+ # If not, just put it in a single group
253
+ if not follows_convention:
254
+ group_map[agent_name] = [agent_name]
255
+ # Otherwise, group it with other agents that follow the same convention
256
+ else:
257
+ group_name = "_".join(agent_name_split[:-1])
258
+ if group_name in group_map:
259
+ group_map[group_name].append(agent_name)
260
+ else:
261
+ group_map[group_name] = [agent_name]
262
+
263
+ return group_map
264
+
265
+ @property
266
+ def lib(self):
267
+ import pettingzoo
268
+
269
+ return pettingzoo
270
+
271
+ def _build_env(
272
+ self,
273
+ env: (
274
+ pettingzoo.utils.env.ParallelEnv # noqa: F821
275
+ | pettingzoo.utils.env.AECEnv # noqa: F821
276
+ ),
277
+ ):
278
+ import pettingzoo
279
+
280
+ if packaging.version.parse(pettingzoo.__version__).base_version != "1.24.3":
281
+ warnings.warn(
282
+ "PettingZoo in TorchRL is tested using version == 1.24.3 , "
283
+ "If you are using a different version and are experiencing compatibility issues,"
284
+ "please raise an issue in the TorchRL github."
285
+ )
286
+
287
+ self.parallel = isinstance(env, pettingzoo.utils.env.ParallelEnv)
288
+ if not self.parallel and not self.use_mask:
289
+ raise ValueError("For AEC environments you need to set use_mask=True")
290
+ if len(self.batch_size):
291
+ raise RuntimeError(
292
+ f"PettingZoo does not support custom batch_size {self.batch_size}."
293
+ )
294
+
295
+ return env
296
+
297
+ @set_gym_backend("gymnasium")
298
+ def _make_specs(
299
+ self,
300
+ env: (
301
+ pettingzoo.utils.env.ParallelEnv # noqa: F821
302
+ | pettingzoo.utils.env.AECEnv # noqa: F821
303
+ ),
304
+ ) -> None:
305
+ # Set default for done on any or all
306
+ if self.done_on_any is None:
307
+ self.done_on_any = self.parallel
308
+
309
+ # Create and check group map
310
+ if self.group_map is None:
311
+ self.group_map = self._get_default_group_map(self.possible_agents)
312
+ elif isinstance(self.group_map, MarlGroupMapType):
313
+ self.group_map = self.group_map.get_group_map(self.possible_agents)
314
+ check_marl_grouping(self.group_map, self.possible_agents)
315
+ self.has_action_mask = {group: False for group in self.group_map.keys()}
316
+
317
+ action_spec = Composite()
318
+ observation_spec = Composite()
319
+ reward_spec = Composite()
320
+ done_spec = Composite(
321
+ {
322
+ "done": Categorical(
323
+ n=2,
324
+ shape=torch.Size((1,)),
325
+ dtype=torch.bool,
326
+ device=self.device,
327
+ ),
328
+ "terminated": Categorical(
329
+ n=2,
330
+ shape=torch.Size((1,)),
331
+ dtype=torch.bool,
332
+ device=self.device,
333
+ ),
334
+ "truncated": Categorical(
335
+ n=2,
336
+ shape=torch.Size((1,)),
337
+ dtype=torch.bool,
338
+ device=self.device,
339
+ ),
340
+ },
341
+ )
342
+ for group, agents in self.group_map.items():
343
+ (
344
+ group_observation_spec,
345
+ group_action_spec,
346
+ group_reward_spec,
347
+ group_done_spec,
348
+ ) = self._make_group_specs(group_name=group, agent_names=agents)
349
+ action_spec[group] = group_action_spec
350
+ observation_spec[group] = group_observation_spec
351
+ reward_spec[group] = group_reward_spec
352
+ done_spec[group] = group_done_spec
353
+
354
+ self.action_spec = action_spec
355
+ self.observation_spec = observation_spec
356
+ self.reward_spec = reward_spec
357
+ self.done_spec = done_spec
358
+
359
+ def _make_group_specs(self, group_name: str, agent_names: list[str]):
360
+ n_agents = len(agent_names)
361
+ action_specs = []
362
+ observation_specs = []
363
+ for agent in agent_names:
364
+ action_specs.append(
365
+ Composite(
366
+ {
367
+ "action": _gym_to_torchrl_spec_transform(
368
+ self.action_space(agent),
369
+ remap_state_to_observation=False,
370
+ categorical_action_encoding=self.categorical_actions,
371
+ device=self.device,
372
+ )
373
+ },
374
+ )
375
+ )
376
+ observation_specs.append(
377
+ Composite(
378
+ {
379
+ "observation": _gym_to_torchrl_spec_transform(
380
+ self.observation_space(agent),
381
+ remap_state_to_observation=False,
382
+ device=self.device,
383
+ )
384
+ }
385
+ )
386
+ )
387
+ group_action_spec = torch.stack(action_specs, dim=0)
388
+ group_observation_spec = torch.stack(observation_specs, dim=0)
389
+
390
+ # Sometimes the observation spec contains an action mask.
391
+ # Or sometimes the info spec contains an action mask.
392
+ # We uniform this by removing it from both places and optionally set it in a standard location.
393
+ group_observation_inner_spec = group_observation_spec["observation"]
394
+ if (
395
+ isinstance(group_observation_inner_spec, Composite)
396
+ and "action_mask" in group_observation_inner_spec.keys()
397
+ ):
398
+ self.has_action_mask[group_name] = True
399
+ del group_observation_inner_spec["action_mask"]
400
+ group_observation_spec["action_mask"] = Categorical(
401
+ n=2,
402
+ shape=group_action_spec["action"].shape
403
+ if not self.categorical_actions
404
+ else group_action_spec["action"].to_one_hot_spec().shape,
405
+ dtype=torch.bool,
406
+ device=self.device,
407
+ )
408
+
409
+ if self.use_mask:
410
+ group_observation_spec["mask"] = Categorical(
411
+ n=2,
412
+ shape=torch.Size((n_agents,)),
413
+ dtype=torch.bool,
414
+ device=self.device,
415
+ )
416
+
417
+ group_reward_spec = Composite(
418
+ {
419
+ "reward": Unbounded(
420
+ shape=torch.Size((n_agents, 1)),
421
+ device=self.device,
422
+ dtype=torch.float32,
423
+ )
424
+ },
425
+ shape=torch.Size((n_agents,)),
426
+ )
427
+ group_done_spec = Composite(
428
+ {
429
+ "done": Categorical(
430
+ n=2,
431
+ shape=torch.Size((n_agents, 1)),
432
+ dtype=torch.bool,
433
+ device=self.device,
434
+ ),
435
+ "terminated": Categorical(
436
+ n=2,
437
+ shape=torch.Size((n_agents, 1)),
438
+ dtype=torch.bool,
439
+ device=self.device,
440
+ ),
441
+ "truncated": Categorical(
442
+ n=2,
443
+ shape=torch.Size((n_agents, 1)),
444
+ dtype=torch.bool,
445
+ device=self.device,
446
+ ),
447
+ },
448
+ shape=torch.Size((n_agents,)),
449
+ )
450
+ return (
451
+ group_observation_spec,
452
+ group_action_spec,
453
+ group_reward_spec,
454
+ group_done_spec,
455
+ )
456
+
457
+ def _check_kwargs(self, kwargs: dict):
458
+ import pettingzoo
459
+
460
+ if "env" not in kwargs:
461
+ raise TypeError("Could not find environment key 'env' in kwargs.")
462
+ env = kwargs["env"]
463
+ if not isinstance(
464
+ env, (pettingzoo.utils.env.ParallelEnv, pettingzoo.utils.env.AECEnv)
465
+ ):
466
+ raise TypeError("env is not of type expected.")
467
+
468
+ def _init_env(self):
469
+ # Add info
470
+ if self.parallel:
471
+ _, info_dict = self._reset_parallel(seed=self.seed)
472
+ else:
473
+ _, info_dict = self._reset_aec(seed=self.seed)
474
+
475
+ for group, agents in self.group_map.items():
476
+ info_specs = []
477
+ for agent in agents:
478
+ info_specs.append(
479
+ Composite(
480
+ {
481
+ "info": Composite(
482
+ {
483
+ key: Unbounded(
484
+ shape=torch.as_tensor(value).shape,
485
+ device=self.device,
486
+ )
487
+ for key, value in info_dict[agent].items()
488
+ }
489
+ )
490
+ },
491
+ device=self.device,
492
+ )
493
+ )
494
+ info_specs = torch.stack(info_specs, dim=0)
495
+ if ("info", "action_mask") in info_specs.keys(True, True):
496
+ if not self.has_action_mask[group]:
497
+ self.has_action_mask[group] = True
498
+ group_action_spec = self.input_spec[
499
+ "full_action_spec", group, "action"
500
+ ]
501
+ self.observation_spec[group]["action_mask"] = Categorical(
502
+ n=2,
503
+ shape=group_action_spec.shape
504
+ if not self.categorical_actions
505
+ else group_action_spec.to_one_hot_spec().shape,
506
+ dtype=torch.bool,
507
+ device=self.device,
508
+ )
509
+ group_inner_info_spec = info_specs["info"]
510
+ del group_inner_info_spec["action_mask"]
511
+
512
+ if len(info_specs["info"].keys()):
513
+ self.observation_spec[group].update(info_specs)
514
+
515
+ if self.return_state:
516
+ try:
517
+ state_spec = _gym_to_torchrl_spec_transform(
518
+ self.state_space,
519
+ remap_state_to_observation=False,
520
+ device=self.device,
521
+ )
522
+ except AttributeError:
523
+ state_example = torch.as_tensor(self.state(), device=self.device)
524
+ state_spec = Unbounded(
525
+ shape=state_example.shape,
526
+ dtype=state_example.dtype,
527
+ device=self.device,
528
+ )
529
+ self.observation_spec["state"] = state_spec
530
+
531
+ # Caching
532
+ self.cached_reset_output_zero = self.observation_spec.zero()
533
+ self.cached_reset_output_zero.update(self.output_spec["full_done_spec"].zero())
534
+
535
+ self.cached_step_output_zero = self.observation_spec.zero()
536
+ self.cached_step_output_zero.update(self.output_spec["full_reward_spec"].zero())
537
+ self.cached_step_output_zero.update(self.output_spec["full_done_spec"].zero())
538
+
539
+ def _set_seed(self, seed: int | None) -> None:
540
+ self.seed = seed
541
+ self.reset(seed=self.seed)
542
+
543
+ def _reset(
544
+ self, tensordict: TensorDictBase | None = None, **kwargs
545
+ ) -> TensorDictBase:
546
+ if tensordict is not None:
547
+ _reset = tensordict.get("_reset", None)
548
+ if _reset is not None and not _reset.all():
549
+ raise RuntimeError(
550
+ f"An attempt to call {type(self)}._reset was made when no "
551
+ f"reset signal could be found. Expected '_reset' entry to "
552
+ f"be `tensor(True)` or `None` but got `{_reset}`."
553
+ )
554
+ if self.parallel:
555
+ # This resets when any is done
556
+ observation_dict, info_dict = self._reset_parallel(**kwargs)
557
+ else:
558
+ # This resets when all are done
559
+ observation_dict, info_dict = self._reset_aec(**kwargs)
560
+
561
+ # We start with zeroed data and fill in the data for alive agents
562
+ tensordict_out = self.cached_reset_output_zero.clone()
563
+ # Update the "mask" for non-acting agents
564
+ self._update_agent_mask(tensordict_out)
565
+ # Update the "action_mask" for non-available actions
566
+ observation_dict, info_dict = self._update_action_mask(
567
+ tensordict_out, observation_dict, info_dict
568
+ )
569
+
570
+ # Now we get the data (obs and info)
571
+ for group, agent_names in self.group_map.items():
572
+ group_observation = tensordict_out.get((group, "observation"))
573
+ group_info = tensordict_out.get((group, "info"), None)
574
+
575
+ for index, agent in enumerate(agent_names):
576
+ group_observation[index] = self.observation_spec[group, "observation"][
577
+ index
578
+ ].encode(observation_dict[agent])
579
+ if group_info is not None:
580
+ agent_info_dict = info_dict[agent]
581
+ for agent_info, value in agent_info_dict.items():
582
+ group_info.get(agent_info)[index] = torch.as_tensor(
583
+ value, device=self.device
584
+ )
585
+
586
+ if self.return_state:
587
+ state = torch.as_tensor(self.state(), device=self.device)
588
+ tensordict_out.set("state", state)
589
+
590
+ return tensordict_out
591
+
592
+ def _reset_aec(self, **kwargs) -> tuple[dict, dict]:
593
+ self._env.reset(**kwargs)
594
+
595
+ observation_dict = {
596
+ agent: self._env.observe(agent) for agent in self.possible_agents
597
+ }
598
+ info_dict = self._env.infos
599
+ return observation_dict, info_dict
600
+
601
+ def _reset_parallel(self, **kwargs) -> tuple[dict, dict]:
602
+ return self._env.reset(**kwargs)
603
+
604
+ def _step(
605
+ self,
606
+ tensordict: TensorDictBase,
607
+ ) -> TensorDictBase:
608
+ if self.parallel:
609
+ (
610
+ observation_dict,
611
+ rewards_dict,
612
+ terminations_dict,
613
+ truncations_dict,
614
+ info_dict,
615
+ ) = self._step_parallel(tensordict)
616
+ else:
617
+ (
618
+ observation_dict,
619
+ rewards_dict,
620
+ terminations_dict,
621
+ truncations_dict,
622
+ info_dict,
623
+ ) = self._step_aec(tensordict)
624
+
625
+ # We start with zeroed data and fill in the data for alive agents
626
+ tensordict_out = self.cached_step_output_zero.clone()
627
+ # Update the "mask" for non-acting agents
628
+ self._update_agent_mask(tensordict_out)
629
+ # Update the "action_mask" for non-available actions
630
+ observation_dict, info_dict = self._update_action_mask(
631
+ tensordict_out, observation_dict, info_dict
632
+ )
633
+
634
+ # Now we get the data
635
+ for group, agent_names in self.group_map.items():
636
+ group_observation = tensordict_out.get((group, "observation"))
637
+ group_reward = tensordict_out.get((group, "reward"))
638
+ group_done = tensordict_out.get((group, "done"))
639
+ group_terminated = tensordict_out.get((group, "terminated"))
640
+ group_truncated = tensordict_out.get((group, "truncated"))
641
+ group_info = tensordict_out.get((group, "info"), None)
642
+
643
+ for index, agent in enumerate(agent_names):
644
+ if agent in observation_dict: # Live agents
645
+ group_observation[index] = self.observation_spec[
646
+ group, "observation"
647
+ ][index].encode(observation_dict[agent])
648
+ group_reward[index] = torch.tensor(
649
+ rewards_dict[agent],
650
+ device=self.device,
651
+ dtype=torch.float32,
652
+ )
653
+ group_done[index] = torch.tensor(
654
+ terminations_dict[agent] or truncations_dict[agent],
655
+ device=self.device,
656
+ dtype=torch.bool,
657
+ )
658
+ group_truncated[index] = torch.tensor(
659
+ truncations_dict[agent],
660
+ device=self.device,
661
+ dtype=torch.bool,
662
+ )
663
+ group_terminated[index] = torch.tensor(
664
+ terminations_dict[agent],
665
+ device=self.device,
666
+ dtype=torch.bool,
667
+ )
668
+
669
+ if group_info is not None:
670
+ agent_info_dict = info_dict[agent]
671
+ for agent_info, value in agent_info_dict.items():
672
+ group_info.get(agent_info)[index] = torch.tensor(
673
+ value, device=self.device
674
+ )
675
+
676
+ elif self.use_mask:
677
+ if agent in self.agents:
678
+ raise ValueError(
679
+ f"Dead agent {agent} not found in step observation but still available in {self.agents}"
680
+ )
681
+ # Dead agent
682
+ terminated = (
683
+ terminations_dict[agent] if agent in terminations_dict else True
684
+ )
685
+ truncated = (
686
+ truncations_dict[agent] if agent in truncations_dict else True
687
+ )
688
+ done = terminated or truncated
689
+ group_done[index] = done
690
+ group_terminated[index] = terminated
691
+ group_truncated[index] = truncated
692
+
693
+ else:
694
+ # Dead agent, if we are not masking it out, this is not allowed
695
+ raise ValueError(
696
+ "Dead agents found in the environment,"
697
+ " you need to set use_mask=True to allow this."
698
+ )
699
+
700
+ # set done values
701
+ done, terminated, truncated = self._aggregate_done(
702
+ tensordict_out, use_any=self.done_on_any
703
+ )
704
+
705
+ tensordict_out.set("done", done)
706
+ tensordict_out.set("terminated", terminated)
707
+ tensordict_out.set("truncated", truncated)
708
+
709
+ if self.return_state:
710
+ state = torch.as_tensor(self.state(), device=self.device)
711
+ tensordict_out.set("state", state)
712
+
713
+ return tensordict_out
714
+
715
+ def _aggregate_done(self, tensordict_out, use_any):
716
+ done = False if use_any else True
717
+ truncated = False if use_any else True
718
+ terminated = False if use_any else True
719
+ for key in self.done_keys:
720
+ if isinstance(key, tuple): # Only look at group keys
721
+ if use_any:
722
+ if key[-1] == "done":
723
+ done = done | tensordict_out.get(key).any()
724
+ if key[-1] == "terminated":
725
+ terminated = terminated | tensordict_out.get(key).any()
726
+ if key[-1] == "truncated":
727
+ truncated = truncated | tensordict_out.get(key).any()
728
+ if done and terminated and truncated:
729
+ # no need to proceed further, all values are flipped
730
+ break
731
+ else:
732
+ if key[-1] == "done":
733
+ done = done & tensordict_out.get(key).all()
734
+ if key[-1] == "terminated":
735
+ terminated = terminated & tensordict_out.get(key).all()
736
+ if key[-1] == "truncated":
737
+ truncated = truncated & tensordict_out.get(key).all()
738
+ if not done and not terminated and not truncated:
739
+ # no need to proceed further, all values are flipped
740
+ break
741
+ return (
742
+ torch.tensor([done], device=self.device),
743
+ torch.tensor([terminated], device=self.device),
744
+ torch.tensor([truncated], device=self.device),
745
+ )
746
+
747
+ def _step_parallel(
748
+ self,
749
+ tensordict: TensorDictBase,
750
+ ) -> tuple[dict, dict, dict, dict, dict]:
751
+ action_dict = {}
752
+ for group, agents in self.group_map.items():
753
+ group_action = tensordict.get((group, "action"))
754
+ group_action_np = self.input_spec[
755
+ "full_action_spec", group, "action"
756
+ ].to_numpy(group_action)
757
+ for index, agent in enumerate(agents):
758
+ # group_action_np can be a dict or an array. We need to recursively index it
759
+ action = _extract_nested_with_index(group_action_np, index)
760
+ action_dict[agent] = action
761
+
762
+ return self._env.step(action_dict)
763
+
764
+ def _step_aec(
765
+ self,
766
+ tensordict: TensorDictBase,
767
+ ) -> tuple[dict, dict, dict, dict, dict]:
768
+ for group, agents in self.group_map.items():
769
+ if self.agent_selection in agents:
770
+ agent_index = agents.index(self._env.agent_selection)
771
+ group_action = tensordict.get((group, "action"))
772
+ group_action_np = self.input_spec[
773
+ "full_action_spec", group, "action"
774
+ ].to_numpy(group_action)
775
+ # group_action_np can be a dict or an array. We need to recursively index it
776
+ action = _extract_nested_with_index(group_action_np, agent_index)
777
+ break
778
+
779
+ self._env.step(action)
780
+ terminations_dict = self._env.terminations
781
+ truncations_dict = self._env.truncations
782
+ info_dict = self._env.infos
783
+ rewards_dict = self._env.rewards
784
+ observation_dict = {
785
+ agent: self._env.observe(agent) for agent in self.possible_agents
786
+ }
787
+ return (
788
+ observation_dict,
789
+ rewards_dict,
790
+ terminations_dict,
791
+ truncations_dict,
792
+ info_dict,
793
+ )
794
+
795
+ def _update_action_mask(self, td, observation_dict, info_dict):
796
+ # Since we remove the action_mask keys we need to copy the data
797
+ observation_dict = copy.deepcopy(observation_dict)
798
+ info_dict = copy.deepcopy(info_dict)
799
+ # In AEC only one agent acts, in parallel env self.agents contains the agents alive
800
+ agents_acting = self.agents if self.parallel else [self.agent_selection]
801
+
802
+ for group, agents in self.group_map.items():
803
+ if self.has_action_mask[group]:
804
+ group_mask = td.get((group, "action_mask"))
805
+ group_mask += True
806
+ for index, agent in enumerate(agents):
807
+ agent_obs = observation_dict[agent]
808
+ agent_info = info_dict[agent]
809
+ if isinstance(agent_obs, dict) and "action_mask" in agent_obs:
810
+ if agent in agents_acting:
811
+ group_mask[index] = torch.tensor(
812
+ agent_obs["action_mask"],
813
+ device=self.device,
814
+ dtype=torch.bool,
815
+ )
816
+ del agent_obs["action_mask"]
817
+ elif isinstance(agent_info, dict) and "action_mask" in agent_info:
818
+ if agent in agents_acting:
819
+ group_mask[index] = torch.tensor(
820
+ agent_info["action_mask"],
821
+ device=self.device,
822
+ dtype=torch.bool,
823
+ )
824
+ del agent_info["action_mask"]
825
+
826
+ group_action_spec = self.input_spec["full_action_spec", group, "action"]
827
+ if isinstance(group_action_spec, (Categorical, OneHot)):
828
+ # We update the mask for available actions
829
+ group_action_spec.update_mask(group_mask.clone())
830
+
831
+ return observation_dict, info_dict
832
+
833
+ def _update_agent_mask(self, td):
834
+ if self.use_mask:
835
+ # In AEC only one agent acts, in parallel env self.agents contains the agents alive
836
+ agents_acting = self.agents if self.parallel else [self.agent_selection]
837
+ for group, agents in self.group_map.items():
838
+ group_mask = td.get((group, "mask"))
839
+ group_mask += True
840
+
841
+ # We now add dead agents to the mask
842
+ for index, agent in enumerate(agents):
843
+ if agent not in agents_acting:
844
+ group_mask[index] = False
845
+
846
+ def close(self, *, raise_if_closed: bool = True) -> None:
847
+ self._env.close()
848
+
849
+
850
+ class PettingZooEnv(PettingZooWrapper):
851
+ """PettingZoo Environment.
852
+
853
+ To install petting zoo follow the guide `here <https://github.com/Farama-Foundation/PettingZoo#installation>__`.
854
+
855
+ This class is a general torchrl wrapper for all PettingZoo environments.
856
+ It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``.
857
+
858
+ Let's see how more in details:
859
+
860
+ For wrapping ``pettingzoo.ParallelEnv`` provide the name of your petting zoo task (in the ``task`` argument)
861
+ and specify ``parallel=True``. This will construct the ``pettingzoo.ParallelEnv`` version of that task
862
+ (if it is supported in pettingzoo) and wrap it for torchrl.
863
+ In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step.
864
+ If the number of agents during the task varies, please set ``use_mask=True``.
865
+ ``"mask"`` will be provided
866
+ as an output in each group and should be used to mask out dead agents.
867
+ The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
868
+
869
+ For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument)
870
+ and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task
871
+ and wrap it for torchrl.
872
+ In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
873
+ For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
874
+ ``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
875
+ The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
876
+
877
+ If there are any unavailable actions for an agent,
878
+ the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
879
+ for each group to reflect the latest available actions. This should be passed to a masked distribution during
880
+ training.
881
+
882
+ As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment.
883
+ You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same
884
+ neural network. You can split agents in different groups where they are heterogenous or should be processed by
885
+ different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time.
886
+
887
+ By default, agents in pettingzoo will be grouped by name.
888
+ For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like:
889
+
890
+ >>> print(env.rand_action(env.reset()))
891
+ TensorDict(
892
+ fields={
893
+ agent: TensorDict(
894
+ fields={
895
+ action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False),
896
+ action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False),
897
+ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
898
+ observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
899
+ terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
900
+ truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
901
+ batch_size=torch.Size([3]))},
902
+ adversary: TensorDict(
903
+ fields={
904
+ action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False),
905
+ action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
906
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
907
+ observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
908
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
909
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
910
+ batch_size=torch.Size([1]))},
911
+ batch_size=torch.Size([]))
912
+ >>> print(env.group_map)
913
+ {"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]}
914
+
915
+ Otherwise, a group map can be specified or selected from some premade options.
916
+ See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
917
+ For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
918
+ have its own tensordict (similar to the pettingzoo parallel API).
919
+
920
+ Grouping is useful for leveraging vectorization among agents whose data goes through the same
921
+ neural network.
922
+
923
+ Args:
924
+ task (str): the name of the pettingzoo task to create in the "<env>/<task>" format (for example, "sisl/multiwalker_v9")
925
+ or "<task>" format (for example, "multiwalker_v9").
926
+ parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``.
927
+ return_state (bool, optional): whether to return the global state from pettingzoo
928
+ (not available in all environments). Defaults to ``False``.
929
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for
930
+ input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified
931
+ or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
932
+ use_mask (bool, optional): whether the environment should output an ``"mask"``. This is compulsory in
933
+ wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used
934
+ for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``.
935
+ categorical_actions (bool, optional): if the environments actions are discrete, whether to transform
936
+ them to categorical or one-hot.
937
+ seed (int, optional): the seed. Defaults to ``None``.
938
+ done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
939
+ using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
940
+ parallel environments and ``all()`` for AEC ones.
941
+
942
+ Examples:
943
+ >>> # Parallel env
944
+ >>> from torchrl.envs.libs.pettingzoo import PettingZooEnv
945
+ >>> kwargs = {"n_pistons": 21, "continuous": True}
946
+ >>> env = PettingZooEnv(
947
+ ... task="pistonball_v6",
948
+ ... parallel=True,
949
+ ... return_state=True,
950
+ ... group_map=None, # Use default (all pistons grouped together)
951
+ ... **kwargs,
952
+ ... )
953
+ >>> print(env.group_map)
954
+ ... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']}
955
+ >>> env.rollout(10)
956
+ >>> # AEC env
957
+ >>> from torchrl.envs.libs.pettingzoo import PettingZooEnv
958
+ >>> from torchrl.envs.utils import MarlGroupMapType
959
+ >>> env = PettingZooEnv(
960
+ ... task="tictactoe_v3",
961
+ ... parallel=False,
962
+ ... use_mask=True, # Must use it since one player plays at a time
963
+ ... group_map=None # # Use default for AEC (one group per player)
964
+ ... )
965
+ >>> print(env.group_map)
966
+ ... {'player_1': ['player_1'], 'player_2': ['player_2']}
967
+ >>> env.rollout(10)
968
+ """
969
+
970
+ def __init__(
971
+ self,
972
+ task: str,
973
+ parallel: bool,
974
+ return_state: bool = False,
975
+ group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
976
+ use_mask: bool = False,
977
+ categorical_actions: bool = True,
978
+ seed: int | None = None,
979
+ done_on_any: bool | None = None,
980
+ **kwargs,
981
+ ):
982
+ if not _has_pettingzoo:
983
+ raise ImportError(
984
+ f"pettingzoo python package was not found. Please install this dependency. "
985
+ f"More info: {self.git_url}."
986
+ )
987
+ kwargs["task"] = task
988
+ kwargs["parallel"] = parallel
989
+ kwargs["return_state"] = return_state
990
+ kwargs["group_map"] = group_map
991
+ kwargs["use_mask"] = use_mask
992
+ kwargs["categorical_actions"] = categorical_actions
993
+ kwargs["seed"] = seed
994
+ kwargs["done_on_any"] = done_on_any
995
+
996
+ super().__init__(**kwargs)
997
+
998
+ def _check_kwargs(self, kwargs: dict):
999
+ if "task" not in kwargs:
1000
+ raise TypeError("Could not find environment key 'task' in kwargs.")
1001
+ if "parallel" not in kwargs:
1002
+ raise TypeError("Could not find environment key 'parallel' in kwargs.")
1003
+
1004
+ def _build_env(
1005
+ self,
1006
+ task: str,
1007
+ parallel: bool,
1008
+ **kwargs,
1009
+ ) -> (
1010
+ pettingzoo.utils.env.ParallelEnv # noqa: F821
1011
+ | pettingzoo.utils.env.AECEnv # noqa: F821
1012
+ ):
1013
+ self.task_name = task
1014
+
1015
+ try:
1016
+ from pettingzoo.utils.all_modules import all_environments
1017
+ except ModuleNotFoundError as err:
1018
+ warnings.warn(
1019
+ f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
1020
+ )
1021
+ all_environments = _load_available_envs()
1022
+
1023
+ if task not in all_environments:
1024
+ # Try looking at the literal translation of values
1025
+ task_module = None
1026
+ for value in all_environments.values():
1027
+ if value.__name__.split(".")[-1] == task:
1028
+ task_module = value
1029
+ break
1030
+ if task_module is None:
1031
+ raise RuntimeError(
1032
+ f"Specified task not in available environments {all_environments}"
1033
+ )
1034
+ else:
1035
+ task_module = all_environments[task]
1036
+
1037
+ if parallel:
1038
+ petting_zoo_env = task_module.parallel_env(**kwargs)
1039
+ else:
1040
+ petting_zoo_env = task_module.env(**kwargs)
1041
+
1042
+ return super()._build_env(env=petting_zoo_env)