torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,652 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib.util
9
+
10
+ import torch
11
+ from tensordict import TensorDict, TensorDictBase
12
+
13
+ from torchrl.data.tensor_specs import (
14
+ Categorical,
15
+ Composite,
16
+ NonTensor,
17
+ OneHot,
18
+ Unbounded,
19
+ )
20
+ from torchrl.envs.common import _EnvWrapper
21
+ from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
22
+
23
+ _has_pyspiel = importlib.util.find_spec("pyspiel") is not None
24
+
25
+
26
+ def _get_envs():
27
+ if not _has_pyspiel:
28
+ raise ImportError(
29
+ "open_spiel not found. Consider downloading and installing "
30
+ f"open_spiel from {OpenSpielWrapper.git_url}."
31
+ )
32
+
33
+ import pyspiel
34
+
35
+ return [game.short_name for game in pyspiel.registered_games()]
36
+
37
+
38
+ class OpenSpielWrapper(_EnvWrapper):
39
+ """Google DeepMind OpenSpiel environment wrapper.
40
+
41
+ GitHub: https://github.com/google-deepmind/open_spiel
42
+
43
+ Documentation: https://openspiel.readthedocs.io/en/latest/index.html
44
+
45
+ Args:
46
+ env (pyspiel.State): the game to wrap.
47
+
48
+ Keyword Args:
49
+ device (torch.device, optional): if provided, the device on which the data
50
+ is to be cast. Defaults to ``None``.
51
+ batch_size (torch.Size, optional): the batch size of the environment.
52
+ Defaults to ``torch.Size([])``.
53
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
54
+ for envs to be ``done`` just after :meth:`reset` is called.
55
+ Defaults to ``False``.
56
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
57
+ group agents in tensordicts for input/output. See
58
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
59
+ Defaults to
60
+ :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`.
61
+ categorical_actions (bool, optional): if ``True``, categorical specs
62
+ will be converted to the TorchRL equivalent
63
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
64
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
65
+ return_state (bool, optional): if ``True``, "state" is included in the
66
+ output of :meth:`reset` and :meth:`~step`. The state can be given
67
+ to :meth:`reset` to reset to that state, rather than resetting to
68
+ the initial state.
69
+ Defaults to ``False``.
70
+
71
+ Attributes:
72
+ available_envs: environments available to build
73
+
74
+ Examples:
75
+ >>> import pyspiel
76
+ >>> from torchrl.envs import OpenSpielWrapper
77
+ >>> from tensordict import TensorDict
78
+ >>> base_env = pyspiel.load_game('chess').new_initial_state()
79
+ >>> env = OpenSpielWrapper(base_env, return_state=True)
80
+ >>> td = env.reset()
81
+ >>> td = env.step(env.full_action_spec.rand())
82
+ >>> print(td)
83
+ TensorDict(
84
+ fields={
85
+ agents: TensorDict(
86
+ fields={
87
+ action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
88
+ batch_size=torch.Size([]),
89
+ device=None,
90
+ is_shared=False),
91
+ next: TensorDict(
92
+ fields={
93
+ agents: TensorDict(
94
+ fields={
95
+ observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
96
+ reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
97
+ batch_size=torch.Size([2]),
98
+ device=None,
99
+ is_shared=False),
100
+ current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
101
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
102
+ state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
103
+ 3009
104
+ , batch_size=torch.Size([]), device=None),
105
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
106
+ batch_size=torch.Size([]),
107
+ device=None,
108
+ is_shared=False)},
109
+ batch_size=torch.Size([]),
110
+ device=None,
111
+ is_shared=False)
112
+ >>> print(env.available_envs)
113
+ ['2048', 'add_noise', 'amazons', 'backgammon', ...]
114
+
115
+ :meth:`reset` can restore a specific state, rather than the initial
116
+ state, as long as ``return_state=True``.
117
+
118
+ >>> import pyspiel
119
+ >>> from torchrl.envs import OpenSpielWrapper
120
+ >>> from tensordict import TensorDict
121
+ >>> base_env = pyspiel.load_game('chess').new_initial_state()
122
+ >>> env = OpenSpielWrapper(base_env, return_state=True)
123
+ >>> td = env.reset()
124
+ >>> td = env.step(env.full_action_spec.rand())
125
+ >>> td_restore = td["next"]
126
+ >>> td = env.step(env.full_action_spec.rand())
127
+ >>> # Current state is not equal `td_restore`
128
+ >>> (td["next"] == td_restore).all()
129
+ False
130
+ >>> td = env.reset(td_restore)
131
+ >>> # After resetting, now the current state is equal to `td_restore`
132
+ >>> (td == td_restore).all()
133
+ True
134
+ """
135
+
136
+ git_url = "https://github.com/google-deepmind/open_spiel"
137
+ libname = "pyspiel"
138
+ _lib = None
139
+
140
+ @_classproperty
141
+ def lib(cls):
142
+ if cls._lib is not None:
143
+ return cls._lib
144
+
145
+ import pyspiel
146
+
147
+ cls._lib = pyspiel
148
+ return pyspiel
149
+
150
+ @_classproperty
151
+ def available_envs(cls):
152
+ if not _has_pyspiel:
153
+ return []
154
+ return _get_envs()
155
+
156
+ def __init__(
157
+ self,
158
+ env=None,
159
+ *,
160
+ group_map: MarlGroupMapType
161
+ | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
162
+ categorical_actions: bool = False,
163
+ return_state: bool = False,
164
+ **kwargs,
165
+ ):
166
+ if env is not None:
167
+ kwargs["env"] = env
168
+
169
+ self.group_map = group_map
170
+ self.categorical_actions = categorical_actions
171
+ self.return_state = return_state
172
+ self._cached_game = None
173
+ super().__init__(**kwargs)
174
+
175
+ # `reset` allows resetting to any state, including a terminal state
176
+ self._allow_done_after_reset = True
177
+
178
+ def _check_kwargs(self, kwargs: dict):
179
+ pyspiel = self.lib
180
+ if "env" not in kwargs:
181
+ raise TypeError("Could not find environment key 'env' in kwargs.")
182
+ env = kwargs["env"]
183
+ if not isinstance(env, pyspiel.State):
184
+ raise TypeError("env is not of type 'pyspiel.State'.")
185
+
186
+ def _build_env(self, env, requires_grad: bool = False, **kwargs):
187
+ game = env.get_game()
188
+ game_type = game.get_type()
189
+
190
+ if game.max_chance_outcomes() != 0:
191
+ raise NotImplementedError(
192
+ f"The game '{game_type.short_name}' has chance nodes, which are not yet supported."
193
+ )
194
+ if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD:
195
+ # NOTE: It is unclear from the OpenSpiel documentation what exactly
196
+ # "mean field" means exactly, and there is no documentation on the
197
+ # several games which have it.
198
+ raise RuntimeError(
199
+ f"Mean field games like '{game_type.name}' are not yet " "supported."
200
+ )
201
+ self.parallel = game_type.dynamics == self.lib.GameType.Dynamics.SIMULTANEOUS
202
+ self.requires_grad = requires_grad
203
+ return env
204
+
205
+ def _init_env(self):
206
+ self._update_action_mask()
207
+
208
+ def _get_game(self):
209
+ if self._cached_game is None:
210
+ self._cached_game = self._env.get_game()
211
+ return self._cached_game
212
+
213
+ def _make_group_map(self, group_map, agent_names):
214
+ if group_map is None:
215
+ group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names)
216
+ elif isinstance(group_map, MarlGroupMapType):
217
+ group_map = group_map.get_group_map(agent_names)
218
+ check_marl_grouping(group_map, agent_names)
219
+ return group_map
220
+
221
+ def _make_group_specs(
222
+ self,
223
+ env,
224
+ group: str,
225
+ ):
226
+ observation_specs = []
227
+ action_specs = []
228
+ reward_specs = []
229
+ game = env.get_game()
230
+
231
+ for _ in self.group_map[group]:
232
+ observation_spec = Composite()
233
+
234
+ if self.has_observation:
235
+ observation_spec["observation"] = Unbounded(
236
+ shape=(*game.observation_tensor_shape(),),
237
+ device=self.device,
238
+ domain="continuous",
239
+ )
240
+
241
+ if self.has_information_state:
242
+ observation_spec["information_state"] = Unbounded(
243
+ shape=(*game.information_state_tensor_shape(),),
244
+ device=self.device,
245
+ domain="continuous",
246
+ )
247
+
248
+ observation_specs.append(observation_spec)
249
+
250
+ action_spec_cls = Categorical if self.categorical_actions else OneHot
251
+ action_specs.append(
252
+ Composite(
253
+ action=action_spec_cls(
254
+ env.num_distinct_actions(),
255
+ dtype=torch.int64,
256
+ device=self.device,
257
+ )
258
+ )
259
+ )
260
+
261
+ reward_specs.append(
262
+ Composite(
263
+ reward=Unbounded(
264
+ shape=(1,),
265
+ device=self.device,
266
+ domain="continuous",
267
+ )
268
+ )
269
+ )
270
+
271
+ group_observation_spec = torch.stack(
272
+ observation_specs, dim=0
273
+ ) # shape = (n_agents, n_obser_per_agent)
274
+ group_action_spec = torch.stack(
275
+ action_specs, dim=0
276
+ ) # shape = (n_agents, n_actions_per_agent)
277
+ group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1)
278
+
279
+ return (
280
+ group_observation_spec,
281
+ group_action_spec,
282
+ group_reward_spec,
283
+ )
284
+
285
+ def _make_specs(self, env: pyspiel.State) -> None: # noqa: F821
286
+ self.agent_names = [f"player_{index}" for index in range(env.num_players())]
287
+ self.agent_names_to_indices_map = {
288
+ agent_name: i for i, agent_name in enumerate(self.agent_names)
289
+ }
290
+ self.group_map = self._make_group_map(self.group_map, self.agent_names)
291
+ self.done_spec = Categorical(
292
+ n=2,
293
+ shape=torch.Size((1,)),
294
+ dtype=torch.bool,
295
+ device=self.device,
296
+ )
297
+ game = env.get_game()
298
+ game_type = game.get_type()
299
+ # In OpenSpiel, a game's state may have either an "observation" tensor,
300
+ # an "information state" tensor, or both. If the OpenSpiel game does not
301
+ # have one of these, then its corresponding accessor functions raise an
302
+ # error, so we must avoid calling them.
303
+ self.has_observation = game_type.provides_observation_tensor
304
+ self.has_information_state = game_type.provides_information_state_tensor
305
+
306
+ observation_spec = {}
307
+ action_spec = {}
308
+ reward_spec = {}
309
+
310
+ for group in self.group_map.keys():
311
+ (
312
+ group_observation_spec,
313
+ group_action_spec,
314
+ group_reward_spec,
315
+ ) = self._make_group_specs(
316
+ env,
317
+ group,
318
+ )
319
+ observation_spec[group] = group_observation_spec
320
+ action_spec[group] = group_action_spec
321
+ reward_spec[group] = group_reward_spec
322
+
323
+ if self.return_state:
324
+ observation_spec["state"] = NonTensor([])
325
+
326
+ observation_spec["current_player"] = Unbounded(
327
+ shape=(),
328
+ dtype=torch.int,
329
+ device=self.device,
330
+ domain="discrete",
331
+ )
332
+
333
+ self.observation_spec = Composite(observation_spec)
334
+ self.action_spec = Composite(action_spec)
335
+ self.reward_spec = Composite(reward_spec)
336
+
337
+ def _set_seed(self, seed: int | None) -> None:
338
+ if seed is not None:
339
+ raise NotImplementedError("This environment has no seed.")
340
+
341
+ def current_player(self):
342
+ return self._env.current_player()
343
+
344
+ def _update_action_mask(self):
345
+ if self._env.is_terminal():
346
+ agents_acting = []
347
+ else:
348
+ agents_acting = [
349
+ self.agent_names
350
+ if self.parallel
351
+ else self.agent_names[self._env.current_player()]
352
+ ]
353
+ for group, agents in self.group_map.items():
354
+ action_masks = []
355
+ for agent in agents:
356
+ agent_index = self.agent_names_to_indices_map[agent]
357
+ if agent in agents_acting:
358
+ action_mask = torch.zeros(
359
+ self._env.num_distinct_actions(),
360
+ device=self.device,
361
+ dtype=torch.bool,
362
+ )
363
+ action_mask[self._env.legal_actions(agent_index)] = True
364
+ else:
365
+ action_mask = torch.zeros(
366
+ self._env.num_distinct_actions(),
367
+ device=self.device,
368
+ dtype=torch.bool,
369
+ )
370
+ # In OpenSpiel parallel games, non-acting players are
371
+ # expected to take action 0.
372
+ # https://openspiel.readthedocs.io/en/latest/api_reference/state_apply_action.html
373
+ action_mask[0] = True
374
+ action_masks.append(action_mask)
375
+ self.full_action_spec[group, "action"].update_mask(
376
+ torch.stack(action_masks, dim=0)
377
+ )
378
+
379
+ def _make_td_out(self, exclude_reward=False):
380
+ done = torch.tensor(
381
+ self._env.is_terminal(), device=self.device, dtype=torch.bool
382
+ )
383
+ current_player = torch.tensor(
384
+ self.current_player(), device=self.device, dtype=torch.int
385
+ )
386
+
387
+ source = {
388
+ "done": done,
389
+ "terminated": done.clone(),
390
+ "current_player": current_player,
391
+ }
392
+
393
+ if self.return_state:
394
+ source["state"] = self._env.serialize()
395
+
396
+ reward = self._env.returns()
397
+
398
+ for group, agent_names in self.group_map.items():
399
+ agent_tds = []
400
+
401
+ for agent in agent_names:
402
+ agent_index = self.agent_names_to_indices_map[agent]
403
+ agent_source = {}
404
+ if self.has_observation:
405
+ observation_shape = self._get_game().observation_tensor_shape()
406
+ agent_source["observation"] = self._to_tensor(
407
+ self._env.observation_tensor(agent_index)
408
+ ).reshape(observation_shape)
409
+
410
+ if self.has_information_state:
411
+ information_state_shape = (
412
+ self._get_game().information_state_tensor_shape()
413
+ )
414
+ agent_source["information_state"] = self._to_tensor(
415
+ self._env.information_state_tensor(agent_index)
416
+ ).reshape(information_state_shape)
417
+
418
+ if not exclude_reward:
419
+ agent_source["reward"] = self._to_tensor(reward[agent_index])
420
+
421
+ agent_td = TensorDict(
422
+ source=agent_source,
423
+ batch_size=self.batch_size,
424
+ device=self.device,
425
+ )
426
+ agent_tds.append(agent_td)
427
+
428
+ source[group] = torch.stack(agent_tds, dim=0)
429
+
430
+ tensordict_out = TensorDict(
431
+ source=source,
432
+ batch_size=self.batch_size,
433
+ device=self.device,
434
+ )
435
+
436
+ return tensordict_out
437
+
438
+ def _get_action_from_tensor(self, tensor):
439
+ if not self.categorical_actions:
440
+ action = torch.argmax(tensor, dim=-1)
441
+ else:
442
+ action = tensor
443
+ return action
444
+
445
+ def _step_parallel(self, tensordict: TensorDictBase):
446
+ actions = [0] * self._env.num_players()
447
+ for group, agents in self.group_map.items():
448
+ for index_in_group, agent in enumerate(agents):
449
+ agent_index = self.agent_names_to_indices_map[agent]
450
+ action_tensor = tensordict[group, "action"][index_in_group]
451
+ action = self._get_action_from_tensor(action_tensor)
452
+ actions[agent_index] = action
453
+
454
+ self._env.apply_actions(actions)
455
+
456
+ def _step_sequential(self, tensordict: TensorDictBase):
457
+ agent_index = self._env.current_player()
458
+
459
+ # If the game has ended, do nothing
460
+ if agent_index == self.lib.PlayerId.TERMINAL:
461
+ return
462
+
463
+ agent = self.agent_names[agent_index]
464
+ agent_group = None
465
+ agent_index_in_group = None
466
+
467
+ for group, agents in self.group_map.items():
468
+ if agent in agents:
469
+ agent_group = group
470
+ agent_index_in_group = agents.index(agent)
471
+ break
472
+
473
+ action_tensor = tensordict[agent_group, "action"][agent_index_in_group]
474
+ action = self._get_action_from_tensor(action_tensor)
475
+ self._env.apply_action(action)
476
+
477
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
478
+ if self.parallel:
479
+ self._step_parallel(tensordict)
480
+ else:
481
+ self._step_sequential(tensordict)
482
+
483
+ self._update_action_mask()
484
+ return self._make_td_out()
485
+
486
+ def _to_tensor(self, value):
487
+ return torch.tensor(value, device=self.device, dtype=torch.float32)
488
+
489
+ def _reset(
490
+ self, tensordict: TensorDictBase | None = None, **kwargs
491
+ ) -> TensorDictBase:
492
+ game = self._get_game()
493
+
494
+ if tensordict is not None and "state" in tensordict:
495
+ new_env = game.deserialize_state(tensordict["state"])
496
+ else:
497
+ new_env = game.new_initial_state()
498
+
499
+ self._env = new_env
500
+ self._update_action_mask()
501
+ return self._make_td_out(exclude_reward=True)
502
+
503
+
504
+ class OpenSpielEnv(OpenSpielWrapper):
505
+ """Google DeepMind OpenSpiel environment wrapper built with the game string.
506
+
507
+ GitHub: https://github.com/google-deepmind/open_spiel
508
+
509
+ Documentation: https://openspiel.readthedocs.io/en/latest/index.html
510
+
511
+ Args:
512
+ game_string (str): the name of the game to wrap. Must be part of
513
+ :attr:`~.available_envs`.
514
+
515
+ Keyword Args:
516
+ device (torch.device, optional): if provided, the device on which the data
517
+ is to be cast. Defaults to ``None``.
518
+ batch_size (torch.Size, optional): the batch size of the environment.
519
+ Defaults to ``torch.Size([])``.
520
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
521
+ for envs to be ``done`` just after :meth:`reset` is called.
522
+ Defaults to ``False``.
523
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
524
+ group agents in tensordicts for input/output. See
525
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
526
+ Defaults to
527
+ :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`.
528
+ categorical_actions (bool, optional): if ``True``, categorical specs
529
+ will be converted to the TorchRL equivalent
530
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
531
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
532
+ return_state (bool, optional): if ``True``, "state" is included in the
533
+ output of :meth:`reset` and :meth:`~step`. The state can be given
534
+ to :meth:`reset` to reset to that state, rather than resetting to
535
+ the initial state.
536
+ Defaults to ``False``.
537
+
538
+ Attributes:
539
+ available_envs: environments available to build
540
+
541
+ Examples:
542
+ >>> from torchrl.envs import OpenSpielEnv
543
+ >>> from tensordict import TensorDict
544
+ >>> env = OpenSpielEnv("chess", return_state=True)
545
+ >>> td = env.reset()
546
+ >>> td = env.step(env.full_action_spec.rand())
547
+ >>> print(td)
548
+ TensorDict(
549
+ fields={
550
+ agents: TensorDict(
551
+ fields={
552
+ action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
553
+ batch_size=torch.Size([]),
554
+ device=None,
555
+ is_shared=False),
556
+ next: TensorDict(
557
+ fields={
558
+ agents: TensorDict(
559
+ fields={
560
+ observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
561
+ reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
562
+ batch_size=torch.Size([2]),
563
+ device=None,
564
+ is_shared=False),
565
+ current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
566
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
567
+ state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
568
+ 674
569
+ , batch_size=torch.Size([]), device=None),
570
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
571
+ batch_size=torch.Size([]),
572
+ device=None,
573
+ is_shared=False)},
574
+ batch_size=torch.Size([]),
575
+ device=None,
576
+ is_shared=False)
577
+ >>> print(env.available_envs)
578
+ ['2048', 'add_noise', 'amazons', 'backgammon', ...]
579
+
580
+ :meth:`reset` can restore a specific state, rather than the initial state,
581
+ as long as ``return_state=True``.
582
+
583
+ >>> from torchrl.envs import OpenSpielEnv
584
+ >>> from tensordict import TensorDict
585
+ >>> env = OpenSpielEnv("chess", return_state=True)
586
+ >>> td = env.reset()
587
+ >>> td = env.step(env.full_action_spec.rand())
588
+ >>> td_restore = td["next"]
589
+ >>> td = env.step(env.full_action_spec.rand())
590
+ >>> # Current state is not equal `td_restore`
591
+ >>> (td["next"] == td_restore).all()
592
+ False
593
+ >>> td = env.reset(td_restore)
594
+ >>> # After resetting, now the current state is equal to `td_restore`
595
+ >>> (td == td_restore).all()
596
+ True
597
+ """
598
+
599
+ def __init__(
600
+ self,
601
+ game_string,
602
+ *,
603
+ group_map: MarlGroupMapType
604
+ | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
605
+ categorical_actions=False,
606
+ return_state: bool = False,
607
+ **kwargs,
608
+ ):
609
+ kwargs["game_string"] = game_string
610
+ super().__init__(
611
+ group_map=group_map,
612
+ categorical_actions=categorical_actions,
613
+ return_state=return_state,
614
+ **kwargs,
615
+ )
616
+
617
+ def _build_env(
618
+ self,
619
+ game_string: str,
620
+ **kwargs,
621
+ ) -> pyspiel.State: # noqa: F821
622
+ if not _has_pyspiel:
623
+ raise ImportError(
624
+ f"open_spiel not found, unable to create {game_string}. Consider "
625
+ f"downloading and installing open_spiel from {self.git_url}"
626
+ )
627
+ requires_grad = kwargs.pop("requires_grad", False)
628
+ parameters = kwargs.pop("parameters", None)
629
+ if kwargs:
630
+ raise ValueError("kwargs not supported.")
631
+
632
+ if parameters:
633
+ game = self.lib.load_game(game_string, parameters=parameters)
634
+ else:
635
+ game = self.lib.load_game(game_string)
636
+
637
+ env = game.new_initial_state()
638
+ return super()._build_env(
639
+ env,
640
+ requires_grad=requires_grad,
641
+ )
642
+
643
+ @property
644
+ def game_string(self):
645
+ return self._constructor_kwargs["game_string"]
646
+
647
+ def _check_kwargs(self, kwargs: dict):
648
+ if "game_string" not in kwargs:
649
+ raise TypeError("Expected 'game_string' to be part of kwargs")
650
+
651
+ def __repr__(self) -> str:
652
+ return f"{self.__class__.__name__}(env={self.game_string}, batch_size={self.batch_size}, device={self.device})"