torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1956 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import functools
9
+ import warnings
10
+ from collections.abc import Callable
11
+ from contextlib import nullcontext
12
+ from dataclasses import asdict, dataclass
13
+ from functools import wraps
14
+
15
+ import torch
16
+ from tensordict import is_tensor_collection, TensorDictBase
17
+ from tensordict.nn import (
18
+ composite_lp_aggregate,
19
+ dispatch,
20
+ ProbabilisticTensorDictModule,
21
+ set_composite_lp_aggregate,
22
+ set_skip_existing,
23
+ TensorDictModule,
24
+ TensorDictModuleBase,
25
+ )
26
+ from tensordict.nn.probabilistic import interaction_type
27
+ from tensordict.utils import NestedKey, unravel_key
28
+ from torch import Tensor
29
+
30
+ from torchrl._utils import logger, rl_warnings
31
+ from torchrl.envs.utils import step_mdp
32
+ from torchrl.objectives.utils import (
33
+ _maybe_get_or_select,
34
+ _pseudo_vmap,
35
+ _vmap_func,
36
+ hold_out_net,
37
+ )
38
+ from torchrl.objectives.value.functional import (
39
+ generalized_advantage_estimate,
40
+ td0_return_estimate,
41
+ td_lambda_return_estimate,
42
+ vec_generalized_advantage_estimate,
43
+ vec_td1_return_estimate,
44
+ vec_td_lambda_return_estimate,
45
+ vtrace_advantage_estimate,
46
+ )
47
+
48
+ try:
49
+ from torch.compiler import is_dynamo_compiling
50
+ except ImportError:
51
+ from torch._dynamo import is_compiling as is_dynamo_compiling
52
+
53
+ try:
54
+ from torch import vmap
55
+ except ImportError as err:
56
+ try:
57
+ from functorch import vmap
58
+ except ImportError:
59
+ raise ImportError(
60
+ "vmap couldn't be found. Make sure you have torch>2.0 installed."
61
+ ) from err
62
+
63
+
64
+ def _self_set_grad_enabled(fun):
65
+ @wraps(fun)
66
+ def new_fun(self, *args, **kwargs):
67
+ with torch.set_grad_enabled(self.differentiable):
68
+ return fun(self, *args, **kwargs)
69
+
70
+ return new_fun
71
+
72
+
73
+ def _self_set_skip_existing(fun):
74
+ @functools.wraps(fun)
75
+ def new_func(self, *args, **kwargs):
76
+ if self.skip_existing is not None:
77
+ with set_skip_existing(self.skip_existing):
78
+ return fun(self, *args, **kwargs)
79
+ return fun(self, *args, **kwargs)
80
+
81
+ return new_func
82
+
83
+
84
+ def _call_actor_net(
85
+ actor_net: ProbabilisticTensorDictModule,
86
+ data: TensorDictBase,
87
+ params: TensorDictBase,
88
+ log_prob_key: NestedKey,
89
+ ):
90
+ dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False))
91
+ s = actor_net._dist_sample(dist, interaction_type=interaction_type())
92
+ with set_composite_lp_aggregate(True):
93
+ return dist.log_prob(s)
94
+
95
+
96
+ class ValueEstimatorBase(TensorDictModuleBase):
97
+ """An abstract parent class for value function modules.
98
+
99
+ Its :meth:`ValueFunctionBase.forward` method will compute the value (given
100
+ by the value network) and the value estimate (given by the value estimator)
101
+ as well as the advantage and write these values in the output tensordict.
102
+
103
+ If only the value estimate is needed, the :meth:`ValueFunctionBase.value_estimate`
104
+ should be used instead.
105
+
106
+ """
107
+
108
+ @dataclass
109
+ class _AcceptedKeys:
110
+ """Maintains default values for all configurable tensordict keys.
111
+
112
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
113
+ default values.
114
+
115
+ Attributes:
116
+ advantage (NestedKey): The input tensordict key where the advantage is written to.
117
+ Will be used for the underlying value estimator. Defaults to ``"advantage"``.
118
+ value_target (NestedKey): The input tensordict key where the target state value is written to.
119
+ Will be used for the underlying value estimator Defaults to ``"value_target"``.
120
+ value (NestedKey): The input tensordict key where the state value is expected.
121
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
122
+ reward (NestedKey): The input tensordict key where the reward is written to.
123
+ Defaults to ``"reward"``.
124
+ done (NestedKey): The key in the input TensorDict that indicates
125
+ whether a trajectory is done. Defaults to ``"done"``.
126
+ terminated (NestedKey): The key in the input TensorDict that indicates
127
+ whether a trajectory is terminated. Defaults to ``"terminated"``.
128
+ steps_to_next_obs (NestedKey): The key in the input tensordict
129
+ that indicates the number of steps to the next observation.
130
+ Defaults to ``"steps_to_next_obs"``.
131
+ sample_log_prob (NestedKey): The key in the input tensordict that
132
+ indicates the log probability of the sampled action.
133
+ Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
134
+ `"action_log_prob"` otherwise.
135
+ """
136
+
137
+ advantage: NestedKey = "advantage"
138
+ value_target: NestedKey = "value_target"
139
+ value: NestedKey = "state_value"
140
+ reward: NestedKey = "reward"
141
+ done: NestedKey = "done"
142
+ terminated: NestedKey = "terminated"
143
+ steps_to_next_obs: NestedKey = "steps_to_next_obs"
144
+ sample_log_prob: NestedKey | None = None
145
+
146
+ def __post_init__(self):
147
+ if self.sample_log_prob is None:
148
+ if composite_lp_aggregate(nowarn=True):
149
+ self.sample_log_prob = "sample_log_prob"
150
+ else:
151
+ self.sample_log_prob = "action_log_prob"
152
+
153
+ default_keys = _AcceptedKeys
154
+ tensor_keys: _AcceptedKeys
155
+ value_network: TensorDictModule | Callable
156
+ _vmap_randomness = None
157
+ deactivate_vmap: bool = False
158
+
159
+ @property
160
+ def advantage_key(self):
161
+ return self.tensor_keys.advantage
162
+
163
+ @property
164
+ def value_key(self):
165
+ return self.tensor_keys.value
166
+
167
+ @property
168
+ def value_target_key(self):
169
+ return self.tensor_keys.value_target
170
+
171
+ @property
172
+ def reward_key(self):
173
+ return self.tensor_keys.reward
174
+
175
+ @property
176
+ def done_key(self):
177
+ return self.tensor_keys.done
178
+
179
+ @property
180
+ def terminated_key(self):
181
+ return self.tensor_keys.terminated
182
+
183
+ @property
184
+ def steps_to_next_obs_key(self):
185
+ return self.tensor_keys.steps_to_next_obs
186
+
187
+ @property
188
+ def sample_log_prob_key(self):
189
+ return self.tensor_keys.sample_log_prob
190
+
191
+ @abc.abstractmethod
192
+ def forward(
193
+ self,
194
+ tensordict: TensorDictBase,
195
+ *,
196
+ params: TensorDictBase | None = None,
197
+ target_params: TensorDictBase | None = None,
198
+ ) -> TensorDictBase:
199
+ """Computes the advantage estimate given the data in tensordict.
200
+
201
+ If a functional module is provided, a nested TensorDict containing the parameters
202
+ (and if relevant the target parameters) can be passed to the module.
203
+
204
+ Args:
205
+ tensordict (TensorDictBase): A TensorDict containing the data
206
+ (an observation key, ``"action"``, ``("next", "reward")``,
207
+ ``("next", "done")``, ``("next", "terminated")``,
208
+ and ``"next"`` tensordict state as returned by the environment)
209
+ necessary to compute the value estimates and the TDEstimate.
210
+ The data passed to this module should be structured as
211
+ :obj:`[*B, T, *F]` where :obj:`B` are
212
+ the batch size, :obj:`T` the time dimension and :obj:`F` the
213
+ feature dimension(s). The tensordict must have shape ``[*B, T]``.
214
+
215
+ Keyword Args:
216
+ params (TensorDictBase, optional): A nested TensorDict containing the params
217
+ to be passed to the functional value network module.
218
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
219
+ target params to be passed to the functional value network module.
220
+ device (torch.device, optional): the device where the buffers will be instantiated.
221
+ Defaults to ``torch.get_default_device()``.
222
+
223
+ Returns:
224
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
225
+ """
226
+ ...
227
+
228
+ def __init__(
229
+ self,
230
+ *,
231
+ value_network: TensorDictModule,
232
+ shifted: bool = False,
233
+ differentiable: bool = False,
234
+ skip_existing: bool | None = None,
235
+ advantage_key: NestedKey = None,
236
+ value_target_key: NestedKey = None,
237
+ value_key: NestedKey = None,
238
+ device: torch.device | None = None,
239
+ deactivate_vmap: bool = False,
240
+ ):
241
+ super().__init__()
242
+ if device is None:
243
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
244
+ # this is saved for tracking only and should not be used to cast anything else than buffers during
245
+ # init.
246
+ self._device = device
247
+ self._tensor_keys = None
248
+ self.differentiable = differentiable
249
+ self.deactivate_vmap = deactivate_vmap
250
+ self.skip_existing = skip_existing
251
+ self.__dict__["value_network"] = value_network
252
+ self.dep_keys = {}
253
+ self.shifted = shifted
254
+
255
+ if advantage_key is not None:
256
+ raise RuntimeError(
257
+ "Setting 'advantage_key' via constructor is deprecated, use .set_keys(advantage_key='some_key') instead.",
258
+ )
259
+ if value_target_key is not None:
260
+ raise RuntimeError(
261
+ "Setting 'value_target_key' via constructor is deprecated, use .set_keys(value_target_key='some_key') instead.",
262
+ )
263
+ if value_key is not None:
264
+ raise RuntimeError(
265
+ "Setting 'value_key' via constructor is deprecated, use .set_keys(value_key='some_key') instead.",
266
+ )
267
+
268
+ @property
269
+ def tensor_keys(self) -> _AcceptedKeys:
270
+ if self._tensor_keys is None:
271
+ self.set_keys()
272
+ return self._tensor_keys
273
+
274
+ @tensor_keys.setter
275
+ def tensor_keys(self, value):
276
+ if not isinstance(value, type(self._AcceptedKeys)):
277
+ raise ValueError("value must be an instance of _AcceptedKeys")
278
+ self._keys = value
279
+
280
+ @property
281
+ def in_keys(self):
282
+ try:
283
+ in_keys = (
284
+ self.value_network.in_keys
285
+ + [
286
+ ("next", self.tensor_keys.reward),
287
+ ("next", self.tensor_keys.done),
288
+ ("next", self.tensor_keys.terminated),
289
+ ]
290
+ + [("next", in_key) for in_key in self.value_network.in_keys]
291
+ )
292
+ except AttributeError:
293
+ # value network does not have an `in_keys` attribute
294
+ in_keys = []
295
+ return in_keys
296
+
297
+ @property
298
+ def out_keys(self):
299
+ return [
300
+ self.tensor_keys.advantage,
301
+ self.tensor_keys.value_target,
302
+ ]
303
+
304
+ def set_keys(self, **kwargs) -> None:
305
+ """Set tensordict key names."""
306
+ for key, value in list(kwargs.items()):
307
+ if isinstance(value, list):
308
+ value = [unravel_key(k) for k in value]
309
+ elif not isinstance(value, (str, tuple)):
310
+ if value is None:
311
+ raise ValueError("tensordict keys cannot be None")
312
+ raise ValueError(
313
+ f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}"
314
+ )
315
+ else:
316
+ value = unravel_key(value)
317
+
318
+ if key not in self._AcceptedKeys.__dict__:
319
+ raise KeyError(
320
+ f"{key} is not an accepted tensordict key for advantages"
321
+ )
322
+ if (
323
+ key == "value"
324
+ and hasattr(self.value_network, "out_keys")
325
+ and (value not in self.value_network.out_keys)
326
+ ):
327
+ raise KeyError(
328
+ f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}"
329
+ )
330
+ kwargs[key] = value
331
+ if self._tensor_keys is None:
332
+ conf = asdict(self.default_keys())
333
+ conf.update(self.dep_keys)
334
+ else:
335
+ conf = asdict(self._tensor_keys)
336
+ conf.update(kwargs)
337
+ self._tensor_keys = self._AcceptedKeys(**conf)
338
+
339
+ def value_estimate(
340
+ self,
341
+ tensordict,
342
+ target_params: TensorDictBase | None = None,
343
+ next_value: torch.Tensor | None = None,
344
+ **kwargs,
345
+ ):
346
+ """Gets a value estimate, usually used as a target value for the value network.
347
+
348
+ If the state value key is present under ``tensordict.get(("next", self.tensor_keys.value))``
349
+ then this value will be used without recurring to the value network.
350
+
351
+ Args:
352
+ tensordict (TensorDictBase): the tensordict containing the data to
353
+ read.
354
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
355
+ target params to be passed to the functional value network module.
356
+ next_value (torch.Tensor, optional): the value of the next state
357
+ or state-action pair. Exclusive with ``target_params``.
358
+ **kwargs: the keyword arguments to be passed to the value network.
359
+
360
+ Returns: a tensor corresponding to the state value.
361
+
362
+ """
363
+ raise NotImplementedError
364
+
365
+ @property
366
+ def is_functional(self):
367
+ # legacy
368
+ return False
369
+
370
+ @property
371
+ def is_stateless(self):
372
+ # legacy
373
+ return False
374
+
375
+ def _next_value(self, tensordict, target_params, kwargs):
376
+ step_td = step_mdp(tensordict, keep_other=False)
377
+ if self.value_network is not None:
378
+ with hold_out_net(
379
+ self.value_network
380
+ ) if target_params is None else target_params.to_module(self.value_network):
381
+ self.value_network(step_td)
382
+ next_value = step_td.get(self.tensor_keys.value)
383
+ return next_value
384
+
385
+ @property
386
+ def vmap_randomness(self):
387
+ if self._vmap_randomness is None:
388
+ if is_dynamo_compiling():
389
+ self._vmap_randomness = "different"
390
+ return "different"
391
+ do_break = False
392
+ for val in self.__dict__.values():
393
+ if isinstance(val, torch.nn.Module):
394
+ import torchrl.objectives.utils
395
+
396
+ for module in val.modules():
397
+ if isinstance(
398
+ module, torchrl.objectives.utils.RANDOM_MODULE_LIST
399
+ ):
400
+ self._vmap_randomness = "different"
401
+ do_break = True
402
+ break
403
+ if do_break:
404
+ # double break
405
+ break
406
+ else:
407
+ self._vmap_randomness = "error"
408
+
409
+ return self._vmap_randomness
410
+
411
+ def set_vmap_randomness(self, value):
412
+ self._vmap_randomness = value
413
+
414
+ def _get_time_dim(self, time_dim: int | None, data: TensorDictBase):
415
+ if time_dim is not None:
416
+ if time_dim < 0:
417
+ time_dim = data.ndim + time_dim
418
+ return time_dim
419
+ time_dim_attr = getattr(self, "time_dim", None)
420
+ if time_dim_attr is not None:
421
+ if time_dim_attr < 0:
422
+ time_dim_attr = data.ndim + time_dim_attr
423
+ return time_dim_attr
424
+ if data._has_names():
425
+ for i, name in enumerate(data.names):
426
+ if name == "time":
427
+ return i
428
+ return data.ndim - 1
429
+
430
+ def _call_value_nets(
431
+ self,
432
+ data: TensorDictBase,
433
+ params: TensorDictBase,
434
+ next_params: TensorDictBase,
435
+ single_call: bool,
436
+ value_key: NestedKey,
437
+ detach_next: bool,
438
+ vmap_randomness: str = "error",
439
+ *,
440
+ value_net: TensorDictModuleBase | None = None,
441
+ ):
442
+ if value_net is None:
443
+ value_net = self.value_network
444
+ in_keys = value_net.in_keys
445
+ if single_call:
446
+ # We are going to flatten the data, then interleave the last observation of each trajectory in between its
447
+ # previous obs (from the root TD) and the first of the next trajectory. Eventually, each trajectory will
448
+ # have T+1 elements (or, for a batch of N trajectories, we will have \Sum_{t=0}^{T-1} length_t + T
449
+ # elements). Then, we can feed that to our RNN which will understand which trajectory is which, pad the data
450
+ # accordingly and process each of them independently.
451
+ try:
452
+ ndim = list(data.names).index("time") + 1
453
+ except ValueError:
454
+ if rl_warnings():
455
+ logger.warning(
456
+ "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
457
+ "This warning can be turned off by setting the environment variable RL_WARNINGS to False."
458
+ )
459
+ ndim = data.ndim
460
+ data_copy = data.copy()
461
+ # we are going to modify the done so let's clone it
462
+ done = data_copy["next", "done"].clone()
463
+ # Mark the last step of every sequence as done. We do this because flattening would cause the trajectories
464
+ # of different batches to be merged.
465
+ done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
466
+ truncated = data_copy.get(("next", "truncated"), done)
467
+ if truncated is not done:
468
+ truncated[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
469
+ data_copy["next", "done"] = done
470
+ data_copy["next", "truncated"] = truncated
471
+ # Reshape to -1 because we cannot guarantee that all dims have the same number of done states
472
+ with data_copy.view(-1) as data_copy_view:
473
+ # Interleave next data when done
474
+ data_copy_select = data_copy_view.select(
475
+ *in_keys, value_key, strict=False
476
+ )
477
+ total_elts = (
478
+ data_copy_view.shape[0]
479
+ + data_copy_view["next", "done"].sum().item()
480
+ )
481
+ data_in = data_copy_select.new_zeros((total_elts,))
482
+ # we can get the indices of non-done data by adding the shifted done cumsum to an arange
483
+ # traj = [0, 0, 0, 1, 1, 2, 2]
484
+ # arange = [0, 1, 2, 3, 4, 5, 6]
485
+ # done = [0, 0, 1, 0, 1, 0, 1]
486
+ # done_cs = [0, 0, 0, 1, 1, 2, 2]
487
+ # indices = [0, 1, 2, 4, 5, 7, 8]
488
+ done_view = data_copy_view["next", "done"]
489
+ if done_view.shape[-1] == 1:
490
+ done_view = done_view.squeeze(-1)
491
+ else:
492
+ done_view = done_view.any(-1)
493
+ done_cs = done_view.cumsum(0)
494
+ done_cs = torch.cat([done_cs.new_zeros((1,)), done_cs[:-1]], dim=0)
495
+ indices = torch.arange(done_cs.shape[0], device=done_cs.device)
496
+ indices = indices + done_cs
497
+ data_in[indices] = data_copy_select
498
+ # To get the indices of the extra data, we can mask indices with done_view and add 1
499
+ indices_interleaved = indices[done_view] + 1
500
+ # assert not set(indices_interleaved.tolist()).intersection(indices.tolist())
501
+ data_in[indices_interleaved] = (
502
+ data_copy_view[done_view]
503
+ .get("next")
504
+ .select(*in_keys, value_key, strict=False)
505
+ )
506
+ if next_params is not None and next_params is not params:
507
+ raise ValueError(
508
+ "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
509
+ )
510
+ if params is not None:
511
+ with params.to_module(value_net):
512
+ value_est = value_net(data_in).get(value_key)
513
+ else:
514
+ value_est = value_net(data_in).get(value_key)
515
+ value, value_ = value_est[indices], value_est[indices + 1]
516
+ value = value.view_as(done)
517
+ value_ = value_.view_as(done)
518
+ else:
519
+ data_root = data.select(*in_keys, value_key, strict=False)
520
+ data_next = data.get("next").select(*in_keys, value_key, strict=False)
521
+ if "is_init" in data_root.keys():
522
+ # We need to mark the first element of the "next" td as being an init step for RNNs
523
+ # otherwise, consecutive elements in the sequence will be considered as part of the same
524
+ # trajectory, even if they're not.
525
+ data_next["is_init"] = data_next["is_init"] | data_root["is_init"]
526
+ data_in = torch.stack(
527
+ [data_root, data_next],
528
+ 0,
529
+ )
530
+ if (params is not None) ^ (next_params is not None):
531
+ raise ValueError(
532
+ "params and next_params must be either both provided or not."
533
+ )
534
+ elif params is not None:
535
+ params_stack = torch.stack([params, next_params], 0).contiguous()
536
+ data_out = _vmap_func(
537
+ value_net,
538
+ (0, 0),
539
+ randomness=vmap_randomness,
540
+ pseudo_vmap=self.deactivate_vmap,
541
+ )(data_in, params_stack)
542
+ elif not self.deactivate_vmap:
543
+ data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
544
+ else:
545
+ data_out = _pseudo_vmap(value_net, (0,), randomness=vmap_randomness)(
546
+ data_in
547
+ )
548
+ value_est = data_out.get(value_key)
549
+ value, value_ = value_est[0], value_est[1]
550
+ data.set(value_key, value)
551
+ data.set(("next", value_key), value_)
552
+ if detach_next:
553
+ value_ = value_.detach()
554
+ return value, value_
555
+
556
+
557
+ class TD0Estimator(ValueEstimatorBase):
558
+ """Temporal Difference (TD(0)) estimate of advantage function.
559
+
560
+ AKA bootstrapped temporal difference or 1-step return.
561
+
562
+ Keyword Args:
563
+ gamma (scalar): exponential mean discount.
564
+ value_network (TensorDictModule): value operator used to retrieve
565
+ the value estimates.
566
+ shifted (bool, optional): if ``True``, the value and next value are
567
+ estimated with a single call to the value network. This is faster
568
+ but is only valid whenever (1) the ``"next"`` value is shifted by
569
+ only one time step (which is not the case with multi-step value
570
+ estimation, for instance) and (2) when the parameters used at time
571
+ ``t`` and ``t+1`` are identical (which is not the case when target
572
+ parameters are to be used). Defaults to ``False``.
573
+ average_rewards (bool, optional): if ``True``, rewards will be standardized
574
+ before the TD is computed.
575
+ differentiable (bool, optional): if ``True``, gradients are propagated through
576
+ the computation of the value function. Default is ``False``.
577
+
578
+ .. note::
579
+ The proper way to make the function call non-differentiable is to
580
+ decorate it in a `torch.no_grad()` context manager/decorator or
581
+ pass detached parameters for functional modules.
582
+
583
+ skip_existing (bool, optional): if ``True``, the value network will skip
584
+ modules which outputs are already present in the tensordict.
585
+ Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
586
+ is not affected.
587
+ advantage_key (str or tuple of str, optional): [Deprecated] the key of
588
+ the advantage entry. Defaults to ``"advantage"``.
589
+ value_target_key (str or tuple of str, optional): [Deprecated] the key
590
+ of the advantage entry. Defaults to ``"value_target"``.
591
+ value_key (str or tuple of str, optional): [Deprecated] the value key to
592
+ read from the input tensordict. Defaults to ``"state_value"``.
593
+ device (torch.device, optional): the device where the buffers will be instantiated.
594
+ Defaults to ``torch.get_default_device()``.
595
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
596
+ Defaults to ``False``.
597
+
598
+ """
599
+
600
+ def __init__(
601
+ self,
602
+ *,
603
+ gamma: float | torch.Tensor,
604
+ value_network: TensorDictModule,
605
+ shifted: bool = False,
606
+ average_rewards: bool = False,
607
+ differentiable: bool = False,
608
+ advantage_key: NestedKey = None,
609
+ value_target_key: NestedKey = None,
610
+ value_key: NestedKey = None,
611
+ skip_existing: bool | None = None,
612
+ device: torch.device | None = None,
613
+ deactivate_vmap: bool = False,
614
+ ):
615
+ super().__init__(
616
+ value_network=value_network,
617
+ differentiable=differentiable,
618
+ shifted=shifted,
619
+ advantage_key=advantage_key,
620
+ value_target_key=value_target_key,
621
+ value_key=value_key,
622
+ skip_existing=skip_existing,
623
+ device=device,
624
+ deactivate_vmap=deactivate_vmap,
625
+ )
626
+ self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
627
+ self.average_rewards = average_rewards
628
+
629
+ @_self_set_skip_existing
630
+ @_self_set_grad_enabled
631
+ @dispatch
632
+ def forward(
633
+ self,
634
+ tensordict: TensorDictBase,
635
+ *,
636
+ params: TensorDictBase | None = None,
637
+ target_params: TensorDictBase | None = None,
638
+ ) -> TensorDictBase:
639
+ """Computes the TD(0) advantage given the data in tensordict.
640
+
641
+ If a functional module is provided, a nested TensorDict containing the parameters
642
+ (and if relevant the target parameters) can be passed to the module.
643
+
644
+ Args:
645
+ tensordict (TensorDictBase): A TensorDict containing the data
646
+ (an observation key, ``"action"``, ``("next", "reward")``,
647
+ ``("next", "done")``, ``("next", "terminated")``, and ``"next"``
648
+ tensordict state as returned by the environment) necessary to
649
+ compute the value estimates and the TDEstimate.
650
+ The data passed to this module should be structured as
651
+ :obj:`[*B, T, *F]` where :obj:`B` are
652
+ the batch size, :obj:`T` the time dimension and :obj:`F` the
653
+ feature dimension(s). The tensordict must have shape ``[*B, T]``.
654
+
655
+ Keyword Args:
656
+ params (TensorDictBase, optional): A nested TensorDict containing the params
657
+ to be passed to the functional value network module.
658
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
659
+ target params to be passed to the functional value network module.
660
+
661
+ Returns:
662
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
663
+
664
+ Examples:
665
+ >>> from tensordict import TensorDict
666
+ >>> value_net = TensorDictModule(
667
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
668
+ ... )
669
+ >>> module = TDEstimate(
670
+ ... gamma=0.98,
671
+ ... value_network=value_net,
672
+ ... )
673
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
674
+ >>> reward = torch.randn(1, 10, 1)
675
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
676
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
677
+ >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10])
678
+ >>> _ = module(tensordict)
679
+ >>> assert "advantage" in tensordict.keys()
680
+
681
+ The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
682
+
683
+ Examples:
684
+ >>> value_net = TensorDictModule(
685
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
686
+ ... )
687
+ >>> module = TDEstimate(
688
+ ... gamma=0.98,
689
+ ... value_network=value_net,
690
+ ... )
691
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
692
+ >>> reward = torch.randn(1, 10, 1)
693
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
694
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
695
+ >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
696
+
697
+ """
698
+ if tensordict.batch_dims < 1:
699
+ raise RuntimeError(
700
+ "Expected input tensordict to have at least one dimensions, got"
701
+ f"tensordict.batch_size = {tensordict.batch_size}"
702
+ )
703
+
704
+ if self.is_stateless and params is None:
705
+ raise RuntimeError(
706
+ "Expected params to be passed to advantage module but got none."
707
+ )
708
+ if self.value_network is not None:
709
+ if params is not None:
710
+ params = params.detach()
711
+ if target_params is None:
712
+ target_params = params.clone(False)
713
+ with hold_out_net(self.value_network) if (
714
+ params is None and target_params is None
715
+ ) else nullcontext():
716
+ # we may still need to pass gradient, but we don't want to assign grads to
717
+ # value net params
718
+ value, next_value = self._call_value_nets(
719
+ data=tensordict,
720
+ params=params,
721
+ next_params=target_params,
722
+ single_call=self.shifted,
723
+ value_key=self.tensor_keys.value,
724
+ detach_next=True,
725
+ vmap_randomness=self.vmap_randomness,
726
+ )
727
+ else:
728
+ value = tensordict.get(self.tensor_keys.value)
729
+ next_value = tensordict.get(("next", self.tensor_keys.value))
730
+
731
+ value_target = self.value_estimate(tensordict, next_value=next_value)
732
+ tensordict.set(self.tensor_keys.advantage, value_target - value)
733
+ tensordict.set(self.tensor_keys.value_target, value_target)
734
+ return tensordict
735
+
736
+ def value_estimate(
737
+ self,
738
+ tensordict,
739
+ target_params: TensorDictBase | None = None,
740
+ next_value: torch.Tensor | None = None,
741
+ **kwargs,
742
+ ):
743
+ reward = tensordict.get(("next", self.tensor_keys.reward))
744
+ device = reward.device
745
+
746
+ if self.gamma.device != device:
747
+ self.gamma = self.gamma.to(device)
748
+ gamma = self.gamma
749
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
750
+ if steps_to_next_obs is not None:
751
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
752
+
753
+ if self.average_rewards:
754
+ reward = reward - reward.mean()
755
+ reward = reward / reward.std().clamp_min(1e-5)
756
+ tensordict.set(
757
+ ("next", self.tensor_keys.reward), reward
758
+ ) # we must update the rewards if they are used later in the code
759
+ if next_value is None:
760
+ next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
761
+
762
+ done = tensordict.get(("next", self.tensor_keys.done))
763
+ terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
764
+ value_target = td0_return_estimate(
765
+ gamma=gamma,
766
+ next_state_value=next_value,
767
+ reward=reward,
768
+ done=done,
769
+ terminated=terminated,
770
+ )
771
+ return value_target
772
+
773
+
774
+ class TD1Estimator(ValueEstimatorBase):
775
+ r""":math:`\infty`-Temporal Difference (TD(1)) estimate of advantage function.
776
+
777
+ Keyword Args:
778
+ gamma (scalar): exponential mean discount.
779
+ value_network (TensorDictModule): value operator used to retrieve the value estimates.
780
+ average_rewards (bool, optional): if ``True``, rewards will be standardized
781
+ before the TD is computed.
782
+ differentiable (bool, optional): if ``True``, gradients are propagated through
783
+ the computation of the value function. Default is ``False``.
784
+
785
+ .. note::
786
+ The proper way to make the function call non-differentiable is to
787
+ decorate it in a `torch.no_grad()` context manager/decorator or
788
+ pass detached parameters for functional modules.
789
+
790
+ skip_existing (bool, optional): if ``True``, the value network will skip
791
+ modules which outputs are already present in the tensordict.
792
+ Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
793
+ is not affected.
794
+ advantage_key (str or tuple of str, optional): [Deprecated] the key of
795
+ the advantage entry. Defaults to ``"advantage"``.
796
+ value_target_key (str or tuple of str, optional): [Deprecated] the key
797
+ of the advantage entry. Defaults to ``"value_target"``.
798
+ value_key (str or tuple of str, optional): [Deprecated] the value key to
799
+ read from the input tensordict. Defaults to ``"state_value"``.
800
+ shifted (bool, optional): if ``True``, the value and next value are
801
+ estimated with a single call to the value network. This is faster
802
+ but is only valid whenever (1) the ``"next"`` value is shifted by
803
+ only one time step (which is not the case with multi-step value
804
+ estimation, for instance) and (2) when the parameters used at time
805
+ ``t`` and ``t+1`` are identical (which is not the case when target
806
+ parameters are to be used). Defaults to ``False``.
807
+ device (torch.device, optional): the device where the buffers will be instantiated.
808
+ Defaults to ``torch.get_default_device()``.
809
+ time_dim (int, optional): the dimension corresponding to the time
810
+ in the input tensordict. If not provided, defaults to the dimension
811
+ marked with the ``"time"`` name if any, and to the last dimension
812
+ otherwise. Can be overridden during a call to
813
+ :meth:`~.value_estimate`.
814
+ Negative dimensions are considered with respect to the input
815
+ tensordict.
816
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
817
+ Defaults to ``False``.
818
+
819
+ """
820
+
821
+ def __init__(
822
+ self,
823
+ *,
824
+ gamma: float | torch.Tensor,
825
+ value_network: TensorDictModule,
826
+ average_rewards: bool = False,
827
+ differentiable: bool = False,
828
+ skip_existing: bool | None = None,
829
+ advantage_key: NestedKey = None,
830
+ value_target_key: NestedKey = None,
831
+ value_key: NestedKey = None,
832
+ shifted: bool = False,
833
+ device: torch.device | None = None,
834
+ time_dim: int | None = None,
835
+ deactivate_vmap: bool = False,
836
+ ):
837
+ super().__init__(
838
+ value_network=value_network,
839
+ differentiable=differentiable,
840
+ advantage_key=advantage_key,
841
+ value_target_key=value_target_key,
842
+ value_key=value_key,
843
+ shifted=shifted,
844
+ skip_existing=skip_existing,
845
+ device=device,
846
+ deactivate_vmap=deactivate_vmap,
847
+ )
848
+ self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
849
+ self.average_rewards = average_rewards
850
+ self.time_dim = time_dim
851
+
852
+ @_self_set_skip_existing
853
+ @_self_set_grad_enabled
854
+ @dispatch
855
+ def forward(
856
+ self,
857
+ tensordict: TensorDictBase,
858
+ *,
859
+ params: TensorDictBase | None = None,
860
+ target_params: TensorDictBase | None = None,
861
+ ) -> TensorDictBase:
862
+ """Computes the TD(1) advantage given the data in tensordict.
863
+
864
+ If a functional module is provided, a nested TensorDict containing the parameters
865
+ (and if relevant the target parameters) can be passed to the module.
866
+
867
+ Args:
868
+ tensordict (TensorDictBase): A TensorDict containing the data
869
+ (an observation key, ``"action"``, ``("next", "reward")``,
870
+ ``("next", "done")``, ``("next", "terminated")``,
871
+ and ``"next"`` tensordict state as returned by the environment)
872
+ necessary to compute the value estimates and the TDEstimate.
873
+ The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
874
+ the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
875
+ The tensordict must have shape ``[*B, T]``.
876
+
877
+ Keyword Args:
878
+ params (TensorDictBase, optional): A nested TensorDict containing the params
879
+ to be passed to the functional value network module.
880
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
881
+ target params to be passed to the functional value network module.
882
+
883
+ Returns:
884
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
885
+
886
+ Examples:
887
+ >>> from tensordict import TensorDict
888
+ >>> value_net = TensorDictModule(
889
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
890
+ ... )
891
+ >>> module = TDEstimate(
892
+ ... gamma=0.98,
893
+ ... value_network=value_net,
894
+ ... )
895
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
896
+ >>> reward = torch.randn(1, 10, 1)
897
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
898
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
899
+ >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10])
900
+ >>> _ = module(tensordict)
901
+ >>> assert "advantage" in tensordict.keys()
902
+
903
+ The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
904
+
905
+ Examples:
906
+ >>> value_net = TensorDictModule(
907
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
908
+ ... )
909
+ >>> module = TDEstimate(
910
+ ... gamma=0.98,
911
+ ... value_network=value_net,
912
+ ... )
913
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
914
+ >>> reward = torch.randn(1, 10, 1)
915
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
916
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
917
+ >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
918
+
919
+ """
920
+ if tensordict.batch_dims < 1:
921
+ raise RuntimeError(
922
+ "Expected input tensordict to have at least one dimensions, got"
923
+ f"tensordict.batch_size = {tensordict.batch_size}"
924
+ )
925
+
926
+ if self.is_stateless and params is None:
927
+ raise RuntimeError(
928
+ "Expected params to be passed to advantage module but got none."
929
+ )
930
+ if self.value_network is not None:
931
+ if params is not None:
932
+ params = params.detach()
933
+ if target_params is None:
934
+ target_params = params.clone(False)
935
+ with hold_out_net(self.value_network) if (
936
+ params is None and target_params is None
937
+ ) else nullcontext():
938
+ # we may still need to pass gradient, but we don't want to assign grads to
939
+ # value net params
940
+ value, next_value = self._call_value_nets(
941
+ data=tensordict,
942
+ params=params,
943
+ next_params=target_params,
944
+ single_call=self.shifted,
945
+ value_key=self.tensor_keys.value,
946
+ detach_next=True,
947
+ vmap_randomness=self.vmap_randomness,
948
+ )
949
+ else:
950
+ value = tensordict.get(self.tensor_keys.value)
951
+ next_value = tensordict.get(("next", self.tensor_keys.value))
952
+
953
+ value_target = self.value_estimate(tensordict, next_value=next_value)
954
+
955
+ tensordict.set(self.tensor_keys.advantage, value_target - value)
956
+ tensordict.set(self.tensor_keys.value_target, value_target)
957
+ return tensordict
958
+
959
+ def value_estimate(
960
+ self,
961
+ tensordict,
962
+ target_params: TensorDictBase | None = None,
963
+ next_value: torch.Tensor | None = None,
964
+ time_dim: int | None = None,
965
+ **kwargs,
966
+ ):
967
+ reward = tensordict.get(("next", self.tensor_keys.reward))
968
+ device = reward.device
969
+ if self.gamma.device != device:
970
+ self.gamma = self.gamma.to(device)
971
+ gamma = self.gamma
972
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
973
+ if steps_to_next_obs is not None:
974
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
975
+
976
+ if self.average_rewards:
977
+ reward = reward - reward.mean()
978
+ reward = reward / reward.std().clamp_min(1e-5)
979
+ tensordict.set(
980
+ ("next", self.tensor_keys.reward), reward
981
+ ) # we must update the rewards if they are used later in the code
982
+ if next_value is None:
983
+ next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
984
+
985
+ done = tensordict.get(("next", self.tensor_keys.done))
986
+ terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
987
+ time_dim = self._get_time_dim(time_dim, tensordict)
988
+ value_target = vec_td1_return_estimate(
989
+ gamma,
990
+ next_value,
991
+ reward,
992
+ done=done,
993
+ terminated=terminated,
994
+ time_dim=time_dim,
995
+ )
996
+ return value_target
997
+
998
+
999
+ class TDLambdaEstimator(ValueEstimatorBase):
1000
+ r"""TD(:math:`\lambda`) estimate of advantage function.
1001
+
1002
+ Args:
1003
+ gamma (scalar): exponential mean discount.
1004
+ lmbda (scalar): trajectory discount.
1005
+ value_network (TensorDictModule): value operator used to retrieve the value estimates.
1006
+ average_rewards (bool, optional): if ``True``, rewards will be standardized
1007
+ before the TD is computed.
1008
+ differentiable (bool, optional): if ``True``, gradients are propagated through
1009
+ the computation of the value function. Default is ``False``.
1010
+
1011
+ .. note::
1012
+ The proper way to make the function call non-differentiable is to
1013
+ decorate it in a `torch.no_grad()` context manager/decorator or
1014
+ pass detached parameters for functional modules.
1015
+
1016
+ vectorized (bool, optional): whether to use the vectorized version of the
1017
+ lambda return. Default is `True`.
1018
+ skip_existing (bool, optional): if ``True``, the value network will skip
1019
+ modules which outputs are already present in the tensordict.
1020
+ Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
1021
+ is not affected.
1022
+ advantage_key (str or tuple of str, optional): [Deprecated] the key of
1023
+ the advantage entry. Defaults to ``"advantage"``.
1024
+ value_target_key (str or tuple of str, optional): [Deprecated] the key
1025
+ of the advantage entry. Defaults to ``"value_target"``.
1026
+ value_key (str or tuple of str, optional): [Deprecated] the value key to
1027
+ read from the input tensordict. Defaults to ``"state_value"``.
1028
+ shifted (bool, optional): if ``True``, the value and next value are
1029
+ estimated with a single call to the value network. This is faster
1030
+ but is only valid whenever (1) the ``"next"`` value is shifted by
1031
+ only one time step (which is not the case with multi-step value
1032
+ estimation, for instance) and (2) when the parameters used at time
1033
+ ``t`` and ``t+1`` are identical (which is not the case when target
1034
+ parameters are to be used). Defaults to ``False``.
1035
+ device (torch.device, optional): the device where the buffers will be instantiated.
1036
+ Defaults to ``torch.get_default_device()``.
1037
+ time_dim (int, optional): the dimension corresponding to the time
1038
+ in the input tensordict. If not provided, defaults to the dimension
1039
+ marked with the ``"time"`` name if any, and to the last dimension
1040
+ otherwise. Can be overridden during a call to
1041
+ :meth:`~.value_estimate`.
1042
+ Negative dimensions are considered with respect to the input
1043
+ tensordict.
1044
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
1045
+ Defaults to ``False``.
1046
+
1047
+ """
1048
+
1049
+ def __init__(
1050
+ self,
1051
+ *,
1052
+ gamma: float | torch.Tensor,
1053
+ lmbda: float | torch.Tensor,
1054
+ value_network: TensorDictModule,
1055
+ average_rewards: bool = False,
1056
+ differentiable: bool = False,
1057
+ vectorized: bool = True,
1058
+ skip_existing: bool | None = None,
1059
+ advantage_key: NestedKey = None,
1060
+ value_target_key: NestedKey = None,
1061
+ value_key: NestedKey = None,
1062
+ shifted: bool = False,
1063
+ device: torch.device | None = None,
1064
+ time_dim: int | None = None,
1065
+ deactivate_vmap: bool = False,
1066
+ ):
1067
+ super().__init__(
1068
+ value_network=value_network,
1069
+ differentiable=differentiable,
1070
+ advantage_key=advantage_key,
1071
+ value_target_key=value_target_key,
1072
+ value_key=value_key,
1073
+ skip_existing=skip_existing,
1074
+ shifted=shifted,
1075
+ device=device,
1076
+ deactivate_vmap=deactivate_vmap,
1077
+ )
1078
+ self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
1079
+ self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
1080
+ self.average_rewards = average_rewards
1081
+ self.vectorized = vectorized
1082
+ self.time_dim = time_dim
1083
+
1084
+ @property
1085
+ def vectorized(self):
1086
+ if is_dynamo_compiling():
1087
+ return False
1088
+ return self._vectorized
1089
+
1090
+ @vectorized.setter
1091
+ def vectorized(self, value):
1092
+ self._vectorized = value
1093
+
1094
+ @_self_set_skip_existing
1095
+ @_self_set_grad_enabled
1096
+ @dispatch
1097
+ def forward(
1098
+ self,
1099
+ tensordict: TensorDictBase,
1100
+ *,
1101
+ params: list[Tensor] | None = None,
1102
+ target_params: list[Tensor] | None = None,
1103
+ ) -> TensorDictBase:
1104
+ r"""Computes the TD(:math:`\lambda`) advantage given the data in tensordict.
1105
+
1106
+ If a functional module is provided, a nested TensorDict containing the parameters
1107
+ (and if relevant the target parameters) can be passed to the module.
1108
+
1109
+ Args:
1110
+ tensordict (TensorDictBase): A TensorDict containing the data
1111
+ (an observation key, ``"action"``, ``("next", "reward")``,
1112
+ ``("next", "done")``, ``("next", "terminated")``,
1113
+ and ``"next"`` tensordict state as returned by the environment)
1114
+ necessary to compute the value estimates and the TDLambdaEstimate.
1115
+ The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
1116
+ the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
1117
+ The tensordict must have shape ``[*B, T]``.
1118
+
1119
+ Keyword Args:
1120
+ params (TensorDictBase, optional): A nested TensorDict containing the params
1121
+ to be passed to the functional value network module.
1122
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
1123
+ target params to be passed to the functional value network module.
1124
+
1125
+ Returns:
1126
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
1127
+
1128
+ Examples:
1129
+ >>> from tensordict import TensorDict
1130
+ >>> value_net = TensorDictModule(
1131
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
1132
+ ... )
1133
+ >>> module = TDLambdaEstimator(
1134
+ ... gamma=0.98,
1135
+ ... lmbda=0.94,
1136
+ ... value_network=value_net,
1137
+ ... )
1138
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
1139
+ >>> reward = torch.randn(1, 10, 1)
1140
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1141
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1142
+ >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10])
1143
+ >>> _ = module(tensordict)
1144
+ >>> assert "advantage" in tensordict.keys()
1145
+
1146
+ The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
1147
+
1148
+ Examples:
1149
+ >>> value_net = TensorDictModule(
1150
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
1151
+ ... )
1152
+ >>> module = TDLambdaEstimator(
1153
+ ... gamma=0.98,
1154
+ ... lmbda=0.94,
1155
+ ... value_network=value_net,
1156
+ ... )
1157
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
1158
+ >>> reward = torch.randn(1, 10, 1)
1159
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1160
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1161
+ >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
1162
+
1163
+ """
1164
+ if tensordict.batch_dims < 1:
1165
+ raise RuntimeError(
1166
+ "Expected input tensordict to have at least one dimensions, got"
1167
+ f"tensordict.batch_size = {tensordict.batch_size}"
1168
+ )
1169
+ if self.is_stateless and params is None:
1170
+ raise RuntimeError(
1171
+ "Expected params to be passed to advantage module but got none."
1172
+ )
1173
+ if self.value_network is not None:
1174
+ if params is not None:
1175
+ params = params.detach()
1176
+ if target_params is None:
1177
+ target_params = params.clone(False)
1178
+ with hold_out_net(self.value_network) if (
1179
+ params is None and target_params is None
1180
+ ) else nullcontext():
1181
+ # we may still need to pass gradient, but we don't want to assign grads to
1182
+ # value net params
1183
+ value, next_value = self._call_value_nets(
1184
+ data=tensordict,
1185
+ params=params,
1186
+ next_params=target_params,
1187
+ single_call=self.shifted,
1188
+ value_key=self.tensor_keys.value,
1189
+ detach_next=True,
1190
+ vmap_randomness=self.vmap_randomness,
1191
+ )
1192
+ else:
1193
+ value = tensordict.get(self.tensor_keys.value)
1194
+ next_value = tensordict.get(("next", self.tensor_keys.value))
1195
+ value_target = self.value_estimate(tensordict, next_value=next_value)
1196
+
1197
+ tensordict.set(self.tensor_keys.advantage, value_target - value)
1198
+ tensordict.set(self.tensor_keys.value_target, value_target)
1199
+ return tensordict
1200
+
1201
+ def value_estimate(
1202
+ self,
1203
+ tensordict,
1204
+ target_params: TensorDictBase | None = None,
1205
+ next_value: torch.Tensor | None = None,
1206
+ time_dim: int | None = None,
1207
+ **kwargs,
1208
+ ):
1209
+ reward = tensordict.get(("next", self.tensor_keys.reward))
1210
+ device = reward.device
1211
+
1212
+ if self.gamma.device != device:
1213
+ self.gamma = self.gamma.to(device)
1214
+ gamma = self.gamma
1215
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
1216
+ if steps_to_next_obs is not None:
1217
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
1218
+
1219
+ if self.lmbda.device != device:
1220
+ self.lmbda = self.lmbda.to(device)
1221
+ lmbda = self.lmbda
1222
+ if self.average_rewards:
1223
+ reward = reward - reward.mean()
1224
+ reward = reward / reward.std().clamp_min(1e-4)
1225
+ tensordict.set(
1226
+ ("next", self.tensor_keys.steps_to_next_obs), reward
1227
+ ) # we must update the rewards if they are used later in the code
1228
+
1229
+ if next_value is None:
1230
+ next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
1231
+
1232
+ done = tensordict.get(("next", self.tensor_keys.done))
1233
+ terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
1234
+ time_dim = self._get_time_dim(time_dim, tensordict)
1235
+ if self.vectorized:
1236
+ val = vec_td_lambda_return_estimate(
1237
+ gamma,
1238
+ lmbda,
1239
+ next_value,
1240
+ reward,
1241
+ done=done,
1242
+ terminated=terminated,
1243
+ time_dim=time_dim,
1244
+ )
1245
+ else:
1246
+ val = td_lambda_return_estimate(
1247
+ gamma,
1248
+ lmbda,
1249
+ next_value,
1250
+ reward,
1251
+ done=done,
1252
+ terminated=terminated,
1253
+ time_dim=time_dim,
1254
+ )
1255
+ return val
1256
+
1257
+
1258
+ class GAE(ValueEstimatorBase):
1259
+ """A class wrapper around the generalized advantage estimate functional.
1260
+
1261
+ Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
1262
+ https://arxiv.org/pdf/1506.02438.pdf for more context.
1263
+
1264
+ Args:
1265
+ gamma (scalar): exponential mean discount.
1266
+ lmbda (scalar): trajectory discount.
1267
+ value_network (TensorDictModule, optional): value operator used to retrieve the value estimates.
1268
+ If ``None``, this module will expect the ``"state_value"`` keys to be already filled, and
1269
+ will not call the value network to produce it.
1270
+ average_gae (bool): if ``True``, the resulting GAE values will be standardized.
1271
+ Default is ``False``.
1272
+ differentiable (bool, optional): if ``True``, gradients are propagated through
1273
+ the computation of the value function. Default is ``False``.
1274
+
1275
+ .. note::
1276
+ The proper way to make the function call non-differentiable is to
1277
+ decorate it in a `torch.no_grad()` context manager/decorator or
1278
+ pass detached parameters for functional modules.
1279
+
1280
+ vectorized (bool, optional): whether to use the vectorized version of the
1281
+ lambda return. Default is `True` if not compiling.
1282
+ skip_existing (bool, optional): if ``True``, the value network will skip
1283
+ modules which outputs are already present in the tensordict.
1284
+ Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
1285
+ is not affected.
1286
+ Defaults to "state_value".
1287
+ advantage_key (str or tuple of str, optional): [Deprecated] the key of
1288
+ the advantage entry. Defaults to ``"advantage"``.
1289
+ value_target_key (str or tuple of str, optional): [Deprecated] the key
1290
+ of the advantage entry. Defaults to ``"value_target"``.
1291
+ value_key (str or tuple of str, optional): [Deprecated] the value key to
1292
+ read from the input tensordict. Defaults to ``"state_value"``.
1293
+ shifted (bool, optional): if ``True``, the value and next value are
1294
+ estimated with a single call to the value network. This is faster
1295
+ but is only valid whenever (1) the ``"next"`` value is shifted by
1296
+ only one time step (which is not the case with multi-step value
1297
+ estimation, for instance) and (2) when the parameters used at time
1298
+ ``t`` and ``t+1`` are identical (which is not the case when target
1299
+ parameters are to be used). Defaults to ``False``.
1300
+ device (torch.device, optional): the device where the buffers will be instantiated.
1301
+ Defaults to ``torch.get_default_device()``.
1302
+ time_dim (int, optional): the dimension corresponding to the time
1303
+ in the input tensordict. If not provided, defaults to the dimension
1304
+ marked with the ``"time"`` name if any, and to the last dimension
1305
+ otherwise. Can be overridden during a call to
1306
+ :meth:`~.value_estimate`.
1307
+ Negative dimensions are considered with respect to the input
1308
+ tensordict.
1309
+ auto_reset_env (bool, optional): if ``True``, the last ``"next"`` state
1310
+ of the episode isn't valid, so the GAE calculation will use the ``value``
1311
+ instead of ``next_value`` to bootstrap truncated episodes.
1312
+ deactivate_vmap (bool, optional): if ``True``, no vmap call will be used, and
1313
+ vectorized maps will be replaced with simple for loops. Defaults to ``False``.
1314
+
1315
+ GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also
1316
+ return a :obj:`"value_target"` entry with the return value that is to be used
1317
+ to train the value network. Finally, if :obj:`gradient_mode` is ``True``,
1318
+ an additional and differentiable :obj:`"value_error"` entry will be returned,
1319
+ which simply represents the difference between the return and the value network
1320
+ output (i.e. an additional distance loss should be applied to that signed value).
1321
+
1322
+ .. note::
1323
+ As other advantage functions do, if the ``value_key`` is already present
1324
+ in the input tensordict, the GAE module will ignore the calls to the value
1325
+ network (if any) and use the provided value instead.
1326
+
1327
+ .. note:: GAE can be used with value networks that rely on recurrent neural networks, provided that the
1328
+ init markers (`"is_init"`) and terminated / truncated markers are properly set.
1329
+ If `shifted=True`, the trajectory batch will be flattened and the last step of each trajectory will
1330
+ be placed within the flat tensordict after the last step from the root, such that each trajectory has
1331
+ `T+1` elements. If `shifted=False`, the root and `"next"` trajecotries will be stacked and the value
1332
+ network will be called with `vmap` over the stack of trajectories. Because RNNs require fair amount of
1333
+ control flow, they are currently not compatible with `torch.vmap` and, as such, the `deactivate_vmap` option
1334
+ must be turned on in these cases.
1335
+ Similarly, if `shifted=False`, the `"is_init"` entry of the root tensordict will be copied onto the
1336
+ `"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data.
1337
+ """
1338
+
1339
+ value_network: TensorDictModule | None
1340
+
1341
+ def __init__(
1342
+ self,
1343
+ *,
1344
+ gamma: float | torch.Tensor,
1345
+ lmbda: float | torch.Tensor,
1346
+ value_network: TensorDictModule | None,
1347
+ average_gae: bool = False,
1348
+ differentiable: bool = False,
1349
+ vectorized: bool | None = None,
1350
+ skip_existing: bool | None = None,
1351
+ advantage_key: NestedKey = None,
1352
+ value_target_key: NestedKey = None,
1353
+ value_key: NestedKey = None,
1354
+ shifted: bool = False,
1355
+ device: torch.device | None = None,
1356
+ time_dim: int | None = None,
1357
+ auto_reset_env: bool = False,
1358
+ deactivate_vmap: bool = False,
1359
+ ):
1360
+ super().__init__(
1361
+ shifted=shifted,
1362
+ value_network=value_network,
1363
+ differentiable=differentiable,
1364
+ advantage_key=advantage_key,
1365
+ value_target_key=value_target_key,
1366
+ value_key=value_key,
1367
+ skip_existing=skip_existing,
1368
+ device=device,
1369
+ )
1370
+ self.register_buffer(
1371
+ "gamma",
1372
+ gamma.to(self._device)
1373
+ if isinstance(gamma, Tensor)
1374
+ else torch.tensor(gamma, device=self._device),
1375
+ )
1376
+ self.register_buffer(
1377
+ "lmbda",
1378
+ lmbda.to(self._device)
1379
+ if isinstance(lmbda, Tensor)
1380
+ else torch.tensor(lmbda, device=self._device),
1381
+ )
1382
+ self.average_gae = average_gae
1383
+ self.vectorized = vectorized
1384
+ self.time_dim = time_dim
1385
+ self.auto_reset_env = auto_reset_env
1386
+ self.deactivate_vmap = deactivate_vmap
1387
+
1388
+ @property
1389
+ def vectorized(self):
1390
+ if is_dynamo_compiling():
1391
+ return False
1392
+ return self._vectorized
1393
+
1394
+ @vectorized.setter
1395
+ def vectorized(self, value):
1396
+ self._vectorized = value
1397
+
1398
+ @_self_set_skip_existing
1399
+ @_self_set_grad_enabled
1400
+ @dispatch
1401
+ def forward(
1402
+ self,
1403
+ tensordict: TensorDictBase,
1404
+ *,
1405
+ params: list[Tensor] | None = None,
1406
+ target_params: list[Tensor] | None = None,
1407
+ time_dim: int | None = None,
1408
+ ) -> TensorDictBase:
1409
+ """Computes the GAE given the data in tensordict.
1410
+
1411
+ If a functional module is provided, a nested TensorDict containing the parameters
1412
+ (and if relevant the target parameters) can be passed to the module.
1413
+
1414
+ Args:
1415
+ tensordict (TensorDictBase): A TensorDict containing the data
1416
+ (an observation key, ``"action"``, ``("next", "reward")``,
1417
+ ``("next", "done")``, ``("next", "terminated")``,
1418
+ and ``"next"`` tensordict state as returned by the environment)
1419
+ necessary to compute the value estimates and the GAE.
1420
+ The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
1421
+ the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
1422
+ The tensordict must have shape ``[*B, T]``.
1423
+
1424
+ Keyword Args:
1425
+ params (TensorDictBase, optional): A nested TensorDict containing the params
1426
+ to be passed to the functional value network module.
1427
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
1428
+ target params to be passed to the functional value network module.
1429
+ time_dim (int, optional): the dimension corresponding to the time
1430
+ in the input tensordict. If not provided, defaults to the dimension
1431
+ marked with the ``"time"`` name if any, and to the last dimension
1432
+ otherwise.
1433
+ Negative dimensions are considered with respect to the input
1434
+ tensordict.
1435
+
1436
+ Returns:
1437
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
1438
+
1439
+ Examples:
1440
+ >>> from tensordict import TensorDict
1441
+ >>> value_net = TensorDictModule(
1442
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
1443
+ ... )
1444
+ >>> module = GAE(
1445
+ ... gamma=0.98,
1446
+ ... lmbda=0.94,
1447
+ ... value_network=value_net,
1448
+ ... differentiable=False,
1449
+ ... )
1450
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
1451
+ >>> reward = torch.randn(1, 10, 1)
1452
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1453
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1454
+ >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward, "terminated": terminated}, [1, 10])
1455
+ >>> _ = module(tensordict)
1456
+ >>> assert "advantage" in tensordict.keys()
1457
+
1458
+ The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
1459
+
1460
+ Examples:
1461
+ >>> value_net = TensorDictModule(
1462
+ ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
1463
+ ... )
1464
+ >>> module = GAE(
1465
+ ... gamma=0.98,
1466
+ ... lmbda=0.94,
1467
+ ... value_network=value_net,
1468
+ ... differentiable=False,
1469
+ ... )
1470
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
1471
+ >>> reward = torch.randn(1, 10, 1)
1472
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1473
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1474
+ >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
1475
+
1476
+ """
1477
+ if tensordict.batch_dims < 1:
1478
+ raise RuntimeError(
1479
+ "Expected input tensordict to have at least one dimension, got "
1480
+ f"tensordict.batch_size = {tensordict.batch_size}"
1481
+ )
1482
+ reward = tensordict.get(("next", self.tensor_keys.reward))
1483
+ device = reward.device
1484
+ if self.gamma.device != device:
1485
+ self.gamma = self.gamma.to(device)
1486
+ gamma = self.gamma
1487
+ if self.lmbda.device != device:
1488
+ self.lmbda = self.lmbda.to(device)
1489
+ lmbda = self.lmbda
1490
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
1491
+ if steps_to_next_obs is not None:
1492
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
1493
+
1494
+ if self.value_network is not None:
1495
+ if params is not None:
1496
+ params = params.detach()
1497
+ if target_params is None:
1498
+ target_params = params.clone(False)
1499
+ with hold_out_net(self.value_network) if (
1500
+ params is None and target_params is None
1501
+ ) else nullcontext():
1502
+ # with torch.no_grad():
1503
+ # we may still need to pass gradient, but we don't want to assign grads to
1504
+ # value net params
1505
+ value, next_value = self._call_value_nets(
1506
+ data=tensordict,
1507
+ params=params,
1508
+ next_params=target_params,
1509
+ single_call=self.shifted,
1510
+ value_key=self.tensor_keys.value,
1511
+ detach_next=True,
1512
+ vmap_randomness=self.vmap_randomness,
1513
+ )
1514
+ else:
1515
+ value = tensordict.get(self.tensor_keys.value)
1516
+ next_value = tensordict.get(("next", self.tensor_keys.value))
1517
+
1518
+ if value is None:
1519
+ raise ValueError(
1520
+ f"The tensor with key {self.tensor_keys.value} is missing, and no value network was provided."
1521
+ )
1522
+ if next_value is None:
1523
+ raise ValueError(
1524
+ f"The tensor with key {('next', self.tensor_keys.value)} is missing, and no value network was provided."
1525
+ )
1526
+
1527
+ done = tensordict.get(("next", self.tensor_keys.done))
1528
+ terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
1529
+ time_dim = self._get_time_dim(time_dim, tensordict)
1530
+
1531
+ if self.auto_reset_env:
1532
+ truncated = tensordict.get(("next", "truncated"))
1533
+ if truncated.any():
1534
+ reward += gamma * value * truncated
1535
+
1536
+ if self.vectorized:
1537
+ adv, value_target = vec_generalized_advantage_estimate(
1538
+ gamma,
1539
+ lmbda,
1540
+ value,
1541
+ next_value,
1542
+ reward,
1543
+ done=done,
1544
+ terminated=terminated if not self.auto_reset_env else done,
1545
+ time_dim=time_dim,
1546
+ )
1547
+ else:
1548
+ adv, value_target = generalized_advantage_estimate(
1549
+ gamma,
1550
+ lmbda,
1551
+ value,
1552
+ next_value,
1553
+ reward,
1554
+ done=done,
1555
+ terminated=terminated if not self.auto_reset_env else done,
1556
+ time_dim=time_dim,
1557
+ )
1558
+
1559
+ if self.average_gae:
1560
+ loc = adv.mean()
1561
+ scale = adv.std().clamp_min(1e-4)
1562
+ adv = adv - loc
1563
+ adv = adv / scale
1564
+
1565
+ tensordict.set(self.tensor_keys.advantage, adv)
1566
+ tensordict.set(self.tensor_keys.value_target, value_target)
1567
+
1568
+ return tensordict
1569
+
1570
+ def value_estimate(
1571
+ self,
1572
+ tensordict,
1573
+ params: TensorDictBase | None = None,
1574
+ target_params: TensorDictBase | None = None,
1575
+ time_dim: int | None = None,
1576
+ **kwargs,
1577
+ ):
1578
+ if tensordict.batch_dims < 1:
1579
+ raise RuntimeError(
1580
+ "Expected input tensordict to have at least one dimensions, got"
1581
+ f"tensordict.batch_size = {tensordict.batch_size}"
1582
+ )
1583
+ reward = tensordict.get(("next", self.tensor_keys.reward))
1584
+ device = reward.device
1585
+ if self.gamma.device != device:
1586
+ self.gamma = self.gamma.to(device)
1587
+ gamma = self.gamma
1588
+ if self.lmbda.device != device:
1589
+ self.lmbda = self.lmbda.to(device)
1590
+ lmbda = self.lmbda
1591
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
1592
+ if steps_to_next_obs is not None:
1593
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
1594
+
1595
+ time_dim = self._get_time_dim(time_dim, tensordict)
1596
+
1597
+ if self.is_stateless and params is None:
1598
+ raise RuntimeError(
1599
+ "Expected params to be passed to advantage module but got none."
1600
+ )
1601
+ if self.value_network is not None:
1602
+ if params is not None:
1603
+ params = params.detach()
1604
+ if target_params is None:
1605
+ target_params = params.clone(False)
1606
+ with hold_out_net(self.value_network) if (
1607
+ params is None and target_params is None
1608
+ ) else nullcontext():
1609
+ # we may still need to pass gradient, but we don't want to assign grads to
1610
+ # value net params
1611
+ value, next_value = self._call_value_nets(
1612
+ data=tensordict,
1613
+ params=params,
1614
+ next_params=target_params,
1615
+ single_call=self.shifted,
1616
+ value_key=self.tensor_keys.value,
1617
+ detach_next=True,
1618
+ vmap_randomness=self.vmap_randomness,
1619
+ )
1620
+ else:
1621
+ value = tensordict.get(self.tensor_keys.value)
1622
+ next_value = tensordict.get(("next", self.tensor_keys.value))
1623
+ done = tensordict.get(("next", self.tensor_keys.done))
1624
+ terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
1625
+ _, value_target = vec_generalized_advantage_estimate(
1626
+ gamma,
1627
+ lmbda,
1628
+ value,
1629
+ next_value,
1630
+ reward,
1631
+ done=done,
1632
+ terminated=terminated,
1633
+ time_dim=time_dim,
1634
+ )
1635
+ return value_target
1636
+
1637
+
1638
+ class VTrace(ValueEstimatorBase):
1639
+ """A class wrapper around V-Trace estimate functional.
1640
+
1641
+ Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
1642
+ :ref:`here <https://arxiv.org/abs/1802.01561>`_ for more context.
1643
+
1644
+ Keyword Args:
1645
+ gamma (scalar): exponential mean discount.
1646
+ value_network (TensorDictModule): value operator used to retrieve the value estimates.
1647
+ actor_network (TensorDictModule): actor operator used to retrieve the log prob.
1648
+ rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
1649
+ Defaults to ``1.0``.
1650
+ c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
1651
+ Defaults to ``1.0``.
1652
+ average_adv (bool): if ``True``, the resulting advantage values will be standardized.
1653
+ Default is ``False``.
1654
+ differentiable (bool, optional): if ``True``, gradients are propagated through
1655
+ the computation of the value function. Default is ``False``.
1656
+
1657
+ .. note::
1658
+ The proper way to make the function call non-differentiable is to
1659
+ decorate it in a `torch.no_grad()` context manager/decorator or
1660
+ pass detached parameters for functional modules.
1661
+ skip_existing (bool, optional): if ``True``, the value network will skip
1662
+ modules which outputs are already present in the tensordict.
1663
+ Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
1664
+ is not affected.
1665
+ Defaults to "state_value".
1666
+ advantage_key (str or tuple of str, optional): [Deprecated] the key of
1667
+ the advantage entry. Defaults to ``"advantage"``.
1668
+ value_target_key (str or tuple of str, optional): [Deprecated] the key
1669
+ of the advantage entry. Defaults to ``"value_target"``.
1670
+ value_key (str or tuple of str, optional): [Deprecated] the value key to
1671
+ read from the input tensordict. Defaults to ``"state_value"``.
1672
+ shifted (bool, optional): if ``True``, the value and next value are
1673
+ estimated with a single call to the value network. This is faster
1674
+ but is only valid whenever (1) the ``"next"`` value is shifted by
1675
+ only one time step (which is not the case with multi-step value
1676
+ estimation, for instance) and (2) when the parameters used at time
1677
+ ``t`` and ``t+1`` are identical (which is not the case when target
1678
+ parameters are to be used). Defaults to ``False``.
1679
+ device (torch.device, optional): the device where the buffers will be instantiated.
1680
+ Defaults to ``torch.get_default_device()``.
1681
+ time_dim (int, optional): the dimension corresponding to the time
1682
+ in the input tensordict. If not provided, defaults to the dimension
1683
+ marked with the ``"time"`` name if any, and to the last dimension
1684
+ otherwise. Can be overridden during a call to
1685
+ :meth:`~.value_estimate`.
1686
+ Negative dimensions are considered with respect to the input
1687
+ tensordict.
1688
+
1689
+ VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also
1690
+ return a :obj:`"value_target"` entry with the V-Trace target value.
1691
+
1692
+ .. note::
1693
+ As other advantage functions do, if the ``value_key`` is already present
1694
+ in the input tensordict, the VTrace module will ignore the calls to the value
1695
+ network (if any) and use the provided value instead.
1696
+
1697
+ """
1698
+
1699
+ def __init__(
1700
+ self,
1701
+ *,
1702
+ gamma: float | torch.Tensor,
1703
+ actor_network: TensorDictModule,
1704
+ value_network: TensorDictModule,
1705
+ rho_thresh: float | torch.Tensor = 1.0,
1706
+ c_thresh: float | torch.Tensor = 1.0,
1707
+ average_adv: bool = False,
1708
+ differentiable: bool = False,
1709
+ skip_existing: bool | None = None,
1710
+ advantage_key: NestedKey | None = None,
1711
+ value_target_key: NestedKey | None = None,
1712
+ value_key: NestedKey | None = None,
1713
+ shifted: bool = False,
1714
+ device: torch.device | None = None,
1715
+ time_dim: int | None = None,
1716
+ ):
1717
+ super().__init__(
1718
+ shifted=shifted,
1719
+ value_network=value_network,
1720
+ differentiable=differentiable,
1721
+ advantage_key=advantage_key,
1722
+ value_target_key=value_target_key,
1723
+ value_key=value_key,
1724
+ skip_existing=skip_existing,
1725
+ device=device,
1726
+ )
1727
+ if not isinstance(gamma, torch.Tensor):
1728
+ gamma = torch.tensor(gamma, device=self._device)
1729
+ if not isinstance(rho_thresh, torch.Tensor):
1730
+ rho_thresh = torch.tensor(rho_thresh, device=self._device)
1731
+ if not isinstance(c_thresh, torch.Tensor):
1732
+ c_thresh = torch.tensor(c_thresh, device=self._device)
1733
+
1734
+ self.register_buffer("gamma", gamma)
1735
+ self.register_buffer("rho_thresh", rho_thresh)
1736
+ self.register_buffer("c_thresh", c_thresh)
1737
+ self.average_adv = average_adv
1738
+ self.actor_network = actor_network
1739
+ self.time_dim = time_dim
1740
+
1741
+ if isinstance(gamma, torch.Tensor) and gamma.shape != ():
1742
+ raise NotImplementedError(
1743
+ "Per-value gamma is not supported yet. Gamma must be a scalar."
1744
+ )
1745
+
1746
+ @property
1747
+ def in_keys(self):
1748
+ parent_in_keys = super().in_keys
1749
+ extended_in_keys = parent_in_keys + [self.tensor_keys.sample_log_prob]
1750
+ return extended_in_keys
1751
+
1752
+ @_self_set_skip_existing
1753
+ @_self_set_grad_enabled
1754
+ @dispatch
1755
+ def forward(
1756
+ self,
1757
+ tensordict: TensorDictBase,
1758
+ *,
1759
+ params: list[Tensor] | None = None,
1760
+ target_params: list[Tensor] | None = None,
1761
+ time_dim: int | None = None,
1762
+ ) -> TensorDictBase:
1763
+ """Computes the V-Trace correction given the data in tensordict.
1764
+
1765
+ If a functional module is provided, a nested TensorDict containing the parameters
1766
+ (and if relevant the target parameters) can be passed to the module.
1767
+
1768
+ Args:
1769
+ tensordict (TensorDictBase): A TensorDict containing the data
1770
+ (an observation key, "action", "reward", "done" and "next" tensordict state
1771
+ as returned by the environment) necessary to compute the value estimates and the GAE.
1772
+ The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are
1773
+ the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
1774
+
1775
+ Keyword Args:
1776
+ params (TensorDictBase, optional): A nested TensorDict containing the params
1777
+ to be passed to the functional value network module.
1778
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
1779
+ target params to be passed to the functional value network module.
1780
+ time_dim (int, optional): the dimension corresponding to the time
1781
+ in the input tensordict. If not provided, defaults to the dimension
1782
+ marked with the ``"time"`` name if any, and to the last dimension
1783
+ otherwise.
1784
+ Negative dimensions are considered with respect to the input
1785
+ tensordict.
1786
+
1787
+ Returns:
1788
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
1789
+
1790
+ Examples:
1791
+ >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
1792
+ >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
1793
+ >>> actor_net = ProbabilisticActor(
1794
+ ... module=actor_net,
1795
+ ... in_keys=["logits"],
1796
+ ... out_keys=["action"],
1797
+ ... distribution_class=OneHotCategorical,
1798
+ ... return_log_prob=True,
1799
+ ... )
1800
+ >>> module = VTrace(
1801
+ ... gamma=0.98,
1802
+ ... value_network=value_net,
1803
+ ... actor_network=actor_net,
1804
+ ... differentiable=False,
1805
+ ... )
1806
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
1807
+ >>> reward = torch.randn(1, 10, 1)
1808
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1809
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1810
+ >>> sample_log_prob = torch.randn(1, 10, 1)
1811
+ >>> tensordict = TensorDict({
1812
+ ... "obs": obs,
1813
+ ... "done": done,
1814
+ ... "terminated": terminated,
1815
+ ... "sample_log_prob": sample_log_prob,
1816
+ ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
1817
+ ... }, batch_size=[1, 10])
1818
+ >>> _ = module(tensordict)
1819
+ >>> assert "advantage" in tensordict.keys()
1820
+
1821
+ The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
1822
+
1823
+ Examples:
1824
+ >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
1825
+ >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
1826
+ >>> actor_net = ProbabilisticActor(
1827
+ ... module=actor_net,
1828
+ ... in_keys=["logits"],
1829
+ ... out_keys=["action"],
1830
+ ... distribution_class=OneHotCategorical,
1831
+ ... return_log_prob=True,
1832
+ ... )
1833
+ >>> module = VTrace(
1834
+ ... gamma=0.98,
1835
+ ... value_network=value_net,
1836
+ ... actor_network=actor_net,
1837
+ ... differentiable=False,
1838
+ ... )
1839
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
1840
+ >>> reward = torch.randn(1, 10, 1)
1841
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
1842
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
1843
+ >>> sample_log_prob = torch.randn(1, 10, 1)
1844
+ >>> tensordict = TensorDict({
1845
+ ... "obs": obs,
1846
+ ... "done": done,
1847
+ ... "terminated": terminated,
1848
+ ... "sample_log_prob": sample_log_prob,
1849
+ ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
1850
+ ... }, batch_size=[1, 10])
1851
+ >>> advantage, value_target = module(
1852
+ ... obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob
1853
+ ... )
1854
+
1855
+ """
1856
+ if tensordict.batch_dims < 1:
1857
+ raise RuntimeError(
1858
+ "Expected input tensordict to have at least one dimensions, got "
1859
+ f"tensordict.batch_size = {tensordict.batch_size}"
1860
+ )
1861
+ reward = tensordict.get(("next", self.tensor_keys.reward))
1862
+ device = reward.device
1863
+
1864
+ if self.gamma.device != device:
1865
+ self.gamma = self.gamma.to(device)
1866
+ gamma = self.gamma
1867
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
1868
+ if steps_to_next_obs is not None:
1869
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
1870
+
1871
+ # Make sure we have the value and next value
1872
+ if self.value_network is not None:
1873
+ if params is not None:
1874
+ params = params.detach()
1875
+ if target_params is None:
1876
+ target_params = params.clone(False)
1877
+ with hold_out_net(self.value_network):
1878
+ # we may still need to pass gradient, but we don't want to assign grads to
1879
+ # value net params
1880
+ value, next_value = self._call_value_nets(
1881
+ data=tensordict,
1882
+ params=params,
1883
+ next_params=target_params,
1884
+ single_call=self.shifted,
1885
+ value_key=self.tensor_keys.value,
1886
+ detach_next=True,
1887
+ vmap_randomness=self.vmap_randomness,
1888
+ )
1889
+ else:
1890
+ value = tensordict.get(self.tensor_keys.value)
1891
+ next_value = tensordict.get(("next", self.tensor_keys.value))
1892
+
1893
+ lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob)
1894
+ if is_tensor_collection(lp):
1895
+ # Sum all values to match the batch size
1896
+ lp = lp.sum(dim="feature", reduce=True)
1897
+ log_mu = lp.view_as(value)
1898
+
1899
+ # Compute log prob with current policy
1900
+ with hold_out_net(self.actor_network):
1901
+ log_pi = _call_actor_net(
1902
+ actor_net=self.actor_network,
1903
+ data=tensordict,
1904
+ params=None,
1905
+ log_prob_key=self.tensor_keys.sample_log_prob,
1906
+ )
1907
+ log_pi = log_pi.view_as(value)
1908
+
1909
+ # Compute the V-Trace correction
1910
+ done = tensordict.get(("next", self.tensor_keys.done))
1911
+ terminated = tensordict.get(("next", self.tensor_keys.terminated))
1912
+
1913
+ time_dim = self._get_time_dim(time_dim, tensordict)
1914
+ adv, value_target = vtrace_advantage_estimate(
1915
+ gamma,
1916
+ log_pi,
1917
+ log_mu,
1918
+ value,
1919
+ next_value,
1920
+ reward,
1921
+ done,
1922
+ terminated,
1923
+ rho_thresh=self.rho_thresh,
1924
+ c_thresh=self.c_thresh,
1925
+ time_dim=time_dim,
1926
+ )
1927
+
1928
+ if self.average_adv:
1929
+ loc = adv.mean()
1930
+ scale = adv.std().clamp_min(1e-5)
1931
+ adv = adv - loc
1932
+ adv = adv / scale
1933
+
1934
+ tensordict.set(self.tensor_keys.advantage, adv)
1935
+ tensordict.set(self.tensor_keys.value_target, value_target)
1936
+
1937
+ return tensordict
1938
+
1939
+
1940
+ def _deprecate_class(cls, new_cls):
1941
+ @wraps(cls.__init__)
1942
+ def new_init(self, *args, **kwargs):
1943
+ warnings.warn(f"class {cls} is deprecated, please use {new_cls} instead.")
1944
+ cls.__init__(self, *args, **kwargs)
1945
+
1946
+ cls.__init__ = new_init
1947
+
1948
+
1949
+ TD0Estimate = type("TD0Estimate", TD0Estimator.__bases__, dict(TD0Estimator.__dict__))
1950
+ _deprecate_class(TD0Estimate, TD0Estimator)
1951
+ TD1Estimate = type("TD1Estimate", TD1Estimator.__bases__, dict(TD1Estimator.__dict__))
1952
+ _deprecate_class(TD1Estimate, TD1Estimator)
1953
+ TDLambdaEstimate = type(
1954
+ "TDLambdaEstimate", TDLambdaEstimator.__bases__, dict(TDLambdaEstimator.__dict__)
1955
+ )
1956
+ _deprecate_class(TDLambdaEstimate, TDLambdaEstimator)