torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,363 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import functools
9
+ from abc import abstractmethod
10
+ from collections.abc import Callable
11
+ from typing import Any, Generic, TypeVar
12
+
13
+ import torch
14
+ from tensordict import is_tensor_collection, NestedKey, TensorDictBase
15
+ from tensordict.nn.common import TensorDictModuleBase
16
+
17
+ from torchrl.data.map.hash import RandomProjectionHash, SipHash
18
+ from torchrl.data.map.query import QueryModule
19
+ from torchrl.data.replay_buffers.storages import (
20
+ _get_default_collate,
21
+ LazyTensorStorage,
22
+ TensorStorage,
23
+ )
24
+
25
+ K = TypeVar("K")
26
+ V = TypeVar("V")
27
+
28
+
29
+ class TensorMap(abc.ABC, Generic[K, V]):
30
+ """An Abstraction for implementing different storage.
31
+
32
+ This class is for internal use, please use derived classes instead.
33
+ """
34
+
35
+ @abstractmethod
36
+ def clear(self) -> None:
37
+ raise NotImplementedError
38
+
39
+ @abstractmethod
40
+ def __getitem__(self, item: K) -> V:
41
+ raise NotImplementedError
42
+
43
+ @abstractmethod
44
+ def __setitem__(self, key: K, value: V) -> None:
45
+ raise NotImplementedError
46
+
47
+ @abstractmethod
48
+ def __len__(self) -> int:
49
+ raise NotImplementedError
50
+
51
+ @abstractmethod
52
+ def contains(self, item: K) -> torch.Tensor:
53
+ raise NotImplementedError
54
+
55
+ def __contains__(self, item):
56
+ return self.contains(item)
57
+
58
+
59
+ class TensorDictMap(
60
+ TensorDictModuleBase, TensorMap[TensorDictModuleBase, TensorDictModuleBase]
61
+ ):
62
+ """A Map-Storage for TensorDict.
63
+
64
+ This module resembles a storage. It takes a tensordict as its input and
65
+ returns another tensordict as output similar to TensorDictModuleBase. However,
66
+ it provides additional functionality like python map:
67
+
68
+ Keyword Args:
69
+ query_module (TensorDictModuleBase): a query module, typically an instance of
70
+ :class:`~tensordict.nn.QueryModule`, used to map a set of tensordict
71
+ entries to a hash key.
72
+ storage (Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]]):
73
+ a dictionary representing the map from an index key to a tensor storage.
74
+ collate_fn (callable, optional): a function to use to collate samples from the
75
+ storage. Defaults to a custom value for each known storage type (stack for
76
+ :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage`
77
+ subtypes and others).
78
+
79
+ Examples:
80
+ >>> import torch
81
+ >>> from tensordict import TensorDict
82
+ >>> from typing import cast
83
+ >>> from torchrl.data import LazyTensorStorage
84
+ >>> query_module = QueryModule(
85
+ ... in_keys=["key1", "key2"],
86
+ ... index_key="index",
87
+ ... )
88
+ >>> embedding_storage = LazyTensorStorage(1000)
89
+ >>> tensor_dict_storage = TensorDictMap(
90
+ ... query_module=query_module,
91
+ ... storage={"out": embedding_storage},
92
+ ... )
93
+ >>> index = TensorDict(
94
+ ... {
95
+ ... "key1": torch.Tensor([[-1], [1], [3], [-3]]),
96
+ ... "key2": torch.Tensor([[0], [2], [4], [-4]]),
97
+ ... },
98
+ ... batch_size=(4,),
99
+ ... )
100
+ >>> value = TensorDict(
101
+ ... {"out": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
102
+ ... )
103
+ >>> tensor_dict_storage[index] = value
104
+ >>> tensor_dict_storage[index]
105
+ TensorDict(
106
+ fields={
107
+ out: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
108
+ batch_size=torch.Size([4]),
109
+ device=None,
110
+ is_shared=False)
111
+ >>> assert torch.sum(tensor_dict_storage.contains(index)).item() == 4
112
+ >>> new_index = index.clone(True)
113
+ >>> new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
114
+ >>> retrieve_value = tensor_dict_storage[new_index]
115
+ >>> assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all()
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ *,
121
+ query_module: QueryModule,
122
+ storage: dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]],
123
+ collate_fn: Callable[[Any], Any] | None = None,
124
+ out_keys: list[NestedKey] | None = None,
125
+ write_fn: Callable[[Any, Any], Any] | None = None,
126
+ ):
127
+ super().__init__()
128
+
129
+ self.in_keys = query_module.in_keys
130
+ if out_keys is not None:
131
+ self.out_keys = out_keys
132
+
133
+ self.query_module = query_module
134
+ self.index_key = query_module.index_key
135
+ self.storage = storage
136
+ self.batch_added = False
137
+ if collate_fn is None:
138
+ collate_fn = _get_default_collate(self.storage)
139
+ self.collate_fn = collate_fn
140
+ self.write_fn = write_fn
141
+
142
+ @property
143
+ def max_size(self):
144
+ return self.storage.max_size
145
+
146
+ @property
147
+ def out_keys(self) -> list[NestedKey]:
148
+ out_keys = self.__dict__.get("_out_keys_and_lazy")
149
+ if out_keys is not None:
150
+ return out_keys[0]
151
+ storage = self.storage
152
+ if isinstance(storage, TensorStorage) and is_tensor_collection(
153
+ storage._storage
154
+ ):
155
+ out_keys = list(storage._storage.keys(True, True))
156
+ self._out_keys_and_lazy = (out_keys, True)
157
+ return self.out_keys
158
+ raise AttributeError(
159
+ f"No out-keys found in the storage of type {type(storage)}"
160
+ )
161
+
162
+ @out_keys.setter
163
+ def out_keys(self, value):
164
+ self._out_keys_and_lazy = (value, False)
165
+
166
+ def _has_lazy_out_keys(self):
167
+ _out_keys_and_lazy = self.__dict__.get("_out_keys_and_lazy")
168
+ if _out_keys_and_lazy is None:
169
+ return True
170
+ return self._out_keys_and_lazy[1]
171
+
172
+ @classmethod
173
+ def from_tensordict_pair(
174
+ cls,
175
+ source,
176
+ dest,
177
+ in_keys: list[NestedKey],
178
+ out_keys: list[NestedKey] | None = None,
179
+ max_size: int = 1000,
180
+ storage_constructor: type | None = None,
181
+ hash_module: Callable | None = None,
182
+ collate_fn: Callable[[Any], Any] | None = None,
183
+ write_fn: Callable[[Any, Any], Any] | None = None,
184
+ consolidated: bool | None = None,
185
+ ) -> TensorDictMap:
186
+ """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
187
+
188
+ Args:
189
+ source (TensorDict): An example of source tensordict, used as index in the storage.
190
+ dest (TensorDict): An example of dest tensordict, used as data in the storage.
191
+ in_keys (List[NestedKey]): a list of keys to use in the map.
192
+ out_keys (List[NestedKey]): a list of keys to return in the output tensordict.
193
+ All keys absent from out_keys, even if present in ``dest``, will not be stored
194
+ in the storage. Defaults to ``None`` (all keys are registered).
195
+ max_size (int, optional): the maximum number of elements in the storage. Ignored if the
196
+ ``storage_constructor`` is passed. Defaults to ``1000``.
197
+ storage_constructor (Type, optional): a type of tensor storage.
198
+ Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`.
199
+ Other options include :class:`~tensordict.nn.storage.FixedStorage`.
200
+ hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`.
201
+ Defaults to :class:`SipHash` for low-dimensional inputs, and :class:`~torchrl.data.map.RandomProjectionHash`
202
+ for larger inputs.
203
+ collate_fn (callable, optional): a function to use to collate samples from the
204
+ storage. Defaults to a custom value for each known storage type (stack for
205
+ :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage`
206
+ subtypes and others).
207
+ consolidated (bool, optional): whether to consolidate the storage in a single storage tensor.
208
+ Defaults to ``False``.
209
+
210
+ Examples:
211
+ >>> # The following example requires torchrl and gymnasium to be installed
212
+ >>> from torchrl.envs import GymEnv
213
+ >>> torch.manual_seed(0)
214
+ >>> env = GymEnv("CartPole-v1")
215
+ >>> env.set_seed(0)
216
+ >>> rollout = env.rollout(100)
217
+ >>> source, dest = rollout.exclude("next"), rollout.get("next")
218
+ >>> storage = TensorDictMap.from_tensordict_pair(
219
+ ... source, dest,
220
+ ... in_keys=["observation", "action"],
221
+ ... )
222
+ >>> # maps the (obs, action) tuple to a corresponding next state
223
+ >>> storage[source] = dest
224
+ >>> print(source["_index"])
225
+ tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
226
+ >>> storage[source]
227
+ TensorDict(
228
+ fields={
229
+ done: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False),
230
+ observation: Tensor(shape=torch.Size([14, 4]), device=cpu, dtype=torch.float32, is_shared=False),
231
+ reward: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.float32, is_shared=False),
232
+ terminated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False),
233
+ truncated: Tensor(shape=torch.Size([14, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
234
+ batch_size=torch.Size([14]),
235
+ device=None,
236
+ is_shared=False)
237
+
238
+ """
239
+ # Build query module
240
+ if hash_module is None:
241
+ # Count the features, if they're greater than RandomProjectionHash._N_COMPONENTS_DEFAULT
242
+ # use that module to project them to that dimensionality.
243
+ n_feat = 0
244
+ hash_module = []
245
+ for in_key in in_keys:
246
+ entry = source[in_key]
247
+ if entry.ndim == source.ndim:
248
+ # this is a good example of why td/tc are useful - carrying metadata
249
+ # allows us to know if there's a feature dim or not
250
+ n_feat = 0
251
+ else:
252
+ n_feat = entry.shape[-1]
253
+ if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT:
254
+ _hash_module = RandomProjectionHash()
255
+ else:
256
+ _hash_module = SipHash()
257
+ hash_module.append(_hash_module)
258
+ query_module = QueryModule(in_keys, hash_module=hash_module)
259
+
260
+ # Build key_to_storage
261
+ if storage_constructor is None:
262
+ storage_constructor = functools.partial(
263
+ LazyTensorStorage, max_size, consolidated=bool(consolidated)
264
+ )
265
+ elif consolidated is not None:
266
+ storage_constructor = functools.partial(
267
+ storage_constructor, consolidated=consolidated
268
+ )
269
+ storage = storage_constructor()
270
+ result = cls(
271
+ query_module=query_module,
272
+ storage=storage,
273
+ collate_fn=collate_fn,
274
+ out_keys=out_keys,
275
+ write_fn=write_fn,
276
+ )
277
+ return result
278
+
279
+ def clear(self) -> None:
280
+ for mem in self.storage.values():
281
+ mem.clear()
282
+
283
+ def _to_index(
284
+ self, item: TensorDictBase, extend: bool, clone: bool | None = None
285
+ ) -> torch.Tensor:
286
+ item = self.query_module(item, extend=extend, clone=clone)
287
+ return item[self.index_key]
288
+
289
+ def _maybe_add_batch(
290
+ self, item: TensorDictBase, value: TensorDictBase | None
291
+ ) -> TensorDictBase:
292
+ self.batch_added = False
293
+ if len(item.batch_size) == 0:
294
+ self.batch_added = True
295
+
296
+ item = item.unsqueeze(dim=0)
297
+ if value is not None:
298
+ value = value.unsqueeze(dim=0)
299
+
300
+ return item, value
301
+
302
+ def _maybe_remove_batch(self, item: TensorDictBase) -> TensorDictBase:
303
+ if self.batch_added:
304
+ item = item.squeeze(dim=0)
305
+ return item
306
+
307
+ def __getitem__(self, item: TensorDictBase) -> TensorDictBase:
308
+ item = item.copy()
309
+ item, _ = self._maybe_add_batch(item, None)
310
+
311
+ index = self._to_index(item, extend=False, clone=False)
312
+
313
+ res = self.storage[index]
314
+ res = self.collate_fn(res)
315
+ res = self._maybe_remove_batch(res)
316
+ return res
317
+
318
+ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
319
+ if not self._has_lazy_out_keys():
320
+ # TODO: make this work with pytrees and avoid calling select if keys match
321
+ value = value.select(*self.out_keys, strict=False)
322
+ item, value = self._maybe_add_batch(item, value)
323
+ index = self._to_index(item, extend=True)
324
+ if index.unique().numel() < index.numel():
325
+ # If multiple values point to the same place in the storage, we cannot process them by batch
326
+ # There could be a better way to deal with this, using unique ids.
327
+ vals = []
328
+ for it, val in zip(item.split(1), value.split(1)):
329
+ self[it] = val
330
+ vals.append(val)
331
+ # __setitem__ may affect the content of the input data
332
+ value.update(TensorDictBase.lazy_stack(vals))
333
+ return
334
+ if self.write_fn is not None:
335
+ # We use this block in the following context: the value written in the storage is already present,
336
+ # but it needs to be updated.
337
+ # We first check if the value is already there using `contains`. If so, we pass the new value and the
338
+ # previous one to write_fn. The values that are not present are passed alone.
339
+ if len(self):
340
+ modifiable = self.contains(item)
341
+ if modifiable.any():
342
+ to_modify = (value[modifiable], self[item[modifiable]])
343
+ v1 = self.write_fn(*to_modify)
344
+ result = value.empty()
345
+ result[modifiable] = v1
346
+ result[~modifiable] = self.write_fn(value[~modifiable])
347
+ value = result
348
+ else:
349
+ value = self.write_fn(value)
350
+ else:
351
+ value = self.write_fn(value)
352
+ self.storage.set(index, value)
353
+
354
+ def __len__(self):
355
+ return len(self.storage)
356
+
357
+ def contains(self, item: TensorDictBase) -> torch.Tensor:
358
+ item, _ = self._maybe_add_batch(item, None)
359
+ index = self._to_index(item, extend=False, clone=True)
360
+
361
+ res = self.storage.contains(index)
362
+ res = self._maybe_remove_batch(res)
363
+ return res