torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,963 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+
9
+ import numpy as np
10
+ import torch
11
+ from packaging import version
12
+ from tensordict import TensorDict, TensorDictBase
13
+ from torchrl.envs.common import _EnvPostInit
14
+ from torchrl.envs.utils import _classproperty
15
+
16
+ _has_jumanji = importlib.util.find_spec("jumanji") is not None
17
+
18
+ from torchrl.data.tensor_specs import (
19
+ Bounded,
20
+ Categorical,
21
+ Composite,
22
+ DEVICE_TYPING,
23
+ MultiCategorical,
24
+ MultiOneHot,
25
+ OneHot,
26
+ TensorSpec,
27
+ Unbounded,
28
+ )
29
+ from torchrl.data.utils import numpy_to_torch_dtype_dict
30
+ from torchrl.envs.gym_like import GymLikeEnv
31
+
32
+ from torchrl.envs.libs.jax_utils import (
33
+ _extract_spec,
34
+ _ndarray_to_tensor,
35
+ _object_to_tensordict,
36
+ _tensordict_to_object,
37
+ _tree_flatten,
38
+ _tree_reshape,
39
+ )
40
+
41
+
42
+ def _get_envs():
43
+ if not _has_jumanji:
44
+ raise ImportError("Jumanji is not installed in your virtual environment.")
45
+ import jumanji
46
+
47
+ return jumanji.registered_environments()
48
+
49
+
50
+ def _jumanji_to_torchrl_spec_transform(
51
+ spec,
52
+ dtype: torch.dtype | None = None,
53
+ device: DEVICE_TYPING = None,
54
+ categorical_action_encoding: bool = True,
55
+ ) -> TensorSpec:
56
+ import jumanji
57
+
58
+ if isinstance(spec, jumanji.specs.DiscreteArray):
59
+ action_space_cls = Categorical if categorical_action_encoding else OneHot
60
+ if dtype is None:
61
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
62
+ return action_space_cls(spec.num_values, dtype=dtype, device=device)
63
+ if isinstance(spec, jumanji.specs.MultiDiscreteArray):
64
+ action_space_cls = (
65
+ MultiCategorical if categorical_action_encoding else MultiOneHot
66
+ )
67
+ if dtype is None:
68
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
69
+ return action_space_cls(
70
+ torch.as_tensor(np.asarray(spec.num_values)), dtype=dtype, device=device
71
+ )
72
+ elif isinstance(spec, jumanji.specs.BoundedArray):
73
+ shape = spec.shape
74
+ if dtype is None:
75
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
76
+ return Bounded(
77
+ shape=shape,
78
+ low=np.asarray(spec.minimum),
79
+ high=np.asarray(spec.maximum),
80
+ dtype=dtype,
81
+ device=device,
82
+ )
83
+ elif isinstance(spec, jumanji.specs.Array):
84
+ shape = spec.shape
85
+ if dtype is None:
86
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
87
+ if dtype in (torch.float, torch.double, torch.half):
88
+ return Unbounded(shape=shape, dtype=dtype, device=device)
89
+ else:
90
+ return Unbounded(shape=shape, dtype=dtype, device=device)
91
+ elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"):
92
+ new_spec = {}
93
+ for key, value in spec.__dict__.items():
94
+ if isinstance(value, jumanji.specs.Spec):
95
+ if key.endswith("_obs"):
96
+ key = key[:-4]
97
+ if key.endswith("_spec"):
98
+ key = key[:-5]
99
+ new_spec[key] = _jumanji_to_torchrl_spec_transform(
100
+ value, dtype, device, categorical_action_encoding
101
+ )
102
+ return Composite(**new_spec)
103
+ else:
104
+ raise TypeError(f"Unsupported spec type {type(spec)}")
105
+
106
+
107
+ class _JumanjiMakeRender(_EnvPostInit):
108
+ def __call__(self, *args, **kwargs):
109
+ instance = super().__call__(*args, **kwargs)
110
+ if instance.from_pixels:
111
+ return instance.make_render()
112
+ return instance
113
+
114
+
115
+ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
116
+ """Jumanji's environment wrapper.
117
+
118
+ Jumanji offers a vectorized simulation framework based on Jax.
119
+ TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
120
+ but computational graphs can still be built on top of the simulated trajectories,
121
+ allowing for backpropagation through the rollout.
122
+
123
+ GitHub: https://github.com/instadeepai/jumanji
124
+
125
+ Doc: https://instadeepai.github.io/jumanji/
126
+
127
+ Paper: https://arxiv.org/abs/2306.09884
128
+
129
+ .. note:: For better performance, turn `jit` on when instantiating this class.
130
+ The `jit` attribute can also be flipped during code execution:
131
+
132
+ >>> env.jit = True # Used jit
133
+ >>> env.jit = False # eager
134
+
135
+ Args:
136
+ env (jumanji.env.Environment): the env to wrap.
137
+ categorical_action_encoding (bool, optional): if ``True``, categorical
138
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
139
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
140
+ Defaults to ``False``.
141
+
142
+ Keyword Args:
143
+ batch_size (torch.Size, optional): the batch size of the environment.
144
+ With ``jumanji``, this indicates the number of vectorized environments.
145
+ If the batch-size is empty, the environment is not batch-locked and an arbitrary number
146
+ of environments can be executed simultaneously.
147
+ Defaults to ``torch.Size([])``.
148
+
149
+ >>> import jumanji
150
+ >>> from torchrl.envs import JumanjiWrapper
151
+ >>> base_env = jumanji.make("Snake-v1")
152
+ >>> env = JumanjiWrapper(base_env)
153
+ >>> # Set the batch-size of the TensorDict instead of the env allows to control the number
154
+ >>> # of envs being run simultaneously
155
+ >>> tdreset = env.reset(TensorDict(batch_size=[32]))
156
+ >>> # Execute a rollout until all envs are done or max steps is reached, whichever comes first
157
+ >>> rollout = env.rollout(100, break_when_all_done=True, auto_reset=False, tensordict=tdreset)
158
+
159
+ from_pixels (bool, optional): Whether the environment should render its output.
160
+ This will drastically impact the environment throughput. Only the first environment
161
+ will be rendered. See :meth:`~torchrl.envs.JumanjiWrapper.render` for more information.
162
+ Defaults to `False`.
163
+ frame_skip (int, optional): if provided, indicates for how many steps the
164
+ same action is to be repeated. The observation returned will be the
165
+ last observation of the sequence, whereas the reward will be the sum
166
+ of rewards across steps.
167
+ device (torch.device, optional): if provided, the device on which the data
168
+ is to be cast. Defaults to ``torch.device("cpu")``.
169
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
170
+ for envs to be ``done`` just after :meth:`reset` is called.
171
+ Defaults to ``False``.
172
+ jit (bool, optional): whether the step and reset method should be wrapped in `jit`.
173
+ Defaults to ``False``.
174
+
175
+ Attributes:
176
+ available_envs: environments available to build
177
+
178
+ Examples:
179
+ >>> import jumanji
180
+ >>> from torchrl.envs import JumanjiWrapper
181
+ >>> base_env = jumanji.make("Snake-v1")
182
+ >>> env = JumanjiWrapper(base_env)
183
+ >>> env.set_seed(0)
184
+ >>> td = env.reset()
185
+ >>> td["action"] = env.action_spec.rand()
186
+ >>> td = env.step(td)
187
+ >>> print(td)
188
+ TensorDict(
189
+ fields={
190
+ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
191
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
192
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
193
+ grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
194
+ next: TensorDict(
195
+ fields={
196
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
197
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
198
+ grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
199
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
200
+ state: TensorDict(
201
+ fields={
202
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
203
+ body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
204
+ body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
205
+ fruit_position: TensorDict(
206
+ fields={
207
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
208
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
209
+ batch_size=torch.Size([]),
210
+ device=cpu,
211
+ is_shared=False),
212
+ head_position: TensorDict(
213
+ fields={
214
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
215
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
216
+ batch_size=torch.Size([]),
217
+ device=cpu,
218
+ is_shared=False),
219
+ key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
220
+ length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
221
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
222
+ tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
223
+ batch_size=torch.Size([]),
224
+ device=cpu,
225
+ is_shared=False),
226
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
227
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
228
+ batch_size=torch.Size([]),
229
+ device=cpu,
230
+ is_shared=False),
231
+ state: TensorDict(
232
+ fields={
233
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
234
+ body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
235
+ body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
236
+ fruit_position: TensorDict(
237
+ fields={
238
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
239
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
240
+ batch_size=torch.Size([]),
241
+ device=cpu,
242
+ is_shared=False),
243
+ head_position: TensorDict(
244
+ fields={
245
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
246
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
247
+ batch_size=torch.Size([]),
248
+ device=cpu,
249
+ is_shared=False),
250
+ key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
251
+ length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
252
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
253
+ tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
254
+ batch_size=torch.Size([]),
255
+ device=cpu,
256
+ is_shared=False),
257
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
258
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
259
+ batch_size=torch.Size([]),
260
+ device=cpu,
261
+ is_shared=False)
262
+ >>> print(env.available_envs)
263
+ ['Game2048-v1',
264
+ 'Maze-v0',
265
+ 'Cleaner-v0',
266
+ 'CVRP-v1',
267
+ 'MultiCVRP-v0',
268
+ 'Minesweeper-v0',
269
+ 'RubiksCube-v0',
270
+ 'Knapsack-v1',
271
+ 'Sudoku-v0',
272
+ 'Snake-v1',
273
+ 'TSP-v1',
274
+ 'Connector-v2',
275
+ 'MMST-v0',
276
+ 'GraphColoring-v0',
277
+ 'RubiksCube-partly-scrambled-v0',
278
+ 'RobotWarehouse-v0',
279
+ 'Tetris-v0',
280
+ 'BinPack-v2',
281
+ 'Sudoku-very-easy-v0',
282
+ 'JobShop-v0']
283
+
284
+ To take advante of Jumanji, one usually executes multiple environments at the
285
+ same time.
286
+
287
+ >>> import jumanji
288
+ >>> from torchrl.envs import JumanjiWrapper
289
+ >>> base_env = jumanji.make("Snake-v1")
290
+ >>> env = JumanjiWrapper(base_env, batch_size=[10])
291
+ >>> env.set_seed(0)
292
+ >>> td = env.reset()
293
+ >>> td["action"] = env.action_spec.rand()
294
+ >>> td = env.step(td)
295
+
296
+ In the following example, we iteratively test different batch sizes
297
+ and report the execution time for a short rollout:
298
+
299
+ Examples:
300
+ >>> from torch.utils.benchmark import Timer
301
+ >>> for batch_size in [4, 16, 128]:
302
+ ... timer = Timer(
303
+ ... '''
304
+ ... env.rollout(100)
305
+ ... ''',
306
+ ... setup=f'''
307
+ ... from torchrl.envs import JumanjiWrapper
308
+ ... import jumanji
309
+ ... env = JumanjiWrapper(jumanji.make('Snake-v1'), batch_size=[{batch_size}])
310
+ ... env.set_seed(0)
311
+ ... env.rollout(2)
312
+ ... ''')
313
+ ... print(batch_size, timer.timeit(number=10))
314
+ 4
315
+ env.rollout(100)
316
+ setup: [...]
317
+ Median: 122.40 ms
318
+ 2 measurements, 1 runs per measurement, 1 thread
319
+
320
+ 16
321
+ env.rollout(100)
322
+ setup: [...]
323
+ Median: 134.39 ms
324
+ 2 measurements, 1 runs per measurement, 1 thread
325
+
326
+ 128
327
+ env.rollout(100)
328
+ setup: [...]
329
+ Median: 172.31 ms
330
+ 2 measurements, 1 runs per measurement, 1 thread
331
+
332
+ """
333
+
334
+ git_url = "https://github.com/instadeepai/jumanji"
335
+ libname = "jumanji"
336
+
337
+ @_classproperty
338
+ def available_envs(cls):
339
+ if not _has_jumanji:
340
+ return []
341
+ return sorted(_get_envs())
342
+
343
+ @property
344
+ def lib(self):
345
+ import jumanji
346
+
347
+ if version.parse(jumanji.__version__) < version.parse("1.0.0"):
348
+ raise ImportError("jumanji version must be >= 1.0.0")
349
+ return jumanji
350
+
351
+ def __init__(
352
+ self,
353
+ env: jumanji.env.Environment = None, # noqa: F821
354
+ categorical_action_encoding=True,
355
+ jit: bool = True,
356
+ **kwargs,
357
+ ):
358
+ if not _has_jumanji:
359
+ raise ImportError(
360
+ "jumanji is not installed or importing it failed. Consider checking your installation."
361
+ )
362
+ self.categorical_action_encoding = categorical_action_encoding
363
+ if env is not None:
364
+ kwargs["env"] = env
365
+ batch_locked = kwargs.pop("batch_locked", kwargs.get("batch_size") is not None)
366
+ super().__init__(**kwargs)
367
+ self._batch_locked = batch_locked
368
+ self.jit = jit
369
+
370
+ @property
371
+ def jit(self):
372
+ return self._jit
373
+
374
+ @jit.setter
375
+ def jit(self, value):
376
+ self._jit = value
377
+ if value:
378
+ import jax
379
+
380
+ self._env_reset = jax.jit(self._env.reset)
381
+ self._env_step = jax.jit(self._env.step)
382
+ else:
383
+ self._env_reset = self._env.reset
384
+ self._env_step = self._env.step
385
+
386
+ def _build_env(
387
+ self,
388
+ env,
389
+ _seed: int | None = None,
390
+ from_pixels: bool = False,
391
+ render_kwargs: dict | None = None,
392
+ pixels_only: bool = False,
393
+ camera_id: int | str = 0,
394
+ **kwargs,
395
+ ):
396
+ self.from_pixels = from_pixels
397
+ self.pixels_only = pixels_only
398
+
399
+ return env
400
+
401
+ def make_render(self):
402
+ """Returns a transformed environment that can be rendered.
403
+
404
+ Examples:
405
+ >>> from torchrl.envs import JumanjiEnv
406
+ >>> from torchrl.record import CSVLogger, VideoRecorder
407
+ >>>
408
+ >>> envname = JumanjiEnv.available_envs[-1]
409
+ >>> logger = CSVLogger("jumanji", video_format="mp4", video_fps=2)
410
+ >>> env = JumanjiEnv(envname, from_pixels=True)
411
+ >>>
412
+ >>> env = env.append_transform(
413
+ ... VideoRecorder(logger=logger, in_keys=["pixels"], tag=envname)
414
+ ... )
415
+ >>> env.set_seed(0)
416
+ >>> r = env.rollout(100)
417
+ >>> env.transform.dump()
418
+
419
+ """
420
+ from torchrl.record import PixelRenderTransform
421
+
422
+ return self.append_transform(
423
+ PixelRenderTransform(
424
+ out_keys=["pixels"],
425
+ pass_tensordict=True,
426
+ as_non_tensor=bool(self.batch_size),
427
+ as_numpy=bool(self.batch_size),
428
+ )
429
+ )
430
+
431
+ def _make_state_example(self, env):
432
+ import jax
433
+ from jax import numpy as jnp
434
+
435
+ key = jax.random.PRNGKey(0)
436
+ keys = jax.random.split(key, self.batch_size.numel())
437
+ state, _ = jax.vmap(env.reset)(jnp.stack(keys))
438
+ state = _tree_reshape(state, self.batch_size)
439
+ return state
440
+
441
+ def _make_state_spec(self, env) -> TensorSpec:
442
+ import jax
443
+
444
+ key = jax.random.PRNGKey(0)
445
+ state, _ = env.reset(key)
446
+ state_dict = _object_to_tensordict(state, self.device, batch_size=())
447
+ state_spec = _extract_spec(state_dict)
448
+ return state_spec
449
+
450
+ def _make_action_spec(self, env) -> TensorSpec:
451
+ action_spec = _jumanji_to_torchrl_spec_transform(
452
+ env.action_spec,
453
+ device=self.device,
454
+ categorical_action_encoding=self.categorical_action_encoding,
455
+ )
456
+ action_spec = action_spec.expand(*self.batch_size, *action_spec.shape)
457
+ return action_spec
458
+
459
+ def _make_observation_spec(self, env) -> TensorSpec:
460
+ jumanji = self.lib
461
+
462
+ spec = env.observation_spec
463
+ new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device)
464
+ if isinstance(spec, jumanji.specs.Array):
465
+ return Composite(observation=new_spec).expand(self.batch_size)
466
+ elif isinstance(spec, jumanji.specs.Spec):
467
+ return Composite(**{k: v for k, v in new_spec.items()}).expand(
468
+ self.batch_size
469
+ )
470
+ else:
471
+ raise TypeError(f"Unsupported spec type {type(spec)}")
472
+
473
+ def _make_reward_spec(self, env) -> TensorSpec:
474
+ reward_spec = _jumanji_to_torchrl_spec_transform(
475
+ env.reward_spec, device=self.device
476
+ )
477
+ if not len(reward_spec.shape):
478
+ reward_spec.shape = torch.Size([1])
479
+ return reward_spec.expand([*self.batch_size, *reward_spec.shape])
480
+
481
+ def _make_specs(self, env: jumanji.env.Environment) -> None: # noqa: F821
482
+
483
+ # extract spec from jumanji definition
484
+ self.action_spec = self._make_action_spec(env)
485
+ self.observation_spec = self._make_observation_spec(env)
486
+ self.reward_spec = self._make_reward_spec(env)
487
+
488
+ # extract state spec from instance
489
+ state_spec = self._make_state_spec(env).expand(self.batch_size)
490
+ self.state_spec["state"] = state_spec
491
+ self.observation_spec["state"] = state_spec.clone()
492
+
493
+ # build state example for data conversion
494
+ self._state_example = self._make_state_example(env)
495
+
496
+ def _check_kwargs(self, kwargs: dict):
497
+ jumanji = self.lib
498
+ if "env" not in kwargs:
499
+ raise TypeError("Could not find environment key 'env' in kwargs.")
500
+ env = kwargs["env"]
501
+ if not isinstance(env, (jumanji.env.Environment,)):
502
+ raise TypeError("env is not of type 'jumanji.env.Environment'.")
503
+
504
+ def _init_env(self):
505
+ pass
506
+
507
+ @property
508
+ def key(self):
509
+ key = getattr(self, "_key", None)
510
+ if key is None:
511
+ raise RuntimeError(
512
+ "the env.key attribute wasn't found. Make sure to call `env.set_seed(seed)` before any interaction."
513
+ )
514
+ return key
515
+
516
+ @key.setter
517
+ def key(self, value):
518
+ self._key = value
519
+
520
+ def _set_seed(self, seed: int | None) -> None:
521
+ import jax
522
+
523
+ if seed is None:
524
+ raise Exception("Jumanji requires an integer seed.")
525
+ self.key = jax.random.PRNGKey(seed)
526
+
527
+ def read_state(self, state, batch_size=None):
528
+ state_dict = _object_to_tensordict(
529
+ state, self.device, self.batch_size if batch_size is None else batch_size
530
+ )
531
+ return self.state_spec["state"].encode(state_dict)
532
+
533
+ def read_obs(self, obs, batch_size=None):
534
+ from jax import numpy as jnp
535
+
536
+ if isinstance(obs, (list, jnp.ndarray, np.ndarray)):
537
+ obs_dict = _ndarray_to_tensor(obs).to(self.device)
538
+ else:
539
+ obs_dict = _object_to_tensordict(
540
+ obs, self.device, self.batch_size if batch_size is None else batch_size
541
+ )
542
+ return super().read_obs(obs_dict)
543
+
544
+ def render(
545
+ self,
546
+ tensordict,
547
+ matplotlib_backend: str | None = None,
548
+ as_numpy: bool = False,
549
+ **kwargs,
550
+ ):
551
+ """Renders the environment output given an input tensordict.
552
+
553
+ This method is intended to be called by the :class:`~torchrl.record.PixelRenderTransform`
554
+ created whenever `from_pixels=True` is selected.
555
+ To create an appropriate rendering transform, use a similar call as bellow:
556
+
557
+ >>> from torchrl.record import PixelRenderTransform
558
+ >>> matplotlib_backend = None # Change this value if a specific matplotlib backend has to be used.
559
+ >>> env = env.append_transform(
560
+ ... PixelRenderTransform(out_keys=["pixels"], pass_tensordict=True, matplotlib_backend=matplotlib_backend)
561
+ ... )
562
+
563
+ This pipeline will write a `"pixels"` entry in your output tensordict.
564
+
565
+ Args:
566
+ tensordict (TensorDictBase): a tensordict containing a state to represent
567
+ matplotlib_backend (str, optional): the matplotlib backend
568
+ as_numpy (bool, optional): if ``False``, the np.ndarray will be converted to a torch.Tensor.
569
+ Defaults to ``False``.
570
+
571
+ """
572
+ import io
573
+
574
+ import jax
575
+ import jax.numpy as jnp
576
+ import jumanji
577
+
578
+ try:
579
+ import matplotlib
580
+ import matplotlib.pyplot as plt
581
+ import PIL
582
+ import torchvision.transforms.v2.functional
583
+ except ImportError as err:
584
+ raise ImportError(
585
+ "Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed."
586
+ ) from err
587
+
588
+ if matplotlib_backend is not None:
589
+ matplotlib.use(matplotlib_backend)
590
+
591
+ # Get only one env
592
+ _state_example = self._state_example
593
+ while tensordict.ndim:
594
+ tensordict = tensordict[0]
595
+ _state_example = jax.tree_util.tree_map(
596
+ lambda x: jnp.take(x, 0, axis=0), _state_example
597
+ )
598
+ # Patch jumanji is_notebook
599
+ is_notebook = jumanji.environments.is_notebook
600
+ try:
601
+ jumanji.environments.is_notebook = lambda: False
602
+
603
+ isinteractive = plt.isinteractive()
604
+ plt.ion()
605
+ buf = io.BytesIO()
606
+ state = _tensordict_to_object(
607
+ tensordict.get("state"),
608
+ _state_example,
609
+ batch_size=tensordict.batch_size if not self.batch_locked else None,
610
+ )
611
+ self._env.render(state, **kwargs)
612
+ plt.savefig(buf, format="png")
613
+ buf.seek(0)
614
+ # Load the image into a PIL object.
615
+ img = PIL.Image.open(buf)
616
+ img_array = torchvision.transforms.v2.functional.pil_to_tensor(img)
617
+ if not isinteractive:
618
+ plt.ioff()
619
+ plt.close()
620
+ if not as_numpy:
621
+ return img_array[:3]
622
+ return img_array[:3].numpy().copy()
623
+ finally:
624
+ jumanji.environments.is_notebook = is_notebook
625
+
626
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
627
+ import jax
628
+
629
+ if self.batch_locked:
630
+ batch_size = self.batch_size
631
+ else:
632
+ batch_size = tensordict.batch_size
633
+
634
+ # prepare inputs
635
+ state = _tensordict_to_object(
636
+ tensordict.get("state"),
637
+ self._state_example,
638
+ batch_size=tensordict.batch_size if not self.batch_locked else None,
639
+ )
640
+ action = self.read_action(tensordict.get("action"))
641
+
642
+ # flatten batch size into vector
643
+ state = _tree_flatten(state, batch_size)
644
+ action = _tree_flatten(action, batch_size)
645
+
646
+ # jax vectorizing map on env.step
647
+ state, timestep = jax.vmap(self._env_step)(state, action)
648
+
649
+ # reshape batch size from vector
650
+ state = _tree_reshape(state, batch_size)
651
+ timestep = _tree_reshape(timestep, batch_size)
652
+
653
+ # collect outputs
654
+ state_dict = self.read_state(state, batch_size=batch_size)
655
+ obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
656
+ reward = self.read_reward(np.asarray(timestep.reward))
657
+ done = timestep.step_type == self.lib.types.StepType.LAST
658
+ done = _ndarray_to_tensor(done).view(torch.bool).to(self.device)
659
+
660
+ # build results
661
+ tensordict_out = TensorDict(
662
+ source=obs_dict,
663
+ batch_size=tensordict.batch_size,
664
+ device=self.device,
665
+ )
666
+ tensordict_out.set("reward", reward)
667
+ tensordict_out.set("done", done)
668
+ tensordict_out.set("terminated", done)
669
+ # tensordict_out.set("terminated", done)
670
+ tensordict_out["state"] = state_dict
671
+
672
+ return tensordict_out
673
+
674
+ def _reset(
675
+ self, tensordict: TensorDictBase | None = None, **kwargs
676
+ ) -> TensorDictBase:
677
+ import jax
678
+ from jax import numpy as jnp
679
+
680
+ if self.batch_locked or tensordict is None:
681
+ numel = self.numel()
682
+ batch_size = self.batch_size
683
+ elif tensordict is not None:
684
+ numel = tensordict.numel()
685
+ batch_size = tensordict.batch_size
686
+
687
+ # generate random keys
688
+ self.key, *keys = jax.random.split(self.key, numel + 1)
689
+
690
+ # jax vectorizing map on env.reset
691
+ state, timestep = jax.vmap(self._env_reset)(jnp.stack(keys))
692
+
693
+ # reshape batch size from vector
694
+ state = _tree_reshape(state, batch_size)
695
+ timestep = _tree_reshape(timestep, batch_size)
696
+
697
+ # collect outputs
698
+ state_dict = self.read_state(state, batch_size=batch_size)
699
+ obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
700
+ if not self.batch_locked:
701
+ done_td = self.full_done_spec.zero(batch_size)
702
+ else:
703
+ done_td = self.full_done_spec.zero()
704
+
705
+ # build results
706
+ tensordict_out = TensorDict(
707
+ source=obs_dict,
708
+ batch_size=batch_size,
709
+ device=self.device,
710
+ )
711
+ tensordict_out.update(done_td)
712
+ tensordict_out["state"] = state_dict
713
+
714
+ return tensordict_out
715
+
716
+ def read_reward(self, reward):
717
+ """Reads the reward and maps it to the reward space.
718
+
719
+ Args:
720
+ reward (torch.Tensor or TensorDict): reward to be mapped.
721
+
722
+ """
723
+ if isinstance(reward, int) and reward == 0:
724
+ return self.reward_spec.zero()
725
+ if self.batch_locked:
726
+ reward = self.reward_spec.encode(reward, ignore_device=True)
727
+ else:
728
+ reward = torch.as_tensor(reward)
729
+ if not reward.ndim or (reward.shape[-1] != self.reward_spec.shape[-1]):
730
+ reward = reward.unsqueeze(-1)
731
+
732
+ if reward is None:
733
+ reward = torch.tensor(np.nan).expand(self.reward_spec.shape)
734
+
735
+ return reward
736
+
737
+ def _output_transform(self, step_outputs_tuple: tuple) -> tuple:
738
+ ...
739
+
740
+ def _reset_output_transform(self, reset_outputs_tuple: tuple) -> tuple:
741
+ ...
742
+
743
+
744
+ class JumanjiEnv(JumanjiWrapper):
745
+ """Jumanji environment wrapper built with the environment name.
746
+
747
+ Jumanji offers a vectorized simulation framework based on Jax.
748
+ TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
749
+ but computational graphs can still be built on top of the simulated trajectories,
750
+ allowing for backpropagation through the rollout.
751
+
752
+ GitHub: https://github.com/instadeepai/jumanji
753
+
754
+ Doc: https://instadeepai.github.io/jumanji/
755
+
756
+ Paper: https://arxiv.org/abs/2306.09884
757
+
758
+ Args:
759
+ env_name (str): the name of the environment to wrap. Must be part of :attr:`~.available_envs`.
760
+ categorical_action_encoding (bool, optional): if ``True``, categorical
761
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
762
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
763
+ Defaults to ``False``.
764
+
765
+ Keyword Args:
766
+ from_pixels (bool, optional): Not yet supported.
767
+ frame_skip (int, optional): if provided, indicates for how many steps the
768
+ same action is to be repeated. The observation returned will be the
769
+ last observation of the sequence, whereas the reward will be the sum
770
+ of rewards across steps.
771
+ device (torch.device, optional): if provided, the device on which the data
772
+ is to be cast. Defaults to ``torch.device("cpu")``.
773
+ batch_size (torch.Size, optional): the batch size of the environment.
774
+ With ``jumanji``, this indicates the number of vectorized environments.
775
+ Defaults to ``torch.Size([])``.
776
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
777
+ for envs to be ``done`` just after :meth:`reset` is called.
778
+ Defaults to ``False``.
779
+
780
+ Attributes:
781
+ available_envs: environments available to build
782
+
783
+ Examples:
784
+ >>> from torchrl.envs import JumanjiEnv
785
+ >>> env = JumanjiEnv("Snake-v1")
786
+ >>> env.set_seed(0)
787
+ >>> td = env.reset()
788
+ >>> td["action"] = env.action_spec.rand()
789
+ >>> td = env.step(td)
790
+ >>> print(td)
791
+ TensorDict(
792
+ fields={
793
+ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
794
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
795
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
796
+ grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
797
+ next: TensorDict(
798
+ fields={
799
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
800
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
801
+ grid: Tensor(shape=torch.Size([12, 12, 5]), device=cpu, dtype=torch.float32, is_shared=False),
802
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
803
+ state: TensorDict(
804
+ fields={
805
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
806
+ body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
807
+ body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
808
+ fruit_position: TensorDict(
809
+ fields={
810
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
811
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
812
+ batch_size=torch.Size([]),
813
+ device=cpu,
814
+ is_shared=False),
815
+ head_position: TensorDict(
816
+ fields={
817
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
818
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
819
+ batch_size=torch.Size([]),
820
+ device=cpu,
821
+ is_shared=False),
822
+ key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
823
+ length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
824
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
825
+ tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
826
+ batch_size=torch.Size([]),
827
+ device=cpu,
828
+ is_shared=False),
829
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
830
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
831
+ batch_size=torch.Size([]),
832
+ device=cpu,
833
+ is_shared=False),
834
+ state: TensorDict(
835
+ fields={
836
+ action_mask: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.bool, is_shared=False),
837
+ body: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False),
838
+ body_state: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.int32, is_shared=False),
839
+ fruit_position: TensorDict(
840
+ fields={
841
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
842
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
843
+ batch_size=torch.Size([]),
844
+ device=cpu,
845
+ is_shared=False),
846
+ head_position: TensorDict(
847
+ fields={
848
+ col: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
849
+ row: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False)},
850
+ batch_size=torch.Size([]),
851
+ device=cpu,
852
+ is_shared=False),
853
+ key: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
854
+ length: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
855
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
856
+ tail: Tensor(shape=torch.Size([12, 12]), device=cpu, dtype=torch.bool, is_shared=False)},
857
+ batch_size=torch.Size([]),
858
+ device=cpu,
859
+ is_shared=False),
860
+ step_count: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
861
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
862
+ batch_size=torch.Size([]),
863
+ device=cpu,
864
+ is_shared=False)
865
+ >>> print(env.available_envs)
866
+ ['Game2048-v1',
867
+ 'Maze-v0',
868
+ 'Cleaner-v0',
869
+ 'CVRP-v1',
870
+ 'MultiCVRP-v0',
871
+ 'Minesweeper-v0',
872
+ 'RubiksCube-v0',
873
+ 'Knapsack-v1',
874
+ 'Sudoku-v0',
875
+ 'Snake-v1',
876
+ 'TSP-v1',
877
+ 'Connector-v2',
878
+ 'MMST-v0',
879
+ 'GraphColoring-v0',
880
+ 'RubiksCube-partly-scrambled-v0',
881
+ 'RobotWarehouse-v0',
882
+ 'Tetris-v0',
883
+ 'BinPack-v2',
884
+ 'Sudoku-very-easy-v0',
885
+ 'JobShop-v0']
886
+
887
+ To take advante of Jumanji, one usually executes multiple environments at the
888
+ same time.
889
+
890
+ >>> from torchrl.envs import JumanjiEnv
891
+ >>> env = JumanjiEnv("Snake-v1", batch_size=[10])
892
+ >>> env.set_seed(0)
893
+ >>> td = env.reset()
894
+ >>> td["action"] = env.action_spec.rand()
895
+ >>> td = env.step(td)
896
+
897
+ In the following example, we iteratively test different batch sizes
898
+ and report the execution time for a short rollout:
899
+
900
+ Examples:
901
+ >>> from torch.utils.benchmark import Timer
902
+ >>> for batch_size in [4, 16, 128]:
903
+ ... timer = Timer(
904
+ ... '''
905
+ ... env.rollout(100)
906
+ ... ''',
907
+ ... setup=f'''
908
+ ... from torchrl.envs import JumanjiEnv
909
+ ... env = JumanjiEnv('Snake-v1', batch_size=[{batch_size}])
910
+ ... env.set_seed(0)
911
+ ... env.rollout(2)
912
+ ... ''')
913
+ ... print(batch_size, timer.timeit(number=10))
914
+ 4 <torch.utils.benchmark.utils.common.Measurement object at 0x1fca91910>
915
+ env.rollout(100)
916
+ setup: [...]
917
+ Median: 122.40 ms
918
+ 2 measurements, 1 runs per measurement, 1 thread
919
+ 16 <torch.utils.benchmark.utils.common.Measurement object at 0x1ff9baee0>
920
+ env.rollout(100)
921
+ setup: [...]
922
+ Median: 134.39 ms
923
+ 2 measurements, 1 runs per measurement, 1 thread
924
+ 128 <torch.utils.benchmark.utils.common.Measurement object at 0x1ff9ba7c0>
925
+ env.rollout(100)
926
+ setup: [...]
927
+ Median: 172.31 ms
928
+ 2 measurements, 1 runs per measurement, 1 thread
929
+ """
930
+
931
+ def __init__(self, env_name, **kwargs):
932
+ kwargs["env_name"] = env_name
933
+ super().__init__(**kwargs)
934
+
935
+ def _build_env(
936
+ self,
937
+ env_name: str,
938
+ **kwargs,
939
+ ) -> jumanji.env.Environment: # noqa: F821
940
+ if not _has_jumanji:
941
+ raise ImportError(
942
+ f"jumanji not found, unable to create {env_name}. "
943
+ f"Consider installing jumanji. More info:"
944
+ f" {self.git_url}."
945
+ )
946
+ from_pixels = kwargs.pop("from_pixels", False)
947
+ pixels_only = kwargs.pop("pixels_only", True)
948
+ if kwargs:
949
+ raise ValueError(f"Extra kwargs are not supported by {type(self)}.")
950
+ self.wrapper_frame_skip = 1
951
+ env = self.lib.make(env_name, **kwargs)
952
+ return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels)
953
+
954
+ @property
955
+ def env_name(self):
956
+ return self._constructor_kwargs["env_name"]
957
+
958
+ def _check_kwargs(self, kwargs: dict):
959
+ if "env_name" not in kwargs:
960
+ raise TypeError("Expected 'env_name' to be part of kwargs")
961
+
962
+ def __repr__(self) -> str:
963
+ return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"