torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,798 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import io
9
+ import json
10
+ import os
11
+ import shutil
12
+ import tempfile
13
+ from collections.abc import Callable
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import torch
18
+ from tensordict import make_tensordict, NonTensorData, pad, TensorDict
19
+ from tensordict.utils import _is_non_tensor
20
+
21
+ from torchrl.data.datasets.common import BaseDatasetExperienceReplay
22
+ from torchrl.data.datasets.utils import _get_root_dir
23
+ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
24
+ from torchrl.data.replay_buffers.samplers import (
25
+ Sampler,
26
+ SliceSampler,
27
+ SliceSamplerWithoutReplacement,
28
+ )
29
+ from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage
30
+ from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
31
+
32
+ _has_datasets = importlib.util.find_spec("datasets", None) is not None
33
+ _has_tv = importlib.util.find_spec("torchvision", None) is not None
34
+
35
+
36
+ class OpenXExperienceReplay(BaseDatasetExperienceReplay):
37
+ """Open X-Embodiment datasets experience replay.
38
+
39
+ The Open X-Embodiment Dataset contains 1M+ real robot trajectories
40
+ spanning 22 robot embodiments, collected through a collaboration between
41
+ 21 institutions, demonstrating 527 skills (160266 tasks).
42
+
43
+ Website: https://robotics-transformer-x.github.io/
44
+
45
+ GitHub: https://github.com/google-deepmind/open_x_embodiment
46
+
47
+ Paper: https://arxiv.org/abs/2310.08864
48
+
49
+ The data format follows the :ref:`TED convention <TED-format>`.
50
+
51
+ .. note::
52
+ Non-tensor data will be written in the tensordict data using the
53
+ :class:`~tensordict.tensorclass.NonTensorData` primitive.
54
+ For instance, the `language_instruction` field in the data will
55
+ be stored in `data.get_non_tensor("language_instruction")` (or equivalently
56
+ `data.get("language_instruction").data`). See the documentation of this
57
+ class for more information on how to interact with non-tensor data
58
+ stored in a :class:`~tensordict.TensorDict`.
59
+
60
+ Args:
61
+ dataset_id (str): The dataset to be downloaded.
62
+ Must be part of ``OpenXExperienceReplay.available_datasets``.
63
+ batch_size (int): Batch-size used during sampling.
64
+ Can be overridden by `data.sample(batch_size)` if necessary.
65
+ See ``num_slices`` and ``slice_len`` keyword arguments for a refined
66
+ sampling strategy.
67
+ If the ``batch_size`` is ``None`` (default), iterating over the
68
+ dataset will deliver trajectories one at a time *whereas* calling
69
+ :meth:`sample` will *still* require a batch-size to be provided.
70
+
71
+ Keyword Args:
72
+ shuffle (bool, optional): if ``True``, trajectories are delivered in a
73
+ random order when the dataset is iterated over.
74
+ If ``False``, the dataset is iterated over in the pre-defined order.
75
+
76
+ .. warning::
77
+ shuffle=False will also impact the sampling. We advice users to
78
+ create a copy of the dataset where the ``shuffle`` attribute of the
79
+ sampler is set to ``False`` if they wish to enjoy the two different
80
+ behaviors (shuffled and not) within the same code base.
81
+
82
+ num_slices (int, optional): the number of slices in a batch. This
83
+ corresponds to the number of trajectories present in a batch.
84
+ Once collected, the batch is presented as a concatenation of
85
+ sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`.
86
+ The `batch_size` must be divisible by `num_slices` if provided.
87
+ This argument is exclusive with ``slice_len``.
88
+ If the ``num_slices`` argument equates the ``batch_size``, each sample
89
+ will belong to a different trajectory.
90
+ If neither ``slice_len`` nor ``num_slice`` are provided:
91
+ whenever a trajectory has a length shorter than the
92
+ batch-size, a contiguous slice of it of length `batch_size` will be
93
+ sampled. If the trajectory length is insufficient, an exception will
94
+ be raised unless `pad` is not `None`.
95
+ slice_len (int, optional): the length of slices in a batch. This
96
+ corresponds to the length of trajectories present in a batch.
97
+ Once collected, the batch is presented as a concatenation of
98
+ sub-trajectories that can be recovered through `batch.reshape(-1, slice_len)`.
99
+ The `batch_size` must be divisible by `slice_len` if provided.
100
+ This argument is exclusive with ``num_slice``.
101
+ If the ``slice_len`` argument equates ``1``, each sample
102
+ will belong to a different trajectory.
103
+ If neither ``slice_len`` nor ``num_slice`` are provided:
104
+ whenever a trajectory has a length shorter than the
105
+ batch-size, a contiguous slice of it of length `batch_size` will be
106
+ sampled. If the trajectory length is insufficient, an exception will
107
+ be raised unless `pad` is not `None`.
108
+
109
+ .. note::
110
+ The ``slice_len`` (but not ``num_slices``) can be used when
111
+ iterating over a dataset without passing a batch-size in the,
112
+ constructor. In these cases, a random sub-sequence of the
113
+ trajectory will be chosen.
114
+
115
+ replacement (bool, optional): if ``False``, sampling will be done
116
+ without replacement. Defaults to ``True`` for downloaded datasets,
117
+ ``False`` for streamed datasets.
118
+ pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length
119
+ given the `slice_len` or `num_slices` arguments will be padded with
120
+ 0s. If another value is provided, it will be used for padding. If
121
+ ``False`` or ``None`` (default) any encounter with a trajectory of
122
+ insufficient length will raise an exception.
123
+ root (Path or str, optional): The OpenX dataset root directory.
124
+ The actual dataset memory-mapped files will be saved under
125
+ `<root>/<dataset_id>`. If none is provided, it defaults to
126
+ `~/.cache/torchrl/atari`.openx`.
127
+ streaming (bool, optional): if ``True``, the data won't be downloaded but
128
+ read from a stream instead.
129
+
130
+ .. note:: The formatting of the data **will change** when `download=True`
131
+ compared to `streaming=True`. If the data is downloaded and
132
+ the sampler is left untouched (ie, `num_slices=None`, `slice_len=None`
133
+ and `sampler=None`, transitions will be sampled randomly from
134
+ the dataset. This isn't possible at a reasonable cost with
135
+ `streaming=True`: in this case, trajectories will be sampled
136
+ one at a time and delivered as such (with cropping to comply with
137
+ the batch-size etc). The behavior of the two modalities is
138
+ much more similar when `num_slices` and `slice_len` are specified,
139
+ as in these cases, views of sub-episodes will be returned in both
140
+ cases.
141
+
142
+ download (bool or str, optional): Whether the dataset should be downloaded if
143
+ not found. Defaults to ``True``. Download can also be passed as "force",
144
+ in which case the downloaded data will be overwritten.
145
+ sampler (Sampler, optional): the sampler to be used. If none is provided
146
+ a default RandomSampler() will be used.
147
+ writer (Writer, optional): the writer to be used. If none is provided
148
+ a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
149
+ collate_fn (callable, optional): merges a list of samples to form a
150
+ mini-batch of Tensor(s)/outputs. Used when using batched
151
+ loading from a map-style dataset.
152
+ pin_memory (bool): whether pin_memory() should be called on the rb
153
+ samples.
154
+ prefetch (int, optional): number of next batches to be prefetched
155
+ using multithreading.
156
+ transform (Transform, optional): Transform to be executed when sample() is called.
157
+ To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
158
+ split_trajs (bool, optional): if ``True``, the trajectories will be split
159
+ along the first dimension and padded to have a matching shape.
160
+ To split the trajectories, the ``"done"`` signal will be used, which
161
+ is recovered via ``done = truncated | terminated``. In other words,
162
+ it is assumed that any ``truncated`` or ``terminated`` signal is
163
+ equivalent to the end of a trajectory.
164
+ Defaults to ``False``.
165
+ strict_length (bool, optional): if ``False``, trajectories of length
166
+ shorter than `slice_len` (or `batch_size // num_slices`) will be
167
+ allowed to appear in the batch.
168
+ Be mindful that this can result in effective `batch_size` shorter
169
+ than the one asked for! Trajectories can be split using
170
+ :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
171
+
172
+ Examples:
173
+ >>> from torchrl.data.datasets import OpenXExperienceReplay
174
+ >>> import tempfile
175
+ >>> # Download the data, and sample 128 elements in each batch out of two trajectories
176
+ >>> num_slices = 2
177
+ >>> with tempfile.TemporaryDirectory() as root:
178
+ ... dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128,
179
+ ... num_slices=num_slices, download=True, streaming=False,
180
+ ... root=root,
181
+ ... )
182
+ ... for batch in dataset:
183
+ ... print(batch.reshape(num_slices, -1))
184
+ ... break
185
+ TensorDict(
186
+ fields={
187
+ action: Tensor(shape=torch.Size([2, 64, 8]), device=cpu, dtype=torch.float64, is_shared=False),
188
+ discount: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
189
+ done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
190
+ episode: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
191
+ index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
192
+ is_init: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False),
193
+ language_embedding: Tensor(shape=torch.Size([2, 64, 512]), device=cpu, dtype=torch.float64, is_shared=False),
194
+ language_instruction: NonTensorData(
195
+ data='lift open green garbage can lid',
196
+ batch_size=torch.Size([2, 64]),
197
+ device=cpu,
198
+ is_shared=False),
199
+ next: TensorDict(
200
+ fields={
201
+ done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
202
+ observation: TensorDict(
203
+ fields={
204
+ image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
205
+ state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
206
+ batch_size=torch.Size([2, 64]),
207
+ device=cpu,
208
+ is_shared=False),
209
+ reward: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.float32, is_shared=False),
210
+ terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
211
+ truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
212
+ batch_size=torch.Size([2, 64]),
213
+ device=cpu,
214
+ is_shared=False),
215
+ observation: TensorDict(
216
+ fields={
217
+ image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
218
+ state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
219
+ batch_size=torch.Size([2, 64]),
220
+ device=cpu,
221
+ is_shared=False),
222
+ terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
223
+ truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
224
+ batch_size=torch.Size([2, 64]),
225
+ device=cpu,
226
+ is_shared=False)
227
+ >>> # Read data from a stream. Deliver entire trajectories when iterating
228
+ >>> dataset = OpenXExperienceReplay("cmu_stretch",
229
+ ... num_slices=num_slices, download=False, streaming=True)
230
+ >>> for data in dataset: # data does not have a consistent shape
231
+ ... break
232
+ >>> # Define batch-size dynamically
233
+ >>> data = dataset.sample(128) # delivers 2 sub-trajectories of length 64
234
+
235
+ """
236
+
237
+ available_datasets = [
238
+ "fractal20220817_data",
239
+ "kuka",
240
+ "bridge",
241
+ "taco_play",
242
+ "jaco_play",
243
+ "berkeley_cable_routing",
244
+ "roboturk",
245
+ "nyu_door_opening_surprising_effectiveness",
246
+ "viola",
247
+ "berkeley_autolab_ur5",
248
+ "toto",
249
+ "language_table",
250
+ "columbia_cairlab_pusht_real",
251
+ "stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
252
+ "nyu_rot_dataset_converted_externally_to_rlds",
253
+ "stanford_hydra_dataset_converted_externally_to_rlds",
254
+ "austin_buds_dataset_converted_externally_to_rlds",
255
+ "nyu_franka_play_dataset_converted_externally_to_rlds",
256
+ "maniskill_dataset_converted_externally_to_rlds",
257
+ "furniture_bench_dataset_converted_externally_to_rlds",
258
+ "cmu_franka_exploration_dataset_converted_externally_to_rlds",
259
+ "ucsd_kitchen_dataset_converted_externally_to_rlds",
260
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
261
+ "austin_sailor_dataset_converted_externally_to_rlds",
262
+ "austin_sirius_dataset_converted_externally_to_rlds",
263
+ "bc_z",
264
+ "usc_cloth_sim_converted_externally_to_rlds",
265
+ "utokyo_pr2_opening_fridge_converted_externally_to_rlds",
266
+ "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
267
+ "utokyo_saytap_converted_externally_to_rlds",
268
+ "utokyo_xarm_pick_and_place_converted_externally_to_rlds",
269
+ "utokyo_xarm_bimanual_converted_externally_to_rlds",
270
+ "robo_net",
271
+ "berkeley_mvp_converted_externally_to_rlds",
272
+ "berkeley_rpt_converted_externally_to_rlds",
273
+ "kaist_nonprehensile_converted_externally_to_rlds",
274
+ "stanford_mask_vit_converted_externally_to_rlds",
275
+ "tokyo_u_lsmo_converted_externally_to_rlds",
276
+ "dlr_sara_pour_converted_externally_to_rlds",
277
+ "dlr_sara_grid_clamp_converted_externally_to_rlds",
278
+ "dlr_edan_shared_control_converted_externally_to_rlds",
279
+ "asu_table_top_converted_externally_to_rlds",
280
+ "stanford_robocook_converted_externally_to_rlds",
281
+ "eth_agent_affordances",
282
+ "imperialcollege_sawyer_wrist_cam",
283
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
284
+ "uiuc_d3field",
285
+ "utaustin_mutex",
286
+ "berkeley_fanuc_manipulation",
287
+ "cmu_playing_with_food",
288
+ "cmu_play_fusion",
289
+ "cmu_stretch",
290
+ "berkeley_gnm_recon",
291
+ "berkeley_gnm_cory_hall",
292
+ "berkeley_gnm_sac_son",
293
+ ]
294
+
295
+ # some very high number that should be above all trajecory lengths in the dataset
296
+ _MAX_TRAJ_LEN = 1_000_000
297
+
298
+ def __init__(
299
+ self,
300
+ dataset_id,
301
+ batch_size: int | None = None,
302
+ *,
303
+ shuffle: bool = True,
304
+ num_slices: int | None = None,
305
+ slice_len: int | None = None,
306
+ pad: float | bool | None = None,
307
+ replacement: bool | None = None,
308
+ streaming: bool | None = None,
309
+ root: str | Path | None = None,
310
+ download: bool | None = None,
311
+ sampler: Sampler | None = None,
312
+ writer: Writer | None = None,
313
+ collate_fn: Callable | None = None,
314
+ pin_memory: bool = False,
315
+ prefetch: int | None = None,
316
+ transform: torchrl.envs.Transform | None = None, # noqa-F821
317
+ split_trajs: bool = False,
318
+ strict_length: bool = True,
319
+ ):
320
+ if download is None and streaming is None:
321
+ download = False
322
+ streaming = True
323
+ elif download is None:
324
+ download = not streaming
325
+ elif streaming is None:
326
+ streaming = not download
327
+ self.download = download
328
+ self.streaming = streaming
329
+ self.dataset_id = dataset_id
330
+ self.split_trajs = split_trajs
331
+ self.shuffle = shuffle
332
+ self.num_slices = num_slices
333
+ self.slice_len = slice_len
334
+ self.pad = pad
335
+ self.strict_length = strict_length
336
+ if (self.num_slices is not None) and (self.slice_len is not None):
337
+ raise ValueError("num_slices or slice_len can be not None, but not both.")
338
+ if split_trajs:
339
+ raise NotImplementedError
340
+ if not streaming:
341
+ if replacement is None:
342
+ replacement = True
343
+ if pad is not None:
344
+ raise RuntimeError(
345
+ "the `pad` argument is to be used only with streaming datasets."
346
+ )
347
+ if root is None:
348
+ root = _get_root_dir("openx")
349
+ os.makedirs(root, exist_ok=True)
350
+ self.root = Path(root)
351
+ if self.download == "force" or (
352
+ self.download and not self._is_downloaded()
353
+ ):
354
+ if download == "force" and os.path.exists(self.data_path_root):
355
+ shutil.rmtree(self.data_path_root)
356
+
357
+ storage = self._download_and_preproc()
358
+ else:
359
+ storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
360
+ if num_slices is not None or slice_len is not None:
361
+ if sampler is not None:
362
+ raise ValueError(
363
+ "`num_slices` and `slice_len` are exclusive with the `sampler` argument."
364
+ )
365
+
366
+ if replacement:
367
+ if not self.shuffle:
368
+ raise RuntimeError(
369
+ "shuffle=False can only be used when replacement=False."
370
+ )
371
+ sampler = SliceSampler(
372
+ num_slices=num_slices,
373
+ slice_len=slice_len,
374
+ strict_length=strict_length,
375
+ )
376
+ else:
377
+ sampler = SliceSamplerWithoutReplacement(
378
+ num_slices=num_slices,
379
+ slice_len=slice_len,
380
+ strict_length=strict_length,
381
+ shuffle=self.shuffle,
382
+ )
383
+
384
+ else:
385
+ if replacement is True:
386
+ # replacement can be False or None
387
+ raise RuntimeError(
388
+ "replacement=True is not available with streamed datasets."
389
+ )
390
+ self.root = None
391
+ if download:
392
+ raise ValueError(
393
+ "download and streaming cannot be set to ``True`` concomitantly."
394
+ )
395
+ storage = _StreamingStorage(
396
+ dataset_id=dataset_id,
397
+ shuffle=self.shuffle,
398
+ num_slices=self.num_slices,
399
+ slice_len=self.slice_len,
400
+ pad=self.pad,
401
+ )
402
+ if sampler is None:
403
+ sampler = _StreamingSampler()
404
+ if writer is None:
405
+ writer = ImmutableDatasetWriter()
406
+ if collate_fn is None:
407
+ collate_fn = _collate_id
408
+ super().__init__(
409
+ storage=storage,
410
+ sampler=sampler,
411
+ writer=writer,
412
+ collate_fn=collate_fn,
413
+ pin_memory=pin_memory,
414
+ prefetch=prefetch,
415
+ batch_size=batch_size,
416
+ transform=transform,
417
+ )
418
+
419
+ def __iter__(self):
420
+ if self._batch_size is None:
421
+ # we can still iterate over the dataset
422
+ if isinstance(self._storage, _StreamingStorage):
423
+ yield from self._storage
424
+ elif self.slice_len is not None and self.num_slices is None:
425
+ try:
426
+ # truncate the trajs with slice_len
427
+ self._batch_size = self.slice_len
428
+ self.num_slices = 1
429
+ self.slice_len = None
430
+ yield from self
431
+ finally:
432
+ self.slice_len = self._batch_size
433
+ self._batch_size = None
434
+ self.num_slices = None
435
+ else:
436
+ # if we don't have a batch size but we know how many trajectories
437
+ # we want in each batch, we can build that on the fly.
438
+ # The only time we can do this is if num_slices is given but not
439
+ # slice_len.
440
+ num_slices = self.num_slices
441
+ if not num_slices:
442
+ num_slices = 1
443
+ sampler = SliceSamplerWithoutReplacement(
444
+ num_slices=num_slices,
445
+ strict_length=False,
446
+ shuffle=self.shuffle,
447
+ )
448
+ batch_size = self._MAX_TRAJ_LEN
449
+ yield from TensorDictReplayBuffer(
450
+ storage=self._storage,
451
+ sampler=sampler,
452
+ batch_size=batch_size,
453
+ transform=self._transform,
454
+ )
455
+ else:
456
+ yield from super().__iter__()
457
+
458
+ @property
459
+ def data_path(self):
460
+ if self.streaming:
461
+ return None
462
+ if self.split_trajs:
463
+ return Path(self.root) / (self.dataset_id + "_split")
464
+ return self.data_path_root
465
+
466
+ @property
467
+ def data_path_root(self):
468
+ if self.streaming:
469
+ return None
470
+ return self.root / self.dataset_id
471
+
472
+ def _is_downloaded(self):
473
+ return os.path.exists(self.data_path_root)
474
+
475
+ def _download_and_preproc(self):
476
+ if not _has_datasets:
477
+ raise ImportError(
478
+ f"the `datasets` library is required for the dataset {self.dataset_id}."
479
+ )
480
+ import datasets
481
+
482
+ with tempfile.TemporaryDirectory() as cache_dir:
483
+ dataset = datasets.load_dataset(
484
+ "jxu124/OpenX-Embodiment",
485
+ self.dataset_id,
486
+ streaming=False,
487
+ split="train",
488
+ cache_dir=cache_dir,
489
+ trust_remote_code=True,
490
+ )
491
+ # iterate over the dataset a first time to count elements
492
+ total_frames = 0
493
+
494
+ try:
495
+ import tqdm
496
+
497
+ _has_tqdm = True
498
+ pbar = tqdm.tqdm(dataset, desc="counting")
499
+ except ImportError:
500
+ _has_tqdm = False
501
+ pbar = dataset
502
+
503
+ for data in pbar:
504
+ if total_frames == 0:
505
+ for step in data["data.pickle"]["steps"]:
506
+ td = _make_tensordict_image_conv(step).zero_()
507
+ # format td: requires td to have a non-null batch_size
508
+ td = td.expand(2, *td.shape)
509
+ _format_data(td, 0)
510
+ td = td[0]
511
+ total_frames += len(data["data.pickle"]["steps"])
512
+ td_data = td.expand(total_frames)
513
+
514
+ def expand_non_tensor(x):
515
+ if isinstance(x, NonTensorData):
516
+ return x.maybe_to_stack()
517
+ return x
518
+
519
+ td_data = td_data._apply_nest(
520
+ expand_non_tensor,
521
+ is_leaf=lambda x: issubclass(x, torch.Tensor) or _is_non_tensor(x),
522
+ )
523
+ td_data = td_data.memmap_like(self.root / self.dataset_id)
524
+ if _has_tqdm:
525
+ pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames)
526
+ else:
527
+ pbar = dataset
528
+ idx0 = 0
529
+ idx1 = 0
530
+ episode = 0
531
+ for data in pbar:
532
+ current_ep = torch.stack(
533
+ [
534
+ _make_tensordict_image_conv(step)
535
+ for step in data["data.pickle"]["steps"]
536
+ ]
537
+ ).contiguous()
538
+ _format_data(current_ep, episode)
539
+ episode += 1
540
+ idx1 += len(current_ep)
541
+ td_data[idx0:idx1] = current_ep
542
+ idx0 = idx1
543
+ if _has_tqdm:
544
+ pbar.update(current_ep.shape[0])
545
+ return TensorStorage(td_data.lock_())
546
+
547
+
548
+ class _StreamingStorage(Storage):
549
+ SLICE_MISMATCH = "The batch_size {} must be divisible by num_slices {} or slice_len {} if provided."
550
+
551
+ def __init__(
552
+ self,
553
+ dataset_id: str,
554
+ repo: str = "jxu124/OpenX-Embodiment",
555
+ split="train",
556
+ base_path="data.pickle",
557
+ shuffle: bool = True,
558
+ truncate: bool = True,
559
+ num_slices=None,
560
+ slice_len=None,
561
+ pad=None,
562
+ ):
563
+ self.shuffle = shuffle
564
+ self.dataset_id = dataset_id
565
+ self.repo = repo
566
+ self.split = split
567
+ self._init()
568
+ self.base_path = base_path
569
+ self.truncate = truncate
570
+ self.num_slices = num_slices
571
+ self.slice_len = slice_len
572
+ self.pad = pad
573
+
574
+ def _init(self):
575
+ if not _has_datasets:
576
+ raise ImportError(
577
+ f"the `datasets` library is required for the dataset {self.dataset_id}."
578
+ )
579
+ import datasets
580
+
581
+ try:
582
+ dataset = datasets.load_dataset(
583
+ self.repo, self.dataset_id, streaming=True, split=self.split
584
+ )
585
+ except Exception as e:
586
+ if "Dataset scripts are no longer supported" in str(e):
587
+ raise RuntimeError(
588
+ f"Failed to load dataset {self.dataset_id}. Your version of `datasets` is too new - please downgrade to <4.0.0."
589
+ ) from e
590
+ raise e
591
+
592
+ if self.shuffle:
593
+ dataset = dataset.shuffle()
594
+ self.dataset = dataset
595
+ self.dataset_iter = iter(dataset)
596
+
597
+ def __iter__(self):
598
+ episode = 0
599
+ for data in self.dataset:
600
+ if self.base_path:
601
+ data = data[self.base_path]
602
+ data = torch.stack(
603
+ [_make_tensordict_image_conv(step) for step in data["steps"]]
604
+ ).contiguous()
605
+ _format_data(data, episode)
606
+ if self.slice_len is not None:
607
+ yield _slice_data(data, slice_len=self.slice_len, pad_value=self.pad)
608
+ else:
609
+ yield data
610
+
611
+ def get(self, index: range | torch.Tensor) -> Any:
612
+ if not isinstance(index, range):
613
+ if (index[1:] != index[:-1] + 1).any():
614
+ # we use a range to indicate how much data we want
615
+ raise RuntimeError("iterable datasets do not support indexing.")
616
+ index = range(index.shape[0])
617
+ total = 0
618
+ data_list = []
619
+ episode = 0
620
+ batch_size = index.stop
621
+ if self.num_slices is not None:
622
+ if batch_size % self.num_slices != 0:
623
+ raise ValueError(
624
+ self.SLICE_MISMATCH.format(
625
+ batch_size, self.num_slices, self.slice_len
626
+ )
627
+ )
628
+ num_slices = self.num_slices
629
+ slice_len = batch_size // num_slices
630
+ else:
631
+ if batch_size % self.slice_len != 0:
632
+ raise ValueError(
633
+ self.SLICE_MISMATCH.format(
634
+ batch_size, self.num_slices, self.slice_len
635
+ )
636
+ )
637
+ slice_len = self.slice_len
638
+ # num_slices = batch_size // slice_len
639
+
640
+ while total < batch_size:
641
+ try:
642
+ data = next(self.dataset_iter)
643
+ except StopIteration:
644
+ self.dataset_iter = iter(self.dataset)
645
+ data = next(self.dataset_iter)
646
+
647
+ if self.base_path:
648
+ data = data[self.base_path]
649
+ data = torch.stack(
650
+ [_make_tensordict_image_conv(step) for step in data["steps"]]
651
+ ).contiguous()
652
+ _format_data(data, episode)
653
+ data = _slice_data(data, slice_len=slice_len, pad_value=self.pad)
654
+ data_list.append(data)
655
+ total += data.numel()
656
+ episode += 1
657
+ data = torch.cat(data_list)
658
+ if self.truncate:
659
+ return data[: index.stop]
660
+ return data
661
+
662
+ def dumps(self, path):
663
+ path = Path(path)
664
+ state_dict = self.state_dict()
665
+ json.dump(state_dict, path / "state_dict.json")
666
+
667
+ def state_dict(self) -> dict[str, Any]:
668
+ return {
669
+ "repo": self.repo,
670
+ "split": self.split,
671
+ "dataset_id": self.dataset_id,
672
+ "shuffle": self.shuffle,
673
+ "base_path": self.base_path,
674
+ "truncated": self.truncate,
675
+ "num_slices": self.num_slices,
676
+ "slice_len": self.slice_len,
677
+ "pad": self.pad,
678
+ }
679
+
680
+ def loads(self, path):
681
+ path = Path(path)
682
+ state_dict = json.load(path / "state_dict.json")
683
+ self.load_state_dict(state_dict)
684
+
685
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
686
+ for key, val in state_dict.items():
687
+ setattr(self, key, val)
688
+ self._init()
689
+
690
+ def __len__(self):
691
+ raise RuntimeError(
692
+ f"{type(self)} does not have a length. Use a downloaded dataset to "
693
+ f"access this property."
694
+ )
695
+
696
+
697
+ def _slice_data(data: TensorDict, slice_len, pad_value):
698
+ if data.shape[-1] < slice_len:
699
+ if pad_value is None:
700
+ raise RuntimeError(
701
+ f"The trajectory length ({data.shape[-1]}) is shorter than the slice length ({slice_len}). "
702
+ f"Decrease the slice length or provide a padding value."
703
+ )
704
+ if pad_value is True:
705
+ pad_value = 0
706
+ return pad(data, [0, slice_len - data.shape[-1]], value=pad_value)
707
+
708
+ if data.ndim == 1:
709
+ random_range = (
710
+ ((data.shape[-1] - slice_len) * torch.rand(())).floor().int().item()
711
+ )
712
+ random_range = slice(random_range, random_range + slice_len)
713
+ else:
714
+ raise NotImplementedError(data)
715
+ data = data[..., random_range]
716
+ truncated = data.get(("next", "truncated"))
717
+ truncated = torch.index_fill(
718
+ truncated,
719
+ dim=data.ndim - 1,
720
+ value=True,
721
+ index=torch.as_tensor(-1, device=truncated.device),
722
+ )
723
+ done = data.get(("next", "done"))
724
+ data.set(("next", "truncated"), truncated)
725
+ data.set(("next", "done"), truncated | done)
726
+ return data
727
+
728
+
729
+ class _StreamingSampler(Sampler):
730
+ def __init__(self):
731
+ ...
732
+
733
+ def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]:
734
+ return range(batch_size), {}
735
+
736
+ def _empty(self):
737
+ return
738
+
739
+ def dumps(self, path):
740
+ ...
741
+
742
+ def loads(self, path):
743
+ ...
744
+
745
+ def state_dict(self) -> dict[str, Any]:
746
+ return {}
747
+
748
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
749
+ ...
750
+
751
+
752
+ OPENX_KEY_MAP = {
753
+ "is_first": "is_init",
754
+ "is_last": ("next", "done"),
755
+ "is_terminal": ("next", "terminated"),
756
+ "reward": ("next", "reward"),
757
+ }
758
+
759
+
760
+ def _format_data(data: TensorDict, episode: int):
761
+ observation_ = data.get("observation")
762
+ observation_pad = pad(observation_[1:], [0, 1])
763
+ data.set(("next", "observation"), observation_pad)
764
+ for key, newkey in OPENX_KEY_MAP.items():
765
+ data.rename_key_(key, newkey)
766
+ data.set(
767
+ ("next", "truncated"),
768
+ data.get(("next", "done")) & ~data.get(("next", "terminated")),
769
+ )
770
+
771
+ for key in ("done", "terminated", "truncated", "reward"):
772
+ data.set(("next", key), data.get(("next", key)).unsqueeze(-1))
773
+ if key != "reward":
774
+ data.set(key, torch.zeros_like(data.get(("next", key))))
775
+
776
+ data.set(
777
+ "episode", torch.full(data.shape, episode, device=data.device, dtype=torch.int)
778
+ )
779
+
780
+
781
+ def _make_tensordict_image_conv(data):
782
+ # in some datasets, the images are not well converted.
783
+ # before building the tensordict, we load the PIL image and convert it to a tensor
784
+ try:
785
+ img_bytes = data["observation"]["image"]["bytes"]
786
+ if not _has_tv:
787
+ raise ImportError(
788
+ "the `torchvision` library is required to read the image observation."
789
+ )
790
+ import torchvision.transforms.v2.functional
791
+ from PIL import Image
792
+
793
+ img = Image.open(io.BytesIO(img_bytes))
794
+ tensor = torchvision.transforms.v2.functional.pil_to_tensor(img)
795
+ data["observation"]["image"] = tensor
796
+ except KeyError:
797
+ pass
798
+ return make_tensordict(data, auto_batch_size=True)