torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,891 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib.util
9
+
10
+ import torch
11
+ from tensordict import TensorDict, TensorDictBase
12
+
13
+ from torchrl.data.tensor_specs import (
14
+ BoundedContinuous,
15
+ Categorical,
16
+ Composite,
17
+ MultiCategorical,
18
+ MultiOneHot,
19
+ Unbounded,
20
+ )
21
+ from torchrl.envs.common import _EnvWrapper
22
+ from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
23
+
24
+ _has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None
25
+
26
+
27
+ def _get_registered_envs():
28
+ if not _has_unity_mlagents:
29
+ raise ImportError(
30
+ "mlagents_envs not found. Consider downloading and installing "
31
+ f"mlagents from {UnityMLAgentsWrapper.git_url}."
32
+ )
33
+
34
+ from mlagents_envs.registry import default_registry
35
+
36
+ return list(default_registry.keys())
37
+
38
+
39
+ class UnityMLAgentsWrapper(_EnvWrapper):
40
+ """Unity ML-Agents environment wrapper.
41
+
42
+ GitHub: https://github.com/Unity-Technologies/ml-agents
43
+
44
+ Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/
45
+
46
+ Args:
47
+ env (mlagents_envs.environment.UnityEnvironment): the ML-Agents
48
+ environment to wrap.
49
+
50
+ Keyword Args:
51
+ device (torch.device, optional): if provided, the device on which the data
52
+ is to be cast. Defaults to ``None``.
53
+ batch_size (torch.Size, optional): the batch size of the environment.
54
+ Defaults to ``torch.Size([])``.
55
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
56
+ for envs to be ``done`` just after :meth:`reset` is called.
57
+ Defaults to ``False``.
58
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
59
+ group agents in tensordicts for input/output. See
60
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
61
+ specified, agents are grouped according to the group ID given by the
62
+ Unity environment. Defaults to ``None``.
63
+ categorical_actions (bool, optional): if ``True``, categorical specs
64
+ will be converted to the TorchRL equivalent
65
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
66
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
67
+
68
+ Attributes:
69
+ available_envs: list of registered environments available to build
70
+
71
+ Examples:
72
+ >>> from mlagents_envs.environment import UnityEnvironment
73
+ >>> base_env = UnityEnvironment()
74
+ >>> from torchrl.envs import UnityMLAgentsWrapper
75
+ >>> env = UnityMLAgentsWrapper(base_env)
76
+ >>> td = env.reset()
77
+ >>> td = env.step(td.update(env.full_action_spec.rand()))
78
+ """
79
+
80
+ git_url = "https://github.com/Unity-Technologies/ml-agents"
81
+ libname = "mlagents_envs"
82
+ _lib = None
83
+
84
+ @_classproperty
85
+ def lib(cls):
86
+ if cls._lib is not None:
87
+ return cls._lib
88
+
89
+ import mlagents_envs.environment
90
+
91
+ cls._lib = mlagents_envs
92
+ return mlagents_envs
93
+
94
+ def __init__(
95
+ self,
96
+ env=None,
97
+ *,
98
+ group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
99
+ categorical_actions: bool = False,
100
+ **kwargs,
101
+ ):
102
+ if env is not None:
103
+ kwargs["env"] = env
104
+
105
+ self.group_map = group_map
106
+ self.categorical_actions = categorical_actions
107
+ super().__init__(**kwargs)
108
+
109
+ def _check_kwargs(self, kwargs: dict):
110
+ mlagents_envs = self.lib
111
+ if "env" not in kwargs:
112
+ raise TypeError("Could not find environment key 'env' in kwargs.")
113
+ env = kwargs["env"]
114
+ if not isinstance(env, mlagents_envs.environment.UnityEnvironment):
115
+ raise TypeError(
116
+ "env is not of type 'mlagents_envs.environment.UnityEnvironment'"
117
+ )
118
+
119
+ def _build_env(self, env, requires_grad: bool = False, **kwargs):
120
+ self.requires_grad = requires_grad
121
+ return env
122
+
123
+ def _init_env(self):
124
+ self._update_action_mask()
125
+
126
+ # Creates a group map where agents are grouped by the group_id given by the
127
+ # Unity environment.
128
+ def _collect_agents(self, env):
129
+ agent_name_to_behavior_map = {}
130
+ agent_name_to_group_id_map = {}
131
+
132
+ for steps_idx in [0, 1]:
133
+ for behavior in env.behavior_specs.keys():
134
+ steps = env.get_steps(behavior)[steps_idx]
135
+ agent_ids = steps.agent_id
136
+ group_ids = steps.group_id
137
+
138
+ for agent_id, group_id in zip(agent_ids, group_ids):
139
+ agent_name = f"agent_{agent_id}"
140
+ if agent_name in agent_name_to_behavior_map:
141
+ # Sometimes in an MLAgents environment, an agent may
142
+ # show up in both the decision steps and the terminal
143
+ # steps. When that happens, just skip the duplicate.
144
+ continue
145
+ agent_name_to_behavior_map[agent_name] = behavior
146
+ agent_name_to_group_id_map[agent_name] = group_id
147
+
148
+ return (
149
+ agent_name_to_behavior_map,
150
+ agent_name_to_group_id_map,
151
+ )
152
+
153
+ # Creates a group map where agents are grouped by their group_id.
154
+ def _make_default_group_map(self, agent_name_to_group_id_map):
155
+ group_map = {}
156
+ for agent_name, group_id in agent_name_to_group_id_map.items():
157
+ group_name = f"group_{group_id}"
158
+ if group_name not in group_map:
159
+ group_map[group_name] = []
160
+ group_map[group_name].append(agent_name)
161
+ return group_map
162
+
163
+ def _make_group_map(self, group_map, agent_name_to_group_id_map):
164
+ if group_map is None:
165
+ group_map = self._make_default_group_map(agent_name_to_group_id_map)
166
+ elif isinstance(group_map, MarlGroupMapType):
167
+ group_map = group_map.get_group_map(agent_name_to_group_id_map.keys())
168
+ check_marl_grouping(group_map, agent_name_to_group_id_map.keys())
169
+ agent_name_to_group_name_map = {}
170
+ for group_name, agents in group_map.items():
171
+ for agent_name in agents:
172
+ agent_name_to_group_name_map[agent_name] = group_name
173
+ return group_map, agent_name_to_group_name_map
174
+
175
+ def _make_specs(
176
+ self, env: mlagents_envs.environment.UnityEnvironment # noqa: F821
177
+ ) -> None:
178
+ # NOTE: We need to reset here because mlagents only initializes the
179
+ # agents and behaviors after reset. In order to build specs, we make the
180
+ # following assumptions about the mlagents environment:
181
+ # * all behaviors are defined on the first step
182
+ # * all agents request an action on the first step
183
+ # However, mlagents allows you to break these assumptions, so we probably
184
+ # will need to detect changes to the behaviors and agents on each step.
185
+ env.reset()
186
+ (
187
+ self.agent_name_to_behavior_map,
188
+ self.agent_name_to_group_id_map,
189
+ ) = self._collect_agents(env)
190
+
191
+ (self.group_map, self.agent_name_to_group_name_map) = self._make_group_map(
192
+ self.group_map, self.agent_name_to_group_id_map
193
+ )
194
+
195
+ action_spec = {}
196
+ observation_spec = {}
197
+ reward_spec = {}
198
+ done_spec = {}
199
+
200
+ for group_name, agents in self.group_map.items():
201
+ group_action_spec = {}
202
+ group_observation_spec = {}
203
+ group_reward_spec = {}
204
+ group_done_spec = {}
205
+ for agent_name in agents:
206
+ behavior = self.agent_name_to_behavior_map[agent_name]
207
+ behavior_spec = env.behavior_specs[behavior]
208
+
209
+ # Create action spec
210
+ agent_action_spec = Composite()
211
+ env_action_spec = behavior_spec.action_spec
212
+ discrete_branches = env_action_spec.discrete_branches
213
+ continuous_size = env_action_spec.continuous_size
214
+ if len(discrete_branches) > 0:
215
+ discrete_action_spec_cls = (
216
+ MultiCategorical if self.categorical_actions else MultiOneHot
217
+ )
218
+ agent_action_spec["discrete_action"] = discrete_action_spec_cls(
219
+ discrete_branches,
220
+ dtype=torch.int32,
221
+ device=self.device,
222
+ )
223
+ if continuous_size > 0:
224
+ # In mlagents, continuous actions can take values between -1
225
+ # and 1 by default:
226
+ # https://github.com/Unity-Technologies/ml-agents/blob/22a59aad34ef46a5de05469735426feed758f8f5/ml-agents-envs/mlagents_envs/base_env.py#L395
227
+ agent_action_spec["continuous_action"] = BoundedContinuous(
228
+ -1, 1, (continuous_size,), self.device, torch.float32
229
+ )
230
+ group_action_spec[agent_name] = agent_action_spec
231
+
232
+ # Create observation spec
233
+ agent_observation_spec = Composite()
234
+ for obs_idx, env_observation_spec in enumerate(
235
+ behavior_spec.observation_specs
236
+ ):
237
+ if len(env_observation_spec.name) == 0:
238
+ obs_name = f"observation_{obs_idx}"
239
+ else:
240
+ obs_name = env_observation_spec.name
241
+ agent_observation_spec[obs_name] = Unbounded(
242
+ env_observation_spec.shape,
243
+ dtype=torch.float32,
244
+ device=self.device,
245
+ )
246
+ group_observation_spec[agent_name] = agent_observation_spec
247
+
248
+ # Create reward spec
249
+ agent_reward_spec = Composite()
250
+ agent_reward_spec["reward"] = Unbounded(
251
+ (1,),
252
+ dtype=torch.float32,
253
+ device=self.device,
254
+ )
255
+ agent_reward_spec["group_reward"] = Unbounded(
256
+ (1,),
257
+ dtype=torch.float32,
258
+ device=self.device,
259
+ )
260
+ group_reward_spec[agent_name] = agent_reward_spec
261
+
262
+ # Create done spec
263
+ agent_done_spec = Composite()
264
+ for done_key in ["done", "terminated", "truncated"]:
265
+ agent_done_spec[done_key] = Categorical(
266
+ 2, (1,), dtype=torch.bool, device=self.device
267
+ )
268
+ group_done_spec[agent_name] = agent_done_spec
269
+
270
+ action_spec[group_name] = group_action_spec
271
+ observation_spec[group_name] = group_observation_spec
272
+ reward_spec[group_name] = group_reward_spec
273
+ done_spec[group_name] = group_done_spec
274
+
275
+ self.action_spec = Composite(action_spec)
276
+ self.observation_spec = Composite(observation_spec)
277
+ self.reward_spec = Composite(reward_spec)
278
+ self.done_spec = Composite(done_spec)
279
+
280
+ def _set_seed(self, seed: int | None) -> None:
281
+ if seed is not None:
282
+ raise NotImplementedError("This environment has no seed.")
283
+
284
+ def _check_agent_exists(self, agent_name, group_id):
285
+ if agent_name not in self.agent_name_to_group_id_map:
286
+ raise RuntimeError(
287
+ "Unity environment added a new agent. This is not yet "
288
+ "supported in torchrl."
289
+ )
290
+ if self.agent_name_to_group_id_map[agent_name] != group_id:
291
+ raise RuntimeError(
292
+ "Unity environment changed the group of an agent. This "
293
+ "is not yet supported in torchrl."
294
+ )
295
+
296
+ def _update_action_mask(self):
297
+ for behavior, behavior_spec in self._env.behavior_specs.items():
298
+ env_action_spec = behavior_spec.action_spec
299
+ discrete_branches = env_action_spec.discrete_branches
300
+
301
+ if len(discrete_branches) > 0:
302
+ steps = self._env.get_steps(behavior)[0]
303
+ env_action_mask = steps.action_mask
304
+ if env_action_mask is not None:
305
+ combined_action_mask = torch.cat(
306
+ [
307
+ torch.tensor(m, device=self.device, dtype=torch.bool)
308
+ for m in env_action_mask
309
+ ],
310
+ dim=-1,
311
+ ).logical_not()
312
+
313
+ for agent_id, group_id, agent_action_mask in zip(
314
+ steps.agent_id, steps.group_id, combined_action_mask
315
+ ):
316
+ agent_name = f"agent_{agent_id}"
317
+ self._check_agent_exists(agent_name, group_id)
318
+ group_name = self.agent_name_to_group_name_map[agent_name]
319
+ self.full_action_spec[
320
+ group_name, agent_name, "discrete_action"
321
+ ].update_mask(agent_action_mask)
322
+
323
+ def _make_td_out(self, tensordict_in, is_reset=False):
324
+ source = {}
325
+ for behavior, behavior_spec in self._env.behavior_specs.items():
326
+ for idx, steps in enumerate(self._env.get_steps(behavior)):
327
+ is_terminal = idx == 1
328
+ for steps_idx, (agent_id, group_id) in enumerate(
329
+ zip(steps.agent_id, steps.group_id)
330
+ ):
331
+ agent_name = f"agent_{agent_id}"
332
+ self._check_agent_exists(agent_name, group_id)
333
+ group_name = self.agent_name_to_group_name_map[agent_name]
334
+ if group_name not in source:
335
+ source[group_name] = {}
336
+ if agent_name not in source[group_name]:
337
+ source[group_name][agent_name] = {}
338
+
339
+ # Add observations
340
+ for obs_idx, (
341
+ behavior_observation,
342
+ env_observation_spec,
343
+ ) in enumerate(zip(steps.obs, behavior_spec.observation_specs)):
344
+ observation = torch.tensor(
345
+ behavior_observation[steps_idx],
346
+ device=self.device,
347
+ dtype=torch.float32,
348
+ )
349
+ if len(env_observation_spec.name) == 0:
350
+ obs_name = f"observation_{obs_idx}"
351
+ else:
352
+ obs_name = env_observation_spec.name
353
+ source[group_name][agent_name][obs_name] = observation
354
+
355
+ # Add rewards
356
+ if not is_reset:
357
+ source[group_name][agent_name]["reward"] = torch.tensor(
358
+ [steps.reward[steps_idx]],
359
+ device=self.device,
360
+ dtype=torch.float32,
361
+ )
362
+ source[group_name][agent_name]["group_reward"] = torch.tensor(
363
+ [steps.group_reward[steps_idx]],
364
+ device=self.device,
365
+ dtype=torch.float32,
366
+ )
367
+
368
+ # Add done
369
+ done = is_terminal and not is_reset
370
+ source[group_name][agent_name]["done"] = torch.tensor(
371
+ done, device=self.device, dtype=torch.bool
372
+ )
373
+ source[group_name][agent_name]["truncated"] = torch.tensor(
374
+ done and steps.interrupted[steps_idx],
375
+ device=self.device,
376
+ dtype=torch.bool,
377
+ )
378
+ source[group_name][agent_name]["terminated"] = torch.tensor(
379
+ done and not steps.interrupted[steps_idx],
380
+ device=self.device,
381
+ dtype=torch.bool,
382
+ )
383
+
384
+ if tensordict_in is not None:
385
+ # In MLAgents, a given step will only contain information for agents
386
+ # which either terminated or requested a decision during the step.
387
+ # Some agents may have neither terminated nor requested a decision,
388
+ # so we need to fill in their information from the previous step.
389
+ for group_name, agents in self.group_map.items():
390
+ for agent_name in agents:
391
+ if group_name not in source.keys():
392
+ source[group_name] = {}
393
+ if agent_name not in source[group_name].keys():
394
+ agent_dict = {}
395
+ agent_behavior = self.agent_name_to_behavior_map[agent_name]
396
+ behavior_spec = self._env.behavior_specs[agent_behavior]
397
+ td_agent_in = tensordict_in[group_name, agent_name]
398
+
399
+ # Add observations
400
+ for env_observation_spec in behavior_spec.observation_specs:
401
+ if len(env_observation_spec.name) == 0:
402
+ obs_name = f"observation_{obs_idx}"
403
+ else:
404
+ obs_name = env_observation_spec.name
405
+ agent_dict[obs_name] = td_agent_in[obs_name]
406
+
407
+ # Add rewards
408
+ if not is_reset:
409
+ # Since the agent didn't request an decision, the
410
+ # reward is 0
411
+ agent_dict["reward"] = torch.zeros(
412
+ (1,), device=self.device, dtype=torch.float32
413
+ )
414
+ agent_dict["group_reward"] = torch.zeros(
415
+ (1,), device=self.device, dtype=torch.float32
416
+ )
417
+
418
+ # Add done
419
+ agent_dict["done"] = torch.tensor(
420
+ False, device=self.device, dtype=torch.bool
421
+ )
422
+ agent_dict["terminated"] = torch.tensor(
423
+ False, device=self.device, dtype=torch.bool
424
+ )
425
+ agent_dict["truncated"] = torch.tensor(
426
+ False, device=self.device, dtype=torch.bool
427
+ )
428
+
429
+ source[group_name][agent_name] = agent_dict
430
+
431
+ tensordict_out = TensorDict(
432
+ source=source,
433
+ batch_size=self.batch_size,
434
+ device=self.device,
435
+ )
436
+
437
+ return tensordict_out
438
+
439
+ def _get_action_from_tensor(self, tensor):
440
+ if not self.categorical_actions:
441
+ action = torch.argmax(tensor, dim=-1)
442
+ else:
443
+ action = tensor
444
+ return action
445
+
446
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
447
+ # Apply actions
448
+ for behavior, behavior_spec in self._env.behavior_specs.items():
449
+ env_action_spec = behavior_spec.action_spec
450
+ steps = self._env.get_steps(behavior)[0]
451
+
452
+ for agent_id, group_id in zip(steps.agent_id, steps.group_id):
453
+ agent_name = f"agent_{agent_id}"
454
+ self._check_agent_exists(agent_name, group_id)
455
+ group_name = self.agent_name_to_group_name_map[agent_name]
456
+
457
+ agent_action_spec = self.full_action_spec[group_name, agent_name]
458
+ action_tuple = self.lib.base_env.ActionTuple()
459
+ discrete_branches = env_action_spec.discrete_branches
460
+ continuous_size = env_action_spec.continuous_size
461
+
462
+ if len(discrete_branches) > 0:
463
+ discrete_spec = agent_action_spec["discrete_action"]
464
+ discrete_action = tensordict[
465
+ group_name, agent_name, "discrete_action"
466
+ ]
467
+ if not self.categorical_actions:
468
+ discrete_action = discrete_spec.to_categorical(discrete_action)
469
+ action_tuple.add_discrete(discrete_action[None, ...].numpy())
470
+
471
+ if continuous_size > 0:
472
+ continuous_action = tensordict[
473
+ group_name, agent_name, "continuous_action"
474
+ ]
475
+ action_tuple.add_continuous(continuous_action[None, ...].numpy())
476
+
477
+ self._env.set_action_for_agent(behavior, agent_id, action_tuple)
478
+
479
+ self._env.step()
480
+ self._update_action_mask()
481
+ return self._make_td_out(tensordict)
482
+
483
+ def _to_tensor(self, value):
484
+ return torch.tensor(value, device=self.device, dtype=torch.float32)
485
+
486
+ def _reset(
487
+ self, tensordict: TensorDictBase | None = None, **kwargs
488
+ ) -> TensorDictBase:
489
+ self._env.reset()
490
+ return self._make_td_out(tensordict, is_reset=True)
491
+
492
+ def close(self, *, raise_if_closed: bool = True):
493
+ self._env.close()
494
+
495
+ @_classproperty
496
+ def available_envs(cls):
497
+ if not _has_unity_mlagents:
498
+ return []
499
+ return _get_registered_envs()
500
+
501
+
502
+ class UnityMLAgentsEnv(UnityMLAgentsWrapper):
503
+ """Unity ML-Agents environment wrapper.
504
+
505
+ GitHub: https://github.com/Unity-Technologies/ml-agents
506
+
507
+ Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/
508
+
509
+ This class can be provided any of the optional initialization arguments that
510
+ :class:`mlagents_envs.environment.UnityEnvironment` class provides. For a
511
+ list of these arguments, see:
512
+ https://unity-technologies.github.io/ml-agents/Python-LLAPI-Documentation/#__init__
513
+
514
+ If both ``file_name`` and ``registered_name`` are given, an error is raised.
515
+
516
+ If neither ``file_name`` nor``registered_name`` are given, the environment
517
+ setup waits on a localhost port, and the user must execute a Unity ML-Agents
518
+ environment binary for to connect to it.
519
+
520
+ Args:
521
+ file_name (str, optional): if provided, the path to the Unity
522
+ environment binary. Defaults to ``None``.
523
+ registered_name (str, optional): if provided, the Unity environment
524
+ binary is loaded from the default ML-Agents registry. The list of
525
+ registered environments is in :attr:`~.available_envs`. Defaults to
526
+ ``None``.
527
+
528
+ Keyword Args:
529
+ device (torch.device, optional): if provided, the device on which the data
530
+ is to be cast. Defaults to ``None``.
531
+ batch_size (torch.Size, optional): the batch size of the environment.
532
+ Defaults to ``torch.Size([])``.
533
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
534
+ for envs to be ``done`` just after :meth:`reset` is called.
535
+ Defaults to ``False``.
536
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
537
+ group agents in tensordicts for input/output. See
538
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
539
+ specified, agents are grouped according to the group ID given by the
540
+ Unity environment. Defaults to ``None``.
541
+ categorical_actions (bool, optional): if ``True``, categorical specs
542
+ will be converted to the TorchRL equivalent
543
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
544
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
545
+
546
+ Attributes:
547
+ available_envs: list of registered environments available to build
548
+
549
+ Examples:
550
+ >>> from torchrl.envs import UnityMLAgentsEnv
551
+ >>> env = UnityMLAgentsEnv(registered_name='3DBall')
552
+ >>> td = env.reset()
553
+ >>> td = env.step(td.update(env.full_action_spec.rand()))
554
+ >>> td
555
+ TensorDict(
556
+ fields={
557
+ group_0: TensorDict(
558
+ fields={
559
+ agent_0: TensorDict(
560
+ fields={
561
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
562
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
563
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
564
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
565
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
566
+ batch_size=torch.Size([]),
567
+ device=None,
568
+ is_shared=False),
569
+ agent_10: TensorDict(
570
+ fields={
571
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
572
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
573
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
574
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
575
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
576
+ batch_size=torch.Size([]),
577
+ device=None,
578
+ is_shared=False),
579
+ agent_11: TensorDict(
580
+ fields={
581
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
582
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
583
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
584
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
585
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
586
+ batch_size=torch.Size([]),
587
+ device=None,
588
+ is_shared=False),
589
+ agent_1: TensorDict(
590
+ fields={
591
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
592
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
593
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
594
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
595
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
596
+ batch_size=torch.Size([]),
597
+ device=None,
598
+ is_shared=False),
599
+ agent_2: TensorDict(
600
+ fields={
601
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
602
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
603
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
604
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
605
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
606
+ batch_size=torch.Size([]),
607
+ device=None,
608
+ is_shared=False),
609
+ agent_3: TensorDict(
610
+ fields={
611
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
612
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
613
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
614
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
615
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
616
+ batch_size=torch.Size([]),
617
+ device=None,
618
+ is_shared=False),
619
+ agent_4: TensorDict(
620
+ fields={
621
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
622
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
623
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
624
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
625
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
626
+ batch_size=torch.Size([]),
627
+ device=None,
628
+ is_shared=False),
629
+ agent_5: TensorDict(
630
+ fields={
631
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
632
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
633
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
634
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
635
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
636
+ batch_size=torch.Size([]),
637
+ device=None,
638
+ is_shared=False),
639
+ agent_6: TensorDict(
640
+ fields={
641
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
642
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
643
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
644
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
645
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
646
+ batch_size=torch.Size([]),
647
+ device=None,
648
+ is_shared=False),
649
+ agent_7: TensorDict(
650
+ fields={
651
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
652
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
653
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
654
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
655
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
656
+ batch_size=torch.Size([]),
657
+ device=None,
658
+ is_shared=False),
659
+ agent_8: TensorDict(
660
+ fields={
661
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
662
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
663
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
664
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
665
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
666
+ batch_size=torch.Size([]),
667
+ device=None,
668
+ is_shared=False),
669
+ agent_9: TensorDict(
670
+ fields={
671
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
672
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
673
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
674
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
675
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
676
+ batch_size=torch.Size([]),
677
+ device=None,
678
+ is_shared=False)},
679
+ batch_size=torch.Size([]),
680
+ device=None,
681
+ is_shared=False),
682
+ next: TensorDict(
683
+ fields={
684
+ group_0: TensorDict(
685
+ fields={
686
+ agent_0: TensorDict(
687
+ fields={
688
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
689
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
690
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
691
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
692
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
693
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
694
+ batch_size=torch.Size([]),
695
+ device=None,
696
+ is_shared=False),
697
+ agent_10: TensorDict(
698
+ fields={
699
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
700
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
701
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
702
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
703
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
704
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
705
+ batch_size=torch.Size([]),
706
+ device=None,
707
+ is_shared=False),
708
+ agent_11: TensorDict(
709
+ fields={
710
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
711
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
712
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
713
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
714
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
715
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
716
+ batch_size=torch.Size([]),
717
+ device=None,
718
+ is_shared=False),
719
+ agent_1: TensorDict(
720
+ fields={
721
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
722
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
723
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
724
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
725
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
726
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
727
+ batch_size=torch.Size([]),
728
+ device=None,
729
+ is_shared=False),
730
+ agent_2: TensorDict(
731
+ fields={
732
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
733
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
734
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
735
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
736
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
737
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
738
+ batch_size=torch.Size([]),
739
+ device=None,
740
+ is_shared=False),
741
+ agent_3: TensorDict(
742
+ fields={
743
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
744
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
745
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
746
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
747
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
748
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
749
+ batch_size=torch.Size([]),
750
+ device=None,
751
+ is_shared=False),
752
+ agent_4: TensorDict(
753
+ fields={
754
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
755
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
756
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
757
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
758
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
759
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
760
+ batch_size=torch.Size([]),
761
+ device=None,
762
+ is_shared=False),
763
+ agent_5: TensorDict(
764
+ fields={
765
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
766
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
767
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
768
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
769
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
770
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
771
+ batch_size=torch.Size([]),
772
+ device=None,
773
+ is_shared=False),
774
+ agent_6: TensorDict(
775
+ fields={
776
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
777
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
778
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
779
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
780
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
781
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
782
+ batch_size=torch.Size([]),
783
+ device=None,
784
+ is_shared=False),
785
+ agent_7: TensorDict(
786
+ fields={
787
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
788
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
789
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
790
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
791
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
792
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
793
+ batch_size=torch.Size([]),
794
+ device=None,
795
+ is_shared=False),
796
+ agent_8: TensorDict(
797
+ fields={
798
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
799
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
800
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
801
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
802
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
803
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
804
+ batch_size=torch.Size([]),
805
+ device=None,
806
+ is_shared=False),
807
+ agent_9: TensorDict(
808
+ fields={
809
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
810
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
811
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
812
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
813
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
814
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
815
+ batch_size=torch.Size([]),
816
+ device=None,
817
+ is_shared=False)},
818
+ batch_size=torch.Size([]),
819
+ device=None,
820
+ is_shared=False)},
821
+ batch_size=torch.Size([]),
822
+ device=None,
823
+ is_shared=False)},
824
+ batch_size=torch.Size([]),
825
+ device=None,
826
+ is_shared=False)
827
+ """
828
+
829
+ def __init__(
830
+ self,
831
+ file_name: str | None = None,
832
+ registered_name: str | None = None,
833
+ *,
834
+ group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
835
+ categorical_actions=False,
836
+ **kwargs,
837
+ ):
838
+ kwargs["file_name"] = file_name
839
+ kwargs["registered_name"] = registered_name
840
+ super().__init__(
841
+ group_map=group_map,
842
+ categorical_actions=categorical_actions,
843
+ **kwargs,
844
+ )
845
+
846
+ def _build_env(
847
+ self,
848
+ file_name: str | None,
849
+ registered_name: str | None,
850
+ **kwargs,
851
+ ) -> mlagents_envs.environment.UnityEnvironment: # noqa: F821
852
+ if not _has_unity_mlagents:
853
+ raise ImportError(
854
+ "mlagents_envs not found, unable to create environment. "
855
+ "Consider downloading and installing mlagents from "
856
+ f"{self.git_url}"
857
+ )
858
+ if file_name is not None and registered_name is not None:
859
+ raise ValueError(
860
+ "Both `file_name` and `registered_name` were specified, which "
861
+ "is not allowed. Specify one of them or neither."
862
+ )
863
+ elif registered_name is not None:
864
+ from mlagents_envs.registry import default_registry
865
+
866
+ env = default_registry[registered_name].make(**kwargs)
867
+ else:
868
+ env = self.lib.environment.UnityEnvironment(file_name, **kwargs)
869
+ requires_grad = kwargs.pop("requires_grad", False)
870
+ return super()._build_env(
871
+ env,
872
+ requires_grad=requires_grad,
873
+ )
874
+
875
+ @property
876
+ def file_name(self):
877
+ return self._constructor_kwargs["file_name"]
878
+
879
+ @property
880
+ def registered_name(self):
881
+ return self._constructor_kwargs["registered_name"]
882
+
883
+ def _check_kwargs(self, kwargs: dict):
884
+ pass
885
+
886
+ def __repr__(self) -> str:
887
+ if self.registered_name is not None:
888
+ env_name = self.registered_name
889
+ else:
890
+ env_name = self.file_name
891
+ return f"{self.__class__.__name__}(env={env_name}, batch_size={self.batch_size}, device={self.device})"