torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,489 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+ import urllib
12
+ import warnings
13
+ from collections.abc import Callable
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import torch
18
+ from tensordict import make_tensordict, PersistentTensorDict, TensorDict
19
+
20
+ from torchrl._utils import logger as torchrl_logger
21
+ from torchrl.collectors.utils import split_trajectories
22
+ from torchrl.data.datasets.common import BaseDatasetExperienceReplay
23
+ from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS
24
+ from torchrl.data.datasets.utils import _get_root_dir
25
+ from torchrl.data.replay_buffers.samplers import Sampler
26
+ from torchrl.data.replay_buffers.storages import TensorStorage
27
+ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
28
+
29
+
30
+ class D4RLExperienceReplay(BaseDatasetExperienceReplay):
31
+ """An Experience replay class for D4RL.
32
+
33
+ To install D4RL, follow the instructions on the
34
+ `official repo <https://github.com/Farama-Foundation/D4RL>`__.
35
+
36
+ The data format follows the :ref:`TED convention <TED-format>`.
37
+ The replay buffer contains the env specs under D4RLExperienceReplay.specs.
38
+
39
+ If present, metadata will be written in ``D4RLExperienceReplay.metadata``
40
+ and excluded from the dataset.
41
+
42
+ The transitions are reconstructed using ``done = terminated | truncated`` and
43
+ the ``("next", "observation")`` of ``"done"`` states are zeroed.
44
+
45
+ Args:
46
+ dataset_id (str): the dataset_id of the D4RL env to get the data from.
47
+ batch_size (int): the batch size to use during sampling.
48
+ sampler (Sampler, optional): the sampler to be used. If none is provided
49
+ a default RandomSampler() will be used.
50
+ writer (Writer, optional): the writer to be used. If none is provided
51
+ a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
52
+ collate_fn (callable, optional): merges a list of samples to form a
53
+ mini-batch of Tensor(s)/outputs. Used when using batched
54
+ loading from a map-style dataset.
55
+ pin_memory (bool): whether pin_memory() should be called on the rb
56
+ samples.
57
+ prefetch (int, optional): number of next batches to be prefetched
58
+ using multithreading.
59
+ transform (Transform, optional): Transform to be executed when sample() is called.
60
+ To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
61
+ split_trajs (bool, optional): if ``True``, the trajectories will be split
62
+ along the first dimension and padded to have a matching shape.
63
+ To split the trajectories, the ``"done"`` signal will be used, which
64
+ is recovered via ``done = truncated | terminated``. In other words,
65
+ it is assumed that any ``truncated`` or ``terminated`` signal is
66
+ equivalent to the end of a trajectory. For some datasets from
67
+ ``D4RL``, this may not be true. It is up to the user to make
68
+ accurate choices regarding this usage of ``split_trajs``.
69
+ Defaults to ``False``.
70
+ from_env (bool, optional): if ``True``, :meth:`env.get_dataset` will
71
+ be used to retrieve the dataset. Otherwise :func:`d4rl.qlearning_dataset`
72
+ will be used. Defaults to ``True``.
73
+
74
+ .. note::
75
+
76
+ Using ``from_env=False`` will provide fewer data than ``from_env=True``.
77
+ For instance, the info keys will be left out.
78
+ Usually, ``from_env=False`` with ``terminate_on_end=True`` will
79
+ lead to the same result as ``from_env=True``, with the latter
80
+ containing meta-data and info entries that the former does
81
+ not possess.
82
+
83
+ .. note::
84
+
85
+ The keys in ``from_env=True`` and ``from_env=False`` *may* unexpectedly
86
+ differ. In particular, the ``"truncated"`` key (used to determine the
87
+ end of an episode) may be absent when ``from_env=False`` but present
88
+ otherwise, leading to a different slicing when ``traj_splits`` is enabled.
89
+ direct_download (bool): if ``True``, the data will be downloaded without
90
+ requiring D4RL. If ``None``, if ``d4rl`` is present in the env it will
91
+ be used to download the dataset, otherwise the download will fall back
92
+ on ``direct_download=True``.
93
+ This is not compatible with ``from_env=True``.
94
+ Defaults to ``None``.
95
+ use_truncated_as_done (bool, optional): if ``True``, ``done = terminated | truncated``.
96
+ Otherwise, only the ``terminated`` key is used. Defaults to ``True``.
97
+ terminate_on_end (bool, optional): Set ``done=True`` on the last timestep
98
+ in a trajectory. Default is ``False``, and will discard the
99
+ last timestep in each trajectory. This is to be used only with
100
+ ``direct_download=False``.
101
+ root (Path or str, optional): The D4RL dataset root directory.
102
+ The actual dataset memory-mapped files will be saved under
103
+ `<root>/<dataset_id>`. If none is provided, it defaults to
104
+ `~/.cache/torchrl/atari`.d4rl`.
105
+ download (bool, optional): Whether the dataset should be downloaded if
106
+ not found. Defaults to ``True``.
107
+ **env_kwargs (key-value pairs): additional kwargs for
108
+ :func:`d4rl.qlearning_dataset`.
109
+
110
+
111
+ Examples:
112
+ >>> from torchrl.data.datasets.d4rl import D4RLExperienceReplay
113
+ >>> from torchrl.envs import ObservationNorm
114
+ >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128)
115
+ >>> # we can append transforms to the dataset
116
+ >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0, in_keys=["observation"]))
117
+ >>> data.sample(128)
118
+
119
+ """
120
+
121
+ D4RL_ERR = None
122
+
123
+ @classmethod
124
+ def _import_d4rl(cls):
125
+ cls._has_d4rl = importlib.util.find_spec("d4rl") is not None
126
+ try:
127
+ import d4rl # noqa
128
+
129
+ except ModuleNotFoundError as err:
130
+ cls.D4RL_ERR = err
131
+ except Exception:
132
+ pass
133
+
134
+ def __init__(
135
+ self,
136
+ dataset_id,
137
+ batch_size: int,
138
+ sampler: Sampler | None = None,
139
+ writer: Writer | None = None,
140
+ collate_fn: Callable | None = None,
141
+ pin_memory: bool = False,
142
+ prefetch: int | None = None,
143
+ transform: torchrl.envs.Transform | None = None, # noqa-F821
144
+ split_trajs: bool = False,
145
+ from_env: bool = False,
146
+ use_truncated_as_done: bool = True,
147
+ direct_download: bool | None = None,
148
+ terminate_on_end: bool | None = None,
149
+ download: bool = True,
150
+ root: str | Path | None = None,
151
+ **env_kwargs,
152
+ ):
153
+ self.use_truncated_as_done = use_truncated_as_done
154
+ if root is None:
155
+ root = _get_root_dir("d4rl")
156
+ self.root = Path(root)
157
+ self.dataset_id = dataset_id
158
+
159
+ if not from_env and direct_download is None:
160
+ self._import_d4rl()
161
+ direct_download = not self._has_d4rl
162
+
163
+ if not direct_download:
164
+ warnings.warn(
165
+ "You are using the D4RL library for collecting data. "
166
+ "We advise against this use, as D4RL formatting can be "
167
+ "inconsistent. "
168
+ "To download the D4RL data without the D4RL library, use "
169
+ "direct_download=True in the dataset constructor. "
170
+ "Recurring to `direct_download=False` will soon be deprecated."
171
+ )
172
+ self.from_env = from_env
173
+ else:
174
+ self.from_env = from_env
175
+
176
+ if (download == "force") or (download and not self._is_downloaded()):
177
+ if download == "force" and os.path.exists(self.data_path_root):
178
+ shutil.rmtree(self.data_path_root)
179
+
180
+ if not direct_download:
181
+ if terminate_on_end is None:
182
+ # we use the default of d4rl
183
+ terminate_on_end = False
184
+ self._import_d4rl()
185
+
186
+ if not self._has_d4rl:
187
+ raise ImportError("Could not import d4rl") from self.D4RL_ERR
188
+
189
+ if from_env:
190
+ dataset = self._get_dataset_from_env(dataset_id, env_kwargs)
191
+ else:
192
+ if self.use_truncated_as_done:
193
+ warnings.warn(
194
+ "Using use_truncated_as_done=True + terminate_on_end=True "
195
+ "with from_env=False may not have the intended effect "
196
+ "as the timeouts (truncation) "
197
+ "can be absent from the static dataset."
198
+ )
199
+ env_kwargs.update({"terminate_on_end": terminate_on_end})
200
+ dataset = self._get_dataset_direct(dataset_id, env_kwargs)
201
+ else:
202
+ if terminate_on_end is False:
203
+ raise ValueError(
204
+ "Using terminate_on_end=False is not compatible with direct_download=True."
205
+ )
206
+ dataset = self._get_dataset_direct_download(dataset_id, env_kwargs)
207
+ # Fill unknown next states with 0
208
+ dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0
209
+
210
+ if split_trajs:
211
+ dataset = split_trajectories(dataset)
212
+ dataset["next", "done"][:, -1] = True
213
+
214
+ storage = TensorStorage(dataset.memmap(self._dataset_path))
215
+ elif self._is_downloaded():
216
+ storage = TensorStorage(TensorDict.load_memmap(self._dataset_path))
217
+ else:
218
+ raise RuntimeError(
219
+ f"The dataset could not be found in {self._dataset_path}."
220
+ )
221
+
222
+ if writer is None:
223
+ writer = ImmutableDatasetWriter()
224
+ super().__init__(
225
+ batch_size=batch_size,
226
+ storage=storage,
227
+ sampler=sampler,
228
+ writer=writer,
229
+ collate_fn=collate_fn,
230
+ pin_memory=pin_memory,
231
+ prefetch=prefetch,
232
+ transform=transform,
233
+ )
234
+
235
+ @property
236
+ def data_path(self) -> Path:
237
+ return self._dataset_path
238
+
239
+ @property
240
+ def data_path_root(self) -> Path:
241
+ return self._dataset_path
242
+
243
+ @property
244
+ def _dataset_path(self):
245
+ return Path(self.root) / self.dataset_id
246
+
247
+ def _is_downloaded(self):
248
+ return os.path.exists(self._dataset_path)
249
+
250
+ def _get_dataset_direct_download(self, name, env_kwargs):
251
+ """Directly download and use a D4RL dataset."""
252
+ if env_kwargs:
253
+ raise RuntimeError(
254
+ f"Cannot pass env_kwargs when `direct_download=True`. Got env_kwargs keys: {env_kwargs.keys()}"
255
+ )
256
+ url = D4RL_DATASETS.get(name, None)
257
+ if url is None:
258
+ raise KeyError(f"Env {name} not found.")
259
+ with tempfile.TemporaryDirectory() as tmpdir:
260
+ os.environ["D4RL_DATASET_DIR"] = tmpdir
261
+ h5path = _download_dataset_from_url(url, tmpdir)
262
+ # h5path_parent = Path(h5path).parent
263
+ dataset = PersistentTensorDict.from_h5(h5path)
264
+ dataset = dataset.to_tensordict()
265
+ with dataset.unlock_():
266
+ dataset = self._process_data_from_env(dataset)
267
+ return dataset
268
+
269
+ def _get_dataset_direct(self, name, env_kwargs):
270
+ from torchrl.envs.libs.gym import GymWrapper, set_gym_backend
271
+
272
+ type(self)._import_d4rl()
273
+
274
+ if not self._has_d4rl:
275
+ raise ImportError("Could not import d4rl") from self.D4RL_ERR
276
+ import d4rl
277
+
278
+ # D4RL environments are registered with gym, not gymnasium
279
+ # so we need to ensure we're using the gym backend
280
+ with set_gym_backend("gym"):
281
+ import gym
282
+
283
+ env = GymWrapper(gym.make(name))
284
+ with tempfile.TemporaryDirectory() as tmpdir:
285
+ os.environ["D4RL_DATASET_DIR"] = tmpdir
286
+ dataset = d4rl.qlearning_dataset(env._env, **env_kwargs)
287
+
288
+ dataset = make_tensordict(
289
+ {
290
+ k: torch.from_numpy(item)
291
+ for k, item in dataset.items()
292
+ if isinstance(item, np.ndarray)
293
+ },
294
+ auto_batch_size=True,
295
+ )
296
+ dataset = dataset.unflatten_keys("/")
297
+ if "metadata" in dataset.keys():
298
+ metadata = dataset.get("metadata")
299
+ dataset = dataset.exclude("metadata")
300
+ self.metadata = metadata
301
+ # find batch size
302
+ dataset = make_tensordict(
303
+ dataset.flatten_keys("/").to_dict(), auto_batch_size=True
304
+ )
305
+ dataset = dataset.unflatten_keys("/")
306
+ else:
307
+ self.metadata = {}
308
+ dataset.rename_key_("observations", "observation")
309
+ dataset.create_nested("next")
310
+ dataset.rename_key_("next_observations", ("next", "observation"))
311
+ dataset.rename_key_("terminals", "terminated")
312
+ if "timeouts" in dataset.keys():
313
+ dataset.rename_key_("timeouts", "truncated")
314
+ if self.use_truncated_as_done:
315
+ done = dataset.get("terminated") | dataset.get("truncated", False)
316
+ dataset.set("done", done)
317
+ else:
318
+ dataset.set("done", dataset.get("terminated"))
319
+ dataset.rename_key_("rewards", "reward")
320
+ dataset.rename_key_("actions", "action")
321
+
322
+ # let's make sure that the dtypes match what's expected
323
+ for key, spec in env.observation_spec.items(True, True):
324
+ dataset[key] = dataset[key].to(spec.dtype)
325
+ dataset["next", key] = dataset["next", key].to(spec.dtype)
326
+ dataset["action"] = dataset["action"].to(env.action_spec.dtype)
327
+ dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype)
328
+
329
+ # format done etc
330
+ dataset["done"] = dataset["done"].bool().unsqueeze(-1)
331
+ dataset["terminated"] = dataset["terminated"].bool().unsqueeze(-1)
332
+ if "truncated" in dataset.keys():
333
+ dataset["truncated"] = dataset["truncated"].bool().unsqueeze(-1)
334
+ dataset["reward"] = dataset["reward"].unsqueeze(-1)
335
+ dataset["next"].update(
336
+ dataset.select("reward", "done", "terminated", "truncated", strict=False)
337
+ )
338
+ dataset = (
339
+ dataset.clone()
340
+ ) # make sure that all tensors have a different data_ptr
341
+ self._shift_reward_done(dataset)
342
+ self.specs = env.specs.clone()
343
+ return dataset
344
+
345
+ def _get_dataset_from_env(self, name, env_kwargs):
346
+ """Creates an environment and retrieves the dataset using env.get_dataset().
347
+
348
+ This method does not accept extra arguments.
349
+
350
+ """
351
+ if env_kwargs:
352
+ raise RuntimeError("env_kwargs cannot be passed with using from_env=True")
353
+ import d4rl # noqa: F401
354
+
355
+ # we do a local import to avoid circular import issues
356
+ from torchrl.envs.libs.gym import GymWrapper, set_gym_backend
357
+
358
+ # D4RL environments are registered with gym, not gymnasium
359
+ # so we need to ensure we're using the gym backend
360
+ with set_gym_backend("gym"), tempfile.TemporaryDirectory() as tmpdir:
361
+ import gym
362
+
363
+ os.environ["D4RL_DATASET_DIR"] = tmpdir
364
+ env = GymWrapper(gym.make(name))
365
+ dataset = make_tensordict(
366
+ {
367
+ k: torch.from_numpy(item)
368
+ for k, item in env.get_dataset().items()
369
+ if isinstance(item, np.ndarray)
370
+ },
371
+ auto_batch_size=True,
372
+ )
373
+ dataset = dataset.unflatten_keys("/")
374
+ dataset = self._process_data_from_env(dataset, env)
375
+ return dataset
376
+
377
+ def _process_data_from_env(self, dataset, env=None):
378
+ if "metadata" in dataset.keys():
379
+ metadata = dataset.get("metadata")
380
+ dataset = dataset.exclude("metadata")
381
+ self.metadata = metadata
382
+ # find batch size
383
+ dataset = make_tensordict(
384
+ dataset.flatten_keys("/").to_dict(), auto_batch_size=True
385
+ )
386
+ dataset = dataset.unflatten_keys("/")
387
+ else:
388
+ self.metadata = {}
389
+
390
+ dataset.rename_key_("observations", "observation")
391
+ dataset.rename_key_("terminals", "terminated")
392
+ if "timeouts" in dataset.keys():
393
+ dataset.rename_key_("timeouts", "truncated")
394
+ if self.use_truncated_as_done:
395
+ dataset.set(
396
+ "done",
397
+ dataset.get("terminated") | dataset.get("truncated", False),
398
+ )
399
+ else:
400
+ dataset.set("done", dataset.get("terminated"))
401
+
402
+ dataset.rename_key_("rewards", "reward")
403
+ dataset.rename_key_("actions", "action")
404
+ try:
405
+ dataset.rename_key_("infos", "info")
406
+ except KeyError:
407
+ pass
408
+
409
+ # let's make sure that the dtypes match what's expected
410
+ if env is not None:
411
+ for key, spec in env.observation_spec.items(True, True):
412
+ dataset[key] = dataset[key].to(spec.dtype)
413
+ dataset["action"] = dataset["action"].to(env.action_spec.dtype)
414
+ dataset["reward"] = dataset["reward"].to(env.reward_spec.dtype)
415
+
416
+ # format done
417
+ dataset["done"] = dataset["done"].bool().unsqueeze(-1)
418
+ dataset["terminated"] = dataset["terminated"].bool().unsqueeze(-1)
419
+ if "truncated" in dataset.keys():
420
+ dataset["truncated"] = dataset["truncated"].bool().unsqueeze(-1)
421
+
422
+ dataset["reward"] = dataset["reward"].unsqueeze(-1)
423
+ if "next_observations" in dataset.keys():
424
+ dataset = dataset[:-1].set(
425
+ "next",
426
+ dataset.select("info", strict=False)[1:],
427
+ )
428
+ dataset.rename_key_("next_observations", ("next", "observation"))
429
+ else:
430
+ dataset = dataset[:-1].set(
431
+ "next",
432
+ dataset.select("observation", "info", strict=False)[1:],
433
+ )
434
+ dataset["next"].update(
435
+ dataset.select("reward", "done", "terminated", "truncated", strict=False)
436
+ )
437
+ dataset = (
438
+ dataset.clone()
439
+ ) # make sure that all tensors have a different data_ptr
440
+ self._shift_reward_done(dataset)
441
+ if env is not None:
442
+ self.specs = env.specs.clone()
443
+ else:
444
+ self.specs = None
445
+ return dataset
446
+
447
+ def _shift_reward_done(self, dataset):
448
+ dataset["reward"] = dataset["reward"].clone()
449
+ dataset["reward"][1:] = dataset["reward"][:-1].clone()
450
+ dataset["reward"][0] = 0
451
+ for key in ("done", "terminated", "truncated"):
452
+ if key not in dataset.keys():
453
+ continue
454
+ dataset[key] = dataset[key].clone()
455
+ dataset[key][1:] = dataset[key][:-1].clone()
456
+ dataset[key][0] = 0
457
+
458
+
459
+ def _download_dataset_from_url(dataset_url, dataset_path):
460
+ dataset_filepath = _filepath_from_url(dataset_url, dataset_path)
461
+ if not os.path.exists(dataset_filepath):
462
+ torchrl_logger.info(f"Downloading dataset: {dataset_url} to {dataset_filepath}")
463
+ urllib.request.urlretrieve(dataset_url, dataset_filepath)
464
+ if not os.path.exists(dataset_filepath):
465
+ raise OSError("Failed to download dataset from %s" % dataset_url)
466
+ return dataset_filepath
467
+
468
+
469
+ def _filepath_from_url(dataset_url, dataset_path):
470
+ _, dataset_name = os.path.split(dataset_url)
471
+ dataset_filepath = os.path.join(dataset_path, dataset_name)
472
+ return dataset_filepath
473
+
474
+
475
+ # def _set_dataset_path(path):
476
+ # global DATASET_PATH
477
+ # DATASET_PATH = path
478
+ # os.makedirs(path, exist_ok=True)
479
+ #
480
+ #
481
+ # _set_dataset_path(
482
+ # os.environ.get(_get_root_dir("d4rl")))
483
+
484
+ if __name__ == "__main__":
485
+ data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128)
486
+ torchrl_logger.info(data)
487
+ for sample in data:
488
+ torchrl_logger.info(sample)
489
+ break