torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1346 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ import warnings
9
+ from copy import deepcopy
10
+ from dataclasses import dataclass
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
16
+ from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
17
+ from tensordict.utils import NestedKey, unravel_key
18
+ from torch import Tensor
19
+
20
+ from torchrl.data.tensor_specs import Composite
21
+ from torchrl.data.utils import _find_action_space
22
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
23
+ from torchrl.modules.tensordict_module.actors import QValueActor
24
+ from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible
25
+ from torchrl.objectives.common import LossModule
26
+ from torchrl.objectives.utils import (
27
+ _cache_values,
28
+ _GAMMA_LMBDA_DEPREC_ERROR,
29
+ _reduce,
30
+ _vmap_func,
31
+ default_value_kwargs,
32
+ distance_loss,
33
+ ValueEstimators,
34
+ )
35
+ from torchrl.objectives.value import (
36
+ TD0Estimator,
37
+ TD1Estimator,
38
+ TDLambdaEstimator,
39
+ ValueEstimatorBase,
40
+ )
41
+
42
+
43
+ class CQLLoss(LossModule):
44
+ """TorchRL implementation of the continuous CQL loss.
45
+
46
+ Presented in "Conservative Q-Learning for Offline Reinforcement Learning" https://arxiv.org/abs/2006.04779
47
+
48
+ Args:
49
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
50
+ qvalue_network (TensorDictModule or list of TensorDictModule): Q(s, a) parametric model.
51
+ This module typically outputs a ``"state_action_value"`` entry.
52
+ If a single instance of `qvalue_network` is provided, it will be duplicated ``N``
53
+ times (where ``N=2`` for this loss). If a list of modules is passed, their
54
+ parameters will be stacked unless they share the same identity (in which case
55
+ the original parameter will be expanded).
56
+
57
+ .. warning:: When a list of parameters if passed, it will **not** be compared against the policy parameters
58
+ and all the parameters will be considered as untied.
59
+
60
+ Keyword args:
61
+ loss_function (str, optional): loss function to be used with
62
+ the value function loss. Default is `"smooth_l1"`.
63
+ alpha_init (:obj:`float`, optional): initial entropy multiplier.
64
+ Default is 1.0.
65
+ min_alpha (:obj:`float`, optional): min value of alpha.
66
+ Default is None (no minimum value).
67
+ max_alpha (:obj:`float`, optional): max value of alpha.
68
+ Default is None (no maximum value).
69
+ action_spec (TensorSpec, optional): the action tensor spec. If not provided
70
+ and the target entropy is ``"auto"``, it will be retrieved from
71
+ the actor.
72
+ fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
73
+ initial value. Otherwise, alpha will be optimized to
74
+ match the 'target_entropy' value.
75
+ Default is ``False``.
76
+ target_entropy (:obj:`float` or str, optional): Target entropy for the
77
+ stochastic policy. Default is "auto", where target entropy is
78
+ computed as :obj:`-prod(n_actions)`.
79
+ delay_actor (bool, optional): Whether to separate the target actor
80
+ networks from the actor networks used for data collection.
81
+ Default is ``False``.
82
+ delay_qvalue (bool, optional): Whether to separate the target Q value
83
+ networks from the Q value networks used for data collection.
84
+ Default is ``True``.
85
+ gamma (:obj:`float`, optional): Discount factor. Default is ``None``.
86
+ temperature (:obj:`float`, optional): CQL temperature. Default is 1.0.
87
+ min_q_weight (:obj:`float`, optional): Minimum Q weight. Default is 1.0.
88
+ max_q_backup (bool, optional): Whether to use the max-min Q backup.
89
+ Default is ``False``.
90
+ deterministic_backup (bool, optional): Whether to use the deterministic. Default is ``True``.
91
+ num_random (int, optional): Number of random actions to sample for the CQL loss.
92
+ Default is 10.
93
+ with_lagrange (bool, optional): Whether to use the Lagrange multiplier.
94
+ Default is ``False``.
95
+ lagrange_thresh (:obj:`float`, optional): Lagrange threshold. Default is 0.0.
96
+ reduction (str, optional): Specifies the reduction to apply to the output:
97
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
98
+ ``"mean"``: the sum of the output will be divided by the number of
99
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
100
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
101
+ Defaults to ``False``.
102
+
103
+ Examples:
104
+ >>> import torch
105
+ >>> from torch import nn
106
+ >>> from torchrl.data import Bounded
107
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
108
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
109
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
110
+ >>> from torchrl.objectives.cql import CQLLoss
111
+ >>> from tensordict import TensorDict
112
+ >>> n_act, n_obs = 4, 3
113
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
114
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
115
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
116
+ >>> actor = ProbabilisticActor(
117
+ ... module=module,
118
+ ... in_keys=["loc", "scale"],
119
+ ... spec=spec,
120
+ ... distribution_class=TanhNormal)
121
+ >>> class ValueClass(nn.Module):
122
+ ... def __init__(self):
123
+ ... super().__init__()
124
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
125
+ ... def forward(self, obs, act):
126
+ ... return self.linear(torch.cat([obs, act], -1))
127
+ >>> module = ValueClass()
128
+ >>> qvalue = ValueOperator(
129
+ ... module=module,
130
+ ... in_keys=['observation', 'action'])
131
+ >>> loss = CQLLoss(actor, qvalue)
132
+ >>> batch = [2, ]
133
+ >>> action = spec.rand(batch)
134
+ >>> data = TensorDict({
135
+ ... "observation": torch.randn(*batch, n_obs),
136
+ ... "action": action,
137
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
138
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
139
+ ... ("next", "reward"): torch.randn(*batch, 1),
140
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
141
+ ... }, batch)
142
+ >>> loss(data)
143
+ TensorDict(
144
+ fields={
145
+ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
146
+ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
147
+ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
148
+ loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
149
+ loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
150
+ loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
151
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
152
+ batch_size=torch.Size([]),
153
+ device=None,
154
+ is_shared=False)
155
+
156
+ This class is compatible with non-tensordict based modules too and can be
157
+ used without recurring to any tensordict-related primitive. In this case,
158
+ the expected keyword arguments are:
159
+ ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network.
160
+ The return value is a tuple of tensors in the following order:
161
+ ``["loss_actor", "loss_qvalue", "loss_alpha", "loss_alpha_prime", "alpha", "entropy"]``.
162
+
163
+ Examples:
164
+ >>> import torch
165
+ >>> from torch import nn
166
+ >>> from torchrl.data import Bounded
167
+ >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
168
+ >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
169
+ >>> from torchrl.modules.tensordict_module.common import SafeModule
170
+ >>> from torchrl.objectives.cql import CQLLoss
171
+ >>> _ = torch.manual_seed(42)
172
+ >>> n_act, n_obs = 4, 3
173
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
174
+ >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
175
+ >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
176
+ >>> actor = ProbabilisticActor(
177
+ ... module=module,
178
+ ... in_keys=["loc", "scale"],
179
+ ... spec=spec,
180
+ ... distribution_class=TanhNormal)
181
+ >>> class ValueClass(nn.Module):
182
+ ... def __init__(self):
183
+ ... super().__init__()
184
+ ... self.linear = nn.Linear(n_obs + n_act, 1)
185
+ ... def forward(self, obs, act):
186
+ ... return self.linear(torch.cat([obs, act], -1))
187
+ >>> module = ValueClass()
188
+ >>> qvalue = ValueOperator(
189
+ ... module=module,
190
+ ... in_keys=['observation', 'action'])
191
+ >>> loss = CQLLoss(actor, qvalue)
192
+ >>> batch = [2, ]
193
+ >>> action = spec.rand(batch)
194
+ >>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss(
195
+ ... observation=torch.randn(*batch, n_obs),
196
+ ... action=action,
197
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
198
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
199
+ ... next_observation=torch.zeros(*batch, n_obs),
200
+ ... next_reward=torch.randn(*batch, 1))
201
+ >>> loss_actor.backward()
202
+
203
+ The output keys can also be filtered using the :meth:`CQLLoss.select_out_keys`
204
+ method.
205
+
206
+ Examples:
207
+ >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
208
+ >>> loss_actor, loss_qvalue = loss(
209
+ ... observation=torch.randn(*batch, n_obs),
210
+ ... action=action,
211
+ ... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
212
+ ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
213
+ ... next_observation=torch.zeros(*batch, n_obs),
214
+ ... next_reward=torch.randn(*batch, 1))
215
+ >>> loss_actor.backward()
216
+ """
217
+
218
+ @dataclass
219
+ class _AcceptedKeys:
220
+ """Maintains default values for all configurable tensordict keys.
221
+
222
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
223
+ default values.
224
+
225
+ Attributes:
226
+ action (NestedKey): The input tensordict key where the action is expected.
227
+ Defaults to ``"advantage"``.
228
+ value (NestedKey): The input tensordict key where the state value is expected.
229
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
230
+ state_action_value (NestedKey): The input tensordict key where the
231
+ state action value is expected. Defaults to ``"state_action_value"``.
232
+ log_prob (NestedKey): The input tensordict key where the log probability is expected.
233
+ Defaults to ``"_log_prob"``.
234
+ pred_q1 (NestedKey): The input tensordict key where the predicted Q1 values are expected.
235
+ Defaults to ``"pred_q1"``.
236
+ pred_q2 (NestedKey): The input tensordict key where the predicted Q2 values are expected.
237
+ Defaults to ``"pred_q2"``.
238
+ priority (NestedKey): The input tensordict key where the target priority is written to.
239
+ Defaults to ``"td_error"``.
240
+ cql_q1_loss (NestedKey): The input tensordict key where the CQL Q1 loss is expected.
241
+ Defaults to ``"cql_q1_loss"``.
242
+ cql_q2_loss (NestedKey): The input tensordict key where the CQL Q2 loss is expected.
243
+ Defaults to ``"cql_q2_loss"``.
244
+ reward (NestedKey): The input tensordict key where the reward is expected.
245
+ Defaults to ``"reward"``.
246
+ done (NestedKey): The input tensordict key where the done flag is expected.
247
+ Defaults to ``"done"``.
248
+ terminated (NestedKey): The input tensordict key where the terminated flag is expected.
249
+ Defaults to ``"terminated"``.
250
+ """
251
+
252
+ action: NestedKey = "action"
253
+ value: NestedKey = "state_value"
254
+ state_action_value: NestedKey = "state_action_value"
255
+ log_prob: NestedKey = "_log_prob"
256
+ pred_q1: NestedKey = "pred_q1"
257
+ pred_q2: NestedKey = "pred_q2"
258
+ priority: NestedKey = "td_error"
259
+ cql_q1_loss: NestedKey = "cql_q1_loss"
260
+ cql_q2_loss: NestedKey = "cql_q2_loss"
261
+ priority: NestedKey = "td_error"
262
+ reward: NestedKey = "reward"
263
+ done: NestedKey = "done"
264
+ terminated: NestedKey = "terminated"
265
+
266
+ tensor_keys: _AcceptedKeys
267
+ default_keys = _AcceptedKeys
268
+ default_value_estimator = ValueEstimators.TD0
269
+
270
+ actor_network: TensorDictModule
271
+ qvalue_network: TensorDictModule
272
+ actor_network_params: TensorDictParams
273
+ qvalue_network_params: TensorDictParams
274
+ target_actor_network_params: TensorDictParams
275
+ target_qvalue_network_params: TensorDictParams
276
+
277
+ def __init__(
278
+ self,
279
+ actor_network: ProbabilisticTensorDictSequential,
280
+ qvalue_network: TensorDictModule | list[TensorDictModule],
281
+ *,
282
+ loss_function: str = "smooth_l1",
283
+ alpha_init: float = 1.0,
284
+ min_alpha: float | None = None,
285
+ max_alpha: float | None = None,
286
+ action_spec=None,
287
+ fixed_alpha: bool = False,
288
+ target_entropy: str | float = "auto",
289
+ delay_actor: bool = False,
290
+ delay_qvalue: bool = True,
291
+ gamma: float | None = None,
292
+ temperature: float = 1.0,
293
+ min_q_weight: float = 1.0,
294
+ max_q_backup: bool = False,
295
+ deterministic_backup: bool = True,
296
+ num_random: int = 10,
297
+ with_lagrange: bool = False,
298
+ lagrange_thresh: float = 0.0,
299
+ reduction: str | None = None,
300
+ deactivate_vmap: bool = False,
301
+ ) -> None:
302
+ self._out_keys = None
303
+ if reduction is None:
304
+ reduction = "mean"
305
+ super().__init__()
306
+
307
+ # Actor
308
+ self.delay_actor = delay_actor
309
+ self.convert_to_functional(
310
+ actor_network,
311
+ "actor_network",
312
+ create_target_params=self.delay_actor,
313
+ )
314
+ self.deactivate_vmap = deactivate_vmap
315
+
316
+ # Q value
317
+ self.delay_qvalue = delay_qvalue
318
+ self.num_qvalue_nets = 2
319
+
320
+ self.convert_to_functional(
321
+ qvalue_network,
322
+ "qvalue_network",
323
+ self.num_qvalue_nets,
324
+ create_target_params=self.delay_qvalue,
325
+ compare_against=list(actor_network.parameters()),
326
+ )
327
+
328
+ self.loss_function = loss_function
329
+ try:
330
+ device = next(self.parameters()).device
331
+ except AttributeError:
332
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
333
+ self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
334
+ if bool(min_alpha) ^ bool(max_alpha):
335
+ min_alpha = min_alpha if min_alpha else 0.0
336
+ if max_alpha == 0:
337
+ raise ValueError("max_alpha must be either None or greater than 0.")
338
+ max_alpha = max_alpha if max_alpha else 1e9
339
+ if min_alpha:
340
+ self.register_buffer(
341
+ "min_log_alpha", torch.tensor(min_alpha, device=device).log()
342
+ )
343
+ else:
344
+ self.min_log_alpha = None
345
+ if max_alpha:
346
+ self.register_buffer(
347
+ "max_log_alpha", torch.tensor(max_alpha, device=device).log()
348
+ )
349
+ else:
350
+ self.max_log_alpha = None
351
+ self.fixed_alpha = fixed_alpha
352
+ if fixed_alpha:
353
+ self.register_buffer(
354
+ "log_alpha", torch.tensor(math.log(alpha_init), device=device)
355
+ )
356
+ else:
357
+ self.register_parameter(
358
+ "log_alpha",
359
+ torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
360
+ )
361
+
362
+ self._target_entropy = target_entropy
363
+ self._action_spec = action_spec
364
+ self.target_entropy_buffer = None
365
+
366
+ if gamma is not None:
367
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
368
+
369
+ self.temperature = temperature
370
+ self.min_q_weight = min_q_weight
371
+ self.max_q_backup = max_q_backup
372
+ self.deterministic_backup = deterministic_backup
373
+ self.num_random = num_random
374
+ self.with_lagrange = with_lagrange
375
+
376
+ if self.with_lagrange:
377
+ self.target_action_gap = lagrange_thresh
378
+ self.register_parameter(
379
+ "log_alpha_prime",
380
+ torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
381
+ )
382
+ self._make_vmap()
383
+ self.reduction = reduction
384
+ _ = self.target_entropy
385
+
386
+ def _make_vmap(self):
387
+ self._vmap_qvalue_networkN0 = _vmap_func(
388
+ self.qvalue_network,
389
+ (None, 0),
390
+ randomness=self.vmap_randomness,
391
+ pseudo_vmap=self.deactivate_vmap,
392
+ )
393
+ self._vmap_qvalue_network00 = _vmap_func(
394
+ self.qvalue_network,
395
+ randomness=self.vmap_randomness,
396
+ pseudo_vmap=self.deactivate_vmap,
397
+ )
398
+
399
+ @property
400
+ def target_entropy(self):
401
+ target_entropy = self.target_entropy_buffer
402
+ if target_entropy is None:
403
+ delattr(self, "target_entropy_buffer")
404
+ target_entropy = self._target_entropy
405
+ action_spec = self._action_spec
406
+ actor_network = self.actor_network
407
+ device = next(self.parameters()).device
408
+ if target_entropy == "auto":
409
+ action_spec = (
410
+ action_spec
411
+ if action_spec is not None
412
+ else getattr(actor_network, "spec", None)
413
+ )
414
+ if action_spec is None:
415
+ raise RuntimeError(
416
+ "Cannot infer the dimensionality of the action. Consider providing "
417
+ "the target entropy explicitly or provide the spec of the "
418
+ "action tensor in the actor network."
419
+ )
420
+ if not isinstance(action_spec, Composite):
421
+ action_spec = Composite({self.tensor_keys.action: action_spec})
422
+ if (
423
+ isinstance(self.tensor_keys.action, tuple)
424
+ and len(self.tensor_keys.action) > 1
425
+ ):
426
+ action_container_shape = action_spec[
427
+ self.tensor_keys.action[:-1]
428
+ ].shape
429
+ else:
430
+ action_container_shape = action_spec.shape
431
+ target_entropy = -float(
432
+ action_spec[self.tensor_keys.action]
433
+ .shape[len(action_container_shape) :]
434
+ .numel()
435
+ )
436
+ self.register_buffer(
437
+ "target_entropy_buffer", torch.tensor(target_entropy, device=device)
438
+ )
439
+ return self.target_entropy_buffer
440
+ return target_entropy
441
+
442
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
443
+ if self._value_estimator is not None:
444
+ self._value_estimator.set_keys(
445
+ value=self._tensor_keys.value,
446
+ reward=self.tensor_keys.reward,
447
+ done=self.tensor_keys.done,
448
+ terminated=self.tensor_keys.terminated,
449
+ )
450
+
451
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
452
+ if value_type is None:
453
+ value_type = self.default_value_estimator
454
+
455
+ # Handle ValueEstimatorBase instance or class
456
+ if isinstance(value_type, ValueEstimatorBase) or (
457
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
458
+ ):
459
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
460
+
461
+ self.value_type = value_type
462
+
463
+ # we will take care of computing the next value inside this module
464
+ value_net = None
465
+
466
+ hp = dict(default_value_kwargs(value_type))
467
+ hp.update(hyperparams)
468
+ if value_type is ValueEstimators.TD1:
469
+ self._value_estimator = TD1Estimator(
470
+ **hp,
471
+ value_network=value_net,
472
+ )
473
+ elif value_type is ValueEstimators.TD0:
474
+ self._value_estimator = TD0Estimator(
475
+ **hp,
476
+ value_network=value_net,
477
+ )
478
+ elif value_type is ValueEstimators.GAE:
479
+ raise NotImplementedError(
480
+ f"Value type {value_type} it not implemented for loss {type(self)}."
481
+ )
482
+ elif value_type is ValueEstimators.TDLambda:
483
+ self._value_estimator = TDLambdaEstimator(
484
+ **hp,
485
+ value_network=value_net,
486
+ )
487
+ else:
488
+ raise NotImplementedError(f"Unknown value type {value_type}")
489
+
490
+ tensor_keys = {
491
+ "value_target": "value_target",
492
+ "value": self.tensor_keys.value,
493
+ "reward": self.tensor_keys.reward,
494
+ "done": self.tensor_keys.done,
495
+ "terminated": self.tensor_keys.terminated,
496
+ }
497
+ self._value_estimator.set_keys(**tensor_keys)
498
+
499
+ @property
500
+ def in_keys(self):
501
+ keys = [
502
+ self.tensor_keys.action,
503
+ ("next", self.tensor_keys.reward),
504
+ ("next", self.tensor_keys.done),
505
+ ("next", self.tensor_keys.terminated),
506
+ *self.actor_network.in_keys,
507
+ *[("next", key) for key in self.actor_network.in_keys],
508
+ *self.qvalue_network.in_keys,
509
+ ]
510
+
511
+ return list(set(keys))
512
+
513
+ @property
514
+ def out_keys(self):
515
+ if self._out_keys is None:
516
+ keys = [
517
+ "loss_actor",
518
+ "loss_actor_bc",
519
+ "loss_qvalue",
520
+ "loss_cql",
521
+ "loss_alpha",
522
+ "alpha",
523
+ "entropy",
524
+ ]
525
+ if self.with_lagrange:
526
+ keys.append("loss_alpha_prime")
527
+ self._out_keys = keys
528
+ return self._out_keys
529
+
530
+ @out_keys.setter
531
+ def out_keys(self, values):
532
+ self._out_keys = values
533
+
534
+ @dispatch
535
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
536
+ q_loss, metadata = self.q_loss(tensordict)
537
+ cql_loss, cql_metadata = self.cql_loss(tensordict)
538
+ if self.with_lagrange:
539
+ alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict)
540
+ metadata.update(alpha_prime_metadata)
541
+ loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict)
542
+ loss_actor, actor_metadata = self.actor_loss(tensordict)
543
+ loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
544
+ metadata.update(bc_metadata)
545
+ metadata.update(cql_metadata)
546
+ metadata.update(actor_metadata)
547
+ metadata.update(alpha_metadata)
548
+ tensordict.set(
549
+ self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
550
+ )
551
+ out = {
552
+ "loss_actor": loss_actor,
553
+ "loss_actor_bc": loss_actor_bc,
554
+ "loss_qvalue": q_loss,
555
+ "loss_cql": cql_loss,
556
+ "loss_alpha": loss_alpha,
557
+ "alpha": self._alpha,
558
+ "entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
559
+ }
560
+ if self.with_lagrange:
561
+ out["loss_alpha_prime"] = alpha_prime_loss.mean()
562
+ td_loss = TensorDict(out)
563
+ self._clear_weakrefs(
564
+ tensordict,
565
+ td_loss,
566
+ "actor_network_params",
567
+ "qvalue_network_params",
568
+ "target_actor_network_params",
569
+ "target_qvalue_network_params",
570
+ )
571
+ return td_loss
572
+
573
+ @property
574
+ @_cache_values
575
+ def _cached_detach_qvalue_params(self):
576
+ return self.qvalue_network_params.detach()
577
+
578
+ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
579
+ with set_exploration_type(
580
+ ExplorationType.RANDOM
581
+ ), self.actor_network_params.to_module(self.actor_network):
582
+ dist = self.actor_network.get_dist(
583
+ tensordict,
584
+ )
585
+ a_reparm = dist.rsample()
586
+ log_prob = dist.log_prob(a_reparm)
587
+ bc_log_prob = dist.log_prob(tensordict.get(self.tensor_keys.action))
588
+
589
+ bc_actor_loss = self._alpha * log_prob - bc_log_prob
590
+ bc_actor_loss = _reduce(bc_actor_loss, reduction=self.reduction)
591
+ metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
592
+ self._clear_weakrefs(
593
+ tensordict,
594
+ "actor_network_params",
595
+ "qvalue_network_params",
596
+ "target_actor_network_params",
597
+ "target_qvalue_network_params",
598
+ )
599
+ return bc_actor_loss, metadata
600
+
601
+ def actor_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
602
+ with set_exploration_type(
603
+ ExplorationType.RANDOM
604
+ ), self.actor_network_params.to_module(self.actor_network):
605
+ dist = self.actor_network.get_dist(
606
+ tensordict,
607
+ )
608
+ a_reparm = dist.rsample()
609
+ log_prob = dist.log_prob(a_reparm)
610
+
611
+ td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
612
+ if td_q is tensordict:
613
+ raise RuntimeError
614
+ td_q.set(self.tensor_keys.action, a_reparm)
615
+ td_q = self._vmap_qvalue_networkN0(
616
+ td_q,
617
+ self._cached_detach_qvalue_params,
618
+ )
619
+ min_q_logprob = (
620
+ td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
621
+ )
622
+
623
+ if log_prob.shape != min_q_logprob.shape:
624
+ raise RuntimeError(
625
+ f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}"
626
+ )
627
+
628
+ metadata = {}
629
+ metadata[self.tensor_keys.log_prob] = log_prob.detach()
630
+ actor_loss = self._alpha * log_prob - min_q_logprob
631
+ actor_loss = _reduce(actor_loss, reduction=self.reduction)
632
+ self._clear_weakrefs(
633
+ tensordict,
634
+ "actor_network_params",
635
+ "qvalue_network_params",
636
+ "target_actor_network_params",
637
+ "target_qvalue_network_params",
638
+ )
639
+ return actor_loss, metadata
640
+
641
+ def _get_policy_actions(self, data, actor_params, num_actions=10):
642
+ batch_size = data.batch_size
643
+ batch_size = list(batch_size[:-1]) + [batch_size[-1] * num_actions]
644
+ in_keys = [unravel_key(key) for key in self.actor_network.in_keys]
645
+
646
+ def filter_and_repeat(name, x):
647
+ if name in in_keys:
648
+ return x.repeat_interleave(num_actions, dim=data.ndim - 1)
649
+
650
+ tensordict = data.named_apply(
651
+ filter_and_repeat, batch_size=batch_size, filter_empty=True
652
+ )
653
+ with set_exploration_type(ExplorationType.RANDOM), actor_params.data.to_module(
654
+ self.actor_network
655
+ ):
656
+ dist = self.actor_network.get_dist(tensordict)
657
+ action = dist.rsample()
658
+ tensordict.set(self.tensor_keys.action, action)
659
+ sample_log_prob = dist.log_prob(action)
660
+
661
+ return (
662
+ tensordict.select(
663
+ *self.actor_network.in_keys, self.tensor_keys.action, strict=False
664
+ ),
665
+ sample_log_prob,
666
+ )
667
+
668
+ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
669
+ tensordict = tensordict.clone(False)
670
+ # get actions and log-probs
671
+ # TODO: wait for compile to handle this properly
672
+ actor_data = actor_params.data.to_module(self.actor_network)
673
+ with set_exploration_type(ExplorationType.RANDOM):
674
+ next_tensordict = tensordict.get("next").clone(False)
675
+ next_dist = self.actor_network.get_dist(next_tensordict)
676
+ next_action = next_dist.rsample()
677
+ next_tensordict.set(self.tensor_keys.action, next_action)
678
+ next_sample_log_prob = next_dist.log_prob(next_action)
679
+ actor_data.to_module(self.actor_network, return_swap=False)
680
+
681
+ # get q-values
682
+ if not self.max_q_backup:
683
+ next_tensordict_expand = self._vmap_qvalue_networkN0(
684
+ next_tensordict, qval_params.data
685
+ )
686
+ next_state_value = next_tensordict_expand.get(
687
+ self.tensor_keys.state_action_value
688
+ ).min(0)[0]
689
+ if (
690
+ next_state_value.shape[-len(next_sample_log_prob.shape) :]
691
+ != next_sample_log_prob.shape
692
+ ):
693
+ next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
694
+ if not self.deterministic_backup:
695
+ next_state_value = next_state_value - _alpha * next_sample_log_prob
696
+
697
+ if self.max_q_backup:
698
+ next_tensordict, _ = self._get_policy_actions(
699
+ tensordict.get("next").copy(),
700
+ actor_params,
701
+ num_actions=self.num_random,
702
+ )
703
+ next_tensordict_expand = self._vmap_qvalue_networkN0(
704
+ next_tensordict, qval_params.data
705
+ )
706
+
707
+ state_action_value = next_tensordict_expand.get(
708
+ self.tensor_keys.state_action_value
709
+ )
710
+ # take max over actions
711
+ state_action_value = state_action_value.reshape(
712
+ torch.Size(
713
+ [self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
714
+ )
715
+ ).max(-2)[0]
716
+ # take min over qvalue nets
717
+ next_state_value = state_action_value.min(0)[0]
718
+
719
+ tensordict.set(
720
+ ("next", self.value_estimator.tensor_keys.value), next_state_value
721
+ )
722
+ target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
723
+ return target_value
724
+
725
+ def q_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
726
+ # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
727
+ target_value = self._get_value_v(
728
+ tensordict.copy(),
729
+ self._alpha,
730
+ self.actor_network_params,
731
+ self.target_qvalue_network_params,
732
+ )
733
+
734
+ tensordict_pred_q = tensordict.select(
735
+ *self.qvalue_network.in_keys, strict=False
736
+ )
737
+ q_pred = self._vmap_qvalue_networkN0(
738
+ tensordict_pred_q, self.qvalue_network_params
739
+ ).get(self.tensor_keys.state_action_value)
740
+
741
+ # write pred values in tensordict for cql loss
742
+ tensordict.set(self.tensor_keys.pred_q1, q_pred[0])
743
+ tensordict.set(self.tensor_keys.pred_q2, q_pred[1])
744
+
745
+ q_pred = q_pred.squeeze(-1)
746
+ loss_qval = distance_loss(
747
+ q_pred,
748
+ target_value.expand_as(q_pred),
749
+ loss_function=self.loss_function,
750
+ ).sum(0)
751
+ loss_qval = _reduce(loss_qval, reduction=self.reduction)
752
+ td_error = (q_pred - target_value).pow(2)
753
+ metadata = {"td_error": td_error.detach()}
754
+ self._clear_weakrefs(
755
+ tensordict,
756
+ "actor_network_params",
757
+ "qvalue_network_params",
758
+ "target_actor_network_params",
759
+ "target_qvalue_network_params",
760
+ )
761
+ return loss_qval, metadata
762
+
763
+ def cql_loss(self, tensordict: TensorDictBase) -> tuple[Tensor, dict]:
764
+ pred_q1 = tensordict.get(self.tensor_keys.pred_q1)
765
+ pred_q2 = tensordict.get(self.tensor_keys.pred_q2)
766
+
767
+ if pred_q1 is None:
768
+ raise KeyError(
769
+ f"Couldn't find the pred_q1 with key {self.tensor_keys.pred_q1} in the input tensordict. "
770
+ "This could be caused by calling cql_loss method before q_loss method."
771
+ )
772
+ if pred_q2 is None:
773
+ raise KeyError(
774
+ f"Couldn't find the pred_q2 with key {self.tensor_keys.pred_q2} in the input tensordict. "
775
+ "This could be caused by calling cql_loss method before q_loss method."
776
+ )
777
+
778
+ random_actions_tensor = pred_q1.new_empty(
779
+ (
780
+ *tensordict.shape[:-1],
781
+ tensordict.shape[-1] * self.num_random,
782
+ tensordict[self.tensor_keys.action].shape[-1],
783
+ )
784
+ ).uniform_(-1, 1)
785
+ curr_actions_td, curr_log_pis = self._get_policy_actions(
786
+ tensordict.copy(),
787
+ self.actor_network_params,
788
+ num_actions=self.num_random,
789
+ )
790
+ new_curr_actions_td, new_log_pis = self._get_policy_actions(
791
+ tensordict.get("next").copy(),
792
+ self.actor_network_params,
793
+ num_actions=self.num_random,
794
+ )
795
+
796
+ # process all in one forward pass
797
+ # stack qvalue params
798
+ qvalue_params = torch.cat(
799
+ [
800
+ self.qvalue_network_params,
801
+ self.qvalue_network_params,
802
+ self.qvalue_network_params,
803
+ ],
804
+ 0,
805
+ )
806
+ # select and stack input params
807
+ # q value random action
808
+ tensordict_q_random = tensordict.select(
809
+ *self.actor_network.in_keys, strict=False
810
+ )
811
+
812
+ batch_size = tensordict_q_random.batch_size
813
+ batch_size = list(batch_size[:-1]) + [batch_size[-1] * self.num_random]
814
+ in_keys = [unravel_key(key) for key in self.actor_network.in_keys]
815
+
816
+ def filter_and_repeat(name, x):
817
+ if name in in_keys:
818
+ return x.repeat_interleave(
819
+ self.num_random, dim=tensordict_q_random.ndim - 1
820
+ )
821
+
822
+ tensordict_q_random = tensordict_q_random.named_apply(
823
+ filter_and_repeat,
824
+ batch_size=batch_size,
825
+ filter_empty=True,
826
+ )
827
+ tensordict_q_random.set(self.tensor_keys.action, random_actions_tensor)
828
+ cql_tensordict = torch.cat(
829
+ [
830
+ tensordict_q_random.expand(
831
+ self.num_qvalue_nets, *curr_actions_td.batch_size
832
+ ),
833
+ curr_actions_td.expand(
834
+ self.num_qvalue_nets, *curr_actions_td.batch_size
835
+ ),
836
+ new_curr_actions_td.expand(
837
+ self.num_qvalue_nets, *curr_actions_td.batch_size
838
+ ),
839
+ ],
840
+ 0,
841
+ )
842
+ cql_tensordict = cql_tensordict.contiguous()
843
+
844
+ cql_tensordict_expand = self._vmap_qvalue_network00(
845
+ cql_tensordict, qvalue_params
846
+ )
847
+ # get q values
848
+ state_action_value = cql_tensordict_expand.get(
849
+ self.tensor_keys.state_action_value
850
+ )
851
+ # split q values
852
+ (q_random, q_curr, q_new,) = state_action_value.split(
853
+ [
854
+ self.num_qvalue_nets,
855
+ self.num_qvalue_nets,
856
+ self.num_qvalue_nets,
857
+ ],
858
+ dim=0,
859
+ )
860
+
861
+ # importance sammpled version
862
+ random_density = np.log(
863
+ 0.5 ** curr_actions_td[self.tensor_keys.action].shape[-1]
864
+ )
865
+ cat_q1 = torch.cat(
866
+ [
867
+ q_random[0] - random_density,
868
+ q_new[0] - new_log_pis.detach().unsqueeze(-1),
869
+ q_curr[0] - curr_log_pis.detach().unsqueeze(-1),
870
+ ],
871
+ -1,
872
+ )
873
+ cat_q2 = torch.cat(
874
+ [
875
+ q_random[1] - random_density,
876
+ q_new[1] - new_log_pis.detach().unsqueeze(-1),
877
+ q_curr[1] - curr_log_pis.detach().unsqueeze(-1),
878
+ ],
879
+ -1,
880
+ )
881
+
882
+ min_qf1_loss = (
883
+ torch.logsumexp(cat_q1 / self.temperature, dim=-1)
884
+ * self.min_q_weight
885
+ * self.temperature
886
+ )
887
+ min_qf2_loss = (
888
+ torch.logsumexp(cat_q2 / self.temperature, dim=-1)
889
+ * self.min_q_weight
890
+ * self.temperature
891
+ )
892
+
893
+ # Subtract the log likelihood of data
894
+ cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight
895
+ cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight
896
+
897
+ # write cql losses in tensordict for alpha prime loss
898
+ tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
899
+ tensordict.set(self.tensor_keys.cql_q2_loss, cql_q2_loss)
900
+
901
+ cql_q_loss = (cql_q1_loss + cql_q2_loss).mean(-1)
902
+ cql_q_loss = _reduce(cql_q_loss, reduction=self.reduction)
903
+
904
+ self._clear_weakrefs(
905
+ tensordict,
906
+ "actor_network_params",
907
+ "qvalue_network_params",
908
+ "target_actor_network_params",
909
+ "target_qvalue_network_params",
910
+ )
911
+ return cql_q_loss, {}
912
+
913
+ def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor:
914
+ cql_q1_loss = tensordict.get(self.tensor_keys.cql_q1_loss)
915
+ cql_q2_loss = tensordict.get(self.tensor_keys.cql_q2_loss)
916
+
917
+ if cql_q1_loss is None:
918
+ raise KeyError(
919
+ f"Couldn't find the cql_q1_loss with key {self.tensor_keys.cql_q1_loss} in the input tensordict. "
920
+ "This could be caused by calling alpha_prime_loss method before cql_loss method."
921
+ )
922
+ if cql_q2_loss is None:
923
+ raise KeyError(
924
+ f"Couldn't find the cql_q2_loss with key {self.tensor_keys.cql_q2_loss} in the input tensordict. "
925
+ "This could be caused by calling alpha_prime_loss method before cql_loss method."
926
+ )
927
+
928
+ alpha_prime = torch.clamp_max(self.log_alpha_prime.exp(), max=1000000.0)
929
+ min_qf1_loss = alpha_prime * (cql_q1_loss.mean() - self.target_action_gap)
930
+ min_qf2_loss = alpha_prime * (cql_q2_loss.mean() - self.target_action_gap)
931
+
932
+ alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
933
+ alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction)
934
+ self._clear_weakrefs(
935
+ tensordict,
936
+ "actor_network_params",
937
+ "qvalue_network_params",
938
+ "target_actor_network_params",
939
+ "target_qvalue_network_params",
940
+ )
941
+ return alpha_prime_loss, {}
942
+
943
+ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
944
+ log_pi = tensordict.get(self.tensor_keys.log_prob)
945
+ if self.target_entropy is not None:
946
+ # we can compute this loss even if log_alpha is not a parameter
947
+ alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy)
948
+ else:
949
+ # placeholder
950
+ alpha_loss = torch.zeros_like(log_pi)
951
+ alpha_loss = _reduce(alpha_loss, reduction=self.reduction)
952
+ self._clear_weakrefs(
953
+ tensordict,
954
+ "actor_network_params",
955
+ "qvalue_network_params",
956
+ "target_actor_network_params",
957
+ "target_qvalue_network_params",
958
+ )
959
+ return alpha_loss, {}
960
+
961
+ @property
962
+ def _alpha(self):
963
+ if self.min_log_alpha is not None or self.max_log_alpha is not None:
964
+ self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
965
+ alpha = self.log_alpha.data.exp()
966
+ return alpha
967
+
968
+
969
+ class DiscreteCQLLoss(LossModule):
970
+ """TorchRL implementation of the discrete CQL loss.
971
+
972
+ This class implements the discrete conservative Q-learning (CQL) loss function, as presented in the paper
973
+ "Conservative Q-Learning for Offline Reinforcement Learning" (https://arxiv.org/abs/2006.04779).
974
+
975
+ Args:
976
+ value_network (Union[QValueActor, nn.Module]): The Q-value network used to estimate state-action values.
977
+ Keyword Args:
978
+ loss_function (Optional[str]): The distance function used to calculate the distance between the predicted
979
+ Q-values and the target Q-values. Defaults to ``l2``.
980
+ delay_value (bool): Whether to separate the target Q value
981
+ networks from the Q value networks used for data collection.
982
+ Default is ``True``.
983
+ gamma (:obj:`float`, optional): Discount factor. Default is ``None``.
984
+ action_space: The action space of the environment. If None, it is inferred from the value network.
985
+ Defaults to None.
986
+ reduction (str, optional): Specifies the reduction to apply to the output:
987
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
988
+ ``"mean"``: the sum of the output will be divided by the number of
989
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
990
+
991
+ Examples:
992
+ >>> from torchrl.modules import MLP, QValueActor
993
+ >>> from torchrl.data import OneHot
994
+ >>> from torchrl.objectives import DiscreteCQLLoss
995
+ >>> n_obs, n_act = 4, 3
996
+ >>> value_net = MLP(in_features=n_obs, out_features=n_act)
997
+ >>> spec = OneHot(n_act)
998
+ >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
999
+ >>> loss = DiscreteCQLLoss(actor, action_space=spec)
1000
+ >>> batch = [10,]
1001
+ >>> data = TensorDict({
1002
+ ... "observation": torch.randn(*batch, n_obs),
1003
+ ... "action": spec.rand(batch),
1004
+ ... ("next", "observation"): torch.randn(*batch, n_obs),
1005
+ ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
1006
+ ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
1007
+ ... ("next", "reward"): torch.randn(*batch, 1)
1008
+ ... }, batch)
1009
+ >>> loss(data)
1010
+ TensorDict(
1011
+ fields={
1012
+ loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1013
+ loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1014
+ pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1015
+ target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1016
+ td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
1017
+ batch_size=torch.Size([]),
1018
+ device=None,
1019
+ is_shared=False)
1020
+
1021
+ This class is compatible with non-tensordict based modules too and can be
1022
+ used without recurring to any tensordict-related primitive. In this case,
1023
+ the expected keyword arguments are:
1024
+ ``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``,
1025
+ and a single loss value is returned.
1026
+
1027
+ Examples:
1028
+ >>> from torchrl.objectives import DiscreteCQLLoss
1029
+ >>> from torchrl.data import OneHot
1030
+ >>> from torch import nn
1031
+ >>> import torch
1032
+ >>> n_obs = 3
1033
+ >>> n_action = 4
1034
+ >>> action_spec = OneHot(n_action)
1035
+ >>> value_network = nn.Linear(n_obs, n_action) # a simple value model
1036
+ >>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec)
1037
+ >>> # define data
1038
+ >>> observation = torch.randn(n_obs)
1039
+ >>> next_observation = torch.randn(n_obs)
1040
+ >>> action = action_spec.rand()
1041
+ >>> next_reward = torch.randn(1)
1042
+ >>> next_done = torch.zeros(1, dtype=torch.bool)
1043
+ >>> next_terminated = torch.zeros(1, dtype=torch.bool)
1044
+ >>> loss_val = dcql_loss(
1045
+ ... observation=observation,
1046
+ ... next_observation=next_observation,
1047
+ ... next_reward=next_reward,
1048
+ ... next_done=next_done,
1049
+ ... next_terminated=next_terminated,
1050
+ ... action=action)
1051
+ """
1052
+
1053
+ @dataclass
1054
+ class _AcceptedKeys:
1055
+ """Maintains default values for all configurable tensordict keys.
1056
+
1057
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
1058
+ default values.
1059
+
1060
+ Attributes:
1061
+ value_target (NestedKey): The input tensordict key where the target state value is expected.
1062
+ Will be used for the underlying value estimator Defaults to ``"value_target"``.
1063
+ value (NestedKey): The input tensordict key where the chosen action value is expected.
1064
+ Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``.
1065
+ action_value (NestedKey): The input tensordict key where the action value is expected.
1066
+ Defaults to ``"action_value"``.
1067
+ action (NestedKey): The input tensordict key where the action is expected.
1068
+ Defaults to ``"action"``.
1069
+ priority (NestedKey): The input tensordict key where the target priority is written to.
1070
+ Defaults to ``"td_error"``.
1071
+ reward (NestedKey): The input tensordict key where the reward is expected.
1072
+ Will be used for the underlying value estimator. Defaults to ``"reward"``.
1073
+ done (NestedKey): The key in the input TensorDict that indicates
1074
+ whether a trajectory is done. Will be used for the underlying value estimator.
1075
+ Defaults to ``"done"``.
1076
+ terminated (NestedKey): The key in the input TensorDict that indicates
1077
+ whether a trajectory is terminated. Will be used for the underlying value estimator.
1078
+ Defaults to ``"terminated"``.
1079
+ pred_val (NestedKey): The key where the predicted value will be written
1080
+ in the input tensordict. This value is subsequently used by cql_loss.
1081
+ Defaults to ``"pred_val"``.
1082
+
1083
+ """
1084
+
1085
+ value_target: NestedKey = "value_target"
1086
+ value: NestedKey = "chosen_action_value"
1087
+ action_value: NestedKey = "action_value"
1088
+ action: NestedKey = "action"
1089
+ priority: NestedKey = "td_error"
1090
+ reward: NestedKey = "reward"
1091
+ done: NestedKey = "done"
1092
+ terminated: NestedKey = "terminated"
1093
+ pred_val: NestedKey = "pred_val"
1094
+
1095
+ tensor_keys: _AcceptedKeys
1096
+ default_keys = _AcceptedKeys
1097
+ default_value_estimator = ValueEstimators.TD0
1098
+ out_keys = [
1099
+ "loss_qvalue",
1100
+ "loss_cql",
1101
+ ]
1102
+
1103
+ value_network: TensorDictModule
1104
+ value_network_params: TensorDictParams
1105
+ target_value_network_params: TensorDictParams
1106
+
1107
+ def __init__(
1108
+ self,
1109
+ value_network: QValueActor | nn.Module,
1110
+ *,
1111
+ loss_function: str | None = "l2",
1112
+ delay_value: bool = True,
1113
+ gamma: float | None = None,
1114
+ action_space=None,
1115
+ reduction: str | None = None,
1116
+ ) -> None:
1117
+ self._in_keys = None
1118
+ if reduction is None:
1119
+ reduction = "mean"
1120
+ super().__init__()
1121
+ self.delay_value = delay_value
1122
+ value_network = ensure_tensordict_compatible(
1123
+ module=value_network,
1124
+ wrapper_type=QValueActor,
1125
+ action_space=action_space,
1126
+ )
1127
+
1128
+ self.convert_to_functional(
1129
+ value_network,
1130
+ "value_network",
1131
+ create_target_params=self.delay_value,
1132
+ )
1133
+
1134
+ self.value_network_in_keys = value_network.in_keys
1135
+
1136
+ self.loss_function = loss_function
1137
+ if action_space is None:
1138
+ # infer from value net
1139
+ if hasattr(value_network, "action_space"):
1140
+ action_space = value_network.spec
1141
+ else:
1142
+ # let's try with action_space then
1143
+ try:
1144
+ action_space = value_network.action_space
1145
+ except AttributeError:
1146
+ raise ValueError(self.ACTION_SPEC_ERROR)
1147
+ if action_space is None:
1148
+ warnings.warn(
1149
+ "action_space was not specified. DiscreteCQLLoss will default to 'one-hot'. "
1150
+ "This behavior will be deprecated soon and a space will have to be passed. "
1151
+ "Check the DiscreteCQLLoss documentation to see how to pass the action space."
1152
+ )
1153
+ action_space = "one-hot"
1154
+ self.action_space = _find_action_space(action_space)
1155
+ self.reduction = reduction
1156
+
1157
+ if gamma is not None:
1158
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
1159
+
1160
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
1161
+ if self._value_estimator is not None:
1162
+ self._value_estimator.set_keys(
1163
+ value_target=self.tensor_keys.value_target,
1164
+ value=self._tensor_keys.value,
1165
+ reward=self._tensor_keys.reward,
1166
+ done=self._tensor_keys.done,
1167
+ terminated=self._tensor_keys.terminated,
1168
+ )
1169
+ self._set_in_keys()
1170
+
1171
+ def _set_in_keys(self):
1172
+ in_keys = {
1173
+ self.tensor_keys.action,
1174
+ unravel_key(("next", self.tensor_keys.reward)),
1175
+ unravel_key(("next", self.tensor_keys.done)),
1176
+ unravel_key(("next", self.tensor_keys.terminated)),
1177
+ *self.value_network.in_keys,
1178
+ *[unravel_key(("next", key)) for key in self.value_network.in_keys],
1179
+ }
1180
+ self._in_keys = sorted(in_keys, key=str)
1181
+
1182
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
1183
+ if value_type is None:
1184
+ value_type = self.default_value_estimator
1185
+
1186
+ # Handle ValueEstimatorBase instance or class
1187
+ if isinstance(value_type, ValueEstimatorBase) or (
1188
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
1189
+ ):
1190
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
1191
+
1192
+ self.value_type = value_type
1193
+
1194
+ # we will take care of computing the next value inside this module
1195
+ value_net = deepcopy(self.value_network)
1196
+ self.value_network_params.to_module(value_net, return_swap=False)
1197
+
1198
+ hp = dict(default_value_kwargs(value_type))
1199
+ hp.update(hyperparams)
1200
+ if value_type is ValueEstimators.TD1:
1201
+ self._value_estimator = TD1Estimator(
1202
+ **hp,
1203
+ value_network=value_net,
1204
+ )
1205
+ elif value_type is ValueEstimators.TD0:
1206
+ self._value_estimator = TD0Estimator(
1207
+ **hp,
1208
+ value_network=value_net,
1209
+ )
1210
+ elif value_type is ValueEstimators.GAE:
1211
+ raise NotImplementedError(
1212
+ f"Value type {value_type} it not implemented for loss {type(self)}."
1213
+ )
1214
+ elif value_type is ValueEstimators.TDLambda:
1215
+ self._value_estimator = TDLambdaEstimator(
1216
+ **hp,
1217
+ value_network=value_net,
1218
+ )
1219
+ else:
1220
+ raise NotImplementedError(f"Unknown value type {value_type}")
1221
+
1222
+ tensor_keys = {
1223
+ "value_target": "value_target",
1224
+ "value": self.tensor_keys.value,
1225
+ "reward": self.tensor_keys.reward,
1226
+ "done": self.tensor_keys.done,
1227
+ "terminated": self.tensor_keys.terminated,
1228
+ }
1229
+ self._value_estimator.set_keys(**tensor_keys)
1230
+
1231
+ @property
1232
+ def in_keys(self):
1233
+ if self._in_keys is None:
1234
+ self._set_in_keys()
1235
+ return self._in_keys
1236
+
1237
+ @in_keys.setter
1238
+ def in_keys(self, values):
1239
+ self._in_keys = values
1240
+
1241
+ @dispatch
1242
+ def value_loss(
1243
+ self,
1244
+ tensordict: TensorDictBase,
1245
+ ) -> tuple[torch.Tensor, dict]:
1246
+ td_copy = tensordict.clone(False)
1247
+ with self.value_network_params.to_module(self.value_network):
1248
+ self.value_network(td_copy)
1249
+
1250
+ action = tensordict.get(self.tensor_keys.action)
1251
+ pred_val = td_copy.get(self.tensor_keys.action_value)
1252
+
1253
+ if self.action_space == "categorical":
1254
+ if action.shape != pred_val.shape:
1255
+ # unsqueeze the action if it lacks on trailing singleton dim
1256
+ action = action.unsqueeze(-1)
1257
+ pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1)
1258
+ else:
1259
+ action = action.to(torch.float)
1260
+ pred_val_index = (pred_val * action).sum(-1)
1261
+
1262
+ # calculate target value
1263
+ target_value = self.value_estimator.value_estimate(
1264
+ td_copy, params=self._cached_detached_target_value_params
1265
+ ).squeeze(-1)
1266
+
1267
+ td_error = (pred_val_index - target_value).pow(2)
1268
+ td_error = td_error.unsqueeze(-1)
1269
+
1270
+ tensordict.set(
1271
+ self.tensor_keys.priority,
1272
+ td_error,
1273
+ inplace=True,
1274
+ )
1275
+ tensordict.set(
1276
+ self.tensor_keys.pred_val,
1277
+ pred_val,
1278
+ inplace=True,
1279
+ )
1280
+ loss = 0.5 * distance_loss(pred_val_index, target_value, self.loss_function)
1281
+ loss = _reduce(loss, reduction=self.reduction)
1282
+
1283
+ metadata = {
1284
+ "td_error": td_error.mean(0).detach(),
1285
+ "pred_value": pred_val.mean().detach(),
1286
+ "target_value": target_value.mean().detach(),
1287
+ }
1288
+
1289
+ return loss, metadata
1290
+
1291
+ @dispatch
1292
+ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1293
+ """Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer.
1294
+
1295
+ This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
1296
+ a priority to items in the tensordict.
1297
+
1298
+ Args:
1299
+ tensordict (TensorDictBase): a tensordict with keys ["action"] and the in_keys of
1300
+ the value network (observations, "done", "terminated", "reward" in a "next" tensordict).
1301
+
1302
+ Returns:
1303
+ a tensor containing the CQL loss.
1304
+
1305
+ """
1306
+ loss_qval, metadata = self.value_loss(tensordict)
1307
+ loss_cql, _ = self.cql_loss(tensordict)
1308
+ source = {
1309
+ "loss_qvalue": loss_qval,
1310
+ "loss_cql": loss_cql,
1311
+ }
1312
+ source.update(metadata)
1313
+ td_out = TensorDict(
1314
+ source=source,
1315
+ batch_size=[],
1316
+ )
1317
+
1318
+ return td_out
1319
+
1320
+ @property
1321
+ @_cache_values
1322
+ def _cached_detached_target_value_params(self):
1323
+ return self.target_value_network_params.detach()
1324
+
1325
+ def cql_loss(self, tensordict):
1326
+ qvalues = tensordict.get(self.tensor_keys.pred_val, default=None)
1327
+ if qvalues is None:
1328
+ raise KeyError(
1329
+ "Couldn't find the predicted qvalue with key {self.tensor_keys.pred_val} in the input tensordict. "
1330
+ "This could be caused by calling cql_loss method before value_loss."
1331
+ )
1332
+
1333
+ current_action = tensordict.get(self.tensor_keys.action)
1334
+
1335
+ logsumexp = torch.logsumexp(qvalues, dim=-1, keepdim=True)
1336
+ if self.action_space == "categorical":
1337
+ if current_action.shape != qvalues.shape:
1338
+ # unsqueeze the action if it lacks on trailing singleton dim
1339
+ current_action = current_action.unsqueeze(-1)
1340
+ q_a = qvalues.gather(-1, current_action)
1341
+ else:
1342
+ q_a = (qvalues * current_action).sum(dim=-1, keepdim=True)
1343
+
1344
+ loss_cql = (logsumexp - q_a).squeeze(-1)
1345
+ loss_cql = _reduce(loss_cql, reduction=self.reduction)
1346
+ return loss_cql, {}