torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,491 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import os
9
+ from collections.abc import Sequence
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from tensordict import TensorDict, TensorDictBase
14
+ from tensordict.utils import NestedKey
15
+ from torchrl._utils import logger as torchrl_logger
16
+ from torchrl.data.replay_buffers import (
17
+ SamplerWithoutReplacement,
18
+ TensorDictReplayBuffer,
19
+ TensorStorage,
20
+ )
21
+
22
+ _has_transformers = importlib.util.find_spec("transformers") is not None
23
+ _has_datasets = importlib.util.find_spec("datasets") is not None
24
+
25
+
26
+ class TokenizedDatasetLoader:
27
+ """Loads a tokenizes dataset, and caches a memory-mapped copy of it.
28
+
29
+ Args:
30
+ split (str): One of ``"train"`` or ``"valid"``.
31
+ max_length (int): the maximum sequence length.
32
+ dataset_name (str): the name of the dataset.
33
+ tokenizer_fn (callable): the tokeinizing method constructor, such as
34
+ :class:`torchrl.data.llm.TensorDictTokenizer`. When called,
35
+ it should return a :class:`tensordict.TensorDict` instance
36
+ or a dictionary-like structure with the tokenized data.
37
+ pre_tokenization_hook (callable, optional): called on
38
+ the Dataset before tokenization. It should return a modified
39
+ Dataset object.
40
+ The intended use is for carrying out tasks that
41
+ require modifying the dataset as a whole as opposed to modifying
42
+ individual datapoints, for example discarding certain datapoints
43
+ based on a particular condition. Tokenization and other
44
+ "elementwise" operations on the data are performed by the process
45
+ function which is mapped over the dataset.
46
+ root_dir (path, optional): the path where the datasets are stored.
47
+ Defaults to ``"$HOME/.cache/torchrl/data"``
48
+ from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
49
+ will be used. Otherwise, :func:`datasets.load_dataset` will be used.
50
+ Defaults to ``False``.
51
+ valid_size (int, optional): the size of the validation dataset (if split
52
+ starts with ``"valid"``) will be truncated to this value.
53
+ Defaults to 2000 items.
54
+ num_workers (int, optional): number of workers for :meth:`datasets.dataset.map`
55
+ which is called during tokenization.
56
+ Defaults to ``max(os.cpu_count() // 2, 1)``.
57
+ tokenizer_class (Type, optional): A tokenizer class, such as
58
+ :class:`~transformers.AutoTokenizer` (default).
59
+ tokenizer_model_name (str, optional): The model from which the vocabulary
60
+ should be gathered. Defaults to ``"gpt2"``.
61
+
62
+ The dataset will be stored in ``<root_dir>/<split>/<max_length>/``.
63
+
64
+ Examples:
65
+ >>> from torchrl.data.llm import TensorDictTokenizer
66
+ >>> from torchrl.data.llm.reward import pre_tokenization_hook
67
+ >>> split = "train"
68
+ >>> max_length = 550
69
+ >>> dataset_name = "CarperAI/openai_summarize_comparisons"
70
+ >>> loader = TokenizedDatasetLoader(
71
+ ... split,
72
+ ... max_length,
73
+ ... dataset_name,
74
+ ... TensorDictTokenizer,
75
+ ... pre_tokenization_hook=pre_tokenization_hook,
76
+ ... )
77
+ >>> dataset = loader.load()
78
+ >>> print(dataset)
79
+ TensorDict(
80
+ fields={
81
+ attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
82
+ input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
83
+ batch_size=torch.Size([185068]),
84
+ device=None,
85
+ is_shared=False)
86
+
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ split,
92
+ max_length,
93
+ dataset_name,
94
+ tokenizer_fn: type[TensorDictTokenizer],
95
+ pre_tokenization_hook=None,
96
+ root_dir=None,
97
+ from_disk=False,
98
+ valid_size: int = 2000,
99
+ num_workers: int | None = None,
100
+ tokenizer_class=None,
101
+ tokenizer_model_name=None,
102
+ ):
103
+ self.split = split
104
+ self.max_length = max_length
105
+ self.dataset_name = dataset_name
106
+ self.tokenizer_fn = tokenizer_fn
107
+ self.pre_tokenization_hook = pre_tokenization_hook
108
+ self.root_dir = root_dir
109
+ self.from_disk = from_disk
110
+ self.valid_size = valid_size
111
+ if num_workers is None:
112
+ num_workers = max(os.cpu_count() // 2, 1)
113
+ self.num_workers = num_workers
114
+ if tokenizer_class is None:
115
+ from transformers import AutoTokenizer
116
+
117
+ tokenizer_class = AutoTokenizer
118
+ if tokenizer_model_name is None:
119
+ tokenizer_model_name = "gpt2"
120
+ self.make_tokenizer(
121
+ tokenizer_class=AutoTokenizer, tokenizer_model_name=tokenizer_model_name
122
+ )
123
+
124
+ def make_tokenizer(self, *, tokenizer_class, tokenizer_model_name):
125
+ tokenizer = tokenizer_class.from_pretrained(tokenizer_model_name)
126
+ tokenizer.pad_token = tokenizer.eos_token
127
+ self.tokenizer = tokenizer
128
+
129
+ def load(self):
130
+ """Loads a pre-processed, memory-mapped dataset if it exists, and creates it otherwise."""
131
+ root_dir = self.root_dir
132
+ max_length = self.max_length
133
+ split = self.split
134
+ if root_dir is None:
135
+ root_dir = Path(os.environ.get("HOME")) / ".cache/torchrl/data/"
136
+ os.makedirs(root_dir, exist_ok=True)
137
+ root_dir = Path(root_dir)
138
+ data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0]
139
+ data_dir_total = data_dir / split / str(max_length)
140
+ # search for data
141
+ torchrl_logger.info(f"Looking for data in {data_dir_total}")
142
+ if os.path.exists(data_dir_total):
143
+ dataset = TensorDict.load_memmap(data_dir_total)
144
+ return dataset
145
+ dataset = self._load_dataset()
146
+ dataset = self._tokenize(dataset)
147
+ prefix = (split, str(max_length))
148
+ result = self.dataset_to_tensordict(
149
+ dataset, data_dir=data_dir, prefix=prefix, valid_mask_key="valid_sample"
150
+ )
151
+ return result[prefix]
152
+
153
+ def _load_dataset(self):
154
+ """Loads a text dataset from ``datasets``.
155
+
156
+ Returns: a dataset of type ``datasets.Dataset``.
157
+ """
158
+ if not _has_datasets:
159
+ raise ImportError(
160
+ "preproc_data requires the datasets package to be installed."
161
+ )
162
+ from datasets import load_dataset, load_from_disk
163
+
164
+ if self.from_disk:
165
+ dataset = load_from_disk(str(self.dataset_name))[self.split]
166
+ else:
167
+ dataset = load_dataset(self.dataset_name, split=self.split)
168
+ if self.split.startswith("valid"):
169
+ # reduce size of validation dataset
170
+ dataset = dataset.select(range(self.valid_size))
171
+ if self.pre_tokenization_hook is not None:
172
+ dataset = self.pre_tokenization_hook(dataset)
173
+ return dataset
174
+
175
+ def _tokenize(
176
+ self,
177
+ dataset,
178
+ excluded_features: Sequence[str] | None = None,
179
+ ):
180
+ """Preprocesses a text dataset from ``datasets``.
181
+
182
+ Args:
183
+ dataset (datasets.Dataset): a dataset loaded using :meth:`load_dataset`.
184
+ excluded_features (sequence of str, optional): the features to exclude
185
+ once tokenization is complete. Defaults to ``{"text", "prompt", "label", "valid_sample"}``.
186
+
187
+ Returns: a dataset of type ``datasets.Dataset``.
188
+ """
189
+ if not _has_transformers:
190
+ raise ImportError("The transformers library is missing.")
191
+
192
+ num_workers = self.num_workers
193
+ if excluded_features is None:
194
+ excluded_features = {"text", "prompt", "label", "valid_sample"}
195
+ tokenizer = self.tokenizer
196
+ # tokenize the dataset
197
+ # TODO: replace this by TensorDict.map
198
+ dataset = dataset.map(
199
+ self.tokenizer_fn(
200
+ tokenizer, max_length=self.max_length, return_tensordict=False
201
+ ),
202
+ desc="Tokenizing...",
203
+ num_proc=num_workers,
204
+ batched=True,
205
+ )
206
+ if not isinstance(dataset, TensorDictBase):
207
+ dataset_dict = dataset.to_dict()
208
+ if excluded_features:
209
+ dataset_dict = {
210
+ key: value
211
+ for key, value in dataset_dict.items()
212
+ if key not in excluded_features
213
+ }
214
+ dataset = TensorDict.from_dict(
215
+ dataset_dict, auto_batch_size=True, batch_dims=1
216
+ )
217
+ elif excluded_features:
218
+ dataset = dataset.exclude(*excluded_features)
219
+ # keep non empty rows (i.e. where at least one token is not eos)
220
+ if "valid_sample" in dataset.keys():
221
+ mask = dataset.get("valid_sample")
222
+ dataset = dataset[mask]
223
+ return dataset
224
+
225
+ @staticmethod
226
+ def dataset_to_tensordict(
227
+ dataset: datasets.Dataset | TensorDict, # noqa: F821
228
+ data_dir: Path,
229
+ prefix: NestedKey = None,
230
+ features: Sequence[str] = None,
231
+ batch_dims=1,
232
+ valid_mask_key=None,
233
+ ):
234
+ """Converts a dataset to a memory-mapped TensorDict.
235
+
236
+ If the dataset is already a :class:`TensorDict` instance, it is simply converted
237
+ to a memory-mapped TensorDict.
238
+ Otherwise, the dataset is expected to have a ``features`` attribute
239
+ which is a sequence of strings indicating the features that can be found
240
+ in the dataset. If it does not, the ``features`` must be passed explicitly
241
+ to this function.
242
+
243
+ Args:
244
+ dataset (datasets.Dataset, TensorDict or equivalent): a dataset to convert
245
+ to a memory-mapped TensorDict.
246
+ If ``features`` is ``None``, it must have a ``features`` attribute
247
+ with the list of keys to write in the tensordict.
248
+ data_dir (Path or equivalent): directory where the data should be written.
249
+ prefix (NestedKey, optional): the prefix of the dataset location. This can
250
+ be used to differentiate several copies of a same dataset that have
251
+ undergone different preprocessings.
252
+ features (sequence of str, optional): a sequence of str indicating the
253
+ features that can be found in the dataset.
254
+ batch_dims (int, optional): the number of batch_dimensions of the data
255
+ (ie number of dimensions along which the tensordict can be indexed).
256
+ Defaults to 1.
257
+ valid_mask_key (NestedKey, optional): if provided, this entry will be
258
+ tentatively gathered and used to filder the data. Defaults to
259
+ ``None`` (ie, no filter key).
260
+
261
+ Returns: a TensorDict containing memory-mapped tensors with the dataset.
262
+
263
+ Examples:
264
+ >>> from datasets import Dataset
265
+ >>> import tempfile
266
+ >>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
267
+ >>> with tempfile.TemporaryDirectory() as tmpdir:
268
+ ... data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
269
+ ... data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
270
+ ... )
271
+ ... print(data_memmap)
272
+ TensorDict(
273
+ fields={
274
+ some: TensorDict(
275
+ fields={
276
+ prefix: TensorDict(
277
+ fields={
278
+ labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
279
+ tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
280
+ batch_size=torch.Size([10]),
281
+ device=None,
282
+ is_shared=False)},
283
+ batch_size=torch.Size([]),
284
+ device=None,
285
+ is_shared=False)},
286
+ batch_size=torch.Size([]),
287
+ device=None,
288
+ is_shared=False)
289
+
290
+ """
291
+ if not isinstance(dataset, TensorDict):
292
+ if features is None:
293
+ features = dataset.features
294
+ if prefix is None:
295
+ prefix = ()
296
+ data_dict = {key: torch.as_tensor(dataset[key]) for key in features}
297
+ out = TensorDict.from_dict(
298
+ data_dict, batch_dims=batch_dims, auto_batch_size=True
299
+ )
300
+ else:
301
+ out = dataset
302
+ if valid_mask_key is not None and valid_mask_key in out.keys(
303
+ include_nested=True
304
+ ):
305
+ out = out[out.get(valid_mask_key)]
306
+ out = TensorDict({prefix: out})
307
+ out.memmap_(prefix=data_dir)
308
+ return out
309
+
310
+
311
+ def create_infinite_iterator(iterator):
312
+ """Iterates indefinitely over an iterator."""
313
+ while True:
314
+ yield from iterator
315
+
316
+
317
+ def get_dataloader(
318
+ batch_size: int,
319
+ block_size: int,
320
+ tensorclass_type: type,
321
+ device: torch.device,
322
+ dataset_name: str | None = None,
323
+ infinite: bool = True,
324
+ prefetch: int = 0,
325
+ split: str = "train",
326
+ root_dir: str | None = None,
327
+ from_disk: bool = False,
328
+ num_workers: int | None = None,
329
+ ):
330
+ """Creates a dataset and returns a dataloader from it.
331
+
332
+ Args:
333
+ batch_size (int): the batch size of the dataloader samples.
334
+ block_size (int): the maximum length of a sequence in the dataloader.
335
+ tensorclass_type (tensorclass class): a tensorclass with a :meth:`from_dataset`
336
+ method that must accept three keyword arguments: ``split`` (see below),
337
+ ``max_length`` which is the block size to be used for training and
338
+ ``dataset_name``, a string indicating the dataset. The ``root_dir``
339
+ and ``from_disk`` arguments should also be supported.
340
+ device (torch.device or equivalent): the device where the samples should
341
+ be cast.
342
+ dataset_name (str, optional): the dataset name. If not provided and if
343
+ the tensorclass supports it, a default dataset name will be gathered
344
+ for the tensorclass being used.
345
+ infinite (bool, optional): if ``True``, the iteration will be infinite
346
+ such that ``next(iterator)`` will always return a value.
347
+ Defaults to ``True``.
348
+ prefetch (int, optional): the number of items to be prefetched if
349
+ multithreaded dataloading is being used.
350
+ split (str, optional): the data split. Either ``"train"`` or ``"valid"``.
351
+ Defaults to ``"train"``.
352
+ root_dir (path, optional): the path where the datasets are stored.
353
+ Defaults to ``"$HOME/.cache/torchrl/data"``
354
+ from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
355
+ will be used. Otherwise, :func:`datasets.load_dataset` will be used.
356
+ Defaults to ``False``.
357
+ num_workers (int, optional): number of workers for :meth:`datasets.dataset.map`
358
+ which is called during tokenization.
359
+ Defaults to ``max(os.cpu_count() // 2, 1)``.
360
+
361
+ Examples:
362
+ >>> from torchrl.data.llm.reward import PairwiseDataset
363
+ >>> dataloader = get_dataloader(
364
+ ... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
365
+ >>> for d in dataloader:
366
+ ... print(d)
367
+ ... break
368
+ PairwiseDataset(
369
+ chosen_data=RewardData(
370
+ attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
371
+ input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
372
+ rewards=None,
373
+ end_scores=None,
374
+ batch_size=torch.Size([256]),
375
+ device=cpu,
376
+ is_shared=False),
377
+ rejected_data=RewardData(
378
+ attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
379
+ input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
380
+ rewards=None,
381
+ end_scores=None,
382
+ batch_size=torch.Size([256]),
383
+ device=cpu,
384
+ is_shared=False),
385
+ batch_size=torch.Size([256]),
386
+ device=cpu,
387
+ is_shared=False)
388
+ """
389
+ data = tensorclass_type.from_dataset(
390
+ split=split,
391
+ dataset_name=dataset_name,
392
+ max_length=block_size,
393
+ root_dir=root_dir,
394
+ from_disk=from_disk,
395
+ num_workers=num_workers,
396
+ )
397
+ out = TensorDictReplayBuffer(
398
+ storage=TensorStorage(data),
399
+ collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True),
400
+ sampler=SamplerWithoutReplacement(drop_last=True),
401
+ batch_size=batch_size,
402
+ prefetch=prefetch,
403
+ )
404
+ if infinite:
405
+ return create_infinite_iterator(out)
406
+ return out
407
+
408
+
409
+ class TensorDictTokenizer:
410
+ """Factory for a process function that applies a tokenizer over a text example.
411
+
412
+ Args:
413
+ tokenizer (tokenizer from transformers library): the tokenizer to use.
414
+ max_length (int): maximum length of the sequence.
415
+ key (str, optional): the key where to find the text. Defaults to ``"text"``.
416
+ padding (str, optional): type of padding. Defaults to ``"max_length"``.
417
+ truncation (bool, optional): whether the sequences should be truncated to max_length.
418
+ return_tensordict (bool, optional): if ``True``, a TensoDict is returned.
419
+ Otherwise, a the original data will be returned.
420
+ device (torch.device, optional): the device where to store the data.
421
+ This option is ignored if ``return_tensordict=False``.
422
+
423
+ See transformers library for more information about tokenizers:
424
+ Padding and truncation: `<https://huggingface.co/docs/transformers/pad_truncation>`_
425
+
426
+ Returns: a :class:`tensordict.TensorDict` instance with the same batch-size
427
+ as the input data.
428
+
429
+ Examples:
430
+ >>> from transformers import AutoTokenizer
431
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
432
+ >>> tokenizer.pad_token = 100
433
+ >>> process = TensorDictTokenizer(tokenizer, max_length=10)
434
+ >>> # example with a single input
435
+ >>> example = {"text": "I am a little worried"}
436
+ >>> process(example)
437
+ TensorDict(
438
+ fields={
439
+ attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
440
+ input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
441
+ batch_size=torch.Size([]),
442
+ device=None,
443
+ is_shared=False)
444
+ >>> # example with a multiple inputs
445
+ >>> example = {"text": ["Let me reassure you", "It will be ok"]}
446
+ >>> process(example)
447
+ TensorDict(
448
+ fields={
449
+ attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
450
+ input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
451
+ batch_size=torch.Size([2]),
452
+ device=None,
453
+ is_shared=False)
454
+
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ tokenizer,
460
+ max_length,
461
+ key="text",
462
+ padding="max_length",
463
+ truncation=True,
464
+ return_tensordict=True,
465
+ device=None,
466
+ ):
467
+ self.tokenizer = tokenizer
468
+ self.max_length = max_length
469
+ self.key = key
470
+ self.padding = padding
471
+ self.truncation = truncation
472
+ self.return_tensordict = return_tensordict
473
+ self.device = device
474
+
475
+ def __call__(self, sample):
476
+ input = sample[self.key]
477
+ tokenized_sample = self.tokenizer(
478
+ input,
479
+ max_length=self.max_length,
480
+ padding=self.padding,
481
+ truncation=self.truncation,
482
+ )
483
+ batch_size = [] if isinstance(input, str) else [len(input)]
484
+ if self.return_tensordict:
485
+ return TensorDict.from_dict(
486
+ dict(tokenized_sample),
487
+ batch_size=batch_size,
488
+ device=self.device,
489
+ auto_batch_size=True,
490
+ )
491
+ return tokenized_sample