torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,103 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable
8
+
9
+ from tensordict import NestedKey
10
+
11
+
12
+ def _plot_plotly_tree(
13
+ tree: Tree, make_labels: Callable[[Tree], str] | None = None # noqa: F821
14
+ ):
15
+ import plotly.graph_objects as go
16
+ from igraph import Graph
17
+
18
+ if make_labels is None:
19
+
20
+ def make_labels(tree, path, *args, **kwargs):
21
+ return str((tree.node_id, tree.hash))
22
+
23
+ nr_vertices = tree.num_vertices()
24
+ vertices = tree.vertices(key_type="path")
25
+
26
+ v_label = [make_labels(subtree, path) for path, subtree in vertices.items()]
27
+ G = Graph(nr_vertices, tree.edges())
28
+
29
+ layout = G.layout_sugiyama(range(nr_vertices))
30
+
31
+ position = {k: layout[k] for k in range(nr_vertices)}
32
+ # Y = [layout[k][1] for k in range(nr_vertices)]
33
+ # M = max(Y)
34
+
35
+ # es = EdgeSeq(G) # sequence of edges
36
+ E = [e.tuple for e in G.es] # list of edges
37
+
38
+ L = len(position)
39
+ Xn = [position[k][0] for k in range(L)]
40
+ # Yn = [2 * M - position[k][1] for k in range(L)]
41
+ Yn = [position[k][1] for k in range(L)]
42
+ Xe = []
43
+ Ye = []
44
+ for edge in E:
45
+ Xe += [position[edge[0]][0], position[edge[1]][0], None]
46
+ # Ye += [2 * M - position[edge[0]][1], 2 * M - position[edge[1]][1], None]
47
+ Ye += [position[edge[0]][1], position[edge[1]][1], None]
48
+
49
+ labels = v_label
50
+ fig = go.Figure()
51
+ fig.add_trace(
52
+ go.Scatter(
53
+ x=Xe,
54
+ y=Ye,
55
+ mode="lines",
56
+ line={"color": "rgb(210,210,210)", "width": 5},
57
+ hoverinfo="none",
58
+ )
59
+ )
60
+ fig.add_trace(
61
+ go.Scatter(
62
+ x=Xn,
63
+ y=Yn,
64
+ mode="markers+text",
65
+ name="bla",
66
+ marker={
67
+ "symbol": "circle-dot",
68
+ "size": 40,
69
+ "color": "#6175c1", # '#DB4551',
70
+ "line": {"color": "rgb(50,50,50)", "width": 1},
71
+ },
72
+ text=labels,
73
+ hoverinfo="text",
74
+ textposition="middle right",
75
+ opacity=0.8,
76
+ )
77
+ )
78
+ fig.show()
79
+
80
+
81
+ def _plot_plotly_box(tree: Tree, info: list[NestedKey] = None): # noqa: F821
82
+ import plotly.graph_objects as go
83
+
84
+ if info is None:
85
+ info = ["hash", ("next", "reward")]
86
+
87
+ parents = [""]
88
+ labels = [tree._label(info, tree, root=True)]
89
+
90
+ _tree = tree
91
+
92
+ def extend(tree: Tree, parent): # noqa: F821
93
+ children = tree.subtree
94
+ if children is None:
95
+ return
96
+ for child in children:
97
+ labels.append(tree._label(info, child))
98
+ parents.append(parent)
99
+ extend(child, labels[-1])
100
+
101
+ extend(_tree, labels[-1])
102
+ fig = go.Figure(go.Treemap(labels=labels, parents=parents))
103
+ fig.show()
@@ -0,0 +1,8 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .postprocs import DensifyReward, MultiStep
7
+
8
+ __all__ = ["MultiStep", "DensifyReward"]
@@ -0,0 +1,391 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+ from tensordict import NestedKey, TensorDictBase, unravel_key
10
+ from tensordict.nn import TensorDictModuleBase
11
+ from tensordict.utils import expand_right
12
+ from torch import nn
13
+
14
+
15
+ def _get_reward(
16
+ gamma: float,
17
+ reward: torch.Tensor,
18
+ done: torch.Tensor,
19
+ max_steps: int,
20
+ ):
21
+ """Sums the rewards up to max_steps in the future with a gamma decay.
22
+
23
+ Supports multiple consecutive trajectories.
24
+
25
+ Assumes that the time dimension is the *last* dim of reward and done.
26
+ """
27
+ filt = torch.tensor(
28
+ [gamma**i for i in range(max_steps + 1)],
29
+ device=reward.device,
30
+ dtype=reward.dtype,
31
+ ).view(1, 1, -1)
32
+ # make one done mask per trajectory
33
+ done_cumsum = done.cumsum(-1)
34
+ done_cumsum = torch.cat(
35
+ [torch.zeros_like(done_cumsum[..., :1]), done_cumsum[..., :-1]], -1
36
+ )
37
+ num_traj = done_cumsum.max().item() + 1
38
+ done_cumsum = done_cumsum.expand(num_traj, *done.shape)
39
+ traj_ids = done_cumsum == torch.arange(
40
+ num_traj, device=done.device, dtype=done_cumsum.dtype
41
+ ).view(num_traj, *[1 for _ in range(done_cumsum.ndim - 1)])
42
+ # an expanded reward tensor where each index along dim 0 is a different trajectory
43
+ # Note: rewards could have a different shape than done (e.g. multi-agent with a single
44
+ # done per group).
45
+ # we assume that reward has the same leading dimension as done.
46
+ if reward.shape != traj_ids.shape[1:]:
47
+ # We'll expand the ids on the right first
48
+ traj_ids_expand = expand_right(traj_ids, (num_traj, *reward.shape))
49
+ reward_traj = traj_ids_expand * reward
50
+ # we must make sure that the last dimension of the reward is the time
51
+ reward_traj = reward_traj.transpose(-1, traj_ids.ndim - 1)
52
+ else:
53
+ # simpler use case: reward shape and traj_ids match
54
+ reward_traj = traj_ids * reward
55
+
56
+ reward_traj = torch.nn.functional.pad(reward_traj, [0, max_steps], value=0.0)
57
+ shape = reward_traj.shape[:-1]
58
+ if len(shape) > 1:
59
+ reward_traj = reward_traj.flatten(0, reward_traj.ndim - 2)
60
+ reward_traj = reward_traj.unsqueeze(-2)
61
+ summed_rewards = torch.conv1d(reward_traj, filt)
62
+ summed_rewards = summed_rewards.squeeze(-2)
63
+ if len(shape) > 1:
64
+ summed_rewards = summed_rewards.unflatten(0, shape)
65
+ # let's check that our summed rewards have the right size
66
+ if reward.shape != traj_ids.shape[1:]:
67
+ summed_rewards = summed_rewards.transpose(-1, traj_ids.ndim - 1)
68
+ summed_rewards = (summed_rewards * traj_ids_expand).sum(0)
69
+ else:
70
+ summed_rewards = (summed_rewards * traj_ids).sum(0)
71
+
72
+ # time_to_obs is the tensor of the time delta to the next obs
73
+ # 0 = take the next obs (ie do nothing)
74
+ # 1 = take the obs after the next
75
+ time_to_obs = (
76
+ traj_ids.flip(-1).cumsum(-1).clamp_max(max_steps + 1).flip(-1) * traj_ids
77
+ )
78
+ time_to_obs = time_to_obs.sum(0)
79
+ time_to_obs = time_to_obs - 1
80
+ return summed_rewards, time_to_obs
81
+
82
+
83
+ class MultiStep(nn.Module):
84
+ """Multistep reward transform.
85
+
86
+ Presented in
87
+
88
+ | Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3(1):9–44.
89
+
90
+ This module maps the "next" observation to the t + n "next" observation.
91
+ It is an identity transform whenever :attr:`n_steps` is 0.
92
+
93
+ Args:
94
+ gamma (:obj:`float`): Discount factor for return computation
95
+ n_steps (integer): maximum look-ahead steps.
96
+
97
+ .. note:: This class is meant to be used within a ``DataCollector``.
98
+ It will only treat the data passed to it at the end of a collection,
99
+ and ignore data preceding that collection or coming in the next batch.
100
+ As such, results on the last steps of the batch may likely be biased
101
+ by the early truncation of the trajectory.
102
+ To mitigate this effect, please use :class:`~torchrl.envs.transforms.MultiStepTransform`
103
+ within the replay buffer instead.
104
+
105
+ Examples:
106
+ >>> from torchrl.modules import RandomPolicy >>> >>> from torchrl.collectors import Collector
107
+ >>> from torchrl.data.postprocs import MultiStep
108
+ >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter
109
+ >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
110
+ >>> env.set_seed(0)
111
+ >>> collector = Collector(env, policy=RandomPolicy(env.action_spec),
112
+ ... frames_per_batch=10, total_frames=2000, postproc=MultiStep(n_steps=4, gamma=0.99))
113
+ >>> for data in collector:
114
+ ... break
115
+ >>> print(data["step_count"])
116
+ tensor([[0],
117
+ [1],
118
+ [2],
119
+ [3],
120
+ [4],
121
+ [5],
122
+ [6],
123
+ [7],
124
+ [8],
125
+ [9]])
126
+ >>> # the next step count is shifted by 3 steps in the future
127
+ >>> print(data["next", "step_count"])
128
+ tensor([[ 5],
129
+ [ 6],
130
+ [ 7],
131
+ [ 8],
132
+ [ 9],
133
+ [10],
134
+ [10],
135
+ [10],
136
+ [10],
137
+ [10]])
138
+
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ gamma: float,
144
+ n_steps: int,
145
+ ):
146
+ super().__init__()
147
+ if n_steps <= 0:
148
+ raise ValueError("n_steps must be a non-negative integer.")
149
+ if not (gamma > 0 and gamma <= 1):
150
+ raise ValueError(f"got out-of-bounds gamma decay: gamma={gamma}")
151
+
152
+ self.gamma = gamma
153
+ self.n_steps = n_steps
154
+ self.register_buffer(
155
+ "gammas",
156
+ torch.tensor(
157
+ [gamma**i for i in range(n_steps + 1)],
158
+ dtype=torch.float,
159
+ ).reshape(1, 1, -1),
160
+ )
161
+ self.done_key = "done"
162
+ self.done_keys = ("done", "terminated", "truncated")
163
+ self.reward_keys = ("reward",)
164
+ self.mask_key = ("collector", "mask")
165
+
166
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
167
+ """Re-writes a tensordict following the multi-step transform.
168
+
169
+ Args:
170
+ tensordict: :class:`tensordict.TensorDictBase` instance with
171
+ ``[*Batch x Time-steps] shape.
172
+ The TensorDict must contain a ``("next", "reward")`` and
173
+ ``("next", "done")`` keys.
174
+ All keys that are contained within the "next" nested tensordict
175
+ will be shifted by (at most) :attr:`~.n_steps` frames.
176
+ The TensorDict will also be updated with new key-value pairs:
177
+
178
+ - gamma: indicating the discount to be used for the next
179
+ reward;
180
+ - nonterminal: boolean value indicating whether a step is
181
+ non-terminal (not done or not last of trajectory);
182
+ - original_reward: previous reward collected in the
183
+ environment (i.e. before multi-step);
184
+ - The "reward" values will be replaced by the newly computed
185
+ rewards.
186
+
187
+ The ``"done"`` key can have either the shape of the tensordict
188
+ OR the shape of the tensordict followed by a singleton
189
+ dimension OR the shape of the tensordict followed by other
190
+ dimensions. In the latter case, the tensordict *must* be
191
+ compatible with a reshape that follows the done shape (ie. the
192
+ leading dimensions of every tensor it contains must match the
193
+ shape of the ``"done"`` entry).
194
+ The ``"reward"`` tensor can have either the shape of the
195
+ tensordict (or done state) or this shape followed by a singleton
196
+ dimension.
197
+
198
+ Returns:
199
+ in-place transformation of the input tensordict.
200
+
201
+ """
202
+ return _multi_step_func(
203
+ tensordict,
204
+ done_key=self.done_key,
205
+ done_keys=self.done_keys,
206
+ reward_keys=self.reward_keys,
207
+ mask_key=self.mask_key,
208
+ n_steps=self.n_steps,
209
+ gamma=self.gamma,
210
+ )
211
+
212
+
213
+ def _multi_step_func(
214
+ tensordict,
215
+ *,
216
+ done_key,
217
+ done_keys,
218
+ reward_keys,
219
+ mask_key,
220
+ n_steps,
221
+ gamma,
222
+ ):
223
+ # in accordance with common understanding of what n_steps should be
224
+ n_steps = n_steps - 1
225
+ tensordict = tensordict.clone(False)
226
+ done = tensordict.get(("next", done_key))
227
+
228
+ # we'll be using the done states to index the tensordict.
229
+ # if the shapes don't match we're in trouble.
230
+ ndim = tensordict.ndim
231
+ if done.shape != tensordict.shape:
232
+ if done.shape[-1] == 1 and done.shape[:-1] == tensordict.shape:
233
+ done = done.squeeze(-1)
234
+ else:
235
+ try:
236
+ # let's try to reshape the tensordict
237
+ tensordict.batch_size = done.shape
238
+ tensordict = tensordict.transpose(ndim - 1, tensordict.ndim - 1)
239
+ done = tensordict.get(("next", done_key))
240
+ except Exception as err:
241
+ raise RuntimeError(
242
+ "tensordict shape must be compatible with the done's shape "
243
+ "(trailing singleton dimension excluded)."
244
+ ) from err
245
+
246
+ if mask_key is not None:
247
+ mask = tensordict.get(mask_key, None)
248
+ else:
249
+ mask = None
250
+
251
+ *batch, T = tensordict.batch_size
252
+
253
+ summed_rewards = []
254
+ for reward_key in reward_keys:
255
+ reward = tensordict.get(("next", reward_key))
256
+
257
+ # sum rewards
258
+ summed_reward, time_to_obs = _get_reward(gamma, reward, done, n_steps)
259
+ summed_rewards.append(summed_reward)
260
+
261
+ idx_to_gather = torch.arange(
262
+ T, device=time_to_obs.device, dtype=time_to_obs.dtype
263
+ ).expand(*batch, T)
264
+ idx_to_gather = idx_to_gather + time_to_obs
265
+
266
+ # idx_to_gather looks like tensor([[ 2, 3, 4, 5, 5, 5, 8, 9, 10, 10, 10]])
267
+ # with a done state tensor([[ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]])
268
+ # meaning that the first obs will be replaced by the third, the second by the fourth etc.
269
+ # The fifth remains the fifth as it is terminal
270
+ tensordict_gather = (
271
+ tensordict.get("next")
272
+ .exclude(*reward_keys, *done_keys)
273
+ .gather(-1, idx_to_gather)
274
+ )
275
+
276
+ tensordict.set("steps_to_next_obs", time_to_obs + 1)
277
+ for reward_key, summed_reward in zip(reward_keys, summed_rewards):
278
+ tensordict.rename_key_(("next", reward_key), ("next", "original_reward"))
279
+ tensordict.set(("next", reward_key), summed_reward)
280
+
281
+ tensordict.get("next").update(tensordict_gather)
282
+ tensordict.set("gamma", gamma ** (time_to_obs + 1))
283
+ nonterminal = time_to_obs != 0
284
+ if mask is not None:
285
+ mask = mask.view(*batch, T)
286
+ nonterminal[~mask] = False
287
+ tensordict.set("nonterminal", nonterminal)
288
+ if tensordict.ndim != ndim:
289
+ tensordict = tensordict.apply(
290
+ lambda x: x.transpose(ndim - 1, tensordict.ndim - 1),
291
+ batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape,
292
+ )
293
+ tensordict.batch_size = tensordict.batch_size[:ndim]
294
+ return tensordict
295
+
296
+
297
+ class DensifyReward(TensorDictModuleBase):
298
+ """A util to reassign the reward at done state to the rest of the trajectory.
299
+
300
+ This transform is to be used with sparse rewards to assign a reward to each step of a trajectory when only the
301
+ reward at `done` is non-null.
302
+
303
+ .. note:: The class calls the :func:`~torchrl.objectives.value.functional.reward2go` function, which will
304
+ also sum intermediate rewards. Make sure you understand what the `reward2go` function returns before using
305
+ this module.
306
+
307
+ Args:
308
+ reward_key (NestedKey, optional): The key in the input TensorDict where the reward is stored.
309
+ Defaults to `"reward"`.
310
+ done_key (NestedKey, optional): The key in the input TensorDict where the done flag is stored.
311
+ Defaults to `"done"`.
312
+ reward_key_out (NestedKey | None, optional): The key in the output TensorDict where the reassigned reward
313
+ will be stored. If None, it defaults to the value of `reward_key`.
314
+ Defaults to `None`.
315
+ time_dim (int, optional): The dimension in the input TensorDict where the time is unrolled.
316
+ Defaults to `2`.
317
+ discount (float, optional): The discount factor to use for computing the discounted cumulative sum of rewards.
318
+ Defaults to `1.0` (no discounting).
319
+
320
+ Returns:
321
+ TensorDict: The input TensorDict with the reassigned reward stored under the key specified by `reward_key_out`.
322
+
323
+ Examples:
324
+ >>> import torch
325
+ >>> from tensordict import TensorDict
326
+ >>>
327
+ >>> from torchrl.data import DensifyReward
328
+ >>>
329
+ >>> # Create a sample TensorDict
330
+ >>> tensordict = TensorDict({
331
+ ... "next": {
332
+ ... "reward": torch.zeros(10, 1),
333
+ ... "done": torch.zeros(10, 1, dtype=torch.bool)
334
+ ... }
335
+ ... }, batch_size=[10])
336
+ >>> # Set some done flags and rewards
337
+ >>> tensordict["next", "done"][[3, 7]] = True
338
+ >>> tensordict["next", "reward"][3] = 3
339
+ >>> tensordict["next", "reward"][7] = 7
340
+ >>> # Create an instance of LastRewardToTraj
341
+ >>> last_reward_to_traj = DensifyReward()
342
+ >>> # Apply the transform
343
+ >>> new_tensordict = last_reward_to_traj(tensordict)
344
+ >>> # Print the reassigned rewards
345
+ >>> print(new_tensordict["next", "reward"])
346
+ tensor([[3.],
347
+ [3.],
348
+ [3.],
349
+ [3.],
350
+ [7.],
351
+ [7.],
352
+ [7.],
353
+ [7.],
354
+ [0.],
355
+ [0.]])
356
+
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ *,
362
+ reward_key: NestedKey = "reward",
363
+ done_key: NestedKey = "done",
364
+ reward_key_out: NestedKey | None = None,
365
+ time_dim: int = 2,
366
+ discount: float = 1.0,
367
+ ):
368
+ from torchrl.objectives.value.functional import reward2go
369
+
370
+ super().__init__()
371
+ self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372
+ if reward_key_out is None:
373
+ reward_key_out = reward_key
374
+ self.out_keys = [unravel_key(reward_key_out)]
375
+ self.time_dim = time_dim
376
+ self.discount = discount
377
+ self.reward2go = reward2go
378
+
379
+ def forward(self, tensordict):
380
+ # Get done
381
+ done = tensordict.get(("next", self.in_keys[1]))
382
+ # Get reward
383
+ reward = tensordict.get(("next", self.in_keys[0]))
384
+ if reward.shape != done.shape:
385
+ raise RuntimeError(
386
+ f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
387
+ f"and done.shape={done.shape}."
388
+ )
389
+ reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
390
+ tensordict.set(("next", self.out_keys[0]), reward)
391
+ return tensordict
@@ -0,0 +1,99 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .checkpointers import (
7
+ CompressedListStorageCheckpointer,
8
+ FlatStorageCheckpointer,
9
+ H5StorageCheckpointer,
10
+ ListStorageCheckpointer,
11
+ NestedStorageCheckpointer,
12
+ StorageCheckpointerBase,
13
+ StorageEnsembleCheckpointer,
14
+ TensorStorageCheckpointer,
15
+ )
16
+ from .ray_buffer import RayReplayBuffer
17
+ from .replay_buffers import (
18
+ PrioritizedReplayBuffer,
19
+ RemoteTensorDictReplayBuffer,
20
+ ReplayBuffer,
21
+ ReplayBufferEnsemble,
22
+ TensorDictPrioritizedReplayBuffer,
23
+ TensorDictReplayBuffer,
24
+ )
25
+ from .samplers import (
26
+ PrioritizedSampler,
27
+ PrioritizedSliceSampler,
28
+ RandomSampler,
29
+ Sampler,
30
+ SamplerEnsemble,
31
+ SamplerWithoutReplacement,
32
+ SliceSampler,
33
+ SliceSamplerWithoutReplacement,
34
+ )
35
+ from .storages import (
36
+ CompressedListStorage,
37
+ LazyMemmapStorage,
38
+ LazyStackStorage,
39
+ LazyTensorStorage,
40
+ ListStorage,
41
+ Storage,
42
+ StorageEnsemble,
43
+ TensorStorage,
44
+ )
45
+ from .utils import Flat2TED, H5Combine, H5Split, Nested2TED, TED2Flat, TED2Nested
46
+ from .writers import (
47
+ ImmutableDatasetWriter,
48
+ RoundRobinWriter,
49
+ TensorDictMaxValueWriter,
50
+ TensorDictRoundRobinWriter,
51
+ Writer,
52
+ WriterEnsemble,
53
+ )
54
+
55
+ __all__ = [
56
+ "CompressedListStorage",
57
+ "CompressedListStorageCheckpointer",
58
+ "FlatStorageCheckpointer",
59
+ "H5StorageCheckpointer",
60
+ "ListStorageCheckpointer",
61
+ "NestedStorageCheckpointer",
62
+ "StorageCheckpointerBase",
63
+ "StorageEnsembleCheckpointer",
64
+ "TensorStorageCheckpointer",
65
+ "RayReplayBuffer",
66
+ "PrioritizedReplayBuffer",
67
+ "RemoteTensorDictReplayBuffer",
68
+ "ReplayBuffer",
69
+ "ReplayBufferEnsemble",
70
+ "TensorDictPrioritizedReplayBuffer",
71
+ "TensorDictReplayBuffer",
72
+ "PrioritizedSampler",
73
+ "PrioritizedSliceSampler",
74
+ "RandomSampler",
75
+ "Sampler",
76
+ "SamplerEnsemble",
77
+ "SamplerWithoutReplacement",
78
+ "SliceSampler",
79
+ "SliceSamplerWithoutReplacement",
80
+ "LazyMemmapStorage",
81
+ "LazyStackStorage",
82
+ "LazyTensorStorage",
83
+ "ListStorage",
84
+ "Storage",
85
+ "StorageEnsemble",
86
+ "TensorStorage",
87
+ "Flat2TED",
88
+ "H5Combine",
89
+ "H5Split",
90
+ "Nested2TED",
91
+ "TED2Flat",
92
+ "TED2Nested",
93
+ "ImmutableDatasetWriter",
94
+ "RoundRobinWriter",
95
+ "TensorDictMaxValueWriter",
96
+ "TensorDictRoundRobinWriter",
97
+ "Writer",
98
+ "WriterEnsemble",
99
+ ]