torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1459 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ import warnings
9
+ from functools import wraps
10
+
11
+ import torch
12
+
13
+ try:
14
+ from torch.compiler import is_dynamo_compiling
15
+ except ImportError:
16
+ from torch._dynamo import is_compiling as is_dynamo_compiling
17
+
18
+ __all__ = [
19
+ "generalized_advantage_estimate",
20
+ "vec_generalized_advantage_estimate",
21
+ "td0_advantage_estimate",
22
+ "td0_return_estimate",
23
+ "td1_return_estimate",
24
+ "vec_td1_return_estimate",
25
+ "td1_advantage_estimate",
26
+ "vec_td1_advantage_estimate",
27
+ "td_lambda_return_estimate",
28
+ "vec_td_lambda_return_estimate",
29
+ "td_lambda_advantage_estimate",
30
+ "vec_td_lambda_advantage_estimate",
31
+ "vtrace_advantage_estimate",
32
+ ]
33
+
34
+ from torchrl.objectives.value.utils import (
35
+ _custom_conv1d,
36
+ _get_num_per_traj,
37
+ _inv_pad_sequence,
38
+ _make_gammas_tensor,
39
+ _split_and_pad_sequence,
40
+ )
41
+
42
+ SHAPE_ERR = (
43
+ "All input tensors (value, reward and done states) must share a unique shape."
44
+ )
45
+
46
+
47
+ def _transpose_time(fun):
48
+ """Checks the time_dim argument of the function to allow for any dim.
49
+
50
+ If not -2, makes a transpose of all the multi-dim input tensors to bring
51
+ time at -2, and does the opposite transform for the outputs.
52
+ """
53
+ ERROR = (
54
+ "The tensor shape and the time dimension are not compatible: "
55
+ "got {} and time_dim={}."
56
+ )
57
+
58
+ @wraps(fun)
59
+ def transposed_fun(*args, **kwargs):
60
+ time_dim = kwargs.pop("time_dim", -2)
61
+
62
+ def transpose_tensor(tensor):
63
+ if not isinstance(tensor, torch.Tensor) or tensor.numel() <= 1:
64
+ return tensor, False
65
+ if time_dim >= 0:
66
+ timedim = time_dim - tensor.ndim
67
+ else:
68
+ timedim = time_dim
69
+ if timedim < -tensor.ndim or timedim >= 0:
70
+ raise RuntimeError(ERROR.format(tensor.shape, timedim))
71
+ if tensor.ndim >= 2:
72
+ single_dim = False
73
+ tensor = tensor.transpose(timedim, -2)
74
+ elif tensor.ndim == 1 and timedim == -1:
75
+ single_dim = True
76
+ tensor = tensor.unsqueeze(-1)
77
+ else:
78
+ raise RuntimeError(ERROR.format(tensor.shape, timedim))
79
+ return tensor, single_dim
80
+
81
+ if time_dim != -2:
82
+ single_dim = False
83
+ if args:
84
+ args, single_dim = zip(*(transpose_tensor(arg) for arg in args))
85
+ single_dim = any(single_dim)
86
+ for k, item in list(kwargs.items()):
87
+ item, sd = transpose_tensor(item)
88
+ single_dim = single_dim or sd
89
+ kwargs[k] = item
90
+ # We don't pass time_dim because it isn't supposed to be used thereafter
91
+ out = fun(*args, **kwargs)
92
+ if isinstance(out, torch.Tensor):
93
+ out = transpose_tensor(out)[0]
94
+ if single_dim:
95
+ out = out.squeeze(-2)
96
+ return out
97
+ if single_dim:
98
+ return tuple(transpose_tensor(_out)[0].squeeze(-2) for _out in out)
99
+ return tuple(transpose_tensor(_out)[0] for _out in out)
100
+ # We don't pass time_dim because it isn't supposed to be used thereafter
101
+ out = fun(*args, **kwargs)
102
+ if isinstance(out, tuple):
103
+ for _out in out:
104
+ if _out.ndim < 2:
105
+ raise RuntimeError(ERROR.format(_out.shape, time_dim))
106
+ else:
107
+ if out.ndim < 2:
108
+ raise RuntimeError(ERROR.format(out.shape, time_dim))
109
+ return out
110
+
111
+ return transposed_fun
112
+
113
+
114
+ ########################################################################
115
+ # GAE
116
+ # ---
117
+
118
+
119
+ @_transpose_time
120
+ def generalized_advantage_estimate(
121
+ gamma: float,
122
+ lmbda: float,
123
+ state_value: torch.Tensor,
124
+ next_state_value: torch.Tensor,
125
+ reward: torch.Tensor,
126
+ done: torch.Tensor,
127
+ terminated: torch.Tensor | None = None,
128
+ *,
129
+ time_dim: int = -2,
130
+ ) -> tuple[torch.Tensor, torch.Tensor]:
131
+ """Generalized advantage estimate of a trajectory.
132
+
133
+ Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
134
+ https://arxiv.org/pdf/1506.02438.pdf for more context.
135
+
136
+ Args:
137
+ gamma (scalar): exponential mean discount.
138
+ lmbda (scalar): trajectory discount.
139
+ state_value (Tensor): value function result with old_state input.
140
+ next_state_value (Tensor): value function result with new_state input.
141
+ reward (Tensor): reward of taking actions in the environment.
142
+ done (Tensor): boolean flag for end of trajectory.
143
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
144
+ if not provided.
145
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
146
+
147
+ All tensors (values, reward and done) must have shape
148
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
149
+
150
+ """
151
+ if terminated is None:
152
+ terminated = done.clone()
153
+ if not (
154
+ next_state_value.shape
155
+ == state_value.shape
156
+ == reward.shape
157
+ == done.shape
158
+ == terminated.shape
159
+ ):
160
+ raise RuntimeError(SHAPE_ERR)
161
+ dtype = next_state_value.dtype
162
+ device = state_value.device
163
+ not_done = (~done).int()
164
+ not_terminated = (~terminated).int()
165
+ *batch_size, time_steps, lastdim = not_done.shape
166
+ advantage = torch.empty(
167
+ *batch_size, time_steps, lastdim, device=device, dtype=dtype
168
+ )
169
+ prev_advantage = 0
170
+ g_not_terminated = gamma * not_terminated
171
+ delta = reward + (g_not_terminated * next_state_value) - state_value
172
+ discount = lmbda * gamma * not_done
173
+ for t in reversed(range(time_steps)):
174
+ prev_advantage = advantage[..., t, :] = delta[..., t, :] + (
175
+ prev_advantage * discount[..., t, :]
176
+ )
177
+
178
+ value_target = advantage + state_value
179
+
180
+ return advantage, value_target
181
+
182
+
183
+ def _geom_series_like(t, r, thr):
184
+ """Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`.
185
+
186
+ Drops all elements which are smaller than `thr` (unless in compile mode).
187
+ """
188
+ if is_dynamo_compiling():
189
+ if isinstance(r, torch.Tensor):
190
+ rs = r.expand_as(t)
191
+ else:
192
+ rs = torch.full_like(t, r)
193
+ else:
194
+ if isinstance(r, torch.Tensor):
195
+ r = r.item()
196
+
197
+ if r == 0.0:
198
+ return torch.zeros_like(t)
199
+ elif r >= 1.0:
200
+ lim = t.numel()
201
+ else:
202
+ lim = int(math.log(thr) / math.log(r))
203
+
204
+ rs = torch.full_like(t[:lim], r)
205
+ rs[0] = 1.0
206
+ rs = rs.cumprod(0)
207
+ rs = rs.unsqueeze(-1)
208
+ return rs
209
+
210
+
211
+ def _fast_vec_gae(
212
+ reward: torch.Tensor,
213
+ state_value: torch.Tensor,
214
+ next_state_value: torch.Tensor,
215
+ done: torch.Tensor,
216
+ terminated: torch.Tensor,
217
+ gamma: float,
218
+ lmbda: float,
219
+ thr: float = 1e-7,
220
+ ):
221
+ """Fast vectorized Generalized Advantage Estimate when gamma and lmbda are scalars.
222
+
223
+ In contrast to `vec_generalized_advantage_estimate` this function does not need
224
+ to allocate a big tensor of the form [B, T, T].
225
+
226
+ Args:
227
+ reward (torch.Tensor): a [*B, T, F] tensor containing rewards
228
+ state_value (torch.Tensor): a [*B, T, F] tensor containing state values (value function)
229
+ next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
230
+ done (torch.Tensor): a [B, T] boolean tensor containing the done states.
231
+ terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
232
+ gamma (scalar): the gamma decay (trajectory discount)
233
+ lmbda (scalar): the lambda decay (exponential mean discount)
234
+ thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
235
+ Defaults to 1e-7.
236
+
237
+ All tensors (values, reward and done) must have shape
238
+ ``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
239
+
240
+ """
241
+ # _get_num_per_traj and _split_and_pad_sequence need
242
+ # time dimension at last position
243
+ done = done.transpose(-2, -1)
244
+ terminated = terminated.transpose(-2, -1)
245
+ reward = reward.transpose(-2, -1)
246
+ state_value = state_value.transpose(-2, -1)
247
+ next_state_value = next_state_value.transpose(-2, -1)
248
+
249
+ gammalmbda = gamma * lmbda
250
+ not_terminated = (~terminated).int()
251
+ td0 = reward + not_terminated * gamma * next_state_value - state_value
252
+
253
+ num_per_traj = _get_num_per_traj(done)
254
+ td0_flat, mask = _split_and_pad_sequence(td0, num_per_traj, return_mask=True)
255
+
256
+ gammalmbdas = _geom_series_like(td0_flat[0], gammalmbda, thr=thr)
257
+
258
+ advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
259
+ advantage = advantage.squeeze(1)
260
+ advantage = advantage[mask].view_as(reward)
261
+
262
+ value_target = advantage + state_value
263
+
264
+ advantage = advantage.transpose(-1, -2)
265
+ value_target = value_target.transpose(-1, -2)
266
+
267
+ return advantage, value_target
268
+
269
+
270
+ @_transpose_time
271
+ def vec_generalized_advantage_estimate(
272
+ gamma: float | torch.Tensor,
273
+ lmbda: float | torch.Tensor,
274
+ state_value: torch.Tensor,
275
+ next_state_value: torch.Tensor,
276
+ reward: torch.Tensor,
277
+ done: torch.Tensor,
278
+ terminated: torch.Tensor | None = None,
279
+ *,
280
+ time_dim: int = -2,
281
+ ) -> tuple[torch.Tensor, torch.Tensor]:
282
+ """Vectorized Generalized advantage estimate of a trajectory.
283
+
284
+ Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
285
+ https://arxiv.org/pdf/1506.02438.pdf for more context.
286
+
287
+ Args:
288
+ gamma (scalar): exponential mean discount.
289
+ lmbda (scalar): trajectory discount.
290
+ state_value (Tensor): value function result with old_state input.
291
+ next_state_value (Tensor): value function result with new_state input.
292
+ reward (Tensor): reward of taking actions in the environment.
293
+ done (Tensor): boolean flag for end of trajectory.
294
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
295
+ if not provided.
296
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
297
+
298
+ All tensors (values, reward and done) must have shape
299
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
300
+
301
+ """
302
+ if terminated is None:
303
+ terminated = done.clone()
304
+ if not (
305
+ next_state_value.shape
306
+ == state_value.shape
307
+ == reward.shape
308
+ == done.shape
309
+ == terminated.shape
310
+ ):
311
+ raise RuntimeError(SHAPE_ERR)
312
+ dtype = state_value.dtype
313
+ *batch_size, time_steps, lastdim = terminated.shape
314
+
315
+ value = gamma * lmbda
316
+
317
+ if isinstance(value, torch.Tensor) and value.numel() > 1:
318
+ # create tensor while ensuring that gradients are passed
319
+ not_done = (~done).to(dtype)
320
+ gammalmbdas = not_done * value
321
+ else:
322
+ # when gamma and lmbda are scalars, use fast_vec_gae implementation
323
+ return _fast_vec_gae(
324
+ reward=reward,
325
+ state_value=state_value,
326
+ next_state_value=next_state_value,
327
+ done=done,
328
+ terminated=terminated,
329
+ gamma=gamma,
330
+ lmbda=lmbda,
331
+ )
332
+
333
+ gammalmbdas = _make_gammas_tensor(gammalmbdas, time_steps, True)
334
+ gammalmbdas = gammalmbdas.cumprod(-2)
335
+
336
+ # Skip data-dependent truncation optimization during compile (causes guards)
337
+ if not is_dynamo_compiling():
338
+ first_below_thr = gammalmbdas < 1e-7
339
+ # if we have multiple gammas, we only want to truncate if _all_ of
340
+ # the geometric sequences fall below the threshold
341
+ first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1)
342
+ if first_below_thr.any():
343
+ first_below_thr = torch.where(first_below_thr)[0][0].item()
344
+ gammalmbdas = gammalmbdas[..., :first_below_thr, :]
345
+
346
+ not_terminated = (~terminated).to(dtype)
347
+ td0 = reward + not_terminated * gamma * next_state_value - state_value
348
+
349
+ if len(batch_size) > 1:
350
+ td0 = td0.flatten(0, len(batch_size) - 1)
351
+ elif not len(batch_size):
352
+ td0 = td0.unsqueeze(0)
353
+
354
+ td0_r = td0.transpose(-2, -1)
355
+ shapes = td0_r.shape[:2]
356
+ if lastdim != 1:
357
+ # then we flatten again the first dims and reset a singleton in between
358
+ td0_r = td0_r.flatten(0, 1).unsqueeze(1)
359
+ advantage = _custom_conv1d(td0_r, gammalmbdas)
360
+ if lastdim != 1:
361
+ advantage = advantage.squeeze(1).unflatten(0, shapes)
362
+
363
+ if len(batch_size) > 1:
364
+ advantage = advantage.unflatten(0, batch_size)
365
+ elif not len(batch_size):
366
+ advantage = advantage.squeeze(0)
367
+
368
+ advantage = advantage.transpose(-2, -1)
369
+ value_target = advantage + state_value
370
+ return advantage, value_target
371
+
372
+
373
+ ########################################################################
374
+ # TD(0)
375
+ # -----
376
+
377
+
378
+ def td0_advantage_estimate(
379
+ gamma: float,
380
+ state_value: torch.Tensor,
381
+ next_state_value: torch.Tensor,
382
+ reward: torch.Tensor,
383
+ done: torch.Tensor,
384
+ terminated: torch.Tensor | None = None,
385
+ ) -> torch.Tensor:
386
+ """TD(0) advantage estimate of a trajectory.
387
+
388
+ Also known as bootstrapped Temporal Difference or one-step return.
389
+
390
+ Args:
391
+ gamma (scalar): exponential mean discount.
392
+ state_value (Tensor): value function result with old_state input.
393
+ next_state_value (Tensor): value function result with new_state input.
394
+ reward (Tensor): reward of taking actions in the environment.
395
+ done (Tensor): boolean flag for end of trajectory.
396
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
397
+ if not provided.
398
+
399
+ All tensors (values, reward and done) must have shape
400
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
401
+
402
+ """
403
+ if terminated is None:
404
+ terminated = done.clone()
405
+ if not (
406
+ next_state_value.shape
407
+ == state_value.shape
408
+ == reward.shape
409
+ == done.shape
410
+ == terminated.shape
411
+ ):
412
+ raise RuntimeError(SHAPE_ERR)
413
+ returns = td0_return_estimate(gamma, next_state_value, reward, terminated)
414
+ advantage = returns - state_value
415
+ return advantage
416
+
417
+
418
+ def td0_return_estimate(
419
+ gamma: float,
420
+ next_state_value: torch.Tensor,
421
+ reward: torch.Tensor,
422
+ terminated: torch.Tensor | None = None,
423
+ *,
424
+ done: torch.Tensor | None = None,
425
+ ) -> torch.Tensor:
426
+ # noqa: D417
427
+ """TD(0) discounted return estimate of a trajectory.
428
+
429
+ Also known as bootstrapped Temporal Difference or one-step return.
430
+
431
+ Args:
432
+ gamma (scalar): exponential mean discount.
433
+ next_state_value (Tensor): value function result with new_state input.
434
+ must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
435
+ reward (Tensor): reward of taking actions in the environment.
436
+ must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
437
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
438
+ if not provided.
439
+
440
+ Keyword Args:
441
+ done (Tensor): Deprecated. Use ``terminated`` instead.
442
+
443
+ All tensors (values, reward and done) must have shape
444
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
445
+
446
+ """
447
+ if done is not None and terminated is None:
448
+ terminated = done.clone()
449
+ warnings.warn(
450
+ "done for td0_return_estimate is deprecated. Pass ``terminated`` instead."
451
+ )
452
+ if not (next_state_value.shape == reward.shape == terminated.shape):
453
+ raise RuntimeError(SHAPE_ERR)
454
+ not_terminated = (~terminated).int()
455
+ returns = reward + gamma * not_terminated * next_state_value
456
+ return returns
457
+
458
+
459
+ ########################################################################
460
+ # TD(1)
461
+ # ----------
462
+
463
+
464
+ @_transpose_time
465
+ def td1_return_estimate(
466
+ gamma: float,
467
+ next_state_value: torch.Tensor,
468
+ reward: torch.Tensor,
469
+ done: torch.Tensor,
470
+ terminated: torch.Tensor | None = None,
471
+ rolling_gamma: bool | None = None,
472
+ *,
473
+ time_dim: int = -2,
474
+ ) -> torch.Tensor:
475
+ r"""TD(1) return estimate.
476
+
477
+ Args:
478
+ gamma (scalar): exponential mean discount.
479
+ next_state_value (Tensor): value function result with new_state input.
480
+ reward (Tensor): reward of taking actions in the environment.
481
+ done (Tensor): boolean flag for end of trajectory.
482
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
483
+ if not provided.
484
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
485
+ of a gamma tensor is tied to a single event:
486
+
487
+ >>> gamma = [g1, g2, g3, g4]
488
+ >>> value = [v1, v2, v3, v4]
489
+ >>> return = [
490
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
491
+ ... v2 + g2 v3 + g2 g3 v4,
492
+ ... v3 + g3 v4,
493
+ ... v4,
494
+ ... ]
495
+
496
+ if ``False``, it is assumed that each gamma is tied to the upcoming
497
+ trajectory:
498
+
499
+ >>> gamma = [g1, g2, g3, g4]
500
+ >>> value = [v1, v2, v3, v4]
501
+ >>> return = [
502
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
503
+ ... v2 + g2 v3 + g2**2 v4,
504
+ ... v3 + g3 v4,
505
+ ... v4,
506
+ ... ]
507
+
508
+ Default is ``True``.
509
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
510
+
511
+ All tensors (values, reward and done) must have shape
512
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
513
+
514
+ """
515
+ if terminated is None:
516
+ terminated = done.clone()
517
+ if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
518
+ raise RuntimeError(SHAPE_ERR)
519
+ not_done = (~done).int()
520
+ not_terminated = (~terminated).int()
521
+
522
+ returns = torch.empty_like(next_state_value)
523
+
524
+ T = returns.shape[-2]
525
+
526
+ single_gamma = False
527
+ if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape):
528
+ single_gamma = True
529
+ if isinstance(gamma, torch.Tensor):
530
+ # Use expand instead of full_like to avoid .item() call which creates
531
+ # unbacked symbols during torch.compile tracing.
532
+ if gamma.device != next_state_value.device:
533
+ gamma = gamma.to(next_state_value.device)
534
+ gamma = gamma.expand(next_state_value.shape)
535
+ else:
536
+ gamma = torch.full_like(next_state_value, gamma)
537
+
538
+ if rolling_gamma is None:
539
+ rolling_gamma = True
540
+ elif not rolling_gamma and single_gamma:
541
+ raise RuntimeError(
542
+ "rolling_gamma=False is expected only with time-sensitive gamma values"
543
+ )
544
+
545
+ done_but_not_terminated = (done & ~terminated).int()
546
+ if rolling_gamma:
547
+ gamma = gamma * not_terminated
548
+ g = next_state_value[..., -1, :]
549
+ for i in reversed(range(T)):
550
+ # if not done (and hence not terminated), get the bootstrapped value
551
+ # if done but not terminated, get nex_val
552
+ # if terminated, take nothing (gamma = 0)
553
+ dnt = done_but_not_terminated[..., i, :]
554
+ g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
555
+ (1 - dnt) * g + dnt * next_state_value[..., i, :]
556
+ )
557
+ else:
558
+ for k in range(T):
559
+ g = 0
560
+ _gamma = gamma[..., k, :]
561
+ nd = not_terminated
562
+ _gamma = _gamma.unsqueeze(-2) * nd
563
+ for i in reversed(range(k, T)):
564
+ dnt = done_but_not_terminated[..., i, :]
565
+ g = reward[..., i, :] + _gamma[..., i, :] * (
566
+ (1 - dnt) * g + dnt * next_state_value[..., i, :]
567
+ )
568
+ returns[..., k, :] = g
569
+ return returns
570
+
571
+
572
+ def td1_advantage_estimate(
573
+ gamma: float,
574
+ state_value: torch.Tensor,
575
+ next_state_value: torch.Tensor,
576
+ reward: torch.Tensor,
577
+ done: torch.Tensor,
578
+ terminated: torch.Tensor | None = None,
579
+ rolling_gamma: bool | None = None,
580
+ time_dim: int = -2,
581
+ ) -> torch.Tensor:
582
+ """TD(1) advantage estimate.
583
+
584
+ Args:
585
+ gamma (scalar): exponential mean discount.
586
+ state_value (Tensor): value function result with old_state input.
587
+ next_state_value (Tensor): value function result with new_state input.
588
+ reward (Tensor): reward of taking actions in the environment.
589
+ done (Tensor): boolean flag for end of trajectory.
590
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
591
+ if not provided.
592
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
593
+ of a gamma tensor is tied to a single event:
594
+
595
+ >>> gamma = [g1, g2, g3, g4]
596
+ >>> value = [v1, v2, v3, v4]
597
+ >>> return = [
598
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
599
+ ... v2 + g2 v3 + g2 g3 v4,
600
+ ... v3 + g3 v4,
601
+ ... v4,
602
+ ... ]
603
+
604
+ if ``False``, it is assumed that each gamma is tied to the upcoming
605
+ trajectory:
606
+
607
+ >>> gamma = [g1, g2, g3, g4]
608
+ >>> value = [v1, v2, v3, v4]
609
+ >>> return = [
610
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
611
+ ... v2 + g2 v3 + g2**2 v4,
612
+ ... v3 + g3 v4,
613
+ ... v4,
614
+ ... ]
615
+
616
+ Default is ``True``.
617
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
618
+
619
+ All tensors (values, reward and done) must have shape
620
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
621
+
622
+ """
623
+ if terminated is None:
624
+ terminated = done.clone()
625
+ if not (
626
+ next_state_value.shape
627
+ == state_value.shape
628
+ == reward.shape
629
+ == done.shape
630
+ == terminated.shape
631
+ ):
632
+ raise RuntimeError(SHAPE_ERR)
633
+ if not state_value.shape == next_state_value.shape:
634
+ raise RuntimeError("shape of state_value and next_state_value must match")
635
+ returns = td1_return_estimate(
636
+ gamma,
637
+ next_state_value,
638
+ reward,
639
+ done,
640
+ terminated=terminated,
641
+ rolling_gamma=rolling_gamma,
642
+ time_dim=time_dim,
643
+ )
644
+ advantage = returns - state_value
645
+ return advantage
646
+
647
+
648
+ @_transpose_time
649
+ def vec_td1_return_estimate(
650
+ gamma,
651
+ next_state_value,
652
+ reward,
653
+ done: torch.Tensor,
654
+ terminated: torch.Tensor | None = None,
655
+ rolling_gamma: bool | None = None,
656
+ time_dim: int = -2,
657
+ ):
658
+ """Vectorized TD(1) return estimate.
659
+
660
+ Args:
661
+ gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
662
+ next_state_value (Tensor): value function result with new_state input.
663
+ reward (Tensor): reward of taking actions in the environment.
664
+ done (Tensor): boolean flag for end of trajectory.
665
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
666
+ if not provided.
667
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
668
+ of the gamma tensor is tied to a single event:
669
+
670
+ >>> gamma = [g1, g2, g3, g4]
671
+ >>> value = [v1, v2, v3, v4]
672
+ >>> return = [
673
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
674
+ ... v2 + g2 v3 + g2 g3 v4,
675
+ ... v3 + g3 v4,
676
+ ... v4,
677
+ ... ]
678
+
679
+ if ``False``, it is assumed that each gamma is tied to the upcoming
680
+ trajectory:
681
+
682
+ >>> gamma = [g1, g2, g3, g4]
683
+ >>> value = [v1, v2, v3, v4]
684
+ >>> return = [
685
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
686
+ ... v2 + g2 v3 + g2**2 v4,
687
+ ... v3 + g3 v4,
688
+ ... v4,
689
+ ... ]
690
+
691
+ Default is ``True``.
692
+ time_dim (int): dimension where the time is unrolled. Defaults to ``-2``.
693
+
694
+ All tensors (values, reward and done) must have shape
695
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
696
+
697
+ """
698
+ return vec_td_lambda_return_estimate(
699
+ gamma=gamma,
700
+ next_state_value=next_state_value,
701
+ reward=reward,
702
+ done=done,
703
+ terminated=terminated,
704
+ rolling_gamma=rolling_gamma,
705
+ lmbda=1,
706
+ time_dim=time_dim,
707
+ )
708
+
709
+
710
+ def vec_td1_advantage_estimate(
711
+ gamma,
712
+ state_value,
713
+ next_state_value,
714
+ reward,
715
+ done: torch.Tensor,
716
+ terminated: torch.Tensor | None = None,
717
+ rolling_gamma: bool | None = None,
718
+ time_dim: int = -2,
719
+ ):
720
+ """Vectorized TD(1) advantage estimate.
721
+
722
+ Args:
723
+ gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
724
+ state_value (Tensor): value function result with old_state input.
725
+ next_state_value (Tensor): value function result with new_state input.
726
+ reward (Tensor): reward of taking actions in the environment.
727
+ done (Tensor): boolean flag for end of trajectory.
728
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
729
+ if not provided.
730
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
731
+ of a gamma tensor is tied to a single event:
732
+
733
+ >>> gamma = [g1, g2, g3, g4]
734
+ >>> value = [v1, v2, v3, v4]
735
+ >>> return = [
736
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
737
+ ... v2 + g2 v3 + g2 g3 v4,
738
+ ... v3 + g3 v4,
739
+ ... v4,
740
+ ... ]
741
+
742
+ if ``False``, it is assumed that each gamma is tied to the upcoming
743
+ trajectory:
744
+
745
+ >>> gamma = [g1, g2, g3, g4]
746
+ >>> value = [v1, v2, v3, v4]
747
+ >>> return = [
748
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
749
+ ... v2 + g2 v3 + g2**2 v4,
750
+ ... v3 + g3 v4,
751
+ ... v4,
752
+ ... ]
753
+
754
+ Default is ``True``.
755
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
756
+
757
+ All tensors (values, reward and done) must have shape
758
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
759
+
760
+ """
761
+ if terminated is None:
762
+ terminated = done.clone()
763
+ if not (
764
+ next_state_value.shape
765
+ == state_value.shape
766
+ == reward.shape
767
+ == done.shape
768
+ == terminated.shape
769
+ ):
770
+ raise RuntimeError(SHAPE_ERR)
771
+ return (
772
+ vec_td1_return_estimate(
773
+ gamma,
774
+ next_state_value,
775
+ reward,
776
+ done=done,
777
+ terminated=terminated,
778
+ rolling_gamma=rolling_gamma,
779
+ time_dim=time_dim,
780
+ )
781
+ - state_value
782
+ )
783
+
784
+
785
+ ########################################################################
786
+ # TD(lambda)
787
+ # ----------
788
+
789
+
790
+ @_transpose_time
791
+ def td_lambda_return_estimate(
792
+ gamma: float,
793
+ lmbda: float,
794
+ next_state_value: torch.Tensor,
795
+ reward: torch.Tensor,
796
+ done: torch.Tensor,
797
+ terminated: torch.Tensor | None = None,
798
+ rolling_gamma: bool | None = None,
799
+ *,
800
+ time_dim: int = -2,
801
+ ) -> torch.Tensor:
802
+ r"""TD(:math:`\lambda`) return estimate.
803
+
804
+ Args:
805
+ gamma (scalar): exponential mean discount.
806
+ lmbda (scalar): trajectory discount.
807
+ next_state_value (Tensor): value function result with new_state input.
808
+ reward (Tensor): reward of taking actions in the environment.
809
+ done (Tensor): boolean flag for end of trajectory.
810
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
811
+ if not provided.
812
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
813
+ of a gamma tensor is tied to a single event:
814
+
815
+ >>> gamma = [g1, g2, g3, g4]
816
+ >>> value = [v1, v2, v3, v4]
817
+ >>> return = [
818
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
819
+ ... v2 + g2 v3 + g2 g3 v4,
820
+ ... v3 + g3 v4,
821
+ ... v4,
822
+ ... ]
823
+
824
+ if ``False``, it is assumed that each gamma is tied to the upcoming
825
+ trajectory:
826
+
827
+ >>> gamma = [g1, g2, g3, g4]
828
+ >>> value = [v1, v2, v3, v4]
829
+ >>> return = [
830
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
831
+ ... v2 + g2 v3 + g2**2 v4,
832
+ ... v3 + g3 v4,
833
+ ... v4,
834
+ ... ]
835
+
836
+ Default is ``True``.
837
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
838
+
839
+ All tensors (values, reward and done) must have shape
840
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
841
+
842
+ """
843
+ if terminated is None:
844
+ terminated = done.clone()
845
+ if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
846
+ raise RuntimeError(SHAPE_ERR)
847
+
848
+ not_terminated = (~terminated).int()
849
+
850
+ returns = torch.empty_like(next_state_value)
851
+ next_state_value = next_state_value * not_terminated
852
+
853
+ *batch, T, lastdim = returns.shape
854
+
855
+ # if gamma is not a tensor of the same shape as other inputs, we use rolling_gamma = True
856
+ single_gamma = False
857
+ if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape):
858
+ single_gamma = True
859
+ if isinstance(gamma, torch.Tensor):
860
+ # Use expand instead of full_like to avoid .item() call which creates
861
+ # unbacked symbols during torch.compile tracing.
862
+ if gamma.device != next_state_value.device:
863
+ gamma = gamma.to(next_state_value.device)
864
+ gamma = gamma.expand(next_state_value.shape)
865
+ else:
866
+ gamma = torch.full_like(next_state_value, gamma)
867
+
868
+ single_lambda = False
869
+ if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape):
870
+ single_lambda = True
871
+ if isinstance(lmbda, torch.Tensor):
872
+ # Use expand instead of full_like to avoid .item() call which creates
873
+ # unbacked symbols during torch.compile tracing.
874
+ if lmbda.device != next_state_value.device:
875
+ lmbda = lmbda.to(next_state_value.device)
876
+ lmbda = lmbda.expand(next_state_value.shape)
877
+ else:
878
+ lmbda = torch.full_like(next_state_value, lmbda)
879
+
880
+ if rolling_gamma is None:
881
+ rolling_gamma = True
882
+ elif not rolling_gamma and single_gamma and single_lambda:
883
+ raise RuntimeError(
884
+ "rolling_gamma=False is expected only with time-sensitive gamma or lambda values"
885
+ )
886
+ if rolling_gamma:
887
+ g = next_state_value[..., -1, :]
888
+ for i in reversed(range(T)):
889
+ dn = done[..., i, :].int()
890
+ nv = next_state_value[..., i, :]
891
+ lmd = lmbda[..., i, :]
892
+ # if done, the bootstrapped gain is the next value, otherwise it's the
893
+ # value we computed during the previous iter
894
+ g = g * (1 - dn) + nv * dn
895
+ g = returns[..., i, :] = reward[..., i, :] + gamma[..., i, :] * (
896
+ (1 - lmd) * nv + lmd * g
897
+ )
898
+ else:
899
+ for k in range(T):
900
+ g = next_state_value[..., -1, :]
901
+ _gamma = gamma[..., k, :]
902
+ _lambda = lmbda[..., k, :]
903
+ for i in reversed(range(k, T)):
904
+ dn = done[..., i, :].int()
905
+ nv = next_state_value[..., i, :]
906
+ g = g * (1 - dn) + nv * dn
907
+ g = reward[..., i, :] + _gamma * ((1 - _lambda) * nv + _lambda * g)
908
+ returns[..., k, :] = g
909
+
910
+ return returns
911
+
912
+
913
+ def td_lambda_advantage_estimate(
914
+ gamma: float,
915
+ lmbda: float,
916
+ state_value: torch.Tensor,
917
+ next_state_value: torch.Tensor,
918
+ reward: torch.Tensor,
919
+ done: torch.Tensor,
920
+ terminated: torch.Tensor | None = None,
921
+ rolling_gamma: bool | None = None,
922
+ # not a kwarg because used directly
923
+ time_dim: int = -2,
924
+ ) -> torch.Tensor:
925
+ r"""TD(:math:`\lambda`) advantage estimate.
926
+
927
+ Args:
928
+ gamma (scalar): exponential mean discount.
929
+ lmbda (scalar): trajectory discount.
930
+ state_value (Tensor): value function result with old_state input.
931
+ next_state_value (Tensor): value function result with new_state input.
932
+ reward (Tensor): reward of taking actions in the environment.
933
+ done (Tensor): boolean flag for end of trajectory.
934
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
935
+ if not provided.
936
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
937
+ of a gamma tensor is tied to a single event:
938
+
939
+ >>> gamma = [g1, g2, g3, g4]
940
+ >>> value = [v1, v2, v3, v4]
941
+ >>> return = [
942
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
943
+ ... v2 + g2 v3 + g2 g3 v4,
944
+ ... v3 + g3 v4,
945
+ ... v4,
946
+ ... ]
947
+
948
+ if ``False``, it is assumed that each gamma is tied to the upcoming
949
+ trajectory:
950
+
951
+ >>> gamma = [g1, g2, g3, g4]
952
+ >>> value = [v1, v2, v3, v4]
953
+ >>> return = [
954
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
955
+ ... v2 + g2 v3 + g2**2 v4,
956
+ ... v3 + g3 v4,
957
+ ... v4,
958
+ ... ]
959
+
960
+ Default is ``True``.
961
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
962
+
963
+ All tensors (values, reward and done) must have shape
964
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
965
+
966
+ """
967
+ if terminated is None:
968
+ terminated = done.clone()
969
+ if not (
970
+ next_state_value.shape
971
+ == state_value.shape
972
+ == reward.shape
973
+ == done.shape
974
+ == terminated.shape
975
+ ):
976
+ raise RuntimeError(SHAPE_ERR)
977
+ if not state_value.shape == next_state_value.shape:
978
+ raise RuntimeError("shape of state_value and next_state_value must match")
979
+ returns = td_lambda_return_estimate(
980
+ gamma,
981
+ lmbda,
982
+ next_state_value,
983
+ reward,
984
+ done,
985
+ terminated=terminated,
986
+ rolling_gamma=rolling_gamma,
987
+ time_dim=time_dim,
988
+ )
989
+ advantage = returns - state_value
990
+ return advantage
991
+
992
+
993
+ def _fast_td_lambda_return_estimate(
994
+ gamma: torch.Tensor | float,
995
+ lmbda: float,
996
+ next_state_value: torch.Tensor,
997
+ reward: torch.Tensor,
998
+ done: torch.Tensor,
999
+ terminated: torch.Tensor,
1000
+ thr: float = 1e-7,
1001
+ ):
1002
+ """Fast vectorized TD lambda return estimate.
1003
+
1004
+ In contrast to the generalized `vec_td_lambda_return_estimate` this function does not need
1005
+ to allocate a big tensor of the form [B, T, T], but it only works with gamma/lmbda being scalars.
1006
+
1007
+ Args:
1008
+ gamma (scalar): the gamma decay, can be a tensor with a single element (trajectory discount)
1009
+ lmbda (scalar): the lambda decay (exponential mean discount)
1010
+ next_state_value (torch.Tensor): a [*B, T, F] tensor containing next state values (value function)
1011
+ reward (torch.Tensor): a [*B, T, F] tensor containing rewards
1012
+ done (Tensor): boolean flag for end of trajectory.
1013
+ terminated (Tensor): boolean flag for end of episode.
1014
+ thr (:obj:`float`): threshold for the filter. Below this limit, components will ignored.
1015
+ Defaults to 1e-7.
1016
+
1017
+ All tensors (values, reward and done) must have shape
1018
+ ``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
1019
+
1020
+ """
1021
+ device = reward.device
1022
+ done = done.transpose(-2, -1)
1023
+ terminated = terminated.transpose(-2, -1)
1024
+ reward = reward.transpose(-2, -1)
1025
+ next_state_value = next_state_value.transpose(-2, -1)
1026
+
1027
+ # the only valid next states are those where the trajectory does not terminate
1028
+ next_state_value = (~terminated).int() * next_state_value
1029
+
1030
+ # Use torch.full to create directly on device (avoids DeviceCopy in cudagraph)
1031
+ # Handle both scalar and single-element tensor gamma
1032
+ if isinstance(gamma, torch.Tensor):
1033
+ gamma_tensor = gamma.to(device).view(1)
1034
+ else:
1035
+ gamma_tensor = torch.full((1,), gamma, device=device)
1036
+ gammalmbda = gamma_tensor * lmbda
1037
+
1038
+ num_per_traj = _get_num_per_traj(done)
1039
+
1040
+ done = done.clone()
1041
+ done[..., -1] = 1
1042
+ not_done = (~done).int()
1043
+
1044
+ t = reward + next_state_value * gamma_tensor * (1 - not_done * lmbda)
1045
+
1046
+ t_flat, mask = _split_and_pad_sequence(t, num_per_traj, return_mask=True)
1047
+
1048
+ gammalmbdas = _geom_series_like(t_flat[0], gammalmbda, thr=thr)
1049
+
1050
+ ret_flat = _custom_conv1d(t_flat.unsqueeze(1), gammalmbdas)
1051
+ ret = ret_flat.squeeze(1)[mask]
1052
+
1053
+ return ret.view_as(reward).transpose(-1, -2)
1054
+
1055
+
1056
+ @_transpose_time
1057
+ def vec_td_lambda_return_estimate(
1058
+ gamma,
1059
+ lmbda,
1060
+ next_state_value,
1061
+ reward,
1062
+ done,
1063
+ terminated: torch.Tensor | None = None,
1064
+ rolling_gamma: bool | None = None,
1065
+ *,
1066
+ time_dim: int = -2,
1067
+ ):
1068
+ r"""Vectorized TD(:math:`\lambda`) return estimate.
1069
+
1070
+ Args:
1071
+ gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
1072
+ must be a [Batch x TimeSteps x 1] tensor.
1073
+ lmbda (scalar): trajectory discount.
1074
+ next_state_value (Tensor): value function result with new_state input.
1075
+ must be a [Batch x TimeSteps x 1] tensor
1076
+ reward (Tensor): reward of taking actions in the environment.
1077
+ must be a [Batch x TimeSteps x 1] or [Batch x TimeSteps] tensor
1078
+ done (Tensor): boolean flag for end of trajectory.
1079
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
1080
+ if not provided.
1081
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
1082
+ of a gamma tensor is tied to a single event:
1083
+
1084
+ >>> gamma = [g1, g2, g3, g4]
1085
+ >>> value = [v1, v2, v3, v4]
1086
+ >>> return = [
1087
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
1088
+ ... v2 + g2 v3 + g2 g3 v4,
1089
+ ... v3 + g3 v4,
1090
+ ... v4,
1091
+ ... ]
1092
+
1093
+ if ``False``, it is assumed that each gamma is tied to the upcoming
1094
+ trajectory:
1095
+
1096
+ >>> gamma = [g1, g2, g3, g4]
1097
+ >>> value = [v1, v2, v3, v4]
1098
+ >>> return = [
1099
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
1100
+ ... v2 + g2 v3 + g2**2 v4,
1101
+ ... v3 + g3 v4,
1102
+ ... v4,
1103
+ ... ]
1104
+
1105
+ Default is ``True``.
1106
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
1107
+
1108
+ All tensors (values, reward and done) must have shape
1109
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
1110
+
1111
+ """
1112
+ if terminated is None:
1113
+ terminated = done.clone()
1114
+ if not (next_state_value.shape == reward.shape == done.shape == terminated.shape):
1115
+ raise RuntimeError(SHAPE_ERR)
1116
+
1117
+ gamma_thr = 1e-7
1118
+ shape = next_state_value.shape
1119
+
1120
+ *batch, T, lastdim = shape
1121
+
1122
+ def _is_scalar(tensor):
1123
+ return not isinstance(tensor, torch.Tensor) or tensor.numel() == 1
1124
+
1125
+ # There are two use-cases: if gamma/lmbda are scalars we can use the
1126
+ # fast implementation, if not we must construct a gamma tensor.
1127
+ if _is_scalar(gamma) and _is_scalar(lmbda):
1128
+ return _fast_td_lambda_return_estimate(
1129
+ gamma=gamma,
1130
+ lmbda=lmbda,
1131
+ next_state_value=next_state_value,
1132
+ reward=reward,
1133
+ done=done,
1134
+ terminated=terminated,
1135
+ thr=gamma_thr,
1136
+ )
1137
+
1138
+ next_state_value = next_state_value.transpose(-2, -1).unsqueeze(-2)
1139
+ if len(batch):
1140
+ next_state_value = next_state_value.flatten(0, len(batch))
1141
+
1142
+ reward = reward.transpose(-2, -1).unsqueeze(-2)
1143
+ if len(batch):
1144
+ reward = reward.flatten(0, len(batch))
1145
+
1146
+ """Vectorized version of td_lambda_advantage_estimate"""
1147
+ device = reward.device
1148
+ not_done = (~done).int()
1149
+ not_terminated = (~terminated).int().transpose(-2, -1).unsqueeze(-2)
1150
+ if len(batch):
1151
+ not_terminated = not_terminated.flatten(0, len(batch))
1152
+ next_state_value = next_state_value * not_terminated
1153
+
1154
+ if rolling_gamma is None:
1155
+ rolling_gamma = True
1156
+ if not rolling_gamma and not is_dynamo_compiling():
1157
+ # Skip this validation during compile to avoid CUDA syncs
1158
+ terminated_follows_terminated = terminated[..., 1:, :][
1159
+ terminated[..., :-1, :]
1160
+ ].all()
1161
+ if not terminated_follows_terminated:
1162
+ raise NotImplementedError(
1163
+ "When using rolling_gamma=False and vectorized TD(lambda) with time-dependent gamma, "
1164
+ "make sure that conseducitve trajectories are separated as different batch "
1165
+ "items. Propagating a gamma value across trajectories is not permitted with "
1166
+ "this method. Check that you need to use rolling_gamma=False, and if so "
1167
+ "consider using the non-vectorized version of the return computation or splitting "
1168
+ "your trajectories."
1169
+ )
1170
+
1171
+ if rolling_gamma:
1172
+ # Make the coefficient table
1173
+ gammas = _make_gammas_tensor(gamma * not_done, T, rolling_gamma)
1174
+ gammas_cp = torch.cumprod(gammas, -2)
1175
+ lambdas = torch.ones(T + 1, 1, device=device)
1176
+ lambdas[1:] = lmbda
1177
+ lambdas_cp = torch.cumprod(lambdas, -2)
1178
+ lambdas = lambdas[1:]
1179
+ dec = gammas_cp * lambdas_cp
1180
+
1181
+ gammas = _make_gammas_tensor(gamma, T, rolling_gamma)
1182
+ gammas = gammas[..., 1:, :]
1183
+ if gammas.ndimension() == 4 and gammas.shape[1] > 1:
1184
+ gammas = gammas[:, :1]
1185
+ if lambdas.ndimension() == 4 and lambdas.shape[1] > 1:
1186
+ lambdas = lambdas[:, :1]
1187
+
1188
+ not_done = not_done.transpose(-2, -1).unsqueeze(-2)
1189
+ if len(batch):
1190
+ not_done = not_done.flatten(0, len(batch))
1191
+ # lambdas = lambdas * not_done
1192
+
1193
+ v3 = (gammas * lambdas).squeeze(-1) * next_state_value * not_done
1194
+ v3[..., :-1] = 0
1195
+ out = _custom_conv1d(
1196
+ reward
1197
+ + gammas.squeeze(-1)
1198
+ * next_state_value
1199
+ * (1 - lambdas.squeeze(-1) * not_done)
1200
+ + v3,
1201
+ dec,
1202
+ )
1203
+
1204
+ return out.view(*batch, lastdim, T).transpose(-2, -1)
1205
+ else:
1206
+ raise NotImplementedError(
1207
+ "The vectorized version of TD(lambda) with rolling_gamma=False is currently not available. "
1208
+ "To use this feature, use the non-vectorized version of TD(lambda). You can expect "
1209
+ "good speed improvements by decorating the function with torch.compile!"
1210
+ )
1211
+
1212
+
1213
+ def vec_td_lambda_advantage_estimate(
1214
+ gamma,
1215
+ lmbda,
1216
+ state_value,
1217
+ next_state_value,
1218
+ reward,
1219
+ done,
1220
+ terminated: torch.Tensor | None = None,
1221
+ rolling_gamma: bool | None = None,
1222
+ # not a kwarg because used directly
1223
+ time_dim: int = -2,
1224
+ ):
1225
+ r"""Vectorized TD(:math:`\lambda`) advantage estimate.
1226
+
1227
+ Args:
1228
+ gamma (scalar, Tensor): exponential mean discount. If tensor-valued,
1229
+ lmbda (scalar): trajectory discount.
1230
+ state_value (Tensor): value function result with old_state input.
1231
+ next_state_value (Tensor): value function result with new_state input.
1232
+ reward (Tensor): reward of taking actions in the environment.
1233
+ done (Tensor): boolean flag for end of trajectory.
1234
+ terminated (Tensor): boolean flag for the end of episode. Defaults to ``done``
1235
+ if not provided.
1236
+ rolling_gamma (bool, optional): if ``True``, it is assumed that each gamma
1237
+ of a gamma tensor is tied to a single event:
1238
+
1239
+ >>> gamma = [g1, g2, g3, g4]
1240
+ >>> value = [v1, v2, v3, v4]
1241
+ >>> return = [
1242
+ ... v1 + g1 v2 + g1 g2 v3 + g1 g2 g3 v4,
1243
+ ... v2 + g2 v3 + g2 g3 v4,
1244
+ ... v3 + g3 v4,
1245
+ ... v4,
1246
+ ... ]
1247
+
1248
+ if ``False``, it is assumed that each gamma is tied to the upcoming
1249
+ trajectory:
1250
+
1251
+ >>> gamma = [g1, g2, g3, g4]
1252
+ >>> value = [v1, v2, v3, v4]
1253
+ >>> return = [
1254
+ ... v1 + g1 v2 + g1**2 v3 + g**3 v4,
1255
+ ... v2 + g2 v3 + g2**2 v4,
1256
+ ... v3 + g3 v4,
1257
+ ... v4,
1258
+ ... ]
1259
+
1260
+ Default is ``True``.
1261
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
1262
+
1263
+ All tensors (values, reward and done) must have shape
1264
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
1265
+
1266
+ """
1267
+ if terminated is None:
1268
+ terminated = done.clone()
1269
+ if not (
1270
+ next_state_value.shape
1271
+ == state_value.shape
1272
+ == reward.shape
1273
+ == done.shape
1274
+ == terminated.shape
1275
+ ):
1276
+ raise RuntimeError(SHAPE_ERR)
1277
+ return (
1278
+ vec_td_lambda_return_estimate(
1279
+ gamma,
1280
+ lmbda,
1281
+ next_state_value,
1282
+ reward,
1283
+ done=done,
1284
+ terminated=terminated,
1285
+ rolling_gamma=rolling_gamma,
1286
+ time_dim=time_dim,
1287
+ )
1288
+ - state_value
1289
+ )
1290
+
1291
+
1292
+ ########################################################################
1293
+ # V-Trace
1294
+ # -----
1295
+
1296
+
1297
+ @_transpose_time
1298
+ def vtrace_advantage_estimate(
1299
+ gamma: float,
1300
+ log_pi: torch.Tensor,
1301
+ log_mu: torch.Tensor,
1302
+ state_value: torch.Tensor,
1303
+ next_state_value: torch.Tensor,
1304
+ reward: torch.Tensor,
1305
+ done: torch.Tensor,
1306
+ terminated: torch.Tensor | None = None,
1307
+ rho_thresh: float | torch.Tensor = 1.0,
1308
+ c_thresh: float | torch.Tensor = 1.0,
1309
+ # not a kwarg because used directly
1310
+ time_dim: int = -2,
1311
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1312
+ """Computes V-Trace off-policy actor critic targets.
1313
+
1314
+ Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
1315
+ https://arxiv.org/abs/1802.01561 for more context.
1316
+
1317
+ Args:
1318
+ gamma (scalar): exponential mean discount.
1319
+ log_pi (Tensor): collection actor log probability of taking actions in the environment.
1320
+ log_mu (Tensor): current actor log probability of taking actions in the environment.
1321
+ state_value (Tensor): value function result with state input.
1322
+ next_state_value (Tensor): value function result with next_state input.
1323
+ reward (Tensor): reward of taking actions in the environment.
1324
+ done (Tensor): boolean flag for end of episode.
1325
+ terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
1326
+ rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
1327
+ c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
1328
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
1329
+
1330
+ All tensors (values, reward and done) must have shape
1331
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
1332
+ """
1333
+ if not (next_state_value.shape == state_value.shape == reward.shape == done.shape):
1334
+ raise RuntimeError(SHAPE_ERR)
1335
+
1336
+ device = state_value.device
1337
+
1338
+ if not isinstance(rho_thresh, torch.Tensor):
1339
+ rho_thresh = torch.tensor(rho_thresh, device=device)
1340
+ if not isinstance(c_thresh, torch.Tensor):
1341
+ c_thresh = torch.tensor(c_thresh, device=device)
1342
+
1343
+ c_thresh = c_thresh.to(device)
1344
+ rho_thresh = rho_thresh.to(device)
1345
+
1346
+ not_done = (~done).int()
1347
+ not_terminated = not_done if terminated is None else (~terminated).int()
1348
+ *batch_size, time_steps, lastdim = not_done.shape
1349
+ done_discounts = gamma * not_done
1350
+ terminated_discounts = gamma * not_terminated
1351
+
1352
+ rho = (log_pi - log_mu).exp()
1353
+ clipped_rho = rho.clamp_max(rho_thresh)
1354
+ deltas = clipped_rho * (
1355
+ reward + terminated_discounts * next_state_value - state_value
1356
+ )
1357
+ clipped_c = rho.clamp_max(c_thresh)
1358
+
1359
+ vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])]
1360
+ for i in reversed(range(time_steps)):
1361
+ discount_t, c_t, delta_t = (
1362
+ done_discounts[..., i, :],
1363
+ clipped_c[..., i, :],
1364
+ deltas[..., i, :],
1365
+ )
1366
+ vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
1367
+ vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim)
1368
+ vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim])
1369
+ vs = vs_minus_v_xs + state_value
1370
+ vs_t_plus_1 = torch.cat(
1371
+ [vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim
1372
+ )
1373
+ advantages = clipped_rho * (
1374
+ reward + terminated_discounts * vs_t_plus_1 - state_value
1375
+ )
1376
+
1377
+ return advantages, vs
1378
+
1379
+
1380
+ ########################################################################
1381
+ # Reward to go
1382
+ # ------------
1383
+
1384
+
1385
+ @_transpose_time
1386
+ def reward2go(
1387
+ reward,
1388
+ done,
1389
+ gamma,
1390
+ *,
1391
+ time_dim: int = -2,
1392
+ ):
1393
+ """Compute the discounted cumulative sum of rewards given multiple trajectories and the episode ends.
1394
+
1395
+ Args:
1396
+ reward (torch.Tensor): A tensor containing the rewards
1397
+ received at each time step over multiple trajectories.
1398
+ done (Tensor): boolean flag for end of episode. Differs from
1399
+ truncated, where the episode did not end but was interrupted.
1400
+ gamma (:obj:`float`, optional): The discount factor to use for computing the
1401
+ discounted cumulative sum of rewards. Defaults to 1.0.
1402
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
1403
+
1404
+ Returns:
1405
+ torch.Tensor: A tensor of shape [B, T] containing the discounted cumulative
1406
+ sum of rewards (reward-to-go) at each time step.
1407
+
1408
+ Examples:
1409
+ >>> reward = torch.ones(1, 10)
1410
+ >>> done = torch.zeros(1, 10, dtype=torch.bool)
1411
+ >>> done[:, [3, 7]] = True
1412
+ >>> reward2go(reward, done, 0.99, time_dim=-1)
1413
+ tensor([[3.9404],
1414
+ [2.9701],
1415
+ [1.9900],
1416
+ [1.0000],
1417
+ [3.9404],
1418
+ [2.9701],
1419
+ [1.9900],
1420
+ [1.0000],
1421
+ [1.9900],
1422
+ [1.0000]])
1423
+
1424
+ """
1425
+ shape = reward.shape
1426
+ if shape != done.shape:
1427
+ raise ValueError(
1428
+ f"reward and done must share the same shape, got {reward.shape} and {done.shape}"
1429
+ )
1430
+ # flatten if needed
1431
+ if reward.ndim > 2:
1432
+ # we know time dim is at -2, let's put it at -3
1433
+ rflip = reward.transpose(-2, -3)
1434
+ rflip_shape = rflip.shape[-2:]
1435
+ r2go = reward2go(
1436
+ rflip.flatten(-2, -1), done.transpose(-2, -3).flatten(-2, -1), gamma=gamma
1437
+ ).unflatten(-1, rflip_shape)
1438
+ return r2go.transpose(-2, -3)
1439
+
1440
+ # place time at dim -1
1441
+ reward = reward.transpose(-2, -1)
1442
+ done = done.transpose(-2, -1)
1443
+
1444
+ num_per_traj = _get_num_per_traj(done)
1445
+ td0_flat = _split_and_pad_sequence(reward, num_per_traj)
1446
+ gammas = _geom_series_like(td0_flat[0], gamma, thr=1e-7)
1447
+ cumsum = _custom_conv1d(td0_flat.unsqueeze(1), gammas)
1448
+ cumsum = cumsum.squeeze(1)
1449
+ cumsum = _inv_pad_sequence(cumsum, num_per_traj)
1450
+ cumsum = cumsum.reshape_as(reward)
1451
+ cumsum = cumsum.transpose(-2, -1)
1452
+ if cumsum.shape != shape:
1453
+ try:
1454
+ cumsum = cumsum.reshape(shape)
1455
+ except RuntimeError:
1456
+ raise RuntimeError(
1457
+ f"Wrong shape for output reward2go: {cumsum.shape} when {shape} was expected."
1458
+ )
1459
+ return cumsum