torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,530 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ from copy import deepcopy
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
13
+
14
+ from tensordict.nn import (
15
+ composite_lp_aggregate,
16
+ dispatch,
17
+ ProbabilisticTensorDictSequential,
18
+ TensorDictModule,
19
+ )
20
+ from tensordict.utils import NestedKey
21
+ from torchrl.objectives.common import LossModule
22
+
23
+ from torchrl.objectives.utils import (
24
+ _clip_value_loss,
25
+ _GAMMA_LMBDA_DEPREC_ERROR,
26
+ _reduce,
27
+ default_value_kwargs,
28
+ distance_loss,
29
+ ValueEstimators,
30
+ )
31
+ from torchrl.objectives.value import (
32
+ GAE,
33
+ TD0Estimator,
34
+ TD1Estimator,
35
+ TDLambdaEstimator,
36
+ ValueEstimatorBase,
37
+ VTrace,
38
+ )
39
+
40
+
41
+ class ReinforceLoss(LossModule):
42
+ """Reinforce loss module.
43
+
44
+ Presented in "Simple statistical gradient-following sota-implementations for connectionist reinforcement learning", Williams, 1992
45
+ https://doi.org/10.1007/BF00992696
46
+
47
+
48
+ Args:
49
+ actor_network (ProbabilisticTensorDictSequential): policy operator.
50
+ critic_network (ValueOperator): value operator.
51
+
52
+ Keyword Args:
53
+ delay_value (bool, optional): if ``True``, a target network is needed
54
+ for the critic. Defaults to ``False``. Incompatible with ``functional=False``.
55
+ loss_critic_type (str): loss function for the value discrepancy.
56
+ Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
57
+ advantage_key (str): [Deprecated, use .set_keys(advantage_key=advantage_key) instead]
58
+ The input tensordict key where the advantage is expected to be written.
59
+ Defaults to ``"advantage"``.
60
+ value_target_key (str): [Deprecated, use .set_keys(value_target_key=value_target_key) instead]
61
+ The input tensordict key where the target state
62
+ value is expected to be written. Defaults to ``"value_target"``.
63
+ separate_losses (bool, optional): if ``True``, shared parameters between
64
+ policy and critic will only be trained on the policy loss.
65
+ Defaults to ``False``, i.e., gradients are propagated to shared
66
+ parameters for both policy and critic losses.
67
+ functional (bool, optional): whether modules should be functionalized.
68
+ Functionalizing permits features like meta-RL, but makes it
69
+ impossible to use distributed models (DDP, FSDP, ...) and comes
70
+ with a little cost. Defaults to ``True``.
71
+ reduction (str, optional): Specifies the reduction to apply to the output:
72
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
73
+ ``"mean"``: the sum of the output will be divided by the number of
74
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
75
+ clip_value (:obj:`float`, optional): If provided, it will be used to compute a clipped version of the value
76
+ prediction with respect to the input tensordict value estimate and use it to calculate the value loss.
77
+ The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training
78
+ and preventing large updates. However, it will have no impact if the value estimate was done by the current
79
+ version of the value estimator. Defaults to ``None``.
80
+
81
+ .. note:
82
+ The advantage (typically GAE) can be computed by the loss function or
83
+ in the training loop. The latter option is usually preferred, but this is
84
+ up to the user to choose which option is to be preferred.
85
+ If the advantage key (``"advantage`` by default) is not present in the
86
+ input tensordict, the advantage will be computed by the :meth:`~.forward`
87
+ method.
88
+
89
+ >>> reinforce_loss = ReinforceLoss(actor, critic)
90
+ >>> advantage = GAE(critic)
91
+ >>> data = next(datacollector)
92
+ >>> losses = reinforce_loss(data)
93
+ >>> # equivalent
94
+ >>> advantage(data)
95
+ >>> losses = reinforce_loss(data)
96
+
97
+ A custom advantage module can be built using :meth:`~.make_value_estimator`.
98
+ The default is :class:`~torchrl.objectives.value.GAE` with hyperparameters
99
+ dictated by :func:`~torchrl.objectives.utils.default_value_kwargs`.
100
+
101
+ >>> reinforce_loss = ReinforceLoss(actor, critic)
102
+ >>> reinforce_loss.make_value_estimator(ValueEstimators.TDLambda)
103
+ >>> data = next(datacollector)
104
+ >>> losses = reinforce_loss(data)
105
+
106
+ Examples:
107
+ >>> import torch
108
+ >>> from torch import nn
109
+ >>> from torchrl.data.tensor_specs import Unbounded
110
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
111
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
112
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
113
+ >>> from torchrl.objectives.reinforce import ReinforceLoss
114
+ >>> from tensordict import TensorDict
115
+ >>> n_obs, n_act = 3, 5
116
+ >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
117
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
118
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
119
+ >>> actor_net = ProbabilisticActor(
120
+ ... module,
121
+ ... distribution_class=TanhNormal,
122
+ ... return_log_prob=True,
123
+ ... in_keys=["loc", "scale"],
124
+ ... spec=Unbounded(n_act),)
125
+ >>> loss = ReinforceLoss(actor_net, value_net)
126
+ >>> batch = 2
127
+ >>> data = TensorDict({
128
+ ... "observation": torch.randn(batch, n_obs),
129
+ ... "next": {
130
+ ... "observation": torch.randn(batch, n_obs),
131
+ ... "reward": torch.randn(batch, 1),
132
+ ... "done": torch.zeros(batch, 1, dtype=torch.bool),
133
+ ... "terminated": torch.zeros(batch, 1, dtype=torch.bool),
134
+ ... },
135
+ ... "action": torch.randn(batch, n_act),
136
+ ... }, [batch])
137
+ >>> loss(data)
138
+ TensorDict(
139
+ fields={
140
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
141
+ loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
142
+ batch_size=torch.Size([]),
143
+ device=None,
144
+ is_shared=False)
145
+
146
+ This class is compatible with non-tensordict based modules too and can be
147
+ used without recurring to any tensordict-related primitive. In this case,
148
+ the expected keyword arguments are:
149
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic network
150
+ The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value"]``.
151
+
152
+ Examples:
153
+ >>> import torch
154
+ >>> from torch import nn
155
+ >>> from torchrl.data.tensor_specs import Unbounded
156
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
157
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
158
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
159
+ >>> from torchrl.objectives.reinforce import ReinforceLoss
160
+ >>> n_obs, n_act = 3, 5
161
+ >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
162
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
163
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
164
+ >>> actor_net = ProbabilisticActor(
165
+ ... module,
166
+ ... distribution_class=TanhNormal,
167
+ ... return_log_prob=True,
168
+ ... in_keys=["loc", "scale"],
169
+ ... spec=Unbounded(n_act),)
170
+ >>> loss = ReinforceLoss(actor_net, value_net)
171
+ >>> batch = 2
172
+ >>> loss_actor, loss_value = loss(
173
+ ... observation=torch.randn(batch, n_obs),
174
+ ... next_observation=torch.randn(batch, n_obs),
175
+ ... next_reward=torch.randn(batch, 1),
176
+ ... next_done=torch.zeros(batch, 1, dtype=torch.bool),
177
+ ... next_terminated=torch.zeros(batch, 1, dtype=torch.bool),
178
+ ... action=torch.randn(batch, n_act),)
179
+ >>> loss_actor.backward()
180
+
181
+ """
182
+
183
+ @dataclass
184
+ class _AcceptedKeys:
185
+ """Maintains default values for all configurable tensordict keys.
186
+
187
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
188
+ default values.
189
+
190
+ Attributes:
191
+ advantage (NestedKey): he input tensordict key where the advantage is expected.
192
+ Will be used for the underlying value estimator. Defaults to ``"advantage"``.
193
+ value_target (NestedKey): The input tensordict key where the target state value is expected.
194
+ Will be used for the underlying value estimator Defaults to ``"value_target"``.
195
+ value (NestedKey): The input tensordict key where the state value is expected.
196
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
197
+ sample_log_prob (NestedKey): The input tensordict key where the sample log probability is expected.
198
+ Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
199
+ `"action_log_prob"` otherwise.
200
+ action (NestedKey): The input tensordict key where the action is expected.
201
+ Defaults to ``"action"``.
202
+ reward (NestedKey): The input tensordict key where the reward is expected.
203
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
204
+ done (NestedKey): The key in the input TensorDict that indicates
205
+ whether a trajectory is done. Will be used for the underlying value estimator.
206
+ Defaults to ``"done"``.
207
+ terminated (NestedKey): The key in the input TensorDict that indicates
208
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
209
+ Defaults to ``"terminated"``.
210
+ """
211
+
212
+ advantage: NestedKey = "advantage"
213
+ value_target: NestedKey = "value_target"
214
+ value: NestedKey = "state_value"
215
+ sample_log_prob: NestedKey | None = None
216
+ action: NestedKey = "action"
217
+ reward: NestedKey = "reward"
218
+ done: NestedKey = "done"
219
+ terminated: NestedKey = "terminated"
220
+
221
+ def __post_init__(self):
222
+ if self.sample_log_prob is None:
223
+ if composite_lp_aggregate(nowarn=True):
224
+ self.sample_log_prob = "sample_log_prob"
225
+ else:
226
+ self.sample_log_prob = "action_log_prob"
227
+
228
+ tensor_keys: _AcceptedKeys
229
+ default_keys = _AcceptedKeys
230
+ default_value_estimator = ValueEstimators.GAE
231
+ out_keys = ["loss_actor", "loss_value"]
232
+
233
+ actor_network: TensorDictModule
234
+ critic_network: TensorDictModule
235
+ actor_network_params: TensorDictParams | None
236
+ critic_network_params: TensorDictParams | None
237
+ target_actor_network_params: TensorDictParams | None
238
+ target_critic_network_params: TensorDictParams | None
239
+
240
+ @classmethod
241
+ def __new__(cls, *args, **kwargs):
242
+ cls._tensor_keys = cls._AcceptedKeys()
243
+ return super().__new__(cls)
244
+
245
+ def __init__(
246
+ self,
247
+ actor_network: ProbabilisticTensorDictSequential,
248
+ critic_network: TensorDictModule | None = None,
249
+ *,
250
+ delay_value: bool = False,
251
+ loss_critic_type: str = "smooth_l1",
252
+ gamma: float | None = None,
253
+ advantage_key: str | None = None,
254
+ value_target_key: str | None = None,
255
+ separate_losses: bool = False,
256
+ functional: bool = True,
257
+ actor: ProbabilisticTensorDictSequential = None,
258
+ critic: ProbabilisticTensorDictSequential = None,
259
+ reduction: str | None = None,
260
+ clip_value: float | None = None,
261
+ ) -> None:
262
+ if actor is not None:
263
+ actor_network = actor
264
+ del actor
265
+ if critic is not None:
266
+ critic_network = critic
267
+ del critic
268
+ if actor_network is None or critic_network is None:
269
+ raise TypeError(
270
+ "Missing positional arguments actor_network or critic_network."
271
+ )
272
+ if not functional and delay_value:
273
+ raise RuntimeError(
274
+ "delay_value and ~functional are incompatible, as delayed value currently relies on functional calls."
275
+ )
276
+ if reduction is None:
277
+ reduction = "mean"
278
+
279
+ self._functional = functional
280
+
281
+ super().__init__()
282
+ self.in_keys = None
283
+ self._set_deprecated_ctor_keys(
284
+ advantage=advantage_key, value_target=value_target_key
285
+ )
286
+
287
+ self.delay_value = delay_value
288
+ self.loss_critic_type = loss_critic_type
289
+ self.reduction = reduction
290
+
291
+ # Actor
292
+ if self.functional:
293
+ self.convert_to_functional(
294
+ actor_network,
295
+ "actor_network",
296
+ create_target_params=False,
297
+ )
298
+ else:
299
+ self.actor_network = actor_network
300
+
301
+ if separate_losses:
302
+ # we want to make sure there are no duplicates in the params: the
303
+ # params of critic must be refs to actor if they're shared
304
+ policy_params = list(actor_network.parameters())
305
+ else:
306
+ policy_params = None
307
+ # Value
308
+ if critic_network is not None:
309
+ if self.functional:
310
+ self.convert_to_functional(
311
+ critic_network,
312
+ "critic_network",
313
+ create_target_params=self.delay_value,
314
+ compare_against=policy_params,
315
+ )
316
+ else:
317
+ self.critic_network = critic_network
318
+ self.target_critic_network_params = None
319
+
320
+ if gamma is not None:
321
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
322
+
323
+ if clip_value is not None:
324
+ if isinstance(clip_value, float):
325
+ clip_value = torch.tensor(clip_value)
326
+ elif isinstance(clip_value, torch.Tensor):
327
+ if clip_value.numel() != 1:
328
+ raise ValueError(
329
+ f"clip_value must be a float or a scalar tensor, got {clip_value}."
330
+ )
331
+ else:
332
+ raise ValueError(
333
+ f"clip_value must be a float or a scalar tensor, got {clip_value}."
334
+ )
335
+ self.register_buffer("clip_value", clip_value)
336
+
337
+ @property
338
+ def functional(self):
339
+ return self._functional
340
+
341
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
342
+ if self._value_estimator is not None:
343
+ self._value_estimator.set_keys(
344
+ advantage=self.tensor_keys.advantage,
345
+ value_target=self.tensor_keys.value_target,
346
+ value=self.tensor_keys.value,
347
+ reward=self.tensor_keys.reward,
348
+ done=self.tensor_keys.done,
349
+ terminated=self.tensor_keys.terminated,
350
+ )
351
+ self._set_in_keys()
352
+
353
+ def _set_in_keys(self):
354
+ keys = [
355
+ self.tensor_keys.action,
356
+ ("next", self.tensor_keys.reward),
357
+ ("next", self.tensor_keys.done),
358
+ ("next", self.tensor_keys.terminated),
359
+ *self.actor_network.in_keys,
360
+ *[("next", key) for key in self.actor_network.in_keys],
361
+ *self.critic_network.in_keys,
362
+ ]
363
+ self._in_keys = list(set(keys))
364
+
365
+ @property
366
+ def in_keys(self):
367
+ if self._in_keys is None:
368
+ self._set_in_keys()
369
+ return self._in_keys
370
+
371
+ @in_keys.setter
372
+ def in_keys(self, values):
373
+ self._in_keys = values
374
+
375
+ @dispatch
376
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
377
+ advantage = tensordict.get(self.tensor_keys.advantage, None)
378
+ if advantage is None:
379
+ self.value_estimator(
380
+ tensordict,
381
+ params=self.critic_network_params.detach() if self.functional else None,
382
+ target_params=self.target_critic_network_params
383
+ if self.functional
384
+ else None,
385
+ )
386
+ advantage = tensordict.get(self.tensor_keys.advantage)
387
+
388
+ # compute log-prob
389
+ with self.actor_network_params.to_module(
390
+ self.actor_network
391
+ ) if self.functional else contextlib.nullcontext():
392
+ tensordict = self.actor_network(tensordict)
393
+
394
+ log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
395
+ if log_prob.shape == advantage.shape[:-1]:
396
+ log_prob = log_prob.unsqueeze(-1)
397
+ loss_actor = -log_prob * advantage.detach()
398
+ td_out = TensorDict({"loss_actor": loss_actor}, batch_size=[])
399
+
400
+ loss_value, value_clip_fraction = self.loss_critic(tensordict)
401
+ td_out.set("loss_value", loss_value)
402
+ if value_clip_fraction is not None:
403
+ td_out.set("value_clip_fraction", value_clip_fraction)
404
+ td_out = td_out.named_apply(
405
+ lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
406
+ if name.startswith("loss_")
407
+ else value,
408
+ )
409
+ self._clear_weakrefs(
410
+ tensordict,
411
+ td_out,
412
+ "actor_network_params",
413
+ "critic_network_params",
414
+ "target_actor_network_params",
415
+ "target_critic_network_params",
416
+ )
417
+ return td_out
418
+
419
+ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
420
+
421
+ if self.clip_value:
422
+ old_state_value = tensordict.get(
423
+ self.tensor_keys.value, None
424
+ ) # TODO: None soon to be removed
425
+ if old_state_value is None:
426
+ raise KeyError(
427
+ f"clip_value is set to {self.clip_value}, but "
428
+ f"the key {self.tensor_keys.value} was not found in the input tensordict. "
429
+ f"Make sure that the value_key passed to Reinforce exists in the input tensordict."
430
+ )
431
+ old_state_value = old_state_value.clone()
432
+
433
+ target_return = tensordict.get(
434
+ self.tensor_keys.value_target, None
435
+ ) # TODO: None soon to be removed
436
+ if target_return is None:
437
+ raise KeyError(
438
+ f"the key {self.tensor_keys.value_target} was not found in the input tensordict. "
439
+ f"Make sure you provided the right key and the value_target (i.e. the target "
440
+ f"return) has been retrieved accordingly. Advantage classes such as GAE, "
441
+ f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
442
+ f"can be used for the value loss."
443
+ )
444
+
445
+ tensordict_select = tensordict.select(
446
+ *self.critic_network.in_keys, strict=False
447
+ )
448
+ with self.critic_network_params.to_module(
449
+ self.critic_network
450
+ ) if self.functional else contextlib.nullcontext():
451
+ state_value = self.critic_network(tensordict_select).get(
452
+ self.tensor_keys.value
453
+ )
454
+ loss_value = distance_loss(
455
+ target_return,
456
+ state_value,
457
+ loss_function=self.loss_critic_type,
458
+ )
459
+ clip_fraction = None
460
+ if self.clip_value:
461
+ loss_value, clip_fraction = _clip_value_loss(
462
+ old_state_value,
463
+ state_value,
464
+ self.clip_value.to(state_value.device),
465
+ target_return,
466
+ loss_value,
467
+ self.loss_critic_type,
468
+ )
469
+ self._clear_weakrefs(
470
+ tensordict,
471
+ "actor_network_params",
472
+ "critic_network_params",
473
+ "target_actor_network_params",
474
+ "target_critic_network_params",
475
+ )
476
+
477
+ return loss_value, clip_fraction
478
+
479
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
480
+ if value_type is None:
481
+ value_type = self.default_value_estimator
482
+
483
+ # Handle ValueEstimatorBase instance or class
484
+ if isinstance(value_type, ValueEstimatorBase) or (
485
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
486
+ ):
487
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
488
+
489
+ self.value_type = value_type
490
+ hp = dict(default_value_kwargs(value_type))
491
+ if hasattr(self, "gamma"):
492
+ hp["gamma"] = self.gamma
493
+ hp.update(hyperparams)
494
+ if value_type == ValueEstimators.TD1:
495
+ self._value_estimator = TD1Estimator(
496
+ value_network=self.critic_network, **hp
497
+ )
498
+ elif value_type == ValueEstimators.TD0:
499
+ self._value_estimator = TD0Estimator(
500
+ value_network=self.critic_network, **hp
501
+ )
502
+ elif value_type == ValueEstimators.GAE:
503
+ self._value_estimator = GAE(value_network=self.critic_network, **hp)
504
+ elif value_type == ValueEstimators.TDLambda:
505
+ self._value_estimator = TDLambdaEstimator(
506
+ value_network=self.critic_network, **hp
507
+ )
508
+ elif value_type == ValueEstimators.VTrace:
509
+ # VTrace currently does not support functional call on the actor
510
+ if self.functional:
511
+ actor_with_params = deepcopy(self.actor_network)
512
+ self.actor_network_params.to_module(actor_with_params)
513
+ else:
514
+ actor_with_params = self.actor_network
515
+ self._value_estimator = VTrace(
516
+ value_network=self.critic_network, actor_network=actor_with_params, **hp
517
+ )
518
+ else:
519
+ raise NotImplementedError(f"Unknown value type {value_type}")
520
+
521
+ tensor_keys = {
522
+ "advantage": self.tensor_keys.advantage,
523
+ "value": self.tensor_keys.value,
524
+ "value_target": self.tensor_keys.value_target,
525
+ "reward": self.tensor_keys.reward,
526
+ "done": self.tensor_keys.done,
527
+ "terminated": self.tensor_keys.terminated,
528
+ "sample_log_prob": self.tensor_keys.sample_log_prob,
529
+ }
530
+ self._value_estimator.set_keys(**tensor_keys)