torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,59 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ from torchrl._utils import _make_ordinal_device
11
+
12
+ from torchrl.data.replay_buffers.replay_buffers import (
13
+ ReplayBuffer,
14
+ TensorDictReplayBuffer,
15
+ )
16
+ from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
17
+ from torchrl.data.replay_buffers.storages import LazyMemmapStorage
18
+ from torchrl.data.utils import DEVICE_TYPING
19
+
20
+
21
+ def make_replay_buffer(
22
+ device: DEVICE_TYPING, cfg: DictConfig # noqa: F821
23
+ ) -> ReplayBuffer: # noqa: F821
24
+ """Builds a replay buffer using the config built from ReplayArgsConfig."""
25
+ device = _make_ordinal_device(torch.device(device))
26
+ if not cfg.prb:
27
+ sampler = RandomSampler()
28
+ else:
29
+ sampler = PrioritizedSampler(
30
+ max_capacity=cfg.buffer_size,
31
+ alpha=0.7,
32
+ beta=0.5,
33
+ )
34
+ buffer = TensorDictReplayBuffer(
35
+ storage=LazyMemmapStorage(
36
+ cfg.buffer_size,
37
+ scratch_dir=cfg.buffer_scratch_dir,
38
+ # device=device, # when using prefetch, this can overload the GPU memory
39
+ ),
40
+ sampler=sampler,
41
+ pin_memory=device != torch.device("cpu"),
42
+ prefetch=cfg.buffer_prefetch,
43
+ batch_size=cfg.batch_size,
44
+ )
45
+ return buffer
46
+
47
+
48
+ @dataclass
49
+ class ReplayArgsConfig:
50
+ """Generic Replay Buffer config struct."""
51
+
52
+ buffer_size: int = 1000000
53
+ # buffer size, in number of frames stored. Default=1e6
54
+ prb: bool = False
55
+ # whether a Prioritized replay buffer should be used instead of a more basic circular one.
56
+ buffer_scratch_dir: str | None = None
57
+ # directory where the buffer data should be stored. If none is passed, they will be placed in /tmp/
58
+ buffer_prefetch: int = 10
59
+ # prefetching queue length for the replay buffer
@@ -0,0 +1,301 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from warnings import warn
9
+
10
+ import torch
11
+ from tensordict.nn import TensorDictModule, TensorDictModuleWrapper
12
+ from torch import optim
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+
15
+ from torchrl._utils import logger as torchrl_logger, VERBOSE
16
+ from torchrl.collectors import BaseCollector
17
+ from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
18
+ from torchrl.envs.common import EnvBase
19
+ from torchrl.envs.utils import ExplorationType
20
+ from torchrl.modules import reset_noise
21
+ from torchrl.objectives.common import LossModule
22
+ from torchrl.objectives.utils import TargetNetUpdater
23
+ from torchrl.record.loggers import Logger
24
+ from torchrl.trainers.trainers import (
25
+ BatchSubSampler,
26
+ ClearCudaCache,
27
+ CountFramesLog,
28
+ LogScalar,
29
+ LogValidationReward,
30
+ ReplayBufferTrainer,
31
+ RewardNormalizer,
32
+ SelectKeys,
33
+ Trainer,
34
+ UpdateWeights,
35
+ )
36
+
37
+ OPTIMIZERS = {
38
+ "adam": optim.Adam,
39
+ "sgd": optim.SGD,
40
+ "adamax": optim.Adamax,
41
+ }
42
+
43
+
44
+ @dataclass
45
+ class TrainerConfig:
46
+ """Trainer config struct."""
47
+
48
+ optim_steps_per_batch: int = 500
49
+ # Number of optimization steps in between two collection of data. See frames_per_batch below.
50
+ optimizer: str = "adam"
51
+ # Optimizer to be used.
52
+ lr_scheduler: str = "cosine"
53
+ # LR scheduler.
54
+ selected_keys: list | None = None
55
+ # a list of strings that indicate the data that should be kept from the data collector. Since storing and
56
+ # retrieving information from the replay buffer does not come for free, limiting the amount of data
57
+ # passed to it can improve the algorithm performance.
58
+ batch_size: int = 256
59
+ # batch size of the TensorDict retrieved from the replay buffer. Default=256.
60
+ log_interval: int = 10000
61
+ # logging interval, in terms of optimization steps. Default=10000.
62
+ lr: float = 3e-4
63
+ # Learning rate used for the optimizer. Default=3e-4.
64
+ weight_decay: float = 0.0
65
+ # Weight-decay to be used with the optimizer. Default=0.0.
66
+ clip_norm: float = 1000.0
67
+ # value at which the total gradient norm / single derivative should be clipped. Default=1000.0
68
+ clip_grad_norm: bool = False
69
+ # if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold.
70
+ normalize_rewards_online: bool = False
71
+ # Computes the running statistics of the rewards and normalizes them before they are passed to the loss module.
72
+ normalize_rewards_online_scale: float = 1.0
73
+ # Final scale of the normalized rewards.
74
+ normalize_rewards_online_decay: float = 0.9999
75
+ # Decay of the reward moving averaging
76
+ sub_traj_len: int = -1
77
+ # length of the trajectories that sub-samples must have in online settings.
78
+
79
+
80
+ def make_trainer(
81
+ collector: BaseCollector,
82
+ loss_module: LossModule,
83
+ recorder: EnvBase | None = None,
84
+ target_net_updater: TargetNetUpdater | None = None,
85
+ policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None,
86
+ replay_buffer: ReplayBuffer | None = None,
87
+ logger: Logger | None = None,
88
+ cfg: DictConfig = None, # noqa: F821
89
+ ) -> Trainer:
90
+ """Creates a Trainer instance given its constituents.
91
+
92
+ Args:
93
+ collector (BaseCollector): A data collector to be used to collect data.
94
+ loss_module (LossModule): A TorchRL loss module
95
+ recorder (EnvBase, optional): a recorder environment. If None, the trainer will train the policy without
96
+ testing it.
97
+ target_net_updater (TargetNetUpdater, optional): A target network update object.
98
+ policy_exploration (TDModule or TensorDictModuleWrapper, optional): a policy to be used for recording and exploration
99
+ updates (should be synced with the learnt policy).
100
+ replay_buffer (ReplayBuffer, optional): a replay buffer to be used to collect data.
101
+ logger (Logger, optional): a Logger to be used for logging.
102
+ cfg (DictConfig, optional): a DictConfig containing the arguments of the script. If None, the default
103
+ arguments are used.
104
+
105
+ Returns:
106
+ A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided.
107
+
108
+ Examples:
109
+ >>> import torch
110
+ >>> import tempfile
111
+ >>> from torchrl.trainers.loggers import TensorboardLogger
112
+ >>> from torchrl.trainers import Trainer
113
+ >>> from torchrl.envs import EnvCreator
114
+ >>> from torchrl.collectors import Collector
115
+ >>> from torchrl.data import TensorDictReplayBuffer
116
+ >>> from torchrl.envs.libs.gym import GymEnv
117
+ >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper
118
+ >>> from torchrl.objectives.common import LossModule
119
+ >>> from torchrl.objectives.utils import TargetNetUpdater
120
+ >>> from torchrl.objectives import DDPGLoss
121
+ >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
122
+ >>> env_proof = env_maker()
123
+ >>> obs_spec = env_proof.observation_spec
124
+ >>> action_spec = env_proof.action_spec
125
+ >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1])
126
+ >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing
127
+ >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"])
128
+ >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"])
129
+ >>> collector = Collector(env_maker, policy, total_frames=100)
130
+ >>> loss_module = DDPGLoss(policy, value, gamma=0.99)
131
+ >>> recorder = env_proof
132
+ >>> target_net_updater = None
133
+ >>> policy_exploration = EGreedyWrapper(policy)
134
+ >>> replay_buffer = TensorDictReplayBuffer()
135
+ >>> dir = tempfile.gettempdir()
136
+ >>> logger = TensorboardLogger(exp_name=dir)
137
+ >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration,
138
+ ... replay_buffer, logger)
139
+ >>> print(trainer)
140
+
141
+ """
142
+ if cfg is None:
143
+ warn(
144
+ "Getting default cfg for the trainer. "
145
+ "This should be only used for debugging."
146
+ )
147
+ cfg = TrainerConfig()
148
+ cfg.frame_skip = 1
149
+ cfg.total_frames = 1000
150
+ cfg.record_frames = 10
151
+ cfg.record_interval = 10
152
+
153
+ optimizer_kwargs = {} if cfg.optimizer != "adam" else {"betas": (0.0, 0.9)}
154
+ optimizer = OPTIMIZERS[cfg.optimizer](
155
+ loss_module.parameters(),
156
+ lr=cfg.lr,
157
+ weight_decay=cfg.weight_decay,
158
+ **optimizer_kwargs,
159
+ )
160
+ device = next(loss_module.parameters()).device
161
+ if cfg.lr_scheduler == "cosine":
162
+ optim_scheduler = CosineAnnealingLR(
163
+ optimizer,
164
+ T_max=int(
165
+ cfg.total_frames / cfg.frames_per_batch * cfg.optim_steps_per_batch
166
+ ),
167
+ )
168
+ elif cfg.lr_scheduler == "":
169
+ optim_scheduler = None
170
+ else:
171
+ raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}")
172
+
173
+ if VERBOSE:
174
+ torchrl_logger.info(
175
+ f"collector = {collector}; \n"
176
+ f"loss_module = {loss_module}; \n"
177
+ f"recorder = {recorder}; \n"
178
+ f"target_net_updater = {target_net_updater}; \n"
179
+ f"policy_exploration = {policy_exploration}; \n"
180
+ f"replay_buffer = {replay_buffer}; \n"
181
+ f"logger = {logger}; \n"
182
+ f"cfg = {cfg}; \n"
183
+ )
184
+
185
+ if logger is not None:
186
+ # log hyperparams
187
+ logger.log_hparams(cfg)
188
+
189
+ trainer = Trainer(
190
+ collector=collector,
191
+ frame_skip=cfg.frame_skip,
192
+ total_frames=cfg.total_frames * cfg.frame_skip,
193
+ loss_module=loss_module,
194
+ optimizer=optimizer,
195
+ logger=logger,
196
+ optim_steps_per_batch=cfg.optim_steps_per_batch,
197
+ clip_grad_norm=cfg.clip_grad_norm,
198
+ clip_norm=cfg.clip_norm,
199
+ )
200
+
201
+ if torch.cuda.device_count() > 0:
202
+ trainer.register_op("pre_optim_steps", ClearCudaCache(1))
203
+
204
+ if hasattr(cfg, "noisy") and cfg.noisy:
205
+ trainer.register_op("pre_optim_steps", lambda: loss_module.apply(reset_noise))
206
+
207
+ if cfg.selected_keys:
208
+ trainer.register_op("batch_process", SelectKeys(cfg.selected_keys))
209
+ trainer.register_op("batch_process", lambda batch: batch.cpu())
210
+
211
+ if replay_buffer is not None:
212
+ # replay buffer is used 2 or 3 times: to register data, to sample
213
+ # data and to update priorities
214
+ rb_trainer = ReplayBufferTrainer(
215
+ replay_buffer,
216
+ cfg.batch_size,
217
+ flatten_tensordicts=False,
218
+ memmap=False,
219
+ device=device,
220
+ )
221
+
222
+ trainer.register_op("batch_process", rb_trainer.extend)
223
+ trainer.register_op("process_optim_batch", rb_trainer.sample)
224
+ trainer.register_op("post_loss", rb_trainer.update_priority)
225
+ else:
226
+ # trainer.register_op("batch_process", mask_batch)
227
+ trainer.register_op(
228
+ "process_optim_batch",
229
+ BatchSubSampler(batch_size=cfg.batch_size, sub_traj_len=cfg.sub_traj_len),
230
+ )
231
+ trainer.register_op("process_optim_batch", lambda batch: batch.to(device))
232
+
233
+ if optim_scheduler is not None:
234
+ trainer.register_op("post_optim", optim_scheduler.step)
235
+
236
+ if target_net_updater is not None:
237
+ trainer.register_op("post_optim", target_net_updater.step)
238
+
239
+ if cfg.normalize_rewards_online:
240
+ # if used the running statistics of the rewards are computed and the
241
+ # rewards used for training will be normalized based on these.
242
+ reward_normalizer = RewardNormalizer(
243
+ scale=cfg.normalize_rewards_online_scale,
244
+ decay=cfg.normalize_rewards_online_decay,
245
+ )
246
+ trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
247
+ trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)
248
+
249
+ if policy_exploration is not None and hasattr(policy_exploration, "step"):
250
+ trainer.register_op(
251
+ "post_steps", policy_exploration.step, frames=cfg.frames_per_batch
252
+ )
253
+
254
+ trainer.register_op(
255
+ "post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]}
256
+ )
257
+
258
+ if recorder is not None:
259
+ # create recorder object
260
+ recorder_obj = LogValidationReward(
261
+ record_frames=cfg.record_frames,
262
+ frame_skip=cfg.frame_skip,
263
+ policy_exploration=policy_exploration,
264
+ environment=recorder,
265
+ record_interval=cfg.record_interval,
266
+ log_keys=cfg.recorder_log_keys,
267
+ )
268
+ # register recorder
269
+ trainer.register_op(
270
+ "post_steps_log",
271
+ recorder_obj,
272
+ )
273
+ # call recorder - could be removed
274
+ recorder_obj(None)
275
+ # create explorative recorder - could be optional
276
+ recorder_obj_explore = LogValidationReward(
277
+ record_frames=cfg.record_frames,
278
+ frame_skip=cfg.frame_skip,
279
+ policy_exploration=policy_exploration,
280
+ environment=recorder,
281
+ record_interval=cfg.record_interval,
282
+ exploration_type=ExplorationType.RANDOM,
283
+ suffix="exploration",
284
+ out_keys={("next", "reward"): "r_evaluation_exploration"},
285
+ )
286
+ # register recorder
287
+ trainer.register_op(
288
+ "post_steps_log",
289
+ recorder_obj_explore,
290
+ )
291
+ # call recorder - could be removed
292
+ recorder_obj_explore(None)
293
+
294
+ trainer.register_op(
295
+ "post_steps", UpdateWeights(collector, update_weights_interval=1)
296
+ )
297
+
298
+ trainer.register_op("pre_steps_log", LogScalar())
299
+ trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))
300
+
301
+ return trainer