torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,177 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss, SACLoss
12
+ from torchrl.objectives.sac import DiscreteSACLoss
13
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
14
+
15
+
16
+ @dataclass
17
+ class LossConfig(ConfigBase):
18
+ """A class to configure a loss.
19
+
20
+ Args:
21
+ loss_type: The type of loss to use.
22
+ """
23
+
24
+ _partial_: bool = False
25
+
26
+ def __post_init__(self) -> None:
27
+ """Post-initialization hook for loss configurations."""
28
+
29
+
30
+ @dataclass
31
+ class SACLossConfig(LossConfig):
32
+ """A class to configure a SAC loss."""
33
+
34
+ actor_network: Any = None
35
+ qvalue_network: Any = None
36
+ value_network: Any = None
37
+ discrete: bool = False
38
+ num_qvalue_nets: int = 2
39
+ loss_function: str = "smooth_l1"
40
+ alpha_init: float = 1.0
41
+ min_alpha: float | None = None
42
+ max_alpha: float | None = None
43
+ action_spec: Any = None
44
+ fixed_alpha: bool = False
45
+ target_entropy: str | float = "auto"
46
+ delay_actor: bool = False
47
+ delay_qvalue: bool = True
48
+ delay_value: bool = True
49
+ gamma: float | None = None
50
+ priority_key: str | None = None
51
+ separate_losses: bool = False
52
+ reduction: str | None = None
53
+ skip_done_states: bool = False
54
+ deactivate_vmap: bool = False
55
+ _target_: str = "torchrl.trainers.algorithms.configs.objectives._make_sac_loss"
56
+
57
+ def __post_init__(self) -> None:
58
+ """Post-initialization hook for SAC loss configurations."""
59
+ super().__post_init__()
60
+
61
+
62
+ def _make_sac_loss(*args, **kwargs) -> SACLoss:
63
+ discrete_loss_type = kwargs.pop("discrete", False)
64
+
65
+ # Instantiate networks if they are config objects
66
+ actor_network = kwargs.get("actor_network")
67
+ qvalue_network = kwargs.get("qvalue_network")
68
+ value_network = kwargs.get("value_network")
69
+
70
+ if actor_network is not None and hasattr(actor_network, "_target_"):
71
+ kwargs["actor_network"] = actor_network()
72
+ if qvalue_network is not None and hasattr(qvalue_network, "_target_"):
73
+ kwargs["qvalue_network"] = qvalue_network()
74
+ if value_network is not None and hasattr(value_network, "_target_"):
75
+ kwargs["value_network"] = value_network()
76
+
77
+ if discrete_loss_type:
78
+ return DiscreteSACLoss(*args, **kwargs)
79
+ else:
80
+ return SACLoss(*args, **kwargs)
81
+
82
+
83
+ @dataclass
84
+ class PPOLossConfig(LossConfig):
85
+ """A class to configure a PPO loss."""
86
+
87
+ actor_network: Any = None
88
+ critic_network: Any = None
89
+ loss_type: str = "clip"
90
+ entropy_bonus: bool = True
91
+ samples_mc_entropy: int = 1
92
+ entropy_coeff: float | None = None
93
+ log_explained_variance: bool = True
94
+ critic_coeff: float = 0.25
95
+ loss_critic_type: str = "smooth_l1"
96
+ normalize_advantage: bool = True
97
+ normalize_advantage_exclude_dims: tuple = ()
98
+ gamma: float | None = None
99
+ separate_losses: bool = False
100
+ advantage_key: str | None = None
101
+ value_target_key: str | None = None
102
+ value_key: str | None = None
103
+ functional: bool = True
104
+ actor: Any = None
105
+ critic: Any = None
106
+ reduction: str | None = None
107
+ clip_value: float | None = None
108
+ device: Any = None
109
+ _target_: str = "torchrl.trainers.algorithms.configs.objectives._make_ppo_loss"
110
+
111
+ def __post_init__(self) -> None:
112
+ """Post-initialization hook for PPO loss configurations."""
113
+ super().__post_init__()
114
+
115
+
116
+ def _make_ppo_loss(*args, **kwargs) -> PPOLoss:
117
+ loss_type = kwargs.pop("loss_type", "clip")
118
+ if loss_type == "clip":
119
+ return ClipPPOLoss(*args, **kwargs)
120
+ elif loss_type == "kl":
121
+ return KLPENPPOLoss(*args, **kwargs)
122
+ elif loss_type == "ppo":
123
+ return PPOLoss(*args, **kwargs)
124
+ else:
125
+ raise ValueError(f"Invalid loss type: {loss_type}")
126
+
127
+
128
+ @dataclass
129
+ class TargetNetUpdaterConfig:
130
+ """An abstract class to configure target net updaters."""
131
+
132
+ loss_module: Any
133
+ _partial_: bool = True
134
+
135
+
136
+ @dataclass
137
+ class SoftUpdateConfig(TargetNetUpdaterConfig):
138
+ """A class for soft update instantiation."""
139
+
140
+ _target_: str = "torchrl.objectives.utils.SoftUpdate"
141
+ eps: float | None = None # noqa # type-ignore
142
+ tau: float | None = 0.001 # noqa # type-ignore
143
+
144
+
145
+ @dataclass
146
+ class HardUpdateConfig(TargetNetUpdaterConfig):
147
+ """A class for hard update instantiation."""
148
+
149
+ _target_: str = "torchrl.objectives.utils.HardUpdate."
150
+ value_network_update_interval: int = 1000
151
+
152
+
153
+ @dataclass
154
+ class GAEConfig(LossConfig):
155
+ """A class to configure a GAELoss."""
156
+
157
+ gamma: float | None = None
158
+ lmbda: float | None = None
159
+ value_network: Any = None
160
+ average_gae: bool = True
161
+ differentiable: bool = False
162
+ vectorized: bool | None = None
163
+ skip_existing: bool | None = None
164
+ advantage_key: str | None = None
165
+ value_target_key: str | None = None
166
+ value_key: str | None = None
167
+ shifted: bool = False
168
+ device: Any = None
169
+ time_dim: int | None = None
170
+ auto_reset_env: bool = False
171
+ deactivate_vmap: bool = False
172
+ _target_: str = "torchrl.objectives.value.GAE"
173
+ _partial_: bool = False
174
+
175
+ def __post_init__(self) -> None:
176
+ """Post-initialization hook for GAELoss configurations."""
177
+ super().__post_init__()
@@ -0,0 +1,340 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ import torch
12
+ from tensordict.nn import TensorDictModuleBase
13
+
14
+ from torchrl.collectors import BaseCollector
15
+ from torchrl.objectives.common import LossModule
16
+ from torchrl.objectives.utils import TargetNetUpdater
17
+ from torchrl.objectives.value.advantages import GAE
18
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
19
+ from torchrl.trainers.algorithms.ppo import PPOTrainer
20
+ from torchrl.trainers.algorithms.sac import SACTrainer
21
+
22
+
23
+ @dataclass
24
+ class TrainerConfig(ConfigBase):
25
+ """Base configuration class for trainers."""
26
+
27
+ def __post_init__(self) -> None:
28
+ """Post-initialization hook for trainer configurations."""
29
+
30
+
31
+ @dataclass
32
+ class SACTrainerConfig(TrainerConfig):
33
+ """Configuration class for SAC (Soft Actor Critic) trainer.
34
+
35
+ This class defines the configuration parameters for creating a SAC trainer,
36
+ including both required and optional fields with sensible defaults.
37
+ """
38
+
39
+ collector: Any
40
+ total_frames: int
41
+ optim_steps_per_batch: int | None
42
+ loss_module: Any
43
+ optimizer: Any
44
+ logger: Any
45
+ save_trainer_file: Any
46
+ replay_buffer: Any
47
+ frame_skip: int = 1
48
+ clip_grad_norm: bool = True
49
+ clip_norm: float | None = None
50
+ progress_bar: bool = True
51
+ seed: int | None = None
52
+ save_trainer_interval: int = 10000
53
+ log_interval: int = 10000
54
+ create_env_fn: Any = None
55
+ actor_network: Any = None
56
+ critic_network: Any = None
57
+ target_net_updater: Any = None
58
+ async_collection: bool = False
59
+ log_timings: bool = False
60
+
61
+ _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer"
62
+
63
+ def __post_init__(self) -> None:
64
+ """Post-initialization hook for SAC trainer configuration."""
65
+ super().__post_init__()
66
+
67
+
68
+ def _make_sac_trainer(*args, **kwargs) -> SACTrainer:
69
+ from torchrl.trainers.trainers import Logger
70
+
71
+ collector = kwargs.pop("collector")
72
+ total_frames = kwargs.pop("total_frames")
73
+ if total_frames is None:
74
+ total_frames = collector.total_frames
75
+ frame_skip = kwargs.pop("frame_skip", 1)
76
+ optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
77
+ loss_module = kwargs.pop("loss_module")
78
+ optimizer = kwargs.pop("optimizer")
79
+ logger = kwargs.pop("logger")
80
+ clip_grad_norm = kwargs.pop("clip_grad_norm", True)
81
+ clip_norm = kwargs.pop("clip_norm")
82
+ progress_bar = kwargs.pop("progress_bar", True)
83
+ replay_buffer = kwargs.pop("replay_buffer")
84
+ save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
85
+ log_interval = kwargs.pop("log_interval", 10000)
86
+ save_trainer_file = kwargs.pop("save_trainer_file")
87
+ seed = kwargs.pop("seed")
88
+ actor_network = kwargs.pop("actor_network")
89
+ critic_network = kwargs.pop("critic_network")
90
+ kwargs.pop("create_env_fn")
91
+ target_net_updater = kwargs.pop("target_net_updater")
92
+ async_collection = kwargs.pop("async_collection", False)
93
+ log_timings = kwargs.pop("log_timings", False)
94
+
95
+ # Instantiate networks first
96
+ if actor_network is not None:
97
+ actor_network = actor_network()
98
+ if critic_network is not None:
99
+ critic_network = critic_network()
100
+
101
+ if not isinstance(collector, BaseCollector):
102
+ # then it's a partial config
103
+ if not async_collection:
104
+ collector = collector()
105
+ elif replay_buffer is not None:
106
+ collector = collector(replay_buffer=replay_buffer)
107
+ elif getattr(collector, "replay_buffer", None) is None:
108
+ if async_collection and (
109
+ collector.replay_buffer is None or replay_buffer is None
110
+ ):
111
+ raise ValueError(
112
+ "replay_buffer must be provided when async_collection is True"
113
+ )
114
+
115
+ if not isinstance(loss_module, LossModule):
116
+ # then it's a partial config
117
+ loss_module = loss_module(
118
+ actor_network=actor_network, critic_network=critic_network
119
+ )
120
+ if not isinstance(target_net_updater, TargetNetUpdater):
121
+ # target_net_updater must be a partial taking the loss as input
122
+ target_net_updater = target_net_updater(loss_module)
123
+ if not isinstance(optimizer, torch.optim.Optimizer):
124
+ # then it's a partial config
125
+ optimizer = optimizer(params=loss_module.parameters())
126
+
127
+ # Quick instance checks
128
+ if not isinstance(collector, BaseCollector):
129
+ raise ValueError(f"collector must be a BaseCollector, got {type(collector)}")
130
+ if not isinstance(loss_module, LossModule):
131
+ raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
132
+ if not isinstance(optimizer, torch.optim.Optimizer):
133
+ raise ValueError(
134
+ f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
135
+ )
136
+ if not isinstance(logger, Logger) and logger is not None:
137
+ raise ValueError(f"logger must be a Logger, got {type(logger)}")
138
+
139
+ return SACTrainer(
140
+ collector=collector,
141
+ total_frames=total_frames,
142
+ frame_skip=frame_skip,
143
+ optim_steps_per_batch=optim_steps_per_batch,
144
+ loss_module=loss_module,
145
+ optimizer=optimizer,
146
+ logger=logger,
147
+ clip_grad_norm=clip_grad_norm,
148
+ clip_norm=clip_norm,
149
+ progress_bar=progress_bar,
150
+ seed=seed,
151
+ save_trainer_interval=save_trainer_interval,
152
+ log_interval=log_interval,
153
+ save_trainer_file=save_trainer_file,
154
+ replay_buffer=replay_buffer,
155
+ target_net_updater=target_net_updater,
156
+ async_collection=async_collection,
157
+ log_timings=log_timings,
158
+ )
159
+
160
+
161
+ @dataclass
162
+ class PPOTrainerConfig(TrainerConfig):
163
+ """Configuration class for PPO (Proximal Policy Optimization) trainer.
164
+
165
+ This class defines the configuration parameters for creating a PPO trainer,
166
+ including both required and optional fields with sensible defaults.
167
+
168
+ Args:
169
+ collector: The data collector for gathering training data.
170
+ total_frames: Total number of frames to train for.
171
+ optim_steps_per_batch: Number of optimization steps per batch.
172
+ loss_module: The loss module for computing policy and value losses.
173
+ optimizer: The optimizer for training.
174
+ logger: Logger for tracking training metrics.
175
+ save_trainer_file: File path for saving trainer state.
176
+ replay_buffer: Replay buffer for storing data.
177
+ frame_skip: Frame skip value for the environment. Default: 1.
178
+ clip_grad_norm: Whether to clip gradient norms. Default: True.
179
+ clip_norm: Maximum gradient norm value.
180
+ progress_bar: Whether to show a progress bar. Default: True.
181
+ seed: Random seed for reproducibility.
182
+ save_trainer_interval: Interval for saving trainer state. Default: 10000.
183
+ log_interval: Interval for logging metrics. Default: 10000.
184
+ create_env_fn: Environment creation function.
185
+ actor_network: Actor network configuration.
186
+ critic_network: Critic network configuration.
187
+ num_epochs: Number of epochs per batch. Default: 4.
188
+ async_collection: Whether to use async collection. Default: False.
189
+ add_gae: Whether to add GAE computation. Default: True.
190
+ gae: Custom GAE module configuration.
191
+ weight_update_map: Mapping from collector destination paths to trainer source paths.
192
+ Required if collector has weight_sync_schemes configured.
193
+ Example: ``{"policy": "loss_module.actor_network", "replay_buffer.transforms[0]": "loss_module.critic_network"}``.
194
+ log_timings: Whether to automatically log timing information for all hooks.
195
+ If True, timing metrics will be logged to the logger (e.g., wandb, tensorboard)
196
+ with prefix "time/" (e.g., "time/hook/UpdateWeights"). Default: False.
197
+ """
198
+
199
+ collector: Any
200
+ total_frames: int
201
+ optim_steps_per_batch: int | None
202
+ loss_module: Any
203
+ optimizer: Any
204
+ logger: Any
205
+ save_trainer_file: Any
206
+ replay_buffer: Any
207
+ frame_skip: int = 1
208
+ clip_grad_norm: bool = True
209
+ clip_norm: float | None = None
210
+ progress_bar: bool = True
211
+ seed: int | None = None
212
+ save_trainer_interval: int = 10000
213
+ log_interval: int = 10000
214
+ create_env_fn: Any = None
215
+ actor_network: Any = None
216
+ critic_network: Any = None
217
+ num_epochs: int = 4
218
+ async_collection: bool = False
219
+ add_gae: bool = True
220
+ gae: Any = None
221
+ weight_update_map: dict[str, str] | None = None
222
+ log_timings: bool = False
223
+
224
+ _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer"
225
+
226
+ def __post_init__(self) -> None:
227
+ """Post-initialization hook for PPO trainer configuration."""
228
+ super().__post_init__()
229
+
230
+
231
+ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer:
232
+ from torchrl.trainers.trainers import Logger
233
+
234
+ collector = kwargs.pop("collector")
235
+ total_frames = kwargs.pop("total_frames")
236
+ if total_frames is None:
237
+ total_frames = collector.total_frames
238
+ frame_skip = kwargs.pop("frame_skip", 1)
239
+ optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
240
+ loss_module = kwargs.pop("loss_module")
241
+ optimizer = kwargs.pop("optimizer")
242
+ logger = kwargs.pop("logger")
243
+ clip_grad_norm = kwargs.pop("clip_grad_norm", True)
244
+ clip_norm = kwargs.pop("clip_norm")
245
+ progress_bar = kwargs.pop("progress_bar", True)
246
+ replay_buffer = kwargs.pop("replay_buffer")
247
+ save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
248
+ log_interval = kwargs.pop("log_interval", 10000)
249
+ save_trainer_file = kwargs.pop("save_trainer_file")
250
+ seed = kwargs.pop("seed")
251
+ actor_network = kwargs.pop("actor_network")
252
+ critic_network = kwargs.pop("critic_network")
253
+ add_gae = kwargs.pop("add_gae", True)
254
+ gae = kwargs.pop("gae")
255
+ create_env_fn = kwargs.pop("create_env_fn")
256
+ weight_update_map = kwargs.pop("weight_update_map", None)
257
+ log_timings = kwargs.pop("log_timings", False)
258
+
259
+ if create_env_fn is not None:
260
+ # could be referenced somewhere else, no need to raise an error
261
+ pass
262
+ num_epochs = kwargs.pop("num_epochs", 4)
263
+ async_collection = kwargs.pop("async_collection", False)
264
+
265
+ # Instantiate networks first
266
+ if actor_network is not None:
267
+ actor_network = actor_network()
268
+ if critic_network is not None:
269
+ critic_network = critic_network()
270
+ else:
271
+ critic_network = loss_module.critic_network
272
+
273
+ # Ensure GAE in replay buffer uses the same value network instance as loss module
274
+ # This fixes the issue where Hydra instantiates separate instances of value_model
275
+ if (
276
+ replay_buffer is not None
277
+ and hasattr(replay_buffer, "_transform")
278
+ and len(replay_buffer._transform) > 1
279
+ and hasattr(replay_buffer._transform[1], "module")
280
+ and hasattr(replay_buffer._transform[1].module, "value_network")
281
+ ):
282
+ replay_buffer._transform[1].module.value_network = critic_network
283
+
284
+ if not isinstance(collector, BaseCollector):
285
+ # then it's a partial config
286
+ if not async_collection:
287
+ collector = collector()
288
+ else:
289
+ collector = collector(replay_buffer=replay_buffer)
290
+ elif async_collection and getattr(collector, "replay_buffer", None) is None:
291
+ raise RuntimeError(
292
+ "replay_buffer must be provided when async_collection is True"
293
+ )
294
+ if not isinstance(loss_module, LossModule):
295
+ # then it's a partial config
296
+ loss_module = loss_module(
297
+ actor_network=actor_network, critic_network=critic_network
298
+ )
299
+ if not isinstance(optimizer, torch.optim.Optimizer):
300
+ # then it's a partial config
301
+ optimizer = optimizer(params=loss_module.parameters())
302
+
303
+ # Quick instance checks
304
+ if not isinstance(collector, BaseCollector):
305
+ raise ValueError(f"collector must be a BaseCollector, got {type(collector)}")
306
+ if not isinstance(loss_module, LossModule):
307
+ raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
308
+ if not isinstance(optimizer, torch.optim.Optimizer):
309
+ raise ValueError(
310
+ f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
311
+ )
312
+ if not isinstance(logger, Logger) and logger is not None:
313
+ raise ValueError(f"logger must be a Logger, got {type(logger)}")
314
+ # instantiate gae if it is a partial config
315
+ if not isinstance(gae, (GAE, TensorDictModuleBase)) and gae is not None:
316
+ gae = gae()
317
+
318
+ return PPOTrainer(
319
+ collector=collector,
320
+ total_frames=total_frames,
321
+ frame_skip=frame_skip,
322
+ optim_steps_per_batch=optim_steps_per_batch,
323
+ loss_module=loss_module,
324
+ optimizer=optimizer,
325
+ logger=logger,
326
+ clip_grad_norm=clip_grad_norm,
327
+ clip_norm=clip_norm,
328
+ progress_bar=progress_bar,
329
+ seed=seed,
330
+ save_trainer_interval=save_trainer_interval,
331
+ log_interval=log_interval,
332
+ save_trainer_file=save_trainer_file,
333
+ replay_buffer=replay_buffer,
334
+ num_epochs=num_epochs,
335
+ async_collection=async_collection,
336
+ add_gae=add_gae,
337
+ gae=gae,
338
+ weight_update_map=weight_update_map,
339
+ log_timings=log_timings,
340
+ )