torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,599 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ from collections.abc import Mapping, Sequence
9
+
10
+ import torch
11
+ from tensordict import TensorDict, TensorDictBase
12
+
13
+ from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec
14
+ from torchrl.envs.common import _EnvWrapper
15
+ from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform
16
+ from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
17
+
18
+ _has_meltingpot = importlib.util.find_spec("meltingpot") is not None
19
+
20
+ PLAYER_STR_FORMAT = "player_{index}"
21
+ _WORLD_PREFIX = "WORLD."
22
+
23
+
24
+ def _get_envs():
25
+ if not _has_meltingpot:
26
+ raise ImportError("meltingpot is not installed in your virtual environment.")
27
+ from meltingpot.configs import substrates as substrate_configs
28
+
29
+ return list(substrate_configs.SUBSTRATES)
30
+
31
+
32
+ def _filter_global_state_from_dict(obs_dict: dict, world: bool) -> dict: # noqa
33
+ return {
34
+ key: value
35
+ for key, value in obs_dict.items()
36
+ if ((_WORLD_PREFIX not in key) if not world else (_WORLD_PREFIX in key))
37
+ }
38
+
39
+
40
+ def _remove_world_observations_from_obs_spec(
41
+ observation_spec: Sequence[Mapping[str, dm_env.specs.Array]], # noqa
42
+ ) -> Sequence[Mapping[str, dm_env.specs.Array]]: # noqa
43
+ return [
44
+ _filter_global_state_from_dict(agent_obs, world=False)
45
+ for agent_obs in observation_spec
46
+ ]
47
+
48
+
49
+ def _global_state_spec_from_obs_spec(
50
+ observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] # noqa
51
+ ) -> Mapping[str, dm_env.specs.Array]: # noqa
52
+ # We only look at agent 0 since world entries are the same for all agents
53
+ world_entries = _filter_global_state_from_dict(observation_spec[0], world=True)
54
+ if len(world_entries) != 1 and _WORLD_PREFIX + "RGB" not in world_entries:
55
+ raise ValueError(
56
+ f"Expected only one world entry named {_WORLD_PREFIX}RGB in observation_spec, but got {world_entries}"
57
+ )
58
+ return _remove_world_prefix(world_entries)
59
+
60
+
61
+ def _remove_world_prefix(world_entries: dict) -> dict:
62
+ return {key[len(_WORLD_PREFIX) :]: value for key, value in world_entries.items()}
63
+
64
+
65
+ class MeltingpotWrapper(_EnvWrapper):
66
+ """Meltingpot environment wrapper.
67
+
68
+ GitHub: https://github.com/google-deepmind/meltingpot
69
+
70
+ Paper: https://arxiv.org/abs/2211.13746
71
+
72
+ Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals,
73
+ and has been designed to test a broad range of social interactions such as: cooperation, competition, deception,
74
+ reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent
75
+ reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios
76
+ on which to evaluate these trained agents.
77
+
78
+ Args:
79
+ env (``meltingpot.utils.substrates.substrate.Substrate``): the meltingpot substrate to wrap.
80
+
81
+ Keyword Args:
82
+ max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon).
83
+ Each Meltingpot substrate can
84
+ be terminating or not. If ``max_steps`` is specified,
85
+ the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached.
86
+ Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`,
87
+ this argument will not set the ``"truncated"`` entry in the tensordict.
88
+ categorical_actions (bool, optional): if the environment actions are discrete, whether to transform
89
+ them to categorical or one-hot. Defaults to ``True``.
90
+ group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for
91
+ input/output. By default, they will be all put
92
+ in one group named ``"agents"``.
93
+ Otherwise, a group map can be specified or selected from some premade options.
94
+ See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
95
+
96
+ Attributes:
97
+ group_map (Dict[str, List[str]]): how to group agents in tensordicts for
98
+ input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
99
+ agent_names (list of str): names of the agent in the environment
100
+ agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment
101
+ available_envs (List[str]): the list of the scenarios available to build.
102
+
103
+ .. warning::
104
+ Meltingpot returns a single ``done`` flag which does not distinguish between
105
+ when the env reached ``max_steps`` and termination.
106
+ If you deem the ``truncation`` signal necessary, set ``max_steps`` to
107
+ ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform.
108
+
109
+ Examples:
110
+ >>> from meltingpot import substrate
111
+ >>> from torchrl.envs.libs.meltingpot import MeltingpotWrapper
112
+ >>> substrate_config = substrate.get_config("commons_harvest__open")
113
+ >>> mp_env = substrate.build_from_config(
114
+ ... substrate_config, roles=substrate_config.default_player_roles
115
+ ... )
116
+ >>> env_torchrl = MeltingpotWrapper(env=mp_env)
117
+ >>> print(env_torchrl.rollout(max_steps=5))
118
+ TensorDict(
119
+ fields={
120
+ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
121
+ agents: TensorDict(
122
+ fields={
123
+ action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False),
124
+ observation: TensorDict(
125
+ fields={
126
+ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
127
+ READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
128
+ RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
129
+ batch_size=torch.Size([5, 7]),
130
+ device=cpu,
131
+ is_shared=False)},
132
+ batch_size=torch.Size([5, 7]),
133
+ device=cpu,
134
+ is_shared=False),
135
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136
+ next: TensorDict(
137
+ fields={
138
+ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
139
+ agents: TensorDict(
140
+ fields={
141
+ observation: TensorDict(
142
+ fields={
143
+ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
144
+ READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
145
+ RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
146
+ batch_size=torch.Size([5, 7]),
147
+ device=cpu,
148
+ is_shared=False),
149
+ reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)},
150
+ batch_size=torch.Size([5, 7]),
151
+ device=cpu,
152
+ is_shared=False),
153
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
154
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
155
+ batch_size=torch.Size([5]),
156
+ device=cpu,
157
+ is_shared=False),
158
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
159
+ batch_size=torch.Size([5]),
160
+ device=cpu,
161
+ is_shared=False)
162
+
163
+ """
164
+
165
+ git_url = "https://github.com/google-deepmind/meltingpot"
166
+ libname = "melitingpot"
167
+
168
+ @property
169
+ def lib(self):
170
+ import meltingpot
171
+
172
+ return meltingpot
173
+
174
+ @_classproperty
175
+ def available_envs(cls):
176
+ if not _has_meltingpot:
177
+ return []
178
+ return _get_envs()
179
+
180
+ def __init__(
181
+ self,
182
+ env: meltingpot.utils.substrates.substrate.Substrate = None, # noqa
183
+ categorical_actions: bool = True,
184
+ group_map: MarlGroupMapType
185
+ | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
186
+ max_steps: int | None = None,
187
+ **kwargs,
188
+ ):
189
+ if env is not None:
190
+ kwargs["env"] = env
191
+ self.group_map = group_map
192
+ self.categorical_actions = categorical_actions
193
+ self.max_steps = max_steps
194
+ self.num_cycles = 0
195
+ super().__init__(**kwargs)
196
+
197
+ def _build_env(
198
+ self,
199
+ env: meltingpot.utils.substrates.substrate.Substrate, # noqa
200
+ ):
201
+ return env
202
+
203
+ def _make_group_map(self):
204
+ if isinstance(self.group_map, MarlGroupMapType):
205
+ self.group_map = self.group_map.get_group_map(self.agent_names)
206
+ check_marl_grouping(self.group_map, self.agent_names)
207
+
208
+ def _make_specs(
209
+ self, env: meltingpot.utils.substrates.substrate.Substrate # noqa
210
+ ) -> None:
211
+ mp_obs_spec = self._env.observation_spec() # List of dict of arrays
212
+ mp_obs_spec_no_world = _remove_world_observations_from_obs_spec(
213
+ mp_obs_spec
214
+ ) # List of dict of arrays
215
+ mp_global_state_spec = _global_state_spec_from_obs_spec(
216
+ mp_obs_spec
217
+ ) # Dict of arrays
218
+ mp_act_spec = self._env.action_spec() # List of discrete arrays
219
+ mp_rew_spec = self._env.reward_spec() # List of arrays
220
+
221
+ torchrl_agent_obs_specs = [
222
+ _dmcontrol_to_torchrl_spec_transform(agent_obs_spec)
223
+ for agent_obs_spec in mp_obs_spec_no_world
224
+ ]
225
+ torchrl_agent_act_specs = [
226
+ _dmcontrol_to_torchrl_spec_transform(
227
+ agent_act_spec, categorical_discrete_encoding=self.categorical_actions
228
+ )
229
+ for agent_act_spec in mp_act_spec
230
+ ]
231
+ torchrl_state_spec = _dmcontrol_to_torchrl_spec_transform(mp_global_state_spec)
232
+ torchrl_rew_spec = [
233
+ _dmcontrol_to_torchrl_spec_transform(agent_rew_spec)
234
+ for agent_rew_spec in mp_rew_spec
235
+ ]
236
+
237
+ # Create and check group map
238
+ _num_players = len(torchrl_rew_spec)
239
+ self.agent_names = [
240
+ PLAYER_STR_FORMAT.format(index=index) for index in range(_num_players)
241
+ ]
242
+ self.agent_names_to_indices_map = {
243
+ agent_name: i for i, agent_name in enumerate(self.agent_names)
244
+ }
245
+ self._make_group_map()
246
+
247
+ action_spec = Composite()
248
+ observation_spec = Composite()
249
+ reward_spec = Composite()
250
+
251
+ for group in self.group_map.keys():
252
+ (
253
+ group_observation_spec,
254
+ group_action_spec,
255
+ group_reward_spec,
256
+ ) = self._make_group_specs(
257
+ group,
258
+ torchrl_agent_obs_specs,
259
+ torchrl_agent_act_specs,
260
+ torchrl_rew_spec,
261
+ )
262
+ action_spec[group] = group_action_spec
263
+ observation_spec[group] = group_observation_spec
264
+ reward_spec[group] = group_reward_spec
265
+
266
+ observation_spec.update(torchrl_state_spec)
267
+ self.done_spec = Composite(
268
+ {
269
+ "done": Categorical(n=2, shape=torch.Size((1,)), dtype=torch.bool),
270
+ },
271
+ )
272
+ self.action_spec = action_spec
273
+ self.observation_spec = observation_spec
274
+ self.reward_spec = reward_spec
275
+
276
+ def _make_group_specs(
277
+ self,
278
+ group: str,
279
+ torchrl_agent_obs_specs: list[TensorSpec],
280
+ torchrl_agent_act_specs: list[TensorSpec],
281
+ torchrl_rew_spec: list[TensorSpec],
282
+ ):
283
+ # Agent specs
284
+ action_specs = []
285
+ observation_specs = []
286
+ reward_specs = []
287
+
288
+ for agent_name in self.group_map[group]:
289
+ agent_index = self.agent_names_to_indices_map[agent_name]
290
+ action_specs.append(
291
+ Composite(
292
+ {
293
+ "action": torchrl_agent_act_specs[
294
+ agent_index
295
+ ] # shape = (n_actions_per_agent,)
296
+ },
297
+ )
298
+ )
299
+ observation_specs.append(
300
+ Composite(
301
+ {
302
+ "observation": torchrl_agent_obs_specs[
303
+ agent_index
304
+ ] # shape = (n_obs_per_agent,)
305
+ },
306
+ )
307
+ )
308
+ reward_specs.append(
309
+ Composite({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,)
310
+ )
311
+
312
+ # Create multi-agent specs
313
+ group_action_spec = torch.stack(
314
+ action_specs, dim=0
315
+ ) # shape = (n_agents_in_group, n_actions_per_agent)
316
+ group_observation_spec = torch.stack(
317
+ observation_specs, dim=0
318
+ ) # shape = (n_agents_in_group, n_obs_per_agent)
319
+ group_reward_spec = torch.stack(
320
+ reward_specs, dim=0
321
+ ) # shape = (n_agents_in_group, 1)
322
+ return (
323
+ group_observation_spec,
324
+ group_action_spec,
325
+ group_reward_spec,
326
+ )
327
+
328
+ def _check_kwargs(self, kwargs: dict):
329
+ meltingpot = self.lib
330
+
331
+ if "env" not in kwargs:
332
+ raise TypeError("Could not find environment key 'env' in kwargs.")
333
+ env = kwargs["env"]
334
+ if not isinstance(env, meltingpot.utils.substrates.substrate.Substrate):
335
+ raise TypeError(
336
+ "env is not of type 'meltingpot.utils.substrates.substrate.Substrate'."
337
+ )
338
+
339
+ def _init_env(self):
340
+ # Caching
341
+ self.cached_full_done_spec_zero = self.full_done_spec.zero()
342
+
343
+ def _set_seed(self, seed: int | None) -> None:
344
+ raise NotImplementedError(
345
+ "It is currently unclear how to set a seed in Meltingpot. "
346
+ "see https://github.com/google-deepmind/meltingpot/issues/129 to track the issue."
347
+ )
348
+
349
+ def _reset(
350
+ self, tensordict: TensorDictBase | None = None, **kwargs
351
+ ) -> TensorDictBase:
352
+ self.num_cycles = 0
353
+ timestep = self._env.reset()
354
+ obs = timestep.observation
355
+
356
+ td = self.cached_full_done_spec_zero.clone()
357
+
358
+ for group, agent_names in self.group_map.items():
359
+ agent_tds = []
360
+ for index_in_group, agent_name in enumerate(agent_names):
361
+ global_index = self.agent_names_to_indices_map[agent_name]
362
+ agent_obs = self.observation_spec[group, "observation"][
363
+ index_in_group
364
+ ].encode(_filter_global_state_from_dict(obs[global_index], world=False))
365
+ agent_td = TensorDict(
366
+ source={
367
+ "observation": agent_obs,
368
+ },
369
+ batch_size=self.batch_size,
370
+ device=self.device,
371
+ )
372
+
373
+ agent_tds.append(agent_td)
374
+ agent_tds = torch.stack(agent_tds, dim=0)
375
+ td.set(group, agent_tds)
376
+
377
+ # Global state
378
+ td.update(
379
+ _remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True))
380
+ )
381
+
382
+ tensordict_out = TensorDict(
383
+ source=td,
384
+ batch_size=self.batch_size,
385
+ device=self.device,
386
+ )
387
+ return tensordict_out
388
+
389
+ def _step(
390
+ self,
391
+ tensordict: TensorDictBase,
392
+ ) -> TensorDictBase:
393
+ action_dict = {}
394
+ for group, agents in self.group_map.items():
395
+ group_action = tensordict.get((group, "action"))
396
+ group_action_np = self.full_action_spec[group, "action"].to_numpy(
397
+ group_action
398
+ )
399
+ for index, agent in enumerate(agents):
400
+ action_dict[agent] = group_action_np[index]
401
+
402
+ actions = [action_dict[agent] for agent in self.agent_names]
403
+ timestep = self._env.step(actions)
404
+ self.num_cycles += 1
405
+
406
+ rewards = timestep.reward
407
+ done = timestep.last() or (
408
+ (self.num_cycles >= self.max_steps) if self.max_steps is not None else False
409
+ )
410
+ obs = timestep.observation
411
+
412
+ td = TensorDict(
413
+ {
414
+ "done": self.full_done_spec["done"].encode(done),
415
+ "terminated": self.full_done_spec["terminated"].encode(done),
416
+ },
417
+ batch_size=self.batch_size,
418
+ )
419
+ # Global state
420
+ td.update(
421
+ _remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True))
422
+ )
423
+
424
+ for group, agent_names in self.group_map.items():
425
+ agent_tds = []
426
+ for index_in_group, agent_name in enumerate(agent_names):
427
+ global_index = self.agent_names_to_indices_map[agent_name]
428
+ agent_obs = self.observation_spec[group, "observation"][
429
+ index_in_group
430
+ ].encode(_filter_global_state_from_dict(obs[global_index], world=False))
431
+ agent_reward = self.full_reward_spec[group, "reward"][
432
+ index_in_group
433
+ ].encode(rewards[global_index])
434
+ agent_td = TensorDict(
435
+ source={
436
+ "observation": agent_obs,
437
+ "reward": agent_reward,
438
+ },
439
+ batch_size=self.batch_size,
440
+ device=self.device,
441
+ )
442
+
443
+ agent_tds.append(agent_td)
444
+ agent_tds = torch.stack(agent_tds, dim=0)
445
+ td.set(group, agent_tds)
446
+
447
+ return td
448
+
449
+ def get_rgb_image(self) -> torch.Tensor:
450
+ """Returns an RGB image of the environment.
451
+
452
+ Returns:
453
+ a ``torch.Tensor`` containing image in format WHC.
454
+
455
+ """
456
+ return torch.from_numpy(self._env.observation()[0][_WORLD_PREFIX + "RGB"])
457
+
458
+
459
+ class MeltingpotEnv(MeltingpotWrapper):
460
+ """Meltingpot environment wrapper.
461
+
462
+ GitHub: https://github.com/google-deepmind/meltingpot
463
+
464
+ Paper: https://arxiv.org/abs/2211.13746
465
+
466
+ Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals,
467
+ and has been designed to test a broad range of social interactions such as: cooperation, competition, deception,
468
+ reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent
469
+ reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios
470
+ on which to evaluate these trained agents.
471
+
472
+ Args:
473
+ substrate(str or ml_collections.config_dict.ConfigDict): the meltingpot substrate to build.
474
+ Can be a string from :attr:`~.available_envs` or a ConfigDict for the substrate
475
+
476
+ Keyword Args:
477
+ max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon).
478
+ Each Meltingpot substrate can
479
+ be terminating or not. If ``max_steps`` is specified,
480
+ the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached.
481
+ Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`,
482
+ this argument will not set the ``"truncated"`` entry in the tensordict.
483
+ categorical_actions (bool, optional): if the environment actions are discrete, whether to transform
484
+ them to categorical or one-hot. Defaults to ``True``.
485
+ group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for
486
+ input/output. By default, they will be all put
487
+ in one group named ``"agents"``.
488
+ Otherwise, a group map can be specified or selected from some premade options.
489
+ See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
490
+
491
+
492
+ Attributes:
493
+ group_map (Dict[str, List[str]]): how to group agents in tensordicts for
494
+ input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
495
+ agent_names (list of str): names of the agent in the environment
496
+ agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment
497
+ available_envs (List[str]): the list of the scenarios available to build.
498
+
499
+ .. warning::
500
+ Meltingpot returns a single ``done`` flag which does not distinguish between
501
+ when the env reached ``max_steps`` and termination.
502
+ If you deem the ``truncation`` signal necessary, set ``max_steps`` to
503
+ ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform.
504
+
505
+ Examples:
506
+ >>> from torchrl.envs.libs.meltingpot import MeltingpotEnv
507
+ >>> env_torchrl = MeltingpotEnv("commons_harvest__open")
508
+ >>> print(env_torchrl.rollout(max_steps=5))
509
+ TensorDict(
510
+ fields={
511
+ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
512
+ agents: TensorDict(
513
+ fields={
514
+ action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False),
515
+ observation: TensorDict(
516
+ fields={
517
+ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
518
+ READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
519
+ RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
520
+ batch_size=torch.Size([5, 7]),
521
+ device=cpu,
522
+ is_shared=False)},
523
+ batch_size=torch.Size([5, 7]),
524
+ device=cpu,
525
+ is_shared=False),
526
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
527
+ next: TensorDict(
528
+ fields={
529
+ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
530
+ agents: TensorDict(
531
+ fields={
532
+ observation: TensorDict(
533
+ fields={
534
+ COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
535
+ READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
536
+ RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
537
+ batch_size=torch.Size([5, 7]),
538
+ device=cpu,
539
+ is_shared=False),
540
+ reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)},
541
+ batch_size=torch.Size([5, 7]),
542
+ device=cpu,
543
+ is_shared=False),
544
+ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
545
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
546
+ batch_size=torch.Size([5]),
547
+ device=cpu,
548
+ is_shared=False),
549
+ terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
550
+ batch_size=torch.Size([5]),
551
+ device=cpu,
552
+ is_shared=False)
553
+
554
+
555
+ """
556
+
557
+ def __init__(
558
+ self,
559
+ substrate: str | ml_collections.config_dict.ConfigDict, # noqa
560
+ *,
561
+ max_steps: int | None = None,
562
+ categorical_actions: bool = True,
563
+ group_map: MarlGroupMapType
564
+ | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
565
+ **kwargs,
566
+ ):
567
+ if not _has_meltingpot:
568
+ raise ImportError(
569
+ f"meltingpot python package was not found. Please install this dependency. "
570
+ f"More info: {self.git_url}."
571
+ )
572
+ super().__init__(
573
+ substrate=substrate,
574
+ max_steps=max_steps,
575
+ categorical_actions=categorical_actions,
576
+ group_map=group_map,
577
+ **kwargs,
578
+ )
579
+
580
+ def _check_kwargs(self, kwargs: dict):
581
+ if "substrate" not in kwargs:
582
+ raise TypeError("Could not find environment key 'substrate' in kwargs.")
583
+
584
+ def _build_env(
585
+ self,
586
+ substrate: str | ml_collections.config_dict.ConfigDict, # noqa
587
+ ) -> meltingpot.utils.substrates.substrate.Substrate: # noqa
588
+ from meltingpot import substrate as mp_substrate
589
+
590
+ if isinstance(substrate, str):
591
+ substrate_config = mp_substrate.get_config(substrate)
592
+ else:
593
+ substrate_config = substrate
594
+
595
+ return super()._build_env(
596
+ env=mp_substrate.build_from_config(
597
+ substrate_config, roles=substrate_config.default_player_roles
598
+ )
599
+ )