torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .chess import ChessEnv
7
+ from .llm import LLMHashingEnv
8
+ from .pendulum import PendulumEnv
9
+ from .tictactoeenv import TicTacToeEnv
10
+
11
+ __all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv"]
@@ -0,0 +1,617 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import io
9
+ import pathlib
10
+
11
+ import torch
12
+ from tensordict import TensorDict, TensorDictBase
13
+ from torchrl.data.tensor_specs import (
14
+ Binary,
15
+ Bounded,
16
+ Categorical,
17
+ Composite,
18
+ NonTensor,
19
+ Unbounded,
20
+ )
21
+ from torchrl.envs import EnvBase
22
+ from torchrl.envs.common import _EnvPostInit
23
+ from torchrl.envs.utils import _classproperty
24
+
25
+
26
+ class _ChessMeta(_EnvPostInit):
27
+ def __call__(cls, *args, **kwargs):
28
+ instance = super().__call__(*args, **kwargs)
29
+ include_hash = kwargs.get("include_hash")
30
+ include_hash_inv = kwargs.get("include_hash_inv")
31
+ if include_hash:
32
+ from torchrl.envs import Hash
33
+
34
+ in_keys = []
35
+ out_keys = []
36
+ in_keys_inv = [] if include_hash_inv else None
37
+ out_keys_inv = [] if include_hash_inv else None
38
+
39
+ def maybe_add_keys(condition, in_key, out_key):
40
+ if condition:
41
+ in_keys.append(in_key)
42
+ out_keys.append(out_key)
43
+ if include_hash_inv:
44
+ in_keys_inv.append(in_key)
45
+ out_keys_inv.append(out_key)
46
+
47
+ maybe_add_keys(instance.include_san, "san", "san_hash")
48
+ maybe_add_keys(instance.include_fen, "fen", "fen_hash")
49
+ maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash")
50
+
51
+ instance = instance.append_transform(
52
+ Hash(in_keys, out_keys, in_keys_inv, out_keys_inv)
53
+ )
54
+ elif include_hash_inv:
55
+ raise ValueError(
56
+ "'include_hash_inv=True' can only be set if"
57
+ f"'include_hash=True', but got 'include_hash={include_hash}'."
58
+ )
59
+ if kwargs.get("mask_actions", True):
60
+ from torchrl.envs import ActionMask
61
+
62
+ instance = instance.append_transform(ActionMask())
63
+ return instance
64
+
65
+
66
+ class ChessEnv(EnvBase, metaclass=_ChessMeta):
67
+ r"""A chess environment that follows the TorchRL API.
68
+
69
+ This environment simulates a chess game using the `chess` library. It supports various state representations
70
+ and can be configured to include different types of observations such as SAN, FEN, PGN, and legal moves.
71
+
72
+ Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
73
+
74
+ Args:
75
+ stateful (bool): Whether to keep track of the internal state of the board.
76
+ If False, the state will be stored in the observation and passed back
77
+ to the environment on each call. Default: ``True``.
78
+ include_san (bool): Whether to include SAN (Standard Algebraic Notation) in the observations. Default: ``False``.
79
+ The ``"san"`` entry corresponding to ``rollout["action"]`` will be found in ``rollout["next", "san"]``,
80
+ whereas the value at the root ``rollout["san"]`` will correspond to the value of the san preceding the
81
+ same index action.
82
+ include_fen (bool): Whether to include FEN (Forsyth-Edwards Notation) in the observations. Default: ``False``.
83
+ include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
84
+ include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
85
+ include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
86
+ mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
87
+ to the env to make sure that the actions are properly masked. Default: ``True``.
88
+ pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
89
+
90
+ .. note::
91
+ The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
92
+ The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves
93
+ being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
94
+
95
+ Examples:
96
+ >>> import torch
97
+ >>> from torchrl.envs import ChessEnv
98
+ >>> _ = torch.manual_seed(0)
99
+ >>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
100
+ >>> print(env)
101
+ TransformedEnv(
102
+ env=ChessEnv(),
103
+ transform=ActionMask(keys=['action', 'action_mask']))
104
+ >>> r = env.reset()
105
+ >>> print(env.rand_step(r))
106
+ TensorDict(
107
+ fields={
108
+ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
109
+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
110
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
111
+ fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
112
+ legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
113
+ next: TensorDict(
114
+ fields={
115
+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
116
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
117
+ fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
118
+ legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
119
+ pgn: NonTensorData(data=[Event "?"]
120
+ [Site "?"]
121
+ [Date "????.??.??"]
122
+ [Round "?"]
123
+ [White "?"]
124
+ [Black "?"]
125
+ [Result "*"]
126
+
127
+ 1. f4 *, batch_size=torch.Size([]), device=None),
128
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
129
+ san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None),
130
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
131
+ turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
132
+ batch_size=torch.Size([]),
133
+ device=None,
134
+ is_shared=False),
135
+ pgn: NonTensorData(data=[Event "?"]
136
+ [Site "?"]
137
+ [Date "????.??.??"]
138
+ [Round "?"]
139
+ [White "?"]
140
+ [Black "?"]
141
+ [Result "*"]
142
+
143
+ *, batch_size=torch.Size([]), device=None),
144
+ san: NonTensorData(data=<start>, batch_size=torch.Size([]), device=None),
145
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
146
+ turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
147
+ batch_size=torch.Size([]),
148
+ device=None,
149
+ is_shared=False)
150
+ >>> print(env.rollout(1000))
151
+ TensorDict(
152
+ fields={
153
+ action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
154
+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
155
+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
156
+ fen: NonTensorStack(
157
+ ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
158
+ batch_size=torch.Size([96]),
159
+ device=None),
160
+ legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
161
+ next: TensorDict(
162
+ fields={
163
+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
164
+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
165
+ fen: NonTensorStack(
166
+ ['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ...,
167
+ batch_size=torch.Size([96]),
168
+ device=None),
169
+ legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
170
+ pgn: NonTensorStack(
171
+ ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
172
+ batch_size=torch.Size([96]),
173
+ device=None),
174
+ reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False),
175
+ san: NonTensorStack(
176
+ ['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra...,
177
+ batch_size=torch.Size([96]),
178
+ device=None),
179
+ terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
180
+ turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
181
+ batch_size=torch.Size([96]),
182
+ device=None,
183
+ is_shared=False),
184
+ pgn: NonTensorStack(
185
+ ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
186
+ batch_size=torch.Size([96]),
187
+ device=None),
188
+ san: NonTensorStack(
189
+ ['<start>', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',...,
190
+ batch_size=torch.Size([96]),
191
+ device=None),
192
+ terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
193
+ turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
194
+ batch_size=torch.Size([96]),
195
+ device=None,
196
+ is_shared=False)
197
+ """ # noqa: D301
198
+
199
+ _hash_table: dict[int, str] = {}
200
+ _PGN_RESTART = """[Event "?"]
201
+ [Site "?"]
202
+ [Date "????.??.??"]
203
+ [Round "?"]
204
+ [White "?"]
205
+ [Black "?"]
206
+ [Result "*"]
207
+
208
+ *"""
209
+
210
+ @_classproperty
211
+ def lib(cls):
212
+ try:
213
+ import chess
214
+ import chess.pgn
215
+ except ImportError:
216
+ raise ImportError(
217
+ "The `chess` library could not be found. Make sure you installed it through `pip install chess`."
218
+ )
219
+ return chess
220
+
221
+ _san_moves = []
222
+
223
+ @_classproperty
224
+ def san_moves(cls):
225
+ if not cls._san_moves:
226
+ with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
227
+ cls._san_moves.extend(f.read().split("\n"))
228
+ return cls._san_moves
229
+
230
+ def _legal_moves_to_index(
231
+ self,
232
+ tensordict: TensorDictBase | None = None,
233
+ board: chess.Board | None = None, # noqa: F821
234
+ return_mask: bool = False,
235
+ pad: bool = False,
236
+ ) -> torch.Tensor:
237
+ if not self.stateful:
238
+ if tensordict is None:
239
+ # trust the board
240
+ pass
241
+ elif self.include_fen:
242
+ fen = tensordict.get("fen", None)
243
+ fen = fen.data
244
+ self.board.set_fen(fen)
245
+ board = self.board
246
+ elif self.include_pgn:
247
+ pgn = tensordict.get("pgn")
248
+ pgn = pgn.data
249
+ board = self._pgn_to_board(pgn, self.board)
250
+
251
+ if board is None:
252
+ board = self.board
253
+
254
+ indices = torch.tensor(
255
+ [self._san_moves.index(board.san(m)) for m in board.legal_moves],
256
+ dtype=torch.int64,
257
+ )
258
+ mask = None
259
+ if return_mask:
260
+ mask = self._move_index_to_mask(indices)
261
+ if pad:
262
+ indices = torch.nn.functional.pad(
263
+ indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
264
+ )
265
+ if return_mask:
266
+ return indices, mask
267
+ return indices
268
+
269
+ @classmethod
270
+ def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor:
271
+ return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_(
272
+ 0, indices, True
273
+ )
274
+
275
+ def __init__(
276
+ self,
277
+ *,
278
+ stateful: bool = True,
279
+ include_san: bool = False,
280
+ include_fen: bool = False,
281
+ include_pgn: bool = False,
282
+ include_legal_moves: bool = False,
283
+ include_hash: bool = False,
284
+ include_hash_inv: bool = False,
285
+ mask_actions: bool = True,
286
+ pixels: bool = False,
287
+ ):
288
+ chess = self.lib
289
+ super().__init__()
290
+ self.full_observation_spec = Composite(
291
+ turn=Categorical(n=2, dtype=torch.bool, shape=()),
292
+ )
293
+ self.include_san = include_san
294
+ self.include_fen = include_fen
295
+ self.include_pgn = include_pgn
296
+ self.mask_actions = mask_actions
297
+ self.include_legal_moves = include_legal_moves
298
+ if include_legal_moves:
299
+ # 218 max possible legal moves per chess board position
300
+ # https://www.stmintz.com/ccc/index.php?id=424966
301
+ # len(self.san_moves)+1 is the padding value
302
+ self.full_observation_spec["legal_moves"] = Bounded(
303
+ 0, 1 + len(self.san_moves), shape=(218,), dtype=torch.int64
304
+ )
305
+ if include_san:
306
+ self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6")
307
+ if include_pgn:
308
+ self.full_observation_spec["pgn"] = NonTensor(
309
+ shape=(), example_data=self._PGN_RESTART
310
+ )
311
+ if include_fen:
312
+ self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any")
313
+ if not stateful and not (include_pgn or include_fen):
314
+ raise RuntimeError(
315
+ "At least one state representation (pgn or fen) must be enabled when stateful "
316
+ f"is {stateful}."
317
+ )
318
+
319
+ self.stateful = stateful
320
+
321
+ # state_spec is loosely defined as such - it's not really an issue that extra keys
322
+ # can go missing but it allows us to reset the env using fen passed to the reset
323
+ # method.
324
+ self.full_state_spec = self.full_observation_spec.clone()
325
+
326
+ self.pixels = pixels
327
+ if pixels:
328
+ if importlib.util.find_spec("cairosvg") is None:
329
+ raise ImportError(
330
+ "Please install cairosvg to use this environment with pixel rendering."
331
+ )
332
+ if importlib.util.find_spec("torchvision") is None:
333
+ raise ImportError(
334
+ "Please install torchvision to use this environment with pixel rendering."
335
+ )
336
+ self.full_observation_spec["pixels"] = Unbounded(
337
+ shape=(3, 390, 390), dtype=torch.uint8
338
+ )
339
+
340
+ self.full_action_spec = Composite(
341
+ action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64)
342
+ )
343
+ self.full_reward_spec = Composite(
344
+ reward=Unbounded(shape=(1,), dtype=torch.float32)
345
+ )
346
+ if self.mask_actions:
347
+ self.full_observation_spec["action_mask"] = Binary(
348
+ n=len(self.san_moves), dtype=torch.bool
349
+ )
350
+
351
+ # done spec generated automatically
352
+ self.board = chess.Board()
353
+ if self.stateful:
354
+ self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))
355
+
356
+ def _is_done(self, board):
357
+ return board.is_game_over() | board.is_fifty_moves()
358
+
359
+ def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
360
+ if not self.mask_actions:
361
+ raise RuntimeError(
362
+ "Cannot generate legal actions since 'mask_actions=False' was "
363
+ "set. If you really want to generate all actions, not just "
364
+ "legal ones, call 'env.full_action_spec.enumerate()'."
365
+ )
366
+ return super().all_actions(tensordict)
367
+
368
+ def _reset(self, tensordict=None):
369
+ fen = None
370
+ pgn = None
371
+ if tensordict is not None:
372
+ dest = tensordict.empty()
373
+ if self.include_fen:
374
+ fen = tensordict.get("fen", None)
375
+ if fen is not None:
376
+ fen = fen.data
377
+ elif self.include_pgn:
378
+ pgn = tensordict.get("pgn", None)
379
+ if pgn is not None:
380
+ pgn = pgn.data
381
+ else:
382
+ dest = TensorDict()
383
+
384
+ if fen is None and pgn is None:
385
+ self.board.reset()
386
+ elif fen is not None:
387
+ self.board.set_fen(fen)
388
+ if self._is_done(self.board):
389
+ raise ValueError(
390
+ "Cannot reset to a fen that is a gameover state." f" fen: {fen}"
391
+ )
392
+ elif pgn is not None:
393
+ self.board = self._pgn_to_board(pgn)
394
+
395
+ if self.include_fen and fen is None:
396
+ fen = self.board.fen()
397
+ if self.include_pgn and pgn is None:
398
+ pgn = self._board_to_pgn(self.board)
399
+
400
+ turn = self.board.turn
401
+ if self.include_san:
402
+ if self.board.move_stack:
403
+ move = self.board.peek()
404
+ else:
405
+ move = None
406
+ if move is None:
407
+ dest.set("san", "<start>")
408
+ else:
409
+ dest.set("san", self.board.san(move))
410
+ if self.include_fen:
411
+ dest.set("fen", fen)
412
+ if self.include_pgn:
413
+ dest.set("pgn", pgn)
414
+ dest.set("turn", turn)
415
+ if self.include_legal_moves:
416
+ moves_idx = self._legal_moves_to_index(
417
+ board=self.board, pad=True, return_mask=self.mask_actions
418
+ )
419
+ if self.mask_actions:
420
+ moves_idx, mask = moves_idx
421
+ dest.set("action_mask", mask)
422
+ dest.set("legal_moves", moves_idx)
423
+ elif self.mask_actions:
424
+ dest.set(
425
+ "action_mask",
426
+ self._legal_moves_to_index(
427
+ board=self.board, pad=True, return_mask=True
428
+ )[1],
429
+ )
430
+
431
+ if self.pixels:
432
+ dest.set("pixels", self._get_tensor_image(board=self.board))
433
+ return dest
434
+
435
+ _cairosvg_lib = None
436
+
437
+ @_classproperty
438
+ def _cairosvg(cls):
439
+ csvg = cls._cairosvg_lib
440
+ if csvg is None:
441
+ import cairosvg
442
+
443
+ csvg = cls._cairosvg_lib = cairosvg
444
+ return csvg
445
+
446
+ _torchvision_lib = None
447
+
448
+ @_classproperty
449
+ def _torchvision(cls):
450
+ tv = cls._torchvision_lib
451
+ if tv is None:
452
+ import torchvision
453
+
454
+ tv = cls._torchvision_lib = torchvision
455
+ return tv
456
+
457
+ @classmethod
458
+ def _get_tensor_image(cls, board):
459
+ try:
460
+ from PIL import Image
461
+
462
+ svg = board._repr_svg_()
463
+ # Convert SVG to PNG using cairosvg
464
+ png_data = io.BytesIO()
465
+ cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data)
466
+ png_data.seek(0)
467
+ # Open the PNG image using Pillow
468
+ img = Image.open(png_data)
469
+ img = cls._torchvision.transforms.functional.pil_to_tensor(img)
470
+ except ImportError:
471
+ raise ImportError(
472
+ "Chess rendering requires cairosvg, PIL and torchvision to be installed."
473
+ )
474
+ return img
475
+
476
+ @classmethod
477
+ def _pgn_to_board(
478
+ cls, pgn_string: str, board: chess.Board | None = None # noqa: F821
479
+ ) -> chess.Board: # noqa: F821
480
+ pgn_io = io.StringIO(pgn_string)
481
+ game = cls.lib.pgn.read_game(pgn_io)
482
+ if board is None:
483
+ board = cls.lib.Board()
484
+ else:
485
+ board.reset()
486
+ for move in game.mainline_moves():
487
+ board.push(move)
488
+ return board
489
+
490
+ @classmethod
491
+ def _add_move_to_pgn(cls, pgn_string: str, move: chess.Move) -> str: # noqa: F821
492
+ pgn_io = io.StringIO(pgn_string)
493
+ game = cls.lib.pgn.read_game(pgn_io)
494
+ if game is None:
495
+ raise ValueError("Invalid PGN string")
496
+ game.end().add_variation(move)
497
+ return str(game)
498
+
499
+ @classmethod
500
+ def _board_to_pgn(cls, board: chess.Board) -> str: # noqa: F821
501
+ game = cls.lib.pgn.Game.from_board(board)
502
+ pgn_string = str(game)
503
+ return pgn_string
504
+
505
+ def get_legal_moves(self, tensordict=None, uci=False):
506
+ """List the legal moves in a position.
507
+
508
+ To choose one of the actions, the "action" key can be set to the index
509
+ of the move in this list.
510
+
511
+ Args:
512
+ tensordict (TensorDict, optional): Tensordict containing the fen
513
+ string of a position. Required if not stateful. If stateful,
514
+ this argument is ignored and the current state of the env is
515
+ used instead.
516
+
517
+ uci (bool, optional): If ``False``, moves are given in SAN format.
518
+ If ``True``, moves are given in UCI format. Default is
519
+ ``False``.
520
+
521
+ """
522
+ board = self.board
523
+ if not self.stateful:
524
+ if tensordict is None:
525
+ raise ValueError(
526
+ "tensordict must be given since this env is not stateful"
527
+ )
528
+ fen = tensordict.get("fen").data
529
+ board.set_fen(fen)
530
+ moves = board.legal_moves
531
+
532
+ if uci:
533
+ return [board.uci(move) for move in moves]
534
+ else:
535
+ return [board.san(move) for move in moves]
536
+
537
+ def _step(self, tensordict):
538
+ # action
539
+ action = tensordict.get("action")
540
+ board = self.board
541
+
542
+ pgn = None
543
+ fen = None
544
+ if not self.stateful:
545
+ if self.include_fen:
546
+ fen = tensordict.get("fen").data
547
+ board.set_fen(fen)
548
+ elif self.include_pgn:
549
+ pgn = tensordict.get("pgn").data
550
+ board = self._pgn_to_board(pgn, board)
551
+ else:
552
+ raise RuntimeError(
553
+ "Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
554
+ )
555
+
556
+ san = self.san_moves[action]
557
+ board.push_san(san)
558
+
559
+ dest = tensordict.empty()
560
+
561
+ # Collect data
562
+ if self.include_fen:
563
+ fen = board.fen()
564
+ dest.set("fen", fen)
565
+
566
+ if self.include_pgn:
567
+ if pgn is not None:
568
+ pgn = self._add_move_to_pgn(pgn, board.move_stack[-1])
569
+ else:
570
+ pgn = self._board_to_pgn(board)
571
+ dest.set("pgn", pgn)
572
+
573
+ if self.include_san:
574
+ dest.set("san", san)
575
+
576
+ if self.include_legal_moves:
577
+ moves_idx = self._legal_moves_to_index(
578
+ board=board, pad=True, return_mask=self.mask_actions
579
+ )
580
+ if self.mask_actions:
581
+ moves_idx, mask = moves_idx
582
+ dest.set("action_mask", mask)
583
+ dest.set("legal_moves", moves_idx)
584
+ elif self.mask_actions:
585
+ dest.set(
586
+ "action_mask",
587
+ self._legal_moves_to_index(
588
+ board=self.board, pad=True, return_mask=True
589
+ )[1],
590
+ )
591
+
592
+ turn = torch.tensor(board.turn)
593
+ done = self._is_done(board)
594
+ if board.is_checkmate():
595
+ # turn flips after every move, even if the game is over
596
+ # winner = not turn
597
+ reward_val = 1 # if winner == self.lib.WHITE else 0
598
+ elif done:
599
+ reward_val = 0.5
600
+ else:
601
+ reward_val = 0.0
602
+
603
+ reward = torch.tensor([reward_val], dtype=torch.float32)
604
+ dest.set("reward", reward)
605
+ dest.set("turn", turn)
606
+ dest.set("done", torch.tensor([done]))
607
+ dest.set("terminated", torch.tensor([done]))
608
+ if self.pixels:
609
+ dest.set("pixels", self._get_tensor_image(board=self.board))
610
+ return dest
611
+
612
+ def _set_seed(self, *args, **kwargs) -> None:
613
+ ...
614
+
615
+ def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
616
+ self._set_action_space(tensordict)
617
+ return self.action_spec.cardinality()