torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,453 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from copy import deepcopy
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
13
+ from tensordict.nn import dispatch, TensorDictModule
14
+ from tensordict.utils import NestedKey, unravel_key
15
+
16
+ from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
17
+ from torchrl.objectives.common import LossModule
18
+ from torchrl.objectives.utils import (
19
+ _cache_values,
20
+ _GAMMA_LMBDA_DEPREC_ERROR,
21
+ _reduce,
22
+ default_value_kwargs,
23
+ distance_loss,
24
+ ValueEstimators,
25
+ )
26
+ from torchrl.objectives.value import (
27
+ TD0Estimator,
28
+ TD1Estimator,
29
+ TDLambdaEstimator,
30
+ ValueEstimatorBase,
31
+ )
32
+
33
+
34
+ class DDPGLoss(LossModule):
35
+ """The DDPG Loss class.
36
+
37
+ Args:
38
+ actor_network (TensorDictModule): a policy operator.
39
+ value_network (TensorDictModule): a Q value operator.
40
+ loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
41
+ delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
42
+ data collection. Default is ``False``.
43
+ delay_value (bool, optional): whether to separate the target value networks from the value networks used for
44
+ data collection. Default is ``True``.
45
+ separate_losses (bool, optional): if ``True``, shared parameters between
46
+ policy and critic will only be trained on the policy loss.
47
+ Defaults to ``False``, i.e., gradients are propagated to shared
48
+ parameters for both policy and critic losses.
49
+ reduction (str, optional): Specifies the reduction to apply to the output:
50
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
51
+ ``"mean"``: the sum of the output will be divided by the number of
52
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
53
+
54
+ Examples:
55
+ >>> import torch
56
+ >>> from torch import nn
57
+ >>> from torchrl.data import Bounded
58
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
59
+ >>> from torchrl.objectives.ddpg import DDPGLoss
60
+ >>> from tensordict import TensorDict
61
+ >>> n_act, n_obs = 4, 3
62
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
63
+ >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
64
+ >>> class ValueClass(nn.Module):
65
+ ... def __init__(self):
66
+ ... super().__init__()
67
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
68
+ ... def forward(self, obs, act):
69
+ ... return self.linear(torch.cat([obs, act], -1))
70
+ >>> module = ValueClass()
71
+ >>> value = ValueOperator(
72
+ ... module=module,
73
+ ... in_keys=["observation", "action"])
74
+ >>> loss = DDPGLoss(actor, value)
75
+ >>> batch = [2, ]
76
+ >>> data = TensorDict({
77
+ ... "observation": torch.randn(*batch, n_obs),
78
+ ... "action": spec.rand(batch),
79
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
80
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
81
+ ... ("next", "reward"): torch.randn(*batch, 1),
82
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
83
+ ... }, batch)
84
+ >>> loss(data)
85
+ TensorDict(
86
+ fields={
87
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
88
+ loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
89
+ pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
90
+ pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
91
+ target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
92
+ target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
93
+ batch_size=torch.Size([]),
94
+ device=None,
95
+ is_shared=False)
96
+
97
+ This class is compatible with non-tensordict based modules too and can be
98
+ used without recurring to any tensordict-related primitive. In this case,
99
+ the expected keyword arguments are:
100
+ ``["next_reward", "next_done", "next_terminated"]`` + in_keys of the actor_network and value_network.
101
+ The return value is a tuple of tensors in the following order:
102
+ ``["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]``
103
+
104
+ Examples:
105
+ >>> import torch
106
+ >>> from torch import nn
107
+ >>> from torchrl.data import Bounded
108
+ >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
109
+ >>> from torchrl.objectives.ddpg import DDPGLoss
110
+ >>> _ = torch.manual_seed(42)
111
+ >>> n_act, n_obs = 4, 3
112
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
113
+ >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
114
+ >>> class ValueClass(nn.Module):
115
+ ... def __init__(self):
116
+ ... super().__init__()
117
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
118
+ ... def forward(self, obs, act):
119
+ ... return self.linear(torch.cat([obs, act], -1))
120
+ >>> module = ValueClass()
121
+ >>> value = ValueOperator(
122
+ ... module=module,
123
+ ... in_keys=["observation", "action"])
124
+ >>> loss = DDPGLoss(actor, value)
125
+ >>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss(
126
+ ... observation=torch.randn(n_obs),
127
+ ... action=spec.rand(),
128
+ ... next_done=torch.zeros(1, dtype=torch.bool),
129
+ ... next_terminated=torch.zeros(1, dtype=torch.bool),
130
+ ... next_observation=torch.randn(n_obs),
131
+ ... next_reward=torch.randn(1))
132
+ >>> loss_actor.backward()
133
+
134
+ The output keys can also be filtered using the :meth:`DDPGLoss.select_out_keys`
135
+ method.
136
+
137
+ Examples:
138
+ >>> loss.select_out_keys('loss_actor', 'loss_value')
139
+ >>> loss_actor, loss_value = loss(
140
+ ... observation=torch.randn(n_obs),
141
+ ... action=spec.rand(),
142
+ ... next_done=torch.zeros(1, dtype=torch.bool),
143
+ ... next_terminated=torch.zeros(1, dtype=torch.bool),
144
+ ... next_observation=torch.randn(n_obs),
145
+ ... next_reward=torch.randn(1))
146
+ >>> loss_actor.backward()
147
+
148
+ """
149
+
150
+ @dataclass
151
+ class _AcceptedKeys:
152
+ """Maintains default values for all configurable tensordict keys.
153
+
154
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
155
+ default values.
156
+
157
+ Attributes:
158
+ state_action_value (NestedKey): The input tensordict key where the
159
+ state action value is expected. Will be used for the underlying
160
+ value estimator as value key. Defaults to ``"state_action_value"``.
161
+ priority (NestedKey): The input tensordict key where the target
162
+ priority is written to. Defaults to ``"td_error"``.
163
+ reward (NestedKey): The input tensordict key where the reward is expected.
164
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
165
+ done (NestedKey): The key in the input TensorDict that indicates
166
+ whether a trajectory is done. Will be used for the underlying value estimator.
167
+ Defaults to ``"done"``.
168
+ terminated (NestedKey): The key in the input TensorDict that indicates
169
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
170
+ Defaults to ``"terminated"``.
171
+
172
+ """
173
+
174
+ state_action_value: NestedKey = "state_action_value"
175
+ priority: NestedKey = "td_error"
176
+ reward: NestedKey = "reward"
177
+ done: NestedKey = "done"
178
+ terminated: NestedKey = "terminated"
179
+ priority_weight: NestedKey = "priority_weight"
180
+
181
+ tensor_keys: _AcceptedKeys
182
+ default_keys = _AcceptedKeys
183
+ default_value_estimator: ValueEstimators = ValueEstimators.TD0
184
+ out_keys = [
185
+ "loss_actor",
186
+ "loss_value",
187
+ "pred_value",
188
+ "target_value",
189
+ "pred_value_max",
190
+ "target_value_max",
191
+ ]
192
+
193
+ actor_network: TensorDictModule
194
+ value_network: actor_network
195
+ actor_network_params: TensorDictParams
196
+ value_network_params: TensorDictParams
197
+ target_actor_network_params: TensorDictParams
198
+ target_value_network_params: TensorDictParams
199
+
200
+ def __init__(
201
+ self,
202
+ actor_network: TensorDictModule,
203
+ value_network: TensorDictModule,
204
+ *,
205
+ loss_function: str = "l2",
206
+ delay_actor: bool = False,
207
+ delay_value: bool = True,
208
+ gamma: float | None = None,
209
+ separate_losses: bool = False,
210
+ reduction: str | None = None,
211
+ use_prioritized_weights: str | bool = "auto",
212
+ ) -> None:
213
+ self._in_keys = None
214
+ if reduction is None:
215
+ reduction = "mean"
216
+ super().__init__()
217
+ self.use_prioritized_weights = use_prioritized_weights
218
+ self.delay_actor = delay_actor
219
+ self.delay_value = delay_value
220
+
221
+ actor_critic = ActorCriticWrapper(actor_network, value_network)
222
+ params = TensorDict.from_module(actor_critic)
223
+ params_meta = params.apply(
224
+ self._make_meta_params, device=torch.device("meta"), filter_empty=False
225
+ )
226
+ with params_meta.to_module(actor_critic):
227
+ self.__dict__["actor_critic"] = deepcopy(actor_critic)
228
+
229
+ self.convert_to_functional(
230
+ actor_network,
231
+ "actor_network",
232
+ create_target_params=self.delay_actor,
233
+ )
234
+ if separate_losses:
235
+ # we want to make sure there are no duplicates in the params: the
236
+ # params of critic must be refs to actor if they're shared
237
+ policy_params = list(actor_network.parameters())
238
+ else:
239
+ policy_params = None
240
+ self.convert_to_functional(
241
+ value_network,
242
+ "value_network",
243
+ create_target_params=self.delay_value,
244
+ compare_against=policy_params,
245
+ )
246
+ self.actor_critic.module[0] = self.actor_network
247
+ self.actor_critic.module[1] = self.value_network
248
+
249
+ self.actor_in_keys = actor_network.in_keys
250
+ self.value_exclusive_keys = set(self.value_network.in_keys) - (
251
+ set(self.actor_in_keys) | set(self.actor_network.out_keys)
252
+ )
253
+
254
+ self.loss_function = loss_function
255
+ self.reduction = reduction
256
+ if gamma is not None:
257
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
258
+
259
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
260
+ if self._value_estimator is not None:
261
+ self._value_estimator.set_keys(
262
+ value=self._tensor_keys.state_action_value,
263
+ reward=self._tensor_keys.reward,
264
+ done=self._tensor_keys.done,
265
+ terminated=self._tensor_keys.terminated,
266
+ )
267
+ self._set_in_keys()
268
+
269
+ def _set_in_keys(self):
270
+ in_keys = {
271
+ unravel_key(("next", self.tensor_keys.reward)),
272
+ unravel_key(("next", self.tensor_keys.done)),
273
+ unravel_key(("next", self.tensor_keys.terminated)),
274
+ *self.actor_in_keys,
275
+ *[unravel_key(("next", key)) for key in self.actor_in_keys],
276
+ *self.value_network.in_keys,
277
+ *[unravel_key(("next", key)) for key in self.value_network.in_keys],
278
+ }
279
+ if self.use_prioritized_weights:
280
+ in_keys.add(unravel_key(self.tensor_keys.priority_weight))
281
+ self._in_keys = sorted(in_keys, key=str)
282
+
283
+ @property
284
+ def in_keys(self):
285
+ if self._in_keys is None:
286
+ self._set_in_keys()
287
+ return self._in_keys
288
+
289
+ @in_keys.setter
290
+ def in_keys(self, values):
291
+ self._in_keys = values
292
+
293
+ @dispatch
294
+ def forward(self, tensordict: TensorDictBase) -> TensorDict:
295
+ """Computes the DDPG losses given a tensordict sampled from the replay buffer.
296
+
297
+ This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
298
+ a priority to items in the tensordict.
299
+
300
+ Args:
301
+ tensordict (TensorDictBase): a tensordict with keys ["done", "terminated", "reward"] and the in_keys of the actor
302
+ and value networks.
303
+
304
+ Returns:
305
+ a tuple of 2 tensors containing the DDPG loss.
306
+
307
+ """
308
+ loss_value, metadata = self.loss_value(tensordict)
309
+ loss_actor, metadata_actor = self.loss_actor(tensordict)
310
+ metadata.update(metadata_actor)
311
+ td_out = TensorDict(
312
+ source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
313
+ batch_size=[],
314
+ )
315
+ self._clear_weakrefs(
316
+ tensordict,
317
+ td_out,
318
+ "value_network_params",
319
+ "target_value_network_params",
320
+ "target_actor_network_params",
321
+ "actor_network_params",
322
+ )
323
+ return td_out
324
+
325
+ def loss_actor(
326
+ self,
327
+ tensordict: TensorDictBase,
328
+ ) -> [torch.Tensor, dict]:
329
+ weights = self._maybe_get_priority_weight(tensordict)
330
+ td_copy = tensordict.select(
331
+ *self.actor_in_keys, *self.value_exclusive_keys, strict=False
332
+ ).detach()
333
+ with self.actor_network_params.to_module(self.actor_network):
334
+ td_copy = self.actor_network(td_copy)
335
+ with self._cached_detached_value_params.to_module(self.value_network):
336
+ td_copy = self.value_network(td_copy)
337
+ loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
338
+ metadata = {}
339
+ loss_actor = _reduce(loss_actor, self.reduction, weights=weights)
340
+ self._clear_weakrefs(
341
+ tensordict,
342
+ loss_actor,
343
+ "value_network_params",
344
+ "target_value_network_params",
345
+ "target_actor_network_params",
346
+ "actor_network_params",
347
+ )
348
+ return loss_actor, metadata
349
+
350
+ def loss_value(
351
+ self,
352
+ tensordict: TensorDictBase,
353
+ ) -> tuple[torch.Tensor, dict]:
354
+ weights = self._maybe_get_priority_weight(tensordict)
355
+ # value loss
356
+ td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
357
+ with self.value_network_params.to_module(self.value_network):
358
+ self.value_network(td_copy)
359
+ pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
360
+
361
+ target_value = self.value_estimator.value_estimate(
362
+ tensordict, target_params=self._cached_target_params
363
+ ).squeeze(-1)
364
+
365
+ # td_error = pred_val - target_value
366
+ loss_value = distance_loss(
367
+ pred_val, target_value, loss_function=self.loss_function
368
+ )
369
+
370
+ td_error = (pred_val - target_value).pow(2)
371
+ td_error = td_error.detach()
372
+ if tensordict.device is not None:
373
+ td_error = td_error.to(tensordict.device)
374
+ tensordict.set(
375
+ self.tensor_keys.priority,
376
+ td_error,
377
+ inplace=True,
378
+ )
379
+ with torch.no_grad():
380
+ metadata = {
381
+ "td_error": td_error,
382
+ "pred_value": pred_val,
383
+ "target_value": target_value,
384
+ "target_value_max": target_value.max(),
385
+ "pred_value_max": pred_val.max(),
386
+ }
387
+ loss_value = _reduce(loss_value, self.reduction, weights=weights)
388
+ self._clear_weakrefs(
389
+ tensordict,
390
+ "value_network_params",
391
+ "target_value_network_params",
392
+ "target_actor_network_params",
393
+ "actor_network_params",
394
+ )
395
+ return loss_value, metadata
396
+
397
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
398
+ if value_type is None:
399
+ value_type = self.default_value_estimator
400
+
401
+ # Handle ValueEstimatorBase instance or class
402
+ if isinstance(value_type, ValueEstimatorBase) or (
403
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
404
+ ):
405
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
406
+
407
+ self.value_type = value_type
408
+ hp = dict(default_value_kwargs(value_type))
409
+ if hasattr(self, "gamma"):
410
+ hp["gamma"] = self.gamma
411
+ hp.update(hyperparams)
412
+ if value_type == ValueEstimators.TD1:
413
+ self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
414
+ elif value_type == ValueEstimators.TD0:
415
+ self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
416
+ elif value_type == ValueEstimators.GAE:
417
+ raise NotImplementedError(
418
+ f"Value type {value_type} it not implemented for loss {type(self)}."
419
+ )
420
+ elif value_type == ValueEstimators.TDLambda:
421
+ self._value_estimator = TDLambdaEstimator(
422
+ value_network=self.actor_critic, **hp
423
+ )
424
+ else:
425
+ raise NotImplementedError(f"Unknown value type {value_type}")
426
+
427
+ tensor_keys = {
428
+ "value": self.tensor_keys.state_action_value,
429
+ "reward": self.tensor_keys.reward,
430
+ "done": self.tensor_keys.done,
431
+ "terminated": self.tensor_keys.terminated,
432
+ }
433
+ self._value_estimator.set_keys(**tensor_keys)
434
+
435
+ @property
436
+ @_cache_values
437
+ def _cached_target_params(self):
438
+ target_params = TensorDict(
439
+ {
440
+ "module": {
441
+ "0": self.target_actor_network_params,
442
+ "1": self.target_value_network_params,
443
+ }
444
+ },
445
+ batch_size=self.target_actor_network_params.batch_size,
446
+ device=self.target_actor_network_params.device,
447
+ )
448
+ return target_params
449
+
450
+ @property
451
+ @_cache_values
452
+ def _cached_detached_value_params(self):
453
+ return self.value_network_params.detach()