torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,516 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from numbers import Number
10
+
11
+ import numpy as np
12
+ import torch
13
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
14
+ from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule
15
+ from tensordict.utils import NestedKey
16
+ from torch import Tensor
17
+
18
+ from torchrl.data.tensor_specs import Composite
19
+ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
20
+ from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators
21
+ from torchrl.objectives.common import LossModule
22
+ from torchrl.objectives.utils import (
23
+ _cache_values,
24
+ _GAMMA_LMBDA_DEPREC_ERROR,
25
+ _reduce,
26
+ _vmap_func,
27
+ )
28
+ from torchrl.objectives.value import (
29
+ TD0Estimator,
30
+ TD1Estimator,
31
+ TDLambdaEstimator,
32
+ ValueEstimatorBase,
33
+ )
34
+
35
+
36
+ class REDQLoss_deprecated(LossModule):
37
+ """REDQ Loss module.
38
+
39
+ REDQ (RANDOMIZED ENSEMBLED DOUBLE Q-LEARNING: LEARNING FAST WITHOUT A MODEL
40
+ https://openreview.net/pdf?id=AY8zfZm0tDd) generalizes the idea of using an ensemble of Q-value functions to
41
+ train a SAC-like algorithm.
42
+
43
+ Args:
44
+ actor_network (TensorDictModule): the actor to be trained
45
+ qvalue_network (TensorDictModule): a single Q-value network that will
46
+ be multiplied as many times as needed.
47
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
48
+ times. If a list of modules is passed, their
49
+ parameters will be stacked unless they share the same identity (in which case
50
+ the original parameter will be expanded).
51
+
52
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
53
+ and all the parameters will be considered as untied.
54
+
55
+ Keyword Args:
56
+ num_qvalue_nets (int, optional): Number of Q-value networks to be trained.
57
+ Default is ``10``.
58
+ sub_sample_len (int, optional): number of Q-value networks to be
59
+ subsampled to evaluate the next state value
60
+ Default is ``2``.
61
+ loss_function (str, optional): loss function to be used for the Q-value.
62
+ Can be one of ``"smooth_l1"``, ``"l2"``,
63
+ ``"l1"``, Default is ``"smooth_l1"``.
64
+ alpha_init (:obj:`float`, optional): initial entropy multiplier.
65
+ Default is ``1.0``.
66
+ min_alpha (:obj:`float`, optional): min value of alpha.
67
+ Default is ``0.1``.
68
+ max_alpha (:obj:`float`, optional): max value of alpha.
69
+ Default is ``10.0``.
70
+ action_spec (TensorSpec, optional): the action tensor spec. If not provided
71
+ and the target entropy is ``"auto"``, it will be retrieved from
72
+ the actor.
73
+ fixed_alpha (bool, optional): whether alpha should be trained to match
74
+ a target entropy. Default is ``False``.
75
+ target_entropy (Union[str, Number], optional): Target entropy for the
76
+ stochastic policy. Default is "auto".
77
+ delay_qvalue (bool, optional): Whether to separate the target Q value
78
+ networks from the Q value networks used
79
+ for data collection. Default is ``False``.
80
+ gSDE (bool, optional): Knowing if gSDE is used is necessary to create
81
+ random noise variables.
82
+ Default is ``False``.
83
+ priority_key (str, optional): [Deprecated] Key where to write the priority value
84
+ for prioritized replay buffers. Default is
85
+ ``"td_error"``.
86
+ separate_losses (bool, optional): if ``True``, shared parameters between
87
+ policy and critic will only be trained on the policy loss.
88
+ Defaults to ``False``, i.e., gradients are propagated to shared
89
+ parameters for both policy and critic losses.
90
+ reduction (str, optional): Specifies the reduction to apply to the output:
91
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
92
+ ``"mean"``: the sum of the output will be divided by the number of
93
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
94
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
95
+ Defaults to ``False``.
96
+ """
97
+
98
+ @dataclass
99
+ class _AcceptedKeys:
100
+ """Maintains default values for all configurable tensordict keys.
101
+
102
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
103
+ default values.
104
+
105
+ Attributes:
106
+ action (NestedKey): The input tensordict key where the action is expected.
107
+ Defaults to ``"advantage"``.
108
+ value (NestedKey): The input tensordict key where the state value is expected.
109
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
110
+ state_action_value (NestedKey): The input tensordict key where the
111
+ state action value is expected. Defaults to ``"state_action_value"``.
112
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
113
+ Defaults to ``"_log_prob"``.
114
+ priority (NestedKey): The input tensordict key where the target priority is written to.
115
+ Defaults to ``"td_error"``.
116
+ reward (NestedKey): The input tensordict key where the reward is expected.
117
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
118
+ done (NestedKey): The key in the input TensorDict that indicates
119
+ whether a trajectory is done. Will be used for the underlying value estimator.
120
+ Defaults to ``"done"``.
121
+ terminated (NestedKey): The key in the input TensorDict that indicates
122
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
123
+ Defaults to ``"terminated"``.
124
+ """
125
+
126
+ action: NestedKey = "action"
127
+ state_action_value: NestedKey = "state_action_value"
128
+ value: NestedKey = "state_value"
129
+ log_prob: NestedKey | None = None
130
+ priority: NestedKey = "td_error"
131
+ reward: NestedKey = "reward"
132
+ done: NestedKey = "done"
133
+ terminated: NestedKey = "terminated"
134
+
135
+ def __post_init__(self):
136
+ if self.log_prob is None:
137
+ if composite_lp_aggregate(nowarn=True):
138
+ self.log_prob = "sample_log_prob"
139
+ else:
140
+ self.log_prob = "action_log_prob"
141
+
142
+ tensor_keys: _AcceptedKeys
143
+ default_keys = _AcceptedKeys
144
+ delay_actor: bool = False
145
+ default_value_estimator = ValueEstimators.TD0
146
+
147
+ actor_network: TensorDictModule
148
+ qvalue_network: TensorDictModule
149
+ actor_network_params: TensorDictParams
150
+ qvalue_network_params: TensorDictParams
151
+ target_actor_network_params: TensorDictParams
152
+ target_qvalue_network_params: TensorDictParams
153
+
154
+ def __init__(
155
+ self,
156
+ actor_network: TensorDictModule,
157
+ qvalue_network: TensorDictModule | list[TensorDictModule],
158
+ *,
159
+ num_qvalue_nets: int = 10,
160
+ sub_sample_len: int = 2,
161
+ loss_function: str = "smooth_l1",
162
+ alpha_init: float = 1.0,
163
+ min_alpha: float = 0.1,
164
+ max_alpha: float = 10.0,
165
+ action_spec=None,
166
+ fixed_alpha: bool = False,
167
+ target_entropy: str | Number = "auto",
168
+ delay_qvalue: bool = True,
169
+ gSDE: bool = False,
170
+ gamma: float | None = None,
171
+ priority_key: str | None = None,
172
+ separate_losses: bool = False,
173
+ reduction: str | None = None,
174
+ deactivate_vmap: bool = False,
175
+ ):
176
+ self._in_keys = None
177
+ self._out_keys = None
178
+ if reduction is None:
179
+ reduction = "mean"
180
+ super().__init__()
181
+ self._set_deprecated_ctor_keys(priority_key=priority_key)
182
+
183
+ self.deactivate_vmap = deactivate_vmap
184
+
185
+ self.convert_to_functional(
186
+ actor_network,
187
+ "actor_network",
188
+ create_target_params=self.delay_actor,
189
+ )
190
+ if separate_losses:
191
+ # we want to make sure there are no duplicates in the params: the
192
+ # params of critic must be refs to actor if they're shared
193
+ policy_params = list(actor_network.parameters())
194
+ else:
195
+ policy_params = None
196
+ # let's make sure that actor_network has `return_log_prob` set to True
197
+ self.actor_network.return_log_prob = True
198
+
199
+ self.delay_qvalue = delay_qvalue
200
+ self.convert_to_functional(
201
+ qvalue_network,
202
+ "qvalue_network",
203
+ expand_dim=num_qvalue_nets,
204
+ create_target_params=self.delay_qvalue,
205
+ compare_against=policy_params,
206
+ )
207
+ self.num_qvalue_nets = num_qvalue_nets
208
+ self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1))
209
+ self.loss_function = loss_function
210
+
211
+ try:
212
+ device = next(self.parameters()).device
213
+ except AttributeError:
214
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
215
+
216
+ self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
217
+ self.register_buffer(
218
+ "min_log_alpha", torch.as_tensor(min_alpha, device=device).log()
219
+ )
220
+ self.register_buffer(
221
+ "max_log_alpha", torch.as_tensor(max_alpha, device=device).log()
222
+ )
223
+ self.fixed_alpha = fixed_alpha
224
+ if fixed_alpha:
225
+ self.register_buffer(
226
+ "log_alpha", torch.as_tensor(math.log(alpha_init), device=device)
227
+ )
228
+ else:
229
+ self.register_parameter(
230
+ "log_alpha",
231
+ torch.nn.Parameter(
232
+ torch.as_tensor(math.log(alpha_init), device=device)
233
+ ),
234
+ )
235
+
236
+ self._target_entropy = target_entropy
237
+ self._action_spec = action_spec
238
+ self.target_entropy_buffer = None
239
+ self.gSDE = gSDE
240
+ self._make_vmap()
241
+ self.reduction = reduction
242
+
243
+ if gamma is not None:
244
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
245
+
246
+ def _make_vmap(self):
247
+ self._vmap_qvalue_networkN0 = _vmap_func(
248
+ self.qvalue_network, (None, 0), pseudo_vmap=self.deactivate_vmap
249
+ )
250
+
251
+ @property
252
+ def target_entropy(self):
253
+ target_entropy = self.target_entropy_buffer
254
+ if target_entropy is None:
255
+ delattr(self, "target_entropy_buffer")
256
+ target_entropy = self._target_entropy
257
+ action_spec = self._action_spec
258
+ actor_network = self.actor_network
259
+ device = next(self.parameters()).device
260
+ if target_entropy == "auto":
261
+ action_spec = (
262
+ action_spec
263
+ if action_spec is not None
264
+ else getattr(actor_network, "spec", None)
265
+ )
266
+ if action_spec is None:
267
+ raise RuntimeError(
268
+ "Cannot infer the dimensionality of the action. Consider providing "
269
+ "the target entropy explicitly or provide the spec of the "
270
+ "action tensor in the actor network."
271
+ )
272
+ if not isinstance(action_spec, Composite):
273
+ action_spec = Composite({self.tensor_keys.action: action_spec})
274
+ target_entropy = -float(
275
+ np.prod(action_spec[self.tensor_keys.action].shape)
276
+ )
277
+ self.register_buffer(
278
+ "target_entropy_buffer", torch.as_tensor(target_entropy, device=device)
279
+ )
280
+ return self.target_entropy_buffer
281
+ return target_entropy
282
+
283
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
284
+ if self._value_estimator is not None:
285
+ self._value_estimator.set_keys(
286
+ value=self.tensor_keys.value,
287
+ reward=self.tensor_keys.reward,
288
+ done=self.tensor_keys.done,
289
+ terminated=self.tensor_keys.terminated,
290
+ )
291
+ self._set_in_keys()
292
+
293
+ @property
294
+ def alpha(self):
295
+ with torch.no_grad():
296
+ # keep alpha is a reasonable range
297
+ alpha = self.log_alpha.clamp(self.min_log_alpha, self.max_log_alpha).exp()
298
+ return alpha
299
+
300
+ def _set_in_keys(self):
301
+ keys = [
302
+ self.tensor_keys.action,
303
+ ("next", self.tensor_keys.reward),
304
+ ("next", self.tensor_keys.done),
305
+ ("next", self.tensor_keys.terminated),
306
+ *self.actor_network.in_keys,
307
+ *[("next", key) for key in self.actor_network.in_keys],
308
+ *self.qvalue_network.in_keys,
309
+ ]
310
+ self._in_keys = list(set(keys))
311
+
312
+ @property
313
+ def in_keys(self):
314
+ if self._in_keys is None:
315
+ self._set_in_keys()
316
+ return self._in_keys
317
+
318
+ @in_keys.setter
319
+ def in_keys(self, values):
320
+ self._in_keys = values
321
+
322
+ @property
323
+ def out_keys(self):
324
+ if self._out_keys is None:
325
+ keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
326
+ self._out_keys = keys
327
+ return self._out_keys
328
+
329
+ @out_keys.setter
330
+ def out_keys(self, values):
331
+ self._out_keys = values
332
+
333
+ @dispatch
334
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
335
+ loss_actor, sample_log_prob = self._actor_loss(tensordict)
336
+
337
+ loss_qval = self._qvalue_loss(tensordict)
338
+ loss_alpha = self._loss_alpha(sample_log_prob)
339
+ if not loss_qval.shape == loss_actor.shape:
340
+ raise RuntimeError(
341
+ f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
342
+ )
343
+ td_out = TensorDict(
344
+ {
345
+ "loss_actor": loss_actor,
346
+ "loss_qvalue": loss_qval,
347
+ "loss_alpha": loss_alpha,
348
+ "alpha": self.alpha,
349
+ "entropy": -sample_log_prob.detach().mean(),
350
+ },
351
+ [],
352
+ )
353
+ td_out = td_out.named_apply(
354
+ lambda name, value: _reduce(value, reduction=self.reduction)
355
+ if name.startswith("loss_")
356
+ else value,
357
+ batch_size=[],
358
+ )
359
+ self._clear_weakrefs(
360
+ tensordict,
361
+ td_out,
362
+ "actor_network_params",
363
+ "qvalue_network_params",
364
+ "target_actor_network_params",
365
+ "target_qvalue_network_params",
366
+ )
367
+ return td_out
368
+
369
+ @property
370
+ @_cache_values
371
+ def _cached_detach_qvalue_network_params(self):
372
+ return self.qvalue_network_params.detach()
373
+
374
+ def _actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, Tensor]:
375
+ obs_keys = self.actor_network.in_keys
376
+ tensordict_clone = tensordict.select(*obs_keys, strict=False)
377
+ with set_exploration_type(
378
+ ExplorationType.RANDOM
379
+ ), self.actor_network_params.to_module(self.actor_network):
380
+ self.actor_network(tensordict_clone)
381
+
382
+ tensordict_expand = self._vmap_qvalue_networkN0(
383
+ tensordict_clone.select(*self.qvalue_network.in_keys, strict=False),
384
+ self._cached_detach_qvalue_network_params,
385
+ )
386
+ state_action_value = tensordict_expand.get(
387
+ self.tensor_keys.state_action_value
388
+ ).squeeze(-1)
389
+ loss_actor = -(
390
+ state_action_value
391
+ - self.alpha * tensordict_clone.get(self.tensor_keys.log_prob).squeeze(-1)
392
+ )
393
+ return loss_actor, tensordict_clone.get(self.tensor_keys.log_prob)
394
+
395
+ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:
396
+ tensordict_save = tensordict
397
+
398
+ obs_keys = self.actor_network.in_keys
399
+ tensordict = tensordict.select(
400
+ "next", *obs_keys, self.tensor_keys.action, strict=False
401
+ ).clone(False)
402
+
403
+ selected_models_idx = torch.randperm(self.num_qvalue_nets)[
404
+ : self.sub_sample_len
405
+ ].sort()[0]
406
+ with torch.no_grad():
407
+ selected_q_params = self.target_qvalue_network_params[selected_models_idx]
408
+
409
+ next_td = step_mdp(tensordict).select(
410
+ *self.actor_network.in_keys, strict=False
411
+ ) # next_observation ->
412
+ # observation
413
+ # select pseudo-action
414
+ with set_exploration_type(
415
+ ExplorationType.RANDOM
416
+ ), self.target_actor_network_params.to_module(self.actor_network):
417
+ self.actor_network(next_td)
418
+ sample_log_prob = next_td.get(self.tensor_keys.log_prob)
419
+ # get q-values
420
+ next_td = self._vmap_qvalue_networkN0(
421
+ next_td,
422
+ selected_q_params,
423
+ )
424
+ state_action_value = next_td.get(self.tensor_keys.state_action_value)
425
+ if (
426
+ state_action_value.shape[-len(sample_log_prob.shape) :]
427
+ != sample_log_prob.shape
428
+ ):
429
+ sample_log_prob = sample_log_prob.unsqueeze(-1)
430
+ next_state_value = (
431
+ next_td.get(self.tensor_keys.state_action_value)
432
+ - self.alpha * sample_log_prob
433
+ )
434
+ next_state_value = next_state_value.min(0)[0]
435
+
436
+ tensordict.set(("next", self.tensor_keys.value), next_state_value)
437
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
438
+ tensordict_expand = self._vmap_qvalue_networkN0(
439
+ tensordict.select(*self.qvalue_network.in_keys, strict=False),
440
+ self.qvalue_network_params,
441
+ )
442
+ pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
443
+ -1
444
+ )
445
+ td_error = abs(pred_val - target_value)
446
+ loss_qval = distance_loss(
447
+ pred_val,
448
+ target_value.expand_as(pred_val),
449
+ loss_function=self.loss_function,
450
+ )
451
+ tensordict_save.set("td_error", td_error.detach().max(0)[0])
452
+ return loss_qval
453
+
454
+ def _loss_alpha(self, log_pi: Tensor) -> Tensor:
455
+ if torch.is_grad_enabled() and not log_pi.requires_grad:
456
+ raise RuntimeError(
457
+ "expected log_pi to require gradient for the alpha loss)"
458
+ )
459
+ if self.target_entropy is not None:
460
+ # we can compute this loss even if log_alpha is not a parameter
461
+ alpha_loss = -self.log_alpha.clamp(
462
+ self.min_log_alpha, self.max_log_alpha
463
+ ).exp() * (log_pi.detach() + self.target_entropy)
464
+ else:
465
+ # placeholder
466
+ alpha_loss = torch.zeros_like(log_pi)
467
+ return alpha_loss
468
+
469
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
470
+ if value_type is None:
471
+ value_type = self.default_value_estimator
472
+
473
+ # Handle ValueEstimatorBase instance or class
474
+ if isinstance(value_type, ValueEstimatorBase) or (
475
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
476
+ ):
477
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
478
+
479
+ self.value_type = value_type
480
+ hp = dict(default_value_kwargs(value_type))
481
+ if hasattr(self, "gamma"):
482
+ hp["gamma"] = self.gamma
483
+ hp.update(hyperparams)
484
+ # we do not need a value network bc the next state value is already passed
485
+ if value_type == ValueEstimators.TD1:
486
+ self._value_estimator = TD1Estimator(value_network=None, **hp)
487
+ elif value_type == ValueEstimators.TD0:
488
+ self._value_estimator = TD0Estimator(value_network=None, **hp)
489
+ elif value_type == ValueEstimators.GAE:
490
+ raise NotImplementedError(
491
+ f"Value type {value_type} it not implemented for loss {type(self)}."
492
+ )
493
+ elif value_type == ValueEstimators.TDLambda:
494
+ self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
495
+ else:
496
+ raise NotImplementedError(f"Unknown value type {value_type}")
497
+ tensor_keys = {
498
+ "value": self.tensor_keys.value,
499
+ "reward": self.tensor_keys.reward,
500
+ "done": self.tensor_keys.done,
501
+ "terminated": self.tensor_keys.terminated,
502
+ }
503
+ self._value_estimator.set_keys(**tensor_keys)
504
+
505
+
506
+ class DoubleREDQLoss_deprecated(REDQLoss_deprecated):
507
+ """[Deprecated] Class for delayed target-REDQ (which should be the default behavior)."""
508
+
509
+ delay_qvalue: bool = True
510
+
511
+ actor_network: TensorDictModule
512
+ qvalue_network: TensorDictModule
513
+ actor_network_params: TensorDictParams
514
+ qvalue_network_params: TensorDictParams
515
+ target_actor_network_params: TensorDictParams
516
+ target_qvalue_network_params: TensorDictParams