torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2457 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Sequence
8
+
9
+ import torch
10
+ from tensordict import TensorDictBase, unravel_key
11
+ from tensordict.nn import (
12
+ CompositeDistribution,
13
+ dispatch,
14
+ TensorDictModule,
15
+ TensorDictModuleBase,
16
+ TensorDictModuleWrapper,
17
+ TensorDictSequential,
18
+ )
19
+ from tensordict.utils import expand_as_right, NestedKey
20
+ from torch import nn
21
+ from torch.distributions import Categorical
22
+
23
+ from torchrl._utils import _replace_last
24
+ from torchrl.data.tensor_specs import Composite, TensorSpec
25
+ from torchrl.data.utils import _process_action_space_spec
26
+ from torchrl.modules.tensordict_module.common import DistributionalDQNnet, SafeModule
27
+ from torchrl.modules.tensordict_module.probabilistic import (
28
+ SafeProbabilisticModule,
29
+ SafeProbabilisticTensorDictSequential,
30
+ )
31
+ from torchrl.modules.tensordict_module.sequence import SafeSequential
32
+
33
+
34
+ class Actor(SafeModule):
35
+ """General class for deterministic actors in RL.
36
+
37
+ The Actor class comes with default values for the out_keys (``["action"]``)
38
+ and if the spec is provided but not as a
39
+ :class:`~torchrl.data.Composite` object, it will be
40
+ automatically translated into ``spec = Composite(action=spec)``.
41
+
42
+ Args:
43
+ module (nn.Module): a :class:`~torch.nn.Module` used to map the input to
44
+ the output parameter space.
45
+ in_keys (iterable of str, optional): keys to be read from input
46
+ tensordict and passed to the module. If it
47
+ contains more than one element, the values will be passed in the
48
+ order given by the in_keys iterable.
49
+ Defaults to ``["observation"]``.
50
+ out_keys (iterable of str): keys to be written to the input tensordict.
51
+ The length of out_keys must match the
52
+ number of tensors returned by the embedded module. Using ``"_"`` as a
53
+ key avoid writing tensor to output.
54
+ Defaults to ``["action"]``.
55
+
56
+ Keyword Args:
57
+ spec (TensorSpec, optional): Keyword-only argument.
58
+ Specs of the output tensor. If the module
59
+ outputs multiple output tensors,
60
+ spec characterize the space of the first output tensor.
61
+ safe (bool): Keyword-only argument.
62
+ If ``True``, the value of the output is checked against the
63
+ input spec. Out-of-domain sampling can
64
+ occur because of exploration policies or numerical under/overflow
65
+ issues. If this value is out of bounds, it is projected back onto the
66
+ desired space using the :meth:`~torchrl.data.TensorSpec.project`
67
+ method. Default is ``False``.
68
+
69
+ Examples:
70
+ >>> import torch
71
+ >>> from tensordict import TensorDict
72
+ >>> from torchrl.data import Unbounded
73
+ >>> from torchrl.modules import Actor
74
+ >>> torch.manual_seed(0)
75
+ >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
76
+ >>> action_spec = Unbounded(4)
77
+ >>> module = torch.nn.Linear(4, 4)
78
+ >>> td_module = Actor(
79
+ ... module=module,
80
+ ... spec=action_spec,
81
+ ... )
82
+ >>> td_module(td)
83
+ TensorDict(
84
+ fields={
85
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
86
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
87
+ batch_size=torch.Size([3]),
88
+ device=None,
89
+ is_shared=False)
90
+ >>> print(td.get("action"))
91
+ tensor([[-1.3635, -0.0340, 0.1476, -1.3911],
92
+ [-0.1664, 0.5455, 0.2247, -0.4583],
93
+ [-0.2916, 0.2160, 0.5337, -0.5193]], grad_fn=<AddmmBackward0>)
94
+
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ module: nn.Module,
100
+ in_keys: Sequence[NestedKey] | None = None,
101
+ out_keys: Sequence[NestedKey] | None = None,
102
+ *,
103
+ spec: TensorSpec | None = None,
104
+ **kwargs,
105
+ ):
106
+ if in_keys is None:
107
+ in_keys = ["observation"]
108
+ if out_keys is None:
109
+ out_keys = ["action"]
110
+ if (
111
+ "action" in out_keys
112
+ and spec is not None
113
+ and not isinstance(spec, Composite)
114
+ ):
115
+ spec = Composite(action=spec)
116
+
117
+ super().__init__(
118
+ module,
119
+ in_keys=in_keys,
120
+ out_keys=out_keys,
121
+ spec=spec,
122
+ **kwargs,
123
+ )
124
+
125
+
126
+ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
127
+ """General class for probabilistic actors in RL.
128
+
129
+ The Actor class comes with default values for the out_keys (["action"])
130
+ and if the spec is provided but not as a Composite object, it will be
131
+ automatically translated into :obj:`spec = Composite(action=spec)`
132
+
133
+ Args:
134
+ module (nn.Module): a :class:`torch.nn.Module` used to map the input to
135
+ the output parameter space.
136
+ in_keys (str or iterable of str or dict): key(s) that will be read from the
137
+ input TensorDict and used to build the distribution. Importantly, if it's an
138
+ iterable of string or a string, those keys must match the keywords used by
139
+ the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for
140
+ the Normal distribution and similar. If in_keys is a dictionary,, the keys
141
+ are the keys of the distribution and the values are the keys in the
142
+ tensordict that will get match to the corresponding distribution keys.
143
+ out_keys (str or iterable of str): keys where the sampled values will be
144
+ written. Importantly, if these keys are found in the input TensorDict, the
145
+ sampling step will be skipped.
146
+ spec (TensorSpec, optional): keyword-only argument containing the specs
147
+ of the output tensor. If the module outputs multiple output tensors,
148
+ spec characterize the space of the first output tensor.
149
+ safe (bool): keyword-only argument. if ``True``, the value of the output is checked against the
150
+ input spec. Out-of-domain sampling can
151
+ occur because of exploration policies or numerical under/overflow
152
+ issues. If this value is out of bounds, it is projected back onto the
153
+ desired space using the :obj:`TensorSpec.project`
154
+ method. Default is ``False``.
155
+ default_interaction_type (tensordict.nn.InteractionType, optional): keyword-only argument.
156
+ Default method to be used to retrieve
157
+ the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.DETERMINISTIC``,
158
+ ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or
159
+ ``InteractionType.RANDOM`` (in which case the value is sampled
160
+ randomly from the distribution).
161
+ TorchRL's ``ExplorationType`` class is a proxy to ``InteractionType``.
162
+ Defaults to ``InteractionType.DETERMINISTIC``.
163
+
164
+ .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will
165
+ first look for the interaction mode dictated by the
166
+ :func:`~tensordict.nn.probabilistic.interaction_type`
167
+ global function. If this returns `None` (its default value), then the
168
+ `default_interaction_type` of the `ProbabilisticTDModule`
169
+ instance will be used. Note that
170
+ :class:`~torchrl.collectors.BaseCollector`
171
+ instances will use `set_interaction_type` to
172
+ :class:`tensordict.nn.InteractionType.RANDOM` by default.
173
+
174
+ distribution_class (Type, optional): keyword-only argument.
175
+ A :class:`torch.distributions.Distribution` class to
176
+ be used for sampling.
177
+ Default is :class:`tensordict.nn.distributions.Delta`.
178
+
179
+ .. note:: if ``distribution_class`` is of type :class:`~tensordict.nn.distributions.CompositeDistribution`,
180
+ the keys will be inferred from the ``distribution_map`` / ``name_map`` keyword arguments of that
181
+ distribution. If this distribution is used with another constructor (e.g., partial or lambda function)
182
+ then the out_keys will need to be provided explicitly.
183
+ Note also that actions will **not** be prefixed with an ``"action"`` key, see the example below
184
+ on how this can be achieved with a ``ProbabilisticActor``.
185
+
186
+ distribution_kwargs (dict, optional): keyword-only argument.
187
+ Keyword-argument pairs to be passed to the distribution.
188
+ return_log_prob (bool, optional): keyword-only argument.
189
+ If ``True``, the log-probability of the
190
+ distribution sample will be written in the tensordict with the key
191
+ `'sample_log_prob'`. Default is ``False``.
192
+ cache_dist (bool, optional): keyword-only argument.
193
+ EXPERIMENTAL: if ``True``, the parameters of the
194
+ distribution (i.e. the output of the module) will be written to the
195
+ tensordict along with the sample. Those parameters can be used to re-compute
196
+ the original distribution later on (e.g. to compute the divergence between
197
+ the distribution used to sample the action and the updated distribution in
198
+ PPO). Default is ``False``.
199
+ n_empirical_estimate (int, optional): keyword-only argument.
200
+ Number of samples to compute the empirical
201
+ mean when it is not available. Defaults to 1000.
202
+
203
+ Examples:
204
+ >>> import torch
205
+ >>> from tensordict import TensorDict
206
+ >>> from tensordict.nn import TensorDictModule
207
+ >>> from torchrl.data import Bounded
208
+ >>> from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal
209
+ >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
210
+ >>> action_spec = Bounded(shape=torch.Size([4]),
211
+ ... low=-1, high=1)
212
+ >>> module = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
213
+ >>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
214
+ >>> td_module = ProbabilisticActor(
215
+ ... module=tensordict_module,
216
+ ... spec=action_spec,
217
+ ... in_keys=["loc", "scale"],
218
+ ... distribution_class=TanhNormal,
219
+ ... )
220
+ >>> td = td_module(td)
221
+ >>> td
222
+ TensorDict(
223
+ fields={
224
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
225
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
226
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
227
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
228
+ batch_size=torch.Size([3]),
229
+ device=None,
230
+ is_shared=False)
231
+
232
+ Probabilistic actors also support compound actions through the
233
+ :class:`tensordict.nn.CompositeDistribution` class. This distribution takes
234
+ a tensordict as input (typically `"params"`) and reads it as a whole: the
235
+ content of this tensordict is the input to the distributions contained in the
236
+ compound one.
237
+
238
+ Examples:
239
+ >>> from tensordict import TensorDict
240
+ >>> from tensordict.nn import CompositeDistribution, TensorDictModule
241
+ >>> from torchrl.modules import ProbabilisticActor
242
+ >>> from torch import nn, distributions as d
243
+ >>> import torch
244
+ >>>
245
+ >>> class Module(nn.Module):
246
+ ... def forward(self, x):
247
+ ... return x[..., :3], x[..., 3:6], x[..., 6:]
248
+ >>> module = TensorDictModule(Module(),
249
+ ... in_keys=["x"],
250
+ ... out_keys=[("params", "normal", "loc"),
251
+ ... ("params", "normal", "scale"),
252
+ ... ("params", "categ", "logits")])
253
+ >>> actor = ProbabilisticActor(module,
254
+ ... in_keys=["params"],
255
+ ... distribution_class=CompositeDistribution,
256
+ ... distribution_kwargs={"distribution_map": {
257
+ ... "normal": d.Normal, "categ": d.Categorical}}
258
+ ... )
259
+ >>> data = TensorDict({"x": torch.rand(10)}, [])
260
+ >>> actor(data)
261
+ TensorDict(
262
+ fields={
263
+ categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
264
+ normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
265
+ params: TensorDict(
266
+ fields={
267
+ categ: TensorDict(
268
+ fields={
269
+ logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
270
+ batch_size=torch.Size([]),
271
+ device=None,
272
+ is_shared=False),
273
+ normal: TensorDict(
274
+ fields={
275
+ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
276
+ scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
277
+ batch_size=torch.Size([]),
278
+ device=None,
279
+ is_shared=False)},
280
+ batch_size=torch.Size([]),
281
+ device=None,
282
+ is_shared=False),
283
+ x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
284
+ batch_size=torch.Size([]),
285
+ device=None,
286
+ is_shared=False)
287
+
288
+ Using a probabilistic actor with a composite distribution can be achieved using the following
289
+ example code:
290
+
291
+ Examples:
292
+ >>> import torch
293
+ >>> from tensordict import TensorDict
294
+ >>> from tensordict.nn import CompositeDistribution
295
+ >>> from tensordict.nn import TensorDictModule
296
+ >>> from torch import distributions as d
297
+ >>> from torch import nn
298
+ >>>
299
+ >>> from torchrl.modules import ProbabilisticActor
300
+ >>>
301
+ >>>
302
+ >>> class Module(nn.Module):
303
+ ... def forward(self, x):
304
+ ... return x[..., :3], x[..., 3:6], x[..., 6:]
305
+ ...
306
+ >>>
307
+ >>> module = TensorDictModule(Module(),
308
+ ... in_keys=["x"],
309
+ ... out_keys=[
310
+ ... ("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits")
311
+ ... ])
312
+ >>> actor = ProbabilisticActor(module,
313
+ ... in_keys=["params"],
314
+ ... distribution_class=CompositeDistribution,
315
+ ... distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical},
316
+ ... "name_map": {"normal": ("action", "normal"),
317
+ ... "categ": ("action", "categ")}}
318
+ ... )
319
+ >>> print(actor.out_keys)
320
+ [('params', 'normal', 'loc'), ('params', 'normal', 'scale'), ('params', 'categ', 'logits'), ('action', 'normal'), ('action', 'categ')]
321
+ >>>
322
+ >>> data = TensorDict({"x": torch.rand(10)}, [])
323
+ >>> module(data)
324
+ >>> print(actor(data))
325
+ TensorDict(
326
+ fields={
327
+ action: TensorDict(
328
+ fields={
329
+ categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
330
+ normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
331
+ batch_size=torch.Size([]),
332
+ device=None,
333
+ is_shared=False),
334
+ params: TensorDict(
335
+ fields={
336
+ categ: TensorDict(
337
+ fields={
338
+ logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
339
+ batch_size=torch.Size([]),
340
+ device=None,
341
+ is_shared=False),
342
+ normal: TensorDict(
343
+ fields={
344
+ loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
345
+ scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
346
+ batch_size=torch.Size([]),
347
+ device=None,
348
+ is_shared=False)},
349
+ batch_size=torch.Size([]),
350
+ device=None,
351
+ is_shared=False),
352
+ x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
353
+ batch_size=torch.Size([]),
354
+ device=None,
355
+ is_shared=False)
356
+
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ module: TensorDictModule,
362
+ in_keys: NestedKey | Sequence[NestedKey],
363
+ out_keys: Sequence[NestedKey] | None = None,
364
+ *,
365
+ spec: TensorSpec | None = None,
366
+ **kwargs,
367
+ ):
368
+ distribution_class = kwargs.get("distribution_class")
369
+ if out_keys is None:
370
+ if distribution_class is CompositeDistribution:
371
+ if "distribution_map" not in kwargs.get("distribution_kwargs", {}):
372
+ raise KeyError(
373
+ "'distribution_map' must be provided within "
374
+ "distribution_kwargs whenever the distribution is of type CompositeDistribution."
375
+ )
376
+ distribution_map = kwargs["distribution_kwargs"]["distribution_map"]
377
+ name_map = kwargs["distribution_kwargs"].get("name_map", None)
378
+ if name_map is not None:
379
+ out_keys = list(name_map.values())
380
+ else:
381
+ out_keys = list(distribution_map.keys())
382
+ else:
383
+ out_keys = ["action"]
384
+ if len(out_keys) == 1 and spec is not None and not isinstance(spec, Composite):
385
+ spec = Composite({out_keys[0]: spec})
386
+
387
+ super().__init__(
388
+ module,
389
+ SafeProbabilisticModule(
390
+ in_keys=in_keys, out_keys=out_keys, spec=spec, **kwargs
391
+ ),
392
+ )
393
+
394
+
395
+ class ValueOperator(TensorDictModule):
396
+ """General class for value functions in RL.
397
+
398
+ The ValueOperator class comes with default values for the in_keys and
399
+ out_keys arguments (["observation"] and ["state_value"] or
400
+ ["state_action_value"], respectively and depending on whether the "action"
401
+ key is part of the in_keys list).
402
+
403
+ Args:
404
+ module (nn.Module): a :class:`torch.nn.Module` used to map the input to
405
+ the output parameter space.
406
+ in_keys (iterable of str, optional): keys to be read from input
407
+ tensordict and passed to the module. If it
408
+ contains more than one element, the values will be passed in the
409
+ order given by the in_keys iterable.
410
+ Defaults to ``["observation"]``.
411
+ out_keys (iterable of str): keys to be written to the input tensordict.
412
+ The length of out_keys must match the
413
+ number of tensors returned by the embedded module. Using "_" as a
414
+ key avoid writing tensor to output.
415
+ Defaults to ``["state_value"]`` or
416
+ ``["state_action_value"]`` if ``"action"`` is part of the ``in_keys``.
417
+
418
+ Examples:
419
+ >>> import torch
420
+ >>> from tensordict import TensorDict
421
+ >>> from torch import nn
422
+ >>> from torchrl.data import Unbounded
423
+ >>> from torchrl.modules import ValueOperator
424
+ >>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,])
425
+ >>> class CustomModule(nn.Module):
426
+ ... def __init__(self):
427
+ ... super().__init__()
428
+ ... self.linear = torch.nn.Linear(6, 1)
429
+ ... def forward(self, obs, action):
430
+ ... return self.linear(torch.cat([obs, action], -1))
431
+ >>> module = CustomModule()
432
+ >>> td_module = ValueOperator(
433
+ ... in_keys=["observation", "action"], module=module
434
+ ... )
435
+ >>> td = td_module(td)
436
+ >>> print(td)
437
+ TensorDict(
438
+ fields={
439
+ action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
440
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
441
+ state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
442
+ batch_size=torch.Size([3]),
443
+ device=None,
444
+ is_shared=False)
445
+
446
+
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ module: nn.Module,
452
+ in_keys: Sequence[NestedKey] | None = None,
453
+ out_keys: Sequence[NestedKey] | None = None,
454
+ ) -> None:
455
+ if in_keys is None:
456
+ in_keys = ["observation"]
457
+ if out_keys is None:
458
+ out_keys = (
459
+ ["state_value"] if "action" not in in_keys else ["state_action_value"]
460
+ )
461
+ super().__init__(
462
+ module=module,
463
+ in_keys=in_keys,
464
+ out_keys=out_keys,
465
+ )
466
+
467
+
468
+ class QValueModule(TensorDictModuleBase):
469
+ """Q-Value TensorDictModule for Q-value policies.
470
+
471
+ This module processes a tensor containing action value into is argmax
472
+ component (i.e. the resulting greedy action), following a given
473
+ action space (one-hot, binary or categorical).
474
+ It works with both tensordict and regular tensors.
475
+
476
+ Args:
477
+ action_space (str, optional): Action space. Must be one of
478
+ ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
479
+ This argument is exclusive with ``spec``, since ``spec``
480
+ conditions the action_space.
481
+ action_value_key (str or tuple of str, optional): The input key
482
+ representing the action value. Defaults to ``"action_value"``.
483
+ action_mask_key (str or tuple of str, optional): The input key
484
+ representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
485
+ out_keys (list of str or tuple of str, optional): The output keys
486
+ representing the actions, action values and chosen action value.
487
+ Defaults to ``["action", "action_value", "chosen_action_value"]``.
488
+ var_nums (int, optional): if ``action_space = "mult-one-hot"``,
489
+ this value represents the cardinality of each
490
+ action component.
491
+ spec (TensorSpec, optional): if provided, the specs of the action (and/or
492
+ other outputs). This is exclusive with ``action_space``, as the spec
493
+ conditions the action space.
494
+ safe (bool): if ``True``, the value of the output is checked against the
495
+ input spec. Out-of-domain sampling can
496
+ occur because of exploration policies or numerical under/overflow issues.
497
+ If this value is out of bounds, it is projected back onto the
498
+ desired space using the :obj:`TensorSpec.project`
499
+ method. Default is ``False``.
500
+
501
+ Returns:
502
+ if the input is a single tensor, a triplet containing the chosen action,
503
+ the values and the value of the chose action is returned. If a tensordict
504
+ is provided, it is updated with these entries at the keys indicated by the
505
+ ``out_keys`` field.
506
+
507
+ Examples:
508
+ >>> from tensordict import TensorDict
509
+ >>> action_space = "categorical"
510
+ >>> action_value_key = "my_action_value"
511
+ >>> actor = QValueModule(action_space, action_value_key=action_value_key)
512
+ >>> # This module works with both tensordict and regular tensors:
513
+ >>> value = torch.zeros(4)
514
+ >>> value[-1] = 1
515
+ >>> actor(my_action_value=value)
516
+ (tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
517
+ >>> actor(value)
518
+ (tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
519
+ >>> actor(TensorDict({action_value_key: value}, []))
520
+ TensorDict(
521
+ fields={
522
+ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
523
+ action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
524
+ chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
525
+ my_action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
526
+ batch_size=torch.Size([]),
527
+ device=None,
528
+ is_shared=False)
529
+
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ action_space: str | None = None,
535
+ action_value_key: NestedKey | None = None,
536
+ action_mask_key: NestedKey | None = None,
537
+ out_keys: Sequence[NestedKey] | None = None,
538
+ var_nums: int | None = None,
539
+ spec: TensorSpec | None = None,
540
+ safe: bool = False,
541
+ ):
542
+ if isinstance(action_space, TensorSpec):
543
+ raise TypeError("Using specs in action_space is deprecated")
544
+ action_space, spec = _process_action_space_spec(action_space, spec)
545
+ self.action_space = action_space
546
+ self.var_nums = var_nums
547
+ self.action_func_mapping = {
548
+ "one_hot": self._one_hot,
549
+ "mult_one_hot": self._mult_one_hot,
550
+ "binary": self._binary,
551
+ "categorical": self._categorical,
552
+ }
553
+ self.action_value_func_mapping = {
554
+ "categorical": self._categorical_action_value,
555
+ }
556
+ if action_space not in self.action_func_mapping:
557
+ raise ValueError(
558
+ f"action_space must be one of {list(self.action_func_mapping.keys())}, got {action_space}"
559
+ )
560
+ if action_value_key is None:
561
+ action_value_key = "action_value"
562
+ self.action_mask_key = action_mask_key
563
+ in_keys = [action_value_key]
564
+ if self.action_mask_key is not None:
565
+ in_keys.append(self.action_mask_key)
566
+ self.in_keys = in_keys
567
+ if out_keys is None:
568
+ out_keys = ["action", action_value_key, "chosen_action_value"]
569
+ elif action_value_key not in out_keys:
570
+ raise RuntimeError(
571
+ f"Expected the action-value key to be '{action_value_key}' but got {out_keys[1]} instead."
572
+ )
573
+ self.out_keys = out_keys
574
+ action_key = out_keys[0]
575
+ if not isinstance(spec, Composite):
576
+ spec = Composite({action_key: spec})
577
+ super().__init__()
578
+ self.register_spec(safe=safe, spec=spec)
579
+
580
+ register_spec = SafeModule.register_spec
581
+
582
+ @property
583
+ def spec(self) -> Composite:
584
+ return self._spec
585
+
586
+ @spec.setter
587
+ def spec(self, spec: Composite) -> None:
588
+ if not isinstance(spec, Composite):
589
+ raise RuntimeError(
590
+ f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance."
591
+ )
592
+ self._spec = spec
593
+
594
+ @property
595
+ def action_value_key(self):
596
+ return self.in_keys[0]
597
+
598
+ @dispatch(auto_batch_size=False)
599
+ def forward(self, tensordict: torch.Tensor) -> TensorDictBase:
600
+ action_values = tensordict.get(self.action_value_key, None)
601
+ if action_values is None:
602
+ raise KeyError(
603
+ f"Action value key {self.action_value_key} not found in {tensordict}."
604
+ )
605
+ if self.action_mask_key is not None:
606
+ action_mask = tensordict.get(self.action_mask_key, None)
607
+ if action_mask is None:
608
+ raise KeyError(
609
+ f"Action mask key {self.action_mask_key} not found in {tensordict}."
610
+ )
611
+ action_values = torch.where(
612
+ action_mask, action_values, torch.finfo(action_values.dtype).min
613
+ )
614
+
615
+ action = self.action_func_mapping[self.action_space](action_values)
616
+
617
+ action_value_func = self.action_value_func_mapping.get(
618
+ self.action_space, self._default_action_value
619
+ )
620
+ chosen_action_value = action_value_func(action_values, action)
621
+ tensordict.update(
622
+ dict(zip(self.out_keys, (action, action_values, chosen_action_value)))
623
+ )
624
+ return tensordict
625
+
626
+ @staticmethod
627
+ def _one_hot(value: torch.Tensor) -> torch.Tensor:
628
+ out = (value == value.max(dim=-1, keepdim=True)[0]).to(torch.long)
629
+ return out
630
+
631
+ @staticmethod
632
+ def _categorical(value: torch.Tensor) -> torch.Tensor:
633
+ return torch.argmax(value, dim=-1).to(torch.long)
634
+
635
+ def _mult_one_hot(
636
+ self, value: torch.Tensor, support: torch.Tensor = None
637
+ ) -> torch.Tensor:
638
+ if self.var_nums is None:
639
+ raise ValueError(
640
+ "var_nums must be provided to the constructor for multi one-hot action spaces."
641
+ )
642
+ values = value.split(self.var_nums, dim=-1)
643
+ return torch.cat(
644
+ [
645
+ self._one_hot(
646
+ _value,
647
+ )
648
+ for _value in values
649
+ ],
650
+ -1,
651
+ )
652
+
653
+ @staticmethod
654
+ def _binary(value: torch.Tensor, support: torch.Tensor) -> torch.Tensor:
655
+ raise NotImplementedError
656
+
657
+ @staticmethod
658
+ def _default_action_value(
659
+ values: torch.Tensor, action: torch.Tensor
660
+ ) -> torch.Tensor:
661
+ return (action * values).sum(-1, True)
662
+
663
+ @staticmethod
664
+ def _categorical_action_value(
665
+ values: torch.Tensor, action: torch.Tensor
666
+ ) -> torch.Tensor:
667
+ return values.gather(-1, action.unsqueeze(-1))
668
+ # if values.ndim == 1:
669
+ # return values[action].unsqueeze(-1)
670
+ # batch_size = values.size(0)
671
+ # return values[range(batch_size), action].unsqueeze(-1)
672
+
673
+
674
+ class DistributionalQValueModule(QValueModule):
675
+ """Distributional Q-Value hook for Q-value policies.
676
+
677
+ This module processes a tensor containing action value logits into is argmax
678
+ component (i.e. the resulting greedy action), following a given
679
+ action space (one-hot, binary or categorical).
680
+ It works with both tensordict and regular tensors.
681
+
682
+ The input action value is expected to be the result of a log-softmax
683
+ operation.
684
+
685
+ For more details regarding Distributional DQN, refer to "A Distributional Perspective on Reinforcement Learning",
686
+ https://arxiv.org/pdf/1707.06887.pdf
687
+
688
+ Args:
689
+ action_space (str, optional): Action space. Must be one of
690
+ ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
691
+ This argument is exclusive with ``spec``, since ``spec``
692
+ conditions the action_space.
693
+ support (torch.Tensor): support of the action values.
694
+ action_value_key (str or tuple of str, optional): The input key
695
+ representing the action value. Defaults to ``"action_value"``.
696
+ action_mask_key (str or tuple of str, optional): The input key
697
+ representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
698
+ out_keys (list of str or tuple of str, optional): The output keys
699
+ representing the actions and action values.
700
+ Defaults to ``["action", "action_value"]``.
701
+ var_nums (int, optional): if ``action_space = "mult-one-hot"``,
702
+ this value represents the cardinality of each
703
+ action component.
704
+ spec (TensorSpec, optional): if provided, the specs of the action (and/or
705
+ other outputs). This is exclusive with ``action_space``, as the spec
706
+ conditions the action space.
707
+ safe (bool): if ``True``, the value of the output is checked against the
708
+ input spec. Out-of-domain sampling can
709
+ occur because of exploration policies or numerical under/overflow issues.
710
+ If this value is out of bounds, it is projected back onto the
711
+ desired space using the :obj:`TensorSpec.project`
712
+ method. Default is ``False``.
713
+
714
+ Examples:
715
+ >>> from tensordict import TensorDict
716
+ >>> torch.manual_seed(0)
717
+ >>> action_space = "categorical"
718
+ >>> action_value_key = "my_action_value"
719
+ >>> support = torch.tensor([-1, 0.0, 1.0]) # the action value is between -1 and 1
720
+ >>> actor = DistributionalQValueModule(action_space, support=support, action_value_key=action_value_key)
721
+ >>> # This module works with both tensordict and regular tensors:
722
+ >>> value = torch.full((3, 4), -100)
723
+ >>> # the first bin (-1) of the first action is high: there's a high chance that it has a low value
724
+ >>> value[0, 0] = 0
725
+ >>> # the second bin (0) of the second action is high: there's a high chance that it has an intermediate value
726
+ >>> value[1, 1] = 0
727
+ >>> # the third bin (0) of the this action is high: there's a high chance that it has an high value
728
+ >>> value[2, 2] = 0
729
+ >>> actor(my_action_value=value)
730
+ (tensor(2), tensor([[ 0, -100, -100, -100],
731
+ [-100, 0, -100, -100],
732
+ [-100, -100, 0, -100]]))
733
+ >>> actor(value)
734
+ (tensor(2), tensor([[ 0, -100, -100, -100],
735
+ [-100, 0, -100, -100],
736
+ [-100, -100, 0, -100]]))
737
+ >>> actor(TensorDict({action_value_key: value}, []))
738
+ TensorDict(
739
+ fields={
740
+ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
741
+ my_action_value: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
742
+ batch_size=torch.Size([]),
743
+ device=None,
744
+ is_shared=False)
745
+
746
+ """
747
+
748
+ def __init__(
749
+ self,
750
+ action_space: str | None,
751
+ support: torch.Tensor,
752
+ action_value_key: NestedKey | None = None,
753
+ action_mask_key: NestedKey | None = None,
754
+ out_keys: Sequence[NestedKey] | None = None,
755
+ var_nums: int | None = None,
756
+ spec: TensorSpec = None,
757
+ safe: bool = False,
758
+ ):
759
+ if action_value_key is None:
760
+ action_value_key = "action_value"
761
+ if out_keys is None:
762
+ out_keys = ["action", action_value_key]
763
+ super().__init__(
764
+ action_space=action_space,
765
+ action_value_key=action_value_key,
766
+ action_mask_key=action_mask_key,
767
+ out_keys=out_keys,
768
+ var_nums=var_nums,
769
+ spec=spec,
770
+ safe=safe,
771
+ )
772
+ self.register_buffer("support", support)
773
+
774
+ @dispatch(auto_batch_size=False)
775
+ def forward(self, tensordict: torch.Tensor) -> TensorDictBase:
776
+ action_values = tensordict.get(self.action_value_key, None)
777
+ if action_values is None:
778
+ raise KeyError(
779
+ f"Action value key {self.action_value_key} not found in {tensordict}."
780
+ )
781
+ if self.action_mask_key is not None:
782
+ action_mask = tensordict.get(self.action_mask_key, None)
783
+ if action_mask is None:
784
+ raise KeyError(
785
+ f"Action mask key {self.action_mask_key} not found in {tensordict}."
786
+ )
787
+ action_values = torch.where(
788
+ action_mask, action_values, torch.finfo(action_values.dtype).min
789
+ )
790
+
791
+ action = self.action_func_mapping[self.action_space](action_values)
792
+
793
+ tensordict.update(
794
+ dict(
795
+ zip(
796
+ self.out_keys,
797
+ (
798
+ action,
799
+ action_values,
800
+ ),
801
+ )
802
+ )
803
+ )
804
+ return tensordict
805
+
806
+ def _support_expected(
807
+ self, log_softmax_values: torch.Tensor, support=None
808
+ ) -> torch.Tensor:
809
+ if support is None:
810
+ support = self.support
811
+ support = support.to(log_softmax_values.device)
812
+ if log_softmax_values.shape[-2] != support.shape[-1]:
813
+ raise RuntimeError(
814
+ "Support length and number of atoms in module output should match, "
815
+ f"got self.support.shape={support.shape} and module(...).shape={log_softmax_values.shape}"
816
+ )
817
+ if (log_softmax_values > 0).any():
818
+ raise ValueError(
819
+ f"input to QValueHook must be log-softmax values (which are expected to be non-positive numbers). "
820
+ f"got a maximum value of {log_softmax_values.max():4.4f}"
821
+ )
822
+ return (log_softmax_values.exp() * support.unsqueeze(-1)).sum(-2)
823
+
824
+ def _one_hot(self, value: torch.Tensor, support=None) -> torch.Tensor:
825
+ if support is None:
826
+ support = self.support
827
+ if not isinstance(value, torch.Tensor):
828
+ raise TypeError(f"got value of type {value.__class__.__name__}")
829
+ if not isinstance(support, torch.Tensor):
830
+ raise TypeError(f"got support of type {support.__class__.__name__}")
831
+ value = self._support_expected(value)
832
+ out = (value == value.max(dim=-1, keepdim=True)[0]).to(torch.long)
833
+ return out
834
+
835
+ def _mult_one_hot(self, value: torch.Tensor, support=None) -> torch.Tensor:
836
+ if support is None:
837
+ support = self.support
838
+ values = value.split(self.var_nums, dim=-1)
839
+ return torch.cat(
840
+ [
841
+ self._one_hot(_value, _support)
842
+ for _value, _support in zip(values, support)
843
+ ],
844
+ -1,
845
+ )
846
+
847
+ def _categorical(
848
+ self,
849
+ value: torch.Tensor,
850
+ ) -> torch.Tensor:
851
+ value = self._support_expected(
852
+ value,
853
+ )
854
+ return torch.argmax(value, dim=-1).to(torch.long)
855
+
856
+ def _binary(self, value: torch.Tensor) -> torch.Tensor:
857
+ raise NotImplementedError(
858
+ "'binary' is currently not supported for DistributionalQValueModule."
859
+ )
860
+
861
+
862
+ class QValueHook:
863
+ """Q-Value hook for Q-value policies.
864
+
865
+ Given the output of a regular nn.Module, representing the values of the
866
+ different discrete actions available,
867
+ a QValueHook will transform these values into their argmax component (i.e.
868
+ the resulting greedy action).
869
+
870
+ Args:
871
+ action_space (str): Action space. Must be one of
872
+ ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
873
+ var_nums (int, optional): if ``action_space = "mult-one-hot"``,
874
+ this value represents the cardinality of each
875
+ action component.
876
+ action_value_key (str or tuple of str, optional): to be used when hooked on
877
+ a TensorDictModule. The input key representing the action value. Defaults
878
+ to ``"action_value"``.
879
+ action_mask_key (str or tuple of str, optional): The input key
880
+ representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
881
+ out_keys (list of str or tuple of str, optional): to be used when hooked on
882
+ a TensorDictModule. The output keys representing the actions, action values
883
+ and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``.
884
+
885
+ Examples:
886
+ >>> import torch
887
+ >>> from tensordict import TensorDict
888
+ >>> from torch import nn
889
+ >>> from torchrl.data import OneHot
890
+ >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
891
+ >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
892
+ >>> module = nn.Linear(4, 4)
893
+ >>> hook = QValueHook("one_hot")
894
+ >>> module.register_forward_hook(hook)
895
+ >>> action_spec = OneHot(4)
896
+ >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
897
+ >>> td = qvalue_actor(td)
898
+ >>> print(td)
899
+ TensorDict(
900
+ fields={
901
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
902
+ action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
903
+ observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
904
+ batch_size=torch.Size([5]),
905
+ device=None,
906
+ is_shared=False)
907
+
908
+ """
909
+
910
+ def __init__(
911
+ self,
912
+ action_space: str,
913
+ var_nums: int | None = None,
914
+ action_value_key: NestedKey | None = None,
915
+ action_mask_key: NestedKey | None = None,
916
+ out_keys: Sequence[NestedKey] | None = None,
917
+ ):
918
+ if isinstance(action_space, TensorSpec):
919
+ raise RuntimeError(
920
+ "Using specs in action_space is deprecated. "
921
+ "Please use the 'spec' argument if you want to provide an action spec"
922
+ )
923
+ action_space, _ = _process_action_space_spec(action_space, None)
924
+
925
+ self.qvalue_model = QValueModule(
926
+ action_space=action_space,
927
+ var_nums=var_nums,
928
+ action_value_key=action_value_key,
929
+ action_mask_key=action_mask_key,
930
+ out_keys=out_keys,
931
+ )
932
+ action_value_key = self.qvalue_model.in_keys[0]
933
+ if isinstance(action_value_key, tuple):
934
+ action_value_key = "_".join(action_value_key)
935
+ # uses "dispatch" to get and return tensors
936
+ self.action_value_key = action_value_key
937
+
938
+ def __call__(
939
+ self, net: nn.Module, observation: torch.Tensor, values: torch.Tensor
940
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
941
+ kwargs = {self.action_value_key: values}
942
+ return self.qvalue_model(**kwargs)
943
+
944
+
945
+ class DistributionalQValueHook(QValueHook):
946
+ """Distributional Q-Value hook for Q-value policies.
947
+
948
+ Given the output of a mapping operator, representing the log-probability of the
949
+ different action value bin available,
950
+ a DistributionalQValueHook will transform these values into their argmax
951
+ component using the provided support.
952
+
953
+ For more details regarding Distributional DQN, refer to "A Distributional Perspective on Reinforcement Learning",
954
+ https://arxiv.org/pdf/1707.06887.pdf
955
+
956
+ Args:
957
+ action_space (str): Action space. Must be one of
958
+ ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
959
+ action_value_key (str or tuple of str, optional): to be used when hooked on
960
+ a TensorDictModule. The input key representing the action value. Defaults
961
+ to ``"action_value"``.
962
+ action_mask_key (str or tuple of str, optional): The input key
963
+ representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
964
+ support (torch.Tensor): support of the action values.
965
+ var_nums (int, optional): if ``action_space = "mult-one-hot"``, this
966
+ value represents the cardinality of each
967
+ action component.
968
+
969
+ Examples:
970
+ >>> import torch
971
+ >>> from tensordict import TensorDict
972
+ >>> from torch import nn
973
+ >>> from torchrl.data import OneHot
974
+ >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
975
+ >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
976
+ >>> nbins = 3
977
+ >>> class CustomDistributionalQval(nn.Module):
978
+ ... def __init__(self):
979
+ ... super().__init__()
980
+ ... self.linear = nn.Linear(4, nbins*4)
981
+ ...
982
+ ... def forward(self, x):
983
+ ... return self.linear(x).view(-1, nbins, 4).log_softmax(-2)
984
+ ...
985
+ >>> module = CustomDistributionalQval()
986
+ >>> params = TensorDict.from_module(module)
987
+ >>> action_spec = OneHot(4)
988
+ >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
989
+ >>> module.register_forward_hook(hook)
990
+ >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
991
+ >>> with params.to_module(module):
992
+ ... qvalue_actor(td)
993
+ >>> print(td)
994
+ TensorDict(
995
+ fields={
996
+ action: Tensor(torch.Size([5, 4]), dtype=torch.int64),
997
+ action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32),
998
+ observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)},
999
+ batch_size=torch.Size([5]),
1000
+ device=None,
1001
+ is_shared=False)
1002
+
1003
+ """
1004
+
1005
+ def __init__(
1006
+ self,
1007
+ action_space: str,
1008
+ support: torch.Tensor,
1009
+ var_nums: int | None = None,
1010
+ action_value_key: NestedKey | None = None,
1011
+ action_mask_key: NestedKey | None = None,
1012
+ out_keys: Sequence[NestedKey] | None = None,
1013
+ ):
1014
+ if isinstance(action_space, TensorSpec):
1015
+ raise RuntimeError("Using specs in action_space is deprecated")
1016
+ action_space, _ = _process_action_space_spec(action_space, None)
1017
+ self.qvalue_model = DistributionalQValueModule(
1018
+ action_space=action_space,
1019
+ var_nums=var_nums,
1020
+ support=support,
1021
+ action_value_key=action_value_key,
1022
+ action_mask_key=action_mask_key,
1023
+ out_keys=out_keys,
1024
+ )
1025
+ action_value_key = self.qvalue_model.in_keys[0]
1026
+ if isinstance(action_value_key, tuple):
1027
+ action_value_key = "_".join(action_value_key)
1028
+ # uses "dispatch" to get and return tensors
1029
+ self.action_value_key = action_value_key
1030
+
1031
+
1032
+ class QValueActor(SafeSequential):
1033
+ """A Q-Value actor class.
1034
+
1035
+ This class appends a :class:`~.QValueModule` after the input module
1036
+ such that the action values are used to select an action.
1037
+
1038
+ Args:
1039
+ module (nn.Module): a :class:`torch.nn.Module` used to map the input to
1040
+ the output parameter space. If the class provided is not compatible
1041
+ with :class:`tensordict.nn.TensorDictModuleBase`, it will be
1042
+ wrapped in a :class:`tensordict.nn.TensorDictModule` with
1043
+ ``in_keys`` indicated by the following keyword argument.
1044
+
1045
+ Keyword Args:
1046
+ in_keys (iterable of str, optional): If the class provided is not
1047
+ compatible with :class:`tensordict.nn.TensorDictModuleBase`, this
1048
+ list of keys indicates what observations need to be passed to the
1049
+ wrapped module to get the action values.
1050
+ Defaults to ``["observation"]``.
1051
+ spec (TensorSpec, optional): Keyword-only argument.
1052
+ Specs of the output tensor. If the module
1053
+ outputs multiple output tensors,
1054
+ spec characterize the space of the first output tensor.
1055
+ safe (bool): Keyword-only argument.
1056
+ If ``True``, the value of the output is checked against the
1057
+ input spec. Out-of-domain sampling can
1058
+ occur because of exploration policies or numerical under/overflow
1059
+ issues. If this value is out of bounds, it is projected back onto the
1060
+ desired space using the :obj:`TensorSpec.project`
1061
+ method. Default is ``False``.
1062
+ action_space (str, optional): Action space. Must be one of
1063
+ ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
1064
+ This argument is exclusive with ``spec``, since ``spec``
1065
+ conditions the action_space.
1066
+ action_value_key (str or tuple of str, optional): if the input module
1067
+ is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must
1068
+ match one of its output keys. Otherwise, this string represents
1069
+ the name of the action-value entry in the output tensordict.
1070
+ action_mask_key (str or tuple of str, optional): The input key
1071
+ representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
1072
+
1073
+ .. note::
1074
+ ``out_keys`` cannot be passed. If the module is a :class:`tensordict.nn.TensorDictModule`
1075
+ instance, the out_keys will be updated accordingly. For regular
1076
+ :class:`torch.nn.Module` instance, the triplet ``["action", action_value_key, "chosen_action_value"]``
1077
+ will be used.
1078
+
1079
+ Examples:
1080
+ >>> import torch
1081
+ >>> from tensordict import TensorDict
1082
+ >>> from torch import nn
1083
+ >>> from torchrl.data import OneHot
1084
+ >>> from torchrl.modules.tensordict_module.actors import QValueActor
1085
+ >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
1086
+ >>> # with a regular nn.Module
1087
+ >>> module = nn.Linear(4, 4)
1088
+ >>> action_spec = OneHot(4)
1089
+ >>> qvalue_actor = QValueActor(module=module, spec=action_spec)
1090
+ >>> td = qvalue_actor(td)
1091
+ >>> print(td)
1092
+ TensorDict(
1093
+ fields={
1094
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
1095
+ action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1096
+ chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1097
+ observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
1098
+ batch_size=torch.Size([5]),
1099
+ device=None,
1100
+ is_shared=False)
1101
+ >>> # with a TensorDictModule
1102
+ >>> td = TensorDict({'obs': torch.randn(5, 4)}, [5])
1103
+ >>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"])
1104
+ >>> action_spec = OneHot(4)
1105
+ >>> qvalue_actor = QValueActor(module=module, spec=action_spec)
1106
+ >>> td = qvalue_actor(td)
1107
+ >>> print(td)
1108
+ TensorDict(
1109
+ fields={
1110
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
1111
+ action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1112
+ chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1113
+ obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
1114
+ batch_size=torch.Size([5]),
1115
+ device=None,
1116
+ is_shared=False)
1117
+
1118
+ """
1119
+
1120
+ def __init__(
1121
+ self,
1122
+ module,
1123
+ *,
1124
+ in_keys=None,
1125
+ spec=None,
1126
+ safe=False,
1127
+ action_space: str | None = None,
1128
+ action_value_key=None,
1129
+ action_mask_key: NestedKey | None = None,
1130
+ ):
1131
+ if isinstance(action_space, TensorSpec):
1132
+ raise RuntimeError(
1133
+ "Using specs in action_space is deprecated. "
1134
+ "Please use the 'spec' argument if you want to provide an action spec"
1135
+ )
1136
+ action_space, spec = _process_action_space_spec(action_space, spec)
1137
+
1138
+ self.action_space = action_space
1139
+ self.action_value_key = action_value_key
1140
+ if action_value_key is None:
1141
+ action_value_key = "action_value"
1142
+ out_keys = [
1143
+ "action",
1144
+ action_value_key,
1145
+ "chosen_action_value",
1146
+ ]
1147
+ if isinstance(module, TensorDictModuleBase):
1148
+ if action_value_key not in module.out_keys:
1149
+ raise KeyError(
1150
+ f"The key '{action_value_key}' is not part of the module out-keys."
1151
+ )
1152
+ else:
1153
+ if in_keys is None:
1154
+ in_keys = ["observation"]
1155
+ module = TensorDictModule(
1156
+ module, in_keys=in_keys, out_keys=[action_value_key]
1157
+ )
1158
+ if spec is None:
1159
+ spec = Composite()
1160
+ if isinstance(spec, Composite):
1161
+ spec = spec.clone()
1162
+ if "action" not in spec.keys():
1163
+ spec["action"] = None
1164
+ else:
1165
+ spec = Composite(action=spec, shape=spec.shape[:-1])
1166
+ spec[action_value_key] = None
1167
+ spec["chosen_action_value"] = None
1168
+ qvalue = QValueModule(
1169
+ action_value_key=action_value_key,
1170
+ out_keys=out_keys,
1171
+ spec=spec,
1172
+ safe=safe,
1173
+ action_space=action_space,
1174
+ action_mask_key=action_mask_key,
1175
+ )
1176
+
1177
+ super().__init__(module, qvalue)
1178
+
1179
+
1180
+ class DistributionalQValueActor(QValueActor):
1181
+ """A Distributional DQN actor class.
1182
+
1183
+ This class appends a :class:`~.QValueModule` after the input module
1184
+ such that the action values are used to select an action.
1185
+
1186
+ Args:
1187
+ module (nn.Module): a :class:`torch.nn.Module` used to map the input to
1188
+ the output parameter space.
1189
+ If the module isn't of type :class:`torchrl.modules.DistributionalDQNnet`,
1190
+ :class:`~.DistributionalQValueActor` will ensure that a log-softmax
1191
+ operation is applied to the action value tensor along dimension ``-2``.
1192
+ This can be deactivated by turning off the ``make_log_softmax``
1193
+ keyword argument.
1194
+
1195
+ Keyword Args:
1196
+ in_keys (iterable of str, optional): keys to be read from input
1197
+ tensordict and passed to the module. If it
1198
+ contains more than one element, the values will be passed in the
1199
+ order given by the in_keys iterable.
1200
+ Defaults to ``["observation"]``.
1201
+ spec (TensorSpec, optional): Keyword-only argument.
1202
+ Specs of the output tensor. If the module
1203
+ outputs multiple output tensors,
1204
+ spec characterize the space of the first output tensor.
1205
+ safe (bool): Keyword-only argument.
1206
+ If ``True``, the value of the output is checked against the
1207
+ input spec. Out-of-domain sampling can
1208
+ occur because of exploration policies or numerical under/overflow
1209
+ issues. If this value is out of bounds, it is projected back onto the
1210
+ desired space using the :obj:`TensorSpec.project`
1211
+ method. Default is ``False``.
1212
+ var_nums (int, optional): if ``action_space = "mult-one-hot"``,
1213
+ this value represents the cardinality of each
1214
+ action component.
1215
+ support (torch.Tensor): support of the action values.
1216
+ action_space (str, optional): Action space. Must be one of
1217
+ ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``.
1218
+ This argument is exclusive with ``spec``, since ``spec``
1219
+ conditions the action_space.
1220
+ make_log_softmax (bool, optional): if ``True`` and if the module is not
1221
+ of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax
1222
+ operation will be applied along dimension -2 of the action value tensor.
1223
+ action_value_key (str or tuple of str, optional): if the input module
1224
+ is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must
1225
+ match one of its output keys. Otherwise, this string represents
1226
+ the name of the action-value entry in the output tensordict.
1227
+ action_mask_key (str or tuple of str, optional): The input key
1228
+ representing the action mask. Defaults to ``"None"`` (equivalent to no masking).
1229
+
1230
+ Examples:
1231
+ >>> import torch
1232
+ >>> from tensordict import TensorDict
1233
+ >>> from tensordict.nn import TensorDictModule, TensorDictSequential
1234
+ >>> from torch import nn
1235
+ >>> from torchrl.data import OneHot
1236
+ >>> from torchrl.modules import DistributionalQValueActor, MLP
1237
+ >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
1238
+ >>> nbins = 3
1239
+ >>> module = MLP(out_features=(nbins, 4), depth=2)
1240
+ >>> # let us make sure that the output is a log-softmax
1241
+ >>> module = TensorDictSequential(
1242
+ ... TensorDictModule(module, ["observation"], ["action_value"]),
1243
+ ... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]),
1244
+ ... )
1245
+ >>> action_spec = OneHot(4)
1246
+ >>> qvalue_actor = DistributionalQValueActor(
1247
+ ... module=module,
1248
+ ... spec=action_spec,
1249
+ ... support=torch.arange(nbins))
1250
+ >>> td = qvalue_actor(td)
1251
+ >>> print(td)
1252
+ TensorDict(
1253
+ fields={
1254
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
1255
+ action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1256
+ observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
1257
+ batch_size=torch.Size([5]),
1258
+ device=None,
1259
+ is_shared=False)
1260
+
1261
+ """
1262
+
1263
+ def __init__(
1264
+ self,
1265
+ module,
1266
+ support: torch.Tensor,
1267
+ in_keys=None,
1268
+ spec=None,
1269
+ safe=False,
1270
+ var_nums: int | None = None,
1271
+ action_space: str | None = None,
1272
+ action_value_key: str = "action_value",
1273
+ action_mask_key: NestedKey | None = None,
1274
+ make_log_softmax: bool = True,
1275
+ ):
1276
+ if isinstance(action_space, TensorSpec):
1277
+ raise RuntimeError("Using specs in action_space is deprecated")
1278
+ action_space, spec = _process_action_space_spec(action_space, spec)
1279
+ self.action_space = action_space
1280
+ self.action_value_key = action_value_key
1281
+ out_keys = [
1282
+ "action",
1283
+ action_value_key,
1284
+ ]
1285
+ if isinstance(module, TensorDictModuleBase):
1286
+ if action_value_key not in module.out_keys:
1287
+ raise KeyError(
1288
+ f"The key '{action_value_key}' is not part of the module out-keys."
1289
+ )
1290
+ else:
1291
+ if in_keys is None:
1292
+ in_keys = ["observation"]
1293
+ module = TensorDictModule(
1294
+ module, in_keys=in_keys, out_keys=[action_value_key]
1295
+ )
1296
+ if spec is None:
1297
+ spec = Composite()
1298
+ if isinstance(spec, Composite):
1299
+ spec = spec.clone()
1300
+ if "action" not in spec.keys():
1301
+ spec["action"] = None
1302
+ else:
1303
+ spec = Composite(action=spec, shape=spec.shape[:-1])
1304
+ spec[action_value_key] = None
1305
+
1306
+ qvalue = DistributionalQValueModule(
1307
+ action_value_key=action_value_key,
1308
+ out_keys=out_keys,
1309
+ spec=spec,
1310
+ safe=safe,
1311
+ action_space=action_space,
1312
+ action_mask_key=action_mask_key,
1313
+ support=support,
1314
+ var_nums=var_nums,
1315
+ )
1316
+ self.make_log_softmax = make_log_softmax
1317
+ if make_log_softmax and not isinstance(module, DistributionalDQNnet):
1318
+ log_softmax_module = DistributionalDQNnet(
1319
+ in_keys=qvalue.in_keys, out_keys=qvalue.in_keys
1320
+ )
1321
+ super(QValueActor, self).__init__(module, log_softmax_module, qvalue)
1322
+ else:
1323
+ super(QValueActor, self).__init__(module, qvalue)
1324
+ self.register_buffer("support", support)
1325
+
1326
+
1327
+ class ActorValueOperator(SafeSequential):
1328
+ """Actor-value operator.
1329
+
1330
+ This class wraps together an actor and a value model that share a common
1331
+ observation embedding network:
1332
+
1333
+ .. aafig::
1334
+ :aspect: 60
1335
+ :scale: 120
1336
+ :proportional:
1337
+ :textual:
1338
+
1339
+ +---------------+
1340
+ |Observation (s)|
1341
+ +---------------+
1342
+ |
1343
+ "common"
1344
+ |
1345
+ v
1346
+ +------------+
1347
+ |Hidden state|
1348
+ +------------+
1349
+ | |
1350
+ actor critic
1351
+ | |
1352
+ v v
1353
+ +-------------+ +------------+
1354
+ |Action (a(s))| |Value (V(s))|
1355
+ +-------------+ +------------+
1356
+
1357
+ .. note::
1358
+ For a similar class that returns an action and a Quality value :math:`Q(s, a)`,
1359
+ see :class:`~.ActorCriticOperator`. For a version without common embedding,
1360
+ refer to :class:`~.ActorCriticWrapper`.
1361
+
1362
+ To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which
1363
+ will both return a standalone TDModule with the dedicated functionality.
1364
+
1365
+ Args:
1366
+ common_operator (TensorDictModule): a common operator that reads
1367
+ observations and produces a hidden variable
1368
+ policy_operator (TensorDictModule): a policy operator that reads the
1369
+ hidden variable and returns an action
1370
+ value_operator (TensorDictModule): a value operator, that reads the
1371
+ hidden variable and returns a value
1372
+
1373
+ Examples:
1374
+ >>> import torch
1375
+ >>> from tensordict import TensorDict
1376
+ >>> from torchrl.modules import ProbabilisticActor, SafeModule
1377
+ >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor
1378
+ >>> module_hidden = torch.nn.Linear(4, 4)
1379
+ >>> td_module_hidden = SafeModule(
1380
+ ... module=module_hidden,
1381
+ ... in_keys=["observation"],
1382
+ ... out_keys=["hidden"],
1383
+ ... )
1384
+ >>> module_action = TensorDictModule(
1385
+ ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()),
1386
+ ... in_keys=["hidden"],
1387
+ ... out_keys=["loc", "scale"],
1388
+ ... )
1389
+ >>> td_module_action = ProbabilisticActor(
1390
+ ... module=module_action,
1391
+ ... in_keys=["loc", "scale"],
1392
+ ... out_keys=["action"],
1393
+ ... distribution_class=TanhNormal,
1394
+ ... return_log_prob=True,
1395
+ ... )
1396
+ >>> module_value = torch.nn.Linear(4, 1)
1397
+ >>> td_module_value = ValueOperator(
1398
+ ... module=module_value,
1399
+ ... in_keys=["hidden"],
1400
+ ... )
1401
+ >>> td_module = ActorValueOperator(td_module_hidden, td_module_action, td_module_value)
1402
+ >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
1403
+ >>> td_clone = td_module(td.clone())
1404
+ >>> print(td_clone)
1405
+ TensorDict(
1406
+ fields={
1407
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1408
+ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1409
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1410
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1411
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1412
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1413
+ state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1414
+ batch_size=torch.Size([3]),
1415
+ device=None,
1416
+ is_shared=False)
1417
+ >>> td_clone = td_module.get_policy_operator()(td.clone())
1418
+ >>> print(td_clone) # no value
1419
+ TensorDict(
1420
+ fields={
1421
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1422
+ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1423
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1424
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1425
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1426
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
1427
+ batch_size=torch.Size([3]),
1428
+ device=None,
1429
+ is_shared=False)
1430
+ >>> td_clone = td_module.get_value_operator()(td.clone())
1431
+ >>> print(td_clone) # no action
1432
+ TensorDict(
1433
+ fields={
1434
+ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1435
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1436
+ state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1437
+ batch_size=torch.Size([3]),
1438
+ device=None,
1439
+ is_shared=False)
1440
+
1441
+ """
1442
+
1443
+ def __init__(
1444
+ self,
1445
+ common_operator: TensorDictModule,
1446
+ policy_operator: TensorDictModule,
1447
+ value_operator: TensorDictModule,
1448
+ ):
1449
+ super().__init__(
1450
+ common_operator,
1451
+ policy_operator,
1452
+ value_operator,
1453
+ )
1454
+
1455
+ def get_policy_operator(self) -> TensorDictSequential:
1456
+ """Returns a standalone policy operator that maps an observation to an action."""
1457
+ if isinstance(self.module[1], SafeProbabilisticTensorDictSequential):
1458
+ return SafeProbabilisticTensorDictSequential(
1459
+ self.module[0], *self.module[1].module
1460
+ )
1461
+ return SafeSequential(self.module[0], self.module[1])
1462
+
1463
+ def get_value_operator(self) -> TensorDictSequential:
1464
+ """Returns a standalone value network operator that maps an observation to a value estimate."""
1465
+ return SafeSequential(self.module[0], self.module[2])
1466
+
1467
+ def get_policy_head(self) -> TensorDictModule:
1468
+ """Returns the policy head."""
1469
+ return self.module[1]
1470
+
1471
+ def get_value_head(self) -> TensorDictModule:
1472
+ """Returns the value head."""
1473
+ return self.module[2]
1474
+
1475
+
1476
+ class ActorCriticOperator(ActorValueOperator):
1477
+ """Actor-critic operator.
1478
+
1479
+ This class wraps together an actor and a value model that share a common
1480
+ observation embedding network:
1481
+
1482
+ .. aafig::
1483
+ :aspect: 60
1484
+ :scale: 120
1485
+ :proportional:
1486
+ :textual:
1487
+
1488
+ +---------------+
1489
+ |Observation (s)|
1490
+ +---------------+
1491
+ |
1492
+ v
1493
+ "common"
1494
+ |
1495
+ v
1496
+ +------------+
1497
+ |Hidden state|
1498
+ +------------+
1499
+ | |
1500
+ v v
1501
+ actor --> critic
1502
+ | |
1503
+ v v
1504
+ +-------------+ +----------------+
1505
+ |Action (a(s))| |Quality (Q(s,a))|
1506
+ +-------------+ +----------------+
1507
+
1508
+ .. note::
1509
+ For a similar class that returns an action and a state-value :math:`V(s)`
1510
+ see :class:`~.ActorValueOperator`.
1511
+
1512
+
1513
+ To facilitate the workflow, this class comes with a get_policy_operator() method, which
1514
+ will both return a standalone TDModule with the dedicated functionality. The get_critic_operator will return the
1515
+ parent object, as the value is computed based on the policy output.
1516
+
1517
+ Args:
1518
+ common_operator (TensorDictModule): a common operator that reads
1519
+ observations and produces a hidden variable
1520
+ policy_operator (TensorDictModule): a policy operator that reads the
1521
+ hidden variable and returns an action
1522
+ value_operator (TensorDictModule): a value operator, that reads the
1523
+ hidden variable and returns a value
1524
+
1525
+ Examples:
1526
+ >>> import torch
1527
+ >>> from tensordict import TensorDict
1528
+ >>> from torchrl.modules import ProbabilisticActor
1529
+ >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamExtractor, MLP
1530
+ >>> module_hidden = torch.nn.Linear(4, 4)
1531
+ >>> td_module_hidden = SafeModule(
1532
+ ... module=module_hidden,
1533
+ ... in_keys=["observation"],
1534
+ ... out_keys=["hidden"],
1535
+ ... )
1536
+ >>> module_action = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
1537
+ >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"])
1538
+ >>> td_module_action = ProbabilisticActor(
1539
+ ... module=module_action,
1540
+ ... in_keys=["loc", "scale"],
1541
+ ... out_keys=["action"],
1542
+ ... distribution_class=TanhNormal,
1543
+ ... return_log_prob=True,
1544
+ ... )
1545
+ >>> module_value = MLP(in_features=8, out_features=1, num_cells=[])
1546
+ >>> td_module_value = ValueOperator(
1547
+ ... module=module_value,
1548
+ ... in_keys=["hidden", "action"],
1549
+ ... out_keys=["state_action_value"],
1550
+ ... )
1551
+ >>> td_module = ActorCriticOperator(td_module_hidden, td_module_action, td_module_value)
1552
+ >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
1553
+ >>> td_clone = td_module(td.clone())
1554
+ >>> print(td_clone)
1555
+ TensorDict(
1556
+ fields={
1557
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1558
+ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1559
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1560
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1561
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1562
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1563
+ state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1564
+ batch_size=torch.Size([3]),
1565
+ device=None,
1566
+ is_shared=False)
1567
+ >>> td_clone = td_module.get_policy_operator()(td.clone())
1568
+ >>> print(td_clone) # no value
1569
+ TensorDict(
1570
+ fields={
1571
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1572
+ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1573
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1574
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1575
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1576
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
1577
+ batch_size=torch.Size([3]),
1578
+ device=None,
1579
+ is_shared=False)
1580
+ >>> td_clone = td_module.get_critic_operator()(td.clone())
1581
+ >>> print(td_clone) # no action
1582
+ TensorDict(
1583
+ fields={
1584
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1585
+ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1586
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1587
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1588
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1589
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1590
+ state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1591
+ batch_size=torch.Size([3]),
1592
+ device=None,
1593
+ is_shared=False)
1594
+
1595
+ """
1596
+
1597
+ def __init__(
1598
+ self,
1599
+ common_operator: TensorDictModule,
1600
+ policy_operator: TensorDictModule,
1601
+ value_operator: TensorDictModule,
1602
+ ):
1603
+ super().__init__(
1604
+ common_operator,
1605
+ policy_operator,
1606
+ value_operator,
1607
+ )
1608
+ if self[2].out_keys[0] == "state_value":
1609
+ raise RuntimeError(
1610
+ "Value out_key is state_value, which may lead to errors in downstream usages"
1611
+ "of that module. Consider setting `'state_action_value'` instead."
1612
+ "Make also sure that `'action'` is amongst the input keys of the value network."
1613
+ "If you are confident that action should not be used to compute the value, please"
1614
+ "user `ActorValueOperator` instead."
1615
+ )
1616
+
1617
+ def get_critic_operator(self) -> TensorDictModuleWrapper:
1618
+ """Returns a standalone critic network operator that maps a state-action pair to a critic estimate."""
1619
+ return self
1620
+
1621
+ def get_value_operator(self) -> TensorDictModuleWrapper:
1622
+ raise RuntimeError(
1623
+ "value_operator is the term used for operators that associate a value with a "
1624
+ "state/observation. This class computes the value of a state-action pair: to get the "
1625
+ "network computing this value, please call td_sequence.get_critic_operator()"
1626
+ )
1627
+
1628
+ def get_policy_head(self) -> TensorDictModule:
1629
+ """Returns the policy head."""
1630
+ return self.module[1]
1631
+
1632
+ def get_value_head(self) -> TensorDictModule:
1633
+ """Returns the value head."""
1634
+ return self.module[2]
1635
+
1636
+
1637
+ class ActorCriticWrapper(SafeSequential):
1638
+ """Actor-value operator without common module.
1639
+
1640
+ This class wraps together an actor and a value model that do not share a common observation embedding network:
1641
+
1642
+ .. aafig::
1643
+ :aspect: 60
1644
+ :scale: 120
1645
+ :proportional:
1646
+ :textual:
1647
+
1648
+ +---------------+
1649
+ |Observation (s)|
1650
+ +---------------+
1651
+ | | |
1652
+ v | v
1653
+ actor | critic
1654
+ | | |
1655
+ v | v
1656
+ +-------------+ | +------------+
1657
+ |Action (a(s))| | |Value (V(s))|
1658
+ +-------------+ | +------------+
1659
+
1660
+
1661
+ To facilitate the workflow, this class comes with a get_policy_operator() and get_value_operator() methods, which
1662
+ will both return a standalone TDModule with the dedicated functionality.
1663
+
1664
+ Args:
1665
+ policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action
1666
+ value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value
1667
+
1668
+ Examples:
1669
+ >>> import torch
1670
+ >>> from tensordict import TensorDict
1671
+ >>> from tensordict.nn import TensorDictModule
1672
+ >>> from torchrl.modules import (
1673
+ ... ActorCriticWrapper,
1674
+ ... ProbabilisticActor,
1675
+ ... NormalParamExtractor,
1676
+ ... TanhNormal,
1677
+ ... ValueOperator,
1678
+ ... )
1679
+ >>> action_module = TensorDictModule(
1680
+ ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()),
1681
+ ... in_keys=["observation"],
1682
+ ... out_keys=["loc", "scale"],
1683
+ ... )
1684
+ >>> td_module_action = ProbabilisticActor(
1685
+ ... module=action_module,
1686
+ ... in_keys=["loc", "scale"],
1687
+ ... distribution_class=TanhNormal,
1688
+ ... return_log_prob=True,
1689
+ ... )
1690
+ >>> module_value = torch.nn.Linear(4, 1)
1691
+ >>> td_module_value = ValueOperator(
1692
+ ... module=module_value,
1693
+ ... in_keys=["observation"],
1694
+ ... )
1695
+ >>> td_module = ActorCriticWrapper(td_module_action, td_module_value)
1696
+ >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
1697
+ >>> td_clone = td_module(td.clone())
1698
+ >>> print(td_clone)
1699
+ TensorDict(
1700
+ fields={
1701
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1702
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1703
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1704
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1705
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1706
+ state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1707
+ batch_size=torch.Size([3]),
1708
+ device=None,
1709
+ is_shared=False)
1710
+ >>> td_clone = td_module.get_policy_operator()(td.clone())
1711
+ >>> print(td_clone) # no value
1712
+ TensorDict(
1713
+ fields={
1714
+ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1715
+ loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1716
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1717
+ sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1718
+ scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
1719
+ batch_size=torch.Size([3]),
1720
+ device=None,
1721
+ is_shared=False)
1722
+ >>> td_clone = td_module.get_value_operator()(td.clone())
1723
+ >>> print(td_clone) # no action
1724
+ TensorDict(
1725
+ fields={
1726
+ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1727
+ state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1728
+ batch_size=torch.Size([3]),
1729
+ device=None,
1730
+ is_shared=False)
1731
+
1732
+ """
1733
+
1734
+ def __init__(
1735
+ self,
1736
+ policy_operator: TensorDictModule,
1737
+ value_operator: TensorDictModule,
1738
+ ):
1739
+ super().__init__(
1740
+ policy_operator,
1741
+ value_operator,
1742
+ )
1743
+
1744
+ def get_policy_operator(self) -> TensorDictModule:
1745
+ """Returns a standalone policy operator that maps an observation to an action."""
1746
+ return self.module[0]
1747
+
1748
+ def get_value_operator(self) -> TensorDictModule:
1749
+ """Returns a standalone value network operator that maps an observation to a value estimate."""
1750
+ return self.module[1]
1751
+
1752
+ get_policy_head = get_policy_operator
1753
+ get_value_head = get_value_operator
1754
+
1755
+
1756
+ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper):
1757
+ """Inference Action Wrapper for the Decision Transformer.
1758
+
1759
+ A wrapper specifically designed for the Decision Transformer, which will mask the
1760
+ input tensordict sequences to the inferece context.
1761
+ The output will be a TensorDict with the same keys as the input, but with only the last
1762
+ action of the predicted action sequence and the last return to go.
1763
+
1764
+ This module creates returns a modified copy of the tensordict, ie. it does
1765
+ **not** modify the tensordict in-place.
1766
+
1767
+ .. note:: If the action, observation or reward-to-go key is not standard,
1768
+ the method :meth:`set_tensor_keys` should be used, e.g.
1769
+
1770
+ >>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz")
1771
+
1772
+ The in_keys are the observation, action and return-to-go keys. The out-keys
1773
+ match the in-keys, with the addition of any other out-key from the policy
1774
+ (eg., parameters of the distribution or hidden values).
1775
+
1776
+ Args:
1777
+ policy (TensorDictModule): The policy module that takes in
1778
+ observations and produces an action value
1779
+
1780
+ Keyword Args:
1781
+ inference_context (int): The number of previous actions that will not be masked in the context.
1782
+ For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries
1783
+ of the context will be masked. Defaults to 5.
1784
+ spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module.
1785
+ device (torch.device, optional): if provided, the device where the buffers / specs will be placed.
1786
+
1787
+ Examples:
1788
+ >>> import torch
1789
+ >>> from tensordict import TensorDict
1790
+ >>> from tensordict.nn import TensorDictModule
1791
+ >>> from torchrl.modules import (
1792
+ ... ProbabilisticActor,
1793
+ ... TanhDelta,
1794
+ ... DTActor,
1795
+ ... DecisionTransformerInferenceWrapper,
1796
+ ... )
1797
+ >>> dtactor = DTActor(state_dim=4, action_dim=2,
1798
+ ... transformer_config=DTActor.default_config()
1799
+ ... )
1800
+ >>> actor_module = TensorDictModule(
1801
+ ... dtactor,
1802
+ ... in_keys=["observation", "action", "return_to_go"],
1803
+ ... out_keys=["param"])
1804
+ >>> dist_class = TanhDelta
1805
+ >>> dist_kwargs = {
1806
+ ... "low": -1.0,
1807
+ ... "high": 1.0,
1808
+ ... }
1809
+ >>> actor = ProbabilisticActor(
1810
+ ... in_keys=["param"],
1811
+ ... out_keys=["action"],
1812
+ ... module=actor_module,
1813
+ ... distribution_class=dist_class,
1814
+ ... distribution_kwargs=dist_kwargs)
1815
+ >>> inference_actor = DecisionTransformerInferenceWrapper(actor)
1816
+ >>> sequence_length = 20
1817
+ >>> td = TensorDict({"observation": torch.randn(1, sequence_length, 4),
1818
+ ... "action": torch.randn(1, sequence_length, 2),
1819
+ ... "return_to_go": torch.randn(1, sequence_length, 1)}, [1,])
1820
+ >>> result = inference_actor(td)
1821
+ >>> print(result)
1822
+ TensorDict(
1823
+ fields={
1824
+ action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
1825
+ observation: Tensor(shape=torch.Size([1, 20, 4]), device=cpu, dtype=torch.float32, is_shared=False),
1826
+ param: Tensor(shape=torch.Size([1, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False),
1827
+ return_to_go: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
1828
+ batch_size=torch.Size([1]),
1829
+ device=None,
1830
+ is_shared=False)
1831
+ """
1832
+
1833
+ def __init__(
1834
+ self,
1835
+ policy: TensorDictModule,
1836
+ *,
1837
+ inference_context: int = 5,
1838
+ spec: TensorSpec | None = None,
1839
+ device: torch.device | None = None,
1840
+ ):
1841
+ super().__init__(policy)
1842
+ self.observation_key = "observation"
1843
+ self.action_key = "action"
1844
+ self.out_action_key = "action"
1845
+ self.return_to_go_key = "return_to_go"
1846
+ self.inference_context = inference_context
1847
+ if spec is not None:
1848
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
1849
+ spec = Composite({self.action_key: spec}, shape=spec.shape[:-1])
1850
+ self._spec = spec
1851
+ elif hasattr(self.td_module, "_spec"):
1852
+ self._spec = self.td_module._spec.clone()
1853
+ if self.action_key not in self._spec.keys():
1854
+ self._spec[self.action_key] = None
1855
+ elif hasattr(self.td_module, "spec"):
1856
+ self._spec = self.td_module.spec.clone()
1857
+ if self.action_key not in self._spec.keys():
1858
+ self._spec[self.action_key] = None
1859
+ else:
1860
+ self._spec = Composite({key: None for key in policy.out_keys})
1861
+ if device is not None:
1862
+ self._spec = self._spec.to(device)
1863
+ self.checked = False
1864
+
1865
+ @property
1866
+ def in_keys(self):
1867
+ return [self.observation_key, self.action_key, self.return_to_go_key]
1868
+
1869
+ @property
1870
+ def out_keys(self):
1871
+ return sorted(
1872
+ set(self.td_module.out_keys).union(
1873
+ {self.observation_key, self.action_key, self.return_to_go_key}
1874
+ ),
1875
+ key=str,
1876
+ )
1877
+
1878
+ def set_tensor_keys(self, **kwargs):
1879
+ """Sets the input keys of the module.
1880
+
1881
+ Keyword Args:
1882
+ observation (NestedKey, optional): The observation key.
1883
+ action (NestedKey, optional): The action key (input to the network).
1884
+ return_to_go (NestedKey, optional): The return_to_go key.
1885
+ out_action (NestedKey, optional): The action key (output of the network).
1886
+
1887
+ """
1888
+ observation_key = unravel_key(kwargs.pop("observation", self.observation_key))
1889
+ action_key = unravel_key(kwargs.pop("action", self.action_key))
1890
+ out_action_key = unravel_key(kwargs.pop("out_action", self.out_action_key))
1891
+ return_to_go_key = unravel_key(
1892
+ kwargs.pop("return_to_go", self.return_to_go_key)
1893
+ )
1894
+ if kwargs:
1895
+ raise TypeError(
1896
+ f"Got unknown input(s) {kwargs.keys()}. Accepted keys are 'action', 'return_to_go' and 'observation'."
1897
+ )
1898
+ self.observation_key = observation_key
1899
+ self.action_key = action_key
1900
+ self.return_to_go_key = return_to_go_key
1901
+ if out_action_key not in self.td_module.out_keys:
1902
+ raise ValueError(
1903
+ f"The value of out_action_key ({out_action_key}) must be "
1904
+ f"within the actor output keys ({self.td_module.out_keys})."
1905
+ )
1906
+ self.out_action_key = out_action_key
1907
+
1908
+ def step(self, frames: int = 1) -> None:
1909
+ pass
1910
+
1911
+ @staticmethod
1912
+ def _check_tensor_dims(reward, obs, action):
1913
+ if not (reward.shape[:-1] == obs.shape[:-1] == action.shape[:-1]):
1914
+ raise ValueError(
1915
+ "Mismatched tensor dimensions. This is not supported yet, file an issue on torchrl"
1916
+ )
1917
+
1918
+ def mask_context(self, tensordict: TensorDictBase) -> TensorDictBase:
1919
+ """Mask the context of the input sequences."""
1920
+ observation = tensordict.get(self.observation_key).clone()
1921
+ action = tensordict.get(self.action_key).clone()
1922
+ return_to_go = tensordict.get(self.return_to_go_key).clone()
1923
+ self._check_tensor_dims(return_to_go, observation, action)
1924
+
1925
+ observation[..., : -self.inference_context, :] = 0
1926
+ action[
1927
+ ..., : -(self.inference_context - 1), :
1928
+ ] = 0 # as we add zeros to the end of the action
1929
+ action = torch.cat(
1930
+ [
1931
+ action[..., 1:, :],
1932
+ torch.zeros(
1933
+ *action.shape[:-2], 1, action.shape[-1], device=action.device
1934
+ ),
1935
+ ],
1936
+ dim=-2,
1937
+ )
1938
+ return_to_go[..., : -self.inference_context, :] = 0
1939
+
1940
+ tensordict.set(self.observation_key, observation)
1941
+ tensordict.set(self.action_key, action)
1942
+ tensordict.set(self.return_to_go_key, return_to_go)
1943
+ return tensordict
1944
+
1945
+ def check_keys(self):
1946
+ # an exception will be raised if the action key mismatch
1947
+ self.set_tensor_keys()
1948
+ self.checked = True
1949
+
1950
+ @dispatch
1951
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1952
+ if not self.checked:
1953
+ self.check_keys()
1954
+ """Forward pass of the inference wrapper."""
1955
+ tensordict = tensordict.clone(False)
1956
+ obs = tensordict.get(self.observation_key)
1957
+ # Mask the context of the input sequences
1958
+ tensordict = self.mask_context(tensordict)
1959
+ # forward pass
1960
+ tensordict = self.td_module.forward(tensordict)
1961
+ # get last action prediction
1962
+ out_action = tensordict.get(self.out_action_key)
1963
+ if tensordict.ndim == out_action.ndim - 1:
1964
+ # then time dimension is in the TD's dimensions, and we must get rid of it
1965
+ tensordict.batch_size = tensordict.batch_size[:-1]
1966
+ out_action = out_action[..., -1, :]
1967
+ tensordict.set(self.out_action_key, out_action)
1968
+
1969
+ out_rtg = tensordict.get(self.return_to_go_key)
1970
+ out_rtg = out_rtg[..., -1, :]
1971
+ tensordict.set(self.return_to_go_key, out_rtg)
1972
+
1973
+ # set unmasked observation
1974
+ tensordict.set(self.observation_key, obs)
1975
+ return tensordict
1976
+
1977
+
1978
+ class TanhModule(TensorDictModuleBase):
1979
+ """A Tanh module for deterministic policies with bounded action space.
1980
+
1981
+ This transform is to be used as a TensorDictModule layer to map a network
1982
+ output to a bounded space.
1983
+
1984
+ Args:
1985
+ in_keys (list of str or tuples of str): the input keys of the module.
1986
+ out_keys (list of str or tuples of str, optional): the output keys of the module.
1987
+ If none is provided, the same keys as in_keys are assumed.
1988
+
1989
+ Keyword Args:
1990
+ spec (TensorSpec, optional): if provided, the spec of the output.
1991
+ If a Composite is provided, its key(s) must match the key(s)
1992
+ in out_keys. Otherwise, the key(s) of out_keys are assumed and the
1993
+ same spec is used for all outputs.
1994
+ low (:obj:`float`, np.ndarray or torch.Tensor): the lower bound of the space.
1995
+ If none is provided and no spec is provided, -1 is assumed. If a
1996
+ spec is provided, the minimum value of the spec will be retrieved.
1997
+ high (:obj:`float`, np.ndarray or torch.Tensor): the higher bound of the space.
1998
+ If none is provided and no spec is provided, 1 is assumed. If a
1999
+ spec is provided, the maximum value of the spec will be retrieved.
2000
+ clamp (bool, optional): if ``True``, the outputs will be clamped to be
2001
+ within the boundaries but at a minimum resolution from them.
2002
+ Defaults to ``False``.
2003
+
2004
+ Examples:
2005
+ >>> from tensordict import TensorDict
2006
+ >>> # simplest use case: -1 - 1 boundaries
2007
+ >>> torch.manual_seed(0)
2008
+ >>> in_keys = ["action"]
2009
+ >>> mod = TanhModule(
2010
+ ... in_keys=in_keys,
2011
+ ... )
2012
+ >>> data = TensorDict({"action": torch.randn(5) * 10}, [])
2013
+ >>> data = mod(data)
2014
+ >>> data['action']
2015
+ tensor([ 1.0000, -0.9944, -1.0000, 1.0000, -1.0000])
2016
+ >>> # low and high can be customized
2017
+ >>> low = -2
2018
+ >>> high = 1
2019
+ >>> mod = TanhModule(
2020
+ ... in_keys=in_keys,
2021
+ ... low=low,
2022
+ ... high=high,
2023
+ ... )
2024
+ >>> data = TensorDict({"action": torch.randn(5) * 10}, [])
2025
+ >>> data = mod(data)
2026
+ >>> data['action']
2027
+ tensor([-2.0000, 0.9991, 1.0000, -2.0000, -1.9991])
2028
+ >>> # A spec can be provided
2029
+ >>> from torchrl.data import Bounded
2030
+ >>> spec = Bounded(low, high, shape=())
2031
+ >>> mod = TanhModule(
2032
+ ... in_keys=in_keys,
2033
+ ... low=low,
2034
+ ... high=high,
2035
+ ... spec=spec,
2036
+ ... clamp=False,
2037
+ ... )
2038
+ >>> # One can also work with multiple keys
2039
+ >>> in_keys = ['a', 'b']
2040
+ >>> spec = Composite(
2041
+ ... a=Bounded(-3, 0, shape=()),
2042
+ ... b=Bounded(0, 3, shape=()))
2043
+ >>> mod = TanhModule(
2044
+ ... in_keys=in_keys,
2045
+ ... spec=spec,
2046
+ ... )
2047
+ >>> data = TensorDict(
2048
+ ... {'a': torch.randn(10), 'b': torch.randn(10)}, batch_size=[])
2049
+ >>> data = mod(data)
2050
+ >>> data['a']
2051
+ tensor([-2.3020, -1.2299, -2.5418, -0.2989, -2.6849, -1.3169, -2.2690, -0.9649,
2052
+ -2.5686, -2.8602])
2053
+ >>> data['b']
2054
+ tensor([2.0315, 2.8455, 2.6027, 2.4746, 1.7843, 2.7782, 0.2111, 0.5115, 1.4687,
2055
+ 0.5760])
2056
+ """
2057
+
2058
+ def __init__(
2059
+ self,
2060
+ in_keys,
2061
+ out_keys=None,
2062
+ *,
2063
+ spec=None,
2064
+ low=None,
2065
+ high=None,
2066
+ clamp: bool = False,
2067
+ ):
2068
+ super().__init__()
2069
+ self.in_keys = in_keys
2070
+ if out_keys is None:
2071
+ out_keys = in_keys
2072
+ if len(in_keys) != len(out_keys):
2073
+ raise ValueError(
2074
+ "in_keys and out_keys should have the same length, "
2075
+ f"got in_keys={in_keys} and out_keys={out_keys}"
2076
+ )
2077
+ self.out_keys = out_keys
2078
+ # action_spec can be a composite spec or not
2079
+ if isinstance(spec, Composite):
2080
+ for out_key in self.out_keys:
2081
+ if out_key not in spec.keys(True, True):
2082
+ spec[out_key] = None
2083
+ else:
2084
+ # if one spec is present, we assume it is the same for all keys
2085
+ spec = Composite(
2086
+ {out_key: spec for out_key in out_keys},
2087
+ )
2088
+
2089
+ leaf_specs = [spec[out_key] for out_key in self.out_keys]
2090
+ self.spec = spec
2091
+ self.non_trivial = {}
2092
+ for out_key, leaf_spec in zip(out_keys, leaf_specs):
2093
+ _low, _high = self._make_low_high(low, high, leaf_spec)
2094
+ key = out_key if isinstance(out_key, str) else "_".join(out_key)
2095
+ self.register_buffer(f"{key}_low", _low)
2096
+ self.register_buffer(f"{key}_high", _high)
2097
+ self.non_trivial[out_key] = (_high != 1).any() or (_low != -1).any()
2098
+ if (_high < _low).any():
2099
+ raise ValueError(f"Got high < low in {type(self)}.")
2100
+ self.clamp = clamp
2101
+
2102
+ def _make_low_high(self, low, high, leaf_spec):
2103
+ if low is None and leaf_spec is None:
2104
+ low = -torch.ones(())
2105
+ elif low is None:
2106
+ low = leaf_spec.space.low
2107
+ elif leaf_spec is not None:
2108
+ if (low != leaf_spec.space.low).any():
2109
+ raise ValueError(
2110
+ f"The minimum value ({low}) provided to {type(self)} does not match the action spec one ({leaf_spec.space.low})."
2111
+ )
2112
+ if not isinstance(low, torch.Tensor):
2113
+ low = torch.tensor(low)
2114
+ if high is None and leaf_spec is None:
2115
+ high = torch.ones(())
2116
+ elif high is None:
2117
+ high = leaf_spec.space.high
2118
+ elif leaf_spec is not None:
2119
+ if (high != leaf_spec.space.high).any():
2120
+ raise ValueError(
2121
+ f"The maximum value ({high}) provided to {type(self)} does not match the action spec one ({leaf_spec.space.high})."
2122
+ )
2123
+ if not isinstance(high, torch.Tensor):
2124
+ high = torch.tensor(high)
2125
+ return low, high
2126
+
2127
+ @dispatch
2128
+ def forward(self, tensordict):
2129
+ inputs = [tensordict.get(key) for key in self.in_keys]
2130
+ # map
2131
+ for out_key, feature in zip(self.out_keys, inputs):
2132
+ key = out_key if isinstance(out_key, str) else "_".join(out_key)
2133
+ low_key = f"{key}_low"
2134
+ high_key = f"{key}_high"
2135
+ low = getattr(self, low_key)
2136
+ high = getattr(self, high_key)
2137
+ feature = feature.tanh()
2138
+ if self.clamp:
2139
+ eps = torch.finfo(feature.dtype).resolution
2140
+ feature = feature.clamp(-1 + eps, 1 - eps)
2141
+ if self.non_trivial:
2142
+ feature = low + (high - low) * (feature + 1) / 2
2143
+ tensordict.set(out_key, feature)
2144
+ return tensordict
2145
+
2146
+
2147
+ class LMHeadActorValueOperator(ActorValueOperator):
2148
+ """Builds an Actor-Value operator from an huggingface-like *LMHeadModel.
2149
+
2150
+ This method:
2151
+
2152
+ - takes as input an huggingface-like *LMHeadModel
2153
+ - extracts the final linear layer uses it as a base layer of the actor_head and
2154
+ adds the sampling layer
2155
+ - uses the common transformer as common model
2156
+ - adds a linear critic
2157
+
2158
+ Args:
2159
+ base_model (nn.Module): a torch model composed by a `.transformer` model and `.lm_head` linear layer
2160
+
2161
+ .. note:: For more details regarding the class construction, please refer to :class:`~.ActorValueOperator`.
2162
+ """
2163
+
2164
+ def __init__(self, base_model):
2165
+ actor_head = base_model.lm_head
2166
+ value_head = nn.Linear(actor_head.in_features, 1, bias=False)
2167
+ common = TensorDictSequential(
2168
+ TensorDictModule(
2169
+ base_model.transformer,
2170
+ in_keys={"input_ids": "input_ids", "attention_mask": "attention_mask"},
2171
+ out_keys=["x", "_"],
2172
+ ),
2173
+ TensorDictModule(lambda x: x[:, -1, :], in_keys=["x"], out_keys=["x"]),
2174
+ )
2175
+ actor_head = TensorDictModule(actor_head, in_keys=["x"], out_keys=["logits"])
2176
+ actor_head = SafeProbabilisticTensorDictSequential(
2177
+ actor_head,
2178
+ SafeProbabilisticModule(
2179
+ in_keys=["logits"],
2180
+ out_keys=["action"],
2181
+ distribution_class=Categorical,
2182
+ return_log_prob=True,
2183
+ ),
2184
+ )
2185
+ value_head = TensorDictModule(
2186
+ value_head, in_keys=["x"], out_keys=["state_value"]
2187
+ )
2188
+
2189
+ super().__init__(common, actor_head, value_head)
2190
+
2191
+
2192
+ class MultiStepActorWrapper(TensorDictModuleBase):
2193
+ """A wrapper around a multi-action actor.
2194
+
2195
+ This class enables macros to be executed in an environment.
2196
+ The actor action(s) entry must have an additional time dimension to
2197
+ be consumed. It must be placed adjacent to the last dimension of the
2198
+ input tensordict (i.e. at ``tensordict.ndim``).
2199
+
2200
+ The action entry keys are retrieved automatically from the actor if
2201
+ not provided using a simple heuristic (any nested key ending with the
2202
+ ``"action"`` string).
2203
+
2204
+ An ``"is_init"`` entry must also be present in the input tensordict
2205
+ to track which and when the current collection should be interrupted
2206
+ because a "done" state has been encountered. Unlike ``action_keys``,
2207
+ this key must be unique.
2208
+
2209
+ Args:
2210
+ actor (TensorDictModuleBase): An actor.
2211
+ n_steps (int, optional): the number of actions the actor outputs at once
2212
+ (lookahead window). Defaults to `None`.
2213
+
2214
+ Keyword Args:
2215
+ action_keys (list of NestedKeys, optional): the action keys from
2216
+ the environment. Can be retrieved from ``env.action_keys``.
2217
+ Defaults to all ``out_keys`` of the ``actor`` which end
2218
+ with the ``"action"`` string.
2219
+ init_key (NestedKey, optional): the key of the entry indicating
2220
+ when the environment has gone through a reset.
2221
+ Defaults to ``"is_init"`` which is the ``out_key`` from the
2222
+ :class:`~torchrl.envs.transforms.InitTracker` transform.
2223
+ keep_dim (bool, optional): whether to keep the time dimension of
2224
+ the macro during indexing. Defaults to ``False``.
2225
+
2226
+ Examples:
2227
+ >>> import torch.nn
2228
+ >>> from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper, Actor
2229
+ >>> from torchrl.envs import CatFrames, GymEnv, TransformedEnv, SerialEnv, InitTracker, Compose
2230
+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
2231
+ >>>
2232
+ >>> time_steps = 6
2233
+ >>> n_obs = 4
2234
+ >>> n_action = 2
2235
+ >>> batch = 5
2236
+ >>>
2237
+ >>> # Transforms a CatFrames in a stack of frames
2238
+ >>> def reshape_cat(data: torch.Tensor):
2239
+ ... return data.unflatten(-1, (time_steps, n_obs))
2240
+ >>> # an actor that reads `time_steps` frames and outputs one action per frame
2241
+ >>> # (actions are conditioned on the observation of `time_steps` in the past)
2242
+ >>> actor_base = Seq(
2243
+ ... Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]),
2244
+ ... Mod(torch.nn.Linear(n_obs, n_action), in_keys=["obs_cat_reshape"], out_keys=["action"])
2245
+ ... )
2246
+ >>> # Wrap the actor to dispatch the actions
2247
+ >>> actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
2248
+ >>>
2249
+ >>> env = TransformedEnv(
2250
+ ... SerialEnv(batch, lambda: GymEnv("CartPole-v1")),
2251
+ ... Compose(
2252
+ ... InitTracker(),
2253
+ ... CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1)
2254
+ ... )
2255
+ ... )
2256
+ >>>
2257
+ >>> print(env.rollout(100, policy=actor, break_when_any_done=False))
2258
+ TensorDict(
2259
+ fields={
2260
+ action: Tensor(shape=torch.Size([5, 100, 2]), device=cpu, dtype=torch.float32, is_shared=False),
2261
+ action_orig: Tensor(shape=torch.Size([5, 100, 6, 2]), device=cpu, dtype=torch.float32, is_shared=False),
2262
+ counter: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.int32, is_shared=False),
2263
+ done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
2264
+ is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
2265
+ next: TensorDict(
2266
+ fields={
2267
+ done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
2268
+ is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
2269
+ obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
2270
+ observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
2271
+ reward: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
2272
+ terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
2273
+ truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
2274
+ batch_size=torch.Size([5, 100]),
2275
+ device=cpu,
2276
+ is_shared=False),
2277
+ obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
2278
+ observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
2279
+ terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
2280
+ truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
2281
+ batch_size=torch.Size([5, 100]),
2282
+ device=cpu,
2283
+ is_shared=False)
2284
+
2285
+ .. seealso:: :class:`torchrl.envs.MultiStepEnvWrapper` is the EnvBase alter-ego of this wrapper:
2286
+ It wraps an environment and unbinds the action, executing it one element at a time.
2287
+
2288
+ """
2289
+
2290
+ def __init__(
2291
+ self,
2292
+ actor: TensorDictModuleBase,
2293
+ n_steps: int | None = None,
2294
+ *,
2295
+ action_keys: list[NestedKey] | None = None,
2296
+ init_key: list[NestedKey] | None = None,
2297
+ keep_dim: bool = False,
2298
+ ):
2299
+ self.action_keys = action_keys
2300
+ self.init_key = init_key
2301
+ self.n_steps = n_steps
2302
+ self.keep_dim = keep_dim
2303
+
2304
+ super().__init__()
2305
+ self.actor = actor
2306
+
2307
+ @property
2308
+ def in_keys(self):
2309
+ return self.actor.in_keys + [self.init_key]
2310
+
2311
+ @property
2312
+ def out_keys(self):
2313
+ return (
2314
+ self.actor.out_keys
2315
+ + list(self._actor_keys_map.values())
2316
+ + [self.counter_key]
2317
+ )
2318
+
2319
+ def _get_and_move(self, tensordict: TensorDictBase) -> TensorDictBase:
2320
+ for action_key in self.action_keys:
2321
+ action = tensordict.get(action_key)
2322
+ if isinstance(action, tuple):
2323
+ action_key_orig = (*action_key[:-1], action_key[-1] + "_orig")
2324
+ else:
2325
+ action_key_orig = action_key + "_orig"
2326
+ tensordict.set(action_key_orig, action)
2327
+
2328
+ _NO_INIT_ERR = RuntimeError(
2329
+ "Cannot initialize the wrapper with partial is_init signal."
2330
+ )
2331
+
2332
+ def _init(self, tensordict: TensorDictBase):
2333
+ is_init = tensordict.get(self.init_key, default=None)
2334
+ if is_init is None:
2335
+ raise KeyError("No init key was passed to the batched action wrapper.")
2336
+ counter = tensordict.get(self.counter_key, None)
2337
+ if counter is None:
2338
+ counter = is_init.int()
2339
+ is_init = is_init | (counter == self.n_steps)
2340
+ if is_init.any():
2341
+ counter = counter.masked_fill(is_init, 0)
2342
+ tensordict_filtered = tensordict[is_init.reshape(tensordict.shape)]
2343
+ output = self.actor(tensordict_filtered)
2344
+
2345
+ for action_key, action_key_orig in self._actor_keys_map.items():
2346
+ action_computed = output.get(action_key, default=None)
2347
+ action_orig = tensordict.get(action_key_orig, default=None)
2348
+ if action_orig is None:
2349
+ if not is_init.all():
2350
+ raise self._NO_INIT_ERR
2351
+ else:
2352
+ is_init_expand = expand_as_right(is_init, action_orig)
2353
+ action_computed = torch.masked_scatter(
2354
+ action_orig, is_init_expand, action_computed
2355
+ )
2356
+ tensordict.set(action_key_orig, action_computed)
2357
+ tensordict.set("counter", counter + 1)
2358
+
2359
+ def forward(
2360
+ self,
2361
+ tensordict: TensorDictBase,
2362
+ ) -> TensorDictBase:
2363
+ self._init(tensordict)
2364
+ for action_key, action_key_orig in self._actor_keys_map.items():
2365
+ # get orig
2366
+ if isinstance(action_key_orig, str):
2367
+ parent_td = tensordict
2368
+ action_entry = parent_td.get(action_key_orig, None)
2369
+ else:
2370
+ parent_td = tensordict.get(action_key_orig[:-1])
2371
+ action_entry = parent_td.get(action_key_orig[-1], None)
2372
+ if action_entry is None:
2373
+ raise self._NO_INIT_ERR
2374
+ if (
2375
+ self.n_steps is not None
2376
+ and action_entry.shape[parent_td.ndim] != self.n_steps
2377
+ ):
2378
+ raise RuntimeError(
2379
+ f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). "
2380
+ f"The action shape was {action_entry.shape}."
2381
+ )
2382
+ base_idx = (
2383
+ slice(
2384
+ None,
2385
+ ),
2386
+ ) * parent_td.ndim
2387
+ if not self.keep_dim:
2388
+ cur_action = action_entry[base_idx + (0,)]
2389
+ else:
2390
+ cur_action = action_entry[base_idx + (slice(1),)]
2391
+ tensordict.set(action_key, cur_action)
2392
+ tensordict.set(
2393
+ action_key_orig,
2394
+ torch.roll(action_entry, shifts=-1, dims=parent_td.ndim),
2395
+ )
2396
+ return tensordict
2397
+
2398
+ @property
2399
+ def action_keys(self) -> list[NestedKey]:
2400
+ action_keys = self.__dict__.get("_action_keys", None)
2401
+ if action_keys is None:
2402
+
2403
+ def ends_with_action(key):
2404
+ if isinstance(key, str):
2405
+ return key == "action"
2406
+ return key[-1] == "action"
2407
+
2408
+ action_keys = [key for key in self.actor.out_keys if ends_with_action(key)]
2409
+
2410
+ self.__dict__["_action_keys"] = action_keys
2411
+ return action_keys
2412
+
2413
+ @action_keys.setter
2414
+ def action_keys(self, value):
2415
+ if value is None:
2416
+ return
2417
+ self.__dict__["_actor_keys_map_values"] = None
2418
+ if not isinstance(value, list):
2419
+ value = [value]
2420
+ self._action_keys = [unravel_key(key) for key in value]
2421
+
2422
+ @property
2423
+ def _actor_keys_map(self) -> dict[NestedKey, NestedKey]:
2424
+ val = self.__dict__.get("_actor_keys_map_values", None)
2425
+ if val is None:
2426
+
2427
+ def _replace_last(action_key):
2428
+ if isinstance(action_key, tuple):
2429
+ action_key_orig = (*action_key[:-1], action_key[-1] + "_orig")
2430
+ else:
2431
+ action_key_orig = action_key + "_orig"
2432
+ return action_key_orig
2433
+
2434
+ val = {key: _replace_last(key) for key in self.action_keys}
2435
+ self.__dict__["_actor_keys_map_values"] = val
2436
+ return val
2437
+
2438
+ @property
2439
+ def init_key(self) -> NestedKey:
2440
+ """The indicator of the initial step for a given element of the batch."""
2441
+ init_key = self.__dict__.get("_init_key", None)
2442
+ if init_key is None:
2443
+ self.init_key = "is_init"
2444
+ return self.init_key
2445
+ return init_key
2446
+
2447
+ @init_key.setter
2448
+ def init_key(self, value):
2449
+ if value is None:
2450
+ return
2451
+ if isinstance(value, list):
2452
+ raise ValueError("Only a single init_key can be passed.")
2453
+ self._init_key = value
2454
+
2455
+ @property
2456
+ def counter_key(self):
2457
+ return _replace_last(self.init_key, "counter")