torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,705 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import sys
9
+
10
+ # Check for hydra/omegaconf availability - required for config system
11
+ try:
12
+ import hydra # noqa: F401
13
+ import omegaconf # noqa: F401
14
+ from hydra.core.config_store import ConfigStore
15
+
16
+ _has_hydra = True
17
+ except ImportError as e:
18
+ raise ImportError(
19
+ "The TorchRL configuration system requires hydra-core and omegaconf. "
20
+ "Please install them with: pip install 'torchrl[utils]' or pip install hydra-core omegaconf"
21
+ ) from e
22
+
23
+ from torchrl.trainers.algorithms.configs.collectors import (
24
+ # New canonical config names
25
+ AsyncCollectorConfig,
26
+ # Legacy config names (aliases)
27
+ AsyncDataCollectorConfig,
28
+ BaseCollectorConfig,
29
+ CollectorConfig,
30
+ MultiAsyncCollectorConfig,
31
+ MultiSyncCollectorConfig,
32
+ SyncDataCollectorConfig,
33
+ )
34
+
35
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
36
+ from torchrl.trainers.algorithms.configs.data import (
37
+ LazyMemmapStorageConfig,
38
+ LazyStackStorageConfig,
39
+ LazyTensorStorageConfig,
40
+ ListStorageConfig,
41
+ PrioritizedSamplerConfig,
42
+ RandomSamplerConfig,
43
+ ReplayBufferConfig,
44
+ RoundRobinWriterConfig,
45
+ SamplerWithoutReplacementConfig,
46
+ SliceSamplerConfig,
47
+ SliceSamplerWithoutReplacementConfig,
48
+ StorageEnsembleConfig,
49
+ StorageEnsembleWriterConfig,
50
+ TensorDictReplayBufferConfig,
51
+ TensorStorageConfig,
52
+ )
53
+ from torchrl.trainers.algorithms.configs.envs import (
54
+ BatchedEnvConfig,
55
+ EnvConfig,
56
+ TransformedEnvConfig,
57
+ )
58
+ from torchrl.trainers.algorithms.configs.envs_libs import (
59
+ BraxEnvConfig,
60
+ DMControlEnvConfig,
61
+ EnvLibsConfig,
62
+ GymEnvConfig,
63
+ HabitatEnvConfig,
64
+ IsaacGymEnvConfig,
65
+ JumanjiEnvConfig,
66
+ MeltingpotEnvConfig,
67
+ MOGymEnvConfig,
68
+ MultiThreadedEnvConfig,
69
+ OpenMLEnvConfig,
70
+ OpenSpielEnvConfig,
71
+ PettingZooEnvConfig,
72
+ RoboHiveEnvConfig,
73
+ SMACv2EnvConfig,
74
+ UnityMLAgentsEnvConfig,
75
+ VmasEnvConfig,
76
+ )
77
+ from torchrl.trainers.algorithms.configs.logging import (
78
+ CSVLoggerConfig,
79
+ LoggerConfig,
80
+ TensorboardLoggerConfig,
81
+ WandbLoggerConfig,
82
+ )
83
+ from torchrl.trainers.algorithms.configs.modules import (
84
+ AdditiveGaussianModuleConfig,
85
+ ConvNetConfig,
86
+ MLPConfig,
87
+ ModelConfig,
88
+ TanhModuleConfig,
89
+ TanhNormalModelConfig,
90
+ TensorDictModuleConfig,
91
+ TensorDictSequentialConfig,
92
+ ValueModelConfig,
93
+ )
94
+ from torchrl.trainers.algorithms.configs.objectives import (
95
+ GAEConfig,
96
+ HardUpdateConfig,
97
+ LossConfig,
98
+ PPOLossConfig,
99
+ SACLossConfig,
100
+ SoftUpdateConfig,
101
+ )
102
+ from torchrl.trainers.algorithms.configs.trainers import (
103
+ PPOTrainerConfig,
104
+ SACTrainerConfig,
105
+ TrainerConfig,
106
+ )
107
+ from torchrl.trainers.algorithms.configs.transforms import (
108
+ ActionDiscretizerConfig,
109
+ ActionMaskConfig,
110
+ AutoResetTransformConfig,
111
+ BatchSizeTransformConfig,
112
+ BinarizeRewardConfig,
113
+ BurnInTransformConfig,
114
+ CatFramesConfig,
115
+ CatTensorsConfig,
116
+ CenterCropConfig,
117
+ ClipTransformConfig,
118
+ ComposeConfig,
119
+ ConditionalPolicySwitchConfig,
120
+ ConditionalSkipConfig,
121
+ CropConfig,
122
+ DeviceCastTransformConfig,
123
+ DiscreteActionProjectionConfig,
124
+ DoubleToFloatConfig,
125
+ DTypeCastTransformConfig,
126
+ EndOfLifeTransformConfig,
127
+ ExcludeTransformConfig,
128
+ FiniteTensorDictCheckConfig,
129
+ FlattenObservationConfig,
130
+ FlattenTensorDictConfig,
131
+ FrameSkipTransformConfig,
132
+ GrayScaleConfig,
133
+ HashConfig,
134
+ InitTrackerConfig,
135
+ KLRewardTransformConfig,
136
+ LineariseRewardsConfig,
137
+ ModuleTransformConfig,
138
+ MultiActionConfig,
139
+ MultiStepTransformConfig,
140
+ NoopResetEnvConfig,
141
+ ObservationNormConfig,
142
+ PermuteTransformConfig,
143
+ PinMemoryTransformConfig,
144
+ R3MTransformConfig,
145
+ RandomCropTensorDictConfig,
146
+ RemoveEmptySpecsConfig,
147
+ RenameTransformConfig,
148
+ ResizeConfig,
149
+ Reward2GoTransformConfig,
150
+ RewardClippingConfig,
151
+ RewardScalingConfig,
152
+ RewardSumConfig,
153
+ SelectTransformConfig,
154
+ SignTransformConfig,
155
+ SqueezeTransformConfig,
156
+ StackConfig,
157
+ StepCounterConfig,
158
+ TargetReturnConfig,
159
+ TensorDictPrimerConfig,
160
+ TimeMaxPoolConfig,
161
+ TimerConfig,
162
+ TokenizerConfig,
163
+ ToTensorImageConfig,
164
+ TrajCounterConfig,
165
+ TransformConfig,
166
+ UnaryTransformConfig,
167
+ UnsqueezeTransformConfig,
168
+ VC1TransformConfig,
169
+ VecGymEnvTransformConfig,
170
+ VecNormConfig,
171
+ VecNormV2Config,
172
+ VIPRewardTransformConfig,
173
+ VIPTransformConfig,
174
+ )
175
+ from torchrl.trainers.algorithms.configs.utils import (
176
+ AdadeltaConfig,
177
+ AdagradConfig,
178
+ AdamaxConfig,
179
+ AdamConfig,
180
+ AdamWConfig,
181
+ ASGDConfig,
182
+ LBFGSConfig,
183
+ LionConfig,
184
+ NAdamConfig,
185
+ RAdamConfig,
186
+ RMSpropConfig,
187
+ RpropConfig,
188
+ SGDConfig,
189
+ SparseAdamConfig,
190
+ )
191
+ from torchrl.trainers.algorithms.configs.weight_sync_schemes import (
192
+ DistributedWeightSyncSchemeConfig,
193
+ MultiProcessWeightSyncSchemeConfig,
194
+ NoWeightSyncSchemeConfig,
195
+ RayModuleTransformSchemeConfig,
196
+ RayWeightSyncSchemeConfig,
197
+ RPCWeightSyncSchemeConfig,
198
+ SharedMemWeightSyncSchemeConfig,
199
+ VLLMDoubleBufferSyncSchemeConfig,
200
+ VLLMWeightSyncSchemeConfig,
201
+ WeightSyncSchemeConfig,
202
+ )
203
+ from torchrl.trainers.algorithms.configs.weight_update import (
204
+ DistributedWeightUpdaterConfig,
205
+ MultiProcessedWeightUpdaterConfig,
206
+ RayWeightUpdaterConfig,
207
+ RemoteModuleWeightUpdaterConfig,
208
+ RPCWeightUpdaterConfig,
209
+ VanillaWeightUpdaterConfig,
210
+ vLLMUpdaterConfig,
211
+ WeightUpdaterConfig,
212
+ )
213
+
214
+ __all__ = [
215
+ # Base configuration
216
+ "ConfigBase",
217
+ # Optimizers
218
+ "AdamConfig",
219
+ "AdamWConfig",
220
+ "AdamaxConfig",
221
+ "AdadeltaConfig",
222
+ "AdagradConfig",
223
+ "ASGDConfig",
224
+ "LBFGSConfig",
225
+ "LionConfig",
226
+ "NAdamConfig",
227
+ "RAdamConfig",
228
+ "RMSpropConfig",
229
+ "RpropConfig",
230
+ "SGDConfig",
231
+ "SparseAdamConfig",
232
+ # Collectors (new canonical names)
233
+ "AsyncCollectorConfig",
234
+ "CollectorConfig",
235
+ "BaseCollectorConfig",
236
+ "MultiAsyncCollectorConfig",
237
+ "MultiSyncCollectorConfig",
238
+ # Collectors (legacy aliases)
239
+ "AsyncDataCollectorConfig",
240
+ "MultiSyncCollectorConfig",
241
+ "MultiAsyncCollectorConfig",
242
+ "SyncDataCollectorConfig",
243
+ # Environments
244
+ "BatchedEnvConfig",
245
+ "EnvConfig",
246
+ "TransformedEnvConfig",
247
+ # Environment Libs
248
+ "BraxEnvConfig",
249
+ "DMControlEnvConfig",
250
+ "EnvLibsConfig",
251
+ "GymEnvConfig",
252
+ "HabitatEnvConfig",
253
+ "IsaacGymEnvConfig",
254
+ "JumanjiEnvConfig",
255
+ "MeltingpotEnvConfig",
256
+ "MOGymEnvConfig",
257
+ "MultiThreadedEnvConfig",
258
+ "OpenMLEnvConfig",
259
+ "OpenSpielEnvConfig",
260
+ "PettingZooEnvConfig",
261
+ "RoboHiveEnvConfig",
262
+ "SMACv2EnvConfig",
263
+ "UnityMLAgentsEnvConfig",
264
+ "VmasEnvConfig",
265
+ # Networks and Models
266
+ "ConvNetConfig",
267
+ "MLPConfig",
268
+ "ModelConfig",
269
+ "TanhModuleConfig",
270
+ "TanhNormalModelConfig",
271
+ "TensorDictModuleConfig",
272
+ "TensorDictSequentialConfig",
273
+ "ValueModelConfig",
274
+ "AdditiveGaussianModuleConfig",
275
+ # Transforms - Core
276
+ "ActionDiscretizerConfig",
277
+ "ActionMaskConfig",
278
+ "AutoResetTransformConfig",
279
+ "BatchSizeTransformConfig",
280
+ "BinarizeRewardConfig",
281
+ "BurnInTransformConfig",
282
+ "CatFramesConfig",
283
+ "CatTensorsConfig",
284
+ "CenterCropConfig",
285
+ "ClipTransformConfig",
286
+ "ComposeConfig",
287
+ "ConditionalPolicySwitchConfig",
288
+ "ConditionalSkipConfig",
289
+ "CropConfig",
290
+ "DeviceCastTransformConfig",
291
+ "DiscreteActionProjectionConfig",
292
+ "DoubleToFloatConfig",
293
+ "DTypeCastTransformConfig",
294
+ "EndOfLifeTransformConfig",
295
+ "ExcludeTransformConfig",
296
+ "FiniteTensorDictCheckConfig",
297
+ "FlattenObservationConfig",
298
+ "FlattenTensorDictConfig",
299
+ "FrameSkipTransformConfig",
300
+ "GrayScaleConfig",
301
+ "HashConfig",
302
+ "InitTrackerConfig",
303
+ "KLRewardTransformConfig",
304
+ "LineariseRewardsConfig",
305
+ "ModuleTransformConfig",
306
+ "MultiActionConfig",
307
+ "MultiStepTransformConfig",
308
+ "NoopResetEnvConfig",
309
+ "ObservationNormConfig",
310
+ "PermuteTransformConfig",
311
+ "PinMemoryTransformConfig",
312
+ "RandomCropTensorDictConfig",
313
+ "RemoveEmptySpecsConfig",
314
+ "RenameTransformConfig",
315
+ "ResizeConfig",
316
+ "Reward2GoTransformConfig",
317
+ "RewardClippingConfig",
318
+ "RewardScalingConfig",
319
+ "RewardSumConfig",
320
+ "R3MTransformConfig",
321
+ "SelectTransformConfig",
322
+ "SignTransformConfig",
323
+ "SqueezeTransformConfig",
324
+ "StackConfig",
325
+ "StepCounterConfig",
326
+ "TargetReturnConfig",
327
+ "TensorDictPrimerConfig",
328
+ "TimerConfig",
329
+ "TimeMaxPoolConfig",
330
+ "ToTensorImageConfig",
331
+ "TokenizerConfig",
332
+ "TrajCounterConfig",
333
+ "TransformConfig",
334
+ "UnaryTransformConfig",
335
+ "UnsqueezeTransformConfig",
336
+ "VC1TransformConfig",
337
+ "VecGymEnvTransformConfig",
338
+ "VecNormConfig",
339
+ "VecNormV2Config",
340
+ "VIPRewardTransformConfig",
341
+ "VIPTransformConfig",
342
+ # Storage and Replay Buffers
343
+ "LazyMemmapStorageConfig",
344
+ "LazyStackStorageConfig",
345
+ "LazyTensorStorageConfig",
346
+ "ListStorageConfig",
347
+ "ReplayBufferConfig",
348
+ "RoundRobinWriterConfig",
349
+ "StorageEnsembleConfig",
350
+ "StorageEnsembleWriterConfig",
351
+ "TensorDictReplayBufferConfig",
352
+ "TensorStorageConfig",
353
+ # Samplers
354
+ "PrioritizedSamplerConfig",
355
+ "RandomSamplerConfig",
356
+ "SamplerWithoutReplacementConfig",
357
+ "SliceSamplerConfig",
358
+ "SliceSamplerWithoutReplacementConfig",
359
+ # Losses
360
+ "LossConfig",
361
+ "PPOLossConfig",
362
+ "SACLossConfig",
363
+ # Value functions
364
+ "GAEConfig",
365
+ # Trainers
366
+ "PPOTrainerConfig",
367
+ "SACTrainerConfig",
368
+ "TrainerConfig",
369
+ # Loggers
370
+ "CSVLoggerConfig",
371
+ "LoggerConfig",
372
+ "TensorboardLoggerConfig",
373
+ "WandbLoggerConfig",
374
+ # Weight Updaters
375
+ "WeightUpdaterConfig",
376
+ "VanillaWeightUpdaterConfig",
377
+ "MultiProcessedWeightUpdaterConfig",
378
+ "RayWeightUpdaterConfig",
379
+ "RemoteModuleWeightUpdaterConfig",
380
+ "RPCWeightUpdaterConfig",
381
+ "DistributedWeightUpdaterConfig",
382
+ "vLLMUpdaterConfig",
383
+ # Weight Sync Schemes
384
+ "WeightSyncSchemeConfig",
385
+ "MultiProcessWeightSyncSchemeConfig",
386
+ "SharedMemWeightSyncSchemeConfig",
387
+ "NoWeightSyncSchemeConfig",
388
+ "RayWeightSyncSchemeConfig",
389
+ "RayModuleTransformSchemeConfig",
390
+ "RPCWeightSyncSchemeConfig",
391
+ "DistributedWeightSyncSchemeConfig",
392
+ "VLLMWeightSyncSchemeConfig",
393
+ "VLLMDoubleBufferSyncSchemeConfig",
394
+ ]
395
+
396
+
397
+ def _register_configs():
398
+ """Register configurations with Hydra ConfigStore.
399
+
400
+ This function is called lazily to avoid GlobalHydra initialization issues
401
+ during testing. It should be called explicitly when needed.
402
+
403
+ To add a new config:
404
+ - Write the config class in the appropriate file (e.g. torchrl/trainers/algorithms/configs/transforms.py) and add it to the __all__ list in torchrl/trainers/algorithms/configs/__init__.py
405
+ - Register the config in the appropriate group, e.g. cs.store(group="transform", name="new_transform", node=NewTransformConfig)
406
+ """
407
+ cs = ConfigStore.instance()
408
+
409
+ # =============================================================================
410
+ # Environment Configurations
411
+ # =============================================================================
412
+
413
+ # Core environment configs
414
+ cs.store(group="env", name="gym", node=GymEnvConfig)
415
+ cs.store(group="env", name="batched_env", node=BatchedEnvConfig)
416
+ cs.store(group="env", name="transformed_env", node=TransformedEnvConfig)
417
+
418
+ # Environment libs configs
419
+ cs.store(group="env", name="brax", node=BraxEnvConfig)
420
+ cs.store(group="env", name="dm_control", node=DMControlEnvConfig)
421
+ cs.store(group="env", name="habitat", node=HabitatEnvConfig)
422
+ cs.store(group="env", name="isaac_gym", node=IsaacGymEnvConfig)
423
+ cs.store(group="env", name="jumanji", node=JumanjiEnvConfig)
424
+ cs.store(group="env", name="meltingpot", node=MeltingpotEnvConfig)
425
+ cs.store(group="env", name="mo_gym", node=MOGymEnvConfig)
426
+ cs.store(group="env", name="multi_threaded", node=MultiThreadedEnvConfig)
427
+ cs.store(group="env", name="openml", node=OpenMLEnvConfig)
428
+ cs.store(group="env", name="openspiel", node=OpenSpielEnvConfig)
429
+ cs.store(group="env", name="pettingzoo", node=PettingZooEnvConfig)
430
+ cs.store(group="env", name="robohive", node=RoboHiveEnvConfig)
431
+ cs.store(group="env", name="smacv2", node=SMACv2EnvConfig)
432
+ cs.store(group="env", name="unity_mlagents", node=UnityMLAgentsEnvConfig)
433
+ cs.store(group="env", name="vmas", node=VmasEnvConfig)
434
+
435
+ # =============================================================================
436
+ # Network and Model Configurations
437
+ # =============================================================================
438
+
439
+ # Network configs
440
+ cs.store(group="network", name="mlp", node=MLPConfig)
441
+ cs.store(group="network", name="convnet", node=ConvNetConfig)
442
+
443
+ # Model configs
444
+ cs.store(group="network", name="tensordict_module", node=TensorDictModuleConfig)
445
+ cs.store(
446
+ group="network", name="tensordict_sequential", node=TensorDictSequentialConfig
447
+ )
448
+ cs.store(group="model", name="tanh_module", node=TanhModuleConfig)
449
+ cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
450
+ cs.store(group="model", name="value", node=ValueModelConfig)
451
+
452
+ # Exploration configs
453
+ cs.store(
454
+ group="exploration",
455
+ name="additive_gaussian",
456
+ node=AdditiveGaussianModuleConfig,
457
+ )
458
+
459
+ # =============================================================================
460
+ # Transform Configurations
461
+ # =============================================================================
462
+
463
+ # Core transforms
464
+ cs.store(group="transform", name="noop_reset", node=NoopResetEnvConfig)
465
+ cs.store(group="transform", name="step_counter", node=StepCounterConfig)
466
+ cs.store(group="transform", name="compose", node=ComposeConfig)
467
+ cs.store(group="transform", name="double_to_float", node=DoubleToFloatConfig)
468
+ cs.store(group="transform", name="to_tensor_image", node=ToTensorImageConfig)
469
+ cs.store(group="transform", name="clip", node=ClipTransformConfig)
470
+ cs.store(group="transform", name="resize", node=ResizeConfig)
471
+ cs.store(group="transform", name="center_crop", node=CenterCropConfig)
472
+ cs.store(group="transform", name="crop", node=CropConfig)
473
+ cs.store(
474
+ group="transform", name="flatten_observation", node=FlattenObservationConfig
475
+ )
476
+ cs.store(group="transform", name="flatten_tensordict", node=FlattenTensorDictConfig)
477
+ cs.store(group="transform", name="gray_scale", node=GrayScaleConfig)
478
+ cs.store(group="transform", name="observation_norm", node=ObservationNormConfig)
479
+ cs.store(group="transform", name="cat_frames", node=CatFramesConfig)
480
+ cs.store(group="transform", name="reward_clipping", node=RewardClippingConfig)
481
+ cs.store(group="transform", name="reward_scaling", node=RewardScalingConfig)
482
+ cs.store(group="transform", name="binarize_reward", node=BinarizeRewardConfig)
483
+ cs.store(group="transform", name="target_return", node=TargetReturnConfig)
484
+ cs.store(group="transform", name="vec_norm", node=VecNormConfig)
485
+ cs.store(group="transform", name="frame_skip", node=FrameSkipTransformConfig)
486
+ cs.store(group="transform", name="device_cast", node=DeviceCastTransformConfig)
487
+ cs.store(group="transform", name="dtype_cast", node=DTypeCastTransformConfig)
488
+ cs.store(group="transform", name="unsqueeze", node=UnsqueezeTransformConfig)
489
+ cs.store(group="transform", name="squeeze", node=SqueezeTransformConfig)
490
+ cs.store(group="transform", name="permute", node=PermuteTransformConfig)
491
+ cs.store(group="transform", name="cat_tensors", node=CatTensorsConfig)
492
+ cs.store(group="transform", name="stack", node=StackConfig)
493
+ cs.store(
494
+ group="transform",
495
+ name="discrete_action_projection",
496
+ node=DiscreteActionProjectionConfig,
497
+ )
498
+ cs.store(group="transform", name="tensordict_primer", node=TensorDictPrimerConfig)
499
+ cs.store(group="transform", name="pin_memory", node=PinMemoryTransformConfig)
500
+ cs.store(group="transform", name="reward_sum", node=RewardSumConfig)
501
+ cs.store(group="transform", name="exclude", node=ExcludeTransformConfig)
502
+ cs.store(group="transform", name="select", node=SelectTransformConfig)
503
+ cs.store(group="transform", name="time_max_pool", node=TimeMaxPoolConfig)
504
+ cs.store(
505
+ group="transform",
506
+ name="random_crop_tensordict",
507
+ node=RandomCropTensorDictConfig,
508
+ )
509
+ cs.store(group="transform", name="init_tracker", node=InitTrackerConfig)
510
+ cs.store(group="transform", name="rename", node=RenameTransformConfig)
511
+ cs.store(group="transform", name="reward2go", node=Reward2GoTransformConfig)
512
+ cs.store(group="transform", name="action_mask", node=ActionMaskConfig)
513
+ cs.store(group="transform", name="vec_gym_env", node=VecGymEnvTransformConfig)
514
+ cs.store(group="transform", name="burn_in", node=BurnInTransformConfig)
515
+ cs.store(group="transform", name="sign", node=SignTransformConfig)
516
+ cs.store(group="transform", name="remove_empty_specs", node=RemoveEmptySpecsConfig)
517
+ cs.store(group="transform", name="batch_size", node=BatchSizeTransformConfig)
518
+ cs.store(group="transform", name="auto_reset", node=AutoResetTransformConfig)
519
+ cs.store(group="transform", name="action_discretizer", node=ActionDiscretizerConfig)
520
+ cs.store(group="transform", name="traj_counter", node=TrajCounterConfig)
521
+ cs.store(group="transform", name="linearise_rewards", node=LineariseRewardsConfig)
522
+ cs.store(group="transform", name="module", node=ModuleTransformConfig)
523
+ cs.store(group="transform", name="conditional_skip", node=ConditionalSkipConfig)
524
+ cs.store(group="transform", name="multi_action", node=MultiActionConfig)
525
+ cs.store(group="transform", name="timer", node=TimerConfig)
526
+ cs.store(
527
+ group="transform",
528
+ name="conditional_policy_switch",
529
+ node=ConditionalPolicySwitchConfig,
530
+ )
531
+ cs.store(
532
+ group="transform",
533
+ name="finite_tensordict_check",
534
+ node=FiniteTensorDictCheckConfig,
535
+ )
536
+ cs.store(group="transform", name="unary", node=UnaryTransformConfig)
537
+ cs.store(group="transform", name="hash", node=HashConfig)
538
+ cs.store(group="transform", name="tokenizer", node=TokenizerConfig)
539
+
540
+ # Specialized transforms
541
+ cs.store(group="transform", name="end_of_life", node=EndOfLifeTransformConfig)
542
+ cs.store(group="transform", name="multi_step", node=MultiStepTransformConfig)
543
+ cs.store(group="transform", name="kl_reward", node=KLRewardTransformConfig)
544
+ cs.store(group="transform", name="r3m", node=R3MTransformConfig)
545
+ cs.store(group="transform", name="vc1", node=VC1TransformConfig)
546
+ cs.store(group="transform", name="vip", node=VIPTransformConfig)
547
+ cs.store(group="transform", name="vip_reward", node=VIPRewardTransformConfig)
548
+ cs.store(group="transform", name="vec_norm_v2", node=VecNormV2Config)
549
+ cs.store(group="transform", name="module", node=ModuleTransformConfig)
550
+
551
+ # =============================================================================
552
+ # Loss Configurations
553
+ # =============================================================================
554
+
555
+ cs.store(group="loss", name="base", node=LossConfig)
556
+ cs.store(group="loss", name="ppo", node=PPOLossConfig)
557
+ cs.store(group="loss", name="sac", node=SACLossConfig)
558
+
559
+ # =============================================================================
560
+ # Value Function Configurations
561
+ # =============================================================================
562
+
563
+ cs.store(group="value", name="gae", node=GAEConfig)
564
+
565
+ # =============================================================================
566
+ # Target Net Updater Configurations
567
+ # =============================================================================
568
+
569
+ cs.store(group="target_net_updater", name="soft", node=SoftUpdateConfig)
570
+ cs.store(group="target_net_updater", name="hard", node=HardUpdateConfig)
571
+
572
+ # =============================================================================
573
+ # Replay Buffer Configurations
574
+ # =============================================================================
575
+
576
+ cs.store(group="replay_buffer", name="base", node=ReplayBufferConfig)
577
+ cs.store(
578
+ group="replay_buffer", name="tensordict", node=TensorDictReplayBufferConfig
579
+ )
580
+ cs.store(group="sampler", name="random", node=RandomSamplerConfig)
581
+ cs.store(
582
+ group="sampler",
583
+ name="without_replacement",
584
+ node=SamplerWithoutReplacementConfig,
585
+ )
586
+ cs.store(group="sampler", name="prioritized", node=PrioritizedSamplerConfig)
587
+ cs.store(group="sampler", name="slice", node=SliceSamplerConfig)
588
+ cs.store(
589
+ group="sampler",
590
+ name="slice_without_replacement",
591
+ node=SliceSamplerWithoutReplacementConfig,
592
+ )
593
+ cs.store(group="storage", name="lazy_stack", node=LazyStackStorageConfig)
594
+ cs.store(group="storage", name="list", node=ListStorageConfig)
595
+ cs.store(group="storage", name="tensor", node=TensorStorageConfig)
596
+ cs.store(group="storage", name="lazy_tensor", node=LazyTensorStorageConfig)
597
+ cs.store(group="storage", name="lazy_memmap", node=LazyMemmapStorageConfig)
598
+ cs.store(group="writer", name="round_robin", node=RoundRobinWriterConfig)
599
+
600
+ # =============================================================================
601
+ # Collector Configurations
602
+ # =============================================================================
603
+
604
+ cs.store(group="collector", name="sync", node=CollectorConfig)
605
+ cs.store(group="collector", name="async", node=AsyncCollectorConfig)
606
+ cs.store(group="collector", name="multi_sync", node=MultiSyncCollectorConfig)
607
+ cs.store(group="collector", name="multi_async", node=MultiAsyncCollectorConfig)
608
+
609
+ # =============================================================================
610
+ # Trainer Configurations
611
+ # =============================================================================
612
+
613
+ cs.store(group="trainer", name="base", node=TrainerConfig)
614
+ cs.store(group="trainer", name="ppo", node=PPOTrainerConfig)
615
+ cs.store(group="trainer", name="sac", node=SACTrainerConfig)
616
+
617
+ # =============================================================================
618
+ # Optimizer Configurations
619
+ # =============================================================================
620
+
621
+ cs.store(group="optimizer", name="adam", node=AdamConfig)
622
+ cs.store(group="optimizer", name="adamw", node=AdamWConfig)
623
+ cs.store(group="optimizer", name="adamax", node=AdamaxConfig)
624
+ cs.store(group="optimizer", name="adadelta", node=AdadeltaConfig)
625
+ cs.store(group="optimizer", name="adagrad", node=AdagradConfig)
626
+ cs.store(group="optimizer", name="asgd", node=ASGDConfig)
627
+ cs.store(group="optimizer", name="lbfgs", node=LBFGSConfig)
628
+ cs.store(group="optimizer", name="lion", node=LionConfig)
629
+ cs.store(group="optimizer", name="nadam", node=NAdamConfig)
630
+ cs.store(group="optimizer", name="radam", node=RAdamConfig)
631
+ cs.store(group="optimizer", name="rmsprop", node=RMSpropConfig)
632
+ cs.store(group="optimizer", name="rprop", node=RpropConfig)
633
+ cs.store(group="optimizer", name="sgd", node=SGDConfig)
634
+ cs.store(group="optimizer", name="sparse_adam", node=SparseAdamConfig)
635
+
636
+ # =============================================================================
637
+ # Logger Configurations
638
+ # =============================================================================
639
+
640
+ cs.store(group="logger", name="wandb", node=WandbLoggerConfig)
641
+ cs.store(group="logger", name="tensorboard", node=TensorboardLoggerConfig)
642
+ cs.store(group="logger", name="csv", node=CSVLoggerConfig)
643
+ cs.store(group="logger", name="base", node=LoggerConfig)
644
+
645
+ # =============================================================================
646
+ # Weight Updater Configurations
647
+ # =============================================================================
648
+
649
+ cs.store(group="weight_updater", name="base", node=WeightUpdaterConfig)
650
+ cs.store(group="weight_updater", name="vanilla", node=VanillaWeightUpdaterConfig)
651
+ cs.store(
652
+ group="weight_updater",
653
+ name="multiprocessed",
654
+ node=MultiProcessedWeightUpdaterConfig,
655
+ )
656
+ cs.store(group="weight_updater", name="ray", node=RayWeightUpdaterConfig)
657
+ cs.store(
658
+ group="weight_updater",
659
+ name="remote_module",
660
+ node=RemoteModuleWeightUpdaterConfig,
661
+ )
662
+ cs.store(group="weight_updater", name="rpc", node=RPCWeightUpdaterConfig)
663
+ cs.store(
664
+ group="weight_updater", name="distributed", node=DistributedWeightUpdaterConfig
665
+ )
666
+ cs.store(group="weight_updater", name="vllm", node=vLLMUpdaterConfig)
667
+
668
+ # =============================================================================
669
+ # Weight Sync Scheme Configurations
670
+ # =============================================================================
671
+
672
+ cs.store(group="weight_sync_scheme", name="base", node=WeightSyncSchemeConfig)
673
+ cs.store(
674
+ group="weight_sync_scheme",
675
+ name="multiprocess",
676
+ node=MultiProcessWeightSyncSchemeConfig,
677
+ )
678
+ cs.store(
679
+ group="weight_sync_scheme",
680
+ name="shared_mem",
681
+ node=SharedMemWeightSyncSchemeConfig,
682
+ )
683
+ cs.store(group="weight_sync_scheme", name="no_sync", node=NoWeightSyncSchemeConfig)
684
+ cs.store(group="weight_sync_scheme", name="ray", node=RayWeightSyncSchemeConfig)
685
+ cs.store(
686
+ group="weight_sync_scheme",
687
+ name="ray_module_transform",
688
+ node=RayModuleTransformSchemeConfig,
689
+ )
690
+ cs.store(group="weight_sync_scheme", name="rpc", node=RPCWeightSyncSchemeConfig)
691
+ cs.store(
692
+ group="weight_sync_scheme",
693
+ name="distributed",
694
+ node=DistributedWeightSyncSchemeConfig,
695
+ )
696
+ cs.store(group="weight_sync_scheme", name="vllm", node=VLLMWeightSyncSchemeConfig)
697
+ cs.store(
698
+ group="weight_sync_scheme",
699
+ name="vllm_double_buffer",
700
+ node=VLLMDoubleBufferSyncSchemeConfig,
701
+ )
702
+
703
+
704
+ if not sys.version_info < (3, 10): #  type: ignore # noqa
705
+ _register_configs()