torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,710 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from functools import wraps
10
+
11
+ import torch
12
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
13
+ from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
14
+ from tensordict.utils import NestedKey
15
+ from torch import Tensor
16
+
17
+ from torchrl.data.tensor_specs import Composite
18
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
19
+ from torchrl.objectives.common import LossModule
20
+ from torchrl.objectives.utils import (
21
+ _cache_values,
22
+ _reduce,
23
+ _vmap_func,
24
+ default_value_kwargs,
25
+ distance_loss,
26
+ ValueEstimators,
27
+ )
28
+ from torchrl.objectives.value import (
29
+ TD0Estimator,
30
+ TD1Estimator,
31
+ TDLambdaEstimator,
32
+ ValueEstimatorBase,
33
+ )
34
+
35
+
36
+ def _delezify(func):
37
+ @wraps(func)
38
+ def new_func(self, *args, **kwargs):
39
+ self.target_entropy
40
+ return func(self, *args, **kwargs)
41
+
42
+ return new_func
43
+
44
+
45
+ class CrossQLoss(LossModule):
46
+ """TorchRL implementation of the CrossQ loss.
47
+
48
+ Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING
49
+ FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX
50
+
51
+ This class has three loss functions that will be called sequentially by the `forward` method:
52
+ :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can
53
+ be called by the user that order.
54
+
55
+ Args:
56
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
57
+ qvalue_network (TensorDictModule): Q(s, a) parametric model.
58
+ This module typically outputs a ``"state_action_value"`` entry.
59
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
60
+ times. If a list of modules is passed, their
61
+ parameters will be stacked unless they share the same identity (in which case
62
+ the original parameter will be expanded).
63
+
64
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
65
+ and all the parameters will be considered as untied.
66
+
67
+ Keyword Args:
68
+ num_qvalue_nets (integer, optional): number of Q-Value networks used.
69
+ Defaults to ``2``.
70
+ loss_function (str, optional): loss function to be used with
71
+ the value function loss. Default is `"smooth_l1"`.
72
+ alpha_init (:obj:`float`, optional): initial entropy multiplier.
73
+ Default is 1.0.
74
+ min_alpha (:obj:`float`, optional): min value of alpha.
75
+ Default is None (no minimum value).
76
+ max_alpha (:obj:`float`, optional): max value of alpha.
77
+ Default is None (no maximum value).
78
+ action_spec (TensorSpec, optional): the action tensor spec. If not provided
79
+ and the target entropy is ``"auto"``, it will be retrieved from
80
+ the actor.
81
+ fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
82
+ initial value. Otherwise, alpha will be optimized to
83
+ match the 'target_entropy' value.
84
+ Default is ``False``.
85
+ target_entropy (:obj:`float` or str, optional): Target entropy for the
86
+ stochastic policy. Default is "auto", where target entropy is
87
+ computed as :obj:`-prod(n_actions)`.
88
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
89
+ Tensordict key where to write the
90
+ priority (for prioritized replay buffer usage). Defaults to ``"td_error"``.
91
+ separate_losses (bool, optional): if ``True``, shared parameters between
92
+ policy and critic will only be trained on the policy loss.
93
+ Defaults to ``False``, i.e., gradients are propagated to shared
94
+ parameters for both policy and critic losses.
95
+ reduction (str, optional): Specifies the reduction to apply to the output:
96
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
97
+ ``"mean"``: the sum of the output will be divided by the number of
98
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
99
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
100
+ Defaults to ``False``.
101
+
102
+ Examples:
103
+ >>> import torch
104
+ >>> from torch import nn
105
+ >>> from torchrl.data import Bounded
106
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
107
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
108
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
109
+ >>> from torchrl.objectives.crossq import CrossQLoss
110
+ >>> from tensordict import TensorDict
111
+ >>> n_act, n_obs = 4, 3
112
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
113
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
114
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
115
+ >>> actor = ProbabilisticActor(
116
+ ... module=module,
117
+ ... in_keys=["loc", "scale"],
118
+ ... spec=spec,
119
+ ... distribution_class=TanhNormal)
120
+ >>> class ValueClass(nn.Module):
121
+ ... def __init__(self):
122
+ ... super().__init__()
123
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
124
+ ... def forward(self, obs, act):
125
+ ... return self.linear(torch.cat([obs, act], -1))
126
+ >>> module = ValueClass()
127
+ >>> qvalue = ValueOperator(
128
+ ... module=module,
129
+ ... in_keys=['observation', 'action'])
130
+ >>> loss = CrossQLoss(actor, qvalue)
131
+ >>> batch = [2, ]
132
+ >>> action = spec.rand(batch)
133
+ >>> data = TensorDict({
134
+ ... "observation": torch.randn(*batch, n_obs),
135
+ ... "action": action,
136
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
137
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
138
+ ... ("next", "reward"): torch.randn(*batch, 1),
139
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
140
+ ... }, batch)
141
+ >>> loss(data)
142
+ TensorDict(
143
+ fields={
144
+ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
145
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
146
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
147
+ loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
148
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
149
+ batch_size=torch.Size([]),
150
+ device=None,
151
+ is_shared=False)
152
+
153
+ This class is compatible with non-tensordict based modules too and can be
154
+ used without recurring to any tensordict-related primitive. In this case,
155
+ the expected keyword arguments are:
156
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network.
157
+ The return value is a tuple of tensors in the following order:
158
+ ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]``
159
+
160
+ Examples:
161
+ >>> import torch
162
+ >>> from torch import nn
163
+ >>> from torchrl.data import Bounded
164
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
165
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
166
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
167
+ >>> from torchrl.objectives import CrossQLoss
168
+ >>> _ = torch.manual_seed(42)
169
+ >>> n_act, n_obs = 4, 3
170
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
171
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
172
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
173
+ >>> actor = ProbabilisticActor(
174
+ ... module=module,
175
+ ... in_keys=["loc", "scale"],
176
+ ... spec=spec,
177
+ ... distribution_class=TanhNormal)
178
+ >>> class ValueClass(nn.Module):
179
+ ... def __init__(self):
180
+ ... super().__init__()
181
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
182
+ ... def forward(self, obs, act):
183
+ ... return self.linear(torch.cat([obs, act], -1))
184
+ >>> module = ValueClass()
185
+ >>> qvalue = ValueOperator(
186
+ ... module=module,
187
+ ... in_keys=['observation', 'action'])
188
+ >>> loss = CrossQLoss(actor, qvalue)
189
+ >>> batch = [2, ]
190
+ >>> action = spec.rand(batch)
191
+ >>> loss_actor, loss_qvalue, _, _, _ = loss(
192
+ ... observation=torch.randn(*batch, n_obs),
193
+ ... action=action,
194
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
195
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
196
+ ... next_observation=torch.zeros(*batch, n_obs),
197
+ ... next_reward=torch.randn(*batch, 1))
198
+ >>> loss_actor.backward()
199
+
200
+ The output keys can also be filtered using the :meth:`CrossQLoss.select_out_keys`
201
+ method.
202
+
203
+ Examples:
204
+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
205
+ >>> loss_actor, loss_qvalue = loss(
206
+ ... observation=torch.randn(*batch, n_obs),
207
+ ... action=action,
208
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
209
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
210
+ ... next_observation=torch.zeros(*batch, n_obs),
211
+ ... next_reward=torch.randn(*batch, 1))
212
+ >>> loss_actor.backward()
213
+ """
214
+
215
+ @dataclass
216
+ class _AcceptedKeys:
217
+ """Maintains default values for all configurable tensordict keys.
218
+
219
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
220
+ default values.
221
+
222
+ Attributes:
223
+ action (NestedKey): The input tensordict key where the action is expected.
224
+ Defaults to ``"advantage"``.
225
+ state_action_value (NestedKey): The input tensordict key where the
226
+ state action value is expected. Defaults to ``"state_action_value"``.
227
+ priority (NestedKey): The input tensordict key where the target priority is written to.
228
+ Defaults to ``"td_error"``.
229
+ reward (NestedKey): The input tensordict key where the reward is expected.
230
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
231
+ done (NestedKey): The key in the input TensorDict that indicates
232
+ whether a trajectory is done. Will be used for the underlying value estimator.
233
+ Defaults to ``"done"``.
234
+ terminated (NestedKey): The key in the input TensorDict that indicates
235
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
236
+ Defaults to ``"terminated"``.
237
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
238
+ Defaults to ``"_log_prob"``.
239
+ """
240
+
241
+ action: NestedKey = "action"
242
+ state_action_value: NestedKey = "state_action_value"
243
+ priority: NestedKey = "td_error"
244
+ reward: NestedKey = "reward"
245
+ done: NestedKey = "done"
246
+ terminated: NestedKey = "terminated"
247
+ log_prob: NestedKey = "_log_prob"
248
+
249
+ tensor_keys: _AcceptedKeys
250
+ default_keys = _AcceptedKeys
251
+ default_value_estimator = ValueEstimators.TD0
252
+
253
+ actor_network: ProbabilisticTensorDictSequential
254
+ actor_network_params: TensorDictParams
255
+ qvalue_network: TensorDictModule
256
+ qvalue_network_params: TensorDictParams
257
+ target_actor_network_params: TensorDictParams
258
+ target_qvalue_network_params: TensorDictParams
259
+
260
+ def __init__(
261
+ self,
262
+ actor_network: ProbabilisticTensorDictSequential,
263
+ qvalue_network: TensorDictModule | list[TensorDictModule],
264
+ *,
265
+ num_qvalue_nets: int = 2,
266
+ loss_function: str = "smooth_l1",
267
+ alpha_init: float = 1.0,
268
+ min_alpha: float | None = None,
269
+ max_alpha: float | None = None,
270
+ action_spec=None,
271
+ fixed_alpha: bool = False,
272
+ target_entropy: str | float = "auto",
273
+ priority_key: str | None = None,
274
+ separate_losses: bool = False,
275
+ reduction: str | None = None,
276
+ deactivate_vmap: bool = False,
277
+ ) -> None:
278
+ self._in_keys = None
279
+ self._out_keys = None
280
+ if reduction is None:
281
+ reduction = "mean"
282
+ super().__init__()
283
+ self._set_deprecated_ctor_keys(priority_key=priority_key)
284
+
285
+ self.deactivate_vmap = deactivate_vmap
286
+
287
+ # Actor
288
+ self.convert_to_functional(
289
+ actor_network,
290
+ "actor_network",
291
+ create_target_params=False,
292
+ )
293
+ if separate_losses:
294
+ # we want to make sure there are no duplicates in the params: the
295
+ # params of critic must be refs to actor if they're shared
296
+ policy_params = list(actor_network.parameters())
297
+ else:
298
+ policy_params = None
299
+ q_value_policy_params = None
300
+
301
+ # Q value
302
+ self.num_qvalue_nets = num_qvalue_nets
303
+
304
+ q_value_policy_params = policy_params
305
+ self.convert_to_functional(
306
+ qvalue_network,
307
+ "qvalue_network",
308
+ num_qvalue_nets,
309
+ create_target_params=False,
310
+ compare_against=q_value_policy_params,
311
+ )
312
+
313
+ self.loss_function = loss_function
314
+ try:
315
+ device = next(self.parameters()).device
316
+ except AttributeError:
317
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
318
+ self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
319
+ if bool(min_alpha) ^ bool(max_alpha):
320
+ min_alpha = min_alpha if min_alpha else 0.0
321
+ if max_alpha == 0:
322
+ raise ValueError("max_alpha must be either None or greater than 0.")
323
+ max_alpha = max_alpha if max_alpha else 1e9
324
+ if min_alpha:
325
+ self.register_buffer(
326
+ "min_log_alpha", torch.tensor(min_alpha, device=device).log()
327
+ )
328
+ else:
329
+ self.min_log_alpha = None
330
+ if max_alpha:
331
+ self.register_buffer(
332
+ "max_log_alpha", torch.tensor(max_alpha, device=device).log()
333
+ )
334
+ else:
335
+ self.max_log_alpha = None
336
+ self.fixed_alpha = fixed_alpha
337
+ if fixed_alpha:
338
+ self.register_buffer(
339
+ "log_alpha", torch.tensor(math.log(alpha_init), device=device)
340
+ )
341
+ else:
342
+ self.register_parameter(
343
+ "log_alpha",
344
+ torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
345
+ )
346
+
347
+ self._target_entropy = target_entropy
348
+ self._action_spec = action_spec
349
+ self._make_vmap()
350
+ self.reduction = reduction
351
+ # init target entropy
352
+ self.maybe_init_target_entropy()
353
+
354
+ def _make_vmap(self):
355
+ self._vmap_qnetworkN0 = _vmap_func(
356
+ self.qvalue_network,
357
+ (None, 0),
358
+ randomness=self.vmap_randomness,
359
+ pseudo_vmap=self.deactivate_vmap,
360
+ )
361
+
362
+ @property
363
+ def target_entropy_buffer(self):
364
+ """The target entropy.
365
+
366
+ This value can be controlled via the `target_entropy` kwarg in the constructor.
367
+ """
368
+ return self.target_entropy
369
+
370
+ def maybe_init_target_entropy(self, fault_tolerant=True):
371
+ """Initialize the target entropy.
372
+
373
+ Args:
374
+ fault_tolerant (bool, optional): if ``True``, returns None if the target entropy
375
+ cannot be determined. Raises an exception otherwise. Defaults to ``True``.
376
+
377
+ """
378
+ if "_target_entropy" in self._buffers:
379
+ return
380
+ target_entropy = self._target_entropy
381
+ device = next(self.parameters()).device
382
+ if target_entropy == "auto":
383
+ action_spec = self.get_action_spec()
384
+ if action_spec is None:
385
+ if fault_tolerant:
386
+ return
387
+ raise RuntimeError(
388
+ "Cannot infer the dimensionality of the action. Consider providing "
389
+ "the target entropy explicitly or provide the spec of the "
390
+ "action tensor in the actor network."
391
+ )
392
+ if not isinstance(action_spec, Composite):
393
+ action_spec = Composite({self.tensor_keys.action: action_spec})
394
+ elif fault_tolerant and self.tensor_keys.action not in action_spec:
395
+ return
396
+ if (
397
+ isinstance(self.tensor_keys.action, tuple)
398
+ and len(self.tensor_keys.action) > 1
399
+ ):
400
+ action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
401
+ else:
402
+ action_container_shape = action_spec.shape
403
+ target_entropy = -float(
404
+ action_spec[self.tensor_keys.action]
405
+ .shape[len(action_container_shape) :]
406
+ .numel()
407
+ )
408
+ delattr(self, "_target_entropy")
409
+ self.register_buffer(
410
+ "_target_entropy", torch.tensor(target_entropy, device=device)
411
+ )
412
+ return self._target_entropy
413
+
414
+ def get_action_spec(self):
415
+ action_spec = self._action_spec
416
+ actor_network = self.actor_network
417
+ action_spec = (
418
+ action_spec
419
+ if action_spec is not None
420
+ else getattr(actor_network, "spec", None)
421
+ )
422
+ return action_spec
423
+
424
+ @property
425
+ def target_entropy(self):
426
+ target_entropy = self._buffers.get("_target_entropy")
427
+ if target_entropy is not None:
428
+ return target_entropy
429
+ return self.maybe_init_target_entropy(fault_tolerant=False)
430
+
431
+ def set_keys(self, **kwargs) -> None:
432
+ out = super().set_keys(**kwargs)
433
+ self.maybe_init_target_entropy()
434
+ return out
435
+
436
+ state_dict = _delezify(LossModule.state_dict)
437
+ load_state_dict = _delezify(LossModule.load_state_dict)
438
+
439
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
440
+ if self._value_estimator is not None:
441
+ self._value_estimator.set_keys(
442
+ value=self.tensor_keys.value,
443
+ reward=self.tensor_keys.reward,
444
+ done=self.tensor_keys.done,
445
+ terminated=self.tensor_keys.terminated,
446
+ )
447
+ self._set_in_keys()
448
+
449
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
450
+ if value_type is None:
451
+ value_type = self.default_value_estimator
452
+
453
+ # Handle ValueEstimatorBase instance or class
454
+ if isinstance(value_type, ValueEstimatorBase) or (
455
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
456
+ ):
457
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
458
+
459
+ self.value_type = value_type
460
+
461
+ value_net = None
462
+ hp = dict(default_value_kwargs(value_type))
463
+ hp.update(hyperparams)
464
+ if value_type is ValueEstimators.TD1:
465
+ self._value_estimator = TD1Estimator(
466
+ **hp,
467
+ value_network=value_net,
468
+ )
469
+ elif value_type is ValueEstimators.TD0:
470
+ self._value_estimator = TD0Estimator(
471
+ **hp,
472
+ value_network=value_net,
473
+ )
474
+ elif value_type is ValueEstimators.GAE:
475
+ raise NotImplementedError(
476
+ f"Value type {value_type} it not implemented for loss {type(self)}."
477
+ )
478
+ elif value_type is ValueEstimators.TDLambda:
479
+ self._value_estimator = TDLambdaEstimator(
480
+ **hp,
481
+ value_network=value_net,
482
+ )
483
+ else:
484
+ raise NotImplementedError(f"Unknown value type {value_type}")
485
+
486
+ tensor_keys = {
487
+ "reward": self.tensor_keys.reward,
488
+ "done": self.tensor_keys.done,
489
+ "terminated": self.tensor_keys.terminated,
490
+ }
491
+ self._value_estimator.set_keys(**tensor_keys)
492
+
493
+ @property
494
+ def device(self) -> torch.device:
495
+ for p in self.parameters():
496
+ return p.device
497
+ raise RuntimeError(
498
+ "At least one of the networks of SACLoss must have trainable " "parameters."
499
+ )
500
+
501
+ def _set_in_keys(self):
502
+ keys = [
503
+ self.tensor_keys.action,
504
+ ("next", self.tensor_keys.reward),
505
+ ("next", self.tensor_keys.done),
506
+ ("next", self.tensor_keys.terminated),
507
+ *self.actor_network.in_keys,
508
+ *[("next", key) for key in self.actor_network.in_keys],
509
+ *self.qvalue_network.in_keys,
510
+ ]
511
+ self._in_keys = list(set(keys))
512
+
513
+ @property
514
+ def in_keys(self):
515
+ if self._in_keys is None:
516
+ self._set_in_keys()
517
+ return self._in_keys
518
+
519
+ @in_keys.setter
520
+ def in_keys(self, values):
521
+ self._in_keys = values
522
+
523
+ @property
524
+ def out_keys(self):
525
+ if self._out_keys is None:
526
+ keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
527
+ self._out_keys = keys
528
+ return self._out_keys
529
+
530
+ @out_keys.setter
531
+ def out_keys(self, values):
532
+ self._out_keys = values
533
+
534
+ @dispatch
535
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
536
+ """The forward method.
537
+
538
+ Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns
539
+ a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached).
540
+ To see what keys are expected in the input tensordict and what keys are expected as output, check the
541
+ class's `"in_keys"` and `"out_keys"` attributes.
542
+ """
543
+ loss_qvalue, value_metadata = self.qvalue_loss(tensordict)
544
+ loss_actor, metadata_actor = self.actor_loss(tensordict)
545
+ loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"])
546
+ tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
547
+ if loss_actor.shape != loss_qvalue.shape:
548
+ raise RuntimeError(
549
+ f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}"
550
+ )
551
+ entropy = -metadata_actor["log_prob"]
552
+ out = {
553
+ "loss_actor": loss_actor,
554
+ "loss_qvalue": loss_qvalue,
555
+ "loss_alpha": loss_alpha,
556
+ "alpha": self._alpha,
557
+ "entropy": entropy.detach().mean(),
558
+ **metadata_actor,
559
+ **value_metadata,
560
+ }
561
+ td_out = TensorDict(out)
562
+ self._clear_weakrefs(
563
+ tensordict,
564
+ td_out,
565
+ "actor_network_params",
566
+ "qvalue_network_params",
567
+ "target_actor_network_params",
568
+ "target_qvalue_network_params",
569
+ )
570
+ return td_out
571
+
572
+ @property
573
+ @_cache_values
574
+ def _cached_detached_qvalue_params(self):
575
+ return self.qvalue_network_params.detach()
576
+
577
+ def actor_loss(
578
+ self, tensordict: TensorDictBase
579
+ ) -> tuple[Tensor, dict[str, Tensor]]:
580
+ """Compute the actor loss.
581
+
582
+ The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which
583
+ requires the `log_prob` field of the `metadata` returned by this method.
584
+
585
+ Args:
586
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
587
+ are required for this to be computed.
588
+
589
+ Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action.
590
+ """
591
+ tensordict = tensordict.copy()
592
+ with set_exploration_type(
593
+ ExplorationType.RANDOM
594
+ ), self.actor_network_params.to_module(self.actor_network):
595
+ dist = self.actor_network.get_dist(tensordict)
596
+ a_reparm = dist.rsample()
597
+ log_prob = dist.log_prob(a_reparm)
598
+
599
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
600
+ self.qvalue_network.eval()
601
+ td_q.set(self.tensor_keys.action, a_reparm)
602
+ td_q = self._vmap_qnetworkN0(
603
+ td_q,
604
+ self._cached_detached_qvalue_params,
605
+ )
606
+
607
+ min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
608
+ self.qvalue_network.train()
609
+
610
+ if log_prob.shape != min_q.shape:
611
+ raise RuntimeError(
612
+ f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}"
613
+ )
614
+ actor_loss = self._alpha * log_prob - min_q
615
+ return _reduce(actor_loss, reduction=self.reduction), {
616
+ "log_prob": log_prob.detach()
617
+ }
618
+
619
+ def qvalue_loss(
620
+ self, tensordict: TensorDictBase
621
+ ) -> tuple[Tensor, dict[str, Tensor]]:
622
+ """Compute the q-value loss.
623
+
624
+ The q-value loss should be computed before the :meth:`~.actor_loss`.
625
+
626
+ Args:
627
+ tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
628
+ are required for this to be computed.
629
+
630
+ Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
631
+ the detached `"td_error"` to be used for prioritized sampling.
632
+ """
633
+ tensordict = tensordict.copy()
634
+ # # compute next action
635
+ with torch.no_grad():
636
+ with set_exploration_type(
637
+ ExplorationType.RANDOM
638
+ ), self.actor_network_params.to_module(self.actor_network):
639
+ next_tensordict = tensordict.get("next").clone(False)
640
+ next_dist = self.actor_network.get_dist(next_tensordict)
641
+ next_action = next_dist.sample()
642
+ next_tensordict.set(self.tensor_keys.action, next_action)
643
+ next_sample_log_prob = next_dist.log_prob(next_action)
644
+
645
+ combined = torch.cat(
646
+ [
647
+ tensordict.select(*self.qvalue_network.in_keys, strict=False),
648
+ next_tensordict.select(*self.qvalue_network.in_keys, strict=False),
649
+ ]
650
+ )
651
+ pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get(
652
+ self.tensor_keys.state_action_value
653
+ )
654
+ (current_state_action_value, next_state_action_value) = pred_qs.split(
655
+ tensordict.batch_size[0], dim=1
656
+ )
657
+
658
+ # compute target value
659
+ if (
660
+ next_state_action_value.shape[-len(next_sample_log_prob.shape) :]
661
+ != next_sample_log_prob.shape
662
+ ):
663
+ next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
664
+ next_state_action_value = next_state_action_value.min(0)[0]
665
+ next_state_action_value = (
666
+ next_state_action_value - self._alpha * next_sample_log_prob
667
+ ).detach()
668
+
669
+ target_value = self.value_estimator.value_estimate(
670
+ tensordict, next_value=next_state_action_value
671
+ ).squeeze(-1)
672
+
673
+ # get current q-values
674
+ pred_val = current_state_action_value.squeeze(-1)
675
+
676
+ # compute loss
677
+ td_error = abs(pred_val - target_value)
678
+ loss_qval = distance_loss(
679
+ pred_val,
680
+ target_value.expand_as(pred_val),
681
+ loss_function=self.loss_function,
682
+ ).sum(0)
683
+ metadata = {"td_error": td_error.detach().max(0)[0]}
684
+ return _reduce(loss_qval, reduction=self.reduction), metadata
685
+
686
+ def alpha_loss(self, log_prob: Tensor) -> Tensor:
687
+ """Compute the entropy loss.
688
+
689
+ The entropy loss should be computed last.
690
+
691
+ Args:
692
+ log_prob (torch.Tensor): a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`.
693
+
694
+ Returns: a differentiable tensor with the entropy loss.
695
+ """
696
+ if self.target_entropy is not None:
697
+ # we can compute this loss even if log_alpha is not a parameter
698
+ alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
699
+ else:
700
+ # placeholder
701
+ alpha_loss = torch.zeros_like(log_prob)
702
+ return _reduce(alpha_loss, reduction=self.reduction)
703
+
704
+ @property
705
+ def _alpha(self):
706
+ if self.min_log_alpha is not None or self.max_log_alpha is not None:
707
+ self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
708
+ with torch.no_grad():
709
+ alpha = self.log_alpha.exp()
710
+ return alpha