torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1639 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import typing
8
+ from typing import Any
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from tensordict import TensorDictBase, unravel_key_list
13
+ from tensordict.base import NO_DEFAULT
14
+ from tensordict.nn import dispatch, TensorDictModuleBase as ModuleBase
15
+ from tensordict.utils import expand_as_right, prod, set_lazy_legacy
16
+ from torch import nn, Tensor
17
+ from torch.nn.modules.rnn import RNNCellBase
18
+
19
+ from torchrl._utils import _ContextManager, _DecoratorContextManager
20
+ from torchrl.data.tensor_specs import Unbounded
21
+
22
+
23
+ class LSTMCell(RNNCellBase):
24
+ r"""A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
25
+
26
+ .. note::
27
+ This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`.
28
+
29
+ Examples:
30
+ >>> import torch
31
+ >>> from torchrl.modules.tensordict_module.rnn import LSTMCell
32
+ >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
33
+ >>> B = 2
34
+ >>> N_IN = 10
35
+ >>> N_OUT = 20
36
+ >>> V = 4 # vector size
37
+ >>> lstm_cell = LSTMCell(input_size=N_IN, hidden_size=N_OUT, device=device)
38
+
39
+ # single call
40
+ >>> x = torch.randn(B, 10, device=device)
41
+ >>> h0 = torch.zeros(B, 20, device=device)
42
+ >>> c0 = torch.zeros(B, 20, device=device)
43
+ >>> with torch.no_grad():
44
+ ... (h1, c1) = lstm_cell(x, (h0, c0))
45
+
46
+ # vectorised call - not possible with nn.LSTMCell
47
+ >>> def call_lstm(x, h, c):
48
+ ... h_out, c_out = lstm_cell(x, (h, c))
49
+ ... return h_out, c_out
50
+ >>> batched_call = torch.vmap(call_lstm)
51
+ >>> x = torch.randn(V, B, 10, device=device)
52
+ >>> h0 = torch.zeros(V, B, 20, device=device)
53
+ >>> c0 = torch.zeros(V, B, 20, device=device)
54
+ >>> with torch.no_grad():
55
+ ... (h1, c1) = batched_call(x, h0, c0)
56
+ """
57
+
58
+ __doc__ += nn.LSTMCell.__doc__
59
+
60
+ def __init__(
61
+ self,
62
+ input_size: int,
63
+ hidden_size: int,
64
+ bias: bool = True,
65
+ device=None,
66
+ dtype=None,
67
+ ) -> None:
68
+ factory_kwargs = {"device": device, "dtype": dtype}
69
+ super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
70
+
71
+ def forward(
72
+ self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None
73
+ ) -> tuple[Tensor, Tensor]:
74
+ if input.dim() not in (1, 2):
75
+ raise ValueError(
76
+ f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
77
+ )
78
+ if hx is not None:
79
+ for idx, value in enumerate(hx):
80
+ if value.dim() not in (1, 2):
81
+ raise ValueError(
82
+ f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead"
83
+ )
84
+ is_batched = input.dim() == 2
85
+ if not is_batched:
86
+ input = input.unsqueeze(0)
87
+
88
+ if hx is None:
89
+ zeros = torch.zeros(
90
+ input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
91
+ )
92
+ hx = (zeros, zeros)
93
+ else:
94
+ hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
95
+
96
+ ret = self.lstm_cell(input, hx[0], hx[1])
97
+
98
+ if not is_batched:
99
+ ret = (ret[0].squeeze(0), ret[1].squeeze(0))
100
+ return ret
101
+
102
+ def lstm_cell(self, x, hx, cx):
103
+ x = x.view(-1, x.size(1))
104
+
105
+ gates = F.linear(x, self.weight_ih, self.bias_ih) + F.linear(
106
+ hx, self.weight_hh, self.bias_hh
107
+ )
108
+
109
+ i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1)
110
+
111
+ i_gate = i_gate.sigmoid()
112
+ f_gate = f_gate.sigmoid()
113
+ g_gate = g_gate.tanh()
114
+ o_gate = o_gate.sigmoid()
115
+
116
+ cy = cx * f_gate + i_gate * g_gate
117
+
118
+ hy = o_gate * cy.tanh()
119
+
120
+ return hy, cy
121
+
122
+
123
+ # copy LSTM
124
+ class LSTMBase(nn.RNNBase):
125
+ """A Base module for LSTM. Inheriting from LSTMBase enables compatibility with torch.compile."""
126
+
127
+ def __init__(self, *args, **kwargs):
128
+ return super().__init__("LSTM", *args, **kwargs)
129
+
130
+
131
+ for attr in nn.LSTM.__dict__:
132
+ if attr != "__init__":
133
+ setattr(LSTMBase, attr, getattr(nn.LSTM, attr))
134
+
135
+
136
+ class LSTM(LSTMBase):
137
+ """A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like :class:`torch.nn.LSTM`, but this implementation is exclusively coded in Python.
138
+
139
+ .. note::
140
+ This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`.
141
+
142
+ Examples:
143
+ >>> import torch
144
+ >>> from torchrl.modules.tensordict_module.rnn import LSTM
145
+
146
+ >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
147
+ >>> B = 2
148
+ >>> T = 4
149
+ >>> N_IN = 10
150
+ >>> N_OUT = 20
151
+ >>> N_LAYERS = 2
152
+ >>> V = 4 # vector size
153
+ >>> lstm = LSTM(
154
+ ... input_size=N_IN,
155
+ ... hidden_size=N_OUT,
156
+ ... device=device,
157
+ ... num_layers=N_LAYERS,
158
+ ... )
159
+
160
+ # single call
161
+ >>> x = torch.randn(B, T, N_IN, device=device)
162
+ >>> h0 = torch.zeros(N_LAYERS, B, N_OUT, device=device)
163
+ >>> c0 = torch.zeros(N_LAYERS, B, N_OUT, device=device)
164
+ >>> with torch.no_grad():
165
+ ... h1, c1 = lstm(x, (h0, c0))
166
+
167
+ # vectorised call - not possible with nn.LSTM
168
+ >>> def call_lstm(x, h, c):
169
+ ... h_out, c_out = lstm(x, (h, c))
170
+ ... return h_out, c_out
171
+ >>> batched_call = torch.vmap(call_lstm)
172
+ >>> x = torch.randn(V, B, T, 10, device=device)
173
+ >>> h0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device)
174
+ >>> c0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device)
175
+ >>> with torch.no_grad():
176
+ ... h1, c1 = batched_call(x, h0, c0)
177
+ """
178
+
179
+ __doc__ += nn.LSTM.__doc__
180
+
181
+ def __init__(
182
+ self,
183
+ input_size: int,
184
+ hidden_size: int,
185
+ num_layers: int = 1,
186
+ batch_first: bool = True,
187
+ bias: bool = True,
188
+ dropout: float = 0.0,
189
+ bidirectional: float = False,
190
+ proj_size: int = 0,
191
+ device=None,
192
+ dtype=None,
193
+ ) -> None:
194
+
195
+ if bidirectional is True:
196
+ raise NotImplementedError(
197
+ "Bidirectional LSTMs are not supported yet in this implementation."
198
+ )
199
+
200
+ super().__init__(
201
+ input_size=input_size,
202
+ hidden_size=hidden_size,
203
+ num_layers=num_layers,
204
+ bias=bias,
205
+ batch_first=batch_first,
206
+ dropout=dropout,
207
+ bidirectional=bidirectional,
208
+ proj_size=proj_size,
209
+ device=device,
210
+ dtype=dtype,
211
+ )
212
+
213
+ @staticmethod
214
+ def _lstm_cell(x, hx, cx, weight_ih, bias_ih, weight_hh, bias_hh):
215
+
216
+ gates = F.linear(x, weight_ih, bias_ih) + F.linear(hx, weight_hh, bias_hh)
217
+
218
+ i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1)
219
+
220
+ i_gate = i_gate.sigmoid()
221
+ f_gate = f_gate.sigmoid()
222
+ g_gate = g_gate.tanh()
223
+ o_gate = o_gate.sigmoid()
224
+
225
+ cy = cx * f_gate + i_gate * g_gate
226
+
227
+ hy = o_gate * cy.tanh()
228
+
229
+ return hy, cy
230
+
231
+ def _lstm(self, x, hx):
232
+
233
+ h_t, c_t = hx
234
+ h_t, c_t = h_t.unbind(0), c_t.unbind(0)
235
+
236
+ outputs = []
237
+
238
+ weight_ihs = []
239
+ weight_hhs = []
240
+ bias_ihs = []
241
+ bias_hhs = []
242
+ for weights in self._all_weights:
243
+ # Retrieve weights
244
+ weight_ihs.append(getattr(self, weights[0]))
245
+ weight_hhs.append(getattr(self, weights[1]))
246
+ if self.bias:
247
+ bias_ihs.append(getattr(self, weights[2]))
248
+ bias_hhs.append(getattr(self, weights[3]))
249
+ else:
250
+ bias_ihs.append(None)
251
+ bias_hhs.append(None)
252
+
253
+ for x_t in x.unbind(int(self.batch_first)):
254
+ h_t_out = []
255
+ c_t_out = []
256
+
257
+ for layer, (
258
+ weight_ih,
259
+ bias_ih,
260
+ weight_hh,
261
+ bias_hh,
262
+ _h_t,
263
+ _c_t,
264
+ ) in enumerate(zip(weight_ihs, bias_ihs, weight_hhs, bias_hhs, h_t, c_t)):
265
+ # Run cell
266
+ _h_t, _c_t = self._lstm_cell(
267
+ x_t, _h_t, _c_t, weight_ih, bias_ih, weight_hh, bias_hh
268
+ )
269
+ h_t_out.append(_h_t)
270
+ c_t_out.append(_c_t)
271
+
272
+ # Apply dropout if in training mode
273
+ if layer < self.num_layers - 1 and self.dropout:
274
+ x_t = F.dropout(_h_t, p=self.dropout, training=self.training)
275
+ else: # No dropout after the last layer
276
+ x_t = _h_t
277
+ h_t = h_t_out
278
+ c_t = c_t_out
279
+ outputs.append(x_t)
280
+
281
+ outputs = torch.stack(outputs, dim=int(self.batch_first))
282
+
283
+ return outputs, (torch.stack(h_t_out, 0), torch.stack(c_t_out, 0))
284
+
285
+ def forward(self, input, hx=None): # noqa: F811
286
+ real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
287
+ if input.dim() != 3:
288
+ raise ValueError(
289
+ f"LSTM: Expected input to be 3D, got {input.dim()}D instead"
290
+ )
291
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
292
+ if hx is None:
293
+ h_zeros = torch.zeros(
294
+ self.num_layers,
295
+ max_batch_size,
296
+ real_hidden_size,
297
+ dtype=input.dtype,
298
+ device=input.device,
299
+ )
300
+ c_zeros = torch.zeros(
301
+ self.num_layers,
302
+ max_batch_size,
303
+ self.hidden_size,
304
+ dtype=input.dtype,
305
+ device=input.device,
306
+ )
307
+ hx = (h_zeros, c_zeros)
308
+ return self._lstm(input, hx)
309
+
310
+
311
+ class LSTMModule(ModuleBase):
312
+ """An embedder for an LSTM module.
313
+
314
+ This class adds the following functionality to :class:`torch.nn.LSTM`:
315
+
316
+ - Compatibility with TensorDict: the hidden states are reshaped to match
317
+ the tensordict batch size.
318
+ - Optional multi-step execution: with torch.nn, one has to choose between
319
+ :class:`torch.nn.LSTMCell` and :class:`torch.nn.LSTM`, the former being
320
+ compatible with single step inputs and the latter being compatible with
321
+ multi-step. This class enables both usages.
322
+
323
+
324
+ After construction, the module is *not* set in recurrent mode, ie. it will
325
+ expect single steps inputs.
326
+
327
+ If in recurrent mode, it is expected that the last dimension of the tensordict
328
+ marks the number of steps. There is no constrain on the dimensionality of the
329
+ tensordict (except that it must be greater than one for temporal inputs).
330
+
331
+ .. note::
332
+ This class can handle multiple consecutive trajectories along the time dimension
333
+ *but* the final hidden values should not be trusted in those cases (ie. they
334
+ should not be re-used for a consecutive trajectory).
335
+ The reason is that LSTM returns only the last hidden value, which for the
336
+ padded inputs we provide can correspond to a 0-filled input.
337
+
338
+ Args:
339
+ input_size: The number of expected features in the input `x`
340
+ hidden_size: The number of features in the hidden state `h`
341
+ num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
342
+ would mean stacking two LSTMs together to form a `stacked LSTM`,
343
+ with the second LSTM taking in outputs of the first LSTM and
344
+ computing the final results. Default: 1
345
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
346
+ Default: ``True``
347
+ dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
348
+ LSTM layer except the last layer, with dropout probability equal to
349
+ :attr:`dropout`. Default: 0
350
+ python_based: If ``True``, will use a full Python implementation of the LSTM cell. Default: ``False``
351
+
352
+ Keyword Args:
353
+ in_key (str or tuple of str): the input key of the module. Exclusive use
354
+ with ``in_keys``. If provided, the recurrent keys are assumed to be
355
+ ["recurrent_state_h", "recurrent_state_c"] and the ``in_key`` will be
356
+ appended before these.
357
+ in_keys (list of str): a triplet of strings corresponding to the input value,
358
+ first and second hidden key. Exclusive with ``in_key``.
359
+ out_key (str or tuple of str): the output key of the module. Exclusive use
360
+ with ``out_keys``. If provided, the recurrent keys are assumed to be
361
+ [("next", "recurrent_state_h"), ("next", "recurrent_state_c")]
362
+ and the ``out_key`` will be
363
+ appended before these.
364
+ out_keys (list of str): a triplet of strings corresponding to the output value,
365
+ first and second hidden key.
366
+
367
+ .. note::
368
+ For a better integration with TorchRL's environments, the best naming
369
+ for the output hidden key is ``("next", <custom_key>)``, such
370
+ that the hidden values are passed from step to step during a rollout.
371
+
372
+ device (torch.device or compatible): the device of the module.
373
+ lstm (torch.nn.LSTM, optional): an LSTM instance to be wrapped.
374
+ Exclusive with other nn.LSTM arguments.
375
+ default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden
376
+ by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator.
377
+ Defaults to ``False``.
378
+
379
+ Attributes:
380
+ recurrent_mode: Returns the recurrent mode of the module.
381
+
382
+ Methods:
383
+ set_recurrent_mode: controls whether the module should be executed in
384
+ recurrent mode.
385
+ make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
386
+ recurrent states of the RNN.
387
+
388
+ .. note:: This module relies on specific ``recurrent_state`` keys being present in the input
389
+ TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
390
+ add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`.
391
+ If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called
392
+ on the parent module to automatically generate the primer transforms required for all submodules, including this one.
393
+
394
+
395
+ Examples:
396
+ >>> from torchrl.envs import TransformedEnv, InitTracker
397
+ >>> from torchrl.envs import GymEnv
398
+ >>> from torchrl.modules import MLP, LSTMModule
399
+ >>> from torch import nn
400
+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
401
+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
402
+ >>> lstm_module = LSTMModule(
403
+ ... input_size=env.observation_spec["observation"].shape[-1],
404
+ ... hidden_size=64,
405
+ ... in_keys=["observation", "rs_h", "rs_c"],
406
+ ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
407
+ >>> mlp = MLP(num_cells=[64], out_features=1)
408
+ >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
409
+ >>> policy(env.reset())
410
+ TensorDict(
411
+ fields={
412
+ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
413
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
414
+ intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
415
+ is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
416
+ next: TensorDict(
417
+ fields={
418
+ rs_c: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
419
+ rs_h: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
420
+ batch_size=torch.Size([]),
421
+ device=cpu,
422
+ is_shared=False),
423
+ observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
424
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
425
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
426
+ batch_size=torch.Size([]),
427
+ device=cpu,
428
+ is_shared=False)
429
+
430
+ """
431
+
432
+ DEFAULT_IN_KEYS = ["recurrent_state_h", "recurrent_state_c"]
433
+ DEFAULT_OUT_KEYS = [("next", "recurrent_state_h"), ("next", "recurrent_state_c")]
434
+
435
+ def __init__(
436
+ self,
437
+ input_size: int | None = None,
438
+ hidden_size: int | None = None,
439
+ num_layers: int = 1,
440
+ bias: bool = True,
441
+ batch_first=True,
442
+ dropout=0,
443
+ proj_size=0,
444
+ bidirectional=False,
445
+ python_based=False,
446
+ *,
447
+ in_key=None,
448
+ in_keys=None,
449
+ out_key=None,
450
+ out_keys=None,
451
+ device=None,
452
+ lstm=None,
453
+ default_recurrent_mode: bool | None = None,
454
+ ):
455
+ super().__init__()
456
+ if lstm is not None:
457
+ if not lstm.batch_first:
458
+ raise ValueError("The input lstm must have batch_first=True.")
459
+ if lstm.bidirectional:
460
+ raise ValueError("The input lstm cannot be bidirectional.")
461
+ if input_size is not None or hidden_size is not None:
462
+ raise ValueError(
463
+ "An LSTM instance cannot be passed along with class argument."
464
+ )
465
+ else:
466
+ if not batch_first:
467
+ raise ValueError("The input lstm must have batch_first=True.")
468
+ if bidirectional:
469
+ raise ValueError("The input lstm cannot be bidirectional.")
470
+ if not hidden_size:
471
+ raise ValueError("hidden_size must be passed.")
472
+ if python_based:
473
+ lstm = LSTM(
474
+ input_size=input_size,
475
+ hidden_size=hidden_size,
476
+ num_layers=num_layers,
477
+ bias=bias,
478
+ dropout=dropout,
479
+ proj_size=proj_size,
480
+ device=device,
481
+ batch_first=True,
482
+ bidirectional=False,
483
+ )
484
+ else:
485
+ lstm = nn.LSTM(
486
+ input_size=input_size,
487
+ hidden_size=hidden_size,
488
+ num_layers=num_layers,
489
+ bias=bias,
490
+ dropout=dropout,
491
+ proj_size=proj_size,
492
+ device=device,
493
+ batch_first=True,
494
+ bidirectional=False,
495
+ )
496
+ if not ((in_key is None) ^ (in_keys is None)):
497
+ raise ValueError(
498
+ f"Either in_keys or in_key must be specified but not both or none. Got {in_keys} and {in_key} respectively."
499
+ )
500
+ elif in_key:
501
+ in_keys = [in_key, *self.DEFAULT_IN_KEYS]
502
+
503
+ if not ((out_key is None) ^ (out_keys is None)):
504
+ raise ValueError(
505
+ f"Either out_keys or out_key must be specified but not both or none. Got {out_keys} and {out_key} respectively."
506
+ )
507
+ elif out_key:
508
+ out_keys = [out_key, *self.DEFAULT_OUT_KEYS]
509
+
510
+ in_keys = unravel_key_list(in_keys)
511
+ out_keys = unravel_key_list(out_keys)
512
+ if not isinstance(in_keys, (tuple, list)) or (
513
+ len(in_keys) != 3 and not (len(in_keys) == 4 and in_keys[-1] == "is_init")
514
+ ):
515
+ raise ValueError(
516
+ f"LSTMModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead."
517
+ )
518
+ if not isinstance(out_keys, (tuple, list)) or len(out_keys) != 3:
519
+ raise ValueError(
520
+ f"LSTMModule expects 3 outputs: a value, and two hidden states. Got out_keys {out_keys} instead."
521
+ )
522
+ self.lstm = lstm
523
+ if "is_init" not in in_keys:
524
+ in_keys = in_keys + ["is_init"]
525
+ self.in_keys = in_keys
526
+ self.out_keys = out_keys
527
+ self._recurrent_mode = default_recurrent_mode
528
+
529
+ def make_python_based(self) -> LSTMModule:
530
+ """Transforms the LSTM layer in its python-based version.
531
+
532
+ Returns:
533
+ self
534
+
535
+ """
536
+ if isinstance(self.lstm, LSTM):
537
+ return self
538
+ lstm = LSTM(
539
+ input_size=self.lstm.input_size,
540
+ hidden_size=self.lstm.hidden_size,
541
+ num_layers=self.lstm.num_layers,
542
+ bias=self.lstm.bias,
543
+ dropout=self.lstm.dropout,
544
+ proj_size=self.lstm.proj_size,
545
+ device="meta",
546
+ batch_first=self.lstm.batch_first,
547
+ bidirectional=self.lstm.bidirectional,
548
+ )
549
+ from tensordict import from_module
550
+
551
+ from_module(self.lstm).to_module(lstm)
552
+ self.lstm = lstm
553
+ return self
554
+
555
+ def make_cudnn_based(self) -> LSTMModule:
556
+ """Transforms the LSTM layer in its CuDNN-based version.
557
+
558
+ Returns:
559
+ self
560
+
561
+ """
562
+ if isinstance(self.lstm, nn.LSTM):
563
+ return self
564
+ lstm = nn.LSTM(
565
+ input_size=self.lstm.input_size,
566
+ hidden_size=self.lstm.hidden_size,
567
+ num_layers=self.lstm.num_layers,
568
+ bias=self.lstm.bias,
569
+ dropout=self.lstm.dropout,
570
+ proj_size=self.lstm.proj_size,
571
+ device="meta",
572
+ batch_first=self.lstm.batch_first,
573
+ bidirectional=self.lstm.bidirectional,
574
+ )
575
+ from tensordict import from_module
576
+
577
+ from_module(self.lstm).to_module(lstm)
578
+ self.lstm = lstm
579
+ return self
580
+
581
+ def make_tensordict_primer(self):
582
+ """Makes a tensordict primer for the environment.
583
+
584
+ A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
585
+ inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
586
+ processes and dealt with properly.
587
+
588
+ When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the
589
+ single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the
590
+ batched env instance level (i.e., a transformed batch of regular envs).
591
+
592
+ Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance
593
+ in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
594
+ tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
595
+ are not registered within the environment specs.
596
+
597
+ See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given
598
+ module.
599
+
600
+ Examples:
601
+ >>> from torchrl.collectors import Collector
602
+ >>> from torchrl.envs import TransformedEnv, InitTracker
603
+ >>> from torchrl.envs import GymEnv
604
+ >>> from torchrl.modules import MLP, LSTMModule
605
+ >>> from torch import nn
606
+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
607
+ >>>
608
+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
609
+ >>> lstm_module = LSTMModule(
610
+ ... input_size=env.observation_spec["observation"].shape[-1],
611
+ ... hidden_size=64,
612
+ ... in_keys=["observation", "rs_h", "rs_c"],
613
+ ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
614
+ >>> mlp = MLP(num_cells=[64], out_features=1)
615
+ >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
616
+ >>> policy(env.reset())
617
+ >>> env = env.append_transform(lstm_module.make_tensordict_primer())
618
+ >>> data_collector = Collector(
619
+ ... env,
620
+ ... policy,
621
+ ... frames_per_batch=10
622
+ ... )
623
+ >>> for data in data_collector:
624
+ ... print(data)
625
+ ... break
626
+
627
+ """
628
+ from torchrl.envs.transforms.transforms import TensorDictPrimer
629
+
630
+ def make_tuple(key):
631
+ if isinstance(key, tuple):
632
+ return key
633
+ return (key,)
634
+
635
+ out_key1 = make_tuple(self.out_keys[1])
636
+ in_key1 = make_tuple(self.in_keys[1])
637
+ out_key2 = make_tuple(self.out_keys[2])
638
+ in_key2 = make_tuple(self.in_keys[2])
639
+ if out_key1 != ("next", *in_key1) or out_key2 != ("next", *in_key2):
640
+ raise RuntimeError(
641
+ "make_tensordict_primer is supposed to work with in_keys/out_keys that "
642
+ "have compatible names, ie. the out_keys should be named after ('next', <in_key>). Got "
643
+ f"in_keys={self.in_keys} and out_keys={self.out_keys} instead."
644
+ )
645
+ return TensorDictPrimer(
646
+ {
647
+ in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)),
648
+ in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)),
649
+ },
650
+ expand_specs=True,
651
+ )
652
+
653
+ @property
654
+ def recurrent_mode(self):
655
+ rm = recurrent_mode()
656
+ if rm is None:
657
+ return bool(self._recurrent_mode)
658
+ return rm
659
+
660
+ @recurrent_mode.setter
661
+ def recurrent_mode(self, value):
662
+ raise RuntimeError(
663
+ "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager."
664
+ )
665
+
666
+ @property
667
+ def temporal_mode(self):
668
+ raise RuntimeError(
669
+ "temporal_mode is deprecated, use recurrent_mode instead.",
670
+ )
671
+
672
+ def set_recurrent_mode(self, mode: bool = True):
673
+ raise RuntimeError(
674
+ "The lstm.set_recurrent_mode() API has been removed in v0.8. "
675
+ "To set the recurrent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or "
676
+ "the `default_recurrent_mode` keyword argument in the constructor."
677
+ )
678
+
679
+ @dispatch
680
+ def forward(self, tensordict: TensorDictBase):
681
+ from torchrl.objectives.value.functional import (
682
+ _inv_pad_sequence,
683
+ _split_and_pad_sequence,
684
+ )
685
+
686
+ # we want to get an error if the value input is missing, but not the hidden states
687
+ defaults = [NO_DEFAULT, None, None]
688
+ shape = tensordict.shape
689
+ tensordict_shaped = tensordict
690
+ if self.recurrent_mode:
691
+ # if less than 2 dims, unsqueeze
692
+ ndim = tensordict_shaped.get(self.in_keys[0]).ndim
693
+ while ndim < 3:
694
+ tensordict_shaped = tensordict_shaped.unsqueeze(0)
695
+ ndim += 1
696
+ if ndim > 3:
697
+ dims_to_flatten = ndim - 3
698
+ # we assume that the tensordict can be flattened like this
699
+ nelts = prod(tensordict_shaped.shape[: dims_to_flatten + 1])
700
+ tensordict_shaped = tensordict_shaped.apply(
701
+ lambda value: value.flatten(0, dims_to_flatten),
702
+ batch_size=[nelts, tensordict_shaped.shape[-1]],
703
+ )
704
+ else:
705
+ tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
706
+
707
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
708
+ splits = None
709
+ if self.recurrent_mode and is_init[..., 1:].any():
710
+ from torchrl.objectives.value.utils import _get_num_per_traj_init
711
+
712
+ # if we have consecutive trajectories, things get a little more complicated
713
+ # we have a tensordict of shape [B, T]
714
+ # we will split / pad things such that we get a tensordict of shape
715
+ # [N, T'] where T' <= T and N >= B is the new batch size, such that
716
+ # each index of N is an independent trajectory. We'll need to keep
717
+ # track of the indices though, as we want to put things back together in the end.
718
+ splits = _get_num_per_traj_init(is_init)
719
+ tensordict_shaped_shape = tensordict_shaped.shape
720
+ tensordict_shaped = _split_and_pad_sequence(
721
+ tensordict_shaped.select(*self.in_keys, strict=False), splits
722
+ )
723
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
724
+
725
+ value, hidden0, hidden1 = (
726
+ tensordict_shaped.get(key, default)
727
+ for key, default in zip(self.in_keys, defaults)
728
+ )
729
+ # packed sequences do not help to get the accurate last hidden values
730
+ # if splits is not None:
731
+ # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
732
+
733
+ if not self.recurrent_mode and hidden0 is not None:
734
+ # We zero the hidden states if we're calling the lstm recursively
735
+ # as we assume the hidden state comes from the previous trajectory.
736
+ # When using the recurrent_mode=True option, the lstm can be called from
737
+ # any intermediate state, hence zeroing should not be done.
738
+ is_init_expand = expand_as_right(is_init, hidden0)
739
+ zeros = torch.zeros_like(hidden0)
740
+ hidden0 = torch.where(is_init_expand, zeros, hidden0)
741
+ hidden1 = torch.where(is_init_expand, zeros, hidden1)
742
+
743
+ batch, steps = value.shape[:2]
744
+ device = value.device
745
+ dtype = value.dtype
746
+
747
+ val, hidden0, hidden1 = self._lstm(
748
+ value, batch, steps, device, dtype, hidden0, hidden1
749
+ )
750
+ tensordict_shaped.set(self.out_keys[0], val)
751
+ tensordict_shaped.set(self.out_keys[1], hidden0)
752
+ tensordict_shaped.set(self.out_keys[2], hidden1)
753
+ if splits is not None:
754
+ # let's recover our original shape
755
+ tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape(
756
+ tensordict_shaped_shape
757
+ )
758
+
759
+ if shape != tensordict_shaped.shape or tensordict_shaped is not tensordict:
760
+ tensordict.update(tensordict_shaped.reshape(shape))
761
+ return tensordict
762
+
763
+ def _lstm(
764
+ self,
765
+ input: torch.Tensor,
766
+ batch,
767
+ steps,
768
+ device,
769
+ dtype,
770
+ hidden0_in: torch.Tensor | None = None,
771
+ hidden1_in: torch.Tensor | None = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ if not self.recurrent_mode and steps != 1:
775
+ raise ValueError("Expected a single step")
776
+
777
+ if hidden1_in is None and hidden0_in is None:
778
+ shape = (batch, steps)
779
+ hidden0_in, hidden1_in = (
780
+ torch.zeros(
781
+ *shape,
782
+ self.lstm.num_layers,
783
+ self.lstm.hidden_size,
784
+ device=device,
785
+ dtype=dtype,
786
+ )
787
+ for _ in range(2)
788
+ )
789
+ elif hidden1_in is None or hidden0_in is None:
790
+ raise RuntimeError(
791
+ f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}"
792
+ )
793
+
794
+ # we only need the first hidden state
795
+ _hidden0_in = hidden0_in[..., 0, :, :]
796
+ _hidden1_in = hidden1_in[..., 0, :, :]
797
+ hidden = (
798
+ _hidden0_in.transpose(-3, -2).contiguous(),
799
+ _hidden1_in.transpose(-3, -2).contiguous(),
800
+ )
801
+
802
+ y, hidden = self.lstm(input, hidden)
803
+ # dim 0 in hidden is num_layers, but that will conflict with tensordict
804
+ hidden = tuple(_h.transpose(0, 1) for _h in hidden)
805
+
806
+ out = [y, *hidden]
807
+ # we pad the hidden states with zero to make tensordict happy
808
+ for i in range(1, 3):
809
+ out[i] = torch.stack(
810
+ [torch.zeros_like(out[i]) for _ in range(steps - 1)] + [out[i]],
811
+ 1,
812
+ )
813
+ return tuple(out)
814
+
815
+
816
+ class GRUCell(RNNCellBase):
817
+ r"""A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
818
+
819
+ .. note::
820
+ This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`.
821
+
822
+ Examples:
823
+ >>> import torch
824
+ >>> from torchrl.modules.tensordict_module.rnn import GRUCell
825
+ >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
826
+ >>> B = 2
827
+ >>> N_IN = 10
828
+ >>> N_OUT = 20
829
+ >>> V = 4 # vector size
830
+ >>> gru_cell = GRUCell(input_size=N_IN, hidden_size=N_OUT, device=device)
831
+
832
+ # single call
833
+ >>> x = torch.randn(B, 10, device=device)
834
+ >>> h0 = torch.zeros(B, 20, device=device)
835
+ >>> with torch.no_grad():
836
+ ... h1 = gru_cell(x, h0)
837
+
838
+ # vectorised call - not possible with nn.GRUCell
839
+ >>> def call_gru(x, h):
840
+ ... h_out = gru_cell(x, h)
841
+ ... return h_out
842
+ >>> batched_call = torch.vmap(call_gru)
843
+ >>> x = torch.randn(V, B, 10, device=device)
844
+ >>> h0 = torch.zeros(V, B, 20, device=device)
845
+ >>> with torch.no_grad():
846
+ ... h1 = batched_call(x, h0)
847
+ """
848
+
849
+ __doc__ += nn.GRUCell.__doc__
850
+
851
+ def __init__(
852
+ self,
853
+ input_size: int,
854
+ hidden_size: int,
855
+ bias: bool = True,
856
+ device=None,
857
+ dtype=None,
858
+ ) -> None:
859
+ factory_kwargs = {"device": device, "dtype": dtype}
860
+ super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
861
+
862
+ def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor:
863
+ if input.dim() not in (1, 2):
864
+ raise ValueError(
865
+ f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
866
+ )
867
+ if hx is not None and hx.dim() not in (1, 2):
868
+ raise ValueError(
869
+ f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
870
+ )
871
+ is_batched = input.dim() == 2
872
+ if not is_batched:
873
+ input = input.unsqueeze(0)
874
+
875
+ if hx is None:
876
+ hx = torch.zeros(
877
+ input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
878
+ )
879
+ else:
880
+ hx = hx.unsqueeze(0) if not is_batched else hx
881
+
882
+ ret = self.gru_cell(input, hx)
883
+
884
+ if not is_batched:
885
+ ret = ret.squeeze(0)
886
+
887
+ return ret
888
+
889
+ def gru_cell(self, x, hx):
890
+
891
+ x = x.view(-1, x.size(1))
892
+
893
+ gate_x = F.linear(x, self.weight_ih, self.bias_ih)
894
+ gate_h = F.linear(hx, self.weight_hh, self.bias_hh)
895
+
896
+ i_r, i_i, i_n = gate_x.chunk(3, 1)
897
+ h_r, h_i, h_n = gate_h.chunk(3, 1)
898
+
899
+ resetgate = F.sigmoid(i_r + h_r)
900
+ inputgate = F.sigmoid(i_i + h_i)
901
+ newgate = F.tanh(i_n + (resetgate * h_n))
902
+
903
+ hy = newgate + inputgate * (hx - newgate)
904
+
905
+ return hy
906
+
907
+
908
+ # copy GRU
909
+ class GRUBase(nn.RNNBase):
910
+ """A Base module for GRU. Inheriting from GRUBase enables compatibility with torch.compile."""
911
+
912
+ def __init__(self, *args, **kwargs):
913
+ return super().__init__("GRU", *args, **kwargs)
914
+
915
+
916
+ for attr in nn.GRU.__dict__:
917
+ if attr != "__init__":
918
+ setattr(GRUBase, attr, getattr(nn.GRU, attr))
919
+
920
+
921
+ class GRU(GRUBase):
922
+ """A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like :class:`torch.nn.GRU`, but this implementation is exclusively coded in Python.
923
+
924
+ .. note::
925
+ This class is implemented without relying on CuDNN, which makes it
926
+ compatible with :func:`torch.vmap` and :func:`torch.compile`.
927
+
928
+ Examples:
929
+ >>> import torch
930
+ >>> from torchrl.modules.tensordict_module.rnn import GRU
931
+
932
+ >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
933
+ >>> B = 2
934
+ >>> T = 4
935
+ >>> N_IN = 10
936
+ >>> N_OUT = 20
937
+ >>> N_LAYERS = 2
938
+ >>> V = 4 # vector size
939
+ >>> gru = GRU(
940
+ ... input_size=N_IN,
941
+ ... hidden_size=N_OUT,
942
+ ... device=device,
943
+ ... num_layers=N_LAYERS,
944
+ ... )
945
+
946
+ # single call
947
+ >>> x = torch.randn(B, T, N_IN, device=device)
948
+ >>> h0 = torch.zeros(N_LAYERS, B, N_OUT, device=device)
949
+ >>> with torch.no_grad():
950
+ ... h1 = gru(x, h0)
951
+
952
+ # vectorised call - not possible with nn.GRU
953
+ >>> def call_gru(x, h):
954
+ ... h_out = gru(x, h)
955
+ ... return h_out
956
+ >>> batched_call = torch.vmap(call_gru)
957
+ >>> x = torch.randn(V, B, T, 10, device=device)
958
+ >>> h0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device)
959
+ >>> with torch.no_grad():
960
+ ... h1 = batched_call(x, h0)
961
+ """
962
+
963
+ __doc__ += nn.GRU.__doc__
964
+
965
+ def __init__(
966
+ self,
967
+ input_size: int,
968
+ hidden_size: int,
969
+ num_layers: int = 1,
970
+ bias: bool = True,
971
+ batch_first: bool = True,
972
+ dropout: float = 0.0,
973
+ bidirectional: bool = False,
974
+ device=None,
975
+ dtype=None,
976
+ ) -> None:
977
+
978
+ if bidirectional:
979
+ raise NotImplementedError(
980
+ "Bidirectional LSTMs are not supported yet in this implementation."
981
+ )
982
+
983
+ super().__init__(
984
+ input_size=input_size,
985
+ hidden_size=hidden_size,
986
+ num_layers=num_layers,
987
+ bias=bias,
988
+ batch_first=batch_first,
989
+ dropout=dropout,
990
+ bidirectional=False,
991
+ device=device,
992
+ dtype=dtype,
993
+ )
994
+
995
+ @staticmethod
996
+ def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
997
+ x = x.view(-1, x.size(1))
998
+
999
+ gate_x = F.linear(x, weight_ih, bias_ih)
1000
+ gate_h = F.linear(hx, weight_hh, bias_hh)
1001
+
1002
+ i_r, i_i, i_n = gate_x.chunk(3, 1)
1003
+ h_r, h_i, h_n = gate_h.chunk(3, 1)
1004
+
1005
+ resetgate = (i_r + h_r).sigmoid()
1006
+ inputgate = (i_i + h_i).sigmoid()
1007
+ newgate = (i_n + (resetgate * h_n)).tanh()
1008
+
1009
+ hy = newgate + inputgate * (hx - newgate)
1010
+
1011
+ return hy
1012
+
1013
+ def _gru(self, x, hx):
1014
+
1015
+ if not self.batch_first:
1016
+ x = x.permute(
1017
+ 1, 0, 2
1018
+ ) # Change (seq_len, batch, features) to (batch, seq_len, features)
1019
+
1020
+ bs, seq_len, input_size = x.size()
1021
+ h_t = list(hx.unbind(0))
1022
+
1023
+ weight_ih = []
1024
+ weight_hh = []
1025
+ bias_ih = []
1026
+ bias_hh = []
1027
+ for layer in range(self.num_layers):
1028
+
1029
+ # Retrieve weights
1030
+ weights = self._all_weights[layer]
1031
+ weight_ih.append(getattr(self, weights[0]))
1032
+ weight_hh.append(getattr(self, weights[1]))
1033
+ if self.bias:
1034
+ bias_ih.append(getattr(self, weights[2]))
1035
+ bias_hh.append(getattr(self, weights[3]))
1036
+ else:
1037
+ bias_ih.append(None)
1038
+ bias_hh.append(None)
1039
+
1040
+ outputs = []
1041
+
1042
+ for x_t in x.unbind(1):
1043
+ for layer in range(self.num_layers):
1044
+ h_t[layer] = self._gru_cell(
1045
+ x_t,
1046
+ h_t[layer],
1047
+ weight_ih[layer],
1048
+ bias_ih[layer],
1049
+ weight_hh[layer],
1050
+ bias_hh[layer],
1051
+ )
1052
+
1053
+ # Apply dropout if in training mode and not the last layer
1054
+ if layer < self.num_layers - 1 and self.dropout:
1055
+ x_t = F.dropout(h_t[layer], p=self.dropout, training=self.training)
1056
+ else:
1057
+ x_t = h_t[layer]
1058
+
1059
+ outputs.append(x_t)
1060
+
1061
+ outputs = torch.stack(outputs, dim=1)
1062
+ if not self.batch_first:
1063
+ outputs = outputs.permute(
1064
+ 1, 0, 2
1065
+ ) # Change back (batch, seq_len, features) to (seq_len, batch, features)
1066
+
1067
+ return outputs, torch.stack(h_t, 0)
1068
+
1069
+ def forward(self, input, hx=None): # noqa: F811
1070
+ if input.dim() != 3:
1071
+ raise ValueError(
1072
+ f"GRU: Expected input to be 3D, got {input.dim()}D instead"
1073
+ )
1074
+ if hx is not None and hx.dim() != 3:
1075
+ raise RuntimeError(
1076
+ f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
1077
+ )
1078
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
1079
+ if hx is None:
1080
+ hx = torch.zeros(
1081
+ self.num_layers,
1082
+ max_batch_size,
1083
+ self.hidden_size,
1084
+ dtype=input.dtype,
1085
+ device=input.device,
1086
+ )
1087
+
1088
+ self.check_forward_args(input, hx, batch_sizes=None)
1089
+ result = self._gru(input, hx)
1090
+
1091
+ output = result[0]
1092
+ hidden = result[1]
1093
+
1094
+ return output, hidden
1095
+
1096
+
1097
+ class GRUModule(ModuleBase):
1098
+ """An embedder for an GRU module.
1099
+
1100
+ This class adds the following functionality to :class:`torch.nn.GRU`:
1101
+
1102
+ - Compatibility with TensorDict: the hidden states are reshaped to match
1103
+ the tensordict batch size.
1104
+ - Optional multi-step execution: with torch.nn, one has to choose between
1105
+ :class:`torch.nn.GRUCell` and :class:`torch.nn.GRU`, the former being
1106
+ compatible with single step inputs and the latter being compatible with
1107
+ multi-step. This class enables both usages.
1108
+
1109
+
1110
+ After construction, the module is *not* set in recurrent mode, ie. it will
1111
+ expect single steps inputs.
1112
+
1113
+ If in recurrent mode, it is expected that the last dimension of the tensordict
1114
+ marks the number of steps. There is no constrain on the dimensionality of the
1115
+ tensordict (except that it must be greater than one for temporal inputs).
1116
+
1117
+ Args:
1118
+ input_size: The number of expected features in the input `x`
1119
+ hidden_size: The number of features in the hidden state `h`
1120
+ num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
1121
+ would mean stacking two GRUs together to form a `stacked GRU`,
1122
+ with the second GRU taking in outputs of the first GRU and
1123
+ computing the final results. Default: 1
1124
+ bias: If ``False``, then the layer does not use bias weights.
1125
+ Default: ``True``
1126
+ dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
1127
+ GRU layer except the last layer, with dropout probability equal to
1128
+ :attr:`dropout`. Default: 0
1129
+ python_based: If ``True``, will use a full Python implementation of the GRU cell. Default: ``False``
1130
+
1131
+ Keyword Args:
1132
+ in_key (str or tuple of str): the input key of the module. Exclusive use
1133
+ with ``in_keys``. If provided, the recurrent keys are assumed to be
1134
+ ["recurrent_state"] and the ``in_key`` will be
1135
+ appended before this.
1136
+ in_keys (list of str): a pair of strings corresponding to the input value and recurrent entry.
1137
+ Exclusive with ``in_key``.
1138
+ out_key (str or tuple of str): the output key of the module. Exclusive use
1139
+ with ``out_keys``. If provided, the recurrent keys are assumed to be
1140
+ [("recurrent_state")] and the ``out_key`` will be
1141
+ appended before these.
1142
+ out_keys (list of str): a pair of strings corresponding to the output value,
1143
+ first and second hidden key.
1144
+
1145
+ .. note::
1146
+ For a better integration with TorchRL's environments, the best naming
1147
+ for the output hidden key is ``("next", <custom_key>)``, such
1148
+ that the hidden values are passed from step to step during a rollout.
1149
+
1150
+ device (torch.device or compatible): the device of the module.
1151
+ gru (torch.nn.GRU, optional): a GRU instance to be wrapped.
1152
+ Exclusive with other nn.GRU arguments.
1153
+ default_recurrent_mode (bool, optional): if provided, the recurrent mode if it hasn't been overridden
1154
+ by the :class:`~torchrl.modules.set_recurrent_mode` context manager / decorator.
1155
+ Defaults to ``False``.
1156
+
1157
+ Attributes:
1158
+ recurrent_mode: Returns the recurrent mode of the module.
1159
+
1160
+ Methods:
1161
+ set_recurrent_mode: controls whether the module should be executed in
1162
+ recurrent mode.
1163
+ make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
1164
+ recurrent states of the RNN.
1165
+
1166
+ .. note:: This module relies on specific ``recurrent_state`` keys being present in the input
1167
+ TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
1168
+ add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`.
1169
+ If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called
1170
+ on the parent module to automatically generate the primer transforms required for all submodules, including this one.
1171
+
1172
+ Examples:
1173
+ >>> from torchrl.envs import TransformedEnv, InitTracker
1174
+ >>> from torchrl.envs import GymEnv
1175
+ >>> from torchrl.modules import MLP
1176
+ >>> from torch import nn
1177
+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
1178
+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
1179
+ >>> gru_module = GRUModule(
1180
+ ... input_size=env.observation_spec["observation"].shape[-1],
1181
+ ... hidden_size=64,
1182
+ ... in_keys=["observation", "rs"],
1183
+ ... out_keys=["intermediate", ("next", "rs")])
1184
+ >>> mlp = MLP(num_cells=[64], out_features=1)
1185
+ >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
1186
+ >>> policy(env.reset())
1187
+ TensorDict(
1188
+ fields={
1189
+ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
1190
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
1191
+ intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
1192
+ is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
1193
+ next: TensorDict(
1194
+ fields={
1195
+ rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
1196
+ batch_size=torch.Size([]),
1197
+ device=cpu,
1198
+ is_shared=False),
1199
+ observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
1200
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
1201
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
1202
+ batch_size=torch.Size([]),
1203
+ device=cpu,
1204
+ is_shared=False)
1205
+ >>> gru_module_training = gru_module.set_recurrent_mode()
1206
+ >>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
1207
+ >>> traj_td = env.rollout(3) # some random temporal data
1208
+ >>> traj_td = policy_training(traj_td)
1209
+ >>> print(traj_td)
1210
+ TensorDict(
1211
+ fields={
1212
+ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1213
+ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1214
+ intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
1215
+ is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1216
+ next: TensorDict(
1217
+ fields={
1218
+ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1219
+ is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1220
+ observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1221
+ reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1222
+ rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
1223
+ terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1224
+ truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1225
+ batch_size=torch.Size([3]),
1226
+ device=cpu,
1227
+ is_shared=False),
1228
+ observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
1229
+ terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1230
+ truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1231
+ batch_size=torch.Size([3]),
1232
+ device=cpu,
1233
+ is_shared=False)
1234
+
1235
+ """
1236
+
1237
+ DEFAULT_IN_KEYS = ["recurrent_state"]
1238
+ DEFAULT_OUT_KEYS = [("next", "recurrent_state")]
1239
+
1240
+ def __init__(
1241
+ self,
1242
+ input_size: int | None = None,
1243
+ hidden_size: int | None = None,
1244
+ num_layers: int = 1,
1245
+ bias: bool = True,
1246
+ batch_first=True,
1247
+ dropout=0,
1248
+ bidirectional=False,
1249
+ python_based=False,
1250
+ *,
1251
+ in_key=None,
1252
+ in_keys=None,
1253
+ out_key=None,
1254
+ out_keys=None,
1255
+ device=None,
1256
+ gru=None,
1257
+ default_recurrent_mode: bool | None = None,
1258
+ ):
1259
+ super().__init__()
1260
+ if gru is not None:
1261
+ if not gru.batch_first:
1262
+ raise ValueError("The input gru must have batch_first=True.")
1263
+ if gru.bidirectional:
1264
+ raise ValueError("The input gru cannot be bidirectional.")
1265
+ if input_size is not None or hidden_size is not None:
1266
+ raise ValueError(
1267
+ "An GRU instance cannot be passed along with class argument."
1268
+ )
1269
+ else:
1270
+ if not batch_first:
1271
+ raise ValueError("The input gru must have batch_first=True.")
1272
+ if bidirectional:
1273
+ raise ValueError("The input gru cannot be bidirectional.")
1274
+
1275
+ if python_based:
1276
+ gru = GRU(
1277
+ input_size=input_size,
1278
+ hidden_size=hidden_size,
1279
+ num_layers=num_layers,
1280
+ bias=bias,
1281
+ dropout=dropout,
1282
+ device=device,
1283
+ batch_first=True,
1284
+ bidirectional=False,
1285
+ )
1286
+ else:
1287
+ gru = nn.GRU(
1288
+ input_size=input_size,
1289
+ hidden_size=hidden_size,
1290
+ num_layers=num_layers,
1291
+ bias=bias,
1292
+ dropout=dropout,
1293
+ device=device,
1294
+ batch_first=True,
1295
+ bidirectional=False,
1296
+ )
1297
+ if not ((in_key is None) ^ (in_keys is None)):
1298
+ raise ValueError(
1299
+ f"Either in_keys or in_key must be specified but not both or none. Got {in_keys} and {in_key} respectively."
1300
+ )
1301
+ elif in_key:
1302
+ in_keys = [in_key, *self.DEFAULT_IN_KEYS]
1303
+
1304
+ if not ((out_key is None) ^ (out_keys is None)):
1305
+ raise ValueError(
1306
+ f"Either out_keys or out_key must be specified but not both or none. Got {out_keys} and {out_key} respectively."
1307
+ )
1308
+ elif out_key:
1309
+ out_keys = [out_key, *self.DEFAULT_OUT_KEYS]
1310
+
1311
+ in_keys = unravel_key_list(in_keys)
1312
+ out_keys = unravel_key_list(out_keys)
1313
+ if not isinstance(in_keys, (tuple, list)) or (
1314
+ len(in_keys) != 2 and not (len(in_keys) == 3 and in_keys[-1] == "is_init")
1315
+ ):
1316
+ raise ValueError(
1317
+ f"GRUModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead."
1318
+ )
1319
+ if not isinstance(out_keys, (tuple, list)) or len(out_keys) != 2:
1320
+ raise ValueError(
1321
+ f"GRUModule expects 3 outputs: a value, and two hidden states. Got out_keys {out_keys} instead."
1322
+ )
1323
+ self.gru = gru
1324
+ if "is_init" not in in_keys:
1325
+ in_keys = in_keys + ["is_init"]
1326
+ self.in_keys = in_keys
1327
+ self.out_keys = out_keys
1328
+ self._recurrent_mode = default_recurrent_mode
1329
+
1330
+ def make_python_based(self) -> GRUModule:
1331
+ """Transforms the GRU layer in its python-based version.
1332
+
1333
+ Returns:
1334
+ self
1335
+
1336
+ """
1337
+ if isinstance(self.gru, GRU):
1338
+ return self
1339
+ gru = GRU(
1340
+ input_size=self.gru.input_size,
1341
+ hidden_size=self.gru.hidden_size,
1342
+ num_layers=self.gru.num_layers,
1343
+ bias=self.gru.bias,
1344
+ dropout=self.gru.dropout,
1345
+ device="meta",
1346
+ batch_first=self.gru.batch_first,
1347
+ bidirectional=self.gru.bidirectional,
1348
+ )
1349
+ from tensordict import from_module
1350
+
1351
+ from_module(self.gru).to_module(gru)
1352
+ self.gru = gru
1353
+ return self
1354
+
1355
+ def make_cudnn_based(self) -> GRUModule:
1356
+ """Transforms the GRU layer in its CuDNN-based version.
1357
+
1358
+ Returns:
1359
+ self
1360
+
1361
+ """
1362
+ if isinstance(self.gru, nn.GRU):
1363
+ return self
1364
+ gru = nn.GRU(
1365
+ input_size=self.gru.input_size,
1366
+ hidden_size=self.gru.hidden_size,
1367
+ num_layers=self.gru.num_layers,
1368
+ bias=self.gru.bias,
1369
+ dropout=self.gru.dropout,
1370
+ device="meta",
1371
+ batch_first=self.gru.batch_first,
1372
+ bidirectional=self.gru.bidirectional,
1373
+ )
1374
+ from tensordict import from_module
1375
+
1376
+ from_module(self.gru).to_module(gru)
1377
+ self.gru = gru
1378
+ return self
1379
+
1380
+ def make_tensordict_primer(self):
1381
+ """Makes a tensordict primer for the environment.
1382
+
1383
+ A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
1384
+ inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
1385
+ processes and dealt with properly.
1386
+
1387
+ Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance
1388
+ in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
1389
+ tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
1390
+ are not registered within the environment specs.
1391
+
1392
+ When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the
1393
+ single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the
1394
+ batched env instance level (i.e., a transformed batch of regular envs).
1395
+
1396
+ See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given
1397
+ module.
1398
+
1399
+ Examples:
1400
+ >>> from torchrl.collectors import Collector
1401
+ >>> from torchrl.envs import TransformedEnv, InitTracker
1402
+ >>> from torchrl.envs import GymEnv
1403
+ >>> from torchrl.modules import MLP, LSTMModule
1404
+ >>> from torch import nn
1405
+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
1406
+ >>>
1407
+ >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
1408
+ >>> gru_module = GRUModule(
1409
+ ... input_size=env.observation_spec["observation"].shape[-1],
1410
+ ... hidden_size=64,
1411
+ ... in_keys=["observation", "rs"],
1412
+ ... out_keys=["intermediate", ("next", "rs")])
1413
+ >>> mlp = MLP(num_cells=[64], out_features=1)
1414
+ >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
1415
+ >>> policy(env.reset())
1416
+ >>> env = env.append_transform(gru_module.make_tensordict_primer())
1417
+ >>> data_collector = Collector(
1418
+ ... env,
1419
+ ... policy,
1420
+ ... frames_per_batch=10
1421
+ ... )
1422
+ >>> for data in data_collector:
1423
+ ... print(data)
1424
+ ... break
1425
+
1426
+ """
1427
+ from torchrl.envs import TensorDictPrimer
1428
+
1429
+ def make_tuple(key):
1430
+ if isinstance(key, tuple):
1431
+ return key
1432
+ return (key,)
1433
+
1434
+ out_key1 = make_tuple(self.out_keys[1])
1435
+ in_key1 = make_tuple(self.in_keys[1])
1436
+ if out_key1 != ("next", *in_key1):
1437
+ raise RuntimeError(
1438
+ "make_tensordict_primer is supposed to work with in_keys/out_keys that "
1439
+ "have compatible names, ie. the out_keys should be named after ('next', <in_key>). Got "
1440
+ f"in_keys={self.in_keys} and out_keys={self.out_keys} instead."
1441
+ )
1442
+ return TensorDictPrimer(
1443
+ {
1444
+ in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)),
1445
+ },
1446
+ expand_specs=True,
1447
+ )
1448
+
1449
+ @property
1450
+ def recurrent_mode(self):
1451
+ rm = recurrent_mode()
1452
+ if rm is None:
1453
+ return bool(self._recurrent_mode)
1454
+ return rm
1455
+
1456
+ @recurrent_mode.setter
1457
+ def recurrent_mode(self, value):
1458
+ raise RuntimeError(
1459
+ "recurrent_mode cannot be changed in-place. Please use the set_recurrent_mode context manager."
1460
+ )
1461
+
1462
+ @property
1463
+ def temporal_mode(self):
1464
+ raise RuntimeError(
1465
+ "temporal_mode is deprecated, use recurrent_mode instead.",
1466
+ )
1467
+
1468
+ def set_recurrent_mode(self, mode: bool = True):
1469
+ raise RuntimeError(
1470
+ "The gru.set_recurrent_mode() API has been removed in v0.8. "
1471
+ "To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or "
1472
+ "the `default_recurrent_mode` keyword argument in the constructor.",
1473
+ )
1474
+
1475
+ @dispatch
1476
+ @set_lazy_legacy(False)
1477
+ def forward(self, tensordict: TensorDictBase):
1478
+ from torchrl.objectives.value.functional import (
1479
+ _inv_pad_sequence,
1480
+ _split_and_pad_sequence,
1481
+ )
1482
+
1483
+ # we want to get an error if the value input is missing, but not the hidden states
1484
+ defaults = [NO_DEFAULT, None]
1485
+ shape = tensordict.shape
1486
+ tensordict_shaped = tensordict
1487
+ if self.recurrent_mode:
1488
+ # if less than 2 dims, unsqueeze
1489
+ ndim = tensordict_shaped.get(self.in_keys[0]).ndim
1490
+ while ndim < 3:
1491
+ tensordict_shaped = tensordict_shaped.unsqueeze(0)
1492
+ ndim += 1
1493
+ if ndim > 3:
1494
+ dims_to_flatten = ndim - 3
1495
+ # we assume that the tensordict can be flattened like this
1496
+ nelts = prod(tensordict_shaped.shape[: dims_to_flatten + 1])
1497
+ tensordict_shaped = tensordict_shaped.apply(
1498
+ lambda value: value.flatten(0, dims_to_flatten),
1499
+ batch_size=[nelts, tensordict_shaped.shape[-1]],
1500
+ )
1501
+ else:
1502
+ tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
1503
+
1504
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
1505
+ splits = None
1506
+ if self.recurrent_mode and is_init[..., 1:].any():
1507
+ from torchrl.objectives.value.utils import _get_num_per_traj_init
1508
+
1509
+ # if we have consecutive trajectories, things get a little more complicated
1510
+ # we have a tensordict of shape [B, T]
1511
+ # we will split / pad things such that we get a tensordict of shape
1512
+ # [N, T'] where T' <= T and N >= B is the new batch size, such that
1513
+ # each index of N is an independent trajectory. We'll need to keep
1514
+ # track of the indices though, as we want to put things back together in the end.
1515
+ splits = _get_num_per_traj_init(is_init)
1516
+ tensordict_shaped_shape = tensordict_shaped.shape
1517
+ tensordict_shaped = _split_and_pad_sequence(
1518
+ tensordict_shaped.select(*self.in_keys, strict=False), splits
1519
+ )
1520
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
1521
+
1522
+ value, hidden = (
1523
+ tensordict_shaped.get(key, default)
1524
+ for key, default in zip(self.in_keys, defaults)
1525
+ )
1526
+ batch, steps = value.shape[:2]
1527
+ device = value.device
1528
+ dtype = value.dtype
1529
+ # packed sequences do not help to get the accurate last hidden values
1530
+ # if splits is not None:
1531
+ # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
1532
+ if not self.recurrent_mode and is_init.any() and hidden is not None:
1533
+ is_init_expand = expand_as_right(is_init, hidden)
1534
+ hidden = torch.where(is_init_expand, 0, hidden)
1535
+ val, hidden = self._gru(value, batch, steps, device, dtype, hidden)
1536
+ tensordict_shaped.set(self.out_keys[0], val)
1537
+ tensordict_shaped.set(self.out_keys[1], hidden)
1538
+ if splits is not None:
1539
+ # let's recover our original shape
1540
+ tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape(
1541
+ tensordict_shaped_shape
1542
+ )
1543
+
1544
+ if shape != tensordict_shaped.shape or tensordict_shaped is not tensordict:
1545
+ tensordict.update(tensordict_shaped.reshape(shape))
1546
+ return tensordict
1547
+
1548
+ def _gru(
1549
+ self,
1550
+ input: torch.Tensor,
1551
+ batch,
1552
+ steps,
1553
+ device,
1554
+ dtype,
1555
+ hidden_in: torch.Tensor | None = None,
1556
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1557
+
1558
+ if not self.recurrent_mode and steps != 1:
1559
+ raise ValueError("Expected a single step")
1560
+
1561
+ if hidden_in is None:
1562
+ shape = (batch, steps)
1563
+ hidden_in = torch.zeros(
1564
+ *shape,
1565
+ self.gru.num_layers,
1566
+ self.gru.hidden_size,
1567
+ device=device,
1568
+ dtype=dtype,
1569
+ )
1570
+
1571
+ # we only need the first hidden state
1572
+ _hidden_in = hidden_in[:, 0]
1573
+ hidden = _hidden_in.transpose(-3, -2).contiguous()
1574
+
1575
+ y, hidden = self.gru(input, hidden)
1576
+ # dim 0 in hidden is num_layers, but that will conflict with tensordict
1577
+ hidden = hidden.transpose(0, 1)
1578
+
1579
+ # we pad the hidden states with zero to make tensordict happy
1580
+ hidden = torch.stack(
1581
+ [torch.zeros_like(hidden) for _ in range(steps - 1)] + [hidden],
1582
+ 1,
1583
+ )
1584
+ out = [y, hidden]
1585
+ return tuple(out)
1586
+
1587
+
1588
+ # Recurrent mode manager
1589
+ recurrent_mode_state_manager = _ContextManager()
1590
+
1591
+
1592
+ def recurrent_mode() -> bool | None:
1593
+ """Returns the current sampling type."""
1594
+ return recurrent_mode_state_manager.get_mode()
1595
+
1596
+
1597
+ class set_recurrent_mode(_DecoratorContextManager):
1598
+ """Context manager for setting RNNs recurrent mode.
1599
+
1600
+ Args:
1601
+ mode (bool, "recurrent" or "sequential"): the recurrent mode to be used within the context manager.
1602
+ `"recurrent"` leads to `mode=True` and `"sequential"` leads to `mode=False`.
1603
+ An RNN executed with recurrent_mode "on" assumes that the data comes in time batches, otherwise
1604
+ it is assumed that each data element in a tensordict is independent of the others.
1605
+ The default value of this context manager is ``True``.
1606
+ The default recurrent mode is ``None``, i.e., the default recurrent mode of the RNN is used
1607
+ (see :class:`~torchrl.modules.LSTMModule` and :class:`~torchrl.modules.GRUModule` constructors).
1608
+
1609
+ .. seealso:: :class:`~torchrl.modules.recurrent_mode``.
1610
+
1611
+ .. note:: All of TorchRL methods are decorated with ``set_recurrent_mode(True)`` by default.
1612
+
1613
+ """
1614
+
1615
+ def __init__(
1616
+ self, mode: bool | typing.Literal["recurrent", "sequential"] | None = True
1617
+ ) -> None:
1618
+ super().__init__()
1619
+ if isinstance(mode, str):
1620
+ if mode.lower() in ("recurrent",):
1621
+ mode = True
1622
+ elif mode.lower() in ("sequential",):
1623
+ mode = False
1624
+ else:
1625
+ raise ValueError(
1626
+ f"Unsupported recurrent mode. Must be a bool, or one of {('recurrent', 'sequential')}"
1627
+ )
1628
+ self.mode = mode
1629
+
1630
+ def clone(self) -> set_recurrent_mode:
1631
+ # override this method if your children class takes __init__ parameters
1632
+ return type(self)(self.mode)
1633
+
1634
+ def __enter__(self) -> None:
1635
+ self.prev = recurrent_mode_state_manager.get_mode()
1636
+ recurrent_mode_state_manager.set_mode(self.mode)
1637
+
1638
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1639
+ recurrent_mode_state_manager.set_mode(self.prev)