torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1107 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import tempfile
9
+ from contextlib import nullcontext
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from tensordict import NestedKey
14
+ from tensordict.nn import (
15
+ InteractionType,
16
+ ProbabilisticTensorDictModule,
17
+ ProbabilisticTensorDictSequential,
18
+ TensorDictModule,
19
+ TensorDictSequential,
20
+ )
21
+ from torchrl import logger as torchrl_logger
22
+ from torchrl._utils import set_profiling_enabled
23
+ from torchrl.collectors import MultiCollector
24
+
25
+ from torchrl.data import (
26
+ Composite,
27
+ LazyMemmapStorage,
28
+ SliceSampler,
29
+ TensorDictReplayBuffer,
30
+ Unbounded,
31
+ )
32
+
33
+ from torchrl.envs import (
34
+ Compose,
35
+ DMControlEnv,
36
+ DoubleToFloat,
37
+ DreamerDecoder,
38
+ DreamerEnv,
39
+ EnvCreator,
40
+ ExcludeTransform,
41
+ # ExcludeTransform,
42
+ FrameSkipTransform,
43
+ GrayScale,
44
+ GymEnv,
45
+ ParallelEnv,
46
+ RenameTransform,
47
+ Resize,
48
+ RewardSum,
49
+ set_gym_backend,
50
+ StepCounter,
51
+ TensorDictPrimer,
52
+ ToTensorImage,
53
+ TransformedEnv,
54
+ )
55
+ from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
56
+ from torchrl.modules import (
57
+ AdditiveGaussianModule,
58
+ DreamerActor,
59
+ IndependentNormal,
60
+ MLP,
61
+ ObsDecoder,
62
+ ObsEncoder,
63
+ RSSMPosterior,
64
+ RSSMPrior,
65
+ RSSMRollout,
66
+ SafeModule,
67
+ SafeProbabilisticModule,
68
+ SafeProbabilisticTensorDictSequential,
69
+ SafeSequential,
70
+ TanhNormal,
71
+ WorldModelWrapper,
72
+ )
73
+ from torchrl.record import VideoRecorder
74
+
75
+
76
+ def allocate_collector_devices(
77
+ num_collectors: int, training_device: torch.device
78
+ ) -> list[torch.device]:
79
+ """Allocate CUDA devices for collectors, reserving cuda:0 for training.
80
+
81
+ Device allocation strategy:
82
+ - Training always uses cuda:0
83
+ - Collectors use cuda:1, cuda:2, ..., cuda:N-1 if available
84
+ - If only 1 CUDA device, colocate training and inference on cuda:0
85
+ - If num_collectors >= num_cuda_devices, raise an exception
86
+
87
+ Args:
88
+ num_collectors: Number of collector workers requested
89
+ training_device: The device used for training (determines if CUDA is used)
90
+
91
+ Returns:
92
+ List of devices for each collector worker
93
+
94
+ Raises:
95
+ ValueError: If num_collectors >= num_cuda_devices (no device left for training)
96
+ """
97
+ if training_device.type != "cuda":
98
+ # CPU training: all collectors on CPU
99
+ return [torch.device("cpu")] * num_collectors
100
+
101
+ num_cuda_devices = torch.cuda.device_count()
102
+
103
+ if num_cuda_devices == 0:
104
+ # No CUDA devices available, fall back to CPU
105
+ return [torch.device("cpu")] * num_collectors
106
+
107
+ if num_cuda_devices == 1:
108
+ # Single GPU: colocate training and inference
109
+ torchrl_logger.info(
110
+ f"Single CUDA device available. Colocating {num_collectors} collectors "
111
+ "with training on cuda:0"
112
+ )
113
+ return [torch.device("cuda:0")] * num_collectors
114
+
115
+ # Multiple GPUs available
116
+ # Reserve cuda:0 for training, use cuda:1..cuda:N-1 for inference
117
+ inference_devices = num_cuda_devices - 1 # Devices available for collectors
118
+
119
+ if num_collectors > inference_devices:
120
+ raise ValueError(
121
+ f"Requested {num_collectors} collectors but only {inference_devices} "
122
+ f"CUDA devices available for inference (cuda:1 to cuda:{num_cuda_devices - 1}). "
123
+ f"cuda:0 is reserved for training. Either reduce num_collectors to "
124
+ f"{inference_devices} or add more GPUs."
125
+ )
126
+
127
+ # Distribute collectors across available inference devices (round-robin)
128
+ collector_devices = []
129
+ for i in range(num_collectors):
130
+ device_idx = (i % inference_devices) + 1 # +1 to skip cuda:0
131
+ collector_devices.append(torch.device(f"cuda:{device_idx}"))
132
+
133
+ device_str = ", ".join(str(d) for d in collector_devices)
134
+ torchrl_logger.info(
135
+ f"Allocated {num_collectors} collectors to devices: [{device_str}]. "
136
+ f"Training on cuda:0."
137
+ )
138
+
139
+ return collector_devices
140
+
141
+
142
+ class DreamerProfiler:
143
+ """Helper class for PyTorch profiling in Dreamer training.
144
+
145
+ Encapsulates profiler setup, stepping, and trace export logic.
146
+
147
+ Args:
148
+ cfg: Hydra config with profiling section.
149
+ device: Training device (used to determine CUDA profiling).
150
+ pbar: Progress bar to update total when profiling.
151
+ """
152
+
153
+ def __init__(self, cfg, device, pbar=None, *, compile_warmup: int = 0):
154
+ self.enabled = cfg.profiling.enabled
155
+ self.cfg = cfg
156
+ self.total_optim_steps = 0
157
+ self._profiler = None
158
+ self._stopped = False
159
+ self._compile_warmup = compile_warmup
160
+
161
+ # Enable detailed profiling instrumentation in torchrl when profiling
162
+ set_profiling_enabled(self.enabled)
163
+
164
+ if not self.enabled:
165
+ return
166
+
167
+ # Override total_optim_steps for profiling runs
168
+ torchrl_logger.info(
169
+ f"Profiling enabled: running {cfg.profiling.total_optim_steps} optim steps "
170
+ f"(skip_first={cfg.profiling.skip_first}, warmup={cfg.profiling.warmup_steps}, "
171
+ f"active={cfg.profiling.active_steps})"
172
+ )
173
+ if pbar is not None:
174
+ pbar.total = cfg.profiling.total_optim_steps
175
+
176
+ # Setup profiler schedule
177
+ # - skip_first: steps to skip entirely (no profiling)
178
+ # - warmup: steps to warm up profiler (data discarded)
179
+ # - active: steps to actually profile (data kept)
180
+ #
181
+ # When torch.compile is enabled via compile_with_warmup, the first `compile_warmup`
182
+ # calls run eagerly and the *next* call typically triggers compilation. Profiling
183
+ # these steps is usually undesirable because it captures compilation overhead and
184
+ # non-representative eager execution.
185
+ #
186
+ # Therefore we automatically extend skip_first by (compile_warmup + 1) optim steps.
187
+ extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0
188
+ skip_first = cfg.profiling.skip_first + extra_skip
189
+ profiler_schedule = torch.profiler.schedule(
190
+ skip_first=skip_first,
191
+ wait=0,
192
+ warmup=cfg.profiling.warmup_steps,
193
+ active=cfg.profiling.active_steps,
194
+ repeat=1,
195
+ )
196
+
197
+ # Determine profiler activities
198
+ activities = [torch.profiler.ProfilerActivity.CPU]
199
+ if cfg.profiling.profile_cuda and device.type == "cuda":
200
+ activities.append(torch.profiler.ProfilerActivity.CUDA)
201
+
202
+ self._profiler = torch.profiler.profile(
203
+ activities=activities,
204
+ schedule=profiler_schedule,
205
+ on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler_logs")
206
+ if not cfg.profiling.trace_file
207
+ else None,
208
+ record_shapes=cfg.profiling.record_shapes,
209
+ profile_memory=cfg.profiling.profile_memory,
210
+ with_stack=cfg.profiling.with_stack,
211
+ with_flops=cfg.profiling.with_flops,
212
+ )
213
+ self._profiler.start()
214
+
215
+ def step(self) -> bool:
216
+ """Step the profiler and check if profiling is complete.
217
+
218
+ Returns:
219
+ True if profiling is complete and training should exit.
220
+ """
221
+ if not self.enabled or self._stopped:
222
+ return False
223
+
224
+ self.total_optim_steps += 1
225
+ self._profiler.step()
226
+
227
+ # Check if we should stop profiling
228
+ extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0
229
+ target_steps = (
230
+ self.cfg.profiling.skip_first
231
+ + extra_skip
232
+ + self.cfg.profiling.warmup_steps
233
+ + self.cfg.profiling.active_steps
234
+ )
235
+ if self.total_optim_steps >= target_steps:
236
+ torchrl_logger.info(
237
+ f"Profiling complete after {self.total_optim_steps} optim steps. "
238
+ f"Exporting trace to {self.cfg.profiling.trace_file}"
239
+ )
240
+ self._profiler.stop()
241
+ self._stopped = True
242
+ # Export trace if trace_file is set
243
+ if self.cfg.profiling.trace_file:
244
+ self._profiler.export_chrome_trace(self.cfg.profiling.trace_file)
245
+ return True
246
+
247
+ return False
248
+
249
+ def should_exit(self) -> bool:
250
+ """Check if training loop should exit due to profiling completion."""
251
+ if not self.enabled:
252
+ return False
253
+ extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0
254
+ target_steps = (
255
+ self.cfg.profiling.skip_first
256
+ + extra_skip
257
+ + self.cfg.profiling.warmup_steps
258
+ + self.cfg.profiling.active_steps
259
+ )
260
+ return self.total_optim_steps >= target_steps
261
+
262
+
263
+ def _make_env(cfg, device, from_pixels=False):
264
+ lib = cfg.env.backend
265
+ if lib in ("gym", "gymnasium"):
266
+ with set_gym_backend(lib):
267
+ env = GymEnv(
268
+ cfg.env.name,
269
+ device=device,
270
+ from_pixels=cfg.env.from_pixels or from_pixels,
271
+ pixels_only=cfg.env.from_pixels,
272
+ )
273
+ # Gym doesn't support native frame_skip, apply transform inside worker
274
+ if cfg.env.frame_skip > 1:
275
+ env = TransformedEnv(env, FrameSkipTransform(cfg.env.frame_skip))
276
+ elif lib == "dm_control":
277
+ env = DMControlEnv(
278
+ cfg.env.name,
279
+ cfg.env.task,
280
+ from_pixels=cfg.env.from_pixels or from_pixels,
281
+ pixels_only=cfg.env.from_pixels,
282
+ device=device,
283
+ frame_skip=cfg.env.frame_skip, # Native frame skip inside worker
284
+ )
285
+ else:
286
+ raise NotImplementedError(f"Unknown lib {lib}.")
287
+ default_dict = {
288
+ "state": Unbounded(shape=(cfg.networks.state_dim,)),
289
+ "belief": Unbounded(shape=(cfg.networks.rssm_hidden_dim,)),
290
+ }
291
+ env = env.append_transform(
292
+ TensorDictPrimer(random=False, default_value=0, **default_dict)
293
+ )
294
+ return env
295
+
296
+
297
+ def transform_env(cfg, env):
298
+ if not isinstance(env, TransformedEnv):
299
+ env = TransformedEnv(env)
300
+ if cfg.env.from_pixels:
301
+ # transforms pixel from 0-255 to 0-1 (uint8 to float32)
302
+ env.append_transform(
303
+ RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"])
304
+ )
305
+ env.append_transform(
306
+ ToTensorImage(from_int=True, in_keys=["pixels_int"], out_keys=["pixels"])
307
+ )
308
+ if cfg.env.grayscale:
309
+ env.append_transform(GrayScale())
310
+
311
+ image_size = cfg.env.image_size
312
+ env.append_transform(Resize(image_size, image_size))
313
+
314
+ env.append_transform(DoubleToFloat())
315
+ env.append_transform(RewardSum())
316
+ # Note: FrameSkipTransform is now applied inside workers (in _make_env) to avoid
317
+ # extra IPC round-trips. DMControl uses native frame_skip, Gym uses the transform.
318
+ env.append_transform(StepCounter(cfg.env.horizon))
319
+
320
+ return env
321
+
322
+
323
+ def make_environments(cfg, parallel_envs=1, logger=None):
324
+ """Make environments for training and evaluation.
325
+
326
+ Returns:
327
+ train_env_factory: A callable that creates a training environment (for MultiCollector)
328
+ eval_env: The evaluation environment instance
329
+ """
330
+
331
+ def train_env_factory():
332
+ """Factory function for creating training environments."""
333
+ func = functools.partial(
334
+ _make_env, cfg=cfg, device=_default_device(cfg.env.device)
335
+ )
336
+ train_env = ParallelEnv(
337
+ parallel_envs,
338
+ EnvCreator(func),
339
+ serial_for_single=True,
340
+ )
341
+ train_env = transform_env(cfg, train_env)
342
+ train_env.set_seed(cfg.env.seed)
343
+ return train_env
344
+
345
+ # Create eval env directly (not a factory)
346
+ func = functools.partial(
347
+ _make_env,
348
+ cfg=cfg,
349
+ device=_default_device(cfg.env.device),
350
+ from_pixels=cfg.logger.video,
351
+ )
352
+ eval_env = ParallelEnv(
353
+ 1,
354
+ EnvCreator(func),
355
+ serial_for_single=True,
356
+ )
357
+ eval_env = transform_env(cfg, eval_env)
358
+ eval_env.set_seed(cfg.env.seed + 1)
359
+ if cfg.logger.video:
360
+ eval_env.insert_transform(
361
+ 0,
362
+ VideoRecorder(
363
+ logger,
364
+ tag="eval/video",
365
+ in_keys=["pixels"],
366
+ skip=cfg.logger.video_skip,
367
+ ),
368
+ )
369
+
370
+ # Check specs on a temporary train env
371
+ temp_train_env = train_env_factory()
372
+ check_env_specs(temp_train_env)
373
+ temp_train_env.close()
374
+ del temp_train_env
375
+
376
+ check_env_specs(eval_env)
377
+ return train_env_factory, eval_env
378
+
379
+
380
+ def dump_video(module, step: int | None = None):
381
+ """Dump video from VideoRecorder transforms.
382
+
383
+ Args:
384
+ module: The transform module to check.
385
+ step: Optional step to log the video at. If not provided,
386
+ the VideoRecorder uses its internal counter.
387
+ """
388
+ if isinstance(module, VideoRecorder):
389
+ module.dump(step=step)
390
+
391
+
392
+ def _compute_encoder_output_size(image_size, channels=32, num_layers=4):
393
+ """Compute the flattened output size of ObsEncoder."""
394
+ # Compute spatial size after each conv layer (kernel=4, stride=2)
395
+ size = image_size
396
+ for _ in range(num_layers):
397
+ size = (size - 4) // 2 + 1
398
+ # Final channels = channels * 2^(num_layers-1)
399
+ final_channels = channels * (2 ** (num_layers - 1))
400
+ return final_channels * size * size
401
+
402
+
403
+ def make_dreamer(
404
+ cfg,
405
+ device,
406
+ action_key: str = "action",
407
+ value_key: str = "state_value",
408
+ use_decoder_in_env: bool = False,
409
+ compile: bool = True,
410
+ logger=None,
411
+ ):
412
+ test_env = _make_env(cfg, device="cpu")
413
+ test_env = transform_env(cfg, test_env)
414
+
415
+ # Get dimensions for explicit module instantiation (avoids lazy modules)
416
+ state_dim = cfg.networks.state_dim
417
+ rssm_hidden_dim = cfg.networks.rssm_hidden_dim
418
+ action_dim = test_env.action_spec.shape[-1]
419
+
420
+ # Make encoder and decoder
421
+ if cfg.env.from_pixels:
422
+ # Determine input channels (1 for grayscale, 3 for RGB)
423
+ in_channels = 1 if cfg.env.grayscale else 3
424
+ image_size = cfg.env.image_size
425
+
426
+ # Compute encoder output size for explicit posterior input
427
+ obs_embed_dim = _compute_encoder_output_size(
428
+ image_size, channels=32, num_layers=4
429
+ )
430
+
431
+ encoder = ObsEncoder(in_channels=in_channels, device=device)
432
+ decoder = ObsDecoder(latent_dim=state_dim + rssm_hidden_dim, device=device)
433
+
434
+ observation_in_key = "pixels"
435
+ observation_out_key = "reco_pixels"
436
+ else:
437
+ obs_embed_dim = 1024 # MLP output size
438
+ encoder = MLP(
439
+ out_features=obs_embed_dim,
440
+ depth=2,
441
+ num_cells=cfg.networks.hidden_dim,
442
+ activation_class=get_activation(cfg.networks.activation),
443
+ device=device,
444
+ )
445
+ decoder = MLP(
446
+ out_features=test_env.observation_spec["observation"].shape[-1],
447
+ depth=2,
448
+ num_cells=cfg.networks.hidden_dim,
449
+ activation_class=get_activation(cfg.networks.activation),
450
+ device=device,
451
+ )
452
+
453
+ observation_in_key = "observation"
454
+ observation_out_key = "reco_observation"
455
+
456
+ # Make RSSM with explicit input sizes (no lazy modules)
457
+ rssm_prior = RSSMPrior(
458
+ hidden_dim=rssm_hidden_dim,
459
+ rnn_hidden_dim=rssm_hidden_dim,
460
+ state_dim=state_dim,
461
+ action_spec=test_env.action_spec,
462
+ action_dim=action_dim,
463
+ device=device,
464
+ )
465
+ rssm_posterior = RSSMPosterior(
466
+ hidden_dim=rssm_hidden_dim,
467
+ state_dim=state_dim,
468
+ rnn_hidden_dim=rssm_hidden_dim,
469
+ obs_embed_dim=obs_embed_dim,
470
+ device=device,
471
+ )
472
+
473
+ # When use_scan=True or rssm_rollout.compile=True, replace C++ GRU with Python-based GRU
474
+ # for torch.compile compatibility. The C++ GRU (cuBLAS) cannot be traced by torch.compile.
475
+ if cfg.networks.use_scan or cfg.networks.rssm_rollout.compile:
476
+ from torchrl.modules.tensordict_module.rnn import GRUCell as PythonGRUCell
477
+
478
+ old_rnn = rssm_prior.rnn
479
+ python_rnn = PythonGRUCell(
480
+ old_rnn.input_size, old_rnn.hidden_size, device=device
481
+ )
482
+ python_rnn.load_state_dict(old_rnn.state_dict())
483
+ rssm_prior.rnn = python_rnn
484
+ torchrl_logger.info(
485
+ "Switched RSSMPrior to Python-based GRU for torch.compile compatibility"
486
+ )
487
+ # Make reward module
488
+ reward_module = MLP(
489
+ out_features=1,
490
+ depth=2,
491
+ num_cells=cfg.networks.hidden_dim,
492
+ activation_class=get_activation(cfg.networks.activation),
493
+ device=device,
494
+ )
495
+
496
+ # Make combined world model (modules already on device)
497
+ world_model = _dreamer_make_world_model(
498
+ encoder,
499
+ decoder,
500
+ rssm_prior,
501
+ rssm_posterior,
502
+ reward_module,
503
+ observation_in_key=observation_in_key,
504
+ observation_out_key=observation_out_key,
505
+ use_scan=cfg.networks.use_scan,
506
+ rssm_rollout_compile=cfg.networks.rssm_rollout.compile,
507
+ rssm_rollout_compile_backend=cfg.networks.rssm_rollout.compile_backend,
508
+ rssm_rollout_compile_mode=cfg.networks.rssm_rollout.compile_mode,
509
+ )
510
+
511
+ # Initialize world model (already on device)
512
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
513
+ tensordict = (
514
+ test_env.rollout(5, auto_cast_to_device=True).unsqueeze(-1).to(device)
515
+ )
516
+ tensordict = tensordict.to_tensordict()
517
+ world_model(tensordict)
518
+
519
+ # Create model-based environment
520
+ model_based_env = _dreamer_make_mbenv(
521
+ reward_module=reward_module,
522
+ rssm_prior=rssm_prior,
523
+ decoder=decoder,
524
+ observation_out_key=observation_out_key,
525
+ test_env=test_env,
526
+ use_decoder_in_env=use_decoder_in_env,
527
+ state_dim=cfg.networks.state_dim,
528
+ rssm_hidden_dim=cfg.networks.rssm_hidden_dim,
529
+ )
530
+
531
+ # def detach_state_and_belief(data):
532
+ # data.set("state", data.get("state").detach())
533
+ # data.set("belief", data.get("belief").detach())
534
+ # return data
535
+ #
536
+ # model_based_env = model_based_env.append_transform(detach_state_and_belief)
537
+ check_env_specs(model_based_env)
538
+
539
+ # Make actor (modules already on device)
540
+ actor_simulator, actor_realworld = _dreamer_make_actors(
541
+ encoder=encoder,
542
+ observation_in_key=observation_in_key,
543
+ rssm_prior=rssm_prior,
544
+ rssm_posterior=rssm_posterior,
545
+ mlp_num_units=cfg.networks.hidden_dim,
546
+ activation=get_activation(cfg.networks.activation),
547
+ action_key=action_key,
548
+ test_env=test_env,
549
+ device=device,
550
+ )
551
+ # Exploration noise to be added to the actor_realworld
552
+ actor_realworld = TensorDictSequential(
553
+ actor_realworld,
554
+ AdditiveGaussianModule(
555
+ spec=test_env.action_spec,
556
+ sigma_init=1.0,
557
+ sigma_end=1.0,
558
+ annealing_num_steps=1,
559
+ mean=0.0,
560
+ std=cfg.networks.exploration_noise,
561
+ device=device,
562
+ ),
563
+ )
564
+
565
+ # Make Critic (on device)
566
+ value_model = _dreamer_make_value_model(
567
+ hidden_dim=cfg.networks.hidden_dim,
568
+ activation=cfg.networks.activation,
569
+ value_key=value_key,
570
+ device=device,
571
+ )
572
+
573
+ # Move model_based_env to device (it contains references to modules already on device)
574
+ model_based_env.to(device)
575
+
576
+ # Initialize model-based environment, actor and critic
577
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
578
+ tensordict = (
579
+ model_based_env.fake_tensordict().unsqueeze(-1).to(value_model.device)
580
+ )
581
+ tensordict = tensordict
582
+ tensordict = actor_simulator(tensordict)
583
+ value_model(tensordict)
584
+
585
+ if cfg.logger.video:
586
+ model_based_env_eval = model_based_env.append_transform(DreamerDecoder())
587
+
588
+ def float_to_int(data):
589
+ reco_pixels_float = data.get("reco_pixels")
590
+ reco_pixels = (reco_pixels_float * 255).floor()
591
+ # assert (reco_pixels < 256).all() and (reco_pixels > 0).all(), (reco_pixels.min(), reco_pixels.max())
592
+ reco_pixels = reco_pixels.to(torch.uint8)
593
+ data.set("reco_pixels_float", reco_pixels_float)
594
+ return data.set("reco_pixels", reco_pixels)
595
+
596
+ model_based_env_eval.append_transform(float_to_int)
597
+ model_based_env_eval.append_transform(
598
+ VideoRecorder(
599
+ logger=logger,
600
+ tag="eval/simulated_video",
601
+ in_keys=["reco_pixels"],
602
+ skip=cfg.logger.video_skip,
603
+ )
604
+ )
605
+
606
+ else:
607
+ model_based_env_eval = None
608
+ return (
609
+ world_model,
610
+ model_based_env,
611
+ model_based_env_eval,
612
+ actor_simulator,
613
+ value_model,
614
+ actor_realworld,
615
+ )
616
+
617
+
618
+ def make_collector(
619
+ cfg,
620
+ train_env_factory,
621
+ actor_model_explore,
622
+ training_device: torch.device,
623
+ replay_buffer=None,
624
+ storage_transform=None,
625
+ track_policy_version=False,
626
+ ):
627
+ """Make async multi-collector for parallel data collection.
628
+
629
+ Args:
630
+ cfg: Configuration object
631
+ train_env_factory: A callable that creates a training environment
632
+ actor_model_explore: The exploration policy
633
+ training_device: Device used for training (used to allocate collector devices)
634
+ replay_buffer: Optional replay buffer for true async collection with start()
635
+ storage_transform: Optional transform to apply before storing in buffer
636
+ track_policy_version: If True, track policy version using integer versioning.
637
+ Can also be a PolicyVersion instance for custom versioning.
638
+
639
+ Returns:
640
+ MultiCollector in async mode with multiple worker processes
641
+
642
+ Device allocation:
643
+ - If training on CUDA with multiple GPUs: collectors use cuda:1, cuda:2, etc.
644
+ - If training on CUDA with single GPU: collectors colocate on cuda:0
645
+ - If training on CPU: collectors use CPU
646
+ """
647
+ num_collectors = cfg.collector.num_collectors
648
+ init_random_frames = (
649
+ cfg.collector.init_random_frames
650
+ if not cfg.profiling.enabled
651
+ else cfg.profiling.collector.init_random_frames_override
652
+ )
653
+
654
+ # Allocate devices for collectors (reserves cuda:0 for training if multi-GPU)
655
+ collector_devices = allocate_collector_devices(num_collectors, training_device)
656
+
657
+ collector = MultiCollector(
658
+ create_env_fn=[train_env_factory] * num_collectors,
659
+ policy=actor_model_explore,
660
+ frames_per_batch=cfg.collector.frames_per_batch,
661
+ total_frames=-1, # Run indefinitely until async_shutdown() is called
662
+ init_random_frames=init_random_frames,
663
+ policy_device=collector_devices,
664
+ env_device=collector_devices, # Match env output device to policy device for CUDA transforms
665
+ storing_device="cpu",
666
+ sync=False, # Async mode for overlapping collection with training
667
+ update_at_each_batch=False, # We manually call update_policy_weights_() in training loop
668
+ replay_buffer=replay_buffer,
669
+ postproc=storage_transform,
670
+ track_policy_version=track_policy_version,
671
+ # Skip fake data initialization - storage handles coordination
672
+ local_init_rb=True,
673
+ )
674
+ collector.set_seed(cfg.env.seed)
675
+
676
+ return collector
677
+
678
+
679
+ def make_storage_transform(
680
+ *,
681
+ pixel_obs=True,
682
+ grayscale=True,
683
+ image_size,
684
+ ):
685
+ """Create transforms to be applied at extend-time (once per frame).
686
+
687
+ These heavy transforms (ToTensorImage, GrayScale, Resize) are applied once
688
+ when data is added to the buffer, rather than on every sample.
689
+ """
690
+ if not pixel_obs:
691
+ return None
692
+
693
+ storage_transforms = Compose(
694
+ ExcludeTransform("pixels", ("next", "pixels"), inverse=True),
695
+ ToTensorImage(
696
+ in_keys=["pixels_int", ("next", "pixels_int")],
697
+ out_keys=["pixels", ("next", "pixels")],
698
+ ),
699
+ )
700
+ if grayscale:
701
+ storage_transforms.append(GrayScale(in_keys=["pixels", ("next", "pixels")]))
702
+ storage_transforms.append(
703
+ Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")])
704
+ )
705
+ return storage_transforms
706
+
707
+
708
+ def _to_device(td, device):
709
+ return td.to(device=device, non_blocking=True)
710
+
711
+
712
+ def make_replay_buffer(
713
+ *,
714
+ batch_size,
715
+ batch_seq_len,
716
+ buffer_size=1000000,
717
+ buffer_scratch_dir=None,
718
+ device=None,
719
+ prefetch=8,
720
+ pixel_obs=True,
721
+ grayscale=True,
722
+ image_size,
723
+ ):
724
+ """Create replay buffer with minimal sample-time transforms.
725
+
726
+ Heavy image transforms are expected to be applied at extend-time using
727
+ make_storage_transform(). Only DeviceCastTransform is applied at sample-time.
728
+
729
+ Note: We don't compile the SliceSampler because:
730
+ 1. Sampler operations (index computation) happen on CPU and are already fast
731
+ 2. torch.compile with inductor has bugs with the sampler's vectorized int64 operations
732
+ """
733
+ with (
734
+ tempfile.TemporaryDirectory()
735
+ if buffer_scratch_dir is None
736
+ else nullcontext(buffer_scratch_dir)
737
+ ) as scratch_dir:
738
+ # Sample-time transforms: only device transfer (fast)
739
+ sample_transforms = Compose(
740
+ functools.partial(_to_device, device=device),
741
+ )
742
+
743
+ replay_buffer = TensorDictReplayBuffer(
744
+ pin_memory=False,
745
+ prefetch=prefetch,
746
+ storage=LazyMemmapStorage(
747
+ buffer_size,
748
+ scratch_dir=scratch_dir,
749
+ device="cpu",
750
+ ndim=2,
751
+ shared_init=True, # Allow remote processes to initialize storage
752
+ ),
753
+ sampler=SliceSampler(
754
+ slice_len=batch_seq_len,
755
+ strict_length=False,
756
+ traj_key=("collector", "traj_ids"),
757
+ cache_values=False, # Disabled for async collection (cache not synced across processes)
758
+ # Don't compile the sampler - inductor has C++ codegen bugs for int64 ops
759
+ ),
760
+ transform=sample_transforms,
761
+ batch_size=batch_size,
762
+ )
763
+ return replay_buffer
764
+
765
+
766
+ def _dreamer_make_value_model(
767
+ hidden_dim: int = 400,
768
+ activation: str = "elu",
769
+ value_key: str = "state_value",
770
+ device=None,
771
+ ):
772
+ value_model = MLP(
773
+ out_features=1,
774
+ depth=3,
775
+ num_cells=hidden_dim,
776
+ activation_class=get_activation(activation),
777
+ device=device,
778
+ )
779
+ value_model = ProbabilisticTensorDictSequential(
780
+ TensorDictModule(
781
+ value_model,
782
+ in_keys=["state", "belief"],
783
+ out_keys=["loc"],
784
+ ),
785
+ ProbabilisticTensorDictModule(
786
+ in_keys=["loc"],
787
+ out_keys=[value_key],
788
+ distribution_class=IndependentNormal,
789
+ distribution_kwargs={"scale": 1.0, "event_dim": 1},
790
+ ),
791
+ )
792
+
793
+ return value_model
794
+
795
+
796
+ def _dreamer_make_actors(
797
+ encoder,
798
+ observation_in_key,
799
+ rssm_prior,
800
+ rssm_posterior,
801
+ mlp_num_units,
802
+ activation,
803
+ action_key,
804
+ test_env,
805
+ device=None,
806
+ ):
807
+ actor_module = DreamerActor(
808
+ out_features=test_env.action_spec.shape[-1],
809
+ depth=3,
810
+ num_cells=mlp_num_units,
811
+ activation_class=activation,
812
+ device=device,
813
+ )
814
+ actor_simulator = _dreamer_make_actor_sim(action_key, test_env, actor_module)
815
+ actor_realworld = _dreamer_make_actor_real(
816
+ encoder,
817
+ observation_in_key,
818
+ rssm_prior,
819
+ rssm_posterior,
820
+ actor_module,
821
+ action_key,
822
+ test_env,
823
+ )
824
+ return actor_simulator, actor_realworld
825
+
826
+
827
+ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
828
+ actor_simulator = SafeProbabilisticTensorDictSequential(
829
+ SafeModule(
830
+ actor_module,
831
+ in_keys=["state", "belief"],
832
+ out_keys=["loc", "scale"],
833
+ spec=Composite(
834
+ **{
835
+ "loc": Unbounded(
836
+ proof_environment.action_spec_unbatched.shape,
837
+ device=proof_environment.action_spec_unbatched.device,
838
+ ),
839
+ "scale": Unbounded(
840
+ proof_environment.action_spec_unbatched.shape,
841
+ device=proof_environment.action_spec_unbatched.device,
842
+ ),
843
+ }
844
+ ),
845
+ ),
846
+ SafeProbabilisticModule(
847
+ in_keys=["loc", "scale"],
848
+ out_keys=[action_key],
849
+ default_interaction_type=InteractionType.RANDOM,
850
+ distribution_class=TanhNormal,
851
+ distribution_kwargs={"tanh_loc": True},
852
+ spec=Composite(**{action_key: proof_environment.action_spec_unbatched}),
853
+ ),
854
+ )
855
+ return actor_simulator
856
+
857
+
858
+ def _dreamer_make_actor_real(
859
+ encoder,
860
+ observation_in_key,
861
+ rssm_prior,
862
+ rssm_posterior,
863
+ actor_module,
864
+ action_key,
865
+ proof_environment,
866
+ ):
867
+ # actor for real world: interacts with states ~ posterior
868
+ # Out actor differs from the original paper where first they compute prior and posterior and then act on it
869
+ # but we found that this approach worked better.
870
+ actor_realworld = SafeSequential(
871
+ SafeModule(
872
+ encoder,
873
+ in_keys=[observation_in_key],
874
+ out_keys=["encoded_latents"],
875
+ ),
876
+ SafeModule(
877
+ rssm_posterior,
878
+ in_keys=["belief", "encoded_latents"],
879
+ out_keys=[
880
+ "_",
881
+ "_",
882
+ "state",
883
+ ],
884
+ ),
885
+ SafeProbabilisticTensorDictSequential(
886
+ SafeModule(
887
+ actor_module,
888
+ in_keys=["state", "belief"],
889
+ out_keys=["loc", "scale"],
890
+ spec=Composite(
891
+ **{
892
+ "loc": Unbounded(
893
+ proof_environment.action_spec_unbatched.shape,
894
+ ),
895
+ "scale": Unbounded(
896
+ proof_environment.action_spec_unbatched.shape,
897
+ ),
898
+ }
899
+ ),
900
+ ),
901
+ SafeProbabilisticModule(
902
+ in_keys=["loc", "scale"],
903
+ out_keys=[action_key],
904
+ default_interaction_type=InteractionType.DETERMINISTIC,
905
+ distribution_class=TanhNormal,
906
+ distribution_kwargs={"tanh_loc": True},
907
+ spec=proof_environment.full_action_spec_unbatched.to("cpu"),
908
+ ),
909
+ ),
910
+ SafeModule(
911
+ rssm_prior,
912
+ in_keys=["state", "belief", action_key],
913
+ out_keys=[
914
+ "_",
915
+ "_",
916
+ "_", # we don't need the prior state
917
+ ("next", "belief"),
918
+ ],
919
+ ),
920
+ )
921
+ return actor_realworld
922
+
923
+
924
+ def _dreamer_make_mbenv(
925
+ reward_module,
926
+ rssm_prior,
927
+ test_env,
928
+ decoder,
929
+ observation_out_key: str = "reco_pixels",
930
+ use_decoder_in_env: bool = False,
931
+ state_dim: int = 30,
932
+ rssm_hidden_dim: int = 200,
933
+ ):
934
+ # MB environment
935
+ if use_decoder_in_env:
936
+ mb_env_obs_decoder = SafeModule(
937
+ decoder,
938
+ in_keys=["state", "belief"],
939
+ out_keys=[observation_out_key],
940
+ )
941
+ else:
942
+ mb_env_obs_decoder = None
943
+
944
+ transition_model = SafeSequential(
945
+ SafeModule(
946
+ rssm_prior,
947
+ in_keys=["state", "belief", "action"],
948
+ out_keys=[
949
+ "_",
950
+ "_",
951
+ "state",
952
+ "belief",
953
+ ],
954
+ ),
955
+ )
956
+
957
+ reward_model = SafeProbabilisticTensorDictSequential(
958
+ SafeModule(
959
+ reward_module,
960
+ in_keys=["state", "belief"],
961
+ out_keys=["loc"],
962
+ ),
963
+ SafeProbabilisticModule(
964
+ in_keys=["loc"],
965
+ out_keys=["reward"],
966
+ distribution_class=IndependentNormal,
967
+ distribution_kwargs={"scale": 1.0, "event_dim": 1},
968
+ ),
969
+ )
970
+
971
+ model_based_env = DreamerEnv(
972
+ world_model=WorldModelWrapper(
973
+ transition_model,
974
+ reward_model,
975
+ ),
976
+ prior_shape=torch.Size([state_dim]),
977
+ belief_shape=torch.Size([rssm_hidden_dim]),
978
+ obs_decoder=mb_env_obs_decoder,
979
+ )
980
+
981
+ model_based_env.set_specs_from_env(test_env)
982
+ return model_based_env
983
+
984
+
985
+ def _dreamer_make_world_model(
986
+ encoder,
987
+ decoder,
988
+ rssm_prior,
989
+ rssm_posterior,
990
+ reward_module,
991
+ observation_in_key: NestedKey = "pixels",
992
+ observation_out_key: NestedKey = "reco_pixels",
993
+ use_scan: bool = False,
994
+ rssm_rollout_compile: bool = False,
995
+ rssm_rollout_compile_backend: str = "inductor",
996
+ rssm_rollout_compile_mode: str | None = "reduce-overhead",
997
+ ):
998
+ # World Model and reward model
999
+ # Note: in_keys uses dict form with out_to_in_map=True to map function args to tensordict keys.
1000
+ # {"noise": "prior_noise"} means: read "prior_noise" from tensordict, pass as `noise` kwarg.
1001
+ # With strict=False (default), missing noise keys pass None to the module.
1002
+ rssm_rollout = RSSMRollout(
1003
+ TensorDictModule(
1004
+ rssm_prior,
1005
+ in_keys={
1006
+ "state": "state",
1007
+ "belief": "belief",
1008
+ "action": "action",
1009
+ "noise": "prior_noise",
1010
+ },
1011
+ out_keys=[
1012
+ ("next", "prior_mean"),
1013
+ ("next", "prior_std"),
1014
+ "_",
1015
+ ("next", "belief"),
1016
+ ],
1017
+ out_to_in_map=True,
1018
+ ),
1019
+ TensorDictModule(
1020
+ rssm_posterior,
1021
+ in_keys={
1022
+ "belief": ("next", "belief"),
1023
+ "obs_embedding": ("next", "encoded_latents"),
1024
+ "noise": "posterior_noise",
1025
+ },
1026
+ out_keys=[
1027
+ ("next", "posterior_mean"),
1028
+ ("next", "posterior_std"),
1029
+ ("next", "state"),
1030
+ ],
1031
+ out_to_in_map=True,
1032
+ ),
1033
+ use_scan=use_scan,
1034
+ compile_step=rssm_rollout_compile,
1035
+ compile_backend=rssm_rollout_compile_backend,
1036
+ compile_mode=rssm_rollout_compile_mode,
1037
+ )
1038
+ event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB
1039
+ decoder = ProbabilisticTensorDictSequential(
1040
+ TensorDictModule(
1041
+ decoder,
1042
+ in_keys=[("next", "state"), ("next", "belief")],
1043
+ out_keys=["loc"],
1044
+ ),
1045
+ ProbabilisticTensorDictModule(
1046
+ in_keys=["loc"],
1047
+ out_keys=[("next", observation_out_key)],
1048
+ distribution_class=IndependentNormal,
1049
+ distribution_kwargs={"scale": 1.0, "event_dim": event_dim},
1050
+ ),
1051
+ )
1052
+
1053
+ transition_model = TensorDictSequential(
1054
+ TensorDictModule(
1055
+ encoder,
1056
+ in_keys=[("next", observation_in_key)],
1057
+ out_keys=[("next", "encoded_latents")],
1058
+ ),
1059
+ rssm_rollout,
1060
+ decoder,
1061
+ )
1062
+
1063
+ reward_model = ProbabilisticTensorDictSequential(
1064
+ TensorDictModule(
1065
+ reward_module,
1066
+ in_keys=[("next", "state"), ("next", "belief")],
1067
+ out_keys=[("next", "loc")],
1068
+ ),
1069
+ ProbabilisticTensorDictModule(
1070
+ in_keys=[("next", "loc")],
1071
+ out_keys=[("next", "reward")],
1072
+ distribution_class=IndependentNormal,
1073
+ distribution_kwargs={"scale": 1.0, "event_dim": 1},
1074
+ ),
1075
+ )
1076
+
1077
+ world_model = WorldModelWrapper(
1078
+ transition_model,
1079
+ reward_model,
1080
+ )
1081
+ return world_model
1082
+
1083
+
1084
+ def log_metrics(logger, metrics, step):
1085
+ for metric_name, metric_value in metrics.items():
1086
+ logger.log_scalar(metric_name, metric_value, step)
1087
+
1088
+
1089
+ def get_activation(name):
1090
+ if name == "relu":
1091
+ return nn.ReLU
1092
+ elif name == "tanh":
1093
+ return nn.Tanh
1094
+ elif name == "leaky_relu":
1095
+ return nn.LeakyReLU
1096
+ elif name == "elu":
1097
+ return nn.ELU
1098
+ else:
1099
+ raise NotImplementedError
1100
+
1101
+
1102
+ def _default_device(device=None):
1103
+ if device in ("", None):
1104
+ if torch.cuda.is_available():
1105
+ return torch.device("cuda")
1106
+ return torch.device("cpu")
1107
+ return torch.device(device)