torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,658 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from tensordict import set_lazy_legacy
12
+ from tensordict.nn import InteractionType
13
+ from torch import nn
14
+ from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
15
+ from torchrl.data.utils import DEVICE_TYPING
16
+ from torchrl.envs.common import EnvBase
17
+ from torchrl.envs.model_based.dreamer import DreamerEnv
18
+ from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
19
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
20
+ from torchrl.modules import (
21
+ NoisyLinear,
22
+ SafeModule,
23
+ SafeProbabilisticModule,
24
+ SafeProbabilisticTensorDictSequential,
25
+ SafeSequential,
26
+ )
27
+ from torchrl.modules.distributions import (
28
+ Delta,
29
+ OneHotCategorical,
30
+ TanhDelta,
31
+ TanhNormal,
32
+ )
33
+ from torchrl.modules.models.model_based import (
34
+ DreamerActor,
35
+ ObsDecoder,
36
+ ObsEncoder,
37
+ RSSMPosterior,
38
+ RSSMPrior,
39
+ RSSMRollout,
40
+ )
41
+ from torchrl.modules.models.models import DuelingCnnDQNet, DuelingMlpDQNet, MLP
42
+ from torchrl.modules.tensordict_module import (
43
+ Actor,
44
+ DistributionalQValueActor,
45
+ QValueActor,
46
+ )
47
+ from torchrl.modules.tensordict_module.world_models import WorldModelWrapper
48
+ from torchrl.trainers.helpers import transformed_env_constructor
49
+
50
+ DISTRIBUTIONS = {
51
+ "delta": Delta,
52
+ "tanh-normal": TanhNormal,
53
+ "categorical": OneHotCategorical,
54
+ "tanh-delta": TanhDelta,
55
+ }
56
+
57
+ ACTIVATIONS = {
58
+ "elu": nn.ELU,
59
+ "tanh": nn.Tanh,
60
+ "relu": nn.ReLU,
61
+ }
62
+
63
+
64
+ def make_dqn_actor(
65
+ proof_environment: EnvBase, cfg: DictConfig, device: torch.device # noqa: F821
66
+ ) -> Actor:
67
+ """DQN constructor helper function.
68
+
69
+ Args:
70
+ proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec.
71
+ cfg (DictConfig): contains arguments of the DQN script
72
+ device (torch.device): device on which the model must be cast
73
+
74
+ Returns:
75
+ A DQN policy operator.
76
+
77
+ Examples:
78
+ >>> from torchrl.trainers.helpers.models import make_dqn_actor, DiscreteModelConfig
79
+ >>> from torchrl.trainers.helpers.envs import EnvConfig
80
+ >>> from torchrl.envs.libs.gym import GymEnv
81
+ >>> from torchrl.envs.transforms import ToTensorImage, TransformedEnv
82
+ >>> import hydra
83
+ >>> from hydra.core.config_store import ConfigStore
84
+ >>> import dataclasses
85
+ >>> proof_environment = TransformedEnv(GymEnv("ALE/Pong-v5",
86
+ ... pixels_only=True), ToTensorImage())
87
+ >>> device = torch.device("cpu")
88
+ >>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in
89
+ ... (DiscreteModelConfig, EnvConfig)
90
+ ... for config_field in dataclasses.fields(config_cls)]
91
+ >>> Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
92
+ >>> cs = ConfigStore.instance()
93
+ >>> cs.store(name="config", node=Config)
94
+ >>> with initialize(config_path=None):
95
+ >>> cfg = compose(config_name="config")
96
+ >>> actor = make_dqn_actor(proof_environment, cfg, device)
97
+ >>> td = proof_environment.reset()
98
+ >>> print(actor(td))
99
+ TensorDict(
100
+ fields={
101
+ done: Tensor(torch.Size([1]), dtype=torch.bool),
102
+ pixels: Tensor(torch.Size([3, 210, 160]), dtype=torch.float32),
103
+ action: Tensor(torch.Size([6]), dtype=torch.int64),
104
+ action_value: Tensor(torch.Size([6]), dtype=torch.float32),
105
+ chosen_action_value: Tensor(torch.Size([1]), dtype=torch.float32)},
106
+ batch_size=torch.Size([]),
107
+ device=cpu,
108
+ is_shared=False)
109
+
110
+
111
+ """
112
+ env_specs = proof_environment.specs
113
+
114
+ atoms = cfg.atoms if cfg.distributional else None
115
+ linear_layer_class = torch.nn.Linear if not cfg.noisy else NoisyLinear
116
+
117
+ action_spec = env_specs["input_spec", "full_action_spec", "action"]
118
+ if action_spec.domain != "discrete":
119
+ raise ValueError(
120
+ f"env {proof_environment} has an action domain "
121
+ f"{action_spec.domain} which is incompatible with "
122
+ f"DQN. Make sure your environment has a discrete "
123
+ f"domain."
124
+ )
125
+
126
+ if cfg.from_pixels:
127
+ net_class = DuelingCnnDQNet
128
+ default_net_kwargs = {
129
+ "cnn_kwargs": {
130
+ "bias_last_layer": True,
131
+ "depth": None,
132
+ "num_cells": [32, 64, 64],
133
+ "kernel_sizes": [8, 4, 3],
134
+ "strides": [4, 2, 1],
135
+ },
136
+ "mlp_kwargs": {"num_cells": 512, "layer_class": linear_layer_class},
137
+ }
138
+ in_key = "pixels"
139
+
140
+ else:
141
+ net_class = DuelingMlpDQNet
142
+ default_net_kwargs = {
143
+ "mlp_kwargs_feature": {}, # see class for details
144
+ "mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class},
145
+ }
146
+ # automatically infer in key
147
+ (in_key,) = itertools.islice(
148
+ env_specs["output_spec", "full_observation_spec"], 1
149
+ )
150
+
151
+ actor_class = QValueActor
152
+ actor_kwargs = {}
153
+
154
+ if isinstance(action_spec, Categorical):
155
+ # if action spec is modeled as categorical variable, we still need to have features equal
156
+ # to the number of possible choices and also set categorical behavioral for actors.
157
+ actor_kwargs.update({"action_space": "categorical"})
158
+ out_features = env_specs["input_spec", "full_action_spec", "action"].space.n
159
+ else:
160
+ out_features = action_spec.shape[0]
161
+
162
+ if cfg.distributional:
163
+ if not atoms:
164
+ raise RuntimeError(
165
+ "Expected atoms to be a positive integer, " f"got {atoms}"
166
+ )
167
+ vmin = -3
168
+ vmax = 3
169
+
170
+ out_features = (atoms, out_features)
171
+ support = torch.linspace(vmin, vmax, atoms)
172
+ actor_class = DistributionalQValueActor
173
+ actor_kwargs.update({"support": support})
174
+ default_net_kwargs.update({"out_features_value": (atoms, 1)})
175
+
176
+ net = net_class(
177
+ out_features=out_features,
178
+ **default_net_kwargs,
179
+ )
180
+
181
+ model = actor_class(
182
+ module=net,
183
+ spec=Composite(action=action_spec),
184
+ in_keys=[in_key],
185
+ safe=True,
186
+ **actor_kwargs,
187
+ ).to(device)
188
+
189
+ # init
190
+ with torch.no_grad():
191
+ td = proof_environment.fake_tensordict()
192
+ td = td.unsqueeze(-1)
193
+ model(td.to(device))
194
+ return model
195
+
196
+
197
+ @set_lazy_legacy(False)
198
+ def make_dreamer(
199
+ cfg: DictConfig, # noqa: F821
200
+ proof_environment: EnvBase = None,
201
+ device: DEVICE_TYPING = "cpu",
202
+ action_key: str = "action",
203
+ value_key: str = "state_value",
204
+ use_decoder_in_env: bool = False,
205
+ obs_norm_state_dict=None,
206
+ ) -> nn.ModuleList:
207
+ """Create Dreamer components.
208
+
209
+ Args:
210
+ cfg (DictConfig): Config object.
211
+ proof_environment (EnvBase): Environment to initialize the model.
212
+ device (DEVICE_TYPING, optional): Device to use.
213
+ Defaults to "cpu".
214
+ action_key (str, optional): Key to use for the action.
215
+ Defaults to "action".
216
+ value_key (str, optional): Key to use for the value.
217
+ Defaults to "state_value".
218
+ use_decoder_in_env (bool, optional): Whether to use the decoder in the model based dreamer env.
219
+ Defaults to `False`.
220
+ obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform used
221
+ when proof_environment is missing. Defaults to None.
222
+
223
+ Returns:
224
+ nn.TensorDictModel: Dreamer World model.
225
+ nn.TensorDictModel: Dreamer Model based environment.
226
+ nn.TensorDictModel: Dreamer Actor the world model space.
227
+ nn.TensorDictModel: Dreamer Value model.
228
+ nn.TensorDictModel: Dreamer Actor for the real world space.
229
+
230
+ """
231
+ proof_env_is_none = proof_environment is None
232
+ if proof_env_is_none:
233
+ proof_environment = transformed_env_constructor(
234
+ cfg=cfg, use_env_creator=False, obs_norm_state_dict=obs_norm_state_dict
235
+ )()
236
+
237
+ # Modules
238
+ obs_encoder = ObsEncoder()
239
+ obs_decoder = ObsDecoder()
240
+
241
+ rssm_prior = RSSMPrior(
242
+ hidden_dim=cfg.rssm_hidden_dim,
243
+ rnn_hidden_dim=cfg.rssm_hidden_dim,
244
+ state_dim=cfg.state_dim,
245
+ action_spec=proof_environment.action_spec,
246
+ )
247
+ rssm_posterior = RSSMPosterior(
248
+ hidden_dim=cfg.rssm_hidden_dim, state_dim=cfg.state_dim
249
+ )
250
+ reward_module = MLP(
251
+ out_features=1, depth=2, num_cells=cfg.mlp_num_units, activation_class=nn.ELU
252
+ )
253
+
254
+ world_model = _dreamer_make_world_model(
255
+ obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
256
+ ).to(device)
257
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
258
+ tensordict = proof_environment.fake_tensordict().unsqueeze(-1)
259
+ tensordict = tensordict.to(device)
260
+ world_model(tensordict)
261
+
262
+ model_based_env = _dreamer_make_mbenv(
263
+ reward_module,
264
+ rssm_prior,
265
+ obs_decoder,
266
+ proof_environment,
267
+ use_decoder_in_env,
268
+ cfg.state_dim,
269
+ cfg.rssm_hidden_dim,
270
+ )
271
+ model_based_env = model_based_env.to(device)
272
+
273
+ actor_simulator, actor_realworld = _dreamer_make_actors(
274
+ obs_encoder,
275
+ rssm_prior,
276
+ rssm_posterior,
277
+ cfg.mlp_num_units,
278
+ action_key,
279
+ proof_environment,
280
+ )
281
+ actor_simulator = actor_simulator.to(device)
282
+
283
+ value_model = _dreamer_make_value_model(cfg.mlp_num_units, value_key)
284
+ value_model = value_model.to(device)
285
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
286
+ tensordict = model_based_env.fake_tensordict().unsqueeze(-1)
287
+ tensordict = tensordict.to(device)
288
+ tensordict = actor_simulator(tensordict)
289
+ value_model(tensordict)
290
+
291
+ actor_realworld = actor_realworld.to(device)
292
+ if proof_env_is_none:
293
+ proof_environment.close()
294
+ torch.cuda.empty_cache()
295
+ del proof_environment
296
+
297
+ del tensordict
298
+ return world_model, model_based_env, actor_simulator, value_model, actor_realworld
299
+
300
+
301
+ def _dreamer_make_world_model(
302
+ obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
303
+ ):
304
+ # World Model and reward model
305
+ rssm_rollout = RSSMRollout(
306
+ SafeModule(
307
+ rssm_prior,
308
+ in_keys=["state", "belief", "action"],
309
+ out_keys=[
310
+ ("next", "prior_mean"),
311
+ ("next", "prior_std"),
312
+ "_",
313
+ ("next", "belief"),
314
+ ],
315
+ ),
316
+ SafeModule(
317
+ rssm_posterior,
318
+ in_keys=[("next", "belief"), ("next", "encoded_latents")],
319
+ out_keys=[
320
+ ("next", "posterior_mean"),
321
+ ("next", "posterior_std"),
322
+ ("next", "state"),
323
+ ],
324
+ ),
325
+ )
326
+
327
+ transition_model = SafeSequential(
328
+ SafeModule(
329
+ obs_encoder,
330
+ in_keys=[("next", "pixels")],
331
+ out_keys=[("next", "encoded_latents")],
332
+ ),
333
+ rssm_rollout,
334
+ SafeModule(
335
+ obs_decoder,
336
+ in_keys=[("next", "state"), ("next", "belief")],
337
+ out_keys=[("next", "reco_pixels")],
338
+ ),
339
+ )
340
+ reward_model = SafeModule(
341
+ reward_module,
342
+ in_keys=[("next", "state"), ("next", "belief")],
343
+ out_keys=[("next", "reward")],
344
+ )
345
+ world_model = WorldModelWrapper(
346
+ transition_model,
347
+ reward_model,
348
+ )
349
+ return world_model
350
+
351
+
352
+ def _dreamer_make_actors(
353
+ obs_encoder,
354
+ rssm_prior,
355
+ rssm_posterior,
356
+ mlp_num_units,
357
+ action_key,
358
+ proof_environment,
359
+ ):
360
+ actor_module = DreamerActor(
361
+ out_features=proof_environment.action_spec.shape[0],
362
+ depth=3,
363
+ num_cells=mlp_num_units,
364
+ activation_class=nn.ELU,
365
+ )
366
+ actor_simulator = _dreamer_make_actor_sim(
367
+ action_key, proof_environment, actor_module
368
+ )
369
+ actor_realworld = _dreamer_make_actor_real(
370
+ obs_encoder,
371
+ rssm_prior,
372
+ rssm_posterior,
373
+ actor_module,
374
+ action_key,
375
+ proof_environment,
376
+ )
377
+ return actor_simulator, actor_realworld
378
+
379
+
380
+ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
381
+ actor_simulator = SafeProbabilisticTensorDictSequential(
382
+ SafeModule(
383
+ actor_module,
384
+ in_keys=["state", "belief"],
385
+ out_keys=["loc", "scale"],
386
+ spec=Composite(
387
+ **{
388
+ "loc": Unbounded(
389
+ proof_environment.action_spec.shape,
390
+ device=proof_environment.action_spec.device,
391
+ ),
392
+ "scale": Unbounded(
393
+ proof_environment.action_spec.shape,
394
+ device=proof_environment.action_spec.device,
395
+ ),
396
+ }
397
+ ),
398
+ ),
399
+ SafeProbabilisticModule(
400
+ in_keys=["loc", "scale"],
401
+ out_keys=[action_key],
402
+ default_interaction_type=InteractionType.RANDOM,
403
+ distribution_class=TanhNormal,
404
+ distribution_kwargs={"tanh_loc": True},
405
+ spec=Composite(**{action_key: proof_environment.action_spec}),
406
+ ),
407
+ )
408
+ return actor_simulator
409
+
410
+
411
+ def _dreamer_make_actor_real(
412
+ obs_encoder, rssm_prior, rssm_posterior, actor_module, action_key, proof_environment
413
+ ):
414
+ # actor for real world: interacts with states ~ posterior
415
+ # Out actor differs from the original paper where first they compute prior and posterior and then act on it
416
+ # but we found that this approach worked better.
417
+ actor_realworld = SafeSequential(
418
+ SafeModule(
419
+ obs_encoder,
420
+ in_keys=["pixels"],
421
+ out_keys=["encoded_latents"],
422
+ ),
423
+ SafeModule(
424
+ rssm_posterior,
425
+ in_keys=["belief", "encoded_latents"],
426
+ out_keys=[
427
+ "_",
428
+ "_",
429
+ "state",
430
+ ],
431
+ ),
432
+ SafeProbabilisticTensorDictSequential(
433
+ SafeModule(
434
+ actor_module,
435
+ in_keys=["state", "belief"],
436
+ out_keys=["loc", "scale"],
437
+ spec=Composite(
438
+ **{
439
+ "loc": Unbounded(
440
+ proof_environment.action_spec.shape,
441
+ ),
442
+ "scale": Unbounded(
443
+ proof_environment.action_spec.shape,
444
+ ),
445
+ }
446
+ ),
447
+ ),
448
+ SafeProbabilisticModule(
449
+ in_keys=["loc", "scale"],
450
+ out_keys=[action_key],
451
+ default_interaction_type=InteractionType.DETERMINISTIC,
452
+ distribution_class=TanhNormal,
453
+ distribution_kwargs={"tanh_loc": True},
454
+ spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}),
455
+ ),
456
+ ),
457
+ SafeModule(
458
+ rssm_prior,
459
+ in_keys=["state", "belief", action_key],
460
+ out_keys=[
461
+ "_",
462
+ "_",
463
+ "_", # we don't need the prior state
464
+ ("next", "belief"),
465
+ ],
466
+ ),
467
+ )
468
+ return actor_realworld
469
+
470
+
471
+ def _dreamer_make_value_model(mlp_num_units, value_key):
472
+ # actor for simulator: interacts with states ~ prior
473
+ value_model = SafeModule(
474
+ MLP(
475
+ out_features=1,
476
+ depth=3,
477
+ num_cells=mlp_num_units,
478
+ activation_class=nn.ELU,
479
+ ),
480
+ in_keys=["state", "belief"],
481
+ out_keys=[value_key],
482
+ )
483
+ return value_model
484
+
485
+
486
+ def _dreamer_make_mbenv(
487
+ reward_module,
488
+ rssm_prior,
489
+ obs_decoder,
490
+ proof_environment,
491
+ use_decoder_in_env,
492
+ state_dim,
493
+ rssm_hidden_dim,
494
+ ):
495
+ # MB environment
496
+ if use_decoder_in_env:
497
+ mb_env_obs_decoder = SafeModule(
498
+ obs_decoder,
499
+ in_keys=[("next", "state"), ("next", "belief")],
500
+ out_keys=[("next", "reco_pixels")],
501
+ )
502
+ else:
503
+ mb_env_obs_decoder = None
504
+
505
+ transition_model = SafeSequential(
506
+ SafeModule(
507
+ rssm_prior,
508
+ in_keys=["state", "belief", "action"],
509
+ out_keys=[
510
+ "_",
511
+ "_",
512
+ "state",
513
+ "belief",
514
+ ],
515
+ ),
516
+ )
517
+ reward_model = SafeModule(
518
+ reward_module,
519
+ in_keys=["state", "belief"],
520
+ out_keys=["reward"],
521
+ )
522
+ model_based_env = DreamerEnv(
523
+ world_model=WorldModelWrapper(
524
+ transition_model,
525
+ reward_model,
526
+ ),
527
+ prior_shape=torch.Size([state_dim]),
528
+ belief_shape=torch.Size([rssm_hidden_dim]),
529
+ obs_decoder=mb_env_obs_decoder,
530
+ )
531
+
532
+ model_based_env.set_specs_from_env(proof_environment)
533
+ model_based_env = TransformedEnv(model_based_env)
534
+ default_dict = {
535
+ "state": Unbounded(state_dim),
536
+ "belief": Unbounded(rssm_hidden_dim),
537
+ # "action": proof_environment.action_spec,
538
+ }
539
+ model_based_env.append_transform(
540
+ TensorDictPrimer(random=False, default_value=0, **default_dict)
541
+ )
542
+ return model_based_env
543
+
544
+
545
+ @dataclass
546
+ class DreamerConfig:
547
+ """Dreamer model config struct."""
548
+
549
+ batch_length: int = 50
550
+ state_dim: int = 30
551
+ rssm_hidden_dim: int = 200
552
+ mlp_num_units: int = 400
553
+ grad_clip: int = 100
554
+ world_model_lr: float = 6e-4
555
+ actor_value_lr: float = 8e-5
556
+ imagination_horizon: int = 15
557
+ model_device: str = ""
558
+ # Decay of the reward moving averaging
559
+ exploration: str = "additive_gaussian"
560
+ # One of "additive_gaussian", "ou_exploration" or ""
561
+
562
+
563
+ @dataclass
564
+ class REDQModelConfig:
565
+ """REDQ model config struct."""
566
+
567
+ annealing_frames: int = 1000000
568
+ # float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.
569
+ noisy: bool = False
570
+ # whether to use NoisyLinearLayers in the value network.
571
+ ou_exploration: bool = False
572
+ # wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for
573
+ # efficient entropy-based exploration, this should be left for experimentation only.
574
+ ou_sigma: float = 0.2
575
+ # Ornstein-Uhlenbeck sigma
576
+ ou_theta: float = 0.15
577
+ # Aimed at superseding --ou_exploration.
578
+ distributional: bool = False
579
+ # whether a distributional loss should be used (TODO: not implemented yet).
580
+ atoms: int = 51
581
+ # number of atoms used for the distributional loss (TODO)
582
+ gSDE: bool = False
583
+ # if True, exploration is achieved using the gSDE technique.
584
+ tanh_loc: bool = False
585
+ # if True, uses a Tanh-Normal transform for the policy location of the form
586
+ # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions)
587
+ default_policy_scale: float = 1.0
588
+ # Default policy scale parameter
589
+ distribution: str = "tanh_normal"
590
+ # if True, uses a Tanh-Normal-Tanh distribution for the policy
591
+ actor_cells: int = 256
592
+ # cells of the actor
593
+ qvalue_cells: int = 256
594
+ # cells of the qvalue net
595
+ scale_lb: float = 0.1
596
+ # min value of scale
597
+ value_cells: int = 256
598
+ # cells of the value net
599
+ activation: str = "tanh"
600
+ # activation function, either relu or elu or tanh, Default=tanh
601
+
602
+
603
+ @dataclass
604
+ class ContinuousModelConfig:
605
+ """Continuous control model config struct."""
606
+
607
+ annealing_frames: int = 1000000
608
+ # float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.
609
+ noisy: bool = False
610
+ # whether to use NoisyLinearLayers in the value network.
611
+ ou_exploration: bool = False
612
+ # wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for
613
+ # efficient entropy-based exploration, this should be left for experimentation only.
614
+ ou_sigma: float = 0.2
615
+ # Ornstein-Uhlenbeck sigma
616
+ ou_theta: float = 0.15
617
+ # Aimed at superseding --ou_exploration.
618
+ distributional: bool = False
619
+ # whether a distributional loss should be used (TODO: not implemented yet).
620
+ atoms: int = 51
621
+ # number of atoms used for the distributional loss (TODO)
622
+ gSDE: bool = False
623
+ # if True, exploration is achieved using the gSDE technique.
624
+ tanh_loc: bool = False
625
+ # if True, uses a Tanh-Normal transform for the policy location of the form
626
+ # upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions)
627
+ default_policy_scale: float = 1.0
628
+ # Default policy scale parameter
629
+ distribution: str = "tanh_normal"
630
+ # if True, uses a Tanh-Normal-Tanh distribution for the policy
631
+ lstm: bool = False
632
+ # if True, uses an LSTM for the policy.
633
+ shared_mapping: bool = False
634
+ # if True, the first layers of the actor-critic are shared.
635
+ actor_cells: int = 256
636
+ # cells of the actor
637
+ qvalue_cells: int = 256
638
+ # cells of the qvalue net
639
+ scale_lb: float = 0.1
640
+ # min value of scale
641
+ value_cells: int = 256
642
+ # cells of the value net
643
+ activation: str = "tanh"
644
+ # activation function, either relu or elu or tanh, Default=tanh
645
+
646
+
647
+ @dataclass
648
+ class DiscreteModelConfig:
649
+ """Discrete model config struct."""
650
+
651
+ annealing_frames: int = 1000000
652
+ # Number of frames used for annealing of the EGreedy exploration. Default=1e6.
653
+ noisy: bool = False
654
+ # whether to use NoisyLinearLayers in the value network
655
+ distributional: bool = False
656
+ # whether a distributional loss should be used.
657
+ atoms: int = 51
658
+ # number of atoms used for the distributional loss