torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,781 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import heapq
8
+ import json
9
+ import textwrap
10
+ from abc import ABC, abstractmethod
11
+ from collections.abc import Sequence
12
+ from copy import copy
13
+ from multiprocessing.context import get_spawning_popen
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import torch
19
+ from tensordict import is_tensor_collection, MemoryMappedTensor, TensorDictBase
20
+ from tensordict.utils import expand_as_right, is_tensorclass
21
+ from torch import multiprocessing as mp
22
+ from torchrl._utils import _STRDTYPE2DTYPE
23
+
24
+ try:
25
+ from torch.compiler import disable as compile_disable
26
+ except ImportError:
27
+ from torch._dynamo import disable as compile_disable
28
+
29
+ try:
30
+ from torch.utils._pytree import tree_leaves
31
+ except ImportError:
32
+ from torch.utils._pytree import tree_flatten
33
+
34
+ def tree_leaves(data): # noqa: D103
35
+ tree_flat, _ = tree_flatten(data)
36
+ return tree_flat
37
+
38
+
39
+ from torchrl.data.replay_buffers.storages import Storage
40
+ from torchrl.data.replay_buffers.utils import _is_int, _reduce
41
+
42
+
43
+ class Writer(ABC):
44
+ """A ReplayBuffer base Writer class."""
45
+
46
+ _storage: Storage
47
+ _rng: torch.Generator | None = None
48
+
49
+ def __init__(self, compilable: bool = False) -> None:
50
+ self._storage = None
51
+ self._compilable = compilable
52
+
53
+ def register_storage(self, storage: Storage) -> None:
54
+ self._storage = storage
55
+
56
+ @abstractmethod
57
+ def add(self, data: Any) -> int:
58
+ """Inserts one piece of data at an appropriate index, and returns that index."""
59
+ ...
60
+
61
+ @abstractmethod
62
+ def extend(self, data: Sequence) -> torch.Tensor:
63
+ """Inserts a series of data points at appropriate indices, and returns a tensor containing the indices."""
64
+ ...
65
+
66
+ @abstractmethod
67
+ def _empty(self, empty_write_count: bool = True) -> None:
68
+ ...
69
+
70
+ @abstractmethod
71
+ def dumps(self, path):
72
+ ...
73
+
74
+ @abstractmethod
75
+ def loads(self, path):
76
+ ...
77
+
78
+ @abstractmethod
79
+ def state_dict(self) -> dict[str, Any]:
80
+ ...
81
+
82
+ @abstractmethod
83
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
84
+ ...
85
+
86
+ def _replicate_index(self, index):
87
+ # replicates the index in a non-zero format to have as many indices as
88
+ # elements truly written when the storage is multidim
89
+ if self._storage.ndim == 1:
90
+ return index
91
+ device = (
92
+ index.device if isinstance(index, torch.Tensor) else torch.device("cpu")
93
+ )
94
+ mesh = torch.stack(
95
+ torch.meshgrid(
96
+ *(torch.arange(dim, device=device) for dim in self._storage.shape[1:])
97
+ ),
98
+ -1,
99
+ ).flatten(0, -2)
100
+ if _is_int(index):
101
+ index0 = torch.as_tensor(int(index)).expand(mesh.shape[0], 1)
102
+ return torch.cat([index0, mesh], 1)
103
+ return torch.cat(
104
+ [
105
+ index.repeat_interleave(mesh.shape[0]).unsqueeze(1),
106
+ mesh.repeat(index.numel(), 1),
107
+ ],
108
+ 1,
109
+ )
110
+
111
+ def __repr__(self):
112
+ return f"{self.__class__.__name__}()"
113
+
114
+ def __getstate__(self):
115
+ state = copy(self.__dict__)
116
+ state["_rng"] = None
117
+ return state
118
+
119
+
120
+ class ImmutableDatasetWriter(Writer):
121
+ """A blocking writer for immutable datasets."""
122
+
123
+ WRITING_ERR = "This dataset doesn't allow writing."
124
+
125
+ def add(self, data: Any) -> int:
126
+ raise RuntimeError(self.WRITING_ERR)
127
+
128
+ def extend(self, data: Sequence) -> torch.Tensor:
129
+ raise RuntimeError(self.WRITING_ERR)
130
+
131
+ def _empty(self, empty_write_count: bool = True) -> None:
132
+ raise RuntimeError(self.WRITING_ERR)
133
+
134
+ def dumps(self, path):
135
+ ...
136
+
137
+ def loads(self, path):
138
+ ...
139
+
140
+ def state_dict(self) -> dict[str, Any]:
141
+ return {}
142
+
143
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
144
+ return
145
+
146
+
147
+ class RoundRobinWriter(Writer):
148
+ """A RoundRobin Writer class for composable replay buffers.
149
+
150
+ Args:
151
+ compilable (bool, optional): whether the writer is compilable.
152
+ If ``True``, the writer cannot be shared between multiple processes.
153
+ Defaults to ``False``.
154
+
155
+ """
156
+
157
+ def __init__(self, compilable: bool = False) -> None:
158
+ super().__init__(compilable=compilable)
159
+ self._cursor = 0
160
+ self._write_count # noqa
161
+
162
+ def dumps(self, path):
163
+ path = Path(path).absolute()
164
+ path.mkdir(exist_ok=True)
165
+ with open(path / "metadata.json", "w") as file:
166
+ json.dump({"cursor": self._cursor}, file)
167
+
168
+ def loads(self, path):
169
+ path = Path(path).absolute()
170
+ with open(path / "metadata.json") as file:
171
+ metadata = json.load(file)
172
+ self._cursor = metadata["cursor"]
173
+
174
+ def add(self, data: Any) -> int | torch.Tensor:
175
+ index = self._cursor
176
+ _cursor = self._cursor
177
+ # we need to update the cursor first to avoid race conditions between workers
178
+ self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
179
+ single_data=data
180
+ )
181
+ self._write_count += 1
182
+ # Replicate index requires the shape of the storage to be known
183
+ # Other than that, a "flat" (1d) index is ok to write the data
184
+ self._storage.set(_cursor, data)
185
+ index = self._replicate_index(index)
186
+ self._mark_update_entities(index)
187
+ return index
188
+
189
+ def extend(self, data: Sequence) -> torch.Tensor:
190
+ cur_size = self._cursor
191
+ if is_tensor_collection(data) or isinstance(data, torch.Tensor):
192
+ batch_size = len(data)
193
+ elif isinstance(data, list):
194
+ batch_size = len(data)
195
+ else:
196
+ batch_size = len(tree_leaves(data)[0])
197
+ if batch_size == 0:
198
+ raise RuntimeError(f"Expected at least one element in extend. Got {data=}")
199
+ device = data.device if hasattr(data, "device") else None
200
+ max_size_along0 = self._storage._max_size_along_dim0(batched_data=data)
201
+ index = (
202
+ torch.arange(
203
+ cur_size, batch_size + cur_size, dtype=torch.long, device=device
204
+ )
205
+ % max_size_along0
206
+ )
207
+ # we need to update the cursor first to avoid race conditions between workers
208
+ self._cursor = (batch_size + cur_size) % max_size_along0
209
+ self._write_count += batch_size
210
+ # Replicate index requires the shape of the storage to be known
211
+ # Other than that, a "flat" (1d) index is ok to write the data
212
+ self._storage.set(index, data)
213
+ index = self._replicate_index(index)
214
+ self._mark_update_entities(index)
215
+ return index
216
+
217
+ def state_dict(self) -> dict[str, Any]:
218
+ return {"_cursor": self._cursor}
219
+
220
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
221
+ self._cursor = state_dict["_cursor"]
222
+
223
+ def _empty(self, empty_write_count: bool = True) -> None:
224
+ self._cursor = 0
225
+ if empty_write_count:
226
+ self._write_count = 0
227
+
228
+ # TODO: Workaround for PyTorch nightly regression where compiler can't handle
229
+ # method calls on objects returned from _attached_entities_iter()
230
+ @compile_disable()
231
+ def _mark_update_entities(self, index: torch.Tensor) -> None:
232
+ """Mark entities as updated with the given index."""
233
+ for ent in self._storage._attached_entities_iter():
234
+ ent.mark_update(index)
235
+
236
+ @property
237
+ def _cursor(self):
238
+ _cursor_value = self.__dict__.get("_cursor_value", None)
239
+ if not self._compilable:
240
+ if _cursor_value is None:
241
+ _cursor_value = self._cursor_value = mp.Value("i", 0)
242
+ return _cursor_value.value
243
+ else:
244
+ if _cursor_value is None:
245
+ _cursor_value = self._cursor_value = 0
246
+ return _cursor_value
247
+
248
+ @_cursor.setter
249
+ def _cursor(self, value):
250
+ if not self._compilable:
251
+ _cursor_value = self.__dict__.get("_cursor_value", None)
252
+ if _cursor_value is None:
253
+ _cursor_value = self._cursor_value = mp.Value("i", 0)
254
+ _cursor_value.value = value
255
+ else:
256
+ self._cursor_value = value
257
+
258
+ @property
259
+ def _write_count(self):
260
+ _write_count = self.__dict__.get("_write_count_value", None)
261
+ if not self._compilable:
262
+ if _write_count is None:
263
+ _write_count = self._write_count_value = mp.Value("i", 0)
264
+ return _write_count.value
265
+ else:
266
+ if _write_count is None:
267
+ _write_count = self._write_count_value = 0
268
+ return _write_count
269
+
270
+ @_write_count.setter
271
+ def _write_count(self, value):
272
+ if not self._compilable:
273
+ _write_count = self.__dict__.get("_write_count_value", None)
274
+ if _write_count is None:
275
+ _write_count = self._write_count_value = mp.Value("i", 0)
276
+ _write_count.value = value
277
+ else:
278
+ self._write_count_value = value
279
+
280
+ def __getstate__(self):
281
+ state = super().__getstate__()
282
+ if get_spawning_popen() is None:
283
+ cursor = self._cursor
284
+ write_count = self._write_count
285
+ del state["_cursor_value"]
286
+ del state["_write_count_value"]
287
+ state["cursor__context"] = cursor
288
+ state["write_count__context"] = write_count
289
+ return state
290
+
291
+ def __setstate__(self, state):
292
+ cursor = state.pop("cursor__context", None)
293
+ write_count = state.pop("write_count__context", None)
294
+ if cursor is not None:
295
+ if not state["_compilable"]:
296
+ _cursor_value = mp.Value("i", cursor)
297
+ else:
298
+ _cursor_value = cursor
299
+ state["_cursor_value"] = _cursor_value
300
+ if write_count is not None:
301
+ if not state["_compilable"]:
302
+ _write_count_value = mp.Value("i", write_count)
303
+ else:
304
+ _write_count_value = write_count
305
+ state["_write_count_value"] = _write_count_value
306
+ self.__dict__.update(state)
307
+
308
+ def __repr__(self):
309
+ return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full})"
310
+
311
+
312
+ class TensorDictRoundRobinWriter(RoundRobinWriter):
313
+ """A RoundRobin Writer class for composable, tensordict-based replay buffers."""
314
+
315
+ def add(self, data: Any) -> int | torch.Tensor:
316
+ index = self._cursor
317
+ # we need to update the cursor first to avoid race conditions between workers
318
+ max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
319
+ self._cursor = (index + 1) % max_size_along_dim0
320
+ self._write_count += 1
321
+ if not is_tensorclass(data):
322
+ data.set(
323
+ "index",
324
+ expand_as_right(
325
+ torch.as_tensor(index, device=data.device, dtype=torch.long), data
326
+ ),
327
+ )
328
+ self._storage.set(index, data)
329
+ index = self._replicate_index(index)
330
+ self._mark_update_entities(index)
331
+ return index
332
+
333
+ def extend(self, data: Sequence) -> torch.Tensor:
334
+ cur_size = self._cursor
335
+ batch_size = len(data)
336
+ device = data.device if hasattr(data, "device") else None
337
+ max_size_along_dim0 = self._storage._max_size_along_dim0(batched_data=data)
338
+ index = (
339
+ torch.arange(
340
+ cur_size, batch_size + cur_size, dtype=torch.long, device=device
341
+ )
342
+ % max_size_along_dim0
343
+ )
344
+ # we need to update the cursor first to avoid race conditions between workers
345
+ self._cursor = (batch_size + cur_size) % max_size_along_dim0
346
+ self._write_count += batch_size
347
+ # storage must convert the data to the appropriate format if needed
348
+ if not is_tensorclass(data):
349
+ data.set(
350
+ "index",
351
+ expand_as_right(
352
+ torch.as_tensor(index, device=data.device, dtype=torch.long), data
353
+ ),
354
+ )
355
+ # Replicate index requires the shape of the storage to be known
356
+ # Other than that, a "flat" (1d) index is ok to write the data
357
+ self._storage.set(index, data)
358
+ index = self._replicate_index(index)
359
+ self._mark_update_entities(index)
360
+ return index
361
+
362
+
363
+ class TensorDictMaxValueWriter(Writer):
364
+ """A Writer class for composable replay buffers that keeps the top elements based on some ranking key.
365
+
366
+ Args:
367
+ rank_key (str or tuple of str): the key to rank the elements by. Defaults to ``("next", "reward")``.
368
+ reduction (str): the reduction method to use if the rank key has more than one element.
369
+ Can be ``"max"``, ``"min"``, ``"mean"``, ``"median"`` or ``"sum"``.
370
+
371
+ Examples:
372
+ >>> import torch
373
+ >>> from tensordict import TensorDict
374
+ >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter
375
+ >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
376
+ >>> rb = TensorDictReplayBuffer(
377
+ ... storage=LazyTensorStorage(1),
378
+ ... sampler=SamplerWithoutReplacement(),
379
+ ... batch_size=1,
380
+ ... writer=TensorDictMaxValueWriter(rank_key="key"),
381
+ ... )
382
+ >>> td = TensorDict({
383
+ ... "key": torch.tensor(range(10)),
384
+ ... "obs": torch.tensor(range(10))
385
+ ... }, batch_size=10)
386
+ >>> rb.extend(td)
387
+ >>> print(rb.sample().get("obs").item())
388
+ 9
389
+ >>> td = TensorDict({
390
+ ... "key": torch.tensor(range(10, 20)),
391
+ ... "obs": torch.tensor(range(10, 20))
392
+ ... }, batch_size=10)
393
+ >>> rb.extend(td)
394
+ >>> print(rb.sample().get("obs").item())
395
+ 19
396
+ >>> td = TensorDict({
397
+ ... "key": torch.tensor(range(10)),
398
+ ... "obs": torch.tensor(range(10))
399
+ ... }, batch_size=10)
400
+ >>> rb.extend(td)
401
+ >>> print(rb.sample().get("obs").item())
402
+ 19
403
+
404
+ .. note::
405
+ This class isn't compatible with storages with more than one dimension.
406
+ This doesn't mean that storing trajectories is prohibited, but that
407
+ the trajectories stored must be stored on a per-trajectory basis.
408
+ Here are some examples of valid and invalid usages of the class.
409
+ First, a flat buffer where we store individual transitions:
410
+
411
+ >>> from torchrl.data import TensorStorage
412
+ >>> # Simplest use case: data comes in 1d and is stored as such
413
+ >>> data = TensorDict({
414
+ ... "obs": torch.zeros(10, 3),
415
+ ... "reward": torch.zeros(10, 1),
416
+ ... }, batch_size=[10])
417
+ >>> rb = TensorDictReplayBuffer(
418
+ ... storage=LazyTensorStorage(max_size=100),
419
+ ... writer=TensorDictMaxValueWriter(rank_key="reward")
420
+ ... )
421
+ >>> # We initialize the buffer: a total of 100 *transitions* can be stored
422
+ >>> rb.extend(data)
423
+ >>> # Samples 5 *transitions* at random
424
+ >>> sample = rb.sample(5)
425
+ >>> assert sample.shape == (5,)
426
+
427
+ Second, a buffer where we store trajectories. The max signal is aggregated
428
+ in each batch (e.g. the reward of each rollout is summed):
429
+
430
+ >>> # One can also store batches of data, each batch being a sub-trajectory
431
+ >>> env = ParallelEnv(2, lambda: GymEnv("Pendulum-v1"))
432
+ >>> # Get a batch of [2, 10] -- format is [Batch, Time]
433
+ >>> rollout = env.rollout(max_steps=10)
434
+ >>> rb = TensorDictReplayBuffer(
435
+ ... storage=LazyTensorStorage(max_size=100),
436
+ ... writer=TensorDictMaxValueWriter(rank_key="reward")
437
+ ... )
438
+ >>> # We initialize the buffer: a total of 100 *trajectories* (!) can be stored
439
+ >>> rb.extend(rollout)
440
+ >>> # Sample 5 trajectories at random
441
+ >>> sample = rb.sample(5)
442
+ >>> assert sample.shape == (5, 10)
443
+
444
+ If data come in batch but a flat buffer is needed, we can simply flatten
445
+ the data before extending the buffer:
446
+
447
+ >>> rb = TensorDictReplayBuffer(
448
+ ... storage=LazyTensorStorage(max_size=100),
449
+ ... writer=TensorDictMaxValueWriter(rank_key="reward")
450
+ ... )
451
+ >>> # We initialize the buffer: a total of 100 *transitions* can be stored
452
+ >>> rb.extend(rollout.reshape(-1))
453
+ >>> # Sample 5 trajectories at random
454
+ >>> sample = rb.sample(5)
455
+ >>> assert sample.shape == (5,)
456
+
457
+ It is not possible to create a buffer that is extended along the time
458
+ dimension, which is usually the recommended way of using buffers with
459
+ batches of trajectories. Since trajectories are overlapping, it's hard
460
+ if not impossible to aggregate the reward values and compare them.
461
+ This constructor isn't valid (notice the ndim argument):
462
+
463
+ >>> rb = TensorDictReplayBuffer(
464
+ ... storage=LazyTensorStorage(max_size=100, ndim=2), # Breaks!
465
+ ... writer=TensorDictMaxValueWriter(rank_key="reward")
466
+ ... )
467
+
468
+ """
469
+
470
+ def __init__(self, rank_key=None, reduction: str = "sum", **kwargs) -> None:
471
+ super().__init__(**kwargs)
472
+ self._cursor = 0
473
+ self._current_top_values = []
474
+ self._rank_key = rank_key
475
+ self._reduction = reduction
476
+ if self._rank_key is None:
477
+ self._rank_key = ("next", "reward")
478
+
479
+ def register_storage(self, storage: Storage) -> None:
480
+ if storage.ndim > 1:
481
+ raise ValueError(
482
+ "TensorDictMaxValueWriter is not compatible with storages with more than one dimension. "
483
+ "See the docstring constructor note about storing trajectories with TensorDictMaxValueWriter."
484
+ )
485
+ return super().register_storage(storage)
486
+
487
+ def get_insert_index(self, data: Any) -> int:
488
+ """Returns the index where the data should be inserted, or ``None`` if it should not be inserted."""
489
+ if not is_tensor_collection(data):
490
+ raise RuntimeError(
491
+ f"{type(self)} expects data to be a tensor collection (tensordict or tensorclass). Found a {type(data)} instead."
492
+ )
493
+ if data.batch_dims > 1:
494
+ raise RuntimeError(
495
+ "Expected input tensordict to have no more than 1 dimension, got"
496
+ f"tensordict.batch_size = {data.batch_size}"
497
+ )
498
+
499
+ ret = None
500
+ rank_data = data.get(self._rank_key)
501
+
502
+ # If time dimension, sum along it.
503
+ if rank_data.numel() > 1:
504
+ rank_data = _reduce(rank_data.reshape(-1), self._reduction, dim=0)
505
+ else:
506
+ rank_data = rank_data.item()
507
+
508
+ if rank_data is None:
509
+ raise KeyError(f"Rank key {self._rank_key} not found in data.")
510
+
511
+ # If the buffer is not full, add the data
512
+ if len(self._current_top_values) < self._storage.max_size:
513
+ ret = self._cursor
514
+ self._cursor = (self._cursor + 1) % self._storage.max_size
515
+
516
+ # Add new reward to the heap
517
+ heapq.heappush(self._current_top_values, (rank_data, ret))
518
+
519
+ # If the buffer is full, check if the new data is better than the worst data in the buffer
520
+ elif rank_data > self._current_top_values[0][0]:
521
+
522
+ # retrieve position of the smallest value
523
+ min_sample = heapq.heappop(self._current_top_values)
524
+ ret = min_sample[1]
525
+
526
+ # Add new reward to the heap
527
+ heapq.heappush(self._current_top_values, (rank_data, ret))
528
+
529
+ return ret
530
+
531
+ @property
532
+ def _write_count(self):
533
+ _write_count = self.__dict__.get("_write_count_value", None)
534
+ if _write_count is None:
535
+ _write_count = self._write_count_value = mp.Value("i", 0)
536
+ return _write_count.value
537
+
538
+ @_write_count.setter
539
+ def _write_count(self, value):
540
+ _write_count = self.__dict__.get("_write_count_value", None)
541
+ if _write_count is None:
542
+ _write_count = self._write_count_value = mp.Value("i", 0)
543
+ _write_count.value = value
544
+
545
+ def add(self, data: Any) -> int | torch.Tensor:
546
+ """Inserts a single element of data at an appropriate index, and returns that index.
547
+
548
+ The ``rank_key`` in the data passed to this module should be structured as [].
549
+ If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
550
+ """
551
+ index = self.get_insert_index(data)
552
+ if index is not None:
553
+ data.set("index", index)
554
+ self._write_count += 1
555
+ # Replicate index requires the shape of the storage to be known
556
+ # Other than that, a "flat" (1d) index is ok to write the data
557
+ self._storage.set(index, data)
558
+ index = self._replicate_index(index)
559
+ for ent in self._storage._attached_entities_iter():
560
+ ent.mark_update(index)
561
+ return index
562
+
563
+ def extend(self, data: TensorDictBase) -> None:
564
+ """Inserts a series of data points at appropriate indices.
565
+
566
+ The ``rank_key`` in the data passed to this module should be structured as [B].
567
+ If it has more dimensions, it will be reduced to a single value using the ``reduction`` method.
568
+ """
569
+ # a map of [idx_in_storage, idx_in_data]
570
+ data_to_replace = {}
571
+ for data_idx, sample in enumerate(data):
572
+ storage_idx = self.get_insert_index(sample)
573
+ if storage_idx is not None:
574
+ self._write_count += 1
575
+ data_to_replace[storage_idx] = data_idx
576
+
577
+ # -1 will be interpreted as invalid by prioritized buffers
578
+ # Replace the data in the storage all at once
579
+ if len(data_to_replace) > 0:
580
+ storage_idx, data_idx = zip(*data_to_replace.items())
581
+ index = data.get("index", None)
582
+ dtype = index.dtype if index is not None else torch.long
583
+ device = index.device if index is not None else data.device
584
+ out_index = torch.full(data.shape, -1, dtype=torch.long, device=device)
585
+ data_idx = torch.as_tensor(data_idx, dtype=dtype, device=device)
586
+ storage_idx = torch.as_tensor(storage_idx, dtype=dtype, device=device)
587
+ out_index[data_idx] = storage_idx
588
+ self._storage.set(storage_idx, data[data_idx])
589
+ else:
590
+ device = getattr(self._storage, "device", None)
591
+ out_index = torch.full(data.shape, -1, dtype=torch.long, device=device)
592
+ index = self._replicate_index(out_index)
593
+ self._mark_update_entities(index)
594
+ return index
595
+
596
+ # TODO: Workaround for PyTorch nightly regression where compiler can't handle
597
+ # method calls on objects returned from _attached_entities_iter()
598
+ @compile_disable()
599
+ def _mark_update_entities(self, index: torch.Tensor) -> None:
600
+ """Mark entities as updated with the given index."""
601
+ for ent in self._storage._attached_entities_iter():
602
+ ent.mark_update(index)
603
+
604
+ def _empty(self, empty_write_count: bool = True) -> None:
605
+ self._cursor = 0
606
+ self._current_top_values = []
607
+ if empty_write_count:
608
+ self._write_count = 0
609
+
610
+ def __getstate__(self):
611
+ if get_spawning_popen() is not None:
612
+ raise RuntimeError(
613
+ f"Writers of type {type(self)} cannot be shared between processes. "
614
+ f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed."
615
+ )
616
+ state = super().__getstate__()
617
+ # Handle the mp.Value object for pickling
618
+ if "_write_count_value" in state:
619
+ write_count = self._write_count
620
+ del state["_write_count_value"]
621
+ state["write_count__context"] = write_count
622
+ return state
623
+
624
+ def __setstate__(self, state):
625
+ write_count = state.pop("write_count__context", None)
626
+ if write_count is not None:
627
+ state["_write_count_value"] = mp.Value("i", write_count)
628
+ self.__dict__.update(state)
629
+
630
+ def dumps(self, path):
631
+ path = Path(path).absolute()
632
+ path.mkdir(exist_ok=True)
633
+ t = torch.as_tensor(self._current_top_values)
634
+ try:
635
+ MemoryMappedTensor.from_filename(
636
+ filename=path / "current_top_values.memmap",
637
+ shape=t.shape,
638
+ dtype=t.dtype,
639
+ ).copy_(t)
640
+ except FileNotFoundError:
641
+ MemoryMappedTensor.from_tensor(
642
+ t, filename=path / "current_top_values.memmap"
643
+ )
644
+ with open(path / "metadata.json", "w") as file:
645
+ json.dump(
646
+ {
647
+ "cursor": self._cursor,
648
+ "rank_key": self._rank_key,
649
+ "dtype": str(t.dtype),
650
+ "shape": list(t.shape),
651
+ },
652
+ file,
653
+ )
654
+
655
+ def loads(self, path):
656
+ path = Path(path).absolute()
657
+ with open(path / "metadata.json") as file:
658
+ metadata = json.load(file)
659
+ self._cursor = metadata["cursor"]
660
+ self._rank_key = metadata["rank_key"]
661
+ shape = torch.Size(metadata["shape"])
662
+ dtype = metadata["dtype"]
663
+ self._current_top_values = MemoryMappedTensor.from_filename(
664
+ filename=path / "current_top_values.memmap",
665
+ dtype=_STRDTYPE2DTYPE[dtype],
666
+ shape=shape,
667
+ ).tolist()
668
+
669
+ def state_dict(self) -> dict[str, Any]:
670
+ raise NotImplementedError
671
+
672
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
673
+ raise NotImplementedError
674
+
675
+ def __repr__(self):
676
+ return f"{self.__class__.__name__}(cursor={int(self._cursor)}, full_storage={self._storage._is_full}, rank_key={self._rank_key}, reduction={self._reduction})"
677
+
678
+
679
+ class WriterEnsemble(Writer):
680
+ """An ensemble of writers.
681
+
682
+ This class is designed to work with :class:`~torchrl.data.replay_buffers.replay_buffers.ReplayBufferEnsemble`.
683
+ It contains the writers but blocks writing with any of them.
684
+
685
+ Args:
686
+ writers (sequence of Writer): the writers to make the composite writer.
687
+
688
+ .. warning::
689
+ This class does not support writing.
690
+ To extend one of the replay buffers, simply index the parent
691
+ :class:`~torchrl.data.ReplayBufferEnsemble` object.
692
+
693
+ """
694
+
695
+ def __init__(self, *writers):
696
+ self._rng_private = None
697
+ self._writers = writers
698
+
699
+ @property
700
+ def _rng(self):
701
+ return self._rng_private
702
+
703
+ @_rng.setter
704
+ def _rng(self, value):
705
+ self._rng_private = value
706
+ for writer in self._writers:
707
+ writer._rng = value
708
+
709
+ def _empty(self, empty_write_count: bool = True) -> None:
710
+ raise NotImplementedError
711
+
712
+ def dumps(self, path: Path):
713
+ path = Path(path).absolute()
714
+ for i, writer in enumerate(self._writers):
715
+ writer.dumps(path / str(i))
716
+
717
+ def loads(self, path: Path):
718
+ path = Path(path).absolute()
719
+ for i, writer in enumerate(self._writers):
720
+ writer.loads(path / str(i))
721
+
722
+ def add(self):
723
+ raise NotImplementedError
724
+
725
+ def extend(self):
726
+ raise NotImplementedError
727
+
728
+ _INDEX_ERROR = "Expected an index of type torch.Tensor, range, np.ndarray, int, slice or ellipsis, got {} instead."
729
+
730
+ def __getitem__(self, index):
731
+ if isinstance(index, tuple):
732
+ if index[0] is Ellipsis:
733
+ index = (slice(None), index[1:])
734
+ result = self[index[0]]
735
+ if len(index) > 1:
736
+ raise IndexError(
737
+ f"Tuple of length greater than 1 are not accepted to index writers of type {type(self)}."
738
+ )
739
+ return result
740
+ if isinstance(index, slice) and index == slice(None):
741
+ return self
742
+ if isinstance(index, (list, range, np.ndarray)):
743
+ index = torch.as_tensor(index)
744
+ if isinstance(index, torch.Tensor):
745
+ if index.ndim > 1:
746
+ raise RuntimeError(
747
+ f"Cannot index a {type(self)} with tensor indices that have more than one dimension."
748
+ )
749
+ if index.is_floating_point():
750
+ raise TypeError(
751
+ "A floating point index was received when an integer dtype was expected."
752
+ )
753
+ if isinstance(index, int) or (not isinstance(index, slice) and len(index) == 0):
754
+ try:
755
+ index = int(index)
756
+ except Exception:
757
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
758
+ try:
759
+ return self._writers[index]
760
+ except IndexError:
761
+ raise IndexError(self._INDEX_ERROR.format(type(index)))
762
+ if isinstance(index, torch.Tensor):
763
+ index = index.tolist()
764
+ writers = [self._writers[i] for i in index]
765
+ else:
766
+ # slice
767
+ writers = self._writers[index]
768
+ return WriterEnsemble(*writers)
769
+
770
+ def __len__(self):
771
+ return len(self._writers)
772
+
773
+ def __repr__(self):
774
+ writers = textwrap.indent(f"writers={self._writers}", " " * 4)
775
+ return f"WriterEnsemble(\n{writers})"
776
+
777
+ def state_dict(self) -> dict[str, Any]:
778
+ raise NotImplementedError
779
+
780
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
781
+ raise NotImplementedError