torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1580 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ import warnings
9
+ from dataclasses import dataclass
10
+ from functools import wraps
11
+ from numbers import Number
12
+
13
+ import numpy as np
14
+ import torch
15
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
16
+ from tensordict.nn import (
17
+ composite_lp_aggregate,
18
+ CompositeDistribution,
19
+ dispatch,
20
+ ProbabilisticTensorDictSequential,
21
+ set_composite_lp_aggregate,
22
+ TensorDictModule,
23
+ )
24
+ from tensordict.utils import expand_right, NestedKey
25
+ from torch import Tensor
26
+
27
+ from torchrl.data.tensor_specs import Composite, TensorSpec
28
+ from torchrl.data.utils import _find_action_space
29
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
30
+ from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
31
+ from torchrl.objectives.common import LossModule
32
+ from torchrl.objectives.utils import (
33
+ _cache_values,
34
+ _GAMMA_LMBDA_DEPREC_ERROR,
35
+ _reduce,
36
+ _vmap_func,
37
+ default_value_kwargs,
38
+ distance_loss,
39
+ ValueEstimators,
40
+ )
41
+ from torchrl.objectives.value import (
42
+ TD0Estimator,
43
+ TD1Estimator,
44
+ TDLambdaEstimator,
45
+ ValueEstimatorBase,
46
+ )
47
+
48
+
49
+ def _delezify(func):
50
+ @wraps(func)
51
+ def new_func(self, *args, **kwargs):
52
+ self.target_entropy
53
+ return func(self, *args, **kwargs)
54
+
55
+ return new_func
56
+
57
+
58
+ def compute_log_prob(action_dist, action_or_tensordict, tensor_key) -> torch.Tensor:
59
+ """Compute the log probability of an action given a distribution."""
60
+ lp = action_dist.log_prob(action_or_tensordict)
61
+ if isinstance(action_dist, CompositeDistribution):
62
+ with set_composite_lp_aggregate(False):
63
+ return sum(lp.sum(dim="feature").values(True, True))
64
+ return lp
65
+
66
+
67
+ class SACLoss(LossModule):
68
+ """TorchRL implementation of the SAC loss.
69
+
70
+ Presented in "Soft Actor-Critic: Off-Policy Maximum Entropy Deep
71
+ Reinforcement Learning with a Stochastic Actor" https://arxiv.org/abs/1801.01290
72
+ and "Soft Actor-Critic Algorithms and Applications" https://arxiv.org/abs/1812.05905
73
+
74
+ Args:
75
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
76
+ qvalue_network (TensorDictModule): Q(s, a) parametric model.
77
+ This module typically outputs a ``"state_action_value"`` entry.
78
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
79
+ times. If a list of modules is passed, their
80
+ parameters will be stacked unless they share the same identity (in which case
81
+ the original parameter will be expanded).
82
+
83
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
84
+ and all the parameters will be considered as untied.
85
+
86
+ value_network (TensorDictModule, optional): V(s) parametric model.
87
+ This module typically outputs a ``"state_value"`` entry.
88
+
89
+ .. note::
90
+ If not provided, the second version of SAC is assumed, where
91
+ only the Q-Value network is needed.
92
+
93
+ Keyword Args:
94
+ num_qvalue_nets (integer, optional): number of Q-Value networks used.
95
+ Defaults to ``2``.
96
+ loss_function (str, optional): loss function to be used with
97
+ the value function loss. Default is `"smooth_l1"`.
98
+ alpha_init (:obj:`float`, optional): initial entropy multiplier.
99
+ Default is 1.0.
100
+ min_alpha (:obj:`float`, optional): min value of alpha.
101
+ Default is None (no minimum value).
102
+ max_alpha (:obj:`float`, optional): max value of alpha.
103
+ Default is None (no maximum value).
104
+ action_spec (TensorSpec, optional): the action tensor spec. If not provided
105
+ and the target entropy is ``"auto"``, it will be retrieved from
106
+ the actor.
107
+ fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
108
+ initial value. Otherwise, alpha will be optimized to
109
+ match the 'target_entropy' value.
110
+ Default is ``False``.
111
+ target_entropy (:obj:`float` or str, optional): Target entropy for the
112
+ stochastic policy. Default is "auto", where target entropy is
113
+ computed as :obj:`-prod(n_actions)`.
114
+ delay_actor (bool, optional): Whether to separate the target actor
115
+ networks from the actor networks used for data collection.
116
+ Default is ``False``.
117
+ delay_qvalue (bool, optional): Whether to separate the target Q value
118
+ networks from the Q value networks used for data collection.
119
+ Default is ``True``.
120
+ delay_value (bool, optional): Whether to separate the target value
121
+ networks from the value networks used for data collection.
122
+ Default is ``True``.
123
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
124
+ Tensordict key where to write the
125
+ priority (for prioritized replay buffer usage). Defaults to ``"td_error"``.
126
+ separate_losses (bool, optional): if ``True``, shared parameters between
127
+ policy and critic will only be trained on the policy loss.
128
+ Defaults to ``False``, i.e., gradients are propagated to shared
129
+ parameters for both policy and critic losses.
130
+ reduction (str, optional): Specifies the reduction to apply to the output:
131
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
132
+ ``"mean"``: the sum of the output will be divided by the number of
133
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
134
+ skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
135
+ valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
136
+ shape of the data and that masking the data results in a valid data structure. Among other things, this may
137
+ not be true in MARL settings or when using RNNs. Defaults to ``False``.
138
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
139
+ Defaults to ``False``.
140
+
141
+ Examples:
142
+ >>> import torch
143
+ >>> from torch import nn
144
+ >>> from torchrl.data import Bounded
145
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
146
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
147
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
148
+ >>> from torchrl.objectives.sac import SACLoss
149
+ >>> from tensordict import TensorDict
150
+ >>> n_act, n_obs = 4, 3
151
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
152
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
153
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
154
+ >>> actor = ProbabilisticActor(
155
+ ... module=module,
156
+ ... in_keys=["loc", "scale"],
157
+ ... spec=spec,
158
+ ... distribution_class=TanhNormal)
159
+ >>> class ValueClass(nn.Module):
160
+ ... def __init__(self):
161
+ ... super().__init__()
162
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
163
+ ... def forward(self, obs, act):
164
+ ... return self.linear(torch.cat([obs, act], -1))
165
+ >>> module = ValueClass()
166
+ >>> qvalue = ValueOperator(
167
+ ... module=module,
168
+ ... in_keys=['observation', 'action'])
169
+ >>> module = nn.Linear(n_obs, 1)
170
+ >>> value = ValueOperator(
171
+ ... module=module,
172
+ ... in_keys=["observation"])
173
+ >>> loss = SACLoss(actor, qvalue, value)
174
+ >>> batch = [2, ]
175
+ >>> action = spec.rand(batch)
176
+ >>> data = TensorDict({
177
+ ... "observation": torch.randn(*batch, n_obs),
178
+ ... "action": action,
179
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
180
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
181
+ ... ("next", "reward"): torch.randn(*batch, 1),
182
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
183
+ ... }, batch)
184
+ >>> loss(data)
185
+ TensorDict(
186
+ fields={
187
+ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
188
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
189
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
190
+ loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
191
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
192
+ loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
193
+ batch_size=torch.Size([]),
194
+ device=None,
195
+ is_shared=False)
196
+
197
+ This class is compatible with non-tensordict based modules too and can be
198
+ used without recurring to any tensordict-related primitive. In this case,
199
+ the expected keyword arguments are:
200
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network.
201
+ The return value is a tuple of tensors in the following order:
202
+ ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + ``"loss_value"`` if version one is used.
203
+
204
+ Examples:
205
+ >>> import torch
206
+ >>> from torch import nn
207
+ >>> from torchrl.data import Bounded
208
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
209
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
210
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
211
+ >>> from torchrl.objectives.sac import SACLoss
212
+ >>> _ = torch.manual_seed(42)
213
+ >>> n_act, n_obs = 4, 3
214
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
215
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
216
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
217
+ >>> actor = ProbabilisticActor(
218
+ ... module=module,
219
+ ... in_keys=["loc", "scale"],
220
+ ... spec=spec,
221
+ ... distribution_class=TanhNormal)
222
+ >>> class ValueClass(nn.Module):
223
+ ... def __init__(self):
224
+ ... super().__init__()
225
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
226
+ ... def forward(self, obs, act):
227
+ ... return self.linear(torch.cat([obs, act], -1))
228
+ >>> module = ValueClass()
229
+ >>> qvalue = ValueOperator(
230
+ ... module=module,
231
+ ... in_keys=['observation', 'action'])
232
+ >>> module = nn.Linear(n_obs, 1)
233
+ >>> value = ValueOperator(
234
+ ... module=module,
235
+ ... in_keys=["observation"])
236
+ >>> loss = SACLoss(actor, qvalue, value)
237
+ >>> batch = [2, ]
238
+ >>> action = spec.rand(batch)
239
+ >>> loss_actor, loss_qvalue, _, _, _, _ = loss(
240
+ ... observation=torch.randn(*batch, n_obs),
241
+ ... action=action,
242
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
243
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
244
+ ... next_observation=torch.zeros(*batch, n_obs),
245
+ ... next_reward=torch.randn(*batch, 1))
246
+ >>> loss_actor.backward()
247
+
248
+ The output keys can also be filtered using the :meth:`SACLoss.select_out_keys`
249
+ method.
250
+
251
+ Examples:
252
+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
253
+ >>> loss_actor, loss_qvalue = loss(
254
+ ... observation=torch.randn(*batch, n_obs),
255
+ ... action=action,
256
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
257
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
258
+ ... next_observation=torch.zeros(*batch, n_obs),
259
+ ... next_reward=torch.randn(*batch, 1))
260
+ >>> loss_actor.backward()
261
+ """
262
+
263
+ @dataclass
264
+ class _AcceptedKeys:
265
+ """Maintains default values for all configurable tensordict keys.
266
+
267
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
268
+ default values.
269
+
270
+ Attributes:
271
+ action (NestedKey): The input tensordict key where the action is expected.
272
+ Defaults to ``"advantage"``.
273
+ value (NestedKey): The input tensordict key where the state value is expected.
274
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
275
+ state_action_value (NestedKey): The input tensordict key where the
276
+ state action value is expected. Defaults to ``"state_action_value"``.
277
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
278
+ Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
279
+ `"action_log_prob"` otherwise.
280
+ priority (NestedKey): The input tensordict key where the target priority is written to.
281
+ Defaults to ``"td_error"``.
282
+ reward (NestedKey): The input tensordict key where the reward is expected.
283
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
284
+ done (NestedKey): The key in the input TensorDict that indicates
285
+ whether a trajectory is done. Will be used for the underlying value estimator.
286
+ Defaults to ``"done"``.
287
+ terminated (NestedKey): The key in the input TensorDict that indicates
288
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
289
+ Defaults to ``"terminated"``.
290
+ """
291
+
292
+ action: NestedKey = "action"
293
+ value: NestedKey = "state_value"
294
+ state_action_value: NestedKey = "state_action_value"
295
+ log_prob: NestedKey | None = None
296
+ priority: NestedKey = "td_error"
297
+ reward: NestedKey = "reward"
298
+ done: NestedKey = "done"
299
+ terminated: NestedKey = "terminated"
300
+ priority_weight: NestedKey = "priority_weight"
301
+
302
+ def __post_init__(self):
303
+ if self.log_prob is None:
304
+ if composite_lp_aggregate(nowarn=True):
305
+ self.log_prob = "sample_log_prob"
306
+ else:
307
+ self.log_prob = "action_log_prob"
308
+
309
+ default_keys = _AcceptedKeys
310
+ tensor_keys: _AcceptedKeys
311
+ default_value_estimator = ValueEstimators.TD0
312
+
313
+ actor_network: TensorDictModule
314
+ qvalue_network: TensorDictModule
315
+ value_network: TensorDictModule | None
316
+ actor_network_params: TensorDictParams
317
+ qvalue_network_params: TensorDictParams
318
+ value_network_params: TensorDictParams | None
319
+ target_actor_network_params: TensorDictParams
320
+ target_qvalue_network_params: TensorDictParams
321
+ target_value_network_params: TensorDictParams | None
322
+
323
+ def __init__(
324
+ self,
325
+ actor_network: ProbabilisticTensorDictSequential,
326
+ qvalue_network: TensorDictModule | list[TensorDictModule],
327
+ value_network: TensorDictModule | None = None,
328
+ *,
329
+ num_qvalue_nets: int = 2,
330
+ loss_function: str = "smooth_l1",
331
+ alpha_init: float = 1.0,
332
+ min_alpha: float | None = None,
333
+ max_alpha: float | None = None,
334
+ action_spec: TensorSpec | None = None,
335
+ fixed_alpha: bool = False,
336
+ target_entropy: str | float = "auto",
337
+ delay_actor: bool = False,
338
+ delay_qvalue: bool = True,
339
+ delay_value: bool = True,
340
+ gamma: float | None = None,
341
+ priority_key: str | None = None,
342
+ separate_losses: bool = False,
343
+ reduction: str | None = None,
344
+ skip_done_states: bool = False,
345
+ deactivate_vmap: bool = False,
346
+ use_prioritized_weights: str | bool = "auto",
347
+ ) -> None:
348
+ self._in_keys = None
349
+ self._out_keys = None
350
+ if reduction is None:
351
+ reduction = "mean"
352
+ super().__init__()
353
+ self.use_prioritized_weights = use_prioritized_weights
354
+ self._set_deprecated_ctor_keys(priority_key=priority_key)
355
+
356
+ # Actor
357
+ self.delay_actor = delay_actor
358
+ self.deactivate_vmap = deactivate_vmap
359
+ self.convert_to_functional(
360
+ actor_network,
361
+ "actor_network",
362
+ create_target_params=self.delay_actor,
363
+ )
364
+ if separate_losses:
365
+ # we want to make sure there are no duplicates in the params: the
366
+ # params of critic must be refs to actor if they're shared
367
+ policy_params = list(actor_network.parameters())
368
+ else:
369
+ policy_params = None
370
+ q_value_policy_params = None
371
+ # Value
372
+ if value_network is not None:
373
+ self._version = 1
374
+ self.delay_value = delay_value
375
+ self.convert_to_functional(
376
+ value_network,
377
+ "value_network",
378
+ create_target_params=self.delay_value,
379
+ compare_against=policy_params,
380
+ )
381
+ else:
382
+ self._version = 2
383
+ self.value_network_params = None
384
+ self.target_value_network_params = None
385
+
386
+ # Q value
387
+ self.delay_qvalue = delay_qvalue
388
+ self.num_qvalue_nets = num_qvalue_nets
389
+ if self._version == 1:
390
+ if separate_losses:
391
+ value_params = list(value_network.parameters())
392
+ q_value_policy_params = policy_params + value_params
393
+ else:
394
+ q_value_policy_params = policy_params
395
+ else:
396
+ q_value_policy_params = policy_params
397
+ self.convert_to_functional(
398
+ qvalue_network,
399
+ "qvalue_network",
400
+ num_qvalue_nets,
401
+ create_target_params=self.delay_qvalue,
402
+ compare_against=q_value_policy_params,
403
+ )
404
+
405
+ self.loss_function = loss_function
406
+ try:
407
+ device = next(self.parameters()).device
408
+ except AttributeError:
409
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
410
+ self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
411
+ if bool(min_alpha) ^ bool(max_alpha):
412
+ min_alpha = min_alpha if min_alpha else 0.0
413
+ if max_alpha == 0:
414
+ raise ValueError("max_alpha must be either None or greater than 0.")
415
+ max_alpha = max_alpha if max_alpha else 1e9
416
+ if min_alpha:
417
+ self.register_buffer(
418
+ "min_log_alpha", torch.tensor(min_alpha, device=device).log()
419
+ )
420
+ else:
421
+ self.min_log_alpha = None
422
+ if max_alpha:
423
+ self.register_buffer(
424
+ "max_log_alpha", torch.tensor(max_alpha, device=device).log()
425
+ )
426
+ else:
427
+ self.max_log_alpha = None
428
+ self.fixed_alpha = fixed_alpha
429
+ if fixed_alpha:
430
+ self.register_buffer(
431
+ "log_alpha", torch.tensor(math.log(alpha_init), device=device)
432
+ )
433
+ else:
434
+ self.register_parameter(
435
+ "log_alpha",
436
+ torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
437
+ )
438
+
439
+ self._target_entropy = target_entropy
440
+ self._action_spec = action_spec
441
+ if self._version == 1:
442
+ self.__dict__["actor_critic"] = ActorCriticWrapper(
443
+ self.actor_network, self.value_network
444
+ )
445
+ if gamma is not None:
446
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
447
+ self._make_vmap()
448
+ self.reduction = reduction
449
+ self.skip_done_states = skip_done_states
450
+
451
+ log_prob_keys = getattr(self.actor_network, "log_prob_keys", [])
452
+ action_keys = getattr(self.actor_network, "dist_sample_keys", [])
453
+ if len(log_prob_keys) > 1:
454
+ self.set_keys(log_prob=log_prob_keys, action=action_keys)
455
+ else:
456
+ self.set_keys(log_prob=log_prob_keys[0], action=action_keys[0])
457
+
458
+ def _make_vmap(self):
459
+ self._vmap_qnetworkN0 = _vmap_func(
460
+ self.qvalue_network,
461
+ (None, 0),
462
+ randomness=self.vmap_randomness,
463
+ pseudo_vmap=self.deactivate_vmap,
464
+ )
465
+ if self._version == 1:
466
+ self._vmap_qnetwork00 = _vmap_func(
467
+ self.qvalue_network,
468
+ randomness=self.vmap_randomness,
469
+ pseudo_vmap=self.deactivate_vmap,
470
+ )
471
+
472
+ @property
473
+ def target_entropy_buffer(self):
474
+ return self.target_entropy
475
+
476
+ @property
477
+ def target_entropy(self):
478
+ target_entropy = self._buffers.get("_target_entropy", None)
479
+ if target_entropy is not None:
480
+ return target_entropy
481
+ target_entropy = self._target_entropy
482
+ action_spec = self._action_spec
483
+ actor_network = self.actor_network
484
+ device = next(self.parameters()).device
485
+ if target_entropy == "auto":
486
+ action_spec = (
487
+ action_spec
488
+ if action_spec is not None
489
+ else getattr(actor_network, "spec", None)
490
+ )
491
+ if action_spec is None:
492
+ raise RuntimeError(
493
+ "Cannot infer the dimensionality of the action. Consider providing "
494
+ "the target entropy explicitly or provide the spec of the "
495
+ "action tensor in the actor network."
496
+ )
497
+ if not isinstance(action_spec, Composite):
498
+ action_spec = Composite({self.tensor_keys.action: action_spec})
499
+ if (
500
+ isinstance(self.tensor_keys.action, tuple)
501
+ and len(self.tensor_keys.action) > 1
502
+ ):
503
+ action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
504
+ else:
505
+ action_container_shape = action_spec.shape
506
+ action_spec_leaf = action_spec[self.tensor_keys.action]
507
+ if action_spec_leaf is None:
508
+ raise RuntimeError(
509
+ "Cannot infer the dimensionality of the action. The action spec "
510
+ f"for key '{self.tensor_keys.action}' is None. This can happen when "
511
+ "using composite action distributions. Consider providing the "
512
+ "'action_spec' or 'target_entropy' argument explicitly to the loss."
513
+ )
514
+ if isinstance(action_spec_leaf, Composite):
515
+ # For composite action specs, sum the numel of all leaf specs
516
+ target_entropy = -float(
517
+ self._compute_composite_spec_numel(
518
+ action_spec_leaf, action_container_shape
519
+ )
520
+ )
521
+ else:
522
+ target_entropy = -float(
523
+ action_spec_leaf.shape[len(action_container_shape) :].numel()
524
+ )
525
+ delattr(self, "_target_entropy")
526
+ self.register_buffer(
527
+ "_target_entropy", torch.tensor(target_entropy, device=device)
528
+ )
529
+ return self._target_entropy
530
+
531
+ state_dict = _delezify(LossModule.state_dict)
532
+ load_state_dict = _delezify(LossModule.load_state_dict)
533
+
534
+ def _compute_composite_spec_numel(
535
+ self, spec: Composite, container_shape: torch.Size
536
+ ) -> int:
537
+ """Compute the total number of action elements in a Composite spec.
538
+
539
+ This handles composite action distributions where multiple sub-actions
540
+ are grouped together.
541
+ """
542
+ total = 0
543
+ for subspec in spec.values():
544
+ if subspec is None:
545
+ continue
546
+ if isinstance(subspec, Composite):
547
+ total += self._compute_composite_spec_numel(subspec, container_shape)
548
+ else:
549
+ total += subspec.shape[len(container_shape) :].numel()
550
+ return total
551
+
552
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
553
+ if self._value_estimator is not None:
554
+ self._value_estimator.set_keys(
555
+ value=self.tensor_keys.value,
556
+ reward=self.tensor_keys.reward,
557
+ done=self.tensor_keys.done,
558
+ terminated=self.tensor_keys.terminated,
559
+ )
560
+ self._set_in_keys()
561
+
562
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
563
+ if value_type is None:
564
+ value_type = self.default_value_estimator
565
+
566
+ # Handle ValueEstimatorBase instance or class
567
+ if isinstance(value_type, ValueEstimatorBase) or (
568
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
569
+ ):
570
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
571
+
572
+ self.value_type = value_type
573
+ if self._version == 1:
574
+ value_net = self.actor_critic
575
+ elif self._version == 2:
576
+ # we will take care of computing the next value inside this module
577
+ value_net = None
578
+ else:
579
+ # unreachable
580
+ raise NotImplementedError
581
+
582
+ hp = dict(default_value_kwargs(value_type))
583
+ hp.update(hyperparams)
584
+ if value_type is ValueEstimators.TD1:
585
+ self._value_estimator = TD1Estimator(
586
+ **hp,
587
+ value_network=value_net,
588
+ deactivate_vmap=self.deactivate_vmap,
589
+ )
590
+ elif value_type is ValueEstimators.TD0:
591
+ self._value_estimator = TD0Estimator(
592
+ **hp,
593
+ value_network=value_net,
594
+ deactivate_vmap=self.deactivate_vmap,
595
+ )
596
+ elif value_type is ValueEstimators.GAE:
597
+ raise NotImplementedError(
598
+ f"Value type {value_type} it not implemented for loss {type(self)}."
599
+ )
600
+ elif value_type is ValueEstimators.TDLambda:
601
+ self._value_estimator = TDLambdaEstimator(
602
+ **hp,
603
+ value_network=value_net,
604
+ deactivate_vmap=self.deactivate_vmap,
605
+ )
606
+ else:
607
+ raise NotImplementedError(f"Unknown value type {value_type}")
608
+
609
+ tensor_keys = {
610
+ "value_target": "value_target",
611
+ "value": self.tensor_keys.value,
612
+ "reward": self.tensor_keys.reward,
613
+ "done": self.tensor_keys.done,
614
+ "terminated": self.tensor_keys.terminated,
615
+ }
616
+ self._value_estimator.set_keys(**tensor_keys)
617
+
618
+ @property
619
+ def device(self) -> torch.device:
620
+ for p in self.parameters():
621
+ return p.device
622
+ raise RuntimeError(
623
+ "At least one of the networks of SACLoss must have trainable " "parameters."
624
+ )
625
+
626
+ def _set_in_keys(self):
627
+ keys = [
628
+ self.tensor_keys.action,
629
+ ("next", self.tensor_keys.reward),
630
+ ("next", self.tensor_keys.done),
631
+ ("next", self.tensor_keys.terminated),
632
+ *self.actor_network.in_keys,
633
+ *[("next", key) for key in self.actor_network.in_keys],
634
+ *self.qvalue_network.in_keys,
635
+ ]
636
+ if self._version == 1:
637
+ keys.extend(self.value_network.in_keys)
638
+ self._in_keys = list(set(keys))
639
+
640
+ @property
641
+ def in_keys(self):
642
+ if self._in_keys is None:
643
+ self._set_in_keys()
644
+ return self._in_keys
645
+
646
+ @in_keys.setter
647
+ def in_keys(self, values):
648
+ self._in_keys = values
649
+
650
+ @property
651
+ def out_keys(self):
652
+ if self._out_keys is None:
653
+ keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
654
+ if self._version == 1:
655
+ keys.append("loss_value")
656
+ self._out_keys = keys
657
+ return self._out_keys
658
+
659
+ @out_keys.setter
660
+ def out_keys(self, values):
661
+ self._out_keys = values
662
+
663
+ @dispatch
664
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
665
+ if self._version == 1:
666
+ loss_qvalue, value_metadata = self.qvalue_v1_loss(tensordict)
667
+ loss_value, _ = self.value_loss(tensordict)
668
+ else:
669
+ loss_qvalue, value_metadata = self.qvalue_v2_loss(tensordict)
670
+ loss_value = None
671
+ loss_actor, metadata_actor = self.actor_loss(tensordict)
672
+ loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
673
+ weights = self._maybe_get_priority_weight(tensordict)
674
+ loss_alpha = _reduce(loss_alpha, reduction=self.reduction, weights=weights)
675
+ tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
676
+ if (loss_actor.shape != loss_qvalue.shape) or (
677
+ loss_value is not None and loss_actor.shape != loss_value.shape
678
+ ):
679
+ raise RuntimeError(
680
+ f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
681
+ )
682
+ entropy = -metadata_actor["log_prob"]
683
+ out = {
684
+ "loss_actor": loss_actor,
685
+ "loss_qvalue": loss_qvalue,
686
+ "loss_alpha": loss_alpha,
687
+ "alpha": self._alpha,
688
+ "entropy": entropy.detach().mean(),
689
+ }
690
+ if self._version == 1:
691
+ out["loss_value"] = loss_value
692
+ td_out = TensorDict(out)
693
+ self._clear_weakrefs(
694
+ tensordict,
695
+ td_out,
696
+ "actor_network_params",
697
+ "qvalue_network_params",
698
+ "value_network_params",
699
+ "target_actor_network_params",
700
+ "target_qvalue_network_params",
701
+ "target_value_network_params",
702
+ )
703
+ return td_out
704
+
705
+ @property
706
+ @_cache_values
707
+ def _cached_detached_qvalue_params(self):
708
+ return self.qvalue_network_params.detach()
709
+
710
+ def actor_loss(
711
+ self, tensordict: TensorDictBase
712
+ ) -> tuple[Tensor, dict[str, Tensor]]:
713
+ weights = self._maybe_get_priority_weight(tensordict)
714
+ with set_exploration_type(
715
+ ExplorationType.RANDOM
716
+ ), self.actor_network_params.to_module(self.actor_network):
717
+ dist = self.actor_network.get_dist(tensordict)
718
+ a_reparm = dist.rsample()
719
+ log_prob = compute_log_prob(dist, a_reparm, self.tensor_keys.log_prob)
720
+
721
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
722
+ td_q.set(self.tensor_keys.action, a_reparm)
723
+ td_q = self._vmap_qnetworkN0(
724
+ td_q,
725
+ self._cached_detached_qvalue_params, # should we clone?
726
+ )
727
+ min_q_logprob = (
728
+ td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
729
+ )
730
+
731
+ if log_prob.shape != min_q_logprob.shape:
732
+ raise RuntimeError(
733
+ f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}"
734
+ )
735
+ loss_actor = self._alpha * log_prob - min_q_logprob
736
+ loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights)
737
+ return loss_actor, {"log_prob": log_prob.detach()}
738
+
739
+ def alpha_loss(self, log_prob: Tensor) -> Tensor:
740
+ """Compute the alpha loss for SAC.
741
+
742
+ This method computes the alpha loss which adapts the entropy coefficient
743
+ to maintain the target entropy level.
744
+
745
+ Args:
746
+ log_prob (Tensor): The log probability of actions from the actor network.
747
+
748
+ Returns:
749
+ The alpha loss tensor
750
+ """
751
+ return self._alpha_loss(log_prob)
752
+
753
+ @property
754
+ def _alpha(self):
755
+ if self.min_log_alpha is not None or self.max_log_alpha is not None:
756
+ self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
757
+ with torch.no_grad():
758
+ alpha = self.log_alpha.exp()
759
+ return alpha
760
+
761
+ @property
762
+ @_cache_values
763
+ def _cached_target_params_actor_value(self):
764
+ return TensorDict._new_unsafe(
765
+ {
766
+ "module": {
767
+ "0": self.target_actor_network_params,
768
+ "1": self.target_value_network_params,
769
+ }
770
+ },
771
+ torch.Size([]),
772
+ )
773
+
774
+ def qvalue_v1_loss(
775
+ self, tensordict: TensorDictBase
776
+ ) -> tuple[Tensor, dict[str, Tensor]]:
777
+ weights = self._maybe_get_priority_weight(tensordict)
778
+ target_params = self._cached_target_params_actor_value
779
+ with set_exploration_type(self.deterministic_sampling_mode):
780
+ target_value = self.value_estimator.value_estimate(
781
+ tensordict, target_params=target_params
782
+ ).squeeze(-1)
783
+
784
+ # Q-nets must be trained independently: as such, we split the data in 2
785
+ # if required and train each q-net on one half of the data.
786
+ shape = tensordict.shape
787
+ if shape[0] % self.num_qvalue_nets != 0:
788
+ raise RuntimeError(
789
+ f"Batch size={tensordict.shape} is incompatible "
790
+ f"with num_qvqlue_nets={self.num_qvalue_nets}."
791
+ )
792
+ tensordict_chunks = tensordict.reshape(
793
+ self.num_qvalue_nets, -1, *tensordict.shape[1:]
794
+ )
795
+ target_chunks = target_value.reshape(
796
+ self.num_qvalue_nets, -1, *target_value.shape[1:]
797
+ )
798
+
799
+ # if vmap=True, it is assumed that the input tensordict must be cast to the param shape
800
+ tensordict_chunks = self._vmap_qnetwork00(
801
+ tensordict_chunks, self.qvalue_network_params
802
+ )
803
+ pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value)
804
+ pred_val = pred_val.squeeze(-1)
805
+ loss_value = distance_loss(
806
+ pred_val, target_chunks, loss_function=self.loss_function
807
+ ).view(*shape)
808
+ loss_value = _reduce(loss_value, reduction=self.reduction, weights=weights)
809
+ metadata = {"td_error": (pred_val - target_chunks).pow(2).flatten(0, 1)}
810
+
811
+ return loss_value, metadata
812
+
813
+ def _compute_target_v2(self, tensordict) -> Tensor:
814
+ r"""Value network for SAC v2.
815
+
816
+ SAC v2 is based on a value estimate of the form:
817
+
818
+ .. math::
819
+
820
+ V = Q(s,a) - \alpha * \log p(a | s)
821
+
822
+ This class computes this value given the actor and qvalue network
823
+
824
+ """
825
+ tensordict = tensordict.clone(False)
826
+ # get actions and log-probs
827
+ with torch.no_grad():
828
+ with set_exploration_type(
829
+ ExplorationType.RANDOM
830
+ ), self.actor_network_params.to_module(self.actor_network):
831
+ next_tensordict = tensordict.get("next").copy()
832
+ if self.skip_done_states:
833
+ # Check done state and avoid passing these to the actor
834
+ done = next_tensordict.get(self.tensor_keys.done)
835
+ if done is not None and done.any():
836
+ next_tensordict_select = next_tensordict[~done.squeeze(-1)]
837
+ else:
838
+ next_tensordict_select = next_tensordict
839
+ next_dist = self.actor_network.get_dist(next_tensordict_select)
840
+ next_action = next_dist.rsample()
841
+ next_sample_log_prob = compute_log_prob(
842
+ next_dist, next_action, self.tensor_keys.log_prob
843
+ )
844
+ if next_tensordict_select is not next_tensordict:
845
+ mask = ~done.squeeze(-1)
846
+ if mask.ndim < next_action.ndim:
847
+ mask = expand_right(
848
+ mask, (*mask.shape, *next_action.shape[mask.ndim :])
849
+ )
850
+ next_action = next_action.new_zeros(mask.shape).masked_scatter_(
851
+ mask, next_action
852
+ )
853
+ mask = ~done.squeeze(-1)
854
+ if mask.ndim < next_sample_log_prob.ndim:
855
+ mask = expand_right(
856
+ mask,
857
+ (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
858
+ )
859
+ next_sample_log_prob = next_sample_log_prob.new_zeros(
860
+ mask.shape
861
+ ).masked_scatter_(mask, next_sample_log_prob)
862
+ next_tensordict.set(self.tensor_keys.action, next_action)
863
+ else:
864
+ next_dist = self.actor_network.get_dist(next_tensordict)
865
+ next_action = next_dist.rsample()
866
+ next_tensordict.set(self.tensor_keys.action, next_action)
867
+ next_sample_log_prob = compute_log_prob(
868
+ next_dist, next_action, self.tensor_keys.log_prob
869
+ )
870
+
871
+ # get q-values
872
+ next_tensordict_expand = self._vmap_qnetworkN0(
873
+ next_tensordict, self.target_qvalue_network_params
874
+ )
875
+ state_action_value = next_tensordict_expand.get(
876
+ self.tensor_keys.state_action_value
877
+ )
878
+ if (
879
+ state_action_value.shape[-len(next_sample_log_prob.shape) :]
880
+ != next_sample_log_prob.shape
881
+ ):
882
+ next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
883
+ next_state_value = state_action_value - self._alpha * next_sample_log_prob
884
+ next_state_value = next_state_value.min(0)[0]
885
+ tensordict.set(
886
+ ("next", self.value_estimator.tensor_keys.value), next_state_value
887
+ )
888
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
889
+ return target_value
890
+
891
+ def qvalue_v2_loss(
892
+ self, tensordict: TensorDictBase
893
+ ) -> tuple[Tensor, dict[str, Tensor]]:
894
+ weights = self._maybe_get_priority_weight(tensordict)
895
+ # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
896
+ target_value = self._compute_target_v2(tensordict)
897
+
898
+ tensordict_expand = self._vmap_qnetworkN0(
899
+ tensordict.select(*self.qvalue_network.in_keys, strict=False),
900
+ self.qvalue_network_params,
901
+ )
902
+ pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
903
+ -1
904
+ )
905
+ td_error = abs(pred_val - target_value)
906
+ loss_qval = distance_loss(
907
+ pred_val,
908
+ target_value.expand_as(pred_val),
909
+ loss_function=self.loss_function,
910
+ ).sum(0)
911
+ loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights)
912
+ metadata = {"td_error": td_error.detach().max(0)[0]}
913
+ return loss_qval, metadata
914
+
915
+ def value_loss(
916
+ self, tensordict: TensorDictBase
917
+ ) -> tuple[Tensor, dict[str, Tensor]]:
918
+ weights = self._maybe_get_priority_weight(tensordict)
919
+ # value loss
920
+ td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
921
+ with self.value_network_params.to_module(self.value_network):
922
+ self.value_network(td_copy)
923
+ pred_val = td_copy.get(self.tensor_keys.value).squeeze(-1)
924
+ with self.target_actor_network_params.to_module(self.actor_network):
925
+ action_dist = self.actor_network.get_dist(td_copy) # resample an action
926
+ action = action_dist.rsample()
927
+
928
+ td_copy.set(self.tensor_keys.action, action, inplace=False)
929
+
930
+ td_copy = self._vmap_qnetworkN0(
931
+ td_copy,
932
+ self.target_qvalue_network_params,
933
+ )
934
+
935
+ min_qval = (
936
+ td_copy.get(self.tensor_keys.state_action_value).squeeze(-1).min(0)[0]
937
+ )
938
+
939
+ log_p = compute_log_prob(action_dist, action, self.tensor_keys.log_prob)
940
+
941
+ if log_p.shape != min_qval.shape:
942
+ raise RuntimeError(
943
+ f"Losses shape mismatch: {min_qval.shape} and {log_p.shape}"
944
+ )
945
+ target_val = min_qval - self._alpha * log_p
946
+
947
+ loss_value = distance_loss(
948
+ pred_val, target_val, loss_function=self.loss_function
949
+ )
950
+ loss_value = _reduce(loss_value, reduction=self.reduction, weights=weights)
951
+ return loss_value, {}
952
+
953
+ def _alpha_loss(self, log_prob: Tensor) -> Tensor:
954
+ if self.target_entropy is not None:
955
+ # we can compute this loss even if log_alpha is not a parameter
956
+ alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
957
+ else:
958
+ # placeholder
959
+ alpha_loss = torch.zeros_like(log_prob)
960
+ return alpha_loss
961
+
962
+
963
+ class DiscreteSACLoss(LossModule):
964
+ """Discrete SAC Loss module.
965
+
966
+ Args:
967
+ actor_network (ProbabilisticTensorDictSequential): the actor to be trained
968
+ qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed.
969
+ action_space (str or TensorSpec): Action space. Must be one of
970
+ ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
971
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
972
+ :class:`torchrl.data.MultiOneHot`,
973
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
974
+ num_actions (int, optional): number of actions in the action space.
975
+ To be provided if target_entropy is set to "auto".
976
+ num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 2.
977
+ loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2",
978
+ "l1", Default is "smooth_l1".
979
+ alpha_init (:obj:`float`, optional): initial entropy multiplier.
980
+ Default is 1.0.
981
+ min_alpha (:obj:`float`, optional): min value of alpha.
982
+ Default is None (no minimum value).
983
+ max_alpha (:obj:`float`, optional): max value of alpha.
984
+ Default is None (no maximum value).
985
+ fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is ``False``.
986
+ target_entropy_weight (:obj:`float`, optional): weight for the target entropy term.
987
+ target_entropy (Union[str, Number], optional): Target entropy for the
988
+ stochastic policy. Default is "auto", where target entropy is
989
+ computed as :obj:`-target_entropy_weight * log(1 / num_actions)`.
990
+ delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used
991
+ for data collection. Default is ``False``.
992
+ priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
993
+ Key where to write the priority value for prioritized replay buffers.
994
+ Default is `"td_error"`.
995
+ separate_losses (bool, optional): if ``True``, shared parameters between
996
+ policy and critic will only be trained on the policy loss.
997
+ Defaults to ``False``, i.e., gradients are propagated to shared
998
+ parameters for both policy and critic losses.
999
+ reduction (str, optional): Specifies the reduction to apply to the output:
1000
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
1001
+ ``"mean"``: the sum of the output will be divided by the number of
1002
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
1003
+ skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
1004
+ valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
1005
+ shape of the data and that masking the data results in a valid data structure. Among other things, this may
1006
+ not be true in MARL settings or when using RNNs. Defaults to ``False``.
1007
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
1008
+ Defaults to ``False``.
1009
+
1010
+ Examples:
1011
+ >>> import torch
1012
+ >>> from torch import nn
1013
+ >>> from torchrl.data.tensor_specs import OneHot
1014
+ >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical
1015
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
1016
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
1017
+ >>> from torchrl.objectives.sac import DiscreteSACLoss
1018
+ >>> from tensordict import TensorDict
1019
+ >>> from tensordict.nn import TensorDictModule
1020
+ >>> n_act, n_obs = 4, 3
1021
+ >>> spec = OneHot(n_act)
1022
+ >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
1023
+ >>> actor = ProbabilisticActor(
1024
+ ... module=module,
1025
+ ... in_keys=["logits"],
1026
+ ... out_keys=["action"],
1027
+ ... spec=spec,
1028
+ ... distribution_class=OneHotCategorical)
1029
+ >>> qvalue = TensorDictModule(
1030
+ ... nn.Linear(n_obs, n_act),
1031
+ ... in_keys=["observation"],
1032
+ ... out_keys=["action_value"],
1033
+ ... )
1034
+ >>> loss = DiscreteSACLoss(actor, qvalue, action_space=spec, num_actions=spec.space.n)
1035
+ >>> batch = [2,]
1036
+ >>> action = spec.rand(batch)
1037
+ >>> data = TensorDict({
1038
+ ... "observation": torch.randn(*batch, n_obs),
1039
+ ... "action": action,
1040
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
1041
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
1042
+ ... ("next", "reward"): torch.randn(*batch, 1),
1043
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
1044
+ ... }, batch)
1045
+ >>> loss(data)
1046
+ TensorDict(
1047
+ fields={
1048
+ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1049
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1050
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1051
+ loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1052
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
1053
+ batch_size=torch.Size([]),
1054
+ device=None,
1055
+ is_shared=False)
1056
+
1057
+
1058
+ This class is compatible with non-tensordict based modules too and can be
1059
+ used without recurring to any tensordict-related primitive. In this case,
1060
+ the expected keyword arguments are:
1061
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network.
1062
+ The return value is a tuple of tensors in the following order:
1063
+ ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]``.
1064
+ The output keys can also be filtered using :meth:`DiscreteSACLoss.select_out_keys` method.
1065
+
1066
+ Examples:
1067
+ >>> import torch
1068
+ >>> from torch import nn
1069
+ >>> from torchrl.data.tensor_specs import OneHot
1070
+ >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical
1071
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
1072
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
1073
+ >>> from torchrl.objectives.sac import DiscreteSACLoss
1074
+ >>> n_act, n_obs = 4, 3
1075
+ >>> spec = OneHot(n_act)
1076
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
1077
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"])
1078
+ >>> actor = ProbabilisticActor(
1079
+ ... module=module,
1080
+ ... in_keys=["logits"],
1081
+ ... out_keys=["action"],
1082
+ ... spec=spec,
1083
+ ... distribution_class=OneHotCategorical)
1084
+ >>> class ValueClass(nn.Module):
1085
+ ... def __init__(self):
1086
+ ... super().__init__()
1087
+ ... self.linear = nn.Linear(n_obs, n_act)
1088
+ ... def forward(self, obs):
1089
+ ... return self.linear(obs)
1090
+ >>> module = ValueClass()
1091
+ >>> qvalue = ValueOperator(
1092
+ ... module=module,
1093
+ ... in_keys=['observation'])
1094
+ >>> loss = DiscreteSACLoss(actor, qvalue, num_actions=actor.spec["action"].space.n)
1095
+ >>> batch = [2, ]
1096
+ >>> action = spec.rand(batch)
1097
+ >>> # filter output keys to "loss_actor", and "loss_qvalue"
1098
+ >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
1099
+ >>> loss_actor, loss_qvalue = loss(
1100
+ ... observation=torch.randn(*batch, n_obs),
1101
+ ... action=action,
1102
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
1103
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
1104
+ ... next_observation=torch.zeros(*batch, n_obs),
1105
+ ... next_reward=torch.randn(*batch, 1))
1106
+ >>> loss_actor.backward()
1107
+ """
1108
+
1109
+ @dataclass
1110
+ class _AcceptedKeys:
1111
+ """Maintains default values for all configurable tensordict keys.
1112
+
1113
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
1114
+ default values
1115
+
1116
+ Attributes:
1117
+ action (NestedKey): The input tensordict key where the action is expected.
1118
+ Defaults to ``"action"``.
1119
+ value (NestedKey): The input tensordict key where the state value is expected.
1120
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
1121
+ priority (NestedKey): The input tensordict key where the target priority is written to.
1122
+ Defaults to ``"td_error"``.
1123
+ reward (NestedKey): The input tensordict key where the reward is expected.
1124
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
1125
+ done (NestedKey): The key in the input TensorDict that indicates
1126
+ whether a trajectory is done. Will be used for the underlying value estimator.
1127
+ Defaults to ``"done"``.
1128
+ terminated (NestedKey): The key in the input TensorDict that indicates
1129
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
1130
+ Defaults to ``"terminated"``.
1131
+ """
1132
+
1133
+ action: NestedKey = "action"
1134
+ value: NestedKey = "state_value"
1135
+ action_value: NestedKey = "action_value"
1136
+ priority: NestedKey = "td_error"
1137
+ reward: NestedKey = "reward"
1138
+ done: NestedKey = "done"
1139
+ terminated: NestedKey = "terminated"
1140
+ log_prob: NestedKey = "log_prob"
1141
+ priority_weight: NestedKey = "priority_weight"
1142
+
1143
+ tensor_keys: _AcceptedKeys
1144
+ default_keys = _AcceptedKeys
1145
+ default_value_estimator = ValueEstimators.TD0
1146
+ delay_actor: bool = False
1147
+ out_keys = [
1148
+ "loss_actor",
1149
+ "loss_qvalue",
1150
+ "loss_alpha",
1151
+ "alpha",
1152
+ "entropy",
1153
+ ]
1154
+
1155
+ actor_network: TensorDictModule
1156
+ qvalue_network: TensorDictModule
1157
+ value_network: TensorDictModule | None
1158
+ actor_network_params: TensorDictParams
1159
+ qvalue_network_params: TensorDictParams
1160
+ value_network_params: TensorDictParams | None
1161
+ target_actor_network_params: TensorDictParams
1162
+ target_qvalue_network_params: TensorDictParams
1163
+ target_value_network_params: TensorDictParams | None
1164
+
1165
+ def __init__(
1166
+ self,
1167
+ actor_network: ProbabilisticTensorDictSequential,
1168
+ qvalue_network: TensorDictModule,
1169
+ *,
1170
+ action_space: str | TensorSpec = None,
1171
+ num_actions: int | None = None,
1172
+ num_qvalue_nets: int = 2,
1173
+ loss_function: str = "smooth_l1",
1174
+ alpha_init: float = 1.0,
1175
+ min_alpha: float | None = None,
1176
+ max_alpha: float | None = None,
1177
+ fixed_alpha: bool = False,
1178
+ target_entropy_weight: float = 0.98,
1179
+ target_entropy: str | Number = "auto",
1180
+ delay_qvalue: bool = True,
1181
+ priority_key: str | None = None,
1182
+ separate_losses: bool = False,
1183
+ reduction: str | None = None,
1184
+ skip_done_states: bool = False,
1185
+ deactivate_vmap: bool = False,
1186
+ use_prioritized_weights: str | bool = "auto",
1187
+ ):
1188
+ if reduction is None:
1189
+ reduction = "mean"
1190
+ self._in_keys = None
1191
+ super().__init__()
1192
+ self.use_prioritized_weights = use_prioritized_weights
1193
+ self._set_deprecated_ctor_keys(priority_key=priority_key)
1194
+
1195
+ self.convert_to_functional(
1196
+ actor_network,
1197
+ "actor_network",
1198
+ create_target_params=self.delay_actor,
1199
+ )
1200
+ self.deactivate_vmap = deactivate_vmap
1201
+ if separate_losses:
1202
+ # we want to make sure there are no duplicates in the params: the
1203
+ # params of critic must be refs to actor if they're shared
1204
+ policy_params = list(actor_network.parameters())
1205
+ else:
1206
+ policy_params = None
1207
+ self.delay_qvalue = delay_qvalue
1208
+ self.convert_to_functional(
1209
+ qvalue_network,
1210
+ "qvalue_network",
1211
+ num_qvalue_nets,
1212
+ create_target_params=self.delay_qvalue,
1213
+ compare_against=policy_params,
1214
+ )
1215
+ self.num_qvalue_nets = num_qvalue_nets
1216
+ self.loss_function = loss_function
1217
+
1218
+ try:
1219
+ device = next(self.parameters()).device
1220
+ except AttributeError:
1221
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
1222
+
1223
+ self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
1224
+ if bool(min_alpha) ^ bool(max_alpha):
1225
+ min_alpha = min_alpha if min_alpha else 0.0
1226
+ if max_alpha == 0:
1227
+ raise ValueError("max_alpha must be either None or greater than 0.")
1228
+ max_alpha = max_alpha if max_alpha else 1e9
1229
+ if min_alpha:
1230
+ self.register_buffer(
1231
+ "min_log_alpha", torch.tensor(min_alpha, device=device).log()
1232
+ )
1233
+ else:
1234
+ self.min_log_alpha = None
1235
+ if max_alpha:
1236
+ self.register_buffer(
1237
+ "max_log_alpha", torch.tensor(max_alpha, device=device).log()
1238
+ )
1239
+ else:
1240
+ self.max_log_alpha = None
1241
+ self.fixed_alpha = fixed_alpha
1242
+ if fixed_alpha:
1243
+ self.register_buffer(
1244
+ "log_alpha", torch.tensor(math.log(alpha_init), device=device)
1245
+ )
1246
+ else:
1247
+ self.register_parameter(
1248
+ "log_alpha",
1249
+ torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
1250
+ )
1251
+
1252
+ if action_space is None:
1253
+ warnings.warn(
1254
+ "action_space was not specified. DiscreteSACLoss will default to 'one-hot'. "
1255
+ "This behavior will be deprecated soon and a space will have to be passed. "
1256
+ "Check the DiscreteSACLoss documentation to see how to pass the action space. "
1257
+ )
1258
+ action_space = "one-hot"
1259
+ self.action_space = _find_action_space(action_space)
1260
+ if target_entropy == "auto":
1261
+ if num_actions is None:
1262
+ raise ValueError(
1263
+ "num_actions needs to be provided if target_entropy == 'auto'"
1264
+ )
1265
+ target_entropy = -float(np.log(1.0 / num_actions) * target_entropy_weight)
1266
+ self.register_buffer(
1267
+ "target_entropy", torch.tensor(target_entropy, device=device)
1268
+ )
1269
+ self._make_vmap()
1270
+ self.reduction = reduction
1271
+ self.skip_done_states = skip_done_states
1272
+
1273
+ def _make_vmap(self):
1274
+ self._vmap_qnetworkN0 = _vmap_func(
1275
+ self.qvalue_network,
1276
+ (None, 0),
1277
+ randomness=self.vmap_randomness,
1278
+ pseudo_vmap=self.deactivate_vmap,
1279
+ )
1280
+
1281
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
1282
+ if self._value_estimator is not None:
1283
+ self._value_estimator.set_keys(
1284
+ value=self._tensor_keys.value,
1285
+ reward=self.tensor_keys.reward,
1286
+ done=self.tensor_keys.done,
1287
+ terminated=self.tensor_keys.terminated,
1288
+ )
1289
+ self._set_in_keys()
1290
+
1291
+ def _set_in_keys(self):
1292
+ keys = [
1293
+ self.tensor_keys.action,
1294
+ ("next", self.tensor_keys.reward),
1295
+ ("next", self.tensor_keys.done),
1296
+ ("next", self.tensor_keys.terminated),
1297
+ *self.actor_network.in_keys,
1298
+ *[("next", key) for key in self.actor_network.in_keys],
1299
+ *self.qvalue_network.in_keys,
1300
+ ]
1301
+ self._in_keys = list(set(keys))
1302
+
1303
+ @property
1304
+ def in_keys(self):
1305
+ if self._in_keys is None:
1306
+ self._set_in_keys()
1307
+ return self._in_keys
1308
+
1309
+ @in_keys.setter
1310
+ def in_keys(self, values):
1311
+ self._in_keys = values
1312
+
1313
+ @dispatch
1314
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1315
+ loss_qvalue, metadata_value = self.qvalue_loss(tensordict)
1316
+ loss_actor, metadata_actor = self.actor_loss(tensordict)
1317
+ loss_alpha = self._alpha_loss(
1318
+ log_prob=metadata_actor["log_prob"],
1319
+ )
1320
+ weights = self._maybe_get_priority_weight(tensordict)
1321
+ loss_alpha = _reduce(loss_alpha, reduction=self.reduction, weights=weights)
1322
+
1323
+ tensordict.set(self.tensor_keys.priority, metadata_value["td_error"])
1324
+ if loss_actor.shape != loss_qvalue.shape:
1325
+ raise RuntimeError(
1326
+ f"Losses shape mismatch: {loss_actor.shape}, and {loss_qvalue.shape}"
1327
+ )
1328
+ entropy = -metadata_actor["log_prob"]
1329
+ out = {
1330
+ "loss_actor": loss_actor,
1331
+ "loss_qvalue": loss_qvalue,
1332
+ "loss_alpha": loss_alpha,
1333
+ "alpha": self._alpha,
1334
+ "entropy": entropy.detach().mean(),
1335
+ }
1336
+ td_out = TensorDict(out, [])
1337
+ self._clear_weakrefs(
1338
+ tensordict,
1339
+ td_out,
1340
+ "actor_network_params",
1341
+ "qvalue_network_params",
1342
+ "target_actor_network_params",
1343
+ "target_qvalue_network_params",
1344
+ "target_value_network_params",
1345
+ "value_network_params",
1346
+ )
1347
+ return td_out
1348
+
1349
+ def _compute_target(self, tensordict) -> Tensor:
1350
+ r"""Value network for SAC v2.
1351
+
1352
+ SAC v2 is based on a value estimate of the form:
1353
+
1354
+ .. math::
1355
+
1356
+ V = Q(s,a) - \alpha * \log p(a | s)
1357
+
1358
+ This class computes this value given the actor and qvalue network
1359
+
1360
+ """
1361
+ tensordict = tensordict.clone(False)
1362
+ # get actions and log-probs
1363
+ with torch.no_grad():
1364
+ next_tensordict = tensordict.get("next").clone(False)
1365
+
1366
+ if self.skip_done_states:
1367
+ done = next_tensordict.get(self.tensor_keys.done)
1368
+ if done is not None and done.any():
1369
+ next_tensordict_select = next_tensordict[~done.squeeze(-1)]
1370
+ else:
1371
+ next_tensordict_select = next_tensordict
1372
+
1373
+ # get probs and log probs for actions computed from "next"
1374
+ with self.actor_network_params.to_module(self.actor_network):
1375
+ next_dist = self.actor_network.get_dist(next_tensordict_select)
1376
+ next_log_prob = next_dist.logits
1377
+ next_prob = next_log_prob.exp()
1378
+
1379
+ # get q-values for all actions
1380
+ next_tensordict_expand = self._vmap_qnetworkN0(
1381
+ next_tensordict_select, self.target_qvalue_network_params
1382
+ )
1383
+ next_action_value = next_tensordict_expand.get(
1384
+ self.tensor_keys.action_value
1385
+ )
1386
+
1387
+ # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1388
+ next_state_value = (
1389
+ next_action_value.min(0)[0] - self._alpha * next_log_prob
1390
+ )
1391
+ # unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1392
+ next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
1393
+ if next_tensordict_select is not next_tensordict:
1394
+ mask = ~done
1395
+ next_state_value = next_state_value.new_zeros(
1396
+ mask.shape
1397
+ ).masked_scatter_(mask, next_state_value)
1398
+ else:
1399
+ # get probs and log probs for actions computed from "next"
1400
+ with self.actor_network_params.to_module(self.actor_network):
1401
+ next_dist = self.actor_network.get_dist(next_tensordict)
1402
+ next_prob = next_dist.probs
1403
+ next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))
1404
+
1405
+ # get q-values for all actions
1406
+ next_tensordict_expand = self._vmap_qnetworkN0(
1407
+ next_tensordict, self.target_qvalue_network_params
1408
+ )
1409
+ next_action_value = next_tensordict_expand.get(
1410
+ self.tensor_keys.action_value
1411
+ )
1412
+ # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
1413
+ next_state_value = (
1414
+ next_action_value.min(0)[0] - self._alpha * next_log_prob
1415
+ )
1416
+ # unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1417
+ next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
1418
+
1419
+ tensordict.set(
1420
+ ("next", self.value_estimator.tensor_keys.value), next_state_value
1421
+ )
1422
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
1423
+ return target_value
1424
+
1425
+ def qvalue_loss(
1426
+ self, tensordict: TensorDictBase
1427
+ ) -> tuple[Tensor, dict[str, Tensor]]:
1428
+ weights = self._maybe_get_priority_weight(tensordict)
1429
+ target_value = self._compute_target(tensordict)
1430
+ tensordict_expand = self._vmap_qnetworkN0(
1431
+ tensordict.select(*self.qvalue_network.in_keys, strict=False),
1432
+ self.qvalue_network_params,
1433
+ )
1434
+
1435
+ action_value = tensordict_expand.get(self.tensor_keys.action_value)
1436
+ action = tensordict.get(self.tensor_keys.action)
1437
+ action = action.expand((action_value.shape[0], *action.shape)) # Add vmap dim
1438
+
1439
+ # TODO this block comes from the dqn loss, we need to swap all these with a proper
1440
+ # helper function which selects the value given the action for all discrete spaces
1441
+ if self.action_space == "categorical":
1442
+ if action.shape != action_value.shape:
1443
+ # unsqueeze the action if it lacks on trailing singleton dim
1444
+ action = action.unsqueeze(-1)
1445
+ chosen_action_value = torch.gather(action_value, -1, index=action).squeeze(
1446
+ -1
1447
+ )
1448
+ else:
1449
+ action = action.to(torch.float)
1450
+ chosen_action_value = (action_value * action).sum(-1)
1451
+
1452
+ td_error = torch.abs(chosen_action_value - target_value)
1453
+ loss_qval = distance_loss(
1454
+ chosen_action_value,
1455
+ target_value.expand_as(chosen_action_value),
1456
+ loss_function=self.loss_function,
1457
+ ).sum(0)
1458
+ loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights)
1459
+
1460
+ metadata = {
1461
+ "td_error": td_error.detach().max(0)[0],
1462
+ }
1463
+ return loss_qval, metadata
1464
+
1465
+ def actor_loss(
1466
+ self, tensordict: TensorDictBase
1467
+ ) -> tuple[Tensor, dict[str, Tensor]]:
1468
+ weights = self._maybe_get_priority_weight(tensordict)
1469
+ # get probs and log probs for actions
1470
+ with self.actor_network_params.to_module(self.actor_network):
1471
+ dist = self.actor_network.get_dist(tensordict.clone(False))
1472
+ prob = dist.probs
1473
+ log_prob = dist.logits
1474
+
1475
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
1476
+
1477
+ td_q = self._vmap_qnetworkN0(
1478
+ td_q, self._cached_detached_qvalue_params # should we clone?
1479
+ )
1480
+ min_q = td_q.get(self.tensor_keys.action_value).min(0)[0]
1481
+
1482
+ if log_prob.shape != min_q.shape:
1483
+ raise RuntimeError(
1484
+ f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}"
1485
+ )
1486
+
1487
+ # like in continuous SAC, we take the entropy term and subtract the minimum of the value ensemble
1488
+ loss = self._alpha * log_prob - min_q
1489
+ # unlike in continuous SAC, we can compute the exact expectation over all discrete actions
1490
+ loss = (prob * loss).sum(-1)
1491
+ loss = _reduce(loss, reduction=self.reduction, weights=weights)
1492
+
1493
+ return loss, {"log_prob": (log_prob * prob).sum(-1).detach()}
1494
+
1495
+ def _alpha_loss(self, log_prob: Tensor) -> Tensor:
1496
+ if self.target_entropy is not None:
1497
+ # we can compute this loss even if log_alpha is not a parameter
1498
+ alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
1499
+ else:
1500
+ # placeholder
1501
+ alpha_loss = torch.zeros_like(log_prob)
1502
+ return alpha_loss
1503
+
1504
+ def alpha_loss(self, log_prob: Tensor) -> Tensor:
1505
+ """Compute the alpha loss for discrete SAC.
1506
+
1507
+ This method computes the alpha loss which adapts the entropy coefficient
1508
+ to maintain the target entropy level for discrete actions.
1509
+
1510
+ Args:
1511
+ log_prob (Tensor): The log probability of actions from the actor network.
1512
+
1513
+ Returns:
1514
+ The alpha loss tensor
1515
+ """
1516
+ return self._alpha_loss(log_prob)
1517
+
1518
+ @property
1519
+ def _alpha(self):
1520
+ if self.min_log_alpha is not None or self.max_log_alpha is not None:
1521
+ self.log_alpha.data = self.log_alpha.data.clamp(
1522
+ self.min_log_alpha, self.max_log_alpha
1523
+ )
1524
+ with torch.no_grad():
1525
+ alpha = self.log_alpha.exp()
1526
+ return alpha
1527
+
1528
+ @property
1529
+ @_cache_values
1530
+ def _cached_detached_qvalue_params(self):
1531
+ return self.qvalue_network_params.detach()
1532
+
1533
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
1534
+ if value_type is None:
1535
+ value_type = self.default_value_estimator
1536
+
1537
+ # Handle ValueEstimatorBase instance or class
1538
+ if isinstance(value_type, ValueEstimatorBase) or (
1539
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
1540
+ ):
1541
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
1542
+
1543
+ self.value_type = value_type
1544
+ hp = dict(default_value_kwargs(value_type))
1545
+ hp.update(hyperparams)
1546
+ if hasattr(self, "gamma"):
1547
+ hp["gamma"] = self.gamma
1548
+ if value_type is ValueEstimators.TD1:
1549
+ self._value_estimator = TD1Estimator(
1550
+ **hp,
1551
+ value_network=None,
1552
+ deactivate_vmap=self.deactivate_vmap,
1553
+ )
1554
+ elif value_type is ValueEstimators.TD0:
1555
+ self._value_estimator = TD0Estimator(
1556
+ **hp,
1557
+ value_network=None,
1558
+ deactivate_vmap=self.deactivate_vmap,
1559
+ )
1560
+ elif value_type is ValueEstimators.GAE:
1561
+ raise NotImplementedError(
1562
+ f"Value type {value_type} it not implemented for loss {type(self)}."
1563
+ )
1564
+ elif value_type is ValueEstimators.TDLambda:
1565
+ self._value_estimator = TDLambdaEstimator(
1566
+ **hp,
1567
+ value_network=None,
1568
+ deactivate_vmap=self.deactivate_vmap,
1569
+ )
1570
+ else:
1571
+ raise NotImplementedError(f"Unknown value type {value_type}")
1572
+
1573
+ tensor_keys = {
1574
+ "value": self.tensor_keys.value,
1575
+ "value_target": "value_target",
1576
+ "reward": self.tensor_keys.reward,
1577
+ "done": self.tensor_keys.done,
1578
+ "terminated": self.tensor_keys.terminated,
1579
+ }
1580
+ self._value_estimator.set_keys(**tensor_keys)