torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,878 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import gzip
9
+ import io
10
+ import json
11
+ import os
12
+ import shutil
13
+ import subprocess
14
+ import tempfile
15
+ from collections import defaultdict
16
+ from collections.abc import Callable
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ from tensordict import MemoryMappedTensor, TensorDict, TensorDictBase
22
+ from torch import multiprocessing as mp
23
+ from torchrl._utils import logger as torchrl_logger
24
+ from torchrl.data.datasets.common import BaseDatasetExperienceReplay
25
+ from torchrl.data.replay_buffers.samplers import (
26
+ SamplerWithoutReplacement,
27
+ SliceSampler,
28
+ SliceSamplerWithoutReplacement,
29
+ )
30
+ from torchrl.data.replay_buffers.storages import Storage, TensorStorage
31
+ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter
32
+ from torchrl.data.utils import CloudpickleWrapper
33
+ from torchrl.envs.utils import _classproperty
34
+
35
+
36
+ class AtariDQNExperienceReplay(BaseDatasetExperienceReplay):
37
+ """Atari DQN Experience replay class.
38
+
39
+ The Atari DQN dataset (https://offline-rl.github.io/) is a collection of 5 training
40
+ iterations of DQN over each of the Arari 2600 games for a total of 200 million frames.
41
+ The sub-sampling rate (frame-skip) is equal to 4, meaning that each game dataset
42
+ has 50 million steps in total.
43
+
44
+ The data format follows the :ref:`TED convention <TED-format>`. Since the dataset is quite heavy,
45
+ the data formatting is done on-line, at sampling time.
46
+
47
+ To make training more modular, we split the dataset in each of the Atari games
48
+ and separate each training round. Consequently, each dataset is presented as
49
+ a Storage of length 50x10^6 elements. Under the hood, this dataset is split
50
+ in 50 memory-mapped tensordicts of length 1 million each.
51
+
52
+ Args:
53
+ dataset_id (str): The dataset to be downloaded.
54
+ Must be part of ``AtariDQNExperienceReplay.available_datasets``.
55
+ batch_size (int): Batch-size used during sampling.
56
+ Can be overridden by `data.sample(batch_size)` if necessary.
57
+
58
+ Keyword Args:
59
+ root (Path or str, optional): The AtariDQN dataset root directory.
60
+ The actual dataset memory-mapped files will be saved under
61
+ `<root>/<dataset_id>`. If none is provided, it defaults to
62
+ `~/.cache/torchrl/atari`.atari`.
63
+ num_procs (int, optional): number of processes to launch for preprocessing.
64
+ Has no effect whenever the data is already downloaded. Defaults to 0
65
+ (no multiprocessing used).
66
+ download (bool or str, optional): Whether the dataset should be downloaded if
67
+ not found. Defaults to ``True``. Download can also be passed as ``"force"``,
68
+ in which case the downloaded data will be overwritten.
69
+ sampler (Sampler, optional): the sampler to be used. If none is provided
70
+ a default RandomSampler() will be used.
71
+ writer (Writer, optional): the writer to be used. If none is provided
72
+ a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
73
+ collate_fn (callable, optional): merges a list of samples to form a
74
+ mini-batch of Tensor(s)/outputs. Used when using batched
75
+ loading from a map-style dataset.
76
+ pin_memory (bool): whether pin_memory() should be called on the rb
77
+ samples.
78
+ prefetch (int, optional): number of next batches to be prefetched
79
+ using multithreading.
80
+ transform (Transform, optional): Transform to be executed when sample() is called.
81
+ To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
82
+ num_slices (int, optional): the number of slices to be sampled. The batch-size
83
+ must be greater or equal to the ``num_slices`` argument. Exclusive
84
+ with ``slice_len``. Defaults to ``None`` (no slice sampling).
85
+ The ``sampler`` arg will override this value.
86
+ slice_len (int, optional): the length of the slices to be sampled. The batch-size
87
+ must be greater or equal to the ``slice_len`` argument and divisible
88
+ by it. Exclusive with ``num_slices``. Defaults to ``None`` (no slice sampling).
89
+ The ``sampler`` arg will override this value.
90
+ strict_length (bool, optional): if ``False``, trajectories of length
91
+ shorter than `slice_len` (or `batch_size // num_slices`) will be
92
+ allowed to appear in the batch.
93
+ Be mindful that this can result in effective `batch_size` shorter
94
+ than the one asked for! Trajectories can be split using
95
+ :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
96
+ The ``sampler`` arg will override this value.
97
+ replacement (bool, optional): if ``False``, sampling will occur without replacement.
98
+ The ``sampler`` arg will override this value.
99
+ mp_start_method (str, optional): the start method for multiprocessed
100
+ download. Defaults to ``"fork"``.
101
+
102
+ Attributes:
103
+ available_datasets: list of available datasets, formatted as `<game_name>/<run>`. Example:
104
+ `"Pong/5"`, `"Krull/2"`, ...
105
+ dataset_id (str): the name of the dataset.
106
+ episodes (torch.Tensor): a 1d tensor indicating to what run each of the
107
+ 1M frames belongs. To be used with :class:`~torchrl.data.replay_buffers.SliceSampler`
108
+ to cheaply sample slices of episodes.
109
+
110
+ Examples:
111
+ >>> from torchrl.data.datasets import AtariDQNExperienceReplay
112
+ >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128)
113
+ >>> for data in dataset:
114
+ ... print(data)
115
+ ... break
116
+ TensorDict(
117
+ fields={
118
+ action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
119
+ done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
120
+ index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
121
+ metadata: NonTensorData(
122
+ data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}},
123
+ batch_size=torch.Size([128]),
124
+ device=None,
125
+ is_shared=False),
126
+ next: TensorDict(
127
+ fields={
128
+ done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
129
+ observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
130
+ reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
131
+ terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
132
+ truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
133
+ batch_size=torch.Size([128]),
134
+ device=None,
135
+ is_shared=False),
136
+ observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
137
+ terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
138
+ truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
139
+ batch_size=torch.Size([128]),
140
+ device=None,
141
+ is_shared=False)
142
+
143
+ .. warning::
144
+ Atari-DQN does not provide the next observation after a termination signal.
145
+ In other words, there is no way to obtain the ``("next", "observation")`` state
146
+ when ``("next", "done")`` is ``True``. This value is filled with 0s but should
147
+ not be used in practice. If TorchRL's value estimators (:class:`~torchrl.objectives.values.ValueEstimator`)
148
+ are used, this should not be an issue.
149
+
150
+ .. note::
151
+ Because the construction of the sampler for episode sampling is slightly
152
+ convoluted, we made it easy for users to pass the arguments of the
153
+ :class:`~torchrl.data.replay_buffers.SliceSampler` directly to the
154
+ ``AtariDQNExperienceReplay`` dataset: any of the ``num_slices`` or
155
+ ``slice_len`` arguments will make the sampler an instance of
156
+ :class:`~torchrl.data.replay_buffers.SliceSampler`. The ``strict_length``
157
+ can also be passed.
158
+
159
+ >>> from torchrl.data.datasets import AtariDQNExperienceReplay
160
+ >>> from torchrl.data.replay_buffers import SliceSampler
161
+ >>> dataset = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64)
162
+ >>> for data in dataset:
163
+ ... print(data)
164
+ ... print(data.get("index")) # indices are in 4 groups of consecutive values
165
+ ... break
166
+ TensorDict(
167
+ fields={
168
+ action: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int32, is_shared=False),
169
+ done: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
170
+ index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
171
+ metadata: NonTensorData(
172
+ data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'}},
173
+ batch_size=torch.Size([128]),
174
+ device=None,
175
+ is_shared=False),
176
+ next: TensorDict(
177
+ fields={
178
+ done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
179
+ observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
180
+ reward: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
181
+ terminated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
182
+ truncated: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
183
+ batch_size=torch.Size([128]),
184
+ device=None,
185
+ is_shared=False),
186
+ observation: Tensor(shape=torch.Size([128, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
187
+ terminated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False),
188
+ truncated: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.uint8, is_shared=False)},
189
+ batch_size=torch.Size([128]),
190
+ device=None,
191
+ is_shared=False)
192
+ tensor([2657628, 2657629, 2657630, 2657631, 2657632, 2657633, 2657634, 2657635,
193
+ 2657636, 2657637, 2657638, 2657639, 2657640, 2657641, 2657642, 2657643,
194
+ 2657644, 2657645, 2657646, 2657647, 2657648, 2657649, 2657650, 2657651,
195
+ 2657652, 2657653, 2657654, 2657655, 2657656, 2657657, 2657658, 2657659,
196
+ 2657660, 2657661, 2657662, 2657663, 2657664, 2657665, 2657666, 2657667,
197
+ 2657668, 2657669, 2657670, 2657671, 2657672, 2657673, 2657674, 2657675,
198
+ 2657676, 2657677, 2657678, 2657679, 2657680, 2657681, 2657682, 2657683,
199
+ 2657684, 2657685, 2657686, 2657687, 2657688, 2657689, 2657690, 2657691,
200
+ 1995687, 1995688, 1995689, 1995690, 1995691, 1995692, 1995693, 1995694,
201
+ 1995695, 1995696, 1995697, 1995698, 1995699, 1995700, 1995701, 1995702,
202
+ 1995703, 1995704, 1995705, 1995706, 1995707, 1995708, 1995709, 1995710,
203
+ 1995711, 1995712, 1995713, 1995714, 1995715, 1995716, 1995717, 1995718,
204
+ 1995719, 1995720, 1995721, 1995722, 1995723, 1995724, 1995725, 1995726,
205
+ 1995727, 1995728, 1995729, 1995730, 1995731, 1995732, 1995733, 1995734,
206
+ 1995735, 1995736, 1995737, 1995738, 1995739, 1995740, 1995741, 1995742,
207
+ 1995743, 1995744, 1995745, 1995746, 1995747, 1995748, 1995749, 1995750])
208
+
209
+ .. note::
210
+ As always, datasets should be composed using :class:`~torchrl.data.replay_buffers.ReplayBufferEnsemble`:
211
+
212
+ >>> from torchrl.data.datasets import AtariDQNExperienceReplay
213
+ >>> from torchrl.data.replay_buffers import ReplayBufferEnsemble
214
+ >>> # we change this parameter for quick experimentation, in practice it should be left untouched
215
+ >>> AtariDQNExperienceReplay._max_runs = 2
216
+ >>> dataset_asterix = AtariDQNExperienceReplay("Asterix/5", batch_size=128, slice_len=64, num_procs=4)
217
+ >>> dataset_pong = AtariDQNExperienceReplay("Pong/5", batch_size=128, slice_len=64, num_procs=4)
218
+ >>> dataset = ReplayBufferEnsemble(dataset_pong, dataset_asterix, batch_size=128, sample_from_all=True)
219
+ >>> sample = dataset.sample()
220
+ >>> print("first sample, Asterix", sample[0])
221
+ first sample, Asterix TensorDict(
222
+ fields={
223
+ action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False),
224
+ done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
225
+ index: TensorDict(
226
+ fields={
227
+ buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
228
+ index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
229
+ batch_size=torch.Size([64]),
230
+ device=None,
231
+ is_shared=False),
232
+ metadata: NonTensorData(
233
+ data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Pong/5'},
234
+ batch_size=torch.Size([64]),
235
+ device=None,
236
+ is_shared=False),
237
+ next: TensorDict(
238
+ fields={
239
+ done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
240
+ observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
241
+ reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
242
+ terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
243
+ truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
244
+ batch_size=torch.Size([64]),
245
+ device=None,
246
+ is_shared=False),
247
+ observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
248
+ terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
249
+ truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)},
250
+ batch_size=torch.Size([64]),
251
+ device=None,
252
+ is_shared=False)
253
+ >>> print("second sample, Pong", sample[1])
254
+ second sample, Pong TensorDict(
255
+ fields={
256
+ action: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int32, is_shared=False),
257
+ done: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
258
+ index: TensorDict(
259
+ fields={
260
+ buffer_ids: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False),
261
+ index: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
262
+ batch_size=torch.Size([64]),
263
+ device=None,
264
+ is_shared=False),
265
+ metadata: NonTensorData(
266
+ data={'invalid_range': MemoryMappedTensor([999998, 999999, 0, 1, 2]), 'add_count': MemoryMappedTensor(999999), 'dataset_id': 'Asterix/5'},
267
+ batch_size=torch.Size([64]),
268
+ device=None,
269
+ is_shared=False),
270
+ next: TensorDict(
271
+ fields={
272
+ done: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
273
+ observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
274
+ reward: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
275
+ terminated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
276
+ truncated: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
277
+ batch_size=torch.Size([64]),
278
+ device=None,
279
+ is_shared=False),
280
+ observation: Tensor(shape=torch.Size([64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
281
+ terminated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False),
282
+ truncated: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.uint8, is_shared=False)},
283
+ batch_size=torch.Size([64]),
284
+ device=None,
285
+ is_shared=False)
286
+ >>> print("Aggregate (metadata hidden)", sample)
287
+ Aggregate (metadata hidden) LazyStackedTensorDict(
288
+ fields={
289
+ action: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
290
+ done: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
291
+ index: LazyStackedTensorDict(
292
+ fields={
293
+ buffer_ids: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
294
+ index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False)},
295
+ exclusive_fields={
296
+ },
297
+ batch_size=torch.Size([2, 64]),
298
+ device=None,
299
+ is_shared=False,
300
+ stack_dim=0),
301
+ metadata: LazyStackedTensorDict(
302
+ fields={
303
+ },
304
+ exclusive_fields={
305
+ },
306
+ batch_size=torch.Size([2, 64]),
307
+ device=None,
308
+ is_shared=False,
309
+ stack_dim=0),
310
+ next: LazyStackedTensorDict(
311
+ fields={
312
+ done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
313
+ observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
314
+ reward: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
315
+ terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
316
+ truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
317
+ exclusive_fields={
318
+ },
319
+ batch_size=torch.Size([2, 64]),
320
+ device=None,
321
+ is_shared=False,
322
+ stack_dim=0),
323
+ observation: Tensor(shape=torch.Size([2, 64, 84, 84]), device=cpu, dtype=torch.uint8, is_shared=False),
324
+ terminated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
325
+ truncated: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.uint8, is_shared=False)},
326
+ exclusive_fields={
327
+ },
328
+ batch_size=torch.Size([2, 64]),
329
+ device=None,
330
+ is_shared=False,
331
+ stack_dim=0)
332
+
333
+ """
334
+
335
+ @_classproperty
336
+ def available_datasets(cls):
337
+ games = [
338
+ "AirRaid",
339
+ "Alien",
340
+ "Amidar",
341
+ "Assault",
342
+ "Asterix",
343
+ "Asteroids",
344
+ "Atlantis",
345
+ "BankHeist",
346
+ "BattleZone",
347
+ "BeamRider",
348
+ "Berzerk",
349
+ "Bowling",
350
+ "Boxing",
351
+ "Breakout",
352
+ "Carnival",
353
+ "Centipede",
354
+ "ChopperCommand",
355
+ "CrazyClimber",
356
+ "DemonAttack",
357
+ "DoubleDunk",
358
+ "ElevatorAction",
359
+ "Enduro",
360
+ "FishingDerby",
361
+ "Freeway",
362
+ "Frostbite",
363
+ "Gopher",
364
+ "Gravitar",
365
+ "Hero",
366
+ "IceHockey",
367
+ "Jamesbond",
368
+ "JourneyEscape",
369
+ "Kangaroo",
370
+ "Krull",
371
+ "KungFuMaster",
372
+ "MontezumaRevenge",
373
+ "MsPacman",
374
+ "NameThisGame",
375
+ "Phoenix",
376
+ "Pitfall",
377
+ "Pong",
378
+ "Pooyan",
379
+ "PrivateEye",
380
+ "Qbert",
381
+ "Riverraid",
382
+ "RoadRunner",
383
+ "Robotank",
384
+ "Seaquest",
385
+ "Skiing",
386
+ "Solaris",
387
+ "SpaceInvaders",
388
+ ]
389
+ return ["/".join((game, str(loop))) for game in games for loop in range(1, 6)]
390
+
391
+ # If we want to keep track of the original atari files
392
+ tmpdir = None
393
+ # use _max_runs for debugging, avoids downloading the entire dataset
394
+ _max_runs = None
395
+
396
+ def __init__(
397
+ self,
398
+ dataset_id: str,
399
+ batch_size: int | None = None,
400
+ *,
401
+ root: str | Path | None = None,
402
+ download: bool | str = True,
403
+ sampler=None,
404
+ writer=None,
405
+ transform: Transform | None = None, # noqa: F821
406
+ num_procs: int = 0,
407
+ num_slices: int | None = None,
408
+ slice_len: int | None = None,
409
+ strict_len: bool = True,
410
+ replacement: bool = True,
411
+ mp_start_method: str = "fork",
412
+ **kwargs,
413
+ ):
414
+ import warnings
415
+
416
+ warnings.warn(
417
+ "This dataset is no longer available. We are working on a fix, or possibly a deprecation.",
418
+ DeprecationWarning,
419
+ )
420
+ if dataset_id not in self.available_datasets:
421
+ raise ValueError(
422
+ "The dataseet_id is not part of the available datasets. The dataset should be named <game_name>/<run> "
423
+ "where <game_name> is one of the Atari 2600 games and the run is a number between 1 and 5. "
424
+ "The full list of accepted dataset_ids is available under AtariDQNExperienceReplay.available_datasets."
425
+ )
426
+ self.dataset_id = dataset_id
427
+ from torchrl.data.datasets.utils import _get_root_dir
428
+
429
+ if root is None:
430
+ root = _get_root_dir("atari")
431
+ self.root = root
432
+ self.num_procs = num_procs
433
+ self.mp_start_method = mp_start_method
434
+ if download == "force" or (download and not self._is_downloaded):
435
+ try:
436
+ self._download_and_preproc()
437
+ except Exception:
438
+ # remove temporary data
439
+ if os.path.exists(self.dataset_path):
440
+ shutil.rmtree(self.dataset_path)
441
+ raise
442
+ if self._downloaded_and_preproc:
443
+ storage = TensorStorage(TensorDict.load_memmap(self.dataset_path))
444
+ else:
445
+ storage = _AtariStorage(self.dataset_path)
446
+ if writer is None:
447
+ writer = ImmutableDatasetWriter()
448
+ if sampler is None:
449
+ if num_slices is not None or slice_len is not None:
450
+ if not replacement:
451
+ sampler = SliceSamplerWithoutReplacement(
452
+ num_slices=num_slices,
453
+ slice_len=slice_len,
454
+ trajectories=storage.episodes,
455
+ )
456
+ else:
457
+ sampler = SliceSampler(
458
+ num_slices=num_slices,
459
+ slice_len=slice_len,
460
+ trajectories=storage.episodes,
461
+ cache_values=True,
462
+ )
463
+ elif not replacement:
464
+ sampler = SamplerWithoutReplacement()
465
+
466
+ super().__init__(
467
+ storage=storage,
468
+ batch_size=batch_size,
469
+ writer=writer,
470
+ sampler=sampler,
471
+ collate_fn=lambda x: x,
472
+ transform=transform,
473
+ **kwargs,
474
+ )
475
+
476
+ @property
477
+ def episodes(self):
478
+ return self._storage.episodes
479
+
480
+ @property
481
+ def root(self) -> Path:
482
+ return self._root
483
+
484
+ @root.setter
485
+ def root(self, value):
486
+ self._root = Path(value)
487
+
488
+ @property
489
+ def dataset_path(self) -> Path:
490
+ return self._root / self.dataset_id
491
+
492
+ @property
493
+ def _downloaded_and_preproc(self):
494
+ return os.path.exists(self.dataset_path / "meta.json")
495
+
496
+ @property
497
+ def _is_downloaded(self):
498
+ if os.path.exists(self.dataset_path / "meta.json"):
499
+ return True
500
+ if os.path.exists(self.dataset_path / "processed.json"):
501
+ with open(self.dataset_path / "processed.json") as jsonfile:
502
+ return json.load(jsonfile).get("processed", False) == self._max_runs
503
+ return False
504
+
505
+ def _download_and_preproc(self):
506
+ torchrl_logger.info(
507
+ f"Downloading and preprocessing dataset {self.dataset_id} with {self.num_procs} processes. This may take a while..."
508
+ )
509
+ if os.path.exists(self.dataset_path):
510
+ shutil.rmtree(self.dataset_path)
511
+ with tempfile.TemporaryDirectory() as tempdir:
512
+ if self.tmpdir is not None:
513
+ tempdir = self.tmpdir
514
+ if not os.listdir(tempdir):
515
+ os.makedirs(tempdir, exist_ok=True)
516
+ # get the list of runs
517
+ try:
518
+ subprocess.run(
519
+ ["gsutil", "version"], check=True, capture_output=True
520
+ )
521
+ except subprocess.CalledProcessError:
522
+ raise RuntimeError("gsutil is not installed or not found in PATH.")
523
+ command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs"
524
+ output = subprocess.run(
525
+ command, shell=True, capture_output=True
526
+ ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
527
+ files = [
528
+ file.decode("utf-8").replace("$", r"\$") # noqa: W605
529
+ for file in output.stdout.splitlines()
530
+ if file.endswith(b".gz")
531
+ ]
532
+ self.remote_gz_files = self._list_runs(None, files)
533
+ remote_gz_files = list(self.remote_gz_files)
534
+ if not len(remote_gz_files):
535
+ raise RuntimeError("No files in file list.")
536
+
537
+ total_runs = remote_gz_files[-1]
538
+ if self.num_procs == 0:
539
+ for run, run_files in self.remote_gz_files.items():
540
+ self._download_and_proc_split(
541
+ run,
542
+ run_files,
543
+ tempdir=tempdir,
544
+ dataset_path=self.dataset_path,
545
+ total_episodes=total_runs,
546
+ max_runs=self._max_runs,
547
+ multithreaded=True,
548
+ )
549
+ else:
550
+ func = functools.partial(
551
+ self._download_and_proc_split,
552
+ tempdir=tempdir,
553
+ dataset_path=self.dataset_path,
554
+ total_episodes=total_runs,
555
+ max_runs=self._max_runs,
556
+ multithreaded=False,
557
+ )
558
+ args = [
559
+ (run, run_files)
560
+ for (run, run_files) in self.remote_gz_files.items()
561
+ ]
562
+ ctx = mp.get_context(self.mp_start_method)
563
+ with ctx.Pool(self.num_procs) as pool:
564
+ pool.starmap(func, args)
565
+ with open(self.dataset_path / "processed.json", "w") as file:
566
+ # we save self._max_runs such that changing the number of runs to process
567
+ # forces the data to be re-downloaded
568
+ json.dump({"processed": self._max_runs}, file)
569
+
570
+ @classmethod
571
+ def _download_and_proc_split(
572
+ cls,
573
+ run,
574
+ run_files,
575
+ *,
576
+ tempdir,
577
+ dataset_path,
578
+ total_episodes,
579
+ max_runs,
580
+ multithreaded=True,
581
+ ):
582
+ if (max_runs is not None) and (run >= max_runs):
583
+ return
584
+ tempdir = Path(tempdir)
585
+ os.makedirs(tempdir / str(run))
586
+ files_str = " ".join(run_files) # .decode("utf-8")
587
+ torchrl_logger.info(f"downloading {files_str}")
588
+ if multithreaded:
589
+ command = f"gsutil -m cp {files_str} {tempdir}/{run}"
590
+ else:
591
+ command = f"gsutil cp {files_str} {tempdir}/{run}"
592
+ subprocess.run(
593
+ command, shell=True
594
+ ) # , stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
595
+ local_gz_files = cls._list_runs(tempdir / str(run))
596
+ # we iterate over the dict but this one has length 1
597
+ for run in local_gz_files:
598
+ path = dataset_path / str(run)
599
+ try:
600
+ cls._preproc_run(path, local_gz_files, run)
601
+ except Exception:
602
+ shutil.rmtree(path)
603
+ raise
604
+ shutil.rmtree(tempdir / str(run))
605
+ torchrl_logger.info(f"Concluded run {run} out of {total_episodes}")
606
+
607
+ @classmethod
608
+ def _preproc_run(cls, path, gz_files, run):
609
+ files = gz_files[run]
610
+ td = TensorDict()
611
+ path = Path(path)
612
+ for file in files:
613
+ name = str(Path(file).parts[-1]).split(".")[0]
614
+ with gzip.GzipFile(file, mode="rb") as f:
615
+ file_content = f.read()
616
+ file_content = io.BytesIO(file_content)
617
+ file_content = np.load(file_content)
618
+ t = torch.as_tensor(file_content)
619
+ # Create the memmap file
620
+ key = cls._process_name(name)
621
+ if key == ("data", "observation"):
622
+ shape = t.shape
623
+ shape = [shape[0] + 1] + list(shape[1:])
624
+ filename = path / "data" / "observation.memmap"
625
+ os.makedirs(filename.parent, exist_ok=True)
626
+ mmap = MemoryMappedTensor.empty(shape, dtype=t.dtype, filename=filename)
627
+ mmap[:-1].copy_(t)
628
+ td[key] = mmap
629
+ # td["data", "next", key[1:]] = mmap[1:]
630
+ else:
631
+ if key in (
632
+ ("data", "reward"),
633
+ ("data", "done"),
634
+ ("data", "terminated"),
635
+ ):
636
+ filename = path / "data" / "next" / (key[-1] + ".memmap")
637
+ os.makedirs(filename.parent, exist_ok=True)
638
+ mmap = MemoryMappedTensor.from_tensor(t, filename=filename)
639
+ td["data", "next", key[1:]] = mmap
640
+ else:
641
+ filename = path
642
+ for i, _key in enumerate(key):
643
+ if i == len(key) - 1:
644
+ _key = _key + ".memmap"
645
+ filename = filename / _key
646
+ os.makedirs(filename.parent, exist_ok=True)
647
+ mmap = MemoryMappedTensor.from_tensor(t, filename=filename)
648
+ td[key] = mmap
649
+ td.set_non_tensor("dataset_id", "/".join(path.parts[-3:-1]))
650
+ td.memmap_(path, copy_existing=False)
651
+
652
+ @staticmethod
653
+ def _process_name(name):
654
+ if name.endswith("_ckpt"):
655
+ name = name[:-5]
656
+ if "store" in name:
657
+ key = ("data", name.split("_")[1])
658
+ else:
659
+ key = (name,)
660
+ if key[-1] == "terminal":
661
+ key = (*key[:-1], "terminated")
662
+ return key
663
+
664
+ @classmethod
665
+ def _list_runs(cls, download_path, gz_files=None) -> dict:
666
+ path = download_path
667
+ if gz_files is None:
668
+ gz_files = []
669
+ for root, _, files in os.walk(path):
670
+ for file in files:
671
+ if file.endswith(".gz"):
672
+ gz_files.append(os.path.join(root, file))
673
+ runs = defaultdict(list)
674
+ for file in gz_files:
675
+ filename = Path(file).parts[-1]
676
+ name, episode, extension = str(filename).split(".")
677
+ episode = int(episode)
678
+ runs[episode].append(file)
679
+ return dict(sorted(runs.items(), key=lambda x: x[0]))
680
+
681
+ def preprocess(
682
+ self,
683
+ fn: Callable[[TensorDictBase], TensorDictBase],
684
+ dim: int = 0,
685
+ num_workers: int | None = None,
686
+ *,
687
+ chunksize: int | None = None,
688
+ num_chunks: int | None = None,
689
+ pool: mp.Pool | None = None,
690
+ generator: torch.Generator | None = None,
691
+ max_tasks_per_child: int | None = None,
692
+ worker_threads: int = 1,
693
+ index_with_generator: bool = False,
694
+ pbar: bool = False,
695
+ mp_start_method: str | None = None,
696
+ dest: str | Path,
697
+ num_frames: int | None = None,
698
+ ):
699
+ # Copy data to a tensordict
700
+ with tempfile.TemporaryDirectory() as tmpdir:
701
+ first_item = self[0]
702
+ metadata = first_item.pop("metadata")
703
+
704
+ mmap = fn(first_item)
705
+ if num_frames is None:
706
+ num_frames = len(self)
707
+ mmap = mmap.expand(num_frames, *first_item.shape)
708
+ mmap = mmap.memmap_like(tmpdir, num_threads=32)
709
+ with mmap.unlock_():
710
+ mmap["_indices"] = torch.arange(mmap.shape[0])
711
+ mmap.memmap_(tmpdir, num_threads=32)
712
+
713
+ def func(mmap: TensorDictBase):
714
+ idx = mmap["_indices"]
715
+ orig = self[idx].exclude("metadata")
716
+ orig = fn(orig)
717
+ mmap.update(orig, inplace=True)
718
+ return
719
+
720
+ if dim != 0:
721
+ raise RuntimeError("dim != 0 is not supported.")
722
+
723
+ mmap.map(
724
+ fn=CloudpickleWrapper(func),
725
+ dim=dim,
726
+ num_workers=num_workers,
727
+ chunksize=chunksize,
728
+ num_chunks=num_chunks,
729
+ pool=pool,
730
+ generator=generator,
731
+ max_tasks_per_child=max_tasks_per_child,
732
+ worker_threads=worker_threads,
733
+ index_with_generator=index_with_generator,
734
+ mp_start_method=mp_start_method,
735
+ pbar=pbar,
736
+ )
737
+
738
+ with mmap.unlock_():
739
+ return TensorStorage(mmap.set("metadata", metadata))
740
+
741
+
742
+ class _AtariStorage(Storage):
743
+ def __init__(self, path):
744
+ self.path = Path(path)
745
+
746
+ def get_folders(path):
747
+ return [
748
+ name
749
+ for name in os.listdir(path)
750
+ if os.path.isdir(os.path.join(path, name))
751
+ ]
752
+
753
+ # Usage
754
+ self.splits = []
755
+ folders = get_folders(path)
756
+ for folder in folders:
757
+ self.splits.append(int(Path(folder).parts[-1]))
758
+ self.splits = sorted(self.splits)
759
+ self._split_tds = []
760
+ frames_per_split = {}
761
+ for split in self.splits:
762
+ path = self.path / str(split)
763
+ self._split_tds.append(self._load_split(path))
764
+ # take away 1 because we padded with 1 empty val
765
+ frames_per_split[split] = (
766
+ self._split_tds[-1].get(("data", "observation")).shape[0] - 1
767
+ )
768
+
769
+ frames_per_split = torch.tensor(
770
+ [[split, length] for (split, length) in frames_per_split.items()]
771
+ )
772
+ frames_per_split[:, 1] = frames_per_split[:, 1].cumsum(0)
773
+ self.frames_per_split = torch.cat(
774
+ # [torch.tensor([[-1, 0]]), frames_per_split], 0
775
+ [torch.tensor([[-1, 0]]), frames_per_split],
776
+ 0,
777
+ )
778
+
779
+ # retrieve episodes
780
+ self.episodes = torch.cumsum(
781
+ torch.cat(
782
+ [td.get(("data", "next", "terminated")) for td in self._split_tds], 0
783
+ ),
784
+ 0,
785
+ )
786
+ super().__init__(max_size=len(self))
787
+
788
+ def __len__(self):
789
+ return self.frames_per_split[-1, 1].item()
790
+
791
+ def _read_from_splits(self, item: int | torch.Tensor):
792
+ # We need to allocate each item to its storage.
793
+ # We don't assume each storage has the same size (too expensive to test)
794
+ # so we keep a map of each storage cumulative length and retrieve the
795
+ # storages one after the other.
796
+ item = torch.as_tensor(item)
797
+ if not item.ndim:
798
+ is_int = True
799
+ item = item.reshape(-1)
800
+ else:
801
+ is_int = False
802
+ split = (item < self.frames_per_split[1:, 1].unsqueeze(1)) & (
803
+ item >= self.frames_per_split[:-1, 1].unsqueeze(1)
804
+ )
805
+ # split_tmp, idx = split.squeeze().nonzero().unbind(-1)
806
+ split_tmp, idx = split.nonzero().unbind(-1)
807
+ split = split_tmp.squeeze()
808
+ idx = idx.squeeze()
809
+
810
+ if not is_int:
811
+ split = torch.zeros_like(split_tmp)
812
+ split[idx] = split_tmp
813
+ split = self.frames_per_split[split + 1, 0]
814
+ item = item - self.frames_per_split[split, 1]
815
+ if is_int:
816
+ item = item.squeeze()
817
+ return self._proc_td(self._split_tds[split], item)
818
+ unique_splits, split_inverse = torch.unique(split, return_inverse=True)
819
+ unique_splits = unique_splits.tolist()
820
+ out = []
821
+ for i, split in enumerate(unique_splits):
822
+ _item = item[split_inverse == i] if split_inverse is not None else item
823
+ out.append(self._proc_td(self._split_tds[split], _item))
824
+ return torch.cat(out, 0)
825
+
826
+ def _load_split(self, path):
827
+ return TensorDict.load_memmap(path)
828
+
829
+ def _proc_td(self, td, index):
830
+ td_data = td.get("data")
831
+ obs_ = td_data.get("observation")[index + 1]
832
+ done = td_data.get(("next", "terminated"))[index].squeeze(-1).bool()
833
+ if done.ndim and done.any():
834
+ obs_ = torch.index_fill(obs_, 0, done.nonzero().squeeze(), 0)
835
+ td_idx = td.empty()
836
+ td_idx.set(("next", "observation"), obs_)
837
+ non_tensor = td.exclude("data").to_dict()
838
+ td_idx.update(td_data.apply(lambda x: x[index]))
839
+ if isinstance(index, torch.Tensor) and index.ndim:
840
+ td_idx.batch_size = [len(index)]
841
+ td_idx.set_non_tensor("metadata", non_tensor)
842
+
843
+ terminated = td_idx.get(("next", "terminated"))
844
+ zterminated = torch.zeros_like(terminated)
845
+ td_idx.set(("next", "done"), terminated.clone())
846
+ td_idx.set(("next", "truncated"), zterminated)
847
+ td_idx.set("terminated", zterminated)
848
+ td_idx.set("done", zterminated)
849
+ td_idx.set("truncated", zterminated)
850
+
851
+ return td_idx
852
+
853
+ def get(self, index):
854
+ if isinstance(index, int):
855
+ return self._read_from_splits(index)
856
+ if isinstance(index, tuple):
857
+ if len(index) == 1:
858
+ return self.get(index[0])
859
+ return self.get(index[0])[(Ellipsis, *index[1:])]
860
+ if isinstance(index, torch.Tensor):
861
+ if index.ndim <= 1:
862
+ return self._read_from_splits(index)
863
+ elif index.shape[1] == 1:
864
+ index = index.squeeze(1)
865
+ return self.get(index)
866
+ else:
867
+ raise RuntimeError("Only 1d tensors are accepted")
868
+ # with ThreadPoolExecutor(16) as pool:
869
+ # results = map(self.__getitem__, index.tolist())
870
+ # return torch.stack(list(results))
871
+ if isinstance(index, (range, list)):
872
+ return self[torch.tensor(index)]
873
+ if isinstance(index, slice):
874
+ start = index.start if index.start is not None else 0
875
+ stop = index.stop if index.stop is not None else len(self)
876
+ step = index.step if index.step is not None else 1
877
+ return self.get(torch.arange(start, stop, step))
878
+ return self[torch.arange(len(self))[index]]