torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1712 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+ from collections.abc import Callable, Sequence
9
+ from copy import deepcopy
10
+ from numbers import Number
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from torchrl._utils import prod
16
+ from torchrl.data.utils import DEVICE_TYPING
17
+ from torchrl.modules.models.decision_transformer import DecisionTransformer
18
+ from torchrl.modules.models.utils import (
19
+ _find_depth,
20
+ create_on_device,
21
+ LazyMapping,
22
+ SquashDims,
23
+ Squeeze2dLayer,
24
+ SqueezeLayer,
25
+ )
26
+ from torchrl.modules.tensordict_module.common import DistributionalDQNnet # noqa
27
+
28
+
29
+ class MLP(nn.Sequential):
30
+ """A multi-layer perceptron.
31
+
32
+ If MLP receives more than one input, it concatenates them all along the last dimension before passing the
33
+ resulting tensor through the network. This is aimed at allowing for a seamless interface with calls of the type of
34
+
35
+ >>> model(state, action) # compute state-action value
36
+
37
+ In the future, this feature may be moved to the ProbabilisticTDModule, though it would require it to handle
38
+ different cases (vectors, images, ...)
39
+
40
+ Args:
41
+ in_features (int, optional): number of input features;
42
+ out_features (int, torch.Size or equivalent): number of output
43
+ features. If iterable of integers, the output is reshaped to the
44
+ desired shape.
45
+ depth (int, optional): depth of the network. A depth of 0 will produce
46
+ a single linear layer network with the desired input and output size.
47
+ A length of 1 will create 2 linear layers etc. If no depth is indicated,
48
+ the depth information should be contained in the ``num_cells``
49
+ argument (see below). If ``num_cells`` is an iterable and depth is
50
+ indicated, both should match: ``len(num_cells)`` must be equal to
51
+ ``depth``.
52
+ Defaults to ``0`` (no depth - the network contains a single linear layer).
53
+ num_cells (int or sequence of int, optional): number of cells of every
54
+ layer in between the input and output. If an integer is provided,
55
+ every layer will have the same number of cells. If an iterable is provided,
56
+ the linear layers ``out_features`` will match the content of
57
+ ``num_cells``. Defaults to ``32``;
58
+ activation_class (Type[nn.Module] or callable, optional): activation
59
+ class or constructor to be used.
60
+ Defaults to :class:`~torch.nn.Tanh`.
61
+ activation_kwargs (dict or list of dicts, optional): kwargs to be used
62
+ with the activation class. Also accepts a list of kwargs of length
63
+ ``depth + int(activate_last_layer)``.
64
+ norm_class (Type or callable, optional): normalization class or
65
+ constructor, if any.
66
+ norm_kwargs (dict or list of dicts, optional): kwargs to be used with
67
+ the normalization layers. Also accepts a list of kwargs of length
68
+ ``depth + int(activate_last_layer)``.
69
+ dropout (:obj:`float`, optional): dropout probability. Defaults to ``None`` (no
70
+ dropout);
71
+ bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter.
72
+ default: True;
73
+ single_bias_last_layer (bool): if ``True``, the last dimension of the bias of the last layer will be a singleton
74
+ dimension.
75
+ default: True;
76
+ layer_class (Type[nn.Module] or callable, optional): class to be used
77
+ for the linear layers;
78
+ layer_kwargs (dict or list of dicts, optional): kwargs for the linear
79
+ layers. Also accepts a list of kwargs of length ``depth + 1``.
80
+ activate_last_layer (bool): whether the MLP output should be activated. This is useful when the MLP output
81
+ is used as the input for another module.
82
+ default: False.
83
+ device (torch.device, optional): device to create the module on.
84
+
85
+ Examples:
86
+ >>> # All of the following examples provide valid, working MLPs
87
+ >>> mlp = MLP(in_features=3, out_features=6, depth=0) # MLP consisting of a single 3 x 6 linear layer
88
+ >>> print(mlp)
89
+ MLP(
90
+ (0): Linear(in_features=3, out_features=6, bias=True)
91
+ )
92
+ >>> mlp = MLP(in_features=3, out_features=6, depth=4, num_cells=32)
93
+ >>> print(mlp)
94
+ MLP(
95
+ (0): Linear(in_features=3, out_features=32, bias=True)
96
+ (1): Tanh()
97
+ (2): Linear(in_features=32, out_features=32, bias=True)
98
+ (3): Tanh()
99
+ (4): Linear(in_features=32, out_features=32, bias=True)
100
+ (5): Tanh()
101
+ (6): Linear(in_features=32, out_features=32, bias=True)
102
+ (7): Tanh()
103
+ (8): Linear(in_features=32, out_features=6, bias=True)
104
+ )
105
+ >>> mlp = MLP(out_features=6, depth=4, num_cells=32) # LazyLinear for the first layer
106
+ >>> print(mlp)
107
+ MLP(
108
+ (0): LazyLinear(in_features=0, out_features=32, bias=True)
109
+ (1): Tanh()
110
+ (2): Linear(in_features=32, out_features=32, bias=True)
111
+ (3): Tanh()
112
+ (4): Linear(in_features=32, out_features=32, bias=True)
113
+ (5): Tanh()
114
+ (6): Linear(in_features=32, out_features=32, bias=True)
115
+ (7): Tanh()
116
+ (8): Linear(in_features=32, out_features=6, bias=True)
117
+ )
118
+ >>> mlp = MLP(out_features=6, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg
119
+ >>> print(mlp)
120
+ MLP(
121
+ (0): LazyLinear(in_features=0, out_features=32, bias=True)
122
+ (1): Tanh()
123
+ (2): Linear(in_features=32, out_features=33, bias=True)
124
+ (3): Tanh()
125
+ (4): Linear(in_features=33, out_features=34, bias=True)
126
+ (5): Tanh()
127
+ (6): Linear(in_features=34, out_features=35, bias=True)
128
+ (7): Tanh()
129
+ (8): Linear(in_features=35, out_features=6, bias=True)
130
+ )
131
+ >>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35]) # returns a view of the output tensor with shape [*, 6, 7]
132
+ >>> print(mlp)
133
+ MLP(
134
+ (0): LazyLinear(in_features=0, out_features=32, bias=True)
135
+ (1): Tanh()
136
+ (2): Linear(in_features=32, out_features=33, bias=True)
137
+ (3): Tanh()
138
+ (4): Linear(in_features=33, out_features=34, bias=True)
139
+ (5): Tanh()
140
+ (6): Linear(in_features=34, out_features=35, bias=True)
141
+ (7): Tanh()
142
+ (8): Linear(in_features=35, out_features=42, bias=True)
143
+ )
144
+ >>> from torchrl.modules import NoisyLinear
145
+ >>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35], layer_class=NoisyLinear) # uses NoisyLinear layers
146
+ >>> print(mlp)
147
+ MLP(
148
+ (0): NoisyLazyLinear(in_features=0, out_features=32, bias=False)
149
+ (1): Tanh()
150
+ (2): NoisyLinear(in_features=32, out_features=33, bias=True)
151
+ (3): Tanh()
152
+ (4): NoisyLinear(in_features=33, out_features=34, bias=True)
153
+ (5): Tanh()
154
+ (6): NoisyLinear(in_features=34, out_features=35, bias=True)
155
+ (7): Tanh()
156
+ (8): NoisyLinear(in_features=35, out_features=42, bias=True)
157
+ )
158
+
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ in_features: int | None = None,
164
+ out_features: int | torch.Size | None = None,
165
+ depth: int | None = None,
166
+ num_cells: Sequence[int] | int | None = None,
167
+ activation_class: type[nn.Module] | Callable = nn.Tanh,
168
+ activation_kwargs: dict | list[dict] | None = None,
169
+ norm_class: type[nn.Module] | Callable | None = None,
170
+ norm_kwargs: dict | list[dict] | None = None,
171
+ dropout: float | None = None,
172
+ bias_last_layer: bool = True,
173
+ single_bias_last_layer: bool = False,
174
+ layer_class: type[nn.Module] | Callable = nn.Linear,
175
+ layer_kwargs: dict | None = None,
176
+ activate_last_layer: bool = False,
177
+ device: DEVICE_TYPING | None = None,
178
+ ):
179
+ if out_features is None:
180
+ raise ValueError("out_features must be specified for MLP.")
181
+
182
+ if num_cells is None:
183
+ default_num_cells = 32
184
+ if depth is None:
185
+ num_cells = []
186
+ depth = 0
187
+ else:
188
+ num_cells = [default_num_cells] * depth
189
+
190
+ self.in_features = in_features
191
+
192
+ _out_features_num = out_features
193
+ if not isinstance(out_features, Number):
194
+ _out_features_num = prod(out_features)
195
+ self.out_features = out_features
196
+ self._reshape_out = not isinstance(
197
+ self.out_features, (int, torch.SymInt, Number)
198
+ )
199
+ self._out_features_num = _out_features_num
200
+ self.activation_class = activation_class
201
+ self.norm_class = norm_class
202
+ self.dropout = dropout
203
+ self.bias_last_layer = bias_last_layer
204
+ self.single_bias_last_layer = single_bias_last_layer
205
+ self.layer_class = layer_class
206
+
207
+ self.activation_kwargs = activation_kwargs
208
+ self.norm_kwargs = norm_kwargs
209
+ self.layer_kwargs = layer_kwargs
210
+
211
+ self.activate_last_layer = activate_last_layer
212
+ if single_bias_last_layer:
213
+ raise NotImplementedError
214
+
215
+ if not (isinstance(num_cells, Sequence) or depth is not None):
216
+ raise RuntimeError(
217
+ "If num_cells is provided as an integer, \
218
+ depth must be provided too."
219
+ )
220
+ self.num_cells = (
221
+ list(num_cells) if isinstance(num_cells, Sequence) else [num_cells] * depth
222
+ )
223
+ self.depth = depth if depth is not None else len(self.num_cells)
224
+ if not (len(self.num_cells) == depth or depth is None):
225
+ raise RuntimeError(
226
+ "depth and num_cells length conflict, \
227
+ consider matching or specifying a constant num_cells argument together with a a desired depth"
228
+ )
229
+
230
+ self._activation_kwargs_iter = _iter_maybe_over_single(
231
+ activation_kwargs, n=self.depth + self.activate_last_layer
232
+ )
233
+ self._norm_kwargs_iter = _iter_maybe_over_single(
234
+ norm_kwargs, n=self.depth + self.activate_last_layer
235
+ )
236
+ self._layer_kwargs_iter = _iter_maybe_over_single(
237
+ layer_kwargs, n=self.depth + 1
238
+ )
239
+ layers = self._make_net(device)
240
+ layers = [
241
+ layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer)
242
+ for layer in layers
243
+ ]
244
+ super().__init__(*layers)
245
+
246
+ def _make_net(self, device: DEVICE_TYPING | None) -> list[nn.Module]:
247
+ layers = []
248
+ in_features = [self.in_features] + self.num_cells
249
+ out_features = self.num_cells + [self._out_features_num]
250
+ for i, (_in, _out) in enumerate(zip(in_features, out_features)):
251
+ layer_kwargs = next(self._layer_kwargs_iter)
252
+ _bias = layer_kwargs.pop(
253
+ "bias", self.bias_last_layer if i == self.depth else True
254
+ )
255
+ if _in is not None:
256
+ layers.append(
257
+ create_on_device(
258
+ self.layer_class,
259
+ device,
260
+ _in,
261
+ _out,
262
+ bias=_bias,
263
+ **layer_kwargs,
264
+ )
265
+ )
266
+ else:
267
+ try:
268
+ lazy_version = LazyMapping[self.layer_class]
269
+ except KeyError:
270
+ raise KeyError(
271
+ f"The lazy version of {self.layer_class.__name__} is not implemented yet. "
272
+ "Consider providing the input feature dimensions explicitly when creating an MLP module"
273
+ )
274
+ layers.append(
275
+ create_on_device(
276
+ lazy_version, device, _out, bias=_bias, **layer_kwargs
277
+ )
278
+ )
279
+
280
+ if i < self.depth or self.activate_last_layer:
281
+ norm_kwargs = next(self._norm_kwargs_iter)
282
+ activation_kwargs = next(self._activation_kwargs_iter)
283
+ if self.dropout is not None:
284
+ layers.append(create_on_device(nn.Dropout, device, p=self.dropout))
285
+ if self.norm_class is not None:
286
+ layers.append(
287
+ create_on_device(self.norm_class, device, **norm_kwargs)
288
+ )
289
+ layers.append(
290
+ create_on_device(self.activation_class, device, **activation_kwargs)
291
+ )
292
+
293
+ return layers
294
+
295
+ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
296
+ if len(inputs) > 1:
297
+ inputs = (torch.cat([*inputs], -1),)
298
+
299
+ out = super().forward(*inputs)
300
+ if self._reshape_out:
301
+ out = out.view(*out.shape[:-1], *self.out_features)
302
+ return out
303
+
304
+
305
+ class ConvNet(nn.Sequential):
306
+ """A convolutional neural network.
307
+
308
+ Args:
309
+ in_features (int, optional): number of input features. If ``None``, a
310
+ :class:`~torch.nn.LazyConv2d` module is used for the first layer.;
311
+ depth (int, optional): depth of the network. A depth of 1 will produce
312
+ a single linear layer network with the desired input size, and
313
+ with an output size equal to the last element of the num_cells
314
+ argument.
315
+ If no depth is indicated, the depth information should be contained
316
+ in the ``num_cells`` argument (see below).
317
+ If ``num_cells`` is an iterable and ``depth`` is indicated, both
318
+ should match: ``len(num_cells)`` must be equal to the ``depth``.
319
+ num_cells (int or Sequence of int, optional): number of cells of
320
+ every layer in between the input and output. If an integer is
321
+ provided, every layer will have the same number of cells. If an
322
+ iterable is provided, the linear layers ``out_features`` will match
323
+ the content of num_cells. Defaults to ``[32, 32, 32]``.
324
+ kernel_sizes (int, sequence of int, optional): Kernel size(s) of the
325
+ conv network. If iterable, the length must match the depth,
326
+ defined by the ``num_cells`` or depth arguments.
327
+ Defaults to ``3``.
328
+ strides (int or sequence of int, optional): Stride(s) of the conv network. If
329
+ iterable, the length must match the depth, defined by the
330
+ ``num_cells`` or depth arguments. Defaults to ``1``.
331
+ activation_class (Type[nn.Module] or callable, optional): activation
332
+ class or constructor to be used.
333
+ Defaults to :class:`~torch.nn.Tanh`.
334
+ activation_kwargs (dict or list of dicts, optional): kwargs to be used
335
+ with the activation class. A list of kwargs of length ``depth``
336
+ can also be passed, with one element per layer.
337
+ norm_class (Type or callable, optional): normalization class or
338
+ constructor, if any.
339
+ norm_kwargs (dict or list of dicts, optional): kwargs to be used with
340
+ the normalization layers. A list of kwargs of length ``depth`` can
341
+ also be passed, with one element per layer.
342
+ bias_last_layer (bool): if ``True``, the last Linear layer will have a
343
+ bias parameter. Defaults to ``True``.
344
+ aggregator_class (Type[nn.Module] or callable): aggregator class or
345
+ constructor to use at the end of the chain.
346
+ Defaults to :class:`torchrl.modules.utils.models.SquashDims`;
347
+ aggregator_kwargs (dict, optional): kwargs for the
348
+ ``aggregator_class``.
349
+ squeeze_output (bool): whether the output should be squeezed of its
350
+ singleton dimensions.
351
+ Defaults to ``False``.
352
+ device (torch.device, optional): device to create the module on.
353
+
354
+ Examples:
355
+ >>> # All of the following examples provide valid, working MLPs
356
+ >>> cnet = ConvNet(in_features=3, depth=1, num_cells=[32,]) # MLP consisting of a single 3 x 6 linear layer
357
+ >>> print(cnet)
358
+ ConvNet(
359
+ (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
360
+ (1): ELU(alpha=1.0)
361
+ (2): SquashDims()
362
+ )
363
+ >>> cnet = ConvNet(in_features=3, depth=4, num_cells=32)
364
+ >>> print(cnet)
365
+ ConvNet(
366
+ (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
367
+ (1): ELU(alpha=1.0)
368
+ (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
369
+ (3): ELU(alpha=1.0)
370
+ (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
371
+ (5): ELU(alpha=1.0)
372
+ (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
373
+ (7): ELU(alpha=1.0)
374
+ (8): SquashDims()
375
+ )
376
+ >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg
377
+ >>> print(cnet)
378
+ ConvNet(
379
+ (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
380
+ (1): ELU(alpha=1.0)
381
+ (2): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1))
382
+ (3): ELU(alpha=1.0)
383
+ (4): Conv2d(33, 34, kernel_size=(3, 3), stride=(1, 1))
384
+ (5): ELU(alpha=1.0)
385
+ (6): Conv2d(34, 35, kernel_size=(3, 3), stride=(1, 1))
386
+ (7): ELU(alpha=1.0)
387
+ (8): SquashDims()
388
+ )
389
+ >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3)]) # defines kernels, possibly rectangular
390
+ >>> print(cnet)
391
+ ConvNet(
392
+ (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
393
+ (1): ELU(alpha=1.0)
394
+ (2): Conv2d(32, 33, kernel_size=(4, 4), stride=(1, 1))
395
+ (3): ELU(alpha=1.0)
396
+ (4): Conv2d(33, 34, kernel_size=(5, 5), stride=(1, 1))
397
+ (5): ELU(alpha=1.0)
398
+ (6): Conv2d(34, 35, kernel_size=(2, 3), stride=(1, 1))
399
+ (7): ELU(alpha=1.0)
400
+ (8): SquashDims()
401
+ )
402
+
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ in_features: int | None = None,
408
+ depth: int | None = None,
409
+ num_cells: Sequence[int] | int = None,
410
+ kernel_sizes: Sequence[int] | int = 3,
411
+ strides: Sequence[int] | int = 1,
412
+ paddings: Sequence[int] | int = 0,
413
+ activation_class: type[nn.Module] | Callable = nn.ELU,
414
+ activation_kwargs: dict | list[dict] | None = None,
415
+ norm_class: type[nn.Module] | Callable | None = None,
416
+ norm_kwargs: dict | list[dict] | None = None,
417
+ bias_last_layer: bool = True,
418
+ aggregator_class: type[nn.Module] | Callable | None = SquashDims,
419
+ aggregator_kwargs: dict | None = None,
420
+ squeeze_output: bool = False,
421
+ device: DEVICE_TYPING | None = None,
422
+ ):
423
+ if num_cells is None:
424
+ num_cells = [32, 32, 32]
425
+
426
+ self.in_features = in_features
427
+ self.activation_class = activation_class
428
+ self.norm_class = norm_class
429
+ self.bias_last_layer = bias_last_layer
430
+ self.aggregator_class = aggregator_class
431
+ self.aggregator_kwargs = (
432
+ aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 3}
433
+ )
434
+ self.squeeze_output = squeeze_output
435
+ # self.single_bias_last_layer = single_bias_last_layer
436
+
437
+ self.activation_kwargs = (
438
+ activation_kwargs if activation_kwargs is not None else {}
439
+ )
440
+ self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
441
+
442
+ depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings)
443
+ self.depth = depth
444
+ if depth == 0:
445
+ raise ValueError("Null depth is not permitted with ConvNet.")
446
+
447
+ for _field, _value in zip(
448
+ ["num_cells", "kernel_sizes", "strides", "paddings"],
449
+ [num_cells, kernel_sizes, strides, paddings],
450
+ ):
451
+ _depth = depth
452
+ setattr(
453
+ self,
454
+ _field,
455
+ (_value if isinstance(_value, Sequence) else [_value] * _depth),
456
+ )
457
+ if not (isinstance(_value, Sequence) or _depth is not None):
458
+ raise RuntimeError(
459
+ f"If {_field} is provided as an integer, "
460
+ "depth must be provided too."
461
+ )
462
+ if not (len(getattr(self, _field)) == _depth or _depth is None):
463
+ raise RuntimeError(
464
+ f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, "
465
+ + f"consider matching or specifying a constant {_field} argument together with a a desired depth"
466
+ )
467
+
468
+ self.out_features = self.num_cells[-1]
469
+
470
+ self.depth = len(self.kernel_sizes)
471
+
472
+ self._activation_kwargs_iter = _iter_maybe_over_single(
473
+ activation_kwargs, n=self.depth
474
+ )
475
+ self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs, n=self.depth)
476
+
477
+ layers = self._make_net(device)
478
+ layers = [
479
+ layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer)
480
+ for layer in layers
481
+ ]
482
+ super().__init__(*layers)
483
+
484
+ def _make_net(self, device: DEVICE_TYPING | None) -> nn.Module:
485
+ layers = []
486
+ in_features = [self.in_features] + list(self.num_cells[: self.depth])
487
+ out_features = list(self.num_cells) + [self.out_features]
488
+ kernel_sizes = self.kernel_sizes
489
+ strides = self.strides
490
+ paddings = self.paddings
491
+ for i, (_in, _out, _kernel, _stride, _padding) in enumerate(
492
+ zip(in_features, out_features, kernel_sizes, strides, paddings)
493
+ ):
494
+ _bias = (i < len(in_features) - 1) or self.bias_last_layer
495
+ if _in is not None:
496
+ layers.append(
497
+ nn.Conv2d(
498
+ _in,
499
+ _out,
500
+ kernel_size=_kernel,
501
+ stride=_stride,
502
+ bias=_bias,
503
+ padding=_padding,
504
+ device=device,
505
+ )
506
+ )
507
+ else:
508
+ layers.append(
509
+ nn.LazyConv2d(
510
+ _out,
511
+ kernel_size=_kernel,
512
+ stride=_stride,
513
+ bias=_bias,
514
+ padding=_padding,
515
+ device=device,
516
+ )
517
+ )
518
+
519
+ activation_kwargs = next(self._activation_kwargs_iter)
520
+ layers.append(
521
+ create_on_device(self.activation_class, device, **activation_kwargs)
522
+ )
523
+ if self.norm_class is not None:
524
+ norm_kwargs = next(self._norm_kwargs_iter)
525
+ layers.append(create_on_device(self.norm_class, device, **norm_kwargs))
526
+
527
+ if self.aggregator_class is not None:
528
+ layers.append(
529
+ create_on_device(
530
+ self.aggregator_class, device, **self.aggregator_kwargs
531
+ )
532
+ )
533
+
534
+ if self.squeeze_output:
535
+ layers.append(Squeeze2dLayer())
536
+ return layers
537
+
538
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
539
+ *batch, C, L, W = inputs.shape
540
+ if len(batch) > 1:
541
+ inputs = inputs.flatten(0, len(batch) - 1)
542
+ out = super().forward(inputs)
543
+ if len(batch) > 1:
544
+ out = out.unflatten(0, batch)
545
+ return out
546
+
547
+ @classmethod
548
+ def default_atari_dqn(cls, num_actions: int):
549
+ """Returns the default DQN as presented in the seminal DQN paper.
550
+
551
+ Args:
552
+ num_actions (int): the action space of the atari game.
553
+
554
+ """
555
+ cnn = ConvNet(
556
+ activation_class=torch.nn.ReLU,
557
+ num_cells=[32, 64, 64],
558
+ kernel_sizes=[8, 4, 3],
559
+ strides=[4, 2, 1],
560
+ )
561
+ mlp = MLP(
562
+ activation_class=torch.nn.ReLU,
563
+ out_features=num_actions,
564
+ num_cells=[512],
565
+ )
566
+ return nn.Sequential(cnn, mlp)
567
+
568
+
569
+ Conv2dNet = ConvNet
570
+
571
+
572
+ class Conv3dNet(nn.Sequential):
573
+ """A 3D-convolutional neural network.
574
+
575
+ Args:
576
+ in_features (int, optional): number of input features. A lazy
577
+ implementation that automatically retrieves the input size will be
578
+ used if none is provided.
579
+ depth (int, optional): depth of the network. A depth of ``1`` will
580
+ produce a single linear layer network with the desired input size,
581
+ and with an output size equal to the last element of the
582
+ ``num_cells`` argument. If no ``depth`` is indicated, the ``depth``
583
+ information should be contained in the ``num_cells`` argument
584
+ (see below).
585
+ If ``num_cells`` is an iterable and ``depth`` is indicated,
586
+ both should match: ``len(num_cells)`` must be equal to
587
+ the ``depth``.
588
+ num_cells (int or sequence of int, optional): number of cells of every
589
+ layer in between the input and output. If an integer is provided,
590
+ every layer will have the same number of cells and the depth will
591
+ be retrieved from ``depth``. If an iterable is
592
+ provided, the linear layers ``out_features`` will match the content
593
+ of num_cells. Defaults to ``[32, 32, 32]`` or ``[32] * depth` is
594
+ depth is not ``None``.
595
+ kernel_sizes (int, sequence of int, optional): Kernel size(s) of the
596
+ conv network. If iterable, the length must match the depth,
597
+ defined by the ``num_cells`` or depth arguments. Defaults to ``3``.
598
+ strides (int or sequence of int): Stride(s) of the conv network.
599
+ If iterable, the length must match the depth, defined by the
600
+ ``num_cells`` or depth arguments. Defaults to ``1``.
601
+ activation_class (Type[nn.Module] or callable): activation class or
602
+ constructor to be used. Defaults to :class:`~torch.nn.Tanh`.
603
+ activation_kwargs (dict or list of dicts, optional): kwargs to be used
604
+ with the activation class. A list of kwargs of length ``depth``
605
+ with one element per layer can also be provided.
606
+ norm_class (Type or callable, optional): normalization class, if any.
607
+ norm_kwargs (dict or list of dicts, optional): kwargs to be used with
608
+ the normalization layers. A list of kwargs of length ``depth``
609
+ with one element per layer can also be provided.
610
+ bias_last_layer (bool): if ``True``, the last Linear layer will have a
611
+ bias parameter. Defaults to ``True``.
612
+ aggregator_class (Type[nn.Module] or callable): aggregator class or
613
+ constructor to use at the end of the chain. Defaults to
614
+ :class:`~torchrl.modules.models.utils.SquashDims`.
615
+ aggregator_kwargs (dict, optional): kwargs for the ``aggregator_class``
616
+ constructor.
617
+ squeeze_output (bool): whether the output should be squeezed of its
618
+ singleton dimensions. Defaults to ``False``.
619
+ device (torch.device, optional): device to create the module on.
620
+
621
+ Examples:
622
+ >>> # All of the following examples provide valid, working MLPs
623
+ >>> cnet = Conv3dNet(in_features=3, depth=1, num_cells=[32,])
624
+ >>> print(cnet)
625
+ Conv3dNet(
626
+ (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
627
+ (1): ELU(alpha=1.0)
628
+ (2): SquashDims()
629
+ )
630
+ >>> cnet = Conv3dNet(in_features=3, depth=4, num_cells=32)
631
+ >>> print(cnet)
632
+ Conv3dNet(
633
+ (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
634
+ (1): ELU(alpha=1.0)
635
+ (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
636
+ (3): ELU(alpha=1.0)
637
+ (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
638
+ (5): ELU(alpha=1.0)
639
+ (6): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
640
+ (7): ELU(alpha=1.0)
641
+ (8): SquashDims()
642
+ )
643
+ >>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg
644
+ >>> print(cnet)
645
+ Conv3dNet(
646
+ (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
647
+ (1): ELU(alpha=1.0)
648
+ (2): Conv3d(32, 33, kernel_size=(3, 3, 3), stride=(1, 1, 1))
649
+ (3): ELU(alpha=1.0)
650
+ (4): Conv3d(33, 34, kernel_size=(3, 3, 3), stride=(1, 1, 1))
651
+ (5): ELU(alpha=1.0)
652
+ (6): Conv3d(34, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1))
653
+ (7): ELU(alpha=1.0)
654
+ (8): SquashDims()
655
+ )
656
+ >>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3, 4)]) # defines kernels, possibly rectangular
657
+ >>> print(cnet)
658
+ Conv3dNet(
659
+ (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
660
+ (1): ELU(alpha=1.0)
661
+ (2): Conv3d(32, 33, kernel_size=(4, 4, 4), stride=(1, 1, 1))
662
+ (3): ELU(alpha=1.0)
663
+ (4): Conv3d(33, 34, kernel_size=(5, 5, 5), stride=(1, 1, 1))
664
+ (5): ELU(alpha=1.0)
665
+ (6): Conv3d(34, 35, kernel_size=(2, 3, 4), stride=(1, 1, 1))
666
+ (7): ELU(alpha=1.0)
667
+ (8): SquashDims()
668
+ )
669
+
670
+ """
671
+
672
+ def __init__(
673
+ self,
674
+ in_features: int | None = None,
675
+ depth: int | None = None,
676
+ num_cells: Sequence[int] | int = None,
677
+ kernel_sizes: Sequence[int] | int = 3,
678
+ strides: Sequence[int] | int = 1,
679
+ paddings: Sequence[int] | int = 0,
680
+ activation_class: type[nn.Module] | Callable = nn.ELU,
681
+ activation_kwargs: dict | list[dict] | None = None,
682
+ norm_class: type[nn.Module] | Callable | None = None,
683
+ norm_kwargs: dict | list[dict] | None = None,
684
+ bias_last_layer: bool = True,
685
+ aggregator_class: type[nn.Module] | Callable | None = SquashDims,
686
+ aggregator_kwargs: dict | None = None,
687
+ squeeze_output: bool = False,
688
+ device: DEVICE_TYPING | None = None,
689
+ ):
690
+ if num_cells is None:
691
+ if depth is None:
692
+ num_cells = [32, 32, 32]
693
+ else:
694
+ num_cells = [32] * depth
695
+
696
+ self.in_features = in_features
697
+ self.activation_class = activation_class
698
+ self.norm_class = norm_class
699
+
700
+ self.activation_kwargs = (
701
+ activation_kwargs if activation_kwargs is not None else {}
702
+ )
703
+ self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
704
+
705
+ self.bias_last_layer = bias_last_layer
706
+ self.aggregator_class = aggregator_class
707
+ self.aggregator_kwargs = (
708
+ aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 4}
709
+ )
710
+ self.squeeze_output = squeeze_output
711
+ # self.single_bias_last_layer = single_bias_last_layer
712
+
713
+ depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings)
714
+ self.depth = depth
715
+ if depth == 0:
716
+ raise ValueError("Null depth is not permitted with Conv3dNet.")
717
+
718
+ for _field, _value in zip(
719
+ ["num_cells", "kernel_sizes", "strides", "paddings"],
720
+ [num_cells, kernel_sizes, strides, paddings],
721
+ ):
722
+ _depth = depth
723
+ setattr(
724
+ self,
725
+ _field,
726
+ (_value if isinstance(_value, Sequence) else [_value] * _depth),
727
+ )
728
+ if not (len(getattr(self, _field)) == _depth or _depth is None):
729
+ raise ValueError(
730
+ f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, "
731
+ + f"consider matching or specifying a constant {_field} argument together with a a desired depth"
732
+ )
733
+
734
+ self.out_features = self.num_cells[-1]
735
+
736
+ self.depth = len(self.kernel_sizes)
737
+
738
+ self._activation_kwargs_iter = _iter_maybe_over_single(
739
+ activation_kwargs, n=self.depth
740
+ )
741
+ self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs, n=self.depth)
742
+
743
+ layers = self._make_net(device)
744
+ layers = [
745
+ layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer)
746
+ for layer in layers
747
+ ]
748
+ super().__init__(*layers)
749
+
750
+ def _make_net(self, device: DEVICE_TYPING | None) -> nn.Module:
751
+ layers = []
752
+ in_features = [self.in_features] + self.num_cells[: self.depth]
753
+ out_features = self.num_cells + [self.out_features]
754
+ kernel_sizes = self.kernel_sizes
755
+ strides = self.strides
756
+ paddings = self.paddings
757
+ for i, (_in, _out, _kernel, _stride, _padding) in enumerate(
758
+ zip(in_features, out_features, kernel_sizes, strides, paddings)
759
+ ):
760
+ _bias = (i < len(in_features) - 1) or self.bias_last_layer
761
+ if _in is not None:
762
+ layers.append(
763
+ nn.Conv3d(
764
+ _in,
765
+ _out,
766
+ kernel_size=_kernel,
767
+ stride=_stride,
768
+ bias=_bias,
769
+ padding=_padding,
770
+ device=device,
771
+ )
772
+ )
773
+ else:
774
+ layers.append(
775
+ nn.LazyConv3d(
776
+ _out,
777
+ kernel_size=_kernel,
778
+ stride=_stride,
779
+ bias=_bias,
780
+ padding=_padding,
781
+ device=device,
782
+ )
783
+ )
784
+
785
+ activation_kwargs = next(self._activation_kwargs_iter)
786
+ layers.append(
787
+ create_on_device(self.activation_class, device, **activation_kwargs)
788
+ )
789
+ if self.norm_class is not None:
790
+ norm_kwargs = next(self._norm_kwargs_iter)
791
+ layers.append(create_on_device(self.norm_class, device, **norm_kwargs))
792
+
793
+ if self.aggregator_class is not None:
794
+ layers.append(
795
+ create_on_device(
796
+ self.aggregator_class, device, **self.aggregator_kwargs
797
+ )
798
+ )
799
+
800
+ if self.squeeze_output:
801
+ layers.append(SqueezeLayer((-3, -2, -1)))
802
+ return layers
803
+
804
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
805
+ try:
806
+ *batch, C, D, L, W = inputs.shape
807
+ except ValueError as err:
808
+ raise ValueError(
809
+ f"The input value of {self.__class__.__name__} must have at least 4 dimensions, got {inputs.ndim} instead."
810
+ ) from err
811
+ if len(batch) > 1:
812
+ inputs = inputs.flatten(0, len(batch) - 1)
813
+ out = super().forward(inputs)
814
+ if len(batch) > 1:
815
+ out = out.unflatten(0, batch)
816
+ return out
817
+
818
+
819
+ class DuelingMlpDQNet(nn.Module):
820
+ """Creates a Dueling MLP Q-network.
821
+
822
+ Presented in https://arxiv.org/abs/1511.06581
823
+
824
+ Args:
825
+ out_features (int, torch.Size or equivalent): number of features for the advantage network
826
+ out_features_value (int): number of features for the value network.
827
+ Defaults to ``1``.
828
+ mlp_kwargs_feature (dict, optional): kwargs for the feature network.
829
+ Default is
830
+
831
+ >>> mlp_kwargs_feature = {
832
+ ... 'num_cells': [256, 256],
833
+ ... 'activation_class': nn.ELU,
834
+ ... 'out_features': 256,
835
+ ... 'activate_last_layer': True,
836
+ ... }
837
+
838
+ mlp_kwargs_output (dict, optional): kwargs for the advantage and
839
+ value networks. Default is
840
+
841
+ >>> mlp_kwargs_output = {
842
+ ... "depth": 1,
843
+ ... "activation_class": nn.ELU,
844
+ ... "num_cells": 512,
845
+ ... "bias_last_layer": True,
846
+ ... }
847
+
848
+ device (torch.device, optional): device to create the module on.
849
+
850
+ Examples:
851
+ >>> import torch
852
+ >>> from torchrl.modules import DuelingMlpDQNet
853
+ >>> # we can ask for a specific output shape
854
+ >>> net = DuelingMlpDQNet(out_features=(3, 2))
855
+ >>> print(net)
856
+ DuelingMlpDQNet(
857
+ (features): MLP(
858
+ (0): LazyLinear(in_features=0, out_features=256, bias=True)
859
+ (1): ELU(alpha=1.0)
860
+ (2): Linear(in_features=256, out_features=256, bias=True)
861
+ (3): ELU(alpha=1.0)
862
+ (4): Linear(in_features=256, out_features=256, bias=True)
863
+ (5): ELU(alpha=1.0)
864
+ )
865
+ (advantage): MLP(
866
+ (0): LazyLinear(in_features=0, out_features=512, bias=True)
867
+ (1): ELU(alpha=1.0)
868
+ (2): Linear(in_features=512, out_features=6, bias=True)
869
+ )
870
+ (value): MLP(
871
+ (0): LazyLinear(in_features=0, out_features=512, bias=True)
872
+ (1): ELU(alpha=1.0)
873
+ (2): Linear(in_features=512, out_features=1, bias=True)
874
+ )
875
+ )
876
+ >>> x = torch.zeros(1, 5)
877
+ >>> y = net(x)
878
+ >>> print(y)
879
+ tensor([[[ 0.0232, -0.0477],
880
+ [-0.0226, -0.0019],
881
+ [-0.0314, 0.0069]]], grad_fn=<SubBackward0>)
882
+
883
+ """
884
+
885
+ def __init__(
886
+ self,
887
+ out_features: int | torch.Size,
888
+ out_features_value: int = 1,
889
+ mlp_kwargs_feature: dict | None = None,
890
+ mlp_kwargs_output: dict | None = None,
891
+ device: DEVICE_TYPING | None = None,
892
+ ):
893
+ super().__init__()
894
+
895
+ mlp_kwargs_feature = (
896
+ mlp_kwargs_feature if mlp_kwargs_feature is not None else {}
897
+ )
898
+ _mlp_kwargs_feature = {
899
+ "num_cells": [256, 256],
900
+ "out_features": 256,
901
+ "activation_class": nn.ELU,
902
+ "activate_last_layer": True,
903
+ }
904
+ _mlp_kwargs_feature.update(mlp_kwargs_feature)
905
+ self.features = MLP(device=device, **_mlp_kwargs_feature)
906
+
907
+ _mlp_kwargs_output = {
908
+ "depth": 1,
909
+ "activation_class": nn.ELU,
910
+ "num_cells": 512,
911
+ "bias_last_layer": True,
912
+ }
913
+ mlp_kwargs_output = mlp_kwargs_output if mlp_kwargs_output is not None else {}
914
+ _mlp_kwargs_output.update(mlp_kwargs_output)
915
+ self.out_features = out_features
916
+ self.out_features_value = out_features_value
917
+ self.advantage = MLP(
918
+ out_features=out_features, device=device, **_mlp_kwargs_output
919
+ )
920
+ self.value = MLP(
921
+ out_features=out_features_value, device=device, **_mlp_kwargs_output
922
+ )
923
+ for layer in self.modules():
924
+ if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance(
925
+ layer.bias, torch.Tensor
926
+ ):
927
+ layer.bias.data.zero_()
928
+
929
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
930
+ x = self.features(x)
931
+ advantage = self.advantage(x)
932
+ value = self.value(x)
933
+ return value + advantage - advantage.mean(dim=-1, keepdim=True)
934
+
935
+
936
+ class DuelingCnnDQNet(nn.Module):
937
+ """Dueling CNN Q-network.
938
+
939
+ Presented in https://arxiv.org/abs/1511.06581
940
+
941
+ Args:
942
+ out_features (int): number of features for the advantage network.
943
+ out_features_value (int): number of features for the value network.
944
+ cnn_kwargs (dict or list of dicts, optional): kwargs for the feature
945
+ network. Default is
946
+
947
+ >>> cnn_kwargs = {
948
+ ... 'num_cells': [32, 64, 64],
949
+ ... 'strides': [4, 2, 1],
950
+ ... 'kernel_sizes': [8, 4, 3],
951
+ ... }
952
+
953
+ mlp_kwargs (dict or list of dicts, optional): kwargs for the advantage
954
+ and value network. Default is
955
+
956
+ >>> mlp_kwargs = {
957
+ ... "depth": 1,
958
+ ... "activation_class": nn.ELU,
959
+ ... "num_cells": 512,
960
+ ... "bias_last_layer": True,
961
+ ... }
962
+
963
+ device (torch.device, optional): device to create the module on.
964
+
965
+ Examples:
966
+ >>> import torch
967
+ >>> from torchrl.modules import DuelingCnnDQNet
968
+ >>> net = DuelingCnnDQNet(out_features=20)
969
+ >>> print(net)
970
+ DuelingCnnDQNet(
971
+ (features): ConvNet(
972
+ (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4))
973
+ (1): ELU(alpha=1.0)
974
+ (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
975
+ (3): ELU(alpha=1.0)
976
+ (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
977
+ (5): ELU(alpha=1.0)
978
+ (6): SquashDims()
979
+ )
980
+ (advantage): MLP(
981
+ (0): LazyLinear(in_features=0, out_features=512, bias=True)
982
+ (1): ELU(alpha=1.0)
983
+ (2): Linear(in_features=512, out_features=20, bias=True)
984
+ )
985
+ (value): MLP(
986
+ (0): LazyLinear(in_features=0, out_features=512, bias=True)
987
+ (1): ELU(alpha=1.0)
988
+ (2): Linear(in_features=512, out_features=1, bias=True)
989
+ )
990
+ )
991
+ >>> x = torch.zeros(1, 3, 64, 64)
992
+ >>> y = net(x)
993
+ >>> print(y.shape)
994
+ torch.Size([1, 20])
995
+
996
+ """
997
+
998
+ def __init__(
999
+ self,
1000
+ out_features: int,
1001
+ out_features_value: int = 1,
1002
+ cnn_kwargs: dict | None = None,
1003
+ mlp_kwargs: dict | None = None,
1004
+ device: DEVICE_TYPING | None = None,
1005
+ ):
1006
+ super().__init__()
1007
+
1008
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
1009
+ _cnn_kwargs = {
1010
+ "num_cells": [32, 64, 64],
1011
+ "strides": [4, 2, 1],
1012
+ "kernel_sizes": [8, 4, 3],
1013
+ }
1014
+ _cnn_kwargs.update(cnn_kwargs)
1015
+ self.features = ConvNet(device=device, **_cnn_kwargs)
1016
+
1017
+ _mlp_kwargs = {
1018
+ "depth": 1,
1019
+ "activation_class": nn.ELU,
1020
+ "num_cells": 512,
1021
+ "bias_last_layer": True,
1022
+ }
1023
+ mlp_kwargs = mlp_kwargs if mlp_kwargs is not None else {}
1024
+ _mlp_kwargs.update(mlp_kwargs)
1025
+ self.out_features = out_features
1026
+ self.out_features_value = out_features_value
1027
+ self.advantage = MLP(out_features=out_features, device=device, **_mlp_kwargs)
1028
+ self.value = MLP(out_features=out_features_value, device=device, **_mlp_kwargs)
1029
+ for layer in self.modules():
1030
+ if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance(
1031
+ layer.bias, torch.Tensor
1032
+ ):
1033
+ layer.bias.data.zero_()
1034
+
1035
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1036
+ x = self.features(x)
1037
+ advantage = self.advantage(x)
1038
+ value = self.value(x)
1039
+ return value + advantage - advantage.mean(dim=-1, keepdim=True)
1040
+
1041
+
1042
+ def ddpg_init_last_layer(
1043
+ module: nn.Sequential,
1044
+ scale: float = 6e-4,
1045
+ device: DEVICE_TYPING | None = None,
1046
+ ) -> None:
1047
+ """Initializer for the last layer of DDPG modules.
1048
+
1049
+ Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
1050
+ https://arxiv.org/pdf/1509.02971.pdf
1051
+
1052
+ Args:
1053
+ module (nn.Module): an actor or critic to be initialized.
1054
+ scale (:obj:`float`, optional): the noise scale. Defaults to ``6e-4``.
1055
+ device (torch.device, optional): the device where the noise should be
1056
+ created. Defaults to the device of the last layer's weight
1057
+ parameter.
1058
+
1059
+ Examples:
1060
+ >>> from torchrl.modules.models.models import MLP, ddpg_init_last_layer
1061
+ >>> mlp = MLP(in_features=4, out_features=5, num_cells=(10, 10))
1062
+ >>> # init the last layer of the MLP
1063
+ >>> ddpg_init_last_layer(mlp)
1064
+
1065
+ """
1066
+ for last_layer in reversed(module):
1067
+ if isinstance(last_layer, (nn.Linear, nn.Conv2d)):
1068
+ break
1069
+ else:
1070
+ raise RuntimeError("Could not find a nn.Linear / nn.Conv2d to initialize.")
1071
+
1072
+ last_layer.weight.data.copy_(
1073
+ torch.rand_like(last_layer.weight.data, device=device) * scale - scale / 2
1074
+ )
1075
+ if last_layer.bias is not None:
1076
+ last_layer.bias.data.copy_(
1077
+ torch.rand_like(last_layer.bias.data, device=device) * scale - scale / 2
1078
+ )
1079
+
1080
+
1081
+ class DdpgCnnActor(nn.Module):
1082
+ """DDPG Convolutional Actor class.
1083
+
1084
+ Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
1085
+ https://arxiv.org/pdf/1509.02971.pdf
1086
+
1087
+ The DDPG Convolutional Actor takes as input an observation (some simple
1088
+ transformation of the observed pixels) and returns an action vector from
1089
+ it, as well as an observation embedding that can be reused for a value
1090
+ estimation. It should be trained to maximise the value returned by the
1091
+ DDPG Q Value network.
1092
+
1093
+ Args:
1094
+ action_dim (int): length of the action vector.
1095
+ conv_net_kwargs (dict or list of dicts, optional): kwargs for the ConvNet.
1096
+ Defaults to
1097
+
1098
+ >>> {
1099
+ ... 'in_features': None,
1100
+ ... "num_cells": [32, 64, 64],
1101
+ ... "kernel_sizes": [8, 4, 3],
1102
+ ... "strides": [4, 2, 1],
1103
+ ... "paddings": [0, 0, 1],
1104
+ ... 'activation_class': torch.nn.ELU,
1105
+ ... 'norm_class': None,
1106
+ ... 'aggregator_class': SquashDims,
1107
+ ... 'aggregator_kwargs': {"ndims_in": 3},
1108
+ ... 'squeeze_output': True,
1109
+ ... } #
1110
+
1111
+ mlp_net_kwargs: kwargs for MLP.
1112
+ Defaults to:
1113
+
1114
+ >>> {
1115
+ ... 'in_features': None,
1116
+ ... 'out_features': action_dim,
1117
+ ... 'depth': 2,
1118
+ ... 'num_cells': 200,
1119
+ ... 'activation_class': nn.ELU,
1120
+ ... 'bias_last_layer': True,
1121
+ ... }
1122
+
1123
+ use_avg_pooling (bool, optional): if ``True``, a
1124
+ :class:`~torch.nn.AvgPooling` layer is used to aggregate the
1125
+ output. Defaults to ``False``.
1126
+ device (torch.device, optional): device to create the module on.
1127
+
1128
+ Examples:
1129
+ >>> import torch
1130
+ >>> from torchrl.modules import DdpgCnnActor
1131
+ >>> actor = DdpgCnnActor(action_dim=4)
1132
+ >>> print(actor)
1133
+ DdpgCnnActor(
1134
+ (convnet): ConvNet(
1135
+ (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4))
1136
+ (1): ELU(alpha=1.0)
1137
+ (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
1138
+ (3): ELU(alpha=1.0)
1139
+ (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1140
+ (5): ELU(alpha=1.0)
1141
+ (6): SquashDims()
1142
+ )
1143
+ (mlp): MLP(
1144
+ (0): LazyLinear(in_features=0, out_features=200, bias=True)
1145
+ (1): ELU(alpha=1.0)
1146
+ (2): Linear(in_features=200, out_features=200, bias=True)
1147
+ (3): ELU(alpha=1.0)
1148
+ (4): Linear(in_features=200, out_features=4, bias=True)
1149
+ )
1150
+ )
1151
+ >>> obs = torch.randn(10, 3, 64, 64)
1152
+ >>> action, hidden = actor(obs)
1153
+ >>> print(action.shape)
1154
+ torch.Size([10, 4])
1155
+ >>> print(hidden.shape)
1156
+ torch.Size([10, 2304])
1157
+
1158
+ """
1159
+
1160
+ def __init__(
1161
+ self,
1162
+ action_dim: int,
1163
+ conv_net_kwargs: dict | None = None,
1164
+ mlp_net_kwargs: dict | None = None,
1165
+ use_avg_pooling: bool = False,
1166
+ device: DEVICE_TYPING | None = None,
1167
+ ):
1168
+ super().__init__()
1169
+ conv_net_default_kwargs = {
1170
+ "in_features": None,
1171
+ "num_cells": [32, 64, 64],
1172
+ "kernel_sizes": [8, 4, 3],
1173
+ "strides": [4, 2, 1],
1174
+ "paddings": [0, 0, 1],
1175
+ "activation_class": nn.ELU,
1176
+ "norm_class": None,
1177
+ "aggregator_class": SquashDims
1178
+ if not use_avg_pooling
1179
+ else nn.AdaptiveAvgPool2d,
1180
+ "aggregator_kwargs": {"ndims_in": 3}
1181
+ if not use_avg_pooling
1182
+ else {"output_size": (1, 1)},
1183
+ "squeeze_output": use_avg_pooling,
1184
+ }
1185
+ conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else {}
1186
+ conv_net_default_kwargs.update(conv_net_kwargs)
1187
+ mlp_net_default_kwargs = {
1188
+ "in_features": None,
1189
+ "out_features": action_dim,
1190
+ "depth": 2,
1191
+ "num_cells": 200,
1192
+ "activation_class": nn.ELU,
1193
+ "bias_last_layer": True,
1194
+ }
1195
+ mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {}
1196
+ mlp_net_default_kwargs.update(mlp_net_kwargs)
1197
+ self.convnet = ConvNet(device=device, **conv_net_default_kwargs)
1198
+ self.mlp = MLP(device=device, **mlp_net_default_kwargs)
1199
+ ddpg_init_last_layer(self.mlp, 6e-4, device=device)
1200
+
1201
+ def forward(self, observation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1202
+ hidden = self.convnet(observation)
1203
+ action = self.mlp(hidden)
1204
+ return action, hidden
1205
+
1206
+
1207
+ class DdpgMlpActor(nn.Module):
1208
+ """DDPG Actor class.
1209
+
1210
+ Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
1211
+ https://arxiv.org/pdf/1509.02971.pdf
1212
+
1213
+ The DDPG Actor takes as input an observation vector and returns an action from it.
1214
+ It is trained to maximise the value returned by the DDPG Q Value network.
1215
+
1216
+ Args:
1217
+ action_dim (int): length of the action vector
1218
+ mlp_net_kwargs (dict, optional): kwargs for MLP.
1219
+ Defaults to
1220
+
1221
+ >>> {
1222
+ ... 'in_features': None,
1223
+ ... 'out_features': action_dim,
1224
+ ... 'depth': 2,
1225
+ ... 'num_cells': [400, 300],
1226
+ ... 'activation_class': nn.ELU,
1227
+ ... 'bias_last_layer': True,
1228
+ ... }
1229
+
1230
+ device (torch.device, optional): device to create the module on.
1231
+
1232
+ Examples:
1233
+ >>> import torch
1234
+ >>> from torchrl.modules import DdpgMlpActor
1235
+ >>> actor = DdpgMlpActor(action_dim=4)
1236
+ >>> print(actor)
1237
+ DdpgMlpActor(
1238
+ (mlp): MLP(
1239
+ (0): LazyLinear(in_features=0, out_features=400, bias=True)
1240
+ (1): ELU(alpha=1.0)
1241
+ (2): Linear(in_features=400, out_features=300, bias=True)
1242
+ (3): ELU(alpha=1.0)
1243
+ (4): Linear(in_features=300, out_features=4, bias=True)
1244
+ )
1245
+ )
1246
+ >>> obs = torch.zeros(10, 6)
1247
+ >>> action = actor(obs)
1248
+ >>> print(action.shape)
1249
+ torch.Size([10, 4])
1250
+
1251
+ """
1252
+
1253
+ def __init__(
1254
+ self,
1255
+ action_dim: int,
1256
+ mlp_net_kwargs: dict | None = None,
1257
+ device: DEVICE_TYPING | None = None,
1258
+ ):
1259
+ super().__init__()
1260
+ mlp_net_default_kwargs = {
1261
+ "in_features": None,
1262
+ "out_features": action_dim,
1263
+ "depth": 2,
1264
+ "num_cells": [400, 300],
1265
+ "activation_class": nn.ELU,
1266
+ "bias_last_layer": True,
1267
+ }
1268
+ mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {}
1269
+ mlp_net_default_kwargs.update(mlp_net_kwargs)
1270
+ self.mlp = MLP(device=device, **mlp_net_default_kwargs)
1271
+ ddpg_init_last_layer(self.mlp, 6e-3, device=device)
1272
+
1273
+ def forward(self, observation: torch.Tensor) -> torch.Tensor:
1274
+ action = self.mlp(observation)
1275
+ return action
1276
+
1277
+
1278
+ class DdpgCnnQNet(nn.Module):
1279
+ """DDPG Convolutional Q-value class.
1280
+
1281
+ Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
1282
+ https://arxiv.org/pdf/1509.02971.pdf
1283
+
1284
+ The DDPG Q-value network takes as input an observation and an action, and
1285
+ returns a scalar from it.
1286
+
1287
+ Args:
1288
+ conv_net_kwargs (dict, optional): kwargs for the
1289
+ convolutional network.
1290
+ Defaults to
1291
+
1292
+ >>> {
1293
+ ... 'in_features': None,
1294
+ ... "num_cells": [32, 64, 128],
1295
+ ... "kernel_sizes": [8, 4, 3],
1296
+ ... "strides": [4, 2, 1],
1297
+ ... "paddings": [0, 0, 1],
1298
+ ... 'activation_class': nn.ELU,
1299
+ ... 'norm_class': None,
1300
+ ... 'aggregator_class': nn.AdaptiveAvgPool2d,
1301
+ ... 'aggregator_kwargs': {},
1302
+ ... 'squeeze_output': True,
1303
+ ... }
1304
+
1305
+ mlp_net_kwargs (dict, optional): kwargs for MLP.
1306
+ Defaults to
1307
+
1308
+ >>> {
1309
+ ... 'in_features': None,
1310
+ ... 'out_features': 1,
1311
+ ... 'depth': 2,
1312
+ ... 'num_cells': 200,
1313
+ ... 'activation_class': nn.ELU,
1314
+ ... 'bias_last_layer': True,
1315
+ ... }
1316
+
1317
+ use_avg_pooling (bool, optional): if ``True``, a
1318
+ :class:`~torch.nn.AvgPooling` layer is used to aggregate the
1319
+ output. Default is ``True``.
1320
+ device (torch.device, optional): device to create the module on.
1321
+
1322
+ Examples:
1323
+ >>> from torchrl.modules import DdpgCnnQNet
1324
+ >>> import torch
1325
+ >>> net = DdpgCnnQNet()
1326
+ >>> print(net)
1327
+ DdpgCnnQNet(
1328
+ (convnet): ConvNet(
1329
+ (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4))
1330
+ (1): ELU(alpha=1.0)
1331
+ (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
1332
+ (3): ELU(alpha=1.0)
1333
+ (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1334
+ (5): ELU(alpha=1.0)
1335
+ (6): AdaptiveAvgPool2d(output_size=(1, 1))
1336
+ (7): Squeeze2dLayer()
1337
+ )
1338
+ (mlp): MLP(
1339
+ (0): LazyLinear(in_features=0, out_features=200, bias=True)
1340
+ (1): ELU(alpha=1.0)
1341
+ (2): Linear(in_features=200, out_features=200, bias=True)
1342
+ (3): ELU(alpha=1.0)
1343
+ (4): Linear(in_features=200, out_features=1, bias=True)
1344
+ )
1345
+ )
1346
+ >>> obs = torch.zeros(1, 3, 64, 64)
1347
+ >>> action = torch.zeros(1, 4)
1348
+ >>> value = net(obs, action)
1349
+ >>> print(value.shape)
1350
+ torch.Size([1, 1])
1351
+
1352
+
1353
+ """
1354
+
1355
+ def __init__(
1356
+ self,
1357
+ conv_net_kwargs: dict | None = None,
1358
+ mlp_net_kwargs: dict | None = None,
1359
+ use_avg_pooling: bool = True,
1360
+ device: DEVICE_TYPING | None = None,
1361
+ ):
1362
+ super().__init__()
1363
+ conv_net_default_kwargs = {
1364
+ "in_features": None,
1365
+ "num_cells": [32, 64, 128],
1366
+ "kernel_sizes": [8, 4, 3],
1367
+ "strides": [4, 2, 1],
1368
+ "paddings": [0, 0, 1],
1369
+ "activation_class": nn.ELU,
1370
+ "norm_class": None,
1371
+ "aggregator_class": SquashDims
1372
+ if not use_avg_pooling
1373
+ else nn.AdaptiveAvgPool2d,
1374
+ "aggregator_kwargs": {"ndims_in": 3}
1375
+ if not use_avg_pooling
1376
+ else {"output_size": (1, 1)},
1377
+ "squeeze_output": use_avg_pooling,
1378
+ }
1379
+ conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else {}
1380
+ conv_net_default_kwargs.update(conv_net_kwargs)
1381
+ mlp_net_default_kwargs = {
1382
+ "in_features": None,
1383
+ "out_features": 1,
1384
+ "depth": 2,
1385
+ "num_cells": 200,
1386
+ "activation_class": nn.ELU,
1387
+ "bias_last_layer": True,
1388
+ }
1389
+ mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {}
1390
+ mlp_net_default_kwargs.update(mlp_net_kwargs)
1391
+ self.convnet = ConvNet(device=device, **conv_net_default_kwargs)
1392
+ self.mlp = MLP(device=device, **mlp_net_default_kwargs)
1393
+ ddpg_init_last_layer(self.mlp, 6e-4, device=device)
1394
+
1395
+ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
1396
+ hidden = torch.cat([self.convnet(observation), action], -1)
1397
+ value = self.mlp(hidden)
1398
+ return value
1399
+
1400
+
1401
+ class DdpgMlpQNet(nn.Module):
1402
+ """DDPG Q-value MLP class.
1403
+
1404
+ Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING",
1405
+ https://arxiv.org/pdf/1509.02971.pdf
1406
+
1407
+ The DDPG Q-value network takes as input an observation and an action,
1408
+ and returns a scalar from it.
1409
+ Because actions are integrated later than observations, two networks are
1410
+ created.
1411
+
1412
+ Args:
1413
+ mlp_net_kwargs_net1 (dict, optional): kwargs for MLP.
1414
+ Defaults to
1415
+
1416
+ >>> {
1417
+ ... 'in_features': None,
1418
+ ... 'out_features': 400,
1419
+ ... 'depth': 0,
1420
+ ... 'num_cells': [],
1421
+ ... 'activation_class': nn.ELU,
1422
+ ... 'bias_last_layer': True,
1423
+ ... 'activate_last_layer': True,
1424
+ ... }
1425
+
1426
+ mlp_net_kwargs_net2
1427
+ Defaults to
1428
+
1429
+ >>> {
1430
+ ... 'in_features': None,
1431
+ ... 'out_features': 1,
1432
+ ... 'depth': 1,
1433
+ ... 'num_cells': [300, ],
1434
+ ... 'activation_class': nn.ELU,
1435
+ ... 'bias_last_layer': True,
1436
+ ... }
1437
+
1438
+ device (torch.device, optional): device to create the module on.
1439
+
1440
+ Examples:
1441
+ >>> import torch
1442
+ >>> from torchrl.modules import DdpgMlpQNet
1443
+ >>> net = DdpgMlpQNet()
1444
+ >>> print(net)
1445
+ DdpgMlpQNet(
1446
+ (mlp1): MLP(
1447
+ (0): LazyLinear(in_features=0, out_features=400, bias=True)
1448
+ (1): ELU(alpha=1.0)
1449
+ )
1450
+ (mlp2): MLP(
1451
+ (0): LazyLinear(in_features=0, out_features=300, bias=True)
1452
+ (1): ELU(alpha=1.0)
1453
+ (2): Linear(in_features=300, out_features=1, bias=True)
1454
+ )
1455
+ )
1456
+ >>> obs = torch.zeros(1, 32)
1457
+ >>> action = torch.zeros(1, 4)
1458
+ >>> value = net(obs, action)
1459
+ >>> print(value.shape)
1460
+ torch.Size([1, 1])
1461
+
1462
+ """
1463
+
1464
+ def __init__(
1465
+ self,
1466
+ mlp_net_kwargs_net1: dict | None = None,
1467
+ mlp_net_kwargs_net2: dict | None = None,
1468
+ device: DEVICE_TYPING | None = None,
1469
+ ):
1470
+ super().__init__()
1471
+ mlp1_net_default_kwargs = {
1472
+ "in_features": None,
1473
+ "out_features": 400,
1474
+ "depth": 0,
1475
+ "num_cells": [],
1476
+ "activation_class": nn.ELU,
1477
+ "bias_last_layer": True,
1478
+ "activate_last_layer": True,
1479
+ }
1480
+ mlp_net_kwargs_net1: dict = (
1481
+ mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else {}
1482
+ )
1483
+ mlp1_net_default_kwargs.update(mlp_net_kwargs_net1)
1484
+ self.mlp1 = MLP(device=device, **mlp1_net_default_kwargs)
1485
+
1486
+ mlp2_net_default_kwargs = {
1487
+ "in_features": None,
1488
+ "out_features": 1,
1489
+ "num_cells": [
1490
+ 300,
1491
+ ],
1492
+ "activation_class": nn.ELU,
1493
+ "bias_last_layer": True,
1494
+ }
1495
+ mlp_net_kwargs_net2 = (
1496
+ mlp_net_kwargs_net2 if mlp_net_kwargs_net2 is not None else {}
1497
+ )
1498
+ mlp2_net_default_kwargs.update(mlp_net_kwargs_net2)
1499
+ self.mlp2 = MLP(device=device, **mlp2_net_default_kwargs)
1500
+ ddpg_init_last_layer(self.mlp2, 6e-3, device=device)
1501
+
1502
+ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
1503
+ value = self.mlp2(torch.cat([self.mlp1(observation), action], -1))
1504
+ return value
1505
+
1506
+
1507
+ class OnlineDTActor(nn.Module):
1508
+ """Online Decision Transformer Actor class.
1509
+
1510
+ Actor class for the Online Decision Transformer to sample actions from
1511
+ gaussian distribution as presented inresented in
1512
+ `"Online Decision Transformer" <https://arxiv.org/abs/2202.05607.pdf>`_.
1513
+
1514
+ Returns the mean and standard deviation for the gaussian distribution to sample actions from.
1515
+
1516
+ Args:
1517
+ state_dim (int): state dimension.
1518
+ action_dim (int): action dimension.
1519
+ transformer_config (Dict or :class:`DecisionTransformer.DTConfig`):
1520
+ config for the GPT2 transformer.
1521
+ Defaults to :meth:`default_config`.
1522
+ device (torch.device, optional): device to use. Defaults to None.
1523
+
1524
+ Examples:
1525
+ >>> model = OnlineDTActor(state_dim=4, action_dim=2,
1526
+ ... transformer_config=OnlineDTActor.default_config())
1527
+ >>> observation = torch.randn(32, 10, 4)
1528
+ >>> action = torch.randn(32, 10, 2)
1529
+ >>> return_to_go = torch.randn(32, 10, 1)
1530
+ >>> mu, std = model(observation, action, return_to_go)
1531
+ >>> mu.shape
1532
+ torch.Size([32, 10, 2])
1533
+ >>> std.shape
1534
+ torch.Size([32, 10, 2])
1535
+ """
1536
+
1537
+ def __init__(
1538
+ self,
1539
+ state_dim: int,
1540
+ action_dim: int,
1541
+ transformer_config: dict | DecisionTransformer.DTConfig = None,
1542
+ device: DEVICE_TYPING | None = None,
1543
+ ):
1544
+ super().__init__()
1545
+ if transformer_config is None:
1546
+ transformer_config = self.default_config()
1547
+ if isinstance(transformer_config, DecisionTransformer.DTConfig):
1548
+ transformer_config = dataclasses.asdict(transformer_config)
1549
+ self.transformer = DecisionTransformer(
1550
+ state_dim=state_dim,
1551
+ action_dim=action_dim,
1552
+ config=transformer_config,
1553
+ device=device,
1554
+ )
1555
+ self.action_layer_mean = nn.Linear(
1556
+ transformer_config["n_embd"], action_dim, device=device
1557
+ )
1558
+ self.action_layer_logstd = nn.Linear(
1559
+ transformer_config["n_embd"], action_dim, device=device
1560
+ )
1561
+
1562
+ self.log_std_min, self.log_std_max = -5.0, 2.0
1563
+
1564
+ def weight_init(m):
1565
+ """Custom weight init for Conv2D and Linear layers."""
1566
+ if isinstance(m, torch.nn.Linear):
1567
+ nn.init.orthogonal_(m.weight.data)
1568
+ if hasattr(m.bias, "data"):
1569
+ m.bias.data.fill_(0.0)
1570
+
1571
+ self.action_layer_mean.apply(weight_init)
1572
+ self.action_layer_logstd.apply(weight_init)
1573
+
1574
+ def forward(
1575
+ self,
1576
+ observation: torch.Tensor,
1577
+ action: torch.Tensor,
1578
+ return_to_go: torch.Tensor,
1579
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1580
+ hidden_state = self.transformer(observation, action, return_to_go)
1581
+ mu = self.action_layer_mean(hidden_state)
1582
+ log_std = self.action_layer_logstd(hidden_state)
1583
+
1584
+ log_std = torch.tanh(log_std)
1585
+ # log_std is the output of tanh so it will be between [-1, 1]
1586
+ # map it to be between [log_std_min, log_std_max]
1587
+ log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (
1588
+ log_std + 1.0
1589
+ )
1590
+ std = log_std.exp()
1591
+
1592
+ return mu, std
1593
+
1594
+ @classmethod
1595
+ def default_config(cls):
1596
+ """Default configuration for :class:`~OnlineDTActor`."""
1597
+ return DecisionTransformer.DTConfig(
1598
+ n_embd=512,
1599
+ n_layer=4,
1600
+ n_head=4,
1601
+ n_inner=2048,
1602
+ activation="relu",
1603
+ n_positions=1024,
1604
+ resid_pdrop=0.1,
1605
+ attn_pdrop=0.1,
1606
+ )
1607
+
1608
+
1609
+ class DTActor(nn.Module):
1610
+ """Decision Transformer Actor class.
1611
+
1612
+ Actor class for the Decision Transformer to output deterministic action as
1613
+ presented in `"Decision Transformer" <https://arxiv.org/abs/2202.05607.pdf>`.
1614
+ Returns the deterministic actions.
1615
+
1616
+ Args:
1617
+ state_dim (int): state dimension.
1618
+ action_dim (int): action dimension.
1619
+ transformer_config (Dict or :class:`DecisionTransformer.DTConfig`, optional):
1620
+ config for the GPT2 transformer.
1621
+ Defaults to :meth:`~.default_config`.
1622
+ device (torch.device, optional): device to use. Defaults to None.
1623
+
1624
+ Examples:
1625
+ >>> model = DTActor(state_dim=4, action_dim=2,
1626
+ ... transformer_config=DTActor.default_config())
1627
+ >>> observation = torch.randn(32, 10, 4)
1628
+ >>> action = torch.randn(32, 10, 2)
1629
+ >>> return_to_go = torch.randn(32, 10, 1)
1630
+ >>> output = model(observation, action, return_to_go)
1631
+ >>> output.shape
1632
+ torch.Size([32, 10, 2])
1633
+
1634
+ """
1635
+
1636
+ def __init__(
1637
+ self,
1638
+ state_dim: int,
1639
+ action_dim: int,
1640
+ transformer_config: dict | DecisionTransformer.DTConfig = None,
1641
+ device: DEVICE_TYPING | None = None,
1642
+ ):
1643
+ super().__init__()
1644
+ if transformer_config is None:
1645
+ transformer_config = self.default_config()
1646
+ if isinstance(transformer_config, DecisionTransformer.DTConfig):
1647
+ transformer_config = dataclasses.asdict(transformer_config)
1648
+ self.transformer = DecisionTransformer(
1649
+ state_dim=state_dim,
1650
+ action_dim=action_dim,
1651
+ config=transformer_config,
1652
+ device=device,
1653
+ )
1654
+ self.action_layer = nn.Linear(
1655
+ transformer_config["n_embd"], action_dim, device=device
1656
+ )
1657
+
1658
+ def weight_init(m):
1659
+ """Custom weight init for Conv2D and Linear layers."""
1660
+ if isinstance(m, torch.nn.Linear):
1661
+ nn.init.orthogonal_(m.weight.data)
1662
+ if hasattr(m.bias, "data"):
1663
+ m.bias.data.fill_(0.0)
1664
+
1665
+ self.action_layer.apply(weight_init)
1666
+
1667
+ def forward(
1668
+ self,
1669
+ observation: torch.Tensor,
1670
+ action: torch.Tensor,
1671
+ return_to_go: torch.Tensor,
1672
+ ) -> torch.Tensor:
1673
+ hidden_state = self.transformer(observation, action, return_to_go)
1674
+ out = self.action_layer(hidden_state)
1675
+ return out
1676
+
1677
+ @classmethod
1678
+ def default_config(cls):
1679
+ """Default configuration for :class:`~DTActor`."""
1680
+ return DecisionTransformer.DTConfig(
1681
+ n_embd=512,
1682
+ n_layer=4,
1683
+ n_head=4,
1684
+ n_inner=2048,
1685
+ activation="relu",
1686
+ n_positions=1024,
1687
+ resid_pdrop=0.1,
1688
+ attn_pdrop=0.1,
1689
+ )
1690
+
1691
+
1692
+ def _iter_maybe_over_single(item: dict | list[dict] | None, n):
1693
+ if item is None:
1694
+ return iter([{} for _ in range(n)])
1695
+ elif isinstance(item, dict):
1696
+ return iter([deepcopy(item) for _ in range(n)])
1697
+ else:
1698
+ return iter([deepcopy(_item) for _item in item])
1699
+
1700
+
1701
+ class _ExecutableLayer(nn.Module):
1702
+ """A thin wrapper around a function to be executed as a module."""
1703
+
1704
+ def __init__(self, func):
1705
+ super().__init__()
1706
+ self.func = func
1707
+
1708
+ def forward(self, *args, **kwargs):
1709
+ return self.func(*args, **kwargs)
1710
+
1711
+ def __repr__(self):
1712
+ return f"{self.__class__.__name__}(func={self.func})"