torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,185 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable
8
+
9
+ import torch
10
+ from torch.nn import Module
11
+
12
+
13
+ class BinaryToDecimal(Module):
14
+ """A Module to convert binaries encoded tensors to decimals.
15
+
16
+ This is a utility class that allow to convert a binary encoding tensor (e.g. `1001`) to
17
+ its decimal value (e.g. `9`)
18
+
19
+ Args:
20
+ num_bits (int): the number of bits to use for the bases table.
21
+ The number of bits must be lower or equal to the input length and the input length
22
+ must be divisible by ``num_bits``. If ``num_bits`` is lower than the number of
23
+ bits in the input, the end result will be aggregated on the last dimension using
24
+ :func:`~torch.sum`.
25
+ device (torch.device): the device where inputs and outputs are to be expected.
26
+ dtype (torch.dtype): the output dtype.
27
+ convert_to_binary (bool, optional): if ``True``, the input to the ``forward``
28
+ method will be cast to a binary input using :func:`~torch.heavyside`.
29
+ Defaults to ``False``.
30
+
31
+ Examples:
32
+ >>> binary_to_decimal = BinaryToDecimal(
33
+ ... num_bits=4, device="cpu", dtype=torch.int32, convert_to_binary=True
34
+ ... )
35
+ >>> binary = torch.Tensor([[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 10, 0]])
36
+ >>> decimal = binary_to_decimal(binary)
37
+ >>> assert decimal.shape == (2,)
38
+ >>> assert (decimal == torch.Tensor([3, 2])).all()
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ num_bits: int,
44
+ device: torch.device,
45
+ dtype: torch.dtype,
46
+ convert_to_binary: bool = False,
47
+ ):
48
+ super().__init__()
49
+ self.convert_to_binary = convert_to_binary
50
+ self.bases = 2 ** torch.arange(num_bits - 1, -1, -1, device=device, dtype=dtype)
51
+ self.num_bits = num_bits
52
+ self.zero_tensor = torch.zeros((1,), device=device)
53
+
54
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
55
+ num_features = features.shape[-1]
56
+ if self.num_bits > num_features:
57
+ raise ValueError(f"{num_features=} is less than {self.num_bits=}")
58
+ elif num_features % self.num_bits != 0:
59
+ raise ValueError(f"{num_features=} is not divisible by {self.num_bits=}")
60
+
61
+ binary_features = (
62
+ torch.heaviside(features, self.zero_tensor)
63
+ if self.convert_to_binary
64
+ else features
65
+ )
66
+ feature_parts = binary_features.reshape(shape=(-1, self.num_bits))
67
+ digits = torch.vmap(torch.dot, (None, 0))(
68
+ self.bases, feature_parts.to(self.bases.dtype)
69
+ )
70
+ digits = digits.reshape(shape=(-1, features.shape[-1] // self.num_bits))
71
+ aggregated_digits = torch.sum(digits, dim=-1)
72
+ return aggregated_digits
73
+
74
+
75
+ class SipHash(Module):
76
+ """A Module to Compute SipHash values for given tensors.
77
+
78
+ A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
79
+ and the output shape will be ``[batch_size]``.
80
+
81
+ Args:
82
+ as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
83
+ through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
84
+
85
+ .. warning:: This module relies on the builtin ``hash`` function.
86
+ To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
87
+ variable must be set before the code is run (changing this value during code
88
+ execution is without effect).
89
+
90
+ Examples:
91
+ >>> # Assuming we set PYTHONHASHSEED=0 prior to running this code
92
+ >>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
93
+ >>> b = a.clone()
94
+ >>> hash_module = SipHash(as_tensor=True)
95
+ >>> hash_a = hash_module(a)
96
+ >>> hash_a
97
+ tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521])
98
+ >>> hash_b = hash_module(b)
99
+ >>> assert (hash_a == hash_b).all()
100
+ """
101
+
102
+ def __init__(self, as_tensor: bool = True):
103
+ super().__init__()
104
+ self.as_tensor = as_tensor
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor | list[bytes]:
107
+ hash_values = []
108
+ if x.dtype in (torch.bfloat16,):
109
+ x = x.to(torch.float16)
110
+ for x_i in x.detach().cpu().numpy():
111
+ hash_value = x_i.tobytes()
112
+ hash_values.append(hash_value)
113
+ if not self.as_tensor:
114
+ return hash_values
115
+ result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64)
116
+ return result
117
+
118
+
119
+ class RandomProjectionHash(SipHash):
120
+ """A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`.
121
+
122
+ This module requires sklearn to be installed.
123
+
124
+ Keyword Args:
125
+ n_components (int, optional): the low-dimensional number of components of the projections.
126
+ Defaults to 16.
127
+ dtype_cast (torch.dtype, optional): the dtype to cast the projection to.
128
+ Defaults to ``torch.bfloat16``.
129
+ as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
130
+ through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
131
+
132
+ .. warning:: This module relies on the builtin ``hash`` function.
133
+ To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
134
+ variable must be set before the code is run (changing this value during code
135
+ execution is without effect).
136
+
137
+ init_method: TODO
138
+ """
139
+
140
+ _N_COMPONENTS_DEFAULT = 16
141
+
142
+ def __init__(
143
+ self,
144
+ *,
145
+ n_components: int | None = None,
146
+ dtype_cast=torch.bfloat16,
147
+ as_tensor: bool = True,
148
+ init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None,
149
+ **kwargs,
150
+ ):
151
+ if n_components is None:
152
+ n_components = self._N_COMPONENTS_DEFAULT
153
+
154
+ super().__init__(as_tensor=as_tensor)
155
+ self.register_buffer("_n_components", torch.as_tensor(n_components))
156
+
157
+ self._init = False
158
+ if init_method is None:
159
+ init_method = torch.nn.init.normal_
160
+ self.init_method = init_method
161
+
162
+ self.dtype_cast = dtype_cast
163
+ self.register_buffer("transform", torch.nn.UninitializedBuffer())
164
+
165
+ @property
166
+ def n_components(self):
167
+ return self._n_components.item()
168
+
169
+ def fit(self, x):
170
+ """Fits the random projection to the input data."""
171
+ self.transform.materialize(
172
+ (x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device
173
+ )
174
+ self.init_method(self.transform)
175
+ self._init = True
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ if not self._init:
179
+ self.fit(x)
180
+ elif not self._init:
181
+ raise RuntimeError(
182
+ f"The {type(self).__name__} has not been initialized. Call fit before calling this method."
183
+ )
184
+ x = x.to(self.dtype_cast) @ self.transform
185
+ return super().forward(x)
@@ -0,0 +1,204 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable, Mapping
8
+
9
+ from copy import deepcopy
10
+ from typing import Any, TypeVar
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from tensordict import NestedKey, TensorDictBase
15
+ from tensordict.nn.common import TensorDictModuleBase
16
+ from torchrl._utils import logger as torchrl_logger
17
+ from torchrl.data.map.hash import SipHash
18
+
19
+ K = TypeVar("K")
20
+ V = TypeVar("V")
21
+
22
+
23
+ class HashToInt(nn.Module):
24
+ """Converts a hash value to an integer that can be used for indexing a contiguous storage."""
25
+
26
+ def __init__(self):
27
+ super().__init__()
28
+ self._index_to_index = {}
29
+
30
+ def __call__(self, key: torch.Tensor, extend: bool = False) -> torch.Tensor:
31
+ result = []
32
+ if extend:
33
+ for _item in key.tolist():
34
+ result.append(
35
+ self._index_to_index.setdefault(_item, len(self._index_to_index))
36
+ )
37
+ else:
38
+ for _item in key.tolist():
39
+ result.append(
40
+ self._index_to_index.get(_item, len(self._index_to_index))
41
+ )
42
+ return torch.tensor(result, device=key.device, dtype=key.dtype)
43
+
44
+ def state_dict(self) -> dict[str, torch.Tensor]:
45
+ values = torch.tensor(self._index_to_index.values())
46
+ keys = torch.tensor(self._index_to_index.keys())
47
+ return {"keys": keys, "values": values}
48
+
49
+ def load_state_dict(
50
+ self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
51
+ ):
52
+ keys = state_dict["keys"]
53
+ values = state_dict["values"]
54
+ self._index_to_index = {
55
+ key: val for key, val in zip(keys.tolist(), values.tolist())
56
+ }
57
+
58
+
59
+ class QueryModule(TensorDictModuleBase):
60
+ """A Module to generate compatible indices for storage.
61
+
62
+ A module that queries a storage and return required index of that storage.
63
+ Currently, it only outputs integer indices (torch.int64).
64
+
65
+ Args:
66
+ in_keys (list of NestedKeys): keys of the input tensordict that
67
+ will be used to generate the hash value.
68
+ index_key (NestedKey): the output key where the index value will be written.
69
+ Defaults to ``"_index"``.
70
+
71
+ Keyword Args:
72
+ hash_key (NestedKey): the output key where the hash value will be written.
73
+ Defaults to ``"_hash"``.
74
+ hash_module (Callable[[Any], int] or a list of these, optional): a hash
75
+ module similar to :class:`~tensordict.nn.SipHash` (default).
76
+ If a list of callables is provided, its length must equate the number of in_keys.
77
+ hash_to_int (Callable[[int], int], optional): a stateful function that
78
+ maps a hash value to a non-negative integer corresponding to an index in a
79
+ storage. Defaults to :class:`~torchrl.data.map.HashToInt`.
80
+ aggregator (Callable[[int], int], optional): a hash function to group multiple hashes
81
+ together. This argument should only be passed when there is more than one ``in_keys``.
82
+ If a single ``hash_module`` is provided but no aggregator is passed, it will take
83
+ the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is
84
+ provided but no aggregator is passed, it will default to ``SipHash``.
85
+ clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
86
+ returned. This can be used to retrieve the integer index within the storage,
87
+ corresponding to a given input tensordict. This can be overridden at runtime by
88
+ providing the ``clone`` argument to the forward method.
89
+ Defaults to ``False``.
90
+
91
+ Examples:
92
+ >>> query_module = QueryModule(
93
+ ... in_keys=["key1", "key2"],
94
+ ... index_key="index",
95
+ ... hash_module=SipHash(),
96
+ ... )
97
+ >>> query = TensorDict(
98
+ ... {
99
+ ... "key1": torch.Tensor([[1], [1], [1], [2]]),
100
+ ... "key2": torch.Tensor([[3], [3], [2], [3]]),
101
+ ... "other": torch.randn(4),
102
+ ... },
103
+ ... batch_size=(4,),
104
+ ... )
105
+ >>> res = query_module(query)
106
+ >>> # The first two pairs of key1 and key2 match
107
+ >>> assert res["index"][0] == res["index"][1]
108
+ >>> # The last three pairs of key1 and key2 have at least one mismatching value
109
+ >>> assert res["index"][1] != res["index"][2]
110
+ >>> assert res["index"][2] != res["index"][3]
111
+
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ in_keys: list[NestedKey],
117
+ index_key: NestedKey = "_index",
118
+ hash_key: NestedKey = "_hash",
119
+ *,
120
+ hash_module: Callable[[Any], int] | list[Callable[[Any], int]] | None = None,
121
+ hash_to_int: Callable[[int], int] | None = None,
122
+ aggregator: Callable[[Any], int] = None,
123
+ clone: bool = False,
124
+ ):
125
+ if len(in_keys) == 0:
126
+ raise ValueError("`in_keys` cannot be empty.")
127
+ in_keys = in_keys if isinstance(in_keys, list) else [in_keys]
128
+
129
+ super().__init__()
130
+ in_keys = self.in_keys = in_keys
131
+ self.out_keys = [index_key, hash_key]
132
+ index_key = self.out_keys[0]
133
+ self.hash_key = self.out_keys[1]
134
+
135
+ if aggregator is not None and len(self.in_keys) == 1:
136
+ torchrl_logger.warn(
137
+ "An aggregator was provided but there is only one in-key to be read. "
138
+ "This module will be ignored."
139
+ )
140
+ elif aggregator is None:
141
+ if hash_module is not None and not isinstance(hash_module, list):
142
+ aggregator = hash_module
143
+ else:
144
+ aggregator = SipHash()
145
+ if hash_module is None:
146
+ hash_module = [SipHash() for _ in range(len(self.in_keys))]
147
+ elif not isinstance(hash_module, list):
148
+ try:
149
+ hash_module = [
150
+ deepcopy(hash_module) if len(self.in_keys) > 1 else hash_module
151
+ for _ in range(len(self.in_keys))
152
+ ]
153
+ except Exception as err:
154
+ raise RuntimeError(
155
+ "failed to deepcopy the hash module. Please provide a list of hash modules instead."
156
+ ) from err
157
+ elif len(hash_module) != len(self.in_keys):
158
+ raise ValueError(
159
+ "The number of hash_modules must match the number of in_keys. "
160
+ f"Got {len(hash_module)} hash modules but {len(in_keys)} in_keys."
161
+ )
162
+ if hash_to_int is None:
163
+ hash_to_int = HashToInt()
164
+
165
+ self.aggregator = aggregator
166
+ self.hash_module = dict(zip(self.in_keys, hash_module))
167
+ self.hash_to_int = hash_to_int
168
+
169
+ self.index_key = index_key
170
+ self.clone = clone
171
+
172
+ def forward(
173
+ self,
174
+ tensordict: TensorDictBase,
175
+ *,
176
+ extend: bool = True,
177
+ write_hash: bool = True,
178
+ clone: bool | None = None,
179
+ ) -> TensorDictBase:
180
+ hash_values = []
181
+
182
+ for k in self.in_keys:
183
+ hash_values.append(self.hash_module[k](tensordict.get(k)))
184
+ if len(self.in_keys) > 1:
185
+ hash_values = torch.stack(
186
+ hash_values,
187
+ dim=-1,
188
+ )
189
+ hash_values = self.aggregator(hash_values)
190
+ else:
191
+ hash_values = hash_values[0]
192
+
193
+ td_hash_value = self.hash_to_int(hash_values, extend=extend)
194
+
195
+ clone = clone if clone is not None else self.clone
196
+ if clone:
197
+ output = tensordict.copy()
198
+ else:
199
+ output = tensordict
200
+
201
+ output.set(self.index_key, td_hash_value)
202
+ if write_hash:
203
+ output.set(self.hash_key, hash_values)
204
+ return output