torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1060 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable, Sequence
8
+
9
+ from copy import copy
10
+
11
+ import torch
12
+ from omegaconf import OmegaConf
13
+ from tensordict.nn import (
14
+ InteractionType,
15
+ ProbabilisticTensorDictSequential,
16
+ TensorDictModule,
17
+ TensorDictModuleWrapper,
18
+ )
19
+ from torch import distributions as d, nn, optim
20
+ from torch.optim.lr_scheduler import CosineAnnealingLR
21
+
22
+ from torchrl._utils import logger as torchrl_logger, VERBOSE
23
+ from torchrl.collectors import DataCollectorBase
24
+ from torchrl.data import (
25
+ LazyMemmapStorage,
26
+ MultiStep,
27
+ PrioritizedSampler,
28
+ RandomSampler,
29
+ ReplayBuffer,
30
+ TensorDictReplayBuffer,
31
+ )
32
+ from torchrl.data.utils import DEVICE_TYPING
33
+ from torchrl.envs import (
34
+ CatFrames,
35
+ CatTensors,
36
+ CenterCrop,
37
+ Compose,
38
+ DMControlEnv,
39
+ DoubleToFloat,
40
+ env_creator,
41
+ EnvBase,
42
+ EnvCreator,
43
+ FlattenObservation,
44
+ GrayScale,
45
+ gSDENoise,
46
+ GymEnv,
47
+ InitTracker,
48
+ NoopResetEnv,
49
+ ObservationNorm,
50
+ ParallelEnv,
51
+ Resize,
52
+ RewardScaling,
53
+ StepCounter,
54
+ ToTensorImage,
55
+ TransformedEnv,
56
+ VecNorm,
57
+ )
58
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
59
+ from torchrl.modules import (
60
+ ActorCriticOperator,
61
+ ActorValueOperator,
62
+ DdpgCnnActor,
63
+ DdpgCnnQNet,
64
+ MLP,
65
+ NoisyLinear,
66
+ NormalParamExtractor,
67
+ ProbabilisticActor,
68
+ SafeModule,
69
+ SafeSequential,
70
+ TanhNormal,
71
+ ValueOperator,
72
+ )
73
+ from torchrl.modules.distributions.continuous import SafeTanhTransform
74
+ from torchrl.modules.models.exploration import LazygSDEModule
75
+ from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater
76
+ from torchrl.objectives.deprecated import REDQLoss_deprecated
77
+ from torchrl.record.loggers import Logger
78
+ from torchrl.record.recorder import VideoRecorder
79
+ from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector
80
+ from torchrl.trainers.trainers import (
81
+ BatchSubSampler,
82
+ ClearCudaCache,
83
+ CountFramesLog,
84
+ LogScalar,
85
+ LogValidationReward,
86
+ ReplayBufferTrainer,
87
+ RewardNormalizer,
88
+ Trainer,
89
+ UpdateWeights,
90
+ )
91
+
92
+ LIBS = {
93
+ "gym": GymEnv,
94
+ "dm_control": DMControlEnv,
95
+ }
96
+ ACTIVATIONS = {
97
+ "elu": nn.ELU,
98
+ "tanh": nn.Tanh,
99
+ "relu": nn.ReLU,
100
+ }
101
+ OPTIMIZERS = {
102
+ "adam": optim.Adam,
103
+ "sgd": optim.SGD,
104
+ "adamax": optim.Adamax,
105
+ }
106
+
107
+
108
+ def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821
109
+ """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.
110
+
111
+ This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames
112
+ of 1M but actually collecting frame_skip * 1M frames.
113
+
114
+ Args:
115
+ cfg (DictConfig): DictConfig containing some frame-counting argument, including:
116
+ "max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames",
117
+ "init_random_frames", "init_env_steps"
118
+
119
+ Returns:
120
+ the input DictConfig, modified in-place.
121
+
122
+ """
123
+
124
+ def _hasattr(field):
125
+ local_cfg = cfg
126
+ fields = field.split(".")
127
+ for f in fields:
128
+ if not hasattr(local_cfg, f):
129
+ return False
130
+ local_cfg = getattr(local_cfg, f)
131
+ else:
132
+ return True
133
+
134
+ def _getattr(field):
135
+ local_cfg = cfg
136
+ fields = field.split(".")
137
+ for f in fields:
138
+ local_cfg = getattr(local_cfg, f)
139
+ return local_cfg
140
+
141
+ def _setattr(field, val):
142
+ local_cfg = cfg
143
+ fields = field.split(".")
144
+ for f in fields[:-1]:
145
+ local_cfg = getattr(local_cfg, f)
146
+ setattr(local_cfg, field[-1], val)
147
+
148
+ # Adapt all frame counts wrt frame_skip
149
+ frame_skip = cfg.env.frame_skip
150
+ if frame_skip != 1:
151
+ fields = [
152
+ "collector.max_frames_per_traj",
153
+ "collector.total_frames",
154
+ "collector.frames_per_batch",
155
+ "logger.record_frames",
156
+ "exploration.annealing_frames",
157
+ "collector.init_random_frames",
158
+ "env.init_env_steps",
159
+ "env.noops",
160
+ ]
161
+ for field in fields:
162
+ if _hasattr(cfg, field):
163
+ _setattr(field, _getattr(field) // frame_skip)
164
+ return cfg
165
+
166
+
167
+ def make_trainer(
168
+ collector: DataCollectorBase,
169
+ loss_module: LossModule,
170
+ recorder: EnvBase | None,
171
+ target_net_updater: TargetNetUpdater | None,
172
+ policy_exploration: TensorDictModuleWrapper | TensorDictModule | None,
173
+ replay_buffer: ReplayBuffer | None,
174
+ logger: Logger | None,
175
+ cfg: DictConfig, # noqa: F821
176
+ ) -> Trainer:
177
+ """Creates a Trainer instance given its constituents.
178
+
179
+ Args:
180
+ collector (DataCollectorBase): A data collector to be used to collect data.
181
+ loss_module (LossModule): A TorchRL loss module
182
+ recorder (EnvBase, optional): a recorder environment.
183
+ target_net_updater (TargetNetUpdater): A target network update object.
184
+ policy_exploration (TDModule or TensorDictModuleWrapper): a policy to be used for recording and exploration
185
+ updates (should be synced with the learnt policy).
186
+ replay_buffer (ReplayBuffer): a replay buffer to be used to collect data.
187
+ logger (Logger): a Logger to be used for logging.
188
+ cfg (DictConfig): a DictConfig containing the arguments of the script.
189
+
190
+ Returns:
191
+ A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided.
192
+
193
+ Examples:
194
+ >>> import torch
195
+ >>> import tempfile
196
+ >>> from torchrl.trainers.loggers import TensorboardLogger
197
+ >>> from torchrl.trainers import Trainer
198
+ >>> from torchrl.envs import EnvCreator
199
+ >>> from torchrl.collectors import SyncDataCollector
200
+ >>> from torchrl.data import TensorDictReplayBuffer
201
+ >>> from torchrl.envs.libs.gym import GymEnv
202
+ >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper
203
+ >>> from torchrl.objectives.common import LossModule
204
+ >>> from torchrl.objectives.utils import TargetNetUpdater
205
+ >>> from torchrl.objectives import DDPGLoss
206
+ >>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
207
+ >>> env_proof = env_maker()
208
+ >>> obs_spec = env_proof.observation_spec
209
+ >>> action_spec = env_proof.action_spec
210
+ >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1])
211
+ >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing
212
+ >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"])
213
+ >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"])
214
+ >>> collector = SyncDataCollector(env_maker, policy, total_frames=100)
215
+ >>> loss_module = DDPGLoss(policy, value, gamma=0.99)
216
+ >>> recorder = env_proof
217
+ >>> target_net_updater = None
218
+ >>> policy_exploration = EGreedyWrapper(policy)
219
+ >>> replay_buffer = TensorDictReplayBuffer()
220
+ >>> dir = tempfile.gettempdir()
221
+ >>> logger = TensorboardLogger(exp_name=dir)
222
+ >>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration,
223
+ ... replay_buffer, logger)
224
+ >>> torchrl_logger.info(trainer)
225
+
226
+ """
227
+
228
+ optimizer = OPTIMIZERS[cfg.optim.optimizer](
229
+ loss_module.parameters(),
230
+ lr=cfg.optim.lr,
231
+ weight_decay=cfg.optim.weight_decay,
232
+ eps=cfg.optim.eps,
233
+ **OmegaConf.to_container(cfg.optim.kwargs),
234
+ )
235
+ device = next(loss_module.parameters()).device
236
+ if cfg.optim.lr_scheduler == "cosine":
237
+ optim_scheduler = CosineAnnealingLR(
238
+ optimizer,
239
+ T_max=int(
240
+ cfg.collector.total_frames
241
+ / cfg.collector.frames_per_batch
242
+ * cfg.optim.steps_per_batch
243
+ ),
244
+ )
245
+ elif cfg.optim.lr_scheduler == "":
246
+ optim_scheduler = None
247
+ else:
248
+ raise NotImplementedError(f"lr scheduler {cfg.optim.lr_scheduler}")
249
+
250
+ if VERBOSE:
251
+ torchrl_logger.info(
252
+ f"collector = {collector}; \n"
253
+ f"loss_module = {loss_module}; \n"
254
+ f"recorder = {recorder}; \n"
255
+ f"target_net_updater = {target_net_updater}; \n"
256
+ f"policy_exploration = {policy_exploration}; \n"
257
+ f"replay_buffer = {replay_buffer}; \n"
258
+ f"logger = {logger}; \n"
259
+ f"cfg = {cfg}; \n"
260
+ )
261
+
262
+ if logger is not None:
263
+ # log hyperparams
264
+ logger.log_hparams(cfg)
265
+
266
+ trainer = Trainer(
267
+ collector=collector,
268
+ frame_skip=cfg.env.frame_skip,
269
+ total_frames=cfg.collector.total_frames * cfg.env.frame_skip,
270
+ loss_module=loss_module,
271
+ optimizer=optimizer,
272
+ logger=logger,
273
+ optim_steps_per_batch=cfg.optim.steps_per_batch,
274
+ clip_grad_norm=cfg.optim.clip_grad_norm,
275
+ clip_norm=cfg.optim.clip_norm,
276
+ )
277
+
278
+ if torch.cuda.device_count() > 0:
279
+ trainer.register_op("pre_optim_steps", ClearCudaCache(1))
280
+
281
+ trainer.register_op("batch_process", lambda batch: batch.cpu())
282
+
283
+ if replay_buffer is not None:
284
+ # replay buffer is used 2 or 3 times: to register data, to sample
285
+ # data and to update priorities
286
+ rb_trainer = ReplayBufferTrainer(
287
+ replay_buffer,
288
+ cfg.buffer.batch_size,
289
+ flatten_tensordicts=True,
290
+ memmap=False,
291
+ device=device,
292
+ )
293
+
294
+ trainer.register_op("batch_process", rb_trainer.extend)
295
+ trainer.register_op("process_optim_batch", rb_trainer.sample)
296
+ trainer.register_op("post_loss", rb_trainer.update_priority)
297
+ else:
298
+ # trainer.register_op("batch_process", mask_batch)
299
+ trainer.register_op(
300
+ "process_optim_batch",
301
+ BatchSubSampler(
302
+ batch_size=cfg.buffer.batch_size, sub_traj_len=cfg.buffer.sub_traj_len
303
+ ),
304
+ )
305
+ trainer.register_op("process_optim_batch", lambda batch: batch.to(device))
306
+
307
+ if optim_scheduler is not None:
308
+ trainer.register_op("post_optim", optim_scheduler.step)
309
+
310
+ if target_net_updater is not None:
311
+ trainer.register_op("post_optim", target_net_updater.step)
312
+
313
+ if cfg.env.normalize_rewards_online:
314
+ # if used the running statistics of the rewards are computed and the
315
+ # rewards used for training will be normalized based on these.
316
+ reward_normalizer = RewardNormalizer(
317
+ scale=cfg.env.normalize_rewards_online_scale,
318
+ decay=cfg.env.normalize_rewards_online_decay,
319
+ )
320
+ trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
321
+ trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)
322
+
323
+ if policy_exploration is not None and hasattr(policy_exploration, "step"):
324
+ trainer.register_op(
325
+ "post_steps", policy_exploration.step, frames=cfg.collector.frames_per_batch
326
+ )
327
+
328
+ trainer.register_op(
329
+ "post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]}
330
+ )
331
+
332
+ if recorder is not None:
333
+ # create recorder object
334
+ recorder_obj = LogValidationReward(
335
+ record_frames=cfg.logger.record_frames,
336
+ frame_skip=cfg.env.frame_skip,
337
+ policy_exploration=policy_exploration,
338
+ environment=recorder,
339
+ record_interval=cfg.logger.record_interval,
340
+ log_keys=cfg.logger.recorder_log_keys,
341
+ )
342
+ # register recorder
343
+ trainer.register_op(
344
+ "post_steps_log",
345
+ recorder_obj,
346
+ )
347
+ # call recorder - could be removed
348
+ recorder_obj(None)
349
+ # create explorative recorder - could be optional
350
+ recorder_obj_explore = LogValidationReward(
351
+ record_frames=cfg.logger.record_frames,
352
+ frame_skip=cfg.env.frame_skip,
353
+ policy_exploration=policy_exploration,
354
+ environment=recorder,
355
+ record_interval=cfg.logger.record_interval,
356
+ exploration_type=ExplorationType.RANDOM,
357
+ suffix="exploration",
358
+ out_keys={("next", "reward"): "r_evaluation_exploration"},
359
+ )
360
+ # register recorder
361
+ trainer.register_op(
362
+ "post_steps_log",
363
+ recorder_obj_explore,
364
+ )
365
+ # call recorder - could be removed
366
+ recorder_obj_explore(None)
367
+
368
+ trainer.register_op(
369
+ "post_steps", UpdateWeights(collector, update_weights_interval=1)
370
+ )
371
+
372
+ trainer.register_op("pre_steps_log", LogScalar())
373
+ trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))
374
+
375
+ return trainer
376
+
377
+
378
+ def make_redq_model(
379
+ proof_environment: EnvBase,
380
+ cfg: DictConfig, # noqa: F821
381
+ device: DEVICE_TYPING = "cpu",
382
+ in_keys: Sequence[str] | None = None,
383
+ actor_net_kwargs=None,
384
+ qvalue_net_kwargs=None,
385
+ observation_key=None,
386
+ **kwargs,
387
+ ) -> nn.ModuleList:
388
+ """Actor and Q-value model constructor helper function for REDQ.
389
+
390
+ Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd.
391
+ Other configurations can easily be implemented by modifying this function at will.
392
+ A single instance of the Q-value model is returned. It will be multiplicated by the loss function.
393
+
394
+ Args:
395
+ proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec
396
+ cfg (DictConfig): contains arguments of the REDQ script
397
+ device (torch.device, optional): device on which the model must be cast. Default is "cpu".
398
+ in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of
399
+ `'observation_vector'` or `'pixels'`. If none is provided, one of these two keys is chosen
400
+ based on the `cfg.from_pixels` argument.
401
+ actor_net_kwargs (dict, optional): kwargs of the actor MLP.
402
+ qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP.
403
+
404
+ Returns:
405
+ A nn.ModuleList containing the actor, qvalue operator(s) and the value operator.
406
+
407
+ """
408
+ torch.manual_seed(cfg.seed)
409
+ tanh_loc = cfg.network.tanh_loc
410
+ default_policy_scale = cfg.network.default_policy_scale
411
+ gSDE = cfg.exploration.gSDE
412
+
413
+ action_spec = proof_environment.action_spec_unbatched
414
+
415
+ if actor_net_kwargs is None:
416
+ actor_net_kwargs = {}
417
+ if qvalue_net_kwargs is None:
418
+ qvalue_net_kwargs = {}
419
+
420
+ linear_layer_class = torch.nn.Linear if not cfg.exploration.noisy else NoisyLinear
421
+
422
+ out_features_actor = (2 - gSDE) * action_spec.shape[-1]
423
+ if cfg.env.from_pixels:
424
+ if in_keys is None:
425
+ in_keys_actor = ["pixels"]
426
+ else:
427
+ in_keys_actor = in_keys
428
+ actor_net_kwargs_default = {
429
+ "mlp_net_kwargs": {
430
+ "layer_class": linear_layer_class,
431
+ "activation_class": ACTIVATIONS[cfg.network.activation],
432
+ },
433
+ "conv_net_kwargs": {
434
+ "activation_class": ACTIVATIONS[cfg.network.activation]
435
+ },
436
+ }
437
+ actor_net_kwargs_default.update(actor_net_kwargs)
438
+ actor_net = DdpgCnnActor(out_features_actor, **actor_net_kwargs_default)
439
+ gSDE_state_key = "hidden"
440
+ out_keys_actor = ["param", "hidden"]
441
+
442
+ value_net_default_kwargs = {
443
+ "mlp_net_kwargs": {
444
+ "layer_class": linear_layer_class,
445
+ "activation_class": ACTIVATIONS[cfg.network.activation],
446
+ },
447
+ "conv_net_kwargs": {
448
+ "activation_class": ACTIVATIONS[cfg.network.activation]
449
+ },
450
+ }
451
+ value_net_default_kwargs.update(qvalue_net_kwargs)
452
+
453
+ in_keys_qvalue = ["pixels", "action"]
454
+ qvalue_net = DdpgCnnQNet(**value_net_default_kwargs)
455
+ else:
456
+ if in_keys is None:
457
+ in_keys_actor = ["observation_vector"]
458
+ else:
459
+ in_keys_actor = in_keys
460
+
461
+ actor_net_kwargs_default = {
462
+ "num_cells": [cfg.network.actor_cells] * cfg.network.actor_depth,
463
+ "out_features": out_features_actor,
464
+ "activation_class": ACTIVATIONS[cfg.network.activation],
465
+ }
466
+ actor_net_kwargs_default.update(actor_net_kwargs)
467
+ actor_net = MLP(**actor_net_kwargs_default)
468
+ out_keys_actor = ["param"]
469
+ gSDE_state_key = in_keys_actor[0]
470
+
471
+ qvalue_net_kwargs_default = {
472
+ "num_cells": [cfg.network.qvalue_cells] * cfg.network.qvalue_depth,
473
+ "out_features": 1,
474
+ "activation_class": ACTIVATIONS[cfg.network.activation],
475
+ }
476
+ qvalue_net_kwargs_default.update(qvalue_net_kwargs)
477
+ qvalue_net = MLP(
478
+ **qvalue_net_kwargs_default,
479
+ )
480
+ in_keys_qvalue = in_keys_actor + ["action"]
481
+
482
+ dist_class = TanhNormal
483
+ dist_kwargs = {
484
+ "low": action_spec.space.low,
485
+ "high": action_spec.space.high,
486
+ "tanh_loc": tanh_loc,
487
+ }
488
+
489
+ if not gSDE:
490
+ actor_net = nn.Sequential(
491
+ actor_net,
492
+ NormalParamExtractor(
493
+ scale_mapping=f"biased_softplus_{default_policy_scale}",
494
+ scale_lb=cfg.network.scale_lb,
495
+ ),
496
+ )
497
+ actor_module = SafeModule(
498
+ actor_net,
499
+ in_keys=in_keys_actor,
500
+ out_keys=["loc", "scale"] + out_keys_actor[1:],
501
+ )
502
+
503
+ else:
504
+ actor_module = SafeModule(
505
+ actor_net,
506
+ in_keys=in_keys_actor,
507
+ out_keys=["action"] + out_keys_actor[1:], # will be overwritten
508
+ )
509
+
510
+ if action_spec.domain == "continuous":
511
+ min = action_spec.space.low
512
+ max = action_spec.space.high
513
+ transform = SafeTanhTransform()
514
+ if (min != -1).any() or (max != 1).any():
515
+ transform = d.ComposeTransform(
516
+ transform,
517
+ d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2),
518
+ )
519
+ else:
520
+ raise RuntimeError("cannot use gSDE with discrete actions")
521
+
522
+ actor_module = SafeSequential(
523
+ actor_module,
524
+ SafeModule(
525
+ LazygSDEModule(transform=transform, device=device),
526
+ in_keys=["action", gSDE_state_key, "_eps_gSDE"],
527
+ out_keys=["loc", "scale", "action", "_eps_gSDE"],
528
+ ),
529
+ )
530
+
531
+ actor = ProbabilisticActor(
532
+ spec=action_spec,
533
+ in_keys=["loc", "scale"],
534
+ module=actor_module,
535
+ distribution_class=dist_class,
536
+ distribution_kwargs=dist_kwargs,
537
+ default_interaction_type=InteractionType.RANDOM,
538
+ return_log_prob=True,
539
+ )
540
+ qvalue = ValueOperator(
541
+ in_keys=in_keys_qvalue,
542
+ module=qvalue_net,
543
+ )
544
+ model = nn.ModuleList([actor, qvalue]).to(device)
545
+
546
+ # init nets
547
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
548
+ td = proof_environment.fake_tensordict()
549
+ td = td.unsqueeze(-1)
550
+ td = td.to(device)
551
+ for net in model:
552
+ net(td)
553
+ del td
554
+ return model
555
+
556
+
557
+ def transformed_env_constructor(
558
+ cfg: DictConfig, # noqa: F821
559
+ video_tag: str = "",
560
+ logger: Logger | None = None,
561
+ stats: dict | None = None,
562
+ norm_obs_only: bool = False,
563
+ use_env_creator: bool = False,
564
+ custom_env_maker: Callable | None = None,
565
+ custom_env: EnvBase | None = None,
566
+ return_transformed_envs: bool = True,
567
+ action_dim_gsde: int | None = None,
568
+ state_dim_gsde: int | None = None,
569
+ batch_dims: int | None = 0,
570
+ obs_norm_state_dict: dict | None = None,
571
+ ) -> Callable | EnvCreator:
572
+ """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
573
+
574
+ Args:
575
+ cfg (DictConfig): a DictConfig containing the arguments of the script.
576
+ video_tag (str, optional): video tag to be passed to the Logger object
577
+ logger (Logger, optional): logger associated with the script
578
+ stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform
579
+ norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online.
580
+ Default is `False`.
581
+ use_env_creator (bool, optional): whether the `EnvCreator` class should be used. By using `EnvCreator`,
582
+ one can make sure that running statistics will be put in shared memory and accessible for all workers
583
+ when using a `VecNorm` transform. Default is `True`.
584
+ custom_env_maker (callable, optional): if your env maker is not part
585
+ of torchrl env wrappers, a custom callable
586
+ can be passed instead. In this case it will override the
587
+ constructor retrieved from `args`.
588
+ custom_env (EnvBase, optional): if an existing environment needs to be
589
+ transformed_in, it can be passed directly to this helper. `custom_env_maker`
590
+ and `custom_env` are exclusive features.
591
+ return_transformed_envs (bool, optional): if ``True``, a transformed_in environment
592
+ is returned.
593
+ action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise.
594
+ Make sure this is indicated in environment executed in parallel.
595
+ state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise.
596
+ Make sure this is indicated in environment executed in parallel.
597
+ batch_dims (int, optional): number of dimensions of a batch of data. If a single env is
598
+ used, it should be 0 (default). If multiple envs are being transformed in parallel,
599
+ it should be set to 1 (or the number of dims of the batch).
600
+ obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the
601
+ environment
602
+ """
603
+
604
+ def make_transformed_env(**kwargs) -> TransformedEnv:
605
+ env_name = cfg.env.name
606
+ env_task = cfg.env.task
607
+ env_library = LIBS[cfg.env.library]
608
+ frame_skip = cfg.env.frame_skip
609
+ from_pixels = cfg.env.from_pixels
610
+ categorical_action_encoding = cfg.env.categorical_action_encoding
611
+
612
+ if custom_env is None and custom_env_maker is None:
613
+ if cfg.collector.device in ("", None):
614
+ device = "cpu" if not torch.cuda.is_available() else "cuda:0"
615
+ elif isinstance(cfg.collector.device, str):
616
+ device = cfg.collector.device
617
+ elif isinstance(cfg.collector.device, Sequence):
618
+ device = cfg.collector.device[0]
619
+ else:
620
+ raise ValueError(
621
+ "collector_device must be either a string or a sequence of strings"
622
+ )
623
+ env_kwargs = {
624
+ "env_name": env_name,
625
+ "device": device,
626
+ "frame_skip": frame_skip,
627
+ "from_pixels": from_pixels or len(video_tag),
628
+ "pixels_only": from_pixels,
629
+ }
630
+ if env_library is GymEnv:
631
+ env_kwargs.update(
632
+ {"categorical_action_encoding": categorical_action_encoding}
633
+ )
634
+ elif categorical_action_encoding:
635
+ raise NotImplementedError(
636
+ "categorical_action_encoding=True is currently only compatible with GymEnvs."
637
+ )
638
+ if env_library is DMControlEnv:
639
+ env_kwargs.update({"task_name": env_task})
640
+ env_kwargs.update(kwargs)
641
+ env = env_library(**env_kwargs)
642
+ elif custom_env is None and custom_env_maker is not None:
643
+ env = custom_env_maker(**kwargs)
644
+ elif custom_env_maker is None and custom_env is not None:
645
+ env = custom_env
646
+ else:
647
+ raise RuntimeError("cannot provide both custom_env and custom_env_maker")
648
+
649
+ if cfg.env.noops and custom_env is None:
650
+ # this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv
651
+ # that already has its NoopResetEnv set for the contained envs.
652
+ # There is a risk however that we're just skipping the NoopsReset instantiation
653
+ env = TransformedEnv(env, NoopResetEnv(cfg.env.noops))
654
+ if not return_transformed_envs:
655
+ return env
656
+
657
+ return make_env_transforms(
658
+ env,
659
+ cfg,
660
+ video_tag,
661
+ logger,
662
+ env_name,
663
+ stats,
664
+ norm_obs_only,
665
+ env_library,
666
+ action_dim_gsde,
667
+ state_dim_gsde,
668
+ batch_dims=batch_dims,
669
+ obs_norm_state_dict=obs_norm_state_dict,
670
+ )
671
+
672
+ if use_env_creator:
673
+ return env_creator(make_transformed_env)
674
+ return make_transformed_env
675
+
676
+
677
+ def get_norm_state_dict(env):
678
+ """Gets the normalization loc and scale from the env state_dict."""
679
+ sd = env.state_dict()
680
+ sd = {
681
+ key: val
682
+ for key, val in sd.items()
683
+ if key.endswith("loc") or key.endswith("scale")
684
+ }
685
+ return sd
686
+
687
+
688
+ def initialize_observation_norm_transforms(
689
+ proof_environment: EnvBase,
690
+ num_iter: int = 1000,
691
+ key: str | tuple[str, ...] = None,
692
+ ):
693
+ """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
694
+
695
+ If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op.
696
+ Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect.
697
+ If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will
698
+ be raised.
699
+
700
+ Args:
701
+ proof_environment (EnvBase instance, optional): if provided, this env will
702
+ be used to execute the rollouts. If not, it will be created using
703
+ the cfg object.
704
+ num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms`
705
+ key (str, optional): if provided, the stats of this key will be gathered.
706
+ If not, it is expected that only one key exists in `env.observation_spec`.
707
+
708
+ """
709
+ if not isinstance(proof_environment.transform, Compose) and not isinstance(
710
+ proof_environment.transform, ObservationNorm
711
+ ):
712
+ return
713
+
714
+ if key is None:
715
+ keys = list(proof_environment.base_env.observation_spec.keys(True, True))
716
+ key = keys.pop()
717
+ if len(keys):
718
+ raise RuntimeError(
719
+ f"More than one key exists in the observation_specs: {[key] + keys} were found, "
720
+ "thus initialize_observation_norm_transforms cannot infer which to compute the stats of."
721
+ )
722
+
723
+ if isinstance(proof_environment.transform, Compose):
724
+ for transform in proof_environment.transform:
725
+ if isinstance(transform, ObservationNorm) and not transform.initialized:
726
+ transform.init_stats(num_iter=num_iter, key=key)
727
+ elif not proof_environment.transform.initialized:
728
+ proof_environment.transform.init_stats(num_iter=num_iter, key=key)
729
+
730
+
731
+ def parallel_env_constructor(
732
+ cfg: DictConfig, **kwargs # noqa: F821
733
+ ) -> ParallelEnv | EnvCreator:
734
+ """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
735
+
736
+ Args:
737
+ cfg (DictConfig): config containing user-defined arguments
738
+ kwargs: keyword arguments for the `transformed_env_constructor` method.
739
+ """
740
+ batch_transform = cfg.env.batch_transform
741
+ if not batch_transform:
742
+ raise NotImplementedError(
743
+ "batch_transform must be set to True for the recorder to be synced "
744
+ "with the collection envs."
745
+ )
746
+ if cfg.collector.env_per_collector == 1:
747
+ kwargs.update({"cfg": cfg, "use_env_creator": True})
748
+ make_transformed_env = transformed_env_constructor(**kwargs)
749
+ return make_transformed_env
750
+ kwargs.update({"cfg": cfg, "use_env_creator": True})
751
+ make_transformed_env = transformed_env_constructor(
752
+ return_transformed_envs=not batch_transform, **kwargs
753
+ )
754
+ parallel_env = ParallelEnv(
755
+ num_workers=cfg.collector.env_per_collector,
756
+ create_env_fn=make_transformed_env,
757
+ create_env_kwargs=None,
758
+ serial_for_single=True,
759
+ pin_memory=False,
760
+ )
761
+ if batch_transform:
762
+ kwargs.update(
763
+ {
764
+ "cfg": cfg,
765
+ "use_env_creator": False,
766
+ "custom_env": parallel_env,
767
+ "batch_dims": 1,
768
+ }
769
+ )
770
+ env = transformed_env_constructor(**kwargs)()
771
+ return env
772
+ return parallel_env
773
+
774
+
775
+ def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv):
776
+ """Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts.
777
+
778
+ Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment
779
+ If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list
780
+
781
+ Args:
782
+ proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm`
783
+ state dict from
784
+ """
785
+ obs_norm_state_dicts = []
786
+
787
+ if isinstance(proof_environment.transform, Compose):
788
+ for idx, transform in enumerate(proof_environment.transform):
789
+ if isinstance(transform, ObservationNorm):
790
+ obs_norm_state_dicts.append((idx, transform.state_dict()))
791
+
792
+ if isinstance(proof_environment.transform, ObservationNorm):
793
+ obs_norm_state_dicts.append((0, proof_environment.transform.state_dict()))
794
+
795
+ return obs_norm_state_dicts
796
+
797
+
798
+ def make_env_transforms(
799
+ env,
800
+ cfg,
801
+ video_tag,
802
+ logger,
803
+ env_name,
804
+ stats,
805
+ norm_obs_only,
806
+ env_library,
807
+ action_dim_gsde,
808
+ state_dim_gsde,
809
+ batch_dims=0,
810
+ obs_norm_state_dict=None,
811
+ ):
812
+ """Creates the typical transforms for and env."""
813
+ env = TransformedEnv(env)
814
+
815
+ from_pixels = cfg.env.from_pixels
816
+ vecnorm = cfg.env.vecnorm
817
+ norm_rewards = vecnorm and cfg.env.norm_rewards
818
+ _norm_obs_only = norm_obs_only or not norm_rewards
819
+ reward_scaling = cfg.env.reward_scaling
820
+ reward_loc = cfg.env.reward_loc
821
+
822
+ if len(video_tag):
823
+ center_crop = cfg.env.center_crop
824
+ if center_crop:
825
+ center_crop = center_crop[0]
826
+ env.append_transform(
827
+ VideoRecorder(
828
+ logger=logger,
829
+ tag=f"{video_tag}_{env_name}_video",
830
+ center_crop=center_crop,
831
+ ),
832
+ )
833
+
834
+ if from_pixels:
835
+ if not cfg.env.catframes:
836
+ raise RuntimeError(
837
+ "this env builder currently only accepts positive catframes values"
838
+ "when pixels are being used."
839
+ )
840
+ env.append_transform(ToTensorImage())
841
+ if cfg.env.center_crop:
842
+ env.append_transform(CenterCrop(*cfg.env.center_crop))
843
+ env.append_transform(Resize(cfg.env.image_size, cfg.env.image_size))
844
+ if cfg.env.grayscale:
845
+ env.append_transform(GrayScale())
846
+ env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
847
+ env.append_transform(CatFrames(N=cfg.env.catframes, in_keys=["pixels"], dim=-3))
848
+ if stats is None and obs_norm_state_dict is None:
849
+ obs_stats = {}
850
+ elif stats is None:
851
+ obs_stats = copy(obs_norm_state_dict)
852
+ else:
853
+ obs_stats = copy(stats)
854
+ obs_stats["standard_normal"] = True
855
+ obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
856
+ env.append_transform(obs_norm)
857
+ if norm_rewards:
858
+ reward_scaling = 1.0
859
+ reward_loc = 0.0
860
+ if norm_obs_only:
861
+ reward_scaling = 1.0
862
+ reward_loc = 0.0
863
+ if reward_scaling is not None:
864
+ env.append_transform(RewardScaling(reward_loc, reward_scaling))
865
+
866
+ if not from_pixels:
867
+ selected_keys = [
868
+ key
869
+ for key in env.observation_spec.keys(True, True)
870
+ if ("pixels" not in key) and (key not in env.state_spec.keys(True, True))
871
+ ]
872
+
873
+ # even if there is a single tensor, it'll be renamed in "observation_vector"
874
+ out_key = "observation_vector"
875
+ env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))
876
+
877
+ if not vecnorm:
878
+ if stats is None and obs_norm_state_dict is None:
879
+ _stats = {}
880
+ elif stats is None:
881
+ _stats = copy(obs_norm_state_dict)
882
+ else:
883
+ _stats = copy(stats)
884
+ _stats.update({"standard_normal": True})
885
+ obs_norm = ObservationNorm(
886
+ **_stats,
887
+ in_keys=[out_key],
888
+ )
889
+ env.append_transform(obs_norm)
890
+ else:
891
+ env.append_transform(
892
+ VecNorm(
893
+ in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key],
894
+ decay=0.9999,
895
+ )
896
+ )
897
+
898
+ env.append_transform(DoubleToFloat())
899
+
900
+ if hasattr(cfg, "catframes") and cfg.env.catframes:
901
+ env.append_transform(
902
+ CatFrames(N=cfg.env.catframes, in_keys=[out_key], dim=-1)
903
+ )
904
+
905
+ else:
906
+ env.append_transform(DoubleToFloat())
907
+
908
+ if hasattr(cfg, "gSDE") and cfg.exploration.gSDE:
909
+ env.append_transform(
910
+ gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde)
911
+ )
912
+
913
+ env.append_transform(StepCounter())
914
+ env.append_transform(InitTracker())
915
+
916
+ return env
917
+
918
+
919
+ def make_redq_loss(model, cfg) -> tuple[REDQLoss_deprecated, TargetNetUpdater | None]:
920
+ """Builds the REDQ loss module."""
921
+ loss_kwargs = {}
922
+ loss_kwargs.update({"loss_function": cfg.loss.loss_function})
923
+ loss_kwargs.update({"delay_qvalue": cfg.loss.type == "double"})
924
+ loss_class = REDQLoss_deprecated
925
+ if isinstance(model, ActorValueOperator):
926
+ actor_model = model.get_policy_operator()
927
+ qvalue_model = model.get_value_operator()
928
+ elif isinstance(model, ActorCriticOperator):
929
+ raise RuntimeError(
930
+ "Although REDQ Q-value depends upon selected actions, using the"
931
+ "ActorCriticOperator will lead to resampling of the actions when"
932
+ "computing the Q-value loss, which we don't want. Please use the"
933
+ "ActorValueOperator instead."
934
+ )
935
+ else:
936
+ actor_model, qvalue_model = model
937
+
938
+ loss_module = loss_class(
939
+ actor_network=actor_model,
940
+ qvalue_network=qvalue_model,
941
+ num_qvalue_nets=cfg.loss.num_q_values,
942
+ gSDE=cfg.exploration.gSDE,
943
+ **loss_kwargs,
944
+ )
945
+ loss_module.make_value_estimator(gamma=cfg.loss.gamma)
946
+ target_net_updater = make_target_updater(cfg, loss_module)
947
+ return loss_module, target_net_updater
948
+
949
+
950
+ def make_target_updater(
951
+ cfg: DictConfig, loss_module: LossModule # noqa: F821
952
+ ) -> TargetNetUpdater | None:
953
+ """Builds a target network weight update object."""
954
+ if cfg.loss.type == "double":
955
+ if not cfg.loss.hard_update:
956
+ target_net_updater = SoftUpdate(
957
+ loss_module, eps=1 - 1 / cfg.loss.value_network_update_interval
958
+ )
959
+ else:
960
+ target_net_updater = HardUpdate(
961
+ loss_module,
962
+ value_network_update_interval=cfg.loss.value_network_update_interval,
963
+ )
964
+ else:
965
+ if cfg.hard_update:
966
+ raise RuntimeError(
967
+ "hard/soft-update are supposed to be used with double SAC loss. "
968
+ "Consider using --loss=double or discarding the hard_update flag."
969
+ )
970
+ target_net_updater = None
971
+ return target_net_updater
972
+
973
+
974
+ def make_collector_offpolicy(
975
+ make_env: Callable[[], EnvBase],
976
+ actor_model_explore: TensorDictModuleWrapper | ProbabilisticTensorDictSequential,
977
+ cfg: DictConfig, # noqa: F821
978
+ make_env_kwargs: dict | None = None,
979
+ ) -> DataCollectorBase:
980
+ """Returns a data collector for off-policy sota-implementations.
981
+
982
+ Args:
983
+ make_env (Callable): environment creator
984
+ actor_model_explore (SafeModule): Model instance used for evaluation and exploration update
985
+ cfg (DictConfig): config for creating collector object
986
+ make_env_kwargs (dict): kwargs for the env creator
987
+
988
+ """
989
+ if cfg.collector.async_collection:
990
+ collector_helper = sync_async_collector
991
+ else:
992
+ collector_helper = sync_sync_collector
993
+
994
+ if cfg.collector.multi_step:
995
+ ms = MultiStep(
996
+ gamma=cfg.loss.gamma,
997
+ n_steps=cfg.collector.n_steps_return,
998
+ )
999
+ else:
1000
+ ms = None
1001
+
1002
+ env_kwargs = {}
1003
+ if make_env_kwargs is not None and isinstance(make_env_kwargs, dict):
1004
+ env_kwargs.update(make_env_kwargs)
1005
+ elif make_env_kwargs is not None:
1006
+ env_kwargs = make_env_kwargs
1007
+ if cfg.collector.device in ("", None):
1008
+ cfg.collector.device = "cpu" if not torch.cuda.is_available() else "cuda:0"
1009
+ else:
1010
+ cfg.collector.device = (
1011
+ cfg.collector.device
1012
+ if len(cfg.collector.device) > 1
1013
+ else cfg.collector.device[0]
1014
+ )
1015
+ collector_helper_kwargs = {
1016
+ "env_fns": make_env,
1017
+ "env_kwargs": env_kwargs,
1018
+ "policy": actor_model_explore,
1019
+ "max_frames_per_traj": cfg.collector.max_frames_per_traj,
1020
+ "frames_per_batch": cfg.collector.frames_per_batch,
1021
+ "total_frames": cfg.collector.total_frames,
1022
+ "postproc": ms,
1023
+ "num_env_per_collector": 1,
1024
+ # we already took care of building the make_parallel_env function
1025
+ "num_collectors": -cfg.num_workers // -cfg.collector.env_per_collector,
1026
+ "device": cfg.collector.device,
1027
+ "init_random_frames": cfg.collector.init_random_frames,
1028
+ "split_trajs": True,
1029
+ # trajectories must be separated if multi-step is used
1030
+ }
1031
+
1032
+ collector = collector_helper(**collector_helper_kwargs)
1033
+ collector.set_seed(cfg.seed)
1034
+ return collector
1035
+
1036
+
1037
+ def make_replay_buffer(
1038
+ device: DEVICE_TYPING, cfg: DictConfig # noqa: F821
1039
+ ) -> ReplayBuffer: # noqa: F821
1040
+ """Builds a replay buffer using the config built from ReplayArgsConfig."""
1041
+ device = torch.device(device)
1042
+ if not cfg.buffer.prb:
1043
+ sampler = RandomSampler()
1044
+ else:
1045
+ sampler = PrioritizedSampler(
1046
+ max_capacity=cfg.buffer.size,
1047
+ alpha=0.7,
1048
+ beta=0.5,
1049
+ )
1050
+ buffer = TensorDictReplayBuffer(
1051
+ storage=LazyMemmapStorage(
1052
+ cfg.buffer.size,
1053
+ scratch_dir=cfg.buffer.scratch_dir,
1054
+ ),
1055
+ sampler=sampler,
1056
+ pin_memory=device != torch.device("cpu"),
1057
+ prefetch=cfg.buffer.prefetch,
1058
+ batch_size=cfg.buffer.batch_size,
1059
+ )
1060
+ return buffer