torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,432 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import importlib
9
+ import json
10
+ import os
11
+ import pathlib
12
+ import shutil
13
+ import tempfile
14
+ from collections import defaultdict
15
+ from collections.abc import Callable
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ from tensordict import PersistentTensorDict, TensorDict
21
+ from torch import multiprocessing as mp
22
+ from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger
23
+ from torchrl.data.datasets.common import BaseDatasetExperienceReplay
24
+ from torchrl.data.datasets.utils import _get_root_dir
25
+ from torchrl.data.replay_buffers.samplers import Sampler
26
+ from torchrl.data.replay_buffers.storages import TensorStorage
27
+ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
28
+ from torchrl.envs.transforms import Compose, Resize, ToTensorImage
29
+ from torchrl.envs.utils import _classproperty
30
+
31
+ _has_tqdm = importlib.util.find_spec("tqdm", None) is not None
32
+ _has_h5py = importlib.util.find_spec("h5py", None) is not None
33
+ _has_hf_hub = importlib.util.find_spec("huggingface_hub", None) is not None
34
+
35
+ THIS_DIR = pathlib.Path(__file__).parent
36
+
37
+
38
+ class VD4RLExperienceReplay(BaseDatasetExperienceReplay):
39
+ """V-D4RL experience replay dataset.
40
+
41
+ This class downloads the H5/npz data from V-D4RL and processes it in a mmap
42
+ format, which makes indexing (and therefore sampling) faster.
43
+
44
+ Learn more about V-D4RL here: https://arxiv.org/abs/2206.04779
45
+
46
+ The `"pixels"` entry is located at the root of the data, and all the data
47
+ that is not reward, done-state, action or pixels is moved under a `"state"`
48
+ node.
49
+
50
+ The data format follows the :ref:`TED convention <TED-format>`.
51
+
52
+ Args:
53
+ dataset_id (str): the dataset to be downloaded. Must be part of
54
+ VD4RLExperienceReplay.available_datasets.
55
+ batch_size (int): Batch-size used during sampling. Can be overridden by
56
+ `data.sample(batch_size)` if necessary.
57
+
58
+ Keyword Args:
59
+ root (Path or str, optional): The V-D4RL dataset root directory.
60
+ The actual dataset memory-mapped files will be saved under
61
+ `<root>/<dataset_id>`. If none is provided, it defaults to
62
+ `~/.cache/torchrl/atari`.vd4rl`.
63
+ download (bool or str, optional): Whether the dataset should be downloaded if
64
+ not found. Defaults to ``True``. Download can also be passed as ``"force"``,
65
+ in which case the downloaded data will be overwritten.
66
+ sampler (Sampler, optional): the sampler to be used. If none is provided
67
+ a default RandomSampler() will be used.
68
+ writer (Writer, optional): the writer to be used. If none is provided
69
+ a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
70
+ collate_fn (callable, optional): merges a list of samples to form a
71
+ mini-batch of Tensor(s)/outputs. Used when using batched
72
+ loading from a map-style dataset.
73
+ pin_memory (bool): whether pin_memory() should be called on the rb
74
+ samples.
75
+ prefetch (int, optional): number of next batches to be prefetched
76
+ using multithreading.
77
+ transform (Transform, optional): Transform to be executed when sample() is called.
78
+ To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
79
+ split_trajs (bool, optional): if ``True``, the trajectories will be split
80
+ along the first dimension and padded to have a matching shape.
81
+ To split the trajectories, the ``"done"`` signal will be used, which
82
+ is recovered via ``done = truncated | terminated``. In other words,
83
+ it is assumed that any ``truncated`` or ``terminated`` signal is
84
+ equivalent to the end of a trajectory. For some datasets from
85
+ ``D4RL``, this may not be true. It is up to the user to make
86
+ accurate choices regarding this usage of ``split_trajs``.
87
+ Defaults to ``False``.
88
+ totensor (bool, optional): if ``True``, a :class:`~torchrl.envs.transforms.ToTensorImage`
89
+ transform will be included in the transform list (if not automatically
90
+ detected). Defaults to ``True``.
91
+ image_size (int, list of ints or None): if not ``None``, this argument
92
+ will be used to create a :class:`~torchrl.envs.transforms.Resize`
93
+ transform that will be appended to the transform list. Supports
94
+ `int` types (square resizing) or a list/tuple of `int` (rectangular
95
+ resizing). Defaults to ``None`` (no resizing).
96
+ num_workers (int, optional): the number of workers to download the files.
97
+ Defaults to ``0`` (no multiprocessing).
98
+
99
+ Attributes:
100
+ available_datasets: a list of accepted entries to be downloaded. These
101
+ names correspond to the directory path in the huggingface dataset
102
+ repository. If possible, the list will be dynamically retrieved from
103
+ huggingface. If no internet connection is available, it a cached
104
+ version will be used.
105
+
106
+ .. note:: Since not all experience replay have start and stop signals, we
107
+ do not mark the episodes in the retrieved dataset.
108
+
109
+ Examples:
110
+ >>> import torch
111
+ >>> torch.manual_seed(0)
112
+ >>> from torchrl.data.datasets import VD4RLExperienceReplay
113
+ >>> d = VD4RLExperienceReplay("main/walker_walk/random/64px", batch_size=32,
114
+ ... image_size=50)
115
+ >>> for batch in d:
116
+ ... break
117
+ >>> print(batch)
118
+ TensorDict(
119
+ fields={
120
+ action: Tensor(shape=torch.Size([32, 6]), device=cpu, dtype=torch.float32, is_shared=False),
121
+ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
122
+ index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False),
123
+ is_init: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False),
124
+ next: TensorDict(
125
+ fields={
126
+ done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
127
+ observation: TensorDict(
128
+ fields={
129
+ height: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.float32, is_shared=False),
130
+ orientations: Tensor(shape=torch.Size([32, 14]), device=cpu, dtype=torch.float32, is_shared=False),
131
+ velocity: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float32, is_shared=False)},
132
+ batch_size=torch.Size([32]),
133
+ device=cpu,
134
+ is_shared=False),
135
+ pixels: Tensor(shape=torch.Size([32, 3, 50, 50]), device=cpu, dtype=torch.float32, is_shared=False),
136
+ reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
137
+ terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
138
+ truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
139
+ batch_size=torch.Size([32]),
140
+ device=cpu,
141
+ is_shared=False),
142
+ observation: TensorDict(
143
+ fields={
144
+ height: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.float32, is_shared=False),
145
+ orientations: Tensor(shape=torch.Size([32, 14]), device=cpu, dtype=torch.float32, is_shared=False),
146
+ velocity: Tensor(shape=torch.Size([32, 9]), device=cpu, dtype=torch.float32, is_shared=False)},
147
+ batch_size=torch.Size([32]),
148
+ device=cpu,
149
+ is_shared=False),
150
+ pixels: Tensor(shape=torch.Size([32, 3, 50, 50]), device=cpu, dtype=torch.float32, is_shared=False),
151
+ terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
152
+ truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
153
+ batch_size=torch.Size([32]),
154
+ device=cpu,
155
+ is_shared=False)
156
+
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dataset_id,
162
+ batch_size: int,
163
+ *,
164
+ root: str | Path | None = None,
165
+ download: bool = True,
166
+ sampler: Sampler | None = None,
167
+ writer: Writer | None = None,
168
+ collate_fn: Callable | None = None,
169
+ pin_memory: bool = False,
170
+ prefetch: int | None = None,
171
+ transform: torchrl.envs.Transform | None = None, # noqa-F821
172
+ split_trajs: bool = False,
173
+ totensor: bool = True,
174
+ image_size: int | list[int] | None = None,
175
+ num_workers: int = 0,
176
+ **env_kwargs,
177
+ ):
178
+ if not _has_h5py or not _has_hf_hub:
179
+ raise ImportError(
180
+ "h5py and huggingface_hub are required for V-D4RL datasets."
181
+ )
182
+ if dataset_id not in self.available_datasets:
183
+ raise ValueError(
184
+ f"The dataset_id {dataset_id} isn't part of the accepted datasets. "
185
+ f"To check which dataset can be downloaded, call `{type(self)}.available_datasets`."
186
+ )
187
+ self.dataset_id = dataset_id
188
+ if root is None:
189
+ root = _get_root_dir("vd4rl")
190
+ os.makedirs(root, exist_ok=True)
191
+ self.root = root
192
+ self.split_trajs = split_trajs
193
+ self.download = download
194
+ self.num_workers = num_workers
195
+ if self.download == "force" or (self.download and not self._is_downloaded()):
196
+ if self.download == "force":
197
+ try:
198
+ if os.path.exists(self.data_path_root):
199
+ shutil.rmtree(self.data_path_root)
200
+ if self.data_path != self.data_path_root:
201
+ shutil.rmtree(self.data_path)
202
+ except FileNotFoundError:
203
+ pass
204
+ storage = self._download_and_preproc(
205
+ dataset_id, data_path=self.data_path, num_workers=self.num_workers
206
+ )
207
+ elif self.split_trajs and not os.path.exists(self.data_path):
208
+ storage = self._make_split()
209
+ else:
210
+ storage = self._load()
211
+ if totensor and transform is None:
212
+ transform = ToTensorImage(
213
+ in_keys=["pixels", ("next", "pixels")], shape_tolerant=True
214
+ )
215
+ elif totensor and (
216
+ not isinstance(transform, Compose)
217
+ or not any(isinstance(t, ToTensorImage) for t in transform)
218
+ ):
219
+ transform = Compose(
220
+ transform,
221
+ ToTensorImage(
222
+ in_keys=["pixels", ("next", "pixels")], shape_tolerant=True
223
+ ),
224
+ )
225
+ if image_size is not None:
226
+ transform = Compose(
227
+ transform, Resize(image_size, in_keys=["pixels", ("next", "pixels")])
228
+ )
229
+ storage = TensorStorage(storage)
230
+
231
+ if writer is None:
232
+ writer = ImmutableDatasetWriter()
233
+
234
+ super().__init__(
235
+ storage=storage,
236
+ sampler=sampler,
237
+ writer=writer,
238
+ collate_fn=collate_fn,
239
+ pin_memory=pin_memory,
240
+ prefetch=prefetch,
241
+ transform=transform,
242
+ batch_size=batch_size,
243
+ )
244
+
245
+ @classmethod
246
+ def _parse_datasets(cls):
247
+ from huggingface_hub import HfApi
248
+
249
+ dataset = HfApi().dataset_info("conglu/vd4rl")
250
+ sibs = defaultdict(list)
251
+ for sib in dataset.siblings:
252
+ if sib.rfilename.endswith("npz") or sib.rfilename.endswith("hdf5"):
253
+ path = Path(sib.rfilename)
254
+ sibs[path.parent].append(path)
255
+ return sibs
256
+
257
+ @classmethod
258
+ def _hf_hub_download(cls, subfolder, filename, *, tmpdir):
259
+ from huggingface_hub import hf_hub_download
260
+
261
+ return hf_hub_download(
262
+ "conglu/vd4rl",
263
+ subfolder=subfolder,
264
+ filename=filename,
265
+ repo_type="dataset",
266
+ cache_dir=str(tmpdir),
267
+ )
268
+
269
+ @classmethod
270
+ def _download_and_preproc(cls, dataset_id, data_path, num_workers):
271
+
272
+ tds = []
273
+ with tempfile.TemporaryDirectory() as tmpdir:
274
+ sibs = cls._parse_datasets()
275
+ total_steps = 0
276
+
277
+ paths_to_proc = []
278
+ files_to_proc = []
279
+
280
+ for path in sibs:
281
+ if dataset_id not in str(path):
282
+ continue
283
+ for file in sibs[path]:
284
+ paths_to_proc.append(str(path))
285
+ files_to_proc.append(str(file.parts[-1]))
286
+ func = functools.partial(cls._hf_hub_download, tmpdir=tmpdir)
287
+ if num_workers > 0:
288
+ with mp.Pool(num_workers) as pool:
289
+ files = pool.starmap(
290
+ func,
291
+ zip(paths_to_proc, files_to_proc),
292
+ )
293
+ files = list(files)
294
+ else:
295
+ files = [
296
+ func(subfolder, filename)
297
+ for (subfolder, filename) in zip(paths_to_proc, files_to_proc)
298
+ ]
299
+ torchrl_logger.info("Downloaded, processing files")
300
+ if _has_tqdm:
301
+ import tqdm
302
+
303
+ pbar = tqdm.tqdm(files)
304
+ else:
305
+ pbar = files
306
+ for local_path in pbar:
307
+ if _has_tqdm:
308
+ pbar.set_description(f"file={local_path}")
309
+ # we memmap temporarily the files for faster access later
310
+ if local_path.endswith("hdf5"):
311
+ td = (
312
+ PersistentTensorDict.from_h5(local_path)
313
+ .to_tensordict()
314
+ .memmap(num_threads=32)
315
+ )
316
+ else:
317
+ td = _from_npz(local_path).memmap(num_threads=32)
318
+ td.unlock_()
319
+ if total_steps == 0:
320
+ tdc = cls._process_data(td.clone())
321
+ td_save = tdc[0]
322
+ tds.append(td)
323
+ total_steps += td.shape[0]
324
+
325
+ # From this point, the local paths are non needed anymore
326
+ td_save = td_save.expand(total_steps).memmap_like(data_path, num_threads=32)
327
+ torchrl_logger.info(f"Saved tensordict: {td_save}")
328
+ idx0 = 0
329
+ idx1 = 0
330
+ while len(files):
331
+ _ = files.pop(0)
332
+ td = tds.pop(0)
333
+ td = cls._process_data(td)
334
+ idx1 += td.shape[0]
335
+ td_save[idx0:idx1] = td
336
+ idx0 = idx1
337
+ return td_save
338
+
339
+ @classmethod
340
+ def _process_data(cls, td: TensorDict):
341
+ for name in list(td.keys()):
342
+ # move remaining data
343
+ if name not in _NAME_MATCH:
344
+ td.rename_key_(name, ("state", name))
345
+ elif name != _NAME_MATCH[name]:
346
+ td.rename_key_(name, _NAME_MATCH[name])
347
+ if ("next", "reward") in td.keys(True):
348
+ td.set(("next", "reward"), td.get(("next", "reward")).unsqueeze(-1))
349
+ if ("next", "done") in td.keys(True) and ("next", "terminated") in td.keys(
350
+ True
351
+ ):
352
+ # first unsqueeze
353
+ td.set(("next", "done"), td.get(("next", "done")).unsqueeze(-1))
354
+ td.set(("next", "terminated"), td.get(("next", "terminated")).unsqueeze(-1))
355
+ # create root vals
356
+ td.set("done", torch.zeros_like(td.get(("next", "done"))))
357
+ td.set("terminated", torch.zeros_like(td.get(("next", "terminated"))))
358
+ # Add truncated
359
+ td.set(
360
+ ("next", "truncated"),
361
+ td.get(("next", "done")) & ~td.get(("next", "terminated")),
362
+ )
363
+
364
+ td.set("truncated", torch.zeros_like(td.get(("next", "truncated"))))
365
+
366
+ pixels = td.get("pixels")
367
+ subtd = td._get_sub_tensordict(slice(0, -1))
368
+ subtd.set(("next", "pixels"), pixels[1:], inplace=True)
369
+ state = td.get("state", None)
370
+ if state is not None:
371
+ subtd.set(("next", "state"), state[1:], inplace=True)
372
+
373
+ return td
374
+
375
+ @_classproperty
376
+ def available_datasets(cls):
377
+ return cls._available_datasets()
378
+
379
+ @classmethod
380
+ def _available_datasets(cls):
381
+ # try to gather paths from hf
382
+ try:
383
+ sibs = cls._parse_datasets()
384
+ return [str(path)[6:] for path in sibs]
385
+ except Exception:
386
+ # return the default datasets
387
+ with open(THIS_DIR / "vd4rl.json") as file:
388
+ return json.load(file)
389
+
390
+ def _make_split(self):
391
+ from torchrl.collectors.utils import split_trajectories
392
+
393
+ td_data = TensorDict.load_memmap(self.data_path_root)
394
+ td_data = split_trajectories(td_data).memmap_(self.data_path)
395
+ return td_data
396
+
397
+ def _load(self):
398
+ return TensorDict.load_memmap(self.data_path)
399
+
400
+ @property
401
+ def data_path(self):
402
+ if self.split_trajs:
403
+ return Path(self.root) / (self.dataset_id + "_split")
404
+ return self.data_path_root
405
+
406
+ @property
407
+ def data_path_root(self):
408
+ return Path(self.root) / self.dataset_id
409
+
410
+ def _is_downloaded(self):
411
+ return os.path.exists(self.data_path_root)
412
+
413
+
414
+ def _from_npz(npz_path):
415
+ npz = np.load(npz_path)
416
+ npz_dict = {file: npz[file] for file in npz.files}
417
+ return TensorDict.from_dict(npz_dict, auto_batch_size=True)
418
+
419
+
420
+ _NAME_MATCH = KeyDependentDefaultDict(lambda x: x)
421
+ _NAME_MATCH.update(
422
+ {
423
+ "is_first": "is_init",
424
+ "is_last": ("next", "done"),
425
+ "is_terminal": ("next", "terminated"),
426
+ "reward": ("next", "reward"),
427
+ "image": "pixels",
428
+ "observation": "pixels",
429
+ "discount": "discount",
430
+ "action": "action",
431
+ }
432
+ )
@@ -0,0 +1,34 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .dataset import (
7
+ create_infinite_iterator,
8
+ get_dataloader,
9
+ TensorDictTokenizer,
10
+ TokenizedDatasetLoader,
11
+ )
12
+ from .history import add_chat_template, ContentBase, History
13
+ from .prompt import PromptData, PromptTensorDictTokenizer
14
+ from .reward import PairwiseDataset, RewardData
15
+ from .topk import TopKRewardSelector
16
+ from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
17
+
18
+ __all__ = [
19
+ "AdaptiveKLController",
20
+ "ConstantKLController",
21
+ "ContentBase",
22
+ "History",
23
+ "PairwiseDataset",
24
+ "PromptData",
25
+ "add_chat_template",
26
+ "PromptTensorDictTokenizer",
27
+ "RewardData",
28
+ "RolloutFromModel",
29
+ "TensorDictTokenizer",
30
+ "TokenizedDatasetLoader",
31
+ "create_infinite_iterator",
32
+ "get_dataloader",
33
+ "TopKRewardSelector",
34
+ ]