torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,321 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from collections.abc import Sequence
9
+
10
+ import torch
11
+ from tensordict import NonTensorData, NonTensorStack, TensorDictBase
12
+ from tensordict.nn import dispatch
13
+ from tensordict.utils import _zip_strict, NestedKey
14
+ from torch import Tensor
15
+ from torchrl._utils import _replace_last
16
+ from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
17
+ from torchrl.envs import Transform, UnaryTransform
18
+ from torchrl.envs.transforms.utils import _set_missing_tolerance
19
+
20
+
21
+ class Tokenizer(UnaryTransform):
22
+ r"""Applies a tokenization operation on the specified inputs.
23
+
24
+ Args:
25
+ in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
26
+ out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
27
+ in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
28
+ out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
29
+
30
+ Keyword Args:
31
+ tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
32
+ "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
33
+ pre-trained tokenizer.
34
+ use_raw_nontensor (bool, optional): if ``False``, data is extracted from
35
+ :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
36
+ function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
37
+ inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
38
+ additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
39
+
40
+ .. note:: This transform can be used both to transform output strings into tokens and to transform back tokenized
41
+ actions or states into strings. If the environment has a string state-spec, the transformed version will have
42
+ a tokenized state-spec. If it is a string action spec, it will result in a tokenized action spec.
43
+
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ in_keys: Sequence[NestedKey] | None = None,
49
+ out_keys: Sequence[NestedKey] | None = None,
50
+ in_keys_inv: Sequence[NestedKey] | None = None,
51
+ out_keys_inv: Sequence[NestedKey] | None = None,
52
+ *,
53
+ tokenizer: transformers.PretrainedTokenizerBase = None, # noqa: F821
54
+ use_raw_nontensor: bool = False,
55
+ additional_tokens: list[str] | None = None,
56
+ skip_special_tokens: bool = True,
57
+ add_special_tokens: bool = False,
58
+ padding: bool = True,
59
+ max_length: int | None = None,
60
+ return_attention_mask: bool = True,
61
+ missing_tolerance: bool = True,
62
+ call_before_reset: bool = False,
63
+ ):
64
+ if tokenizer is None:
65
+ from transformers import AutoTokenizer
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
68
+ elif isinstance(tokenizer, str):
69
+ from transformers import AutoTokenizer
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
72
+
73
+ self.tokenizer = tokenizer
74
+ self.add_special_tokens = add_special_tokens
75
+ self.skip_special_tokens = skip_special_tokens
76
+ self.padding = padding
77
+ self.max_length = max_length
78
+ self.return_attention_mask = return_attention_mask
79
+ self.call_before_reset = call_before_reset
80
+ if additional_tokens:
81
+ self.tokenizer.add_tokens(additional_tokens)
82
+ super().__init__(
83
+ in_keys=in_keys,
84
+ out_keys=out_keys,
85
+ in_keys_inv=in_keys_inv,
86
+ out_keys_inv=out_keys_inv,
87
+ fn=self.call_tokenizer_fn,
88
+ inv_fn=self.call_tokenizer_inv_fn,
89
+ use_raw_nontensor=use_raw_nontensor,
90
+ )
91
+ self._missing_tolerance = missing_tolerance
92
+
93
+ @property
94
+ def device(self):
95
+ if "_device" in self.__dict__:
96
+ return self._device
97
+ parent = self.parent
98
+ if parent is None:
99
+ return None
100
+ device = parent.device
101
+ self._device = device
102
+ return device
103
+
104
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
105
+ # Specialized for attention mask
106
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
107
+ value = next_tensordict.get(in_key, default=None)
108
+ if value is not None:
109
+ observation = self._apply_transform(value)
110
+ if self.return_attention_mask:
111
+ observation, attention_mask = observation
112
+ next_tensordict.set(
113
+ _replace_last(out_key, "attention_mask"),
114
+ attention_mask,
115
+ )
116
+ next_tensordict.set(
117
+ out_key,
118
+ observation,
119
+ )
120
+ elif (
121
+ self.missing_tolerance
122
+ and self.return_attention_mask
123
+ and out_key in next_tensordict.keys(True)
124
+ ):
125
+ attention_key = _replace_last(out_key, "attention_mask")
126
+ if attention_key not in next_tensordict:
127
+ next_tensordict[attention_key] = torch.ones_like(
128
+ next_tensordict.get(out_key)
129
+ )
130
+ elif not self.missing_tolerance:
131
+ raise KeyError(
132
+ f"{self}: '{in_key}' not found in tensordict {next_tensordict}"
133
+ )
134
+ return next_tensordict
135
+
136
+ @dispatch(source="in_keys", dest="out_keys")
137
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
138
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
139
+ data = tensordict.get(in_key, None)
140
+ if data is not None:
141
+ data = self._apply_transform(data)
142
+ if self.return_attention_mask:
143
+ data, attention_mask = data
144
+ tensordict.set(
145
+ _replace_last(out_key, "attention_mask"),
146
+ attention_mask,
147
+ )
148
+ tensordict.set(out_key, data)
149
+ elif not self.missing_tolerance:
150
+ raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
151
+ return tensordict
152
+
153
+ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
154
+ if self.call_before_reset:
155
+ with _set_missing_tolerance(self, True):
156
+ tensordict = self._call(tensordict)
157
+ return tensordict
158
+
159
+ def _reset(
160
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
161
+ ) -> TensorDictBase:
162
+ if self.call_before_reset:
163
+ return tensordict_reset
164
+ return super()._reset(tensordict, tensordict_reset)
165
+
166
+ def call_tokenizer_fn(self, value: str | list[str]):
167
+ device = self.device
168
+ kwargs = {"add_special_tokens": self.add_special_tokens}
169
+ if self.max_length is not None:
170
+ kwargs["padding"] = "max_length"
171
+ kwargs["max_length"] = self.max_length
172
+ if isinstance(value, str):
173
+ out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0]
174
+ # TODO: incorporate attention mask
175
+ if self.return_attention_mask:
176
+ attention_mask = torch.ones_like(out, dtype=torch.int64)
177
+ else:
178
+ kwargs["padding"] = (
179
+ self.padding if self.max_length is None else "max_length"
180
+ )
181
+ kwargs["return_attention_mask"] = self.return_attention_mask
182
+ # kwargs["return_token_type_ids"] = False
183
+ out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs)
184
+ if self.return_attention_mask:
185
+ attention_mask = out["attention_mask"]
186
+ out = out["input_ids"]
187
+
188
+ if device is not None and out.device != device:
189
+ out = out.to(device)
190
+ if self.return_attention_mask:
191
+ attention_mask = attention_mask.to(device)
192
+ if self.return_attention_mask:
193
+ return out, attention_mask
194
+ return out
195
+
196
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
197
+ # Override _inv_call to account for ragged dims
198
+ if not self.in_keys_inv:
199
+ return tensordict
200
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
201
+ data = tensordict.get(out_key, None, as_padded_tensor=True)
202
+ if data is not None:
203
+ item = self._inv_apply_transform(data)
204
+ tensordict.set(in_key, item)
205
+ elif not self.missing_tolerance:
206
+ raise KeyError(f"'{out_key}' not found in tensordict {tensordict}")
207
+ return tensordict
208
+
209
+ def call_tokenizer_inv_fn(self, value: Tensor):
210
+ if value.ndim == 1:
211
+ out = self.tokenizer.decode(
212
+ value.int(), skip_special_tokens=self.skip_special_tokens
213
+ )
214
+ else:
215
+ out = self.tokenizer.batch_decode(
216
+ value.int(), skip_special_tokens=self.skip_special_tokens
217
+ )
218
+ device = self._str_device
219
+ if isinstance(out, list):
220
+ result = NonTensorStack(*out)
221
+ if device:
222
+ result = result.to(device)
223
+ return result
224
+ return NonTensorData(out, device=device)
225
+
226
+ @property
227
+ def _str_device(self):
228
+ parent = self.parent
229
+ if parent is None:
230
+ return None
231
+ if self.in_keys:
232
+ in_key = self.in_keys[0]
233
+ elif self.in_keys_inv:
234
+ in_key = self.in_keys_inv[0]
235
+ else:
236
+ return None
237
+ if in_key in parent.observation_keys:
238
+ return parent.full_observation_spec[in_key].device
239
+ if in_key in parent.action_keys:
240
+ return parent.full_action_spec[in_key].device
241
+ if in_key in parent.state_keys:
242
+ return parent.full_state_spec[in_key].device
243
+ return None
244
+
245
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
246
+ # We need to cap the spec to generate valid random strings
247
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
248
+ if in_key in input_spec["full_state_spec"].keys(True, True):
249
+ spec = input_spec["full_state_spec"]
250
+ elif in_key in input_spec["full_action_spec"].keys(False, True):
251
+ spec = input_spec["full_action_spec"]
252
+ else:
253
+ raise KeyError(
254
+ f"The input keys {in_key} wasn't found in the env input specs."
255
+ )
256
+ local_spec = spec.pop(in_key)
257
+ local_dtype = local_spec.dtype
258
+ if local_dtype is None or local_dtype.is_floating_point:
259
+ local_dtype = torch.int64
260
+ new_shape = spec.shape
261
+ if self.max_length is None:
262
+ # Then we can't tell what the shape will be
263
+ new_shape = new_shape + torch.Size((-1,))
264
+ else:
265
+ new_shape = new_shape + torch.Size((self.max_length,))
266
+ spec[out_key] = Bounded(
267
+ 0,
268
+ self.tokenizer.vocab_size,
269
+ shape=new_shape,
270
+ device=local_spec.device,
271
+ dtype=local_dtype,
272
+ )
273
+ return input_spec
274
+
275
+ transform_output_spec = Transform.transform_output_spec
276
+ transform_reward_spec = Transform.transform_reward_spec
277
+ transform_done_spec = Transform.transform_done_spec
278
+
279
+ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
280
+ attention_mask_keys = set()
281
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
282
+ new_shape = observation_spec.shape + torch.Size((-1,))
283
+ try:
284
+ in_spec = observation_spec[in_key]
285
+ obs_dtype = in_spec.dtype
286
+ device = in_spec.device
287
+ except KeyError:
288
+ # In some cases (eg, the tokenizer is applied during reset on data that
289
+ # originates from a dataloader) we don't have an in_spec
290
+ in_spec = None
291
+ obs_dtype = None
292
+ device = observation_spec.device
293
+ if obs_dtype is None or obs_dtype.is_floating_point:
294
+ obs_dtype = torch.int64
295
+ observation_spec[out_key] = Bounded(
296
+ 0,
297
+ self.tokenizer.vocab_size,
298
+ shape=new_shape,
299
+ device=device,
300
+ dtype=obs_dtype,
301
+ )
302
+ if self.return_attention_mask:
303
+ attention_mask_key = _replace_last(out_key, "attention_mask")
304
+ if attention_mask_key in attention_mask_keys:
305
+ raise KeyError(
306
+ "Conflicting attention_mask keys. Make sure the token tensors are "
307
+ "nested at different places in the tensordict such that `(*root, 'attention_mask')` "
308
+ "entries are unique."
309
+ )
310
+ attention_mask_keys.add(attention_mask_key)
311
+ attention_dtype = obs_dtype
312
+ if attention_dtype is None or attention_dtype.is_floating_point:
313
+ attention_dtype = torch.int64
314
+ observation_spec[attention_mask_key] = Bounded(
315
+ 0,
316
+ 2,
317
+ shape=new_shape,
318
+ device=device,
319
+ dtype=attention_dtype,
320
+ )
321
+ return observation_spec