torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,363 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import os.path
9
+ import shutil
10
+ import tempfile
11
+ from collections.abc import Callable
12
+ from contextlib import nullcontext
13
+ from pathlib import Path
14
+
15
+ import torch
16
+ from tensordict import PersistentTensorDict, TensorDict
17
+ from torchrl._utils import (
18
+ KeyDependentDefaultDict,
19
+ logger as torchrl_logger,
20
+ print_directory_tree,
21
+ )
22
+ from torchrl.data.datasets.common import BaseDatasetExperienceReplay
23
+ from torchrl.data.datasets.utils import _get_root_dir
24
+ from torchrl.data.replay_buffers.samplers import Sampler
25
+ from torchrl.data.replay_buffers.storages import TensorStorage
26
+ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
27
+
28
+ _has_tqdm = importlib.util.find_spec("tqdm", None) is not None
29
+ _has_h5py = importlib.util.find_spec("h5py", None) is not None
30
+ _has_hf_hub = importlib.util.find_spec("huggingface_hub", None) is not None
31
+
32
+ _NAME_MATCH = KeyDependentDefaultDict(lambda key: key)
33
+ _NAME_MATCH["observations"] = "observation"
34
+ _NAME_MATCH["rewards"] = "reward"
35
+ _NAME_MATCH["actions"] = "action"
36
+ _NAME_MATCH["env_infos"] = "info"
37
+
38
+
39
+ class RobosetExperienceReplay(BaseDatasetExperienceReplay):
40
+ """Roboset experience replay dataset.
41
+
42
+ This class downloads the H5 data from roboset and processes it in a mmap
43
+ format, which makes indexing (and therefore sampling) faster.
44
+
45
+ Learn more about roboset here: https://sites.google.com/view/robohive/roboset
46
+
47
+ The data format follows the :ref:`TED convention <TED-format>`.
48
+
49
+ Args:
50
+ dataset_id (str): the dataset to be downloaded. Must be part of RobosetExperienceReplay.available_datasets.
51
+ batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if
52
+ necessary.
53
+
54
+ Keyword Args:
55
+ root (Path or str, optional): The Roboset dataset root directory.
56
+ The actual dataset memory-mapped files will be saved under
57
+ `<root>/<dataset_id>`. If none is provided, it defaults to
58
+ `~/.cache/torchrl/atari`.roboset`.
59
+ download (bool or str, optional): Whether the dataset should be downloaded if
60
+ not found. Defaults to ``True``. Download can also be passed as ``"force"``,
61
+ in which case the downloaded data will be overwritten.
62
+ sampler (Sampler, optional): the sampler to be used. If none is provided
63
+ a default RandomSampler() will be used.
64
+ writer (Writer, optional): the writer to be used. If none is provided
65
+ a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
66
+ collate_fn (callable, optional): merges a list of samples to form a
67
+ mini-batch of Tensor(s)/outputs. Used when using batched
68
+ loading from a map-style dataset.
69
+ pin_memory (bool): whether pin_memory() should be called on the rb
70
+ samples.
71
+ prefetch (int, optional): number of next batches to be prefetched
72
+ using multithreading.
73
+ transform (Transform, optional): Transform to be executed when sample() is called.
74
+ To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
75
+ split_trajs (bool, optional): if ``True``, the trajectories will be split
76
+ along the first dimension and padded to have a matching shape.
77
+ To split the trajectories, the ``"done"`` signal will be used, which
78
+ is recovered via ``done = truncated | terminated``. In other words,
79
+ it is assumed that any ``truncated`` or ``terminated`` signal is
80
+ equivalent to the end of a trajectory.
81
+ Defaults to ``False``.
82
+
83
+ Attributes:
84
+ available_datasets: a list of accepted entries to be downloaded.
85
+
86
+ Examples:
87
+ >>> import torch
88
+ >>> torch.manual_seed(0)
89
+ >>> from torchrl.envs.transforms import ExcludeTransform
90
+ >>> from torchrl.data.datasets import RobosetExperienceReplay
91
+ >>> d = RobosetExperienceReplay("FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4", batch_size=32,
92
+ ... transform=ExcludeTransform("info", ("next", "info"))) # excluding info dict for conciseness
93
+ >>> for batch in d:
94
+ ... break
95
+ >>> # data is organised by seed and episode, but stored contiguously
96
+ >>> print(f"{batch['seed']}, {batch['episode']}")
97
+ tensor([2, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2, 0, 2, 0, 2, 2, 1,
98
+ 0, 2, 0, 0, 1, 1, 2, 1]) tensor([17, 20, 18, 9, 6, 1, 12, 6, 2, 6, 8, 15, 8, 21, 17, 3, 9, 20,
99
+ 23, 12, 3, 16, 19, 16, 16, 4, 4, 12, 1, 2, 15, 24])
100
+ >>> print(batch)
101
+ TensorDict(
102
+ fields={
103
+ action: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float64, is_shared=False),
104
+ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
105
+ episode: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False),
106
+ index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False),
107
+ next: TensorDict(
108
+ fields={
109
+ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
110
+ observation: Tensor(shape=torch.Size([32, 75]), device=cpu, dtype=torch.float64, is_shared=False),
111
+ reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False),
112
+ terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
113
+ truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
114
+ batch_size=torch.Size([32]),
115
+ device=cpu,
116
+ is_shared=False),
117
+ observation: Tensor(shape=torch.Size([32, 75]), device=cpu, dtype=torch.float64, is_shared=False),
118
+ seed: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False),
119
+ terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
120
+ time: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.float64, is_shared=False)},
121
+ truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
122
+ batch_size=torch.Size([32]),
123
+ device=cpu,
124
+ is_shared=False)
125
+
126
+ """
127
+
128
+ available_datasets = [
129
+ "DAPG(expert)/door_v2d-v1",
130
+ "DAPG(expert)/relocate_v2d-v1",
131
+ "DAPG(expert)/hammer_v2d-v1",
132
+ "DAPG(expert)/pen_v2d-v1",
133
+ "DAPG(human)/door_v2d-v1",
134
+ "DAPG(human)/relocate_v2d-v1",
135
+ "DAPG(human)/hammer_v2d-v1",
136
+ "DAPG(human)/pen_v2d-v1",
137
+ "FK1-v4(expert)/FK1_MicroOpenRandom_v2d-v4",
138
+ "FK1-v4(expert)/FK1_Knob2OffRandom_v2d-v4",
139
+ "FK1-v4(expert)/FK1_LdoorOpenRandom_v2d-v4",
140
+ "FK1-v4(expert)/FK1_SdoorOpenRandom_v2d-v4",
141
+ "FK1-v4(expert)/FK1_Knob1OnRandom_v2d-v4",
142
+ "FK1-v4(human)/human_demos_by_playdata",
143
+ "FK1-v4(human)/human_demos_by_task/human_demo_singleTask_Fixed-v4",
144
+ "FK1-v4(human)/human_demos_by_task/FK1_SdoorOpenRandom_v2d-v4",
145
+ "FK1-v4(human)/human_demos_by_task/FK1_LdoorOpenRandom_v2d-v4",
146
+ "FK1-v4(human)/human_demos_by_task/FK1_Knob2OffRandom_v2d-v4",
147
+ "FK1-v4(human)/human_demos_by_task/FK1_Knob1OnRandom_v2d-v4",
148
+ "FK1-v4(human)/human_demos_by_task/FK1_MicroOpenRandom_v2d-v4",
149
+ ]
150
+
151
+ def __init__(
152
+ self,
153
+ dataset_id,
154
+ batch_size: int,
155
+ *,
156
+ root: str | Path | None = None,
157
+ download: bool = True,
158
+ sampler: Sampler | None = None,
159
+ writer: Writer | None = None,
160
+ collate_fn: Callable | None = None,
161
+ pin_memory: bool = False,
162
+ prefetch: int | None = None,
163
+ transform: torchrl.envs.Transform | None = None, # noqa-F821
164
+ split_trajs: bool = False,
165
+ **env_kwargs,
166
+ ):
167
+ if not _has_h5py or not _has_hf_hub:
168
+ raise ImportError(
169
+ "h5py and huggingface_hub are required for Roboset datasets."
170
+ )
171
+ if dataset_id not in self.available_datasets:
172
+ raise ValueError(
173
+ f"The dataset_id {dataset_id} isn't part of the accepted datasets. "
174
+ f"To check which dataset can be downloaded, call `{type(self)}.available_datasets`."
175
+ )
176
+ self.dataset_id = dataset_id
177
+ if root is None:
178
+ root = _get_root_dir("roboset")
179
+ os.makedirs(root, exist_ok=True)
180
+ self.root = root
181
+ self.split_trajs = split_trajs
182
+ self.download = download
183
+ if self.download == "force" or (self.download and not self._is_downloaded()):
184
+ if self.download == "force":
185
+ try:
186
+ if os.path.exists(self.data_path_root):
187
+ shutil.rmtree(self.data_path_root)
188
+
189
+ if self.data_path != self.data_path_root:
190
+ shutil.rmtree(self.data_path)
191
+ except FileNotFoundError:
192
+ pass
193
+ storage = self._download_and_preproc()
194
+ elif self.split_trajs and not os.path.exists(self.data_path):
195
+ storage = self._make_split()
196
+ else:
197
+ storage = self._load()
198
+ storage = TensorStorage(storage)
199
+
200
+ if writer is None:
201
+ writer = ImmutableDatasetWriter()
202
+
203
+ super().__init__(
204
+ storage=storage,
205
+ sampler=sampler,
206
+ writer=writer,
207
+ collate_fn=collate_fn,
208
+ pin_memory=pin_memory,
209
+ prefetch=prefetch,
210
+ transform=transform,
211
+ batch_size=batch_size,
212
+ )
213
+
214
+ def _download_from_huggingface(self, tempdir):
215
+ try:
216
+ from huggingface_hub import hf_hub_download, HfApi
217
+ except ImportError:
218
+ raise ImportError(
219
+ f"huggingface_hub is required for downloading {type(self)}'s datasets."
220
+ )
221
+ dataset = HfApi().dataset_info("jdvakil/RoboSet_Sim")
222
+ h5_files = []
223
+ datapath = Path(tempdir) / "data"
224
+ for sibling in dataset.siblings:
225
+ if sibling.rfilename.startswith(
226
+ self.dataset_id
227
+ ) and sibling.rfilename.endswith(".h5"):
228
+ path = Path(sibling.rfilename)
229
+ local_path = hf_hub_download(
230
+ "jdvakil/RoboSet_Sim",
231
+ subfolder=str(path.parent),
232
+ filename=str(path.parts[-1]),
233
+ repo_type="dataset",
234
+ cache_dir=str(datapath),
235
+ )
236
+ h5_files.append(local_path)
237
+
238
+ return sorted(h5_files)
239
+
240
+ def _download_and_preproc(self):
241
+
242
+ with tempfile.TemporaryDirectory() as tempdir:
243
+ h5_data_files = self._download_from_huggingface(tempdir)
244
+ return self._preproc_h5(h5_data_files)
245
+
246
+ def _preproc_h5(self, h5_data_files):
247
+ td_data = TensorDict()
248
+ total_steps = 0
249
+ torchrl_logger.info(
250
+ f"first read through data files {h5_data_files} to create data structure..."
251
+ )
252
+ episode_dict = {}
253
+ h5_datas = []
254
+ for seed, h5_data_name in enumerate(h5_data_files):
255
+ torchrl_logger.info(f"\nReading {h5_data_name}")
256
+ h5_data = PersistentTensorDict.from_h5(h5_data_name)
257
+ h5_datas.append(h5_data)
258
+ for i, (episode_key, episode) in enumerate(h5_data.items()):
259
+ episode_num = int(episode_key[len("Trial") :])
260
+ episode_len = episode["actions"].shape[0]
261
+ episode_dict[(seed, episode_num)] = (episode_key, episode_len)
262
+ # Get the total number of steps for the dataset
263
+ total_steps += episode_len
264
+ torchrl_logger.info(f"total_steps {total_steps}")
265
+ if i == 0 and seed == 0:
266
+ td_data.set("episode", 0)
267
+ td_data.set("seed", 0)
268
+ for key, val in episode.items():
269
+ match = _NAME_MATCH[key]
270
+ if key in ("observations", "env_infos", "done"):
271
+ td_data.set(("next", match), torch.zeros_like(val[0]))
272
+ td_data.set(match, torch.zeros_like(val[0]))
273
+ elif key not in ("rewards",):
274
+ td_data.set(match, torch.zeros_like(val[0]))
275
+ else:
276
+ td_data.set(
277
+ ("next", match),
278
+ torch.zeros_like(val[0].unsqueeze(-1)),
279
+ )
280
+
281
+ # give it the proper size
282
+ td_data["next", "done"] = td_data["next", "done"].unsqueeze(-1)
283
+ td_data["done"] = td_data["done"].unsqueeze(-1)
284
+ td_data["next", "terminated"] = td_data["next", "done"]
285
+ td_data["next", "truncated"] = td_data["next", "done"]
286
+ td_data["terminated"] = td_data["done"]
287
+ td_data["truncated"] = td_data["done"]
288
+
289
+ td_data = td_data.expand(total_steps)
290
+ # save to designated location
291
+ torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ")
292
+ td_data = td_data.memmap_like(self.data_path_root)
293
+ # torchrl_logger.info(f"tensordict structure: {td_data}")
294
+ torchrl_logger.info(
295
+ f"Local dataset structure: {print_directory_tree(self.data_path_root)}"
296
+ )
297
+
298
+ torchrl_logger.info(f"Reading data from {len(episode_dict)} episodes")
299
+ index = 0
300
+ if _has_tqdm:
301
+ from tqdm import tqdm
302
+ else:
303
+ tqdm = None
304
+ with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar:
305
+ # iterate over episodes and populate the tensordict
306
+ for seed, episode_num in sorted(episode_dict, key=lambda key: key[1]):
307
+ h5_data = h5_datas[seed]
308
+ episode_key, steps = episode_dict[(seed, episode_num)]
309
+ episode = h5_data.get(episode_key)
310
+ idx = slice(index, (index + steps))
311
+ data_view = td_data[idx]
312
+ data_view.fill_("episode", episode_num)
313
+ data_view.fill_("seed", seed)
314
+ for key, val in episode.items():
315
+ match = _NAME_MATCH[key]
316
+ if steps != val.shape[0]:
317
+ raise RuntimeError(
318
+ f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
319
+ )
320
+ if key in (
321
+ "observations",
322
+ "env_infos",
323
+ ):
324
+ data_view["next", match][:-1].copy_(val[1:])
325
+ data_view[match].copy_(val)
326
+ elif key not in ("rewards", "done", "terminated", "truncated"):
327
+ data_view[match].copy_(val)
328
+ elif key in ("done", "terminated", "truncated"):
329
+ data_view[match].copy_(val.unsqueeze(-1))
330
+ data_view[("next", match)].copy_(val.unsqueeze(-1))
331
+ else:
332
+ data_view[("next", match)].copy_(val.unsqueeze(-1))
333
+ data_view["next", "terminated"].copy_(data_view["next", "done"])
334
+ if pbar is not None:
335
+ pbar.update(steps)
336
+ pbar.set_description(
337
+ f"index={index} - episode num {episode_num} - seed {seed}"
338
+ )
339
+ index += steps
340
+ return td_data
341
+
342
+ def _make_split(self):
343
+ from torchrl.collectors.utils import split_trajectories
344
+
345
+ td_data = TensorDict.load_memmap(self.data_path_root)
346
+ td_data = split_trajectories(td_data).memmap_(self.data_path)
347
+ return td_data
348
+
349
+ def _load(self):
350
+ return TensorDict.load_memmap(self.data_path)
351
+
352
+ @property
353
+ def data_path(self):
354
+ if self.split_trajs:
355
+ return Path(self.root) / (self.dataset_id + "_split")
356
+ return self.data_path_root
357
+
358
+ @property
359
+ def data_path_root(self):
360
+ return Path(self.root) / self.dataset_id
361
+
362
+ def _is_downloaded(self):
363
+ return os.path.exists(self.data_path_root)
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import os
8
+
9
+
10
+ def _get_root_dir(dataset: str):
11
+ return os.path.join(os.path.expanduser("~"), ".cache", "torchrl", dataset)