torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,429 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ import os
9
+ import warnings
10
+ from copy import copy
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ from tensordict import TensorDict
16
+
17
+ from torchrl.data.tensor_specs import Unbounded
18
+ from torchrl.envs.common import _maybe_unlock
19
+ from torchrl.envs.libs.gym import (
20
+ _gym_to_torchrl_spec_transform,
21
+ _GymAsyncMeta,
22
+ gym_backend,
23
+ GymEnv,
24
+ )
25
+ from torchrl.envs.utils import _classproperty, make_composite_from_td
26
+
27
+ _has_gym = (
28
+ importlib.util.find_spec("gym") is not None
29
+ or importlib.util.find_spec("gymnasium") is not None
30
+ )
31
+ _has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym
32
+
33
+ if _has_robohive:
34
+ os.environ.setdefault("sim_backend", "MUJOCO")
35
+
36
+
37
+ class set_directory:
38
+ """Sets the cwd within the context.
39
+
40
+ Args:
41
+ path (Path): The path to the cwd
42
+ """
43
+
44
+ def __init__(self, path: Path):
45
+ self.path = path
46
+ self.origin = Path().absolute()
47
+
48
+ def __enter__(self):
49
+ os.chdir(self.path)
50
+
51
+ def __exit__(self, *args, **kwargs):
52
+ os.chdir(self.origin)
53
+
54
+ def __call__(self, fun):
55
+ def new_fun(*args, **kwargs):
56
+ with set_directory(Path(self.path)):
57
+ return fun(*args, **kwargs)
58
+
59
+ return new_fun
60
+
61
+
62
+ class _RoboHiveBuild(_GymAsyncMeta):
63
+ def __call__(self, *args, **kwargs):
64
+ instance: RoboHiveEnv = super().__call__(*args, **kwargs)
65
+ instance._refine_specs()
66
+ return instance
67
+
68
+
69
+ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
70
+ """A wrapper for RoboHive gym environments.
71
+
72
+ RoboHive is a collection of environments/tasks simulated with the MuJoCo physics engine exposed using the OpenAI-Gym API.
73
+
74
+ Github: https://github.com/vikashplus/robohive/
75
+
76
+ Doc: https://github.com/vikashplus/robohive/wiki
77
+
78
+ Paper: https://arxiv.org/abs/2310.06828
79
+
80
+ .. warning::
81
+ RoboHive requires gym 0.13.
82
+
83
+ Args:
84
+ env_name (str): the environment name to build. Must be one of :attr:`.available_envs`
85
+ categorical_action_encoding (bool, optional): if ``True``, categorical
86
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
87
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
88
+ Defaults to ``False``.
89
+
90
+ Keyword Args:
91
+ from_pixels (bool, optional): if ``True``, an attempt to return the pixel
92
+ observations from the env will be performed. By default, these observations
93
+ will be written under the ``"pixels"`` entry.
94
+ The method being used varies
95
+ depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
96
+ Defaults to ``False``.
97
+ pixels_only (bool, optional): if ``True``, only the pixel observations will
98
+ be returned (by default under the ``"pixels"`` entry in the output tensordict).
99
+ If ``False``, observations (eg, states) and pixels will be returned
100
+ whenever ``from_pixels=True``. Defaults to ``True``.
101
+ from_depths (bool, optional): if ``True``, an attempt to return the depth
102
+ observations from the env will be performed. By default, these observations
103
+ will be written under the ``"depths"`` entry. Requires ``from_pixels`` to be ``True``.
104
+ Defaults to ``False``.
105
+ frame_skip (int, optional): if provided, indicates for how many steps the
106
+ same action is to be repeated. The observation returned will be the
107
+ last observation of the sequence, whereas the reward will be the sum
108
+ of rewards across steps.
109
+ device (torch.device, optional): if provided, the device on which the data
110
+ is to be cast. Defaults to ``torch.device("cpu")``.
111
+ batch_size (torch.Size, optional): Only ``torch.Size([])`` will work with
112
+ ``RoboHiveEnv`` since vectorized environments are not supported within the
113
+ class. To execute more than one environment at a time, see :class:`~torchrl.envs.ParallelEnv`.
114
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
115
+ for envs to be ``done`` just after :meth:`reset` is called.
116
+ Defaults to ``False``.
117
+
118
+ Attributes:
119
+ available_envs (list): a list of available envs to build.
120
+
121
+ Examples:
122
+ >>> from torchrl.envs import RoboHiveEnv
123
+ >>> env = RoboHiveEnv(RoboHiveEnv.available_envs[0])
124
+ >>> env.rollout(3)
125
+
126
+ """
127
+
128
+ env_list = []
129
+
130
+ @_classproperty
131
+ def CURR_DIR(cls):
132
+ if _has_robohive:
133
+ import robohive.envs.multi_task.substeps1
134
+
135
+ return robohive.envs.multi_task.substeps1.CURR_DIR
136
+ else:
137
+ return None
138
+
139
+ @_classproperty
140
+ def available_envs(cls):
141
+ if not _has_robohive:
142
+ return []
143
+ cls.register_envs()
144
+ return cls.env_list
145
+
146
+ @classmethod
147
+ def register_envs(cls):
148
+ if not _has_robohive:
149
+ raise ImportError(
150
+ "Cannot load robohive from the current virtual environment."
151
+ )
152
+ from robohive import robohive_env_suite as robohive_envs
153
+ from robohive.utils.prompt_utils import Prompt, set_prompt_verbosity
154
+
155
+ set_prompt_verbosity(Prompt.WARN)
156
+ cls.env_list += robohive_envs
157
+ if not len(robohive_envs):
158
+ raise RuntimeError("did not load any environment.")
159
+
160
+ def _build_env( # noqa: F811
161
+ self,
162
+ env_name: str,
163
+ from_pixels: bool = False,
164
+ pixels_only: bool = False,
165
+ from_depths: bool = False,
166
+ **kwargs,
167
+ ) -> gym.core.Env: # noqa: F821
168
+ if from_pixels:
169
+ if "cameras" not in kwargs:
170
+ warnings.warn(
171
+ "from_pixels=True will lead to a registration of ALL available cameras, "
172
+ "which may lead to performance issue. "
173
+ "Consider passing only the needed cameras through cameras=list_of_cameras. "
174
+ "The list of available cameras for a specific environment can be obtained via "
175
+ "RobohiveEnv.get_available_cams(env_name)."
176
+ )
177
+ kwargs["cameras"] = self.get_available_cams(env_name)
178
+ cams = list(kwargs.pop("cameras"))
179
+ env_name = self.register_visual_env(
180
+ cams=cams, env_name=env_name, from_depths=from_depths
181
+ )
182
+
183
+ elif "cameras" in kwargs and kwargs["cameras"]:
184
+ raise RuntimeError("Got a list of cameras but from_pixels is set to False.")
185
+
186
+ self.pixels_only = pixels_only
187
+ try:
188
+ render_device = int(str(self.device)[-1])
189
+ except ValueError:
190
+ render_device = 0
191
+
192
+ if not _has_robohive:
193
+ raise ImportError(
194
+ f"gym/robohive not found, unable to create {env_name}. "
195
+ f"Consider downloading and installing dm_control from"
196
+ f" {self.git_url}"
197
+ )
198
+ try:
199
+ env = self.lib.make(
200
+ env_name,
201
+ frameskip=self.frame_skip,
202
+ device_id=render_device,
203
+ return_dict=True,
204
+ **kwargs,
205
+ )
206
+ self.wrapper_frame_skip = 1
207
+ except TypeError as err:
208
+ if "unexpected keyword argument 'frameskip" not in str(err):
209
+ raise err
210
+ kwargs.pop("framek_skip")
211
+ env = self.lib.make(
212
+ env_name, return_dict=True, device_id=render_device, **kwargs
213
+ )
214
+ self.wrapper_frame_skip = self.frame_skip
215
+ # except Exception as err:
216
+ # raise RuntimeError(f"Failed to build env {env_name}.") from err
217
+ self.from_pixels = from_pixels
218
+ self.from_depths = from_depths
219
+ self.render_device = render_device
220
+ if kwargs.get("read_info", True):
221
+ self.set_info_dict_reader(self.read_info)
222
+ return env
223
+
224
+ def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821
225
+ out = super()._make_specs(env=env, batch_size=batch_size)
226
+ self.env.reset()
227
+ *_, info = self.env.step(self.env.action_space.sample())
228
+ info = self.read_info(info, TensorDict())
229
+ info = info.get("info")
230
+ self.observation_spec["info"] = make_composite_from_td(info)
231
+ return out
232
+
233
+ @classmethod
234
+ def register_visual_env(cls, env_name, cams, from_depths):
235
+ with set_directory(cls.CURR_DIR):
236
+ from robohive.envs.env_variants import register_env_variant
237
+
238
+ if not len(cams):
239
+ raise RuntimeError("Cannot create a visual envs without cameras.")
240
+ cams = sorted(cams)
241
+ cams_rep = [i.replace("A:", "A_") for i in cams]
242
+ new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name])
243
+ visual_keys = [f"rgb:{c}:224x224:2d" for c in cams]
244
+ if from_depths:
245
+ visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])
246
+ register_env_variant(
247
+ env_name,
248
+ variants={
249
+ "visual_keys": visual_keys,
250
+ },
251
+ variant_id=new_env_name,
252
+ )
253
+ env_name = new_env_name
254
+ cls.env_list += [env_name]
255
+ return env_name
256
+
257
+ @_maybe_unlock
258
+ def _refine_specs(self) -> None: # noqa: F821
259
+ env = self._env
260
+ self.action_spec = _gym_to_torchrl_spec_transform(
261
+ env.action_space, device=self.device
262
+ )
263
+ # get a np rollout
264
+ rollout = TensorDict({"done": torch.zeros(3, 1)}, [3])
265
+ env.reset()
266
+
267
+ def get_obs():
268
+ _dict = {}
269
+ obs_dict = copy(env.obs_dict)
270
+ if self.from_pixels:
271
+ visual = self.env.get_exteroception()
272
+ obs_dict.update(visual)
273
+ pixel_list, depth_list = [], []
274
+ for obs_key in obs_dict:
275
+ if obs_key.startswith("rgb"):
276
+ pix = obs_dict[obs_key]
277
+ if not pix.shape[0] == 1:
278
+ pix = pix[None]
279
+ pixel_list.append(pix)
280
+ elif obs_key.startswith("d:"):
281
+ dep = obs_dict[obs_key]
282
+ dep = dep[None]
283
+ depth_list.append(dep)
284
+ elif obs_key in env.obs_keys:
285
+ value = env.obs_dict[obs_key]
286
+ if not value.shape:
287
+ value = value[None]
288
+ _dict[obs_key] = value
289
+ if pixel_list:
290
+ _dict["pixels"] = np.concatenate(pixel_list, 0)
291
+ if depth_list:
292
+ _dict["depths"] = np.concatenate(depth_list, 0)
293
+ return _dict
294
+
295
+ for i in range(3):
296
+ _dict = {}
297
+ _dict.update(get_obs())
298
+ _dict["action"] = action = env.action_space.sample()
299
+ _, r, trunc, term, done, _ = self._output_transform(env.step(action))
300
+ _dict[("next", "reward")] = r.reshape(1)
301
+ _dict[("next", "done")] = [1]
302
+ _dict[("next", "terminated")] = [1]
303
+ _dict[("next", "truncated")] = [1]
304
+ _dict["next"] = get_obs()
305
+ rollout[i] = TensorDict(_dict, [])
306
+
307
+ observation_spec = make_composite_from_td(
308
+ rollout.get("next").exclude("done", "reward", "terminated", "truncated")[0]
309
+ )
310
+ self.observation_spec = observation_spec
311
+
312
+ self.reward_spec = Unbounded(
313
+ shape=(1,),
314
+ device=self.device,
315
+ ) # default
316
+
317
+ rollout = self.rollout(2, return_contiguous=False).get("next")
318
+ rollout = rollout.exclude(
319
+ self.reward_key, *self.done_keys, *self.observation_spec.keys(True, True)
320
+ )
321
+ rollout = rollout[..., 0]
322
+ spec = make_composite_from_td(rollout)
323
+ self.observation_spec.update(spec)
324
+ self.empty_cache()
325
+
326
+ def _reset_output_transform(self, reset_data):
327
+ if not (isinstance(reset_data, tuple) and len(reset_data) == 2):
328
+ return reset_data, {}
329
+ return reset_data
330
+
331
+ def set_from_pixels(self, from_pixels: bool) -> None:
332
+ """Sets the from_pixels attribute to an existing environment.
333
+
334
+ Args:
335
+ from_pixels (bool): new value for the from_pixels attribute
336
+
337
+ """
338
+ if from_pixels is self.from_pixels:
339
+ return
340
+ self.from_pixels = from_pixels
341
+ self._refine_specs()
342
+
343
+ def read_obs(self, observation):
344
+ # the info is missing from the reset
345
+ observations = self.env.obs_dict
346
+ try:
347
+ del observations["t"]
348
+ except KeyError:
349
+ pass
350
+ # recover vec
351
+ obsdict = {}
352
+ pixel_list, depth_list = [], []
353
+ if self.from_pixels:
354
+ visual = self.env.get_exteroception()
355
+ observations.update(visual)
356
+ for key in observations:
357
+ if key.startswith("rgb"):
358
+ pix = observations[key]
359
+ if not pix.shape[0] == 1:
360
+ pix = pix[None]
361
+ pixel_list.append(pix)
362
+ elif key.startswith("d:"):
363
+ dep = observations[key]
364
+ dep = dep[None]
365
+ depth_list.append(dep)
366
+ elif key in self._env.obs_keys:
367
+ value = observations[key]
368
+ if not value.shape:
369
+ value = value[None]
370
+ obsdict[key] = value # ravel helps with images
371
+ # if obsvec:
372
+ # obsvec = np.concatenate(obsvec, 0)
373
+ if self.from_pixels:
374
+ obsdict.update({"pixels": np.concatenate(pixel_list, 0)})
375
+ if self.from_pixels and self.from_depths:
376
+ obsdict.update({"depths": np.concatenate(depth_list, 0)})
377
+ out = obsdict
378
+ return super().read_obs(out)
379
+
380
+ def read_info(self, info, tensordict_out):
381
+ if not info:
382
+ info_spec = self.observation_spec.get("info", None)
383
+ if info_spec is None:
384
+ return tensordict_out
385
+ tensordict_out.set("info", info_spec.zero())
386
+ return tensordict_out
387
+ out = (
388
+ TensorDict(info, [])
389
+ .filter_non_tensor_data()
390
+ .exclude("obs_dict", "done", "reward", *self._env.obs_keys, "act")
391
+ .apply(lambda x: x, filter_empty=True)
392
+ )
393
+ if "info" in self.observation_spec.keys():
394
+ info_spec = self.observation_spec["info"]
395
+
396
+ def func(name, x):
397
+ spec = info_spec.get(name, None)
398
+ if spec is None:
399
+ return None
400
+ return x.reshape(info_spec[name].shape)
401
+
402
+ out.update(out.named_apply(func, nested_keys=True, filter_empty=True))
403
+ else:
404
+ out.update(
405
+ out.apply(
406
+ lambda x: x.reshape((1,)) if not x.shape else x, filter_empty=True
407
+ )
408
+ )
409
+ tensordict_out.set("info", out)
410
+ return tensordict_out
411
+
412
+ def _init_env(self):
413
+ pass
414
+
415
+ def to(self, *args, **kwargs):
416
+ out = super().to(*args, **kwargs)
417
+ try:
418
+ render_device = int(str(out.device)[-1])
419
+ except ValueError:
420
+ render_device = 0
421
+ if render_device != self.render_device:
422
+ out._build_env(**self._constructor_kwargs)
423
+ return out
424
+
425
+ @classmethod
426
+ def get_available_cams(cls, env_name):
427
+ env = gym_backend().make(env_name)
428
+ cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)]
429
+ return cams