torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,830 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import weakref
8
+ from collections.abc import Callable, Sequence
9
+ from numbers import Number
10
+
11
+ import numpy as np
12
+ import torch
13
+ from packaging import version
14
+ from torch import distributions as D, nn
15
+ from torch.distributions import constraints
16
+ from torch.distributions.transforms import _InverseTransform
17
+
18
+ from torchrl._utils import safe_is_current_stream_capturing
19
+ from torchrl.modules.distributions.truncated_normal import (
20
+ TruncatedNormal as _TruncatedNormal,
21
+ )
22
+ from torchrl.modules.distributions.utils import (
23
+ _cast_device,
24
+ FasterTransformedDistribution,
25
+ safeatanh_noeps,
26
+ safetanh_noeps,
27
+ )
28
+
29
+ # speeds up distribution construction
30
+ D.Distribution.set_default_validate_args(False)
31
+
32
+ try:
33
+ from torch.compiler import assume_constant_result
34
+ except ImportError:
35
+ from torch._dynamo import assume_constant_result
36
+
37
+ try:
38
+ from torch.compiler import is_compiling
39
+ except ImportError:
40
+ from torch._dynamo import is_compiling
41
+
42
+ TORCH_VERSION = version.parse(torch.__version__).base_version
43
+ TORCH_VERSION_PRE_2_6 = version.parse(TORCH_VERSION) < version.parse("2.6.0")
44
+
45
+
46
+ class IndependentNormal(D.Independent):
47
+ """Implements a Normal distribution with location scaling.
48
+
49
+ Location scaling prevents the location to be "too far" from 0, which ultimately
50
+ leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion).
51
+ In practice, the location is computed according to
52
+
53
+ .. math::
54
+ loc = tanh(loc / upscale) * upscale.
55
+
56
+ This behavior can be disabled by switching off the tanh_loc parameter (see below).
57
+
58
+
59
+ Args:
60
+ loc (torch.Tensor): normal distribution location parameter
61
+ scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
62
+ Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
63
+ Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
64
+ avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
65
+ in :func:`torch.compile`.
66
+ upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
67
+
68
+ .. math::
69
+ loc = tanh(loc / upscale) * upscale.
70
+
71
+ Default is 5.0
72
+
73
+ tanh_loc (bool, optional): if ``False``, the above formula is used for
74
+ the location scaling, otherwise the raw value
75
+ is kept. Default is ``False``;
76
+
77
+ Example:
78
+ >>> import torch
79
+ >>> from functools import partial
80
+ >>> from torchrl.modules.distributions import IndependentNormal
81
+ >>> loc = torch.zeros(3, 4)
82
+ >>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
83
+ >>> dist = IndependentNormal(loc, scale=torch.ones_like)
84
+ >>> # For a custom scale value, use partial to create a callable
85
+ >>> dist = IndependentNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
86
+ >>> sample = dist.sample()
87
+ >>> sample.shape
88
+ torch.Size([3, 4])
89
+
90
+ """
91
+
92
+ num_params: int = 2
93
+
94
+ def __init__(
95
+ self,
96
+ loc: torch.Tensor,
97
+ scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
98
+ upscale: float = 5.0,
99
+ tanh_loc: bool = False,
100
+ event_dim: int = 1,
101
+ **kwargs,
102
+ ):
103
+ self.tanh_loc = tanh_loc
104
+ self.upscale = upscale
105
+ self._event_dim = event_dim
106
+ self._kwargs = kwargs
107
+ # Support callable scale (e.g., torch.ones_like) for compile-friendliness
108
+ if callable(scale) and not isinstance(scale, torch.Tensor):
109
+ scale = scale(loc)
110
+ elif not isinstance(scale, torch.Tensor):
111
+ scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
112
+ elif scale.device != loc.device:
113
+ scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
114
+ super().__init__(D.Normal(loc, scale, **kwargs), event_dim)
115
+
116
+ def update(self, loc, scale):
117
+ if self.tanh_loc:
118
+ loc = self.upscale * (loc / self.upscale).tanh()
119
+ # Support callable scale (e.g., torch.ones_like) for compile-friendliness
120
+ if callable(scale) and not isinstance(scale, torch.Tensor):
121
+ scale = scale(loc)
122
+ elif not isinstance(scale, torch.Tensor):
123
+ scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
124
+ elif scale.device != loc.device:
125
+ scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
126
+ super().__init__(D.Normal(loc, scale, **self._kwargs), self._event_dim)
127
+
128
+ @property
129
+ def mode(self):
130
+ return self.base_dist.mean
131
+
132
+ @property
133
+ def deterministic_sample(self):
134
+ return self.mean
135
+
136
+
137
+ class SafeTanhTransform(D.TanhTransform):
138
+ """TanhTransform subclass that ensured that the transformation is numerically invertible."""
139
+
140
+ def _call(self, x: torch.Tensor) -> torch.Tensor:
141
+ return safetanh_noeps(x)
142
+
143
+ def _inverse(self, y: torch.Tensor) -> torch.Tensor:
144
+ return safeatanh_noeps(y)
145
+
146
+ @property
147
+ def inv(self):
148
+ inv = None
149
+ if self._inv is not None:
150
+ inv = self._inv()
151
+ if inv is None:
152
+ inv = _InverseTransform(self)
153
+ if not is_compiling():
154
+ self._inv = weakref.ref(inv)
155
+ return inv
156
+
157
+
158
+ class NormalParamWrapper(nn.Module): # noqa: D101
159
+ def __init__(
160
+ self,
161
+ operator: nn.Module,
162
+ scale_mapping: str = "biased_softplus_1.0",
163
+ scale_lb: Number = 1e-4,
164
+ ) -> None:
165
+ raise RuntimeError(
166
+ "NormalParamWrapper has been deprecated in favor of `tensordict.nn.NormalParamExtractor`. Use this class instead."
167
+ )
168
+
169
+
170
+ class TruncatedNormal(D.Independent):
171
+ """Implements a Truncated Normal distribution with location scaling.
172
+
173
+ Location scaling prevents the location to be "too far" from 0, which ultimately
174
+ leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion).
175
+ In practice, the location is computed according to
176
+
177
+ .. math::
178
+ loc = tanh(loc / upscale) * upscale.
179
+
180
+ This behavior can be disabled by switching off the tanh_loc parameter (see below).
181
+
182
+
183
+ Args:
184
+ loc (torch.Tensor): normal distribution location parameter
185
+ scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
186
+ upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
187
+
188
+ .. math::
189
+ loc = tanh(loc / upscale) * upscale.
190
+
191
+ Default is 5.0
192
+
193
+ low (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0;
194
+ high (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0;
195
+ tanh_loc (bool, optional): if ``True``, the above formula is used for
196
+ the location scaling, otherwise the raw value is kept.
197
+ Default is ``False``;
198
+ """
199
+
200
+ num_params: int = 2
201
+
202
+ base_dist: _TruncatedNormal
203
+
204
+ arg_constraints = {
205
+ "loc": constraints.real,
206
+ "scale": constraints.greater_than(1e-6),
207
+ }
208
+
209
+ def __init__(
210
+ self,
211
+ loc: torch.Tensor,
212
+ scale: torch.Tensor,
213
+ upscale: torch.Tensor | float = 5.0,
214
+ low: torch.Tensor | float = -1.0,
215
+ high: torch.Tensor | float = 1.0,
216
+ tanh_loc: bool = False,
217
+ ):
218
+
219
+ err_msg = "TanhNormal high values must be strictly greater than low values"
220
+ if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
221
+ if not (high > low).all():
222
+ raise RuntimeError(err_msg)
223
+ elif isinstance(high, Number) and isinstance(low, Number):
224
+ if not high > low:
225
+ raise RuntimeError(err_msg)
226
+ else:
227
+ if not all(high > low):
228
+ raise RuntimeError(err_msg)
229
+
230
+ if isinstance(high, torch.Tensor):
231
+ self.non_trivial_max = (high != 1.0).any()
232
+ else:
233
+ self.non_trivial_max = high != 1.0
234
+
235
+ if isinstance(low, torch.Tensor):
236
+ self.non_trivial_min = (low != -1.0).any()
237
+ else:
238
+ self.non_trivial_min = low != -1.0
239
+ self.tanh_loc = tanh_loc
240
+
241
+ self.device = loc.device
242
+ self.upscale = torch.as_tensor(upscale, device=self.device)
243
+
244
+ high = torch.as_tensor(high, device=self.device)
245
+ low = torch.as_tensor(low, device=self.device)
246
+ self.low = low
247
+ self.high = high
248
+ self.update(loc, scale)
249
+
250
+ @property
251
+ def min(self):
252
+ self._warn_minmax()
253
+ return self.low
254
+
255
+ @property
256
+ def max(self):
257
+ self._warn_minmax()
258
+ return self.high
259
+
260
+ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
261
+ if self.tanh_loc:
262
+ loc = (loc / self.upscale).tanh() * self.upscale
263
+ self.loc = loc
264
+ self.scale = scale
265
+
266
+ base_dist = _TruncatedNormal(
267
+ loc,
268
+ scale,
269
+ a=self.low.expand_as(loc),
270
+ b=self.high.expand_as(scale),
271
+ device=self.device,
272
+ )
273
+ super().__init__(base_dist, 1, validate_args=False)
274
+
275
+ @property
276
+ def mode(self):
277
+ m = self.base_dist.loc
278
+ a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0
279
+ b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0
280
+ m = torch.min(torch.stack([m, b], -1), dim=-1)[0]
281
+ return torch.max(torch.stack([m, a], -1), dim=-1)[0]
282
+
283
+ @property
284
+ def deterministic_sample(self):
285
+ return self.mean
286
+
287
+ def log_prob(self, value, **kwargs):
288
+ above_or_below = (self.low > value) | (self.high < value)
289
+ a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0
290
+ a = a.expand_as(value)
291
+ b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0
292
+ b = b.expand_as(value)
293
+ value = torch.min(torch.stack([value, b], -1), dim=-1)[0]
294
+ value = torch.max(torch.stack([value, a], -1), dim=-1)[0]
295
+ lp = super().log_prob(value, **kwargs)
296
+ if above_or_below.any():
297
+ if self.event_shape:
298
+ above_or_below = above_or_below.flatten(-len(self.event_shape), -1).any(
299
+ -1
300
+ )
301
+ lp = torch.masked_fill(
302
+ lp,
303
+ above_or_below.expand_as(lp),
304
+ torch.tensor(-float("inf"), device=lp.device, dtype=lp.dtype),
305
+ )
306
+ return lp
307
+
308
+
309
+ class _PatchedComposeTransform(D.ComposeTransform):
310
+ @property
311
+ def inv(self):
312
+ inv = None
313
+ if self._inv is not None:
314
+ inv = self._inv()
315
+ if inv is None:
316
+ inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
317
+ if not is_compiling():
318
+ self._inv = weakref.ref(inv)
319
+ inv._inv = weakref.ref(self)
320
+ return inv
321
+
322
+
323
+ class _PatchedAffineTransform(D.AffineTransform):
324
+ @property
325
+ def inv(self):
326
+ inv = None
327
+ if self._inv is not None:
328
+ inv = self._inv()
329
+ if inv is None:
330
+ inv = _InverseTransform(self)
331
+ if not is_compiling():
332
+ self._inv = weakref.ref(inv)
333
+ return inv
334
+
335
+
336
+ class TanhNormal(FasterTransformedDistribution):
337
+ """Implements a TanhNormal distribution with location scaling.
338
+
339
+ Location scaling prevents the location to be "too far" from 0 when a
340
+ ``TanhTransform`` is applied, but ultimately
341
+ leads to numerically unstable samples and poor gradient computation
342
+ (e.g. gradient explosion).
343
+ In practice, with location scaling the location is computed according to
344
+
345
+ .. math::
346
+ loc = tanh(loc / upscale) * upscale.
347
+
348
+
349
+ Args:
350
+ loc (torch.Tensor): normal distribution location parameter
351
+ scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
352
+ Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
353
+ Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
354
+ avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
355
+ in :func:`torch.compile`.
356
+ upscale (torch.Tensor or number): 'a' scaling factor in the formula:
357
+
358
+ .. math::
359
+ loc = tanh(loc / upscale) * upscale.
360
+
361
+ low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
362
+ high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
363
+ event_dims (int, optional): number of dimensions describing the action.
364
+ Default is 1. Setting ``event_dims`` to ``0`` will result in a log-probability that has the same shape
365
+ as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc.
366
+ tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
367
+ value is kept. Default is ``False``;
368
+ safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
369
+ This will currently break with :func:`torch.compile`.
370
+
371
+ Example:
372
+ >>> import torch
373
+ >>> from functools import partial
374
+ >>> from torchrl.modules.distributions import TanhNormal
375
+ >>> loc = torch.zeros(3, 4)
376
+ >>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
377
+ >>> dist = TanhNormal(loc, scale=torch.ones_like)
378
+ >>> # For a custom scale value, use partial to create a callable
379
+ >>> dist = TanhNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
380
+ >>> sample = dist.sample()
381
+ >>> sample.shape
382
+ torch.Size([3, 4])
383
+
384
+ """
385
+
386
+ arg_constraints = {
387
+ "loc": constraints.real,
388
+ "scale": constraints.greater_than(1e-6),
389
+ }
390
+
391
+ num_params = 2
392
+
393
+ def __init__(
394
+ self,
395
+ loc: torch.Tensor,
396
+ scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
397
+ upscale: torch.Tensor | Number = 5.0,
398
+ low: torch.Tensor | Number = -1.0,
399
+ high: torch.Tensor | Number = 1.0,
400
+ event_dims: int | None = None,
401
+ tanh_loc: bool = False,
402
+ safe_tanh: bool = True,
403
+ ):
404
+ if not isinstance(loc, torch.Tensor):
405
+ loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
406
+ _non_blocking = loc.device.type == "cuda"
407
+ # Support callable scale (e.g., torch.ones_like) for compile-friendliness
408
+ if callable(scale) and not isinstance(scale, torch.Tensor):
409
+ scale = scale(loc)
410
+ elif not isinstance(scale, torch.Tensor):
411
+ scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
412
+ elif scale.device != loc.device:
413
+ scale = scale.to(loc.device, non_blocking=_non_blocking)
414
+ if event_dims is None:
415
+ event_dims = min(1, loc.ndim)
416
+
417
+ err_msg = "TanhNormal high values must be strictly greater than low values"
418
+ if not is_compiling() and not safe_is_current_stream_capturing():
419
+ if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
420
+ if not (high > low).all():
421
+ raise RuntimeError(err_msg)
422
+ elif isinstance(high, Number) and isinstance(low, Number):
423
+ if not high > low:
424
+ raise RuntimeError(err_msg)
425
+ else:
426
+ if not all(high > low):
427
+ raise RuntimeError(err_msg)
428
+
429
+ if not isinstance(high, torch.Tensor):
430
+ high = torch.as_tensor(high, device=loc.device)
431
+ elif high.device != loc.device:
432
+ high = high.to(loc.device, non_blocking=_non_blocking)
433
+ if not isinstance(low, torch.Tensor):
434
+ low = torch.as_tensor(low, device=loc.device)
435
+ elif low.device != loc.device:
436
+ low = low.to(loc.device, non_blocking=_non_blocking)
437
+ if not is_compiling() and not safe_is_current_stream_capturing():
438
+ self.non_trivial_max = (high != 1.0).any()
439
+ self.non_trivial_min = (low != -1.0).any()
440
+ else:
441
+ self.non_trivial_max = self.non_trivial_min = True
442
+
443
+ self.tanh_loc = tanh_loc
444
+ self._event_dims = event_dims
445
+
446
+ self.device = loc.device
447
+ self.upscale = (
448
+ upscale
449
+ if not isinstance(upscale, torch.Tensor)
450
+ else upscale.to(self.device, non_blocking=_non_blocking)
451
+ )
452
+
453
+ low = low.to(loc.device, non_blocking=_non_blocking)
454
+ self.low = low
455
+ self.high = high
456
+
457
+ if safe_tanh:
458
+ if is_compiling() and TORCH_VERSION_PRE_2_6:
459
+ _err_compile_safetanh()
460
+ t = SafeTanhTransform()
461
+ else:
462
+ t = D.TanhTransform()
463
+ # t = D.TanhTransform()
464
+ if is_compiling() or (self.non_trivial_max or self.non_trivial_min):
465
+ t = _PatchedComposeTransform(
466
+ [
467
+ t,
468
+ _PatchedAffineTransform(
469
+ loc=(high + low) / 2, scale=(high - low) / 2
470
+ ),
471
+ ]
472
+ )
473
+ self._t = t
474
+
475
+ self.update(loc, scale)
476
+
477
+ @property
478
+ def min(self):
479
+ self._warn_minmax()
480
+ return self.low
481
+
482
+ @property
483
+ def max(self):
484
+ self._warn_minmax()
485
+ return self.high
486
+
487
+ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
488
+ if self.tanh_loc:
489
+ loc = (loc / self.upscale).tanh() * self.upscale
490
+ # loc must be rescaled if tanh_loc
491
+ if is_compiling() or (self.non_trivial_max or self.non_trivial_min):
492
+ loc = loc + (self.high - self.low) / 2 + self.low
493
+ # Support callable scale (e.g., torch.ones_like) for compile-friendliness
494
+ if callable(scale) and not isinstance(scale, torch.Tensor):
495
+ scale = scale(loc)
496
+ elif not isinstance(scale, torch.Tensor):
497
+ scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
498
+ elif scale.device != loc.device:
499
+ scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
500
+ self.loc = loc
501
+ self.scale = scale
502
+
503
+ if (
504
+ hasattr(self, "base_dist")
505
+ and (self.root_dist.loc.shape == self.loc.shape)
506
+ and (self.root_dist.scale.shape == self.scale.shape)
507
+ ):
508
+ self.root_dist.loc = self.loc
509
+ self.root_dist.scale = self.scale
510
+ else:
511
+ if self._event_dims > 0:
512
+ base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
513
+ super().__init__(base, self._t)
514
+ else:
515
+ base = D.Normal(self.loc, self.scale)
516
+ super().__init__(base, self._t)
517
+
518
+ @property
519
+ def support(self):
520
+ return D.constraints.real()
521
+
522
+ @property
523
+ def root_dist(self):
524
+ bd = self
525
+ while hasattr(bd, "base_dist"):
526
+ bd = bd.base_dist
527
+ return bd
528
+
529
+ @property
530
+ def mode(self):
531
+ raise RuntimeError(
532
+ f"The distribution {type(self).__name__} has not analytical mode. "
533
+ f"Use ExplorationMode.DETERMINISTIC to get a deterministic sample from it."
534
+ )
535
+
536
+ @property
537
+ def deterministic_sample(self):
538
+ m = self.root_dist.mean
539
+ for t in self.transforms:
540
+ m = t(m)
541
+ return m
542
+
543
+ @torch.enable_grad()
544
+ def get_mode(self):
545
+ """Computes an estimation of the mode using the Adam optimizer."""
546
+ # Get starting point
547
+ m = self.sample((1000,)).mean(0)
548
+ m = torch.nn.Parameter(m.clamp(self.low, self.high).detach())
549
+ optim = torch.optim.Adam((m,), lr=1e-2)
550
+ self_copy = type(self)(
551
+ loc=self.loc.detach(),
552
+ scale=self.scale.detach(),
553
+ low=self.low.detach(),
554
+ high=self.high.detach(),
555
+ event_dims=self._event_dims,
556
+ upscale=self.upscale,
557
+ tanh_loc=False,
558
+ )
559
+ for _ in range(200):
560
+ lp = -self_copy.log_prob(m)
561
+ lp.mean().backward()
562
+ mc = m.clone().detach()
563
+ m.grad.clamp_max_(1)
564
+ optim.step()
565
+ optim.zero_grad()
566
+ m.data.clamp_(self_copy.low, self_copy.high)
567
+ nans = m.isnan()
568
+ if nans.any():
569
+ m.data = torch.where(nans, mc, m.data)
570
+ if (m - mc).norm() < 1e-3:
571
+ break
572
+ return m.detach()
573
+
574
+ @property
575
+ def mean(self):
576
+ raise NotImplementedError(
577
+ f"{type(self).__name__} does not have a closed form formula for the average. "
578
+ "An estimate of this value can be computed using dist.sample((N,)).mean(dim=0), "
579
+ "where N is a large number of samples."
580
+ )
581
+
582
+
583
+ def uniform_sample_tanhnormal(dist: TanhNormal, size=None) -> torch.Tensor:
584
+ """Defines what uniform sampling looks like for a TanhNormal distribution.
585
+
586
+ Args:
587
+ dist (TanhNormal): distribution defining the space where the sampling should occur.
588
+ size (torch.Size): batch-size of the output tensor
589
+
590
+ Returns:
591
+ a tensor sampled uniformly in the boundaries defined by the input distribution.
592
+
593
+ """
594
+ if size is None:
595
+ size = torch.Size([])
596
+ return torch.rand_like(dist.sample(size)) * (dist.max - dist.min) + dist.min
597
+
598
+
599
+ class Delta(D.Distribution):
600
+ """Delta distribution.
601
+
602
+ Args:
603
+ param (torch.Tensor): parameter of the delta distribution;
604
+ atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter;
605
+ Default is 1e-6
606
+ rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter;
607
+ Default is 1e-6
608
+ batch_shape (torch.Size, optional): batch shape;
609
+ event_shape (torch.Size, optional): shape of the outcome.
610
+
611
+ """
612
+
613
+ arg_constraints: dict = {}
614
+
615
+ def __init__(
616
+ self,
617
+ param: torch.Tensor,
618
+ atol: float = 1e-6,
619
+ rtol: float = 1e-6,
620
+ batch_shape: torch.Size | Sequence[int] = None,
621
+ event_shape: torch.Size | Sequence[int] = None,
622
+ ):
623
+ if batch_shape is None:
624
+ batch_shape = torch.Size([])
625
+ if event_shape is None:
626
+ event_shape = torch.Size([])
627
+ self.update(param)
628
+ self.atol = atol
629
+ self.rtol = rtol
630
+ if not len(batch_shape) and not len(event_shape):
631
+ batch_shape = param.shape[:-1]
632
+ event_shape = param.shape[-1:]
633
+ super().__init__(batch_shape=batch_shape, event_shape=event_shape)
634
+
635
+ def expand(self, batch_shape: torch.Size, _instance=None):
636
+ if self.batch_shape != tuple(batch_shape):
637
+ return type(self)(
638
+ self.param.expand((*batch_shape, *self.event_shape)),
639
+ atol=self.atol,
640
+ rtol=self.rtol,
641
+ )
642
+ return self
643
+
644
+ def update(self, param):
645
+ self.param = param
646
+
647
+ def _is_equal(self, value: torch.Tensor) -> torch.Tensor:
648
+ param = self.param.expand_as(value)
649
+ is_equal = abs(value - param) < self.atol + self.rtol * abs(param)
650
+ for i in range(-1, -len(self.event_shape) - 1, -1):
651
+ is_equal = is_equal.all(i)
652
+ return is_equal
653
+
654
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
655
+ is_equal = self._is_equal(value)
656
+ out = torch.zeros_like(is_equal, dtype=value.dtype)
657
+ out.masked_fill_(is_equal, np.inf)
658
+ out.masked_fill_(~is_equal, -np.inf)
659
+ return out
660
+
661
+ @torch.no_grad()
662
+ def sample(self, size=None) -> torch.Tensor:
663
+ if size is None:
664
+ size = torch.Size([])
665
+ return self.param.expand(*size, *self.param.shape)
666
+
667
+ def rsample(self, size=None) -> torch.Tensor:
668
+ if size is None:
669
+ size = torch.Size([])
670
+ return self.param.expand(*size, *self.param.shape)
671
+
672
+ @property
673
+ def mode(self) -> torch.Tensor:
674
+ return self.param
675
+
676
+ @property
677
+ def deterministic_sample(self):
678
+ return self.mean
679
+
680
+ @property
681
+ def mean(self) -> torch.Tensor:
682
+ return self.param
683
+
684
+
685
+ class TanhDelta(FasterTransformedDistribution):
686
+ """Implements a Tanh transformed_in Delta distribution.
687
+
688
+ Args:
689
+ param (torch.Tensor): parameter of the delta distribution;
690
+ low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
691
+ high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
692
+ event_dims (int, optional): number of dimensions describing the action.
693
+ Default is 1;
694
+ atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter;
695
+ Default is 1e-6
696
+ rtol (number, optional): relative tolerance to consider that a tensor matches the distribution parameter;
697
+ Default is 1e-6
698
+ batch_shape (torch.Size, optional): batch shape;
699
+ event_shape (torch.Size, optional): shape of the outcome;
700
+
701
+ """
702
+
703
+ arg_constraints = {
704
+ "loc": constraints.real,
705
+ }
706
+
707
+ def __init__(
708
+ self,
709
+ param: torch.Tensor,
710
+ low: torch.Tensor | float = -1.0,
711
+ high: torch.Tensor | float = 1.0,
712
+ event_dims: int = 1,
713
+ atol: float = 1e-6,
714
+ rtol: float = 1e-6,
715
+ safe: bool = True,
716
+ ):
717
+ minmax_msg = "high value has been found to be equal or less than low value"
718
+ if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
719
+ if is_compiling():
720
+ assert (high > low).all()
721
+ else:
722
+ if not (high > low).all():
723
+ raise ValueError(minmax_msg)
724
+ elif isinstance(high, Number) and isinstance(low, Number):
725
+ if is_compiling():
726
+ assert high > low
727
+ elif high <= low:
728
+ raise ValueError(minmax_msg)
729
+ else:
730
+ if not all(high > low):
731
+ raise ValueError(minmax_msg)
732
+
733
+ if safe:
734
+ if is_compiling():
735
+ _err_compile_safetanh()
736
+ t = SafeTanhTransform()
737
+ else:
738
+ t = torch.distributions.TanhTransform()
739
+ non_trivial_min = is_compiling or (
740
+ (isinstance(low, torch.Tensor) and (low != -1.0).any())
741
+ or (not isinstance(low, torch.Tensor) and low != -1.0)
742
+ )
743
+ non_trivial_max = is_compiling or (
744
+ (isinstance(high, torch.Tensor) and (high != 1.0).any())
745
+ or (not isinstance(high, torch.Tensor) and high != 1.0)
746
+ )
747
+ self.non_trivial = non_trivial_min or non_trivial_max
748
+
749
+ self.low = _cast_device(low, param.device)
750
+ self.high = _cast_device(high, param.device)
751
+ loc = self.update(param)
752
+
753
+ if self.non_trivial:
754
+ t = _PatchedComposeTransform(
755
+ [
756
+ t,
757
+ _PatchedAffineTransform(
758
+ loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2
759
+ ),
760
+ ]
761
+ )
762
+ event_shape = param.shape[-event_dims:]
763
+ batch_shape = param.shape[:-event_dims]
764
+ base = Delta(
765
+ loc,
766
+ atol=atol,
767
+ rtol=rtol,
768
+ batch_shape=batch_shape,
769
+ event_shape=event_shape,
770
+ )
771
+
772
+ super().__init__(base, t)
773
+
774
+ @property
775
+ def min(self):
776
+ self._warn_minmax()
777
+ return self.low
778
+
779
+ @property
780
+ def max(self):
781
+ self._warn_minmax()
782
+ return self.high
783
+
784
+ def update(self, net_output: torch.Tensor) -> torch.Tensor | None:
785
+ loc = net_output
786
+ if self.non_trivial:
787
+ device = loc.device
788
+ shift = _cast_device(self.high - self.low, device)
789
+ loc = loc + shift / 2 + _cast_device(self.low, device)
790
+ if hasattr(self, "base_dist"):
791
+ self.base_dist.update(loc)
792
+ else:
793
+ return loc
794
+
795
+ @property
796
+ def mode(self) -> torch.Tensor:
797
+ mode = self.base_dist.param
798
+ for t in self.transforms:
799
+ mode = t(mode)
800
+ return mode
801
+
802
+ @property
803
+ def deterministic_sample(self):
804
+ return self.mode
805
+
806
+ @property
807
+ def mean(self) -> torch.Tensor:
808
+ raise AttributeError("TanhDelta mean has not analytical form.")
809
+
810
+
811
+ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
812
+ if size is None:
813
+ size = torch.Size([])
814
+ return torch.randn_like(dist.sample(size))
815
+
816
+
817
+ uniform_sample_delta = _uniform_sample_delta
818
+
819
+
820
+ def _err_compile_safetanh():
821
+ raise RuntimeError(
822
+ "safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. "
823
+ " To deactivate it, pass safe_tanh=False. "
824
+ "If you are using a ProbabilisticTensorDictModule, this can be done via "
825
+ "`distribution_kwargs={'safe_tanh': False}`. "
826
+ "See https://github.com/pytorch/pytorch/issues/133529 for more details."
827
+ )
828
+
829
+
830
+ _warn_compile_safetanh = assume_constant_result(_err_compile_safetanh)