torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,814 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+ from typing import TYPE_CHECKING
9
+
10
+ import numpy as np
11
+ import torch
12
+ from tensordict import NestedKey, TensorDictBase
13
+ from tensordict.nn import (
14
+ TensorDictModule,
15
+ TensorDictModuleBase,
16
+ TensorDictModuleWrapper,
17
+ )
18
+ from tensordict.utils import expand_as_right, expand_right
19
+ from torch import nn
20
+
21
+ from torchrl.data.tensor_specs import Composite, TensorSpec
22
+ from torchrl.envs.utils import exploration_type, ExplorationType
23
+ from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
24
+
25
+ if TYPE_CHECKING:
26
+ from torchrl.envs import EnvBase
27
+
28
+ __all__ = [
29
+ "EGreedyWrapper",
30
+ "EGreedyModule",
31
+ "AdditiveGaussianModule",
32
+ "OrnsteinUhlenbeckProcessModule",
33
+ "OrnsteinUhlenbeckProcessWrapper",
34
+ "set_exploration_modules_spec_from_env",
35
+ ]
36
+
37
+
38
+ class EGreedyModule(TensorDictModuleBase):
39
+ """Epsilon-Greedy exploration module.
40
+
41
+ This module randomly updates the action(s) in a tensordict given an epsilon greedy exploration strategy.
42
+ At each call, random draws (one per action) are executed given a certain probability threshold. If successful,
43
+ the corresponding actions are being replaced by random samples drawn from the action spec provided.
44
+ Others are left unchanged.
45
+
46
+ Args:
47
+ spec (TensorSpec): the spec used for sampling actions.
48
+ eps_init (scalar, optional): initial epsilon value.
49
+ default: 1.0
50
+ eps_end (scalar, optional): final epsilon value.
51
+ default: 0.1
52
+ annealing_num_steps (int, optional): number of steps it will take for epsilon to reach
53
+ the ``eps_end`` value. Defaults to `1000`.
54
+
55
+ Keyword Args:
56
+ action_key (NestedKey, optional): the key where the action can be found in the input tensordict.
57
+ Default is ``"action"``.
58
+ action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict.
59
+ Default is ``None`` (corresponding to no mask).
60
+ device (torch.device, optional): the device of the exploration module.
61
+
62
+ .. note::
63
+ It is crucial to incorporate a call to :meth:`step` in the training loop
64
+ to update the exploration factor.
65
+ Since it is not easy to capture this omission no warning or exception
66
+ will be raised if this is omitted!
67
+
68
+ Examples:
69
+ >>> import torch
70
+ >>> from tensordict import TensorDict
71
+ >>> from tensordict.nn import TensorDictSequential
72
+ >>> from torchrl.modules import EGreedyModule, Actor
73
+ >>> from torchrl.data import Bounded
74
+ >>> torch.manual_seed(0)
75
+ >>> spec = Bounded(-1, 1, torch.Size([4]))
76
+ >>> module = torch.nn.Linear(4, 4, bias=False)
77
+ >>> policy = Actor(spec=spec, module=module)
78
+ >>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2))
79
+ >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
80
+ >>> print(explorative_policy(td).get("action"))
81
+ tensor([[ 0.0000, 0.0000, 0.0000, 0.0000],
82
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
83
+ [ 0.9055, -0.9277, -0.6295, -0.2532],
84
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
85
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
86
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
87
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
88
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
89
+ [ 0.0000, 0.0000, 0.0000, 0.0000],
90
+ [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<AddBackward0>)
91
+
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ spec: TensorSpec,
97
+ eps_init: float = 1.0,
98
+ eps_end: float = 0.1,
99
+ annealing_num_steps: int = 1000,
100
+ *,
101
+ action_key: NestedKey | None = "action",
102
+ action_mask_key: NestedKey | None = None,
103
+ device: torch.device | None = None,
104
+ ):
105
+ if not isinstance(eps_init, float):
106
+ warnings.warn("eps_init should be a float.")
107
+ if eps_end > eps_init:
108
+ raise RuntimeError("eps should decrease over time or be constant")
109
+ self.action_key = action_key
110
+ self.action_mask_key = action_mask_key
111
+ in_keys = [self.action_key]
112
+ if self.action_mask_key is not None:
113
+ in_keys.append(self.action_mask_key)
114
+ self.in_keys = in_keys
115
+ self.out_keys = [self.action_key]
116
+
117
+ super().__init__()
118
+
119
+ self.register_buffer("eps_init", torch.as_tensor(eps_init, device=device))
120
+ self.register_buffer("eps_end", torch.as_tensor(eps_end, device=device))
121
+ self.annealing_num_steps = annealing_num_steps
122
+ self.register_buffer(
123
+ "eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device)
124
+ )
125
+
126
+ if spec is not None:
127
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
128
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
129
+ if device is not None:
130
+ spec = spec.to(device)
131
+ self._spec = spec
132
+
133
+ @property
134
+ def spec(self):
135
+ return self._spec
136
+
137
+ def step(self, frames: int = 1) -> None:
138
+ """A step of epsilon decay.
139
+
140
+ After `self.annealing_num_steps` calls to this method, calls result in no-op.
141
+
142
+ Args:
143
+ frames (int, optional): number of frames since last step. Defaults to ``1``.
144
+
145
+ """
146
+ for _ in range(frames):
147
+ self.eps.data.copy_(
148
+ torch.maximum(
149
+ self.eps_end,
150
+ (
151
+ self.eps
152
+ - (self.eps_init - self.eps_end) / self.annealing_num_steps
153
+ ),
154
+ )
155
+ )
156
+
157
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
158
+ expl = exploration_type()
159
+ if expl in (ExplorationType.RANDOM, None):
160
+ if isinstance(self.action_key, tuple) and len(self.action_key) > 1:
161
+ action_tensordict = tensordict.get(self.action_key[:-1])
162
+ action_key = self.action_key[-1]
163
+ else:
164
+ action_tensordict = tensordict
165
+ action_key = self.action_key
166
+
167
+ action = action_tensordict.get(action_key)
168
+ eps = self.eps
169
+ device = eps.device
170
+ action_device = action.device
171
+ if action_device is not None and action_device != device:
172
+ raise RuntimeError(
173
+ f"Expected action and e-greedy module to be on the same device, but got {action.device=} and e-greedy device={device}."
174
+ )
175
+ cond = torch.rand(action_tensordict.shape, device=device) < eps
176
+ # cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
177
+ cond = expand_as_right(cond, action)
178
+ spec = self.spec
179
+ if spec is not None:
180
+ if isinstance(spec, Composite):
181
+ spec = spec[self.action_key]
182
+ if spec.shape != action.shape:
183
+ # In batched envs if the spec is passed unbatched, the rand() will not
184
+ # cover all batched dims
185
+ if (
186
+ not len(spec.shape)
187
+ or action.shape[-len(spec.shape) :] == spec.shape
188
+ ):
189
+ spec = spec.expand(action.shape)
190
+ else:
191
+ raise ValueError(
192
+ "Action spec shape does not match the action shape"
193
+ )
194
+ if self.action_mask_key is not None:
195
+ action_mask = tensordict.get(self.action_mask_key, None)
196
+ if action_mask is None:
197
+ raise KeyError(
198
+ f"Action mask key {self.action_mask_key} not found in {tensordict}."
199
+ )
200
+ spec.update_mask(action_mask)
201
+ r = spec.rand()
202
+ if r.device != device:
203
+ r = r.to(device)
204
+ action = torch.where(cond, r, action)
205
+ else:
206
+ raise RuntimeError("spec must be provided to the exploration wrapper.")
207
+ action_tensordict.set(action_key, action)
208
+ return tensordict
209
+
210
+
211
+ class EGreedyWrapper(TensorDictModuleWrapper):
212
+ """[Deprecated] Epsilon-Greedy PO wrapper."""
213
+
214
+ def __init__(
215
+ self,
216
+ policy: TensorDictModule,
217
+ *,
218
+ eps_init: float = 1.0,
219
+ eps_end: float = 0.1,
220
+ annealing_num_steps: int = 1000,
221
+ action_key: NestedKey | None = "action",
222
+ action_mask_key: NestedKey | None = None,
223
+ spec: TensorSpec | None = None,
224
+ ):
225
+ raise RuntimeError(
226
+ "This class has been deprecated in favor of torchrl.modules.EGreedyModule."
227
+ )
228
+
229
+
230
+ class AdditiveGaussianWrapper(TensorDictModuleWrapper):
231
+ """[Deprecated] Additive Gaussian PO wrapper."""
232
+
233
+ def __init__(
234
+ self,
235
+ policy: TensorDictModule,
236
+ *,
237
+ sigma_init: float = 1.0,
238
+ sigma_end: float = 0.1,
239
+ annealing_num_steps: int = 1000,
240
+ mean: float = 0.0,
241
+ std: float = 1.0,
242
+ action_key: NestedKey | None = "action",
243
+ spec: TensorSpec | None = None,
244
+ safe: bool | None = True,
245
+ device: torch.device | None = None,
246
+ ):
247
+ raise RuntimeError(
248
+ "This module has been removed from TorchRL. Please use torchrl.modules.AdditiveGaussianModule instead."
249
+ )
250
+
251
+
252
+ class AdditiveGaussianModule(TensorDictModuleBase):
253
+ """Additive Gaussian PO module.
254
+
255
+ Args:
256
+ spec (TensorSpec, optional): the spec used for sampling actions. The sampled
257
+ action will be projected onto the valid action space once explored.
258
+ Can be ``None`` for delayed initialization, in which case the spec
259
+ must be set via the :attr:`spec` property setter before calling
260
+ :meth:`forward`.
261
+ default: None
262
+ sigma_init (scalar, optional): initial epsilon value.
263
+ default: 1.0
264
+ sigma_end (scalar, optional): final epsilon value.
265
+ default: 0.1
266
+ annealing_num_steps (int, optional): number of steps it will take for
267
+ sigma to reach the :obj:`sigma_end` value.
268
+ default: 1000
269
+ mean (:obj:`float`, optional): mean of each output element's normal distribution.
270
+ default: 0.0
271
+ std (:obj:`float`, optional): standard deviation of each output element's normal distribution.
272
+ default: 1.0
273
+
274
+ Keyword Args:
275
+ action_key (NestedKey, optional): if the policy module has more than one output key,
276
+ its output spec will be of type Composite. One needs to know where to
277
+ find the action spec.
278
+ default: "action"
279
+ safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
280
+ given the :obj:`TensorSpec.project` heuristic.
281
+ default: False
282
+ device (torch.device, optional): the device where the buffers have to be stored.
283
+
284
+ .. note::
285
+ It is
286
+ crucial to incorporate a call to :meth:`step` in the training loop
287
+ to update the exploration factor.
288
+ Since it is not easy to capture this omission no warning or exception
289
+ will be raised if this is omitted!
290
+
291
+
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ spec: TensorSpec | None = None,
297
+ sigma_init: float = 1.0,
298
+ sigma_end: float = 0.1,
299
+ annealing_num_steps: int = 1000,
300
+ mean: float = 0.0,
301
+ std: float = 1.0,
302
+ *,
303
+ action_key: NestedKey | None = "action",
304
+ # safe is already implemented because we project in the noise addition
305
+ safe: bool = False,
306
+ device: torch.device | None = None,
307
+ ):
308
+ if not isinstance(sigma_init, float):
309
+ warnings.warn("eps_init should be a float.")
310
+ if sigma_end > sigma_init:
311
+ raise RuntimeError("sigma should decrease over time or be constant")
312
+ self.action_key = action_key
313
+ self.in_keys = [self.action_key]
314
+ self.out_keys = [self.action_key]
315
+
316
+ super().__init__()
317
+
318
+ self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
319
+ self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device))
320
+ self.annealing_num_steps = annealing_num_steps
321
+ self.register_buffer("mean", torch.tensor(mean, device=device))
322
+ self.register_buffer("std", torch.tensor(std, device=device))
323
+ self.register_buffer(
324
+ "sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device)
325
+ )
326
+
327
+ # spec can be None for delayed initialization. In this case, it must be
328
+ # set via the spec property before forward() is called.
329
+ if (
330
+ spec is not None
331
+ and not isinstance(spec, Composite)
332
+ and len(self.out_keys) >= 1
333
+ ):
334
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
335
+ self._spec = spec
336
+ self.safe = safe
337
+ if self.safe:
338
+ self.register_forward_hook(_forward_hook_safe_action)
339
+
340
+ @property
341
+ def spec(self):
342
+ return self._spec
343
+
344
+ @spec.setter
345
+ def spec(self, value: TensorSpec) -> None:
346
+ if not isinstance(value, Composite) and len(self.out_keys) >= 1:
347
+ value = Composite({self.action_key: value}, shape=value.shape[:-1])
348
+ self._spec = value
349
+
350
+ def step(self, frames: int = 1) -> None:
351
+ """A step of sigma decay.
352
+
353
+ After `self.annealing_num_steps` calls to this method, calls result in no-op.
354
+
355
+ Args:
356
+ frames (int): number of frames since last step. Defaults to ``1``.
357
+
358
+ """
359
+ for _ in range(frames):
360
+ self.sigma.data.copy_(
361
+ torch.maximum(
362
+ self.sigma_end,
363
+ (
364
+ self.sigma
365
+ - (self.sigma_init - self.sigma_end) / self.annealing_num_steps
366
+ ),
367
+ )
368
+ )
369
+
370
+ def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
371
+ if self._spec is None:
372
+ raise RuntimeError(
373
+ "spec has not been set. Pass spec at construction time or set it via "
374
+ "the `spec` property before calling forward()."
375
+ )
376
+ sigma = self.sigma
377
+ mean = self.mean.expand(action.shape)
378
+ std = self.std.expand(action.shape)
379
+ if not mean.dtype.is_floating_point:
380
+ mean = mean.to(torch.get_default_dtype())
381
+ if not std.dtype.is_floating_point:
382
+ std = std.to(torch.get_default_dtype())
383
+ noise = torch.normal(mean=mean, std=std)
384
+ if noise.device != action.device:
385
+ noise = noise.to(action.device)
386
+ action = action + noise * sigma
387
+ spec = self.spec[self.action_key]
388
+ action = spec.project(action)
389
+ return action
390
+
391
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
392
+ if exploration_type() is ExplorationType.RANDOM or exploration_type() is None:
393
+ out = tensordict.get(self.action_key)
394
+ out = self._add_noise(out)
395
+ tensordict.set(self.action_key, out)
396
+ return tensordict
397
+
398
+
399
+ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
400
+ """[Deprecated] Ornstein-Uhlenbeck exploration policy wrapper."""
401
+
402
+ def __init__(
403
+ self,
404
+ policy: TensorDictModule,
405
+ *,
406
+ eps_init: float = 1.0,
407
+ eps_end: float = 0.1,
408
+ annealing_num_steps: int = 1000,
409
+ theta: float = 0.15,
410
+ mu: float = 0.0,
411
+ sigma: float = 0.2,
412
+ dt: float = 1e-2,
413
+ x0: torch.Tensor | np.ndarray | None = None,
414
+ sigma_min: float | None = None,
415
+ n_steps_annealing: int = 1000,
416
+ action_key: NestedKey | None = "action",
417
+ is_init_key: NestedKey | None = "is_init",
418
+ spec: TensorSpec = None,
419
+ safe: bool = True,
420
+ key: NestedKey | None = None,
421
+ device: torch.device | None = None,
422
+ ):
423
+ raise RuntimeError(
424
+ "OrnsteinUhlenbeckProcessWrapper has been removed. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule instead."
425
+ )
426
+
427
+
428
+ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):
429
+ r"""Ornstein-Uhlenbeck exploration policy module.
430
+
431
+ Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf.
432
+
433
+ The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration
434
+ noise. This enables a sort of 'structured' exploration.
435
+
436
+ Noise equation:
437
+
438
+ .. math::
439
+ noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W
440
+
441
+ Sigma equation:
442
+
443
+ .. math::
444
+ \sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma))
445
+
446
+ To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys
447
+ will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset,
448
+ indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive
449
+ trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of
450
+ zeroing the tensordict at reset time.
451
+
452
+ .. note::
453
+ It is
454
+ crucial to incorporate a call to :meth:`step` in the training loop
455
+ to update the exploration factor.
456
+ Since it is not easy to capture this omission no warning or exception
457
+ will be raised if this is omitted!
458
+
459
+ Args:
460
+ spec (TensorSpec): the spec used for sampling actions. The sampled
461
+ action will be projected onto the valid action space once explored.
462
+ eps_init (scalar): initial epsilon value, determining the amount of noise to be added.
463
+ default: 1.0
464
+ eps_end (scalar): final epsilon value, determining the amount of noise to be added.
465
+ default: 0.1
466
+ annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value.
467
+ default: 1000
468
+ theta (scalar): theta factor in the noise equation
469
+ default: 0.15
470
+ mu (scalar): OU average (mu in the noise equation).
471
+ default: 0.0
472
+ sigma (scalar): sigma value in the sigma equation.
473
+ default: 0.2
474
+ dt (scalar): dt in the noise equation.
475
+ default: 0.01
476
+ x0 (Tensor, ndarray, optional): initial value of the process.
477
+ default: 0.0
478
+ sigma_min (number, optional): sigma_min in the sigma equation.
479
+ default: None
480
+ n_steps_annealing (int): number of steps for the sigma annealing.
481
+ default: 1000
482
+
483
+ Keyword Args:
484
+ action_key (NestedKey, optional): key of the action to be modified.
485
+ default: "action"
486
+ is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
487
+ default: "is_init"
488
+ safe (boolean, optional): if False, the TensorSpec can be None. If it
489
+ is set to False but the spec is passed, the projection will still
490
+ happen.
491
+ Default is True.
492
+ device (torch.device, optional): the device where the buffers have to be stored.
493
+
494
+ Examples:
495
+ >>> import torch
496
+ >>> from tensordict import TensorDict
497
+ >>> from tensordict.nn import TensorDictSequential
498
+ >>> from torchrl.data import Bounded
499
+ >>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor
500
+ >>> torch.manual_seed(0)
501
+ >>> spec = Bounded(-1, 1, torch.Size([4]))
502
+ >>> module = torch.nn.Linear(4, 4, bias=False)
503
+ >>> policy = Actor(module=module, spec=spec)
504
+ >>> ou = OrnsteinUhlenbeckProcessModule(spec=spec)
505
+ >>> explorative_policy = TensorDictSequential(policy, ou)
506
+ >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
507
+ >>> print(explorative_policy(td))
508
+ TensorDict(
509
+ fields={
510
+ _ou_prev_noise: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
511
+ _ou_steps: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
512
+ action: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
513
+ observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
514
+ batch_size=torch.Size([10]),
515
+ device=None,
516
+ is_shared=False)
517
+ """
518
+
519
+ def __init__(
520
+ self,
521
+ spec: TensorSpec,
522
+ eps_init: float = 1.0,
523
+ eps_end: float = 0.1,
524
+ annealing_num_steps: int = 1000,
525
+ theta: float = 0.15,
526
+ mu: float = 0.0,
527
+ sigma: float = 0.2,
528
+ dt: float = 1e-2,
529
+ x0: torch.Tensor | np.ndarray | None = None,
530
+ sigma_min: float | None = None,
531
+ n_steps_annealing: int = 1000,
532
+ *,
533
+ action_key: NestedKey = "action",
534
+ is_init_key: NestedKey = "is_init",
535
+ safe: bool = True,
536
+ device: torch.device | None = None,
537
+ ):
538
+ super().__init__()
539
+
540
+ self.ou = _OrnsteinUhlenbeckProcess(
541
+ theta=theta,
542
+ mu=mu,
543
+ sigma=sigma,
544
+ dt=dt,
545
+ x0=x0,
546
+ sigma_min=sigma_min,
547
+ n_steps_annealing=n_steps_annealing,
548
+ key=action_key,
549
+ device=device,
550
+ )
551
+
552
+ self.register_buffer("eps_init", torch.tensor(eps_init, device=device))
553
+ self.register_buffer("eps_end", torch.tensor(eps_end, device=device))
554
+ if self.eps_end > self.eps_init:
555
+ raise ValueError(
556
+ "eps should decrease over time or be constant, "
557
+ f"got eps_init={eps_init} and eps_end={eps_end}"
558
+ )
559
+ self.annealing_num_steps = annealing_num_steps
560
+ self.register_buffer(
561
+ "eps", torch.tensor(eps_init, dtype=torch.float32, device=device)
562
+ )
563
+
564
+ self.in_keys = [self.ou.key]
565
+ self.out_keys = [self.ou.key] + self.ou.out_keys
566
+ self.is_init_key = is_init_key
567
+ noise_key = self.ou.noise_key
568
+ steps_key = self.ou.steps_key
569
+
570
+ if spec is not None:
571
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
572
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
573
+ self._spec = spec
574
+ else:
575
+ raise RuntimeError("spec cannot be None.")
576
+ ou_specs = {
577
+ noise_key: None,
578
+ steps_key: None,
579
+ }
580
+ self._spec.update(ou_specs)
581
+ if len(set(self.out_keys)) != len(self.out_keys):
582
+ raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}")
583
+ self.safe = safe
584
+ if self.safe:
585
+ self.register_forward_hook(_forward_hook_safe_action)
586
+
587
+ @property
588
+ def spec(self):
589
+ return self._spec
590
+
591
+ def step(self, frames: int = 1) -> None:
592
+ """Updates the eps noise factor.
593
+
594
+ Args:
595
+ frames (int): number of frames of the current batch (corresponding to the number of updates to be made).
596
+
597
+ """
598
+ for _ in range(frames):
599
+ if self.annealing_num_steps > 0:
600
+ self.eps.data.copy_(
601
+ torch.maximum(
602
+ self.eps_end,
603
+ (
604
+ self.eps
605
+ - (self.eps_init - self.eps_end) / self.annealing_num_steps
606
+ ),
607
+ )
608
+ )
609
+ else:
610
+ raise ValueError(
611
+ f"{self.__class__.__name__}.step() called when "
612
+ f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive "
613
+ f"number of frames."
614
+ )
615
+
616
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
617
+ if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
618
+ is_init = tensordict.get(self.is_init_key, None)
619
+ if is_init is None:
620
+ warnings.warn(
621
+ f"The tensordict passed to {self.__class__.__name__} appears to be "
622
+ f"missing the '{self.is_init_key}' entry. This entry is used to "
623
+ f"reset the noise at the beginning of a trajectory, without it "
624
+ f"the behavior of this exploration method is undefined. "
625
+ f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
626
+ f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
627
+ f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
628
+ )
629
+ tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init)
630
+ return tensordict
631
+
632
+
633
+ # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
634
+ class _OrnsteinUhlenbeckProcess(nn.Module):
635
+ def __init__(
636
+ self,
637
+ theta: float,
638
+ mu: float = 0.0,
639
+ sigma: float = 0.2,
640
+ dt: float = 1e-2,
641
+ x0: torch.Tensor | np.ndarray | None = None,
642
+ sigma_min: float | None = None,
643
+ n_steps_annealing: int = 1000,
644
+ key: NestedKey | None = "action",
645
+ is_init_key: NestedKey | None = "is_init",
646
+ device: torch.device | None = None,
647
+ ):
648
+ super().__init__()
649
+ self.register_buffer("_empty_tensor_device", torch.zeros(0, device=device))
650
+
651
+ self.mu = mu
652
+ self.sigma = sigma
653
+
654
+ if sigma_min is not None:
655
+ self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
656
+ self.c = sigma
657
+ self.sigma_min = sigma_min
658
+ else:
659
+ self.m = 0.0
660
+ self.c = sigma
661
+ self.sigma_min = sigma
662
+
663
+ self.theta = theta
664
+ self.mu = mu
665
+ self.dt = dt
666
+ self.x0 = x0 if x0 is not None else 0.0
667
+ self.key = key
668
+ self.is_init_key = is_init_key
669
+ self._noise_key = "_ou_prev_noise"
670
+ self._steps_key = "_ou_steps"
671
+ self.out_keys = [self.noise_key, self.steps_key]
672
+ self._auto_buffer()
673
+
674
+ def _auto_buffer(self):
675
+ for key, item in list(self.__dict__.items()):
676
+ if isinstance(item, torch.Tensor):
677
+ delattr(self, key)
678
+ self.register_buffer(key, item)
679
+
680
+ @property
681
+ def noise_key(self):
682
+ return self._noise_key # + str(id(self))
683
+
684
+ @property
685
+ def steps_key(self):
686
+ return self._steps_key # + str(id(self))
687
+
688
+ def _make_noise_pair(
689
+ self,
690
+ action_tensordict: TensorDictBase,
691
+ tensordict: TensorDictBase,
692
+ is_init: torch.Tensor,
693
+ ):
694
+ device = tensordict.device
695
+ if device is None:
696
+ device = self._empty_tensor_device.device
697
+
698
+ if self.steps_key not in tensordict.keys():
699
+ noise = torch.zeros(tensordict.get(self.key).shape, device=device)
700
+ steps = torch.zeros(
701
+ action_tensordict.batch_size, dtype=torch.long, device=device
702
+ )
703
+ tensordict.set(self.noise_key, noise)
704
+ tensordict.set(self.steps_key, steps)
705
+ else:
706
+ # We must clone for cudagraph, otherwise the same tensor may re-enter the compiled region
707
+ noise = tensordict.get(self.noise_key).clone()
708
+ steps = tensordict.get(self.steps_key).clone()
709
+ if is_init is not None:
710
+ noise = torch.masked_fill(noise, expand_right(is_init, noise.shape), 0)
711
+ steps = torch.masked_fill(steps, expand_right(is_init, steps.shape), 0)
712
+ return noise, steps
713
+
714
+ def add_sample(
715
+ self,
716
+ tensordict: TensorDictBase,
717
+ eps: float = 1.0,
718
+ is_init: torch.Tensor | None = None,
719
+ ) -> TensorDictBase:
720
+
721
+ # Get the nested tensordict where the action lives
722
+ if isinstance(self.key, tuple) and len(self.key) > 1:
723
+ action_tensordict = tensordict.get(self.key[:-1])
724
+ else:
725
+ action_tensordict = tensordict
726
+
727
+ if is_init is None:
728
+ is_init = tensordict.get(self.is_init_key, None)
729
+ if (
730
+ is_init is not None
731
+ ): # is_init has the shape of done_spec, let's bring it to the action_tensordict shape
732
+ if is_init.ndim > 1 and is_init.shape[-1] == 1:
733
+ is_init = is_init.squeeze(-1) # Squeeze dangling dim
734
+ if (
735
+ action_tensordict.ndim >= is_init.ndim
736
+ ): # if is_init has fewer dimensions than action_tensordict we expand it
737
+ is_init = expand_right(is_init, action_tensordict.shape)
738
+ else:
739
+ is_init = is_init.sum(
740
+ tuple(range(action_tensordict.batch_dims, is_init.ndim)),
741
+ dtype=torch.bool,
742
+ ) # otherwise we reduce it to that batch_size
743
+ if is_init.shape != action_tensordict.shape:
744
+ raise ValueError(
745
+ f"'{self.is_init_key}' shape not compatible with action tensordict shape, "
746
+ f"got {tensordict.get(self.is_init_key).shape} and {action_tensordict.shape}"
747
+ )
748
+
749
+ prev_noise, n_steps = self._make_noise_pair(
750
+ action_tensordict, tensordict, is_init
751
+ )
752
+
753
+ prev_noise = prev_noise + self.x0
754
+ noise = (
755
+ prev_noise
756
+ + self.theta * (self.mu - prev_noise) * self.dt
757
+ + self.current_sigma(expand_as_right(n_steps, prev_noise))
758
+ * np.sqrt(self.dt)
759
+ * torch.randn_like(prev_noise)
760
+ )
761
+ tensordict.set(self.noise_key, noise - self.x0)
762
+ tensordict.set(self.key, tensordict.get(self.key) + eps * noise)
763
+ tensordict.set(self.steps_key, n_steps + 1)
764
+ return tensordict
765
+
766
+ def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor:
767
+ sigma = (self.m * n_steps + self.c).clamp_min(self.sigma_min)
768
+ return sigma
769
+
770
+
771
+ class RandomPolicy:
772
+ """A random policy for data collectors.
773
+
774
+ This is a wrapper around the action_spec.rand method.
775
+
776
+ Args:
777
+ action_spec: TensorSpec object describing the action specs
778
+
779
+ Examples:
780
+ >>> from tensordict import TensorDict
781
+ >>> from torchrl.data.tensor_specs import Bounded
782
+ >>> action_spec = Bounded(-torch.ones(3), torch.ones(3))
783
+ >>> actor = RandomPolicy(action_spec=action_spec)
784
+ >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1]
785
+ """
786
+
787
+ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
788
+ super().__init__()
789
+ self.action_spec = action_spec.clone()
790
+ self.action_key = action_key
791
+
792
+ def __call__(self, td: TensorDictBase) -> TensorDictBase:
793
+ if isinstance(self.action_spec, Composite):
794
+ return td.update(self.action_spec.rand())
795
+ else:
796
+ return td.set(self.action_key, self.action_spec.rand())
797
+
798
+
799
+ def set_exploration_modules_spec_from_env(policy: nn.Module, env: EnvBase) -> None:
800
+ """Sets exploration module specs from an environment action spec.
801
+
802
+ This is intended for cases where exploration modules (e.g. AdditiveGaussianModule)
803
+ are instantiated with ``spec=None`` and must be configured once the environment
804
+ is known (e.g. inside a collector).
805
+ """
806
+ action_spec = (
807
+ env.action_spec_unbatched
808
+ if hasattr(env, "action_spec_unbatched")
809
+ else env.action_spec
810
+ )
811
+
812
+ for submodule in policy.modules():
813
+ if isinstance(submodule, AdditiveGaussianModule) and submodule._spec is None:
814
+ submodule.spec = action_spec