torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,845 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ import uuid
9
+ import warnings
10
+ from collections import OrderedDict
11
+ from collections.abc import Sequence
12
+ from copy import copy
13
+
14
+ from typing import Any
15
+
16
+ import torch
17
+ from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
18
+ from tensordict.utils import _zip_strict
19
+ from torch import multiprocessing as mp
20
+ from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
21
+
22
+ from torchrl.envs.common import EnvBase
23
+ from torchrl.envs.transforms.transforms import Compose, ObservationNorm, Transform
24
+
25
+ from torchrl.envs.transforms.utils import _set_missing_tolerance
26
+
27
+
28
+ class VecNormV2(Transform):
29
+ """A class for normalizing vectorized observations and rewards in reinforcement learning environments.
30
+
31
+ `VecNormV2` can operate in either a stateful or stateless mode. In stateful mode, it maintains
32
+ internal statistics (mean and variance) to normalize inputs. In stateless mode, it requires
33
+ external statistics to be provided for normalization.
34
+
35
+ .. note:: This class is designed to be an almost drop-in replacement for :class:`~torchrl.envs.transforms.VecNorm`.
36
+ It should not be constructed directly, but rather with the :class:`~torchrl.envs.transforms.VecNorm`
37
+ transform using the `new_api=True` keyword argument. In v0.10, the :class:`~torchrl.envs.transforms.VecNorm`
38
+ transform will be switched to the new api by default.
39
+
40
+ Stateful vs. Stateless:
41
+ Stateful Mode (`stateful=True`):
42
+
43
+ - Maintains internal statistics (`loc`, `var`, `count`) for normalization.
44
+ - Updates statistics with each call unless frozen.
45
+ - `state_dict` returns the current statistics.
46
+ - `load_state_dict` updates the internal statistics with the provided state.
47
+
48
+ Stateless Mode (`stateful=False`):
49
+
50
+ - Requires external statistics to be provided for normalization.
51
+ - Does not maintain or update internal statistics.
52
+ - `state_dict` returns an empty dictionary.
53
+ - `load_state_dict` does not affect internal state.
54
+
55
+ Args:
56
+ in_keys (Sequence[NestedKey]): The input keys for the data to be normalized.
57
+ out_keys (Sequence[NestedKey] | None): The output keys for the normalized data. Defaults to `in_keys` if
58
+ not provided.
59
+ lock (mp.Lock, optional): A lock for thread safety.
60
+ stateful (bool, optional): Whether the `VecNorm` is stateful. Stateless versions of this
61
+ transform requires the data to be carried within the input/output tensordicts.
62
+ Defaults to `True`.
63
+ decay (float, optional): The decay rate for updating statistics. Defaults to `0.9999`.
64
+ If `decay=1` is used, the normalizing statistics have an infinite memory (each item is weighed
65
+ identically). Lower values weigh recent data more than old ones.
66
+ eps (float, optional): A small value to prevent division by zero. Defaults to `1e-4`.
67
+ shared_data (TensorDictBase | None, optional): Shared data for initialization. Defaults to `None`.
68
+ reduce_batch_dims (bool, optional): If `True`, the batch dimensions are reduced by averaging the data
69
+ before updating the statistics. This is useful when samples are received in batches, as it allows
70
+ the moving average to be computed over the entire batch rather than individual elements. Note that
71
+ this option is only supported in stateful mode (`stateful=True`). Defaults to `False`.
72
+
73
+ Attributes:
74
+ stateful (bool): Indicates whether the VecNormV2 is stateful or stateless.
75
+ lock (mp.Lock): A multiprocessing lock to ensure thread safety when updating statistics.
76
+ decay (float): The decay rate for updating statistics.
77
+ eps (float): A small value to prevent division by zero during normalization.
78
+ frozen (bool): Indicates whether the VecNormV2 is frozen, preventing updates to statistics.
79
+ _cast_int_to_float (bool): Indicates whether integer inputs should be cast to float.
80
+
81
+ Methods:
82
+ freeze(): Freezes the VecNorm, preventing updates to statistics.
83
+ unfreeze(): Unfreezes the VecNorm, allowing updates to statistics.
84
+ frozen_copy(): Returns a frozen copy of the VecNorm.
85
+ clone(): Returns a clone of the VecNorm.
86
+ transform_observation_spec(observation_spec): Transforms the observation specification.
87
+ transform_reward_spec(reward_spec, observation_spec): Transforms the reward specification.
88
+ transform_output_spec(output_spec): Transforms the output specification.
89
+ to_observation_norm(): Converts the VecNorm to an ObservationNorm transform.
90
+ set_extra_state(state): Sets the extra state for the VecNorm.
91
+ get_extra_state(): Gets the extra state of the VecNorm.
92
+ loc: Returns the location (mean) for normalization.
93
+ scale: Returns the scale (standard deviation) for normalization.
94
+ standard_normal: Indicates whether the normalization follows the standard normal distribution.
95
+
96
+ State Dict Behavior:
97
+
98
+ - In stateful mode, `state_dict` returns a dictionary containing the current `loc`, `var`, and `count`.
99
+ These can be used to share the tensors across processes (this method is automatically triggered by
100
+ :class:`~torchrl.envs.VecNorm` to share the VecNorm states across processes).
101
+ - In stateless mode, `state_dict` returns an empty dictionary as no internal state is maintained.
102
+
103
+ Load State Dict Behavior:
104
+
105
+ - In stateful mode, `load_state_dict` updates the internal `loc`, `var`, and `count` with the provided state.
106
+ - In stateless mode, `load_state_dict` does not modify any internal state as there is none to update.
107
+
108
+ .. seealso:: :class:`~torchrl.envs.transforms.VecNorm` for the first version of this transform.
109
+
110
+ Examples:
111
+ >>> import torch
112
+ >>> from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, SerialEnv, VecNormV2
113
+ >>>
114
+ >>> torch.manual_seed(0)
115
+ >>> env = GymEnv("Pendulum-v1")
116
+ >>> env_trsf = env.append_transform(
117
+ >>> VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"])
118
+ >>> )
119
+ >>> r = env_trsf.rollout(10)
120
+ >>> print("Unnormalized rewards", r["next", "reward"])
121
+ Unnormalized rewards tensor([[ -1.7967],
122
+ [ -2.1238],
123
+ [ -2.5911],
124
+ [ -3.5275],
125
+ [ -4.8585],
126
+ [ -6.5028],
127
+ [ -8.2505],
128
+ [-10.3169],
129
+ [-12.1332],
130
+ [-13.1235]])
131
+ >>> print("Normalized rewards", r["next", "reward_norm"])
132
+ Normalized rewards tensor([[-1.6596e-04],
133
+ [-8.3072e-02],
134
+ [-1.9170e-01],
135
+ [-3.9255e-01],
136
+ [-5.9131e-01],
137
+ [-7.4671e-01],
138
+ [-8.3760e-01],
139
+ [-9.2058e-01],
140
+ [-9.3484e-01],
141
+ [-8.6185e-01]])
142
+ >>> # Aggregate values when using batched envs
143
+ >>> env = SerialEnv(2, [lambda: GymEnv("Pendulum-v1")] * 2)
144
+ >>> env_trsf = env.append_transform(
145
+ >>> VecNormV2(
146
+ >>> in_keys=["observation", "reward"],
147
+ >>> out_keys=["observation_norm", "reward_norm"],
148
+ >>> # Use reduce_batch_dims=True to aggregate values across batch elements
149
+ >>> reduce_batch_dims=True, )
150
+ >>> )
151
+ >>> r = env_trsf.rollout(10)
152
+ >>> print("Unnormalized rewards", r["next", "reward"])
153
+ Unnormalized rewards tensor([[[-0.1456],
154
+ [-0.1862],
155
+ [-0.2053],
156
+ [-0.2605],
157
+ [-0.4046],
158
+ [-0.5185],
159
+ [-0.8023],
160
+ [-1.1364],
161
+ [-1.6183],
162
+ [-2.5406]],
163
+
164
+ [[-0.0920],
165
+ [-0.1492],
166
+ [-0.2702],
167
+ [-0.3917],
168
+ [-0.5001],
169
+ [-0.7947],
170
+ [-1.0160],
171
+ [-1.3347],
172
+ [-1.9082],
173
+ [-2.9679]]])
174
+ >>> print("Normalized rewards", r["next", "reward_norm"])
175
+ Normalized rewards tensor([[[-0.2199],
176
+ [-0.2918],
177
+ [-0.1668],
178
+ [-0.2083],
179
+ [-0.4981],
180
+ [-0.5046],
181
+ [-0.7950],
182
+ [-0.9791],
183
+ [-1.1484],
184
+ [-1.4182]],
185
+
186
+ [[ 0.2201],
187
+ [-0.0403],
188
+ [-0.5206],
189
+ [-0.7791],
190
+ [-0.8282],
191
+ [-1.2306],
192
+ [-1.2279],
193
+ [-1.2907],
194
+ [-1.4929],
195
+ [-1.7793]]])
196
+ >>> print("Loc / scale", env_trsf.transform.loc["reward"], env_trsf.transform.scale["reward"])
197
+ Loc / scale tensor([-0.8626]) tensor([1.1832])
198
+ >>>
199
+ >>> # Share values between workers
200
+ >>> def make_env():
201
+ ... env = GymEnv("Pendulum-v1")
202
+ ... env_trsf = env.append_transform(
203
+ ... VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"])
204
+ ... )
205
+ ... return env_trsf
206
+ ...
207
+ ...
208
+ >>> if __name__ == "__main__":
209
+ ... # EnvCreator will share the loc/scale vals
210
+ ... make_env = EnvCreator(make_env)
211
+ ... # Create a local env to track the loc/scale
212
+ ... local_env = make_env()
213
+ ... env = ParallelEnv(2, [make_env] * 2)
214
+ ... r = env.rollout(10)
215
+ ... # Non-zero loc and scale testify that the sub-envs share their summary stats with us
216
+ ... print("Remotely updated loc / scale", local_env.transform.loc["reward"], local_env.transform.scale["reward"])
217
+ Remotely updated loc / scale tensor([-0.4307]) tensor([0.9613])
218
+ ... env.close()
219
+
220
+ """
221
+
222
+ # TODO:
223
+ # - test 2 different vecnorms, one for reward one for obs and that they don't collide
224
+ # - test that collision is spotted
225
+ # - customize the vecnorm keys in stateless
226
+ def __init__(
227
+ self,
228
+ in_keys: Sequence[NestedKey],
229
+ out_keys: Sequence[NestedKey] | None = None,
230
+ *,
231
+ lock: mp.Lock = None,
232
+ stateful: bool = True,
233
+ decay: float = 0.9999,
234
+ eps: float = 1e-4,
235
+ shared_data: TensorDictBase | None = None,
236
+ reduce_batch_dims: bool = False,
237
+ ) -> None:
238
+ self.stateful = stateful
239
+ if lock is None:
240
+ lock = mp.Lock()
241
+ if out_keys is None:
242
+ out_keys = copy(in_keys)
243
+ super().__init__(in_keys=in_keys, out_keys=out_keys)
244
+
245
+ self.lock = lock
246
+ self.decay = decay
247
+ self.eps = eps
248
+ self.frozen = False
249
+ self._cast_int_to_float = False
250
+ if self.stateful:
251
+ self.register_buffer("initialized", torch.zeros((), dtype=torch.bool))
252
+ if shared_data:
253
+ self._loc = shared_data["loc"]
254
+ self._var = shared_data["var"]
255
+ self._count = shared_data["count"]
256
+ else:
257
+ self._loc = None
258
+ self._var = None
259
+ self._count = None
260
+ else:
261
+ self.initialized = False
262
+ if shared_data:
263
+ # FIXME
264
+ raise NotImplementedError
265
+ if reduce_batch_dims and not stateful:
266
+ raise RuntimeError(
267
+ "reduce_batch_dims=True and stateful=False are not supported."
268
+ )
269
+ self.reduce_batch_dims = reduce_batch_dims
270
+
271
+ @property
272
+ def in_keys(self) -> Sequence[NestedKey]:
273
+ in_keys = self._in_keys
274
+ if not self.stateful:
275
+ in_keys = in_keys + [
276
+ f"{self.prefix}_count",
277
+ f"{self.prefix}_loc",
278
+ f"{self.prefix}_var",
279
+ ]
280
+ return in_keys
281
+
282
+ @in_keys.setter
283
+ def in_keys(self, in_keys: Sequence[NestedKey]):
284
+ self._in_keys = in_keys
285
+
286
+ def set_container(self, container: Transform | EnvBase) -> None:
287
+ super().set_container(container)
288
+ if self.stateful:
289
+ parent = getattr(self, "parent", None)
290
+ if parent is not None and isinstance(parent, EnvBase):
291
+ if not parent.batch_locked:
292
+ warnings.warn(
293
+ f"Support of {type(self).__name__} for unbatched container is experimental and subject to change."
294
+ )
295
+ if parent.batch_size:
296
+ warnings.warn(
297
+ f"Support of {type(self).__name__} for containers with non-empty batch-size is experimental and subject to change."
298
+ )
299
+ # init
300
+ data = parent.fake_tensordict().get("next")
301
+ self._maybe_stateful_init(data)
302
+ else:
303
+ parent = getattr(self, "parent", None)
304
+ if parent is not None and isinstance(parent, EnvBase):
305
+ self._make_prefix(parent.output_spec)
306
+
307
+ def freeze(self) -> VecNormV2:
308
+ """Freezes the VecNorm, avoiding the stats to be updated when called.
309
+
310
+ See :meth:`~.unfreeze`.
311
+ """
312
+ self.frozen = True
313
+ return self
314
+
315
+ def unfreeze(self) -> VecNormV2:
316
+ """Unfreezes the VecNorm.
317
+
318
+ See :meth:`~.freeze`.
319
+ """
320
+ self.frozen = False
321
+ return self
322
+
323
+ def frozen_copy(self):
324
+ """Returns a copy of the Transform that keeps track of the stats but does not update them."""
325
+ if not self.stateful:
326
+ raise RuntimeError("Cannot create a frozen copy of a statelss VecNorm.")
327
+ if self._loc is None:
328
+ raise RuntimeError(
329
+ "Make sure the VecNorm has been initialized before creating a frozen copy."
330
+ )
331
+ clone = self.clone()
332
+ if self.stateful:
333
+ # replace values
334
+ clone._var = self._var.clone()
335
+ clone._loc = self._loc.clone()
336
+ clone._count = self._count.clone()
337
+ # freeze
338
+ return clone.freeze()
339
+
340
+ def clone(self) -> VecNormV2:
341
+ other = super().clone()
342
+ if self.stateful:
343
+ delattr(other, "initialized")
344
+ other.register_buffer("initialized", self.initialized.clone())
345
+ if self._loc is not None:
346
+ other.initialized.fill_(True)
347
+ other._loc = self._loc.clone()
348
+ other._var = self._var.clone()
349
+ other._count = self._count.clone()
350
+ return other
351
+
352
+ def _apply(self, fn, recurse=True):
353
+ """Apply device/dtype transformation to the module and its TensorDict state.
354
+
355
+ This method is called internally by PyTorch when using .to(), .cuda(), .cpu(), etc.
356
+ In stateful mode, we manually apply the transformation to _loc, _var, and _count
357
+ since they are TensorDict instances, not registered buffers.
358
+ """
359
+ super()._apply(fn, recurse=recurse)
360
+
361
+ if self.stateful and self._loc is not None:
362
+ self._loc = self._loc.apply(fn)
363
+ self._var = self._var.apply(fn)
364
+ # Move _count to same device as _loc, but preserve its int dtype.
365
+ # We extract the device from an actual leaf tensor because TensorDict.device
366
+ # can be stale after .apply(fn) moves the leaves.
367
+ iterator = iter(self._loc.values(True, True))
368
+ leaf_tensor = next(iterator)
369
+ while not isinstance(leaf_tensor, torch.Tensor):
370
+ leaf_tensor = next(iterator)
371
+ target_device = leaf_tensor.device
372
+ if isinstance(self._count, TensorDictBase):
373
+ self._count = self._count.to(device=target_device)
374
+ else:
375
+ self._count = self._count.to(device=target_device)
376
+
377
+ return self
378
+
379
+ def _reset(
380
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
381
+ ) -> TensorDictBase:
382
+ # TODO: remove this decorator when trackers are in data
383
+ with _set_missing_tolerance(self, True):
384
+ return self._step(tensordict_reset, tensordict_reset)
385
+ return tensordict_reset
386
+
387
+ def _step(
388
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
389
+ ) -> TensorDictBase:
390
+ if self.lock is not None:
391
+ self.lock.acquire()
392
+ try:
393
+ if self.stateful:
394
+ self._maybe_stateful_init(next_tensordict)
395
+ next_tensordict_select = next_tensordict.select(
396
+ *self.in_keys, strict=not self.missing_tolerance
397
+ )
398
+ if self.missing_tolerance and next_tensordict_select.is_empty():
399
+ return next_tensordict
400
+ self._stateful_update(next_tensordict_select)
401
+ next_tensordict_norm = self._stateful_norm(next_tensordict_select)
402
+ else:
403
+ self._maybe_stateless_init(tensordict)
404
+ next_tensordict_select = next_tensordict.select(
405
+ *self._in_keys_safe, strict=not self.missing_tolerance
406
+ )
407
+ if self.missing_tolerance and next_tensordict_select.is_empty():
408
+ return next_tensordict
409
+ loc = tensordict[f"{self.prefix}_loc"]
410
+ var = tensordict[f"{self.prefix}_var"]
411
+ count = tensordict[f"{self.prefix}_count"]
412
+
413
+ loc, var, count = self._stateless_update(
414
+ next_tensordict_select, loc, var, count
415
+ )
416
+ next_tensordict_norm = self._stateless_norm(
417
+ next_tensordict_select, loc, var, count
418
+ )
419
+ # updates have been done in-place, we're good
420
+ next_tensordict_norm.set(f"{self.prefix}_loc", loc)
421
+ next_tensordict_norm.set(f"{self.prefix}_var", var)
422
+ next_tensordict_norm.set(f"{self.prefix}_count", count)
423
+
424
+ next_tensordict.update(next_tensordict_norm)
425
+ finally:
426
+ if self.lock is not None:
427
+ self.lock.release()
428
+
429
+ return next_tensordict
430
+
431
+ def _maybe_cast_to_float(self, data):
432
+ if self._cast_int_to_float:
433
+ dtype = torch.get_default_dtype()
434
+ data = data.apply(
435
+ lambda x: x.to(dtype) if not x.dtype.is_floating_point else x
436
+ )
437
+ return data
438
+
439
+ @staticmethod
440
+ def _maybe_make_float(x):
441
+ if x.dtype.is_floating_point:
442
+ return x
443
+ return x.to(torch.get_default_dtype())
444
+
445
+ def _maybe_stateful_init(self, data):
446
+ if not self.initialized:
447
+ self.initialized.copy_(True)
448
+ # Some keys (specifically rewards) may be missing, but we can use the
449
+ # specs for them
450
+ try:
451
+ data_select = data.select(*self._in_keys_safe, strict=True)
452
+ except KeyError:
453
+ data_select = self.parent.full_observation_spec.zero().update(
454
+ self.parent.full_reward_spec.zero()
455
+ )
456
+ data_select = data_select.update(data)
457
+ data_select = data_select.select(*self._in_keys_safe, strict=True)
458
+ if self.reduce_batch_dims and data_select.ndim:
459
+ # collapse the batch-dims
460
+ data_select = data_select.mean(dim=tuple(range(data.ndim)))
461
+ # For the count, we must use a TD because some keys (eg Reward) may be missing at some steps (eg, reset)
462
+ # We use mean() to eliminate all dims - since it's local we don't need to expand the shape
463
+ count = (
464
+ torch.zeros_like(data_select, dtype=torch.float32)
465
+ .mean()
466
+ .to(torch.int64)
467
+ )
468
+ # create loc
469
+ loc = torch.zeros_like(data_select.apply(self._maybe_make_float))
470
+ # create var
471
+ var = torch.zeros_like(data_select.apply(self._maybe_make_float))
472
+ self._loc = loc
473
+ self._var = var
474
+ self._count = count
475
+
476
+ @property
477
+ def _in_keys_safe(self):
478
+ if not self.stateful:
479
+ return self.in_keys[:-3]
480
+ return self.in_keys
481
+
482
+ def _norm(self, data, loc, var, count):
483
+ if self.missing_tolerance:
484
+ loc = loc.select(*data.keys(True, True))
485
+ var = var.select(*data.keys(True, True))
486
+ count = count.select(*data.keys(True, True))
487
+ if loc.is_empty():
488
+ return data
489
+
490
+ if self.decay < 1.0:
491
+ bias_correction = 1 - (count * math.log(self.decay)).exp()
492
+ bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
493
+ else:
494
+ bias_correction = 1
495
+
496
+ var = var - loc.pow(2)
497
+ loc = loc / bias_correction
498
+ var = var / bias_correction
499
+
500
+ scale = var.sqrt().clamp_min(self.eps)
501
+
502
+ data_update = (data - loc) / scale
503
+ if self.out_keys[: len(self.in_keys)] != self.in_keys:
504
+ # map names
505
+ for in_key, out_key in _zip_strict(self._in_keys_safe, self.out_keys):
506
+ if in_key in data_update:
507
+ data_update.rename_key_(in_key, out_key)
508
+ else:
509
+ pass
510
+ return data_update
511
+
512
+ def _stateful_norm(self, data):
513
+ return self._norm(data, self._loc, self._var, self._count)
514
+
515
+ def _stateful_update(self, data):
516
+ if self.frozen:
517
+ return
518
+ if self.missing_tolerance:
519
+ var = self._var.select(*data.keys(True, True))
520
+ loc = self._loc.select(*data.keys(True, True))
521
+ count = self._count.select(*data.keys(True, True))
522
+ else:
523
+ var = self._var
524
+ loc = self._loc
525
+ count = self._count
526
+ data = self._maybe_cast_to_float(data)
527
+ if self.reduce_batch_dims and data.ndim:
528
+ # The naive way to do this would be to convert the data to a list and iterate over it, but (1) that is
529
+ # slow, and (2) it makes the value of the loc/var conditioned on the order we take to iterate over the data.
530
+ # The second approach would be to average the data, but that would mean that having one vecnorm per batched
531
+ # env or one per sub-env will lead to different results as a batch of N elements will actually be
532
+ # considered as a single one.
533
+ # What we go for instead is to average the data (and its squared value) then do the moving average with
534
+ # adapted decay.
535
+ n = data.numel()
536
+ count += n
537
+ data2 = data.pow(2).mean(dim=tuple(range(data.ndim)))
538
+ data_mean = data.mean(dim=tuple(range(data.ndim)))
539
+ if self.decay != 1.0:
540
+ weight = 1 - self.decay**n
541
+ else:
542
+ weight = n / count
543
+ else:
544
+ count += 1
545
+ data2 = data.pow(2)
546
+ data_mean = data
547
+ if self.decay != 1.0:
548
+ weight = 1 - self.decay
549
+ else:
550
+ weight = 1 / count
551
+ loc.lerp_(end=data_mean, weight=weight)
552
+ var.lerp_(end=data2, weight=weight)
553
+
554
+ def _maybe_stateless_init(self, data):
555
+ if not self.initialized or f"{self.prefix}_loc" not in data.keys():
556
+ self.initialized = True
557
+ # select all except vecnorm
558
+ # Some keys (specifically rewards) may be missing, but we can use the
559
+ # specs for them
560
+ try:
561
+ data_select = data.select(*self._in_keys_safe, strict=True)
562
+ except KeyError:
563
+ data_select = self.parent.full_observation_spec.zero().update(
564
+ self.parent.full_reward_spec.zero()
565
+ )
566
+ data_select = data_select.update(data)
567
+ data_select = data_select.select(*self._in_keys_safe, strict=True)
568
+
569
+ data[f"{self.prefix}_count"] = torch.zeros_like(
570
+ data_select, dtype=torch.int64
571
+ )
572
+ # create loc
573
+ loc = torch.zeros_like(data_select.apply(self._maybe_make_float))
574
+ # create var
575
+ var = torch.zeros_like(data_select.apply(self._maybe_make_float))
576
+ data[f"{self.prefix}_loc"] = loc
577
+ data[f"{self.prefix}_var"] = var
578
+
579
+ def _stateless_norm(self, data, loc, var, count):
580
+ data = self._norm(data, loc, var, count)
581
+ return data
582
+
583
+ def _stateless_update(self, data, loc, var, count):
584
+ if self.frozen:
585
+ return loc, var, count
586
+ count = count + 1
587
+ data = self._maybe_cast_to_float(data)
588
+ if self.decay != 1.0:
589
+ weight = 1 - self.decay
590
+ else:
591
+ weight = 1 / count
592
+ loc = loc.lerp(end=data, weight=weight)
593
+ var = var.lerp(end=data.pow(2), weight=weight)
594
+ return loc, var, count
595
+
596
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
597
+ return self._transform_spec(observation_spec)
598
+
599
+ def transform_reward_spec(
600
+ self, reward_spec: Composite, observation_spec
601
+ ) -> Composite:
602
+ return self._transform_spec(reward_spec, observation_spec)
603
+
604
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
605
+ # This is a copy-paste of the parent methd to ensure that we correct the reward spec properly
606
+ output_spec = output_spec.clone()
607
+ observation_spec = self.transform_observation_spec(
608
+ output_spec["full_observation_spec"]
609
+ )
610
+ if "full_reward_spec" in output_spec.keys():
611
+ output_spec["full_reward_spec"] = self.transform_reward_spec(
612
+ output_spec["full_reward_spec"], observation_spec
613
+ )
614
+ output_spec["full_observation_spec"] = observation_spec
615
+ if "full_done_spec" in output_spec.keys():
616
+ output_spec["full_done_spec"] = self.transform_done_spec(
617
+ output_spec["full_done_spec"]
618
+ )
619
+ output_spec_keys = [
620
+ unravel_key(k[1:]) for k in output_spec.keys(True) if isinstance(k, tuple)
621
+ ]
622
+ out_keys = {unravel_key(k) for k in self.out_keys}
623
+ in_keys = {unravel_key(k) for k in self.in_keys}
624
+ for key in out_keys - in_keys:
625
+ if unravel_key(key) not in output_spec_keys:
626
+ warnings.warn(
627
+ f"The key '{key}' is unaccounted for by the transform (expected keys {output_spec_keys}). "
628
+ f"Every new entry in the tensordict resulting from a call to a transform must be "
629
+ f"registered in the specs for torchrl rollouts to be consistently built. "
630
+ f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. "
631
+ "This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.",
632
+ category=FutureWarning,
633
+ )
634
+ return output_spec
635
+
636
+ def _maybe_convert_bounded(self, in_spec):
637
+ if isinstance(in_spec, Composite):
638
+ return Composite(
639
+ {
640
+ key: self._maybe_convert_bounded(value)
641
+ for key, value in in_spec.items()
642
+ }
643
+ )
644
+ dtype = in_spec.dtype
645
+ if dtype is not None and not dtype.is_floating_point:
646
+ # we need to cast the tensor and spec to a float type
647
+ in_spec = in_spec.clone()
648
+ in_spec.dtype = torch.get_default_dtype()
649
+ self._cast_int_to_float = True
650
+
651
+ if isinstance(in_spec, Bounded):
652
+ in_spec = Unbounded(
653
+ shape=in_spec.shape, device=in_spec.device, dtype=in_spec.dtype
654
+ )
655
+ return in_spec
656
+
657
+ @property
658
+ def prefix(self):
659
+ prefix = getattr(self, "_prefix", "_vecnorm")
660
+ return prefix
661
+
662
+ def _make_prefix(self, output_spec):
663
+ prefix = getattr(self, "_prefix", None)
664
+ if prefix is not None:
665
+ return prefix
666
+ if (
667
+ "_vecnorm_loc" in output_spec["full_observation_spec"].keys()
668
+ or "_vecnorm_loc" in output_spec["full_reward_spec"].keys()
669
+ ):
670
+ prefix = "_vecnorm" + str(uuid.uuid1())
671
+ else:
672
+ prefix = "_vecnorm"
673
+ self._prefix = prefix
674
+ return prefix
675
+
676
+ def _proc_count_spec(self, count_spec, parent_shape=None):
677
+ if isinstance(count_spec, Composite):
678
+ for key, spec in count_spec.items():
679
+ spec = self._proc_count_spec(spec, parent_shape=count_spec.shape)
680
+ count_spec[key] = spec
681
+ return count_spec
682
+ if count_spec.dtype:
683
+ count_spec = Unbounded(
684
+ shape=count_spec.shape, dtype=torch.int64, device=count_spec.device
685
+ )
686
+ return count_spec
687
+
688
+ def _transform_spec(
689
+ self, spec: Composite, obs_spec: Composite | None = None
690
+ ) -> Composite:
691
+ in_specs = {}
692
+ for in_key, out_key in zip(self._in_keys_safe, self.out_keys):
693
+ if unravel_key(in_key) in spec.keys(True):
694
+ in_spec = spec.get(in_key).clone()
695
+ in_spec = self._maybe_convert_bounded(in_spec)
696
+ spec.set(out_key, in_spec)
697
+ in_specs[in_key] = in_spec
698
+ if not self.stateful and in_specs:
699
+ if obs_spec is None:
700
+ obs_spec = spec
701
+ loc_spec = obs_spec.get(f"{self.prefix}_loc", default=None)
702
+ var_spec = obs_spec.get(f"{self.prefix}_var", default=None)
703
+ count_spec = obs_spec.get(f"{self.prefix}_count", default=None)
704
+ if loc_spec is None:
705
+ loc_spec = Composite(shape=obs_spec.shape, device=obs_spec.device)
706
+ var_spec = Composite(shape=obs_spec.shape, device=obs_spec.device)
707
+ count_spec = Composite(shape=obs_spec.shape, device=obs_spec.device)
708
+ loc_spec.update(in_specs)
709
+ # should we clone?
710
+ var_spec.update(in_specs)
711
+ count_spec = count_spec.update(in_specs)
712
+ count_spec = self._proc_count_spec(count_spec)
713
+ obs_spec[f"{self.prefix}_loc"] = loc_spec
714
+ obs_spec[f"{self.prefix}_var"] = var_spec
715
+ obs_spec[f"{self.prefix}_count"] = count_spec
716
+ return spec
717
+
718
+ def to_observation_norm(self) -> Compose | ObservationNorm:
719
+ if not self.stateful:
720
+ # FIXME
721
+ raise NotImplementedError()
722
+ result = []
723
+
724
+ loc, scale = self._get_loc_scale()
725
+
726
+ for key, key_out in _zip_strict(self.in_keys, self.out_keys):
727
+ local_result = ObservationNorm(
728
+ loc=loc.get(key),
729
+ scale=scale.get(key),
730
+ standard_normal=True,
731
+ in_keys=key,
732
+ out_keys=key_out,
733
+ eps=self.eps,
734
+ )
735
+ result += [local_result]
736
+ if len(self.in_keys) > 1:
737
+ return Compose(*result)
738
+ return local_result
739
+
740
+ def _get_loc_scale(self, loc_only: bool = False) -> tuple:
741
+ if self.stateful:
742
+ loc = self._loc
743
+ count = self._count
744
+ if self.decay != 1.0:
745
+ bias_correction = 1 - (count * math.log(self.decay)).exp()
746
+ bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc)
747
+ else:
748
+ bias_correction = 1
749
+ if loc_only:
750
+ return loc / bias_correction, None
751
+ var = self._var
752
+ var = var - loc.pow(2)
753
+ loc = loc / bias_correction
754
+ var = var / bias_correction
755
+ scale = var.sqrt().clamp_min(self.eps)
756
+ return loc, scale
757
+ else:
758
+ raise RuntimeError("_get_loc_scale() called on stateless vecnorm.")
759
+
760
+ def __getstate__(self) -> dict[str, Any]:
761
+ state = super().__getstate__()
762
+ _lock = state.pop("lock", None)
763
+ if _lock is not None:
764
+ state["lock_placeholder"] = None
765
+ return state
766
+
767
+ def __setstate__(self, state: dict[str, Any]):
768
+ if "lock_placeholder" in state:
769
+ state.pop("lock_placeholder")
770
+ _lock = mp.Lock()
771
+ state["lock"] = _lock
772
+ super().__setstate__(state)
773
+
774
+ SEP = ".-|-."
775
+
776
+ def set_extra_state(self, state: OrderedDict) -> None:
777
+ if not self.stateful:
778
+ return
779
+ if not state:
780
+ if self._loc is None:
781
+ # we're good, not init yet
782
+ return
783
+ raise RuntimeError(
784
+ "set_extra_state() called with a void state-dict while the instance is initialized."
785
+ )
786
+ td = TensorDict(state).unflatten_keys(self.SEP)
787
+ if self._loc is None and not all(v.is_shared() for v in td.values(True, True)):
788
+ warnings.warn(
789
+ "VecNorm wasn't initialized and the tensordict is not shared. In single "
790
+ "process settings, this is ok, but if you need to share the statistics "
791
+ "between workers this should require some attention. "
792
+ "Make sure that the content of VecNorm is transmitted to the workers "
793
+ "after calling load_state_dict and not before, as other workers "
794
+ "may not have access to the loaded TensorDict."
795
+ )
796
+ td.share_memory_()
797
+ self._loc = td["loc"]
798
+ self._var = td["var"]
799
+ self._count = td["count"]
800
+
801
+ def get_extra_state(self) -> OrderedDict:
802
+ if not self.stateful:
803
+ return {}
804
+ if self._loc is None:
805
+ warnings.warn(
806
+ "Querying state_dict on an uninitialized VecNorm transform will "
807
+ "return a `None` value for the summary statistics. "
808
+ "Loading such a state_dict on an initialized VecNorm will result in "
809
+ "an error."
810
+ )
811
+ return {}
812
+ td = TensorDict(
813
+ loc=self._loc,
814
+ var=self._var,
815
+ count=self._count,
816
+ )
817
+ return td.flatten_keys(self.SEP).to_dict()
818
+
819
+ @property
820
+ def loc(self):
821
+ """Returns a TensorDict with the loc to be used for an affine transform."""
822
+ if not self.stateful:
823
+ raise RuntimeError("loc cannot be computed with stateless vecnorm.")
824
+ # We can't cache that value bc the summary stats could be updated by a different process
825
+ loc, _ = self._get_loc_scale(loc_only=True)
826
+ return loc
827
+
828
+ @property
829
+ def scale(self):
830
+ """Returns a TensorDict with the scale to be used for an affine transform."""
831
+ if not self.stateful:
832
+ raise RuntimeError("scale cannot be computed with stateless vecnorm.")
833
+ # We can't cache that value bc the summary stats could be updated by a different process
834
+ _, scale = self._get_loc_scale()
835
+ return scale
836
+
837
+ @property
838
+ def standard_normal(self):
839
+ """Whether the affine transform given by `loc` and `scale` follows the standard normal equation.
840
+
841
+ Similar to :class:`~torchrl.envs.ObservationNorm` standard_normal attribute.
842
+
843
+ Always returns ``True``.
844
+ """
845
+ return True