torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,846 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import warnings
9
+
10
+ import torch
11
+ from packaging import version
12
+ from tensordict import TensorDict, TensorDictBase
13
+
14
+ from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
15
+ from torchrl.envs.batched_envs import ParallelEnv
16
+ from torchrl.envs.common import _EnvPostInit, _EnvWrapper
17
+ from torchrl.envs.libs.jax_utils import (
18
+ _extract_spec,
19
+ _ndarray_to_tensor,
20
+ _object_to_tensordict,
21
+ _tensor_to_ndarray,
22
+ _tensordict_to_object,
23
+ _tree_flatten,
24
+ _tree_reshape,
25
+ )
26
+ from torchrl.envs.utils import _classproperty
27
+
28
+ _has_brax = importlib.util.find_spec("brax") is not None
29
+
30
+ _DEFAULT_CACHE_CLEAR_FREQUENCY = 20
31
+
32
+
33
+ def _get_envs():
34
+ if not _has_brax:
35
+ raise ImportError("BRAX is not installed in your virtual environment.")
36
+
37
+ import brax.envs
38
+
39
+ return list(brax.envs._envs.keys())
40
+
41
+
42
+ class _BraxMeta(_EnvPostInit):
43
+ """Metaclass for BraxEnv that returns a lazy ParallelEnv when num_workers > 1."""
44
+
45
+ def __call__(cls, *args, num_workers: int | None = None, **kwargs):
46
+ # Extract num_workers from explicit kwarg or kwargs dict
47
+ if num_workers is None:
48
+ num_workers = kwargs.pop("num_workers", 1)
49
+ else:
50
+ kwargs.pop("num_workers", None)
51
+
52
+ num_workers = int(num_workers)
53
+ if cls.__name__ == "BraxEnv" and num_workers > 1:
54
+ # Extract env_name from args or kwargs
55
+ env_name = args[0] if len(args) >= 1 else kwargs.get("env_name")
56
+
57
+ # Remove env_name from kwargs if present (it will be passed positionally)
58
+ env_kwargs = {k: v for k, v in kwargs.items() if k != "env_name"}
59
+
60
+ # Create factory function that builds single BraxEnv instances
61
+ def make_env(_env_name=env_name, _kwargs=env_kwargs):
62
+ return cls(_env_name, num_workers=1, **_kwargs)
63
+
64
+ # Return lazy ParallelEnv (workers not started yet)
65
+ return ParallelEnv(num_workers, make_env)
66
+
67
+ return super().__call__(*args, **kwargs)
68
+
69
+
70
+ class BraxWrapper(_EnvWrapper):
71
+ """Google Brax environment wrapper.
72
+
73
+ Brax offers a vectorized and differentiable simulation framework based on Jax.
74
+ TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
75
+ but computational graphs can still be built on top of the simulated trajectories,
76
+ allowing for backpropagation through the rollout.
77
+
78
+ GitHub: https://github.com/google/brax
79
+
80
+ Paper: https://arxiv.org/abs/2106.13281
81
+
82
+ Args:
83
+ env (brax.envs.base.PipelineEnv): the environment to wrap.
84
+ categorical_action_encoding (bool, optional): if ``True``, categorical
85
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
86
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
87
+ Defaults to ``False``.
88
+ cache_clear_frequency (int, optional): automatically clear JAX's internal
89
+ cache every N steps to prevent memory leaks when using ``requires_grad=True``.
90
+ Defaults to `False` (deactivates automatic cache clearing).
91
+
92
+ Keyword Args:
93
+ from_pixels (bool, optional): Not yet supported.
94
+ frame_skip (int, optional): if provided, indicates for how many steps the
95
+ same action is to be repeated. The observation returned will be the
96
+ last observation of the sequence, whereas the reward will be the sum
97
+ of rewards across steps.
98
+ device (torch.device, optional): if provided, the device on which the data
99
+ is to be cast. Defaults to ``torch.device("cpu")``.
100
+ batch_size (torch.Size, optional): the batch size of the environment.
101
+ In ``brax``, this controls the number of environments simulated in
102
+ parallel via JAX's ``vmap`` on a single device (GPU/TPU). Brax leverages
103
+ MuJoCo XLA (MJX) for hardware-accelerated batched simulation, enabling
104
+ thousands of environments to run in parallel within a single process.
105
+ Defaults to ``torch.Size([])``.
106
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
107
+ for envs to be ``done`` just after :meth:`reset` is called.
108
+ Defaults to ``False``.
109
+
110
+ Attributes:
111
+ available_envs: environments available to build
112
+
113
+ Examples:
114
+ >>> import brax.envs
115
+ >>> from torchrl.envs import BraxWrapper
116
+ >>> import torch
117
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ >>> base_env = brax.envs.get_environment("ant")
119
+ >>> env = BraxWrapper(base_env, device=device)
120
+ >>> env.set_seed(0)
121
+ >>> td = env.reset()
122
+ >>> td["action"] = env.action_spec.rand()
123
+ >>> td = env.step(td)
124
+ >>> print(td)
125
+ TensorDict(
126
+ fields={
127
+ action: Tensor(torch.Size([8]), dtype=torch.float32),
128
+ done: Tensor(torch.Size([1]), dtype=torch.bool),
129
+ next: TensorDict(
130
+ fields={
131
+ observation: Tensor(torch.Size([87]), dtype=torch.float32)},
132
+ batch_size=torch.Size([]),
133
+ device=cpu,
134
+ is_shared=False),
135
+ observation: Tensor(torch.Size([87]), dtype=torch.float32),
136
+ reward: Tensor(torch.Size([1]), dtype=torch.float32),
137
+ state: TensorDict(...)},
138
+ batch_size=torch.Size([]),
139
+ device=cpu,
140
+ is_shared=False)
141
+ >>> print(env.available_envs)
142
+ ['acrobot', 'ant', 'fast', 'fetch', ...]
143
+
144
+ To take advante of Brax, one usually executes multiple environments at the
145
+ same time. In the following example, we iteratively test different batch sizes
146
+ and report the execution time for a short rollout:
147
+
148
+ Examples:
149
+ >>> import torch
150
+ >>> from torch.utils.benchmark import Timer
151
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
152
+ >>> for batch_size in [4, 16, 128]:
153
+ ... timer = Timer('''
154
+ ... env.rollout(100)
155
+ ... ''',
156
+ ... setup=f'''
157
+ ... import brax.envs
158
+ ... from torchrl.envs import BraxWrapper
159
+ ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}")
160
+ ... env.set_seed(0)
161
+ ... env.rollout(2)
162
+ ... ''')
163
+ ... print(batch_size, timer.timeit(10))
164
+ 4
165
+ env.rollout(100)
166
+ setup: [...]
167
+ 310.00 ms
168
+ 1 measurement, 10 runs , 1 thread
169
+
170
+ 16
171
+ env.rollout(100)
172
+ setup: [...]
173
+ 268.46 ms
174
+ 1 measurement, 10 runs , 1 thread
175
+
176
+ 128
177
+ env.rollout(100)
178
+ setup: [...]
179
+ 433.80 ms
180
+ 1 measurement, 10 runs , 1 thread
181
+
182
+ One can backpropagate through the rollout and optimize the policy directly:
183
+
184
+ >>> import brax.envs
185
+ >>> from torchrl.envs import BraxWrapper
186
+ >>> from tensordict.nn import TensorDictModule
187
+ >>> from torch import nn
188
+ >>> import torch
189
+ >>>
190
+ >>> env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[10], requires_grad=True, cache_clear_frequency=100)
191
+ >>> env.set_seed(0)
192
+ >>> torch.manual_seed(0)
193
+ >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"])
194
+ >>>
195
+ >>> td = env.rollout(10, policy)
196
+ >>>
197
+ >>> td["next", "reward"].mean().backward(retain_graph=True)
198
+ >>> print(policy.module.weight.grad.norm())
199
+ tensor(213.8605)
200
+
201
+ """
202
+
203
+ git_url = "https://github.com/google/brax"
204
+
205
+ @_classproperty
206
+ def available_envs(cls):
207
+ if not _has_brax:
208
+ return []
209
+ return list(_get_envs())
210
+
211
+ libname = "brax"
212
+
213
+ _lib = None
214
+ _jax = None
215
+
216
+ @_classproperty
217
+ def lib(cls):
218
+ if cls._lib is not None:
219
+ return cls._lib
220
+
221
+ import brax
222
+ import brax.envs
223
+
224
+ cls._lib = brax
225
+ return brax
226
+
227
+ @_classproperty
228
+ def jax(cls):
229
+ if cls._jax is not None:
230
+ return cls._jax
231
+
232
+ import jax
233
+
234
+ cls._jax = jax
235
+ return jax
236
+
237
+ def __init__(
238
+ self,
239
+ env=None,
240
+ categorical_action_encoding=False,
241
+ cache_clear_frequency: int | None = None,
242
+ **kwargs,
243
+ ):
244
+ if env is not None:
245
+ kwargs["env"] = env
246
+ self._seed_calls_reset = None
247
+ self._categorical_action_encoding = categorical_action_encoding
248
+ # If user passes None or False, deactivate automatic cache clearing
249
+ if cache_clear_frequency in (False,):
250
+ self._cache_clear_frequency = False
251
+ elif cache_clear_frequency in (None, True):
252
+ self._cache_clear_frequency = _DEFAULT_CACHE_CLEAR_FREQUENCY
253
+ else:
254
+ self._cache_clear_frequency = cache_clear_frequency
255
+ self._step_count = 0
256
+ super().__init__(**kwargs)
257
+ if not self.device:
258
+ warnings.warn(
259
+ f"No device is set for env {self}. "
260
+ f"Setting a device in Brax wrapped environments is strongly recommended."
261
+ )
262
+
263
+ def _check_kwargs(self, kwargs: dict):
264
+ brax = self.lib
265
+ if version.parse(brax.__version__) < version.parse("0.10.4"):
266
+ raise ImportError("Brax v0.10.4 or greater is required.")
267
+
268
+ if "env" not in kwargs:
269
+ raise TypeError("Could not find environment key 'env' in kwargs.")
270
+ env = kwargs["env"]
271
+ if not isinstance(env, brax.envs.Env):
272
+ raise TypeError("env is not of type 'brax.envs.Env'.")
273
+
274
+ def _build_env(
275
+ self,
276
+ env,
277
+ _seed: int | None = None,
278
+ from_pixels: bool = False,
279
+ render_kwargs: dict | None = None,
280
+ pixels_only: bool = False,
281
+ requires_grad: bool = False,
282
+ camera_id: int | str = 0,
283
+ **kwargs,
284
+ ):
285
+ self.from_pixels = from_pixels
286
+ self.pixels_only = pixels_only
287
+ self.requires_grad = requires_grad
288
+
289
+ if from_pixels:
290
+ raise NotImplementedError(
291
+ "from_pixels=True is not yest supported within BraxWrapper"
292
+ )
293
+ return env
294
+
295
+ def _make_state_spec(self, env: brax.envs.env.Env): # noqa: F821
296
+ jax = self.jax
297
+
298
+ key = jax.random.PRNGKey(0)
299
+ state = env.reset(key)
300
+ state_dict = _object_to_tensordict(state, self.device, batch_size=())
301
+ state_spec = _extract_spec(state_dict).expand(self.batch_size)
302
+ return state_spec
303
+
304
+ def _make_specs(self, env: brax.envs.env.Env) -> None: # noqa: F821
305
+ self.action_spec = Bounded(
306
+ low=-1,
307
+ high=1,
308
+ shape=(
309
+ *self.batch_size,
310
+ env.action_size,
311
+ ),
312
+ device=self.device,
313
+ )
314
+ self.reward_spec = Unbounded(
315
+ shape=[
316
+ *self.batch_size,
317
+ 1,
318
+ ],
319
+ device=self.device,
320
+ )
321
+ self.observation_spec = Composite(
322
+ observation=Unbounded(
323
+ shape=(
324
+ *self.batch_size,
325
+ env.observation_size,
326
+ ),
327
+ device=self.device,
328
+ ),
329
+ shape=self.batch_size,
330
+ )
331
+ # extract state spec from instance
332
+ state_spec = self._make_state_spec(env)
333
+ self.state_spec["state"] = state_spec
334
+ self.observation_spec["state"] = state_spec.clone()
335
+
336
+ def _make_state_example(self):
337
+ jax = self.jax
338
+
339
+ key = jax.random.PRNGKey(0)
340
+ keys = jax.random.split(key, self.batch_size.numel())
341
+ state = self._vmap_jit_env_reset(jax.numpy.stack(keys))
342
+ state = _tree_reshape(state, self.batch_size)
343
+ return state
344
+
345
+ def _init_env(self) -> int | None:
346
+ jax = self.jax
347
+ self._key = None
348
+ self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset))
349
+ self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step))
350
+ self._state_example = self._make_state_example()
351
+
352
+ def _set_seed(self, seed: int | None) -> None:
353
+ jax = self.jax
354
+ if seed is None:
355
+ raise Exception("Brax requires an integer seed.")
356
+ self._key = jax.random.PRNGKey(seed)
357
+
358
+ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
359
+ jax = self.jax
360
+
361
+ # ensure a valid JAX PRNG key exists
362
+ if getattr(self, "_key", None) is None:
363
+ seed = getattr(self, "_seed", None)
364
+ if seed is None:
365
+ seed = 0
366
+
367
+ self._key = jax.random.PRNGKey(int(seed))
368
+
369
+ # generate random keys
370
+ self._key, *keys = jax.random.split(self._key, 1 + self.numel())
371
+
372
+ # call env reset with jit and vmap
373
+ state = self._vmap_jit_env_reset(jax.numpy.stack(keys))
374
+
375
+ # reshape batch size
376
+ state = _tree_reshape(state, self.batch_size)
377
+ state = _object_to_tensordict(state, self.device, self.batch_size)
378
+
379
+ # build result
380
+ state["reward"] = state.get("reward").view(*self.reward_spec.shape)
381
+ state["done"] = state.get("done").view(*self.reward_spec.shape)
382
+ done = state["done"].bool()
383
+ tensordict_out = TensorDict._new_unsafe(
384
+ source={
385
+ "observation": state.get("obs"),
386
+ # "reward": reward,
387
+ "done": done,
388
+ "terminated": done.clone(),
389
+ "state": state,
390
+ },
391
+ batch_size=self.batch_size,
392
+ device=self.device,
393
+ )
394
+ return tensordict_out
395
+
396
+ def _step_without_grad(self, tensordict: TensorDictBase):
397
+
398
+ # convert tensors to ndarrays
399
+ state = _tensordict_to_object(tensordict.get("state"), self._state_example)
400
+ action = _tensor_to_ndarray(tensordict.get("action"))
401
+
402
+ # flatten batch size
403
+ state = _tree_flatten(state, self.batch_size)
404
+ action = _tree_flatten(action, self.batch_size)
405
+
406
+ # call env step with jit and vmap
407
+ next_state = self._vmap_jit_env_step(state, action)
408
+
409
+ # reshape batch size and convert ndarrays to tensors
410
+ next_state = _tree_reshape(next_state, self.batch_size)
411
+ next_state = _object_to_tensordict(next_state, self.device, self.batch_size)
412
+
413
+ # build result
414
+ next_state.set("reward", next_state.get("reward").view(self.reward_spec.shape))
415
+ next_state.set("done", next_state.get("done").view(self.reward_spec.shape))
416
+ done = next_state["done"].bool()
417
+ reward = next_state["reward"]
418
+ tensordict_out = TensorDict._new_unsafe(
419
+ source={
420
+ "observation": next_state.get("obs"),
421
+ "reward": reward,
422
+ "done": done,
423
+ "terminated": done.clone(),
424
+ "state": next_state,
425
+ },
426
+ batch_size=self.batch_size,
427
+ device=self.device,
428
+ )
429
+ return tensordict_out
430
+
431
+ def _step_with_grad(self, tensordict: TensorDictBase):
432
+
433
+ # convert tensors to ndarrays
434
+ action = tensordict.get("action")
435
+ state = tensordict.get("state")
436
+ qp_keys, qp_values = zip(*state.get("pipeline_state").items())
437
+
438
+ # call env step with autograd function
439
+ next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply(
440
+ self, state, action, *qp_values
441
+ )
442
+
443
+ # extract done values: we assume a shape identical to reward
444
+ next_done = next_state_nograd.get("done").view(*self.reward_spec.shape)
445
+ next_reward = next_reward.view(*self.reward_spec.shape)
446
+
447
+ # merge with tensors with grad function
448
+ next_state = next_state_nograd
449
+ next_state["obs"] = next_obs
450
+ next_state.set("reward", next_reward)
451
+ next_state.set("done", next_done)
452
+ next_done = next_done.bool()
453
+ next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values)))
454
+
455
+ # build result
456
+ tensordict_out = TensorDict._new_unsafe(
457
+ source={
458
+ "observation": next_obs,
459
+ "reward": next_reward,
460
+ "done": next_done,
461
+ "terminated": next_done,
462
+ "state": next_state,
463
+ },
464
+ batch_size=self.batch_size,
465
+ device=self.device,
466
+ )
467
+ return tensordict_out
468
+
469
+ def _step(
470
+ self,
471
+ tensordict: TensorDictBase,
472
+ ) -> TensorDictBase:
473
+
474
+ if self.requires_grad:
475
+ out = self._step_with_grad(tensordict)
476
+ else:
477
+ out = self._step_without_grad(tensordict)
478
+
479
+ self._step_count += 1
480
+ if (
481
+ self._cache_clear_frequency
482
+ and (self._step_count % self._cache_clear_frequency) == 0
483
+ ):
484
+ self.clear_cache()
485
+
486
+ return out
487
+
488
+ def clear_cache(self):
489
+ """Clear JAX's internal cache to prevent memory leaks.
490
+
491
+ This method should be called periodically when using requires_grad=True
492
+ to prevent memory accumulation from JAX's internal computation graph.
493
+ """
494
+ if hasattr(self, "jax"):
495
+ try:
496
+ # Clear JAX's compilation cache
497
+ if hasattr(self.jax.jit, "clear_caches"):
498
+ self.jax.jit.clear_caches()
499
+ # Alternative: clear JAX's internal cache
500
+ if hasattr(self.jax, "clear_caches"):
501
+ self.jax.clear_caches()
502
+ # Clear JAX's XLA compilation cache if available
503
+ try:
504
+ import jaxlib
505
+
506
+ if hasattr(jaxlib, "xla_extension"):
507
+ jaxlib.xla_extension.clear_caches()
508
+ except Exception:
509
+ pass
510
+ except Exception:
511
+ pass
512
+
513
+
514
+ class BraxEnv(BraxWrapper, metaclass=_BraxMeta):
515
+ """Google Brax environment wrapper built with the environment name.
516
+
517
+ Brax offers a vectorized and differentiable simulation framework based on Jax.
518
+ TorchRL's wrapper incurs some overhead for the jax-to-torch conversion,
519
+ but computational graphs can still be built on top of the simulated trajectories,
520
+ allowing for backpropagation through the rollout.
521
+
522
+ GitHub: https://github.com/google/brax
523
+
524
+ Paper: https://arxiv.org/abs/2106.13281
525
+
526
+ Args:
527
+ env_name (str): the environment name of the env to wrap. Must be part of
528
+ :attr:`~.available_envs`.
529
+ categorical_action_encoding (bool, optional): if ``True``, categorical
530
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
531
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
532
+ Defaults to ``False``.
533
+ cache_clear_frequency (int, optional): automatically clear JAX's internal
534
+ cache every N steps to prevent memory leaks when using ``requires_grad=True``.
535
+ Defaults to `False` (deactivates automatic cache clearing).
536
+
537
+ Keyword Args:
538
+ from_pixels (bool, optional): Not yet supported.
539
+ frame_skip (int, optional): if provided, indicates for how many steps the
540
+ same action is to be repeated. The observation returned will be the
541
+ last observation of the sequence, whereas the reward will be the sum
542
+ of rewards across steps.
543
+ device (torch.device, optional): if provided, the device on which the data
544
+ is to be cast. Defaults to ``torch.device("cpu")``.
545
+ batch_size (torch.Size, optional): the batch size of the environment.
546
+ In ``brax``, this controls the number of environments simulated in
547
+ parallel via JAX's ``vmap`` on a single device (GPU/TPU). Brax leverages
548
+ MuJoCo XLA (MJX) for hardware-accelerated batched simulation, enabling
549
+ thousands of environments to run in parallel within a single process.
550
+ Defaults to ``torch.Size([])``.
551
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
552
+ for envs to be ``done`` just after :meth:`reset` is called.
553
+ Defaults to ``False``.
554
+ num_workers (int, optional): if greater than 1, a lazy :class:`~torchrl.envs.ParallelEnv`
555
+ will be returned instead, with each worker instantiating its own
556
+ :class:`~torchrl.envs.BraxEnv` instance. Defaults to ``None``.
557
+
558
+ .. note::
559
+ There are two orthogonal ways to scale environment throughput:
560
+
561
+ - **batch_size**: Uses Brax's native JAX-based vectorization (``vmap``) to run
562
+ multiple environments in parallel on a single GPU/TPU. This is highly efficient
563
+ for moderate batch sizes where the MJX solver has not yet saturated.
564
+ - **num_workers**: Uses TorchRL's :class:`~torchrl.envs.ParallelEnv` to spawn
565
+ multiple Python processes, each running its own ``BraxEnv``.
566
+
567
+ These can be combined: ``BraxEnv("ant", batch_size=[128], num_workers=4)`` creates
568
+ 4 worker processes, each running 128 vectorized environments, for a total of 512
569
+ parallel environments. This hybrid approach can be beneficial when the MJX solver
570
+ saturates on a single device, or when distributing across multiple GPUs/CPUs.
571
+
572
+ Attributes:
573
+ available_envs: environments available to build
574
+
575
+ Examples:
576
+ >>> from torchrl.envs import BraxEnv
577
+ >>> import torch
578
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
579
+ >>> env = BraxEnv("ant", device=device)
580
+ >>> env.set_seed(0)
581
+ >>> td = env.reset()
582
+ >>> td["action"] = env.action_spec.rand()
583
+ >>> td = env.step(td)
584
+ >>> print(td)
585
+ TensorDict(
586
+ fields={
587
+ action: Tensor(torch.Size([8]), dtype=torch.float32),
588
+ done: Tensor(torch.Size([1]), dtype=torch.bool),
589
+ next: TensorDict(
590
+ fields={
591
+ observation: Tensor(torch.Size([87]), dtype=torch.float32)},
592
+ batch_size=torch.Size([]),
593
+ device=cpu,
594
+ is_shared=False),
595
+ observation: Tensor(torch.Size([87]), dtype=torch.float32),
596
+ reward: Tensor(torch.Size([1]), dtype=torch.float32),
597
+ state: TensorDict(...)},
598
+ batch_size=torch.Size([]),
599
+ device=cpu,
600
+ is_shared=False)
601
+ >>> print(env.available_envs)
602
+ ['acrobot', 'ant', 'fast', 'fetch', ...]
603
+
604
+ # Example: create a parallel environment with 4 workers. This returns a lazy
605
+ # ParallelEnv; each worker will instantiate a BraxEnv with num_workers=1.
606
+ >>> from torchrl.envs import BraxEnv
607
+ >>> par_env = BraxEnv("ant", batch_size=[8], num_workers=4, device="cpu")
608
+ >>> # par_env is a ParallelEnv; start interacting as usual
609
+ >>> par_env.set_seed(0)
610
+ >>> td = par_env.reset()
611
+ >>> print(td.shape)
612
+ torch.Size([4, 8])
613
+ >>> td["action"] = par_env.action_spec.rand()
614
+ >>> td = par_env.step(td)
615
+
616
+ To take advante of Brax, one usually executes multiple environments at the
617
+ same time. In the following example, we iteratively test different batch sizes
618
+ and report the execution time for a short rollout:
619
+
620
+ Examples:
621
+ >>> import torch
622
+ >>> from torch.utils.benchmark import Timer
623
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
624
+ >>> for batch_size in [4, 16, 128]:
625
+ ... timer = Timer('''
626
+ ... env.rollout(100)
627
+ ... ''',
628
+ ... setup=f'''
629
+ ... from torchrl.envs import BraxEnv
630
+ ... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}")
631
+ ... env.set_seed(0)
632
+ ... env.rollout(2)
633
+ ... ''')
634
+ ... print(batch_size, timer.timeit(10))
635
+ 4
636
+ env.rollout(100)
637
+ setup: [...]
638
+ 310.00 ms
639
+ 1 measurement, 10 runs , 1 thread
640
+
641
+ 16
642
+ env.rollout(100)
643
+ setup: [...]
644
+ 268.46 ms
645
+ 1 measurement, 10 runs , 1 thread
646
+
647
+ 128
648
+ env.rollout(100)
649
+ setup: [...]
650
+ 433.80 ms
651
+ 1 measurement, 10 runs , 1 thread
652
+
653
+ One can backpropagate through the rollout and optimize the policy directly:
654
+
655
+ >>> from torchrl.envs import BraxEnv
656
+ >>> from tensordict.nn import TensorDictModule
657
+ >>> from torch import nn
658
+ >>> import torch
659
+ >>>
660
+ >>> env = BraxEnv("ant", batch_size=[10], requires_grad=True, cache_clear_frequency=100)
661
+ >>> env.set_seed(0)
662
+ >>> torch.manual_seed(0)
663
+ >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"])
664
+ >>>
665
+ >>> td = env.rollout(10, policy)
666
+ >>>
667
+ >>> td["next", "reward"].mean().backward(retain_graph=True)
668
+ >>> print(policy.module.weight.grad.norm())
669
+ tensor(213.8605)
670
+
671
+ """
672
+
673
+ def __init__(self, env_name, **kwargs):
674
+ kwargs["env_name"] = env_name
675
+ super().__init__(**kwargs)
676
+
677
+ def _build_env(
678
+ self,
679
+ env_name: str,
680
+ **kwargs,
681
+ ) -> brax.envs.env.Env: # noqa: F821
682
+ if not _has_brax:
683
+ raise ImportError(
684
+ f"brax not found, unable to create {env_name}. "
685
+ f"Consider downloading and installing brax from"
686
+ f" {self.git_url}"
687
+ )
688
+ from_pixels = kwargs.pop("from_pixels", False)
689
+ pixels_only = kwargs.pop("pixels_only", True)
690
+ requires_grad = kwargs.pop("requires_grad", False)
691
+ cache_clear_frequency = kwargs.pop("cache_clear_frequency", False)
692
+ if kwargs:
693
+ raise ValueError("kwargs not supported.")
694
+ self.wrapper_frame_skip = 1
695
+ env = self.lib.envs.get_environment(env_name, **kwargs)
696
+ return super()._build_env(
697
+ env,
698
+ pixels_only=pixels_only,
699
+ from_pixels=from_pixels,
700
+ requires_grad=requires_grad,
701
+ cache_clear_frequency=cache_clear_frequency,
702
+ )
703
+
704
+ @property
705
+ def env_name(self):
706
+ return self._constructor_kwargs["env_name"]
707
+
708
+ def _check_kwargs(self, kwargs: dict):
709
+ if "env_name" not in kwargs:
710
+ raise TypeError("Expected 'env_name' to be part of kwargs")
711
+
712
+ def __repr__(self) -> str:
713
+ return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
714
+
715
+
716
+ class _BraxEnvStep(torch.autograd.Function):
717
+ @staticmethod
718
+ def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values):
719
+ import jax
720
+
721
+ # convert tensors to ndarrays
722
+ state_obj = _tensordict_to_object(state_td, env._state_example)
723
+ action_nd = _tensor_to_ndarray(action_tensor)
724
+
725
+ # flatten batch size
726
+ state = _tree_flatten(state_obj, env.batch_size)
727
+ action = _tree_flatten(action_nd, env.batch_size)
728
+
729
+ # call vjp with jit and vmap
730
+ next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action)
731
+
732
+ # reshape batch size
733
+ next_state_reshape = _tree_reshape(next_state, env.batch_size)
734
+
735
+ # convert ndarrays to tensors
736
+ next_state_tensor = _object_to_tensordict(
737
+ next_state_reshape, device=env.device, batch_size=env.batch_size
738
+ )
739
+
740
+ # save context
741
+ ctx.vjp_fn = vjp_fn
742
+ ctx.next_state = next_state_tensor
743
+ ctx.env = env
744
+ # Mark that backward hasn't been called yet
745
+ ctx._backward_called = False
746
+
747
+ return (
748
+ next_state_tensor, # no gradient
749
+ next_state_tensor["obs"],
750
+ next_state_tensor["reward"],
751
+ *next_state_tensor["pipeline_state"].values(),
752
+ )
753
+
754
+ @staticmethod
755
+ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):
756
+ # Prevent multiple backward calls on the same context
757
+ if hasattr(ctx, "_backward_called") and ctx._backward_called:
758
+ return (None, None, *([None] * len(grad_next_qp_values)))
759
+
760
+ ctx._backward_called = True
761
+
762
+ pipeline_state = dict(
763
+ zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values)
764
+ )
765
+ none_keys = []
766
+
767
+ def _make_none(key, val):
768
+ if val is not None:
769
+ return val
770
+ none_keys.append(key)
771
+ return torch.zeros_like(ctx.next_state.get(("pipeline_state", key)))
772
+
773
+ pipeline_state = {
774
+ key: _make_none(key, val) for key, val in pipeline_state.items()
775
+ }
776
+ metrics = ctx.next_state.get("metrics", None)
777
+ if metrics is None:
778
+ metrics = {}
779
+ info = ctx.next_state.get("info", None)
780
+ if info is None:
781
+ info = {}
782
+ grad_next_state_td = TensorDict(
783
+ source={
784
+ "pipeline_state": pipeline_state,
785
+ "obs": grad_next_obs,
786
+ "reward": grad_next_reward,
787
+ "done": torch.zeros_like(ctx.next_state.get("done")),
788
+ "metrics": {k: torch.zeros_like(v) for k, v in metrics.items()},
789
+ "info": {k: torch.zeros_like(v) for k, v in info.items()},
790
+ },
791
+ device=ctx.env.device,
792
+ batch_size=ctx.env.batch_size,
793
+ )
794
+ # convert tensors to ndarrays
795
+ grad_next_state_obj = _tensordict_to_object(
796
+ grad_next_state_td, ctx.env._state_example
797
+ )
798
+
799
+ # flatten batch size
800
+ grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size)
801
+
802
+ # call vjp to get gradients
803
+ grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat)
804
+ # assert grad_action.device == ctx.env.device
805
+
806
+ # reshape batch size
807
+ grad_state = _tree_reshape(grad_state, ctx.env.batch_size)
808
+ grad_action = _tree_reshape(grad_action, ctx.env.batch_size)
809
+ # assert grad_action.device == ctx.env.device
810
+
811
+ # convert ndarrays to tensors
812
+ grad_state_qp = _object_to_tensordict(
813
+ grad_state.pipeline_state,
814
+ device=ctx.env.device,
815
+ batch_size=ctx.env.batch_size,
816
+ )
817
+ grad_action = _ndarray_to_tensor(grad_action).to(ctx.env.device)
818
+ grad_state_qp = {
819
+ key: val if key not in none_keys else None
820
+ for key, val in grad_state_qp.items()
821
+ }
822
+ grads = (grad_action, *grad_state_qp.values())
823
+
824
+ # Clean up context to prevent memory leaks
825
+ try:
826
+ # Clear JAX VJP function reference
827
+ del ctx.vjp_fn
828
+ except AttributeError:
829
+ pass
830
+ try:
831
+ # Clear stored tensors
832
+ del ctx.next_state
833
+ except AttributeError:
834
+ pass
835
+ try:
836
+ # Clear environment reference
837
+ del ctx.env
838
+ except AttributeError:
839
+ pass
840
+ try:
841
+ # Clear the backward flag
842
+ del ctx._backward_called
843
+ except AttributeError:
844
+ pass
845
+
846
+ return (None, None, *grads)