torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,321 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+
9
+ import torch
10
+ from tensordict import TensorDictBase, unravel_key_list
11
+ from tensordict.nn import (
12
+ InteractionType,
13
+ ProbabilisticTensorDictModule,
14
+ ProbabilisticTensorDictSequential,
15
+ TensorDictModule,
16
+ )
17
+ from tensordict.utils import NestedKey
18
+ from torchrl.data.tensor_specs import Composite, TensorSpec
19
+ from torchrl.modules.distributions import Delta
20
+ from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
21
+ from torchrl.modules.tensordict_module.sequence import SafeSequential
22
+
23
+
24
+ class SafeProbabilisticModule(ProbabilisticTensorDictModule):
25
+ """:class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as an argument to control the output domain.
26
+
27
+ `SafeProbabilisticModule` is a non-parametric module embedding a
28
+ probability distribution constructor. It reads the distribution parameters from an input
29
+ TensorDict using the specified `in_keys` and outputs a sample (loosely speaking) of the
30
+ distribution.
31
+
32
+ The output "sample" is produced given some rule, specified by the input ``default_interaction_type``
33
+ argument and the ``interaction_type()`` global function.
34
+
35
+ `SafeProbabilisticModule` can be used to construct the distribution
36
+ (through the :meth:`get_dist` method) and/or sampling from this distribution
37
+ (through a regular :meth:`__call__` to the module).
38
+
39
+ A `SafeProbabilisticModule` instance has two main features:
40
+
41
+ - It reads and writes from and to TensorDict objects;
42
+ - It uses a real mapping R^n -> R^m to create a distribution in R^d from
43
+ which values can be sampled or computed.
44
+
45
+ When the ``__call__`` and ``forward`` methods are called, a distribution is
46
+ created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean',
47
+ 'dist.mode', 'dist.median' attributes could be used, as well as
48
+ the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied
49
+ TensorDict has all the desired key-value pairs already.
50
+
51
+ By default, `SafeProbabilisticModule` distribution class is a :class:`~torchrl.modules.distributions.Delta`
52
+ distribution, making `SafeProbabilisticModule` a simple wrapper around
53
+ a deterministic mapping function.
54
+
55
+ This class differs from :class:`tensordict.nn.ProbabilisticTensorDictModule` in that it accepts a :attr:`spec`
56
+ keyword argument which can be used to control whether samples belong to the distribution or not. The :attr:`safe`
57
+ keyword argument controls whether the samples values should be checked against the spec.
58
+
59
+ Args:
60
+ in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict
61
+ and used to build the distribution.
62
+ Importantly, if it's a list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
63
+ the distribution class of interest, e.g. ``"loc"`` and ``"scale"`` for
64
+ the :class:`~torch.distributions.Normal` distribution and similar.
65
+ If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the
66
+ tensordict that will get match to the corresponding distribution keys.
67
+ out_keys (NestedKey | List[NestedKey] | None): key(s) where the sampled values will be written.
68
+ Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.
69
+ spec (TensorSpec): specs of the first output tensor. Used when calling
70
+ td_module.random() to generate random values in the target space.
71
+
72
+ Keyword Args:
73
+ safe (bool, optional): if ``True``, the value of the sample is checked against the
74
+ input spec. Out-of-domain sampling can occur because of exploration policies
75
+ or numerical under/overflow issues. As for the :obj:`spec` argument, this
76
+ check will only occur for the distribution sample, but not the other tensors
77
+ returned by the input module. If the sample is out of bounds, it is
78
+ projected back onto the desired space using the `TensorSpec.project` method.
79
+ Default is ``False``.
80
+ default_interaction_type (InteractionType, optional): keyword-only argument.
81
+ Default method to be used to retrieve
82
+ the output value. Should be one of InteractionType: MODE, MEDIAN, MEAN or RANDOM
83
+ (in which case the value is sampled randomly from the distribution). Default
84
+ is MODE.
85
+
86
+ .. note:: When a sample is drawn, the
87
+ :class:`ProbabilisticTensorDictModule` instance will
88
+ first look for the interaction mode dictated by the
89
+ :func:`~tensordict.nn.probabilistic.interaction_type`
90
+ global function. If this returns `None` (its default value), then the
91
+ `default_interaction_type` of the `ProbabilisticTDModule`
92
+ instance will be used. Note that
93
+ :class:`~torchrl.collectors.BaseCollector`
94
+ instances will use `set_interaction_type` to
95
+ :class:`tensordict.nn.InteractionType.RANDOM` by default.
96
+
97
+ .. note::
98
+ In some cases, the mode, median or mean value may not be
99
+ readily available through the corresponding attribute.
100
+ To paliate this, :class:`~ProbabilisticTensorDictModule` will first attempt
101
+ to get the value through a call to ``get_mode()``, ``get_median()`` or ``get_mean()``
102
+ if the method exists.
103
+
104
+ distribution_class (Type or Callable[[Any], Distribution], optional): keyword-only argument.
105
+ A :class:`torch.distributions.Distribution` class to
106
+ be used for sampling.
107
+ Default is :class:`~tensordict.nn.distributions.Delta`.
108
+
109
+ .. note::
110
+ If the distribution class is of type
111
+ :class:`~tensordict.nn.distributions.CompositeDistribution`, the ``out_keys``
112
+ can be inferred directly form the ``"distribution_map"`` or ``"name_map"``
113
+ keyword arguments provided through this class' ``distribution_kwargs``
114
+ keyword argument, making the ``out_keys`` optional in such cases.
115
+
116
+ distribution_kwargs (dict, optional): keyword-only argument.
117
+ Keyword-argument pairs to be passed to the distribution.
118
+
119
+ .. note:: if your kwargs contain tensors that you would like to transfer to device with the module, or
120
+ tensors that should see their dtype modified when calling `module.to(dtype)`, you can wrap the kwargs
121
+ in a :class:`~tensordict.nn.TensorDictParams` to do this automatically.
122
+
123
+ return_log_prob (bool, optional): keyword-only argument.
124
+ If ``True``, the log-probability of the
125
+ distribution sample will be written in the tensordict with the key
126
+ `log_prob_key`. Default is ``False``.
127
+ log_prob_keys (List[NestedKey], optional): keys where to write the log_prob if ``return_log_prob=True``.
128
+ Defaults to `'<sample_key_name>_log_prob'`, where `<sample_key_name>` is each of the :attr:`out_keys`.
129
+
130
+ .. note:: This is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``False``.
131
+
132
+ log_prob_key (NestedKey, optional): key where to write the log_prob if ``return_log_prob=True``.
133
+ Defaults to `'sample_log_prob'` when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to `True`
134
+ or `'<sample_key_name>_log_prob'` otherwise.
135
+
136
+ .. note:: When there is more than one sample, this is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``True``.
137
+
138
+ cache_dist (bool, optional): keyword-only argument.
139
+ EXPERIMENTAL: if ``True``, the parameters of the
140
+ distribution (i.e. the output of the module) will be written to the
141
+ tensordict along with the sample. Those parameters can be used to re-compute
142
+ the original distribution later on (e.g. to compute the divergence between
143
+ the distribution used to sample the action and the updated distribution in
144
+ PPO). Default is ``False``.
145
+ n_empirical_estimate (int, optional): keyword-only argument.
146
+ Number of samples to compute the empirical
147
+ mean when it is not available. Defaults to 1000.
148
+
149
+ .. warning:: Running checks takes time! Using `safe=True` will guarantee that the samples are within the spec bounds
150
+ given some heuristic coded in :meth:`~torchrl.data.TensorSpec.project`, but that requires checking whether the
151
+ values are within the spec space, which will induce some overhead.
152
+
153
+ .. seealso:: :class`The composite distribution in tensordict <~tensordict.nn.CompositeDistribution>` can be used
154
+ to create multi-head policies.
155
+
156
+ Example:
157
+ >>> from torchrl.modules import SafeProbabilisticModule
158
+ >>> from torchrl.data import Bounded
159
+ >>> import torch
160
+ >>> from tensordict import TensorDict
161
+ >>> from tensordict.nn import InteractionType
162
+ >>> mod = SafeProbabilisticModule(
163
+ ... in_keys=["loc", "scale"],
164
+ ... out_keys=["action"],
165
+ ... distribution_class=torch.distributions.Normal,
166
+ ... safe=True,
167
+ ... spec=Bounded(low=-1, high=1, shape=()),
168
+ ... default_interaction_type=InteractionType.RANDOM
169
+ ... )
170
+ >>> _ = torch.manual_seed(0)
171
+ >>> data = TensorDict(
172
+ ... loc=torch.zeros(10, requires_grad=True),
173
+ ... scale=torch.full((10,), 10.0),
174
+ ... batch_size=(10,))
175
+ >>> data = mod(data)
176
+ >>> print(data["action"]) # All actions are within bound
177
+ tensor([ 1., -1., -1., 1., -1., -1., 1., 1., -1., -1.],
178
+ grad_fn=<ClampBackward0>)
179
+ >>> data["action"].mean().backward()
180
+ >>> print(data["loc"].grad) # clamp anihilates gradients
181
+ tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ in_keys: NestedKey | list[NestedKey] | dict[str, NestedKey],
187
+ out_keys: NestedKey | list[NestedKey] | None = None,
188
+ spec: TensorSpec | None = None,
189
+ *,
190
+ safe: bool = False,
191
+ default_interaction_type: InteractionType = InteractionType.DETERMINISTIC,
192
+ distribution_class: type = Delta,
193
+ distribution_kwargs: dict | None = None,
194
+ return_log_prob: bool = False,
195
+ log_prob_keys: list[NestedKey] | None = None,
196
+ log_prob_key: NestedKey | None = None,
197
+ cache_dist: bool = False,
198
+ n_empirical_estimate: int = 1000,
199
+ num_samples: int | torch.Size | None = None,
200
+ ):
201
+ super().__init__(
202
+ in_keys=in_keys,
203
+ out_keys=out_keys,
204
+ default_interaction_type=default_interaction_type,
205
+ distribution_class=distribution_class,
206
+ distribution_kwargs=distribution_kwargs,
207
+ return_log_prob=return_log_prob,
208
+ cache_dist=cache_dist,
209
+ n_empirical_estimate=n_empirical_estimate,
210
+ log_prob_keys=log_prob_keys,
211
+ log_prob_key=log_prob_key,
212
+ num_samples=num_samples,
213
+ )
214
+ if spec is not None:
215
+ spec = spec.clone()
216
+ if spec is not None and not isinstance(spec, TensorSpec):
217
+ raise TypeError("spec must be a TensorSpec subclass")
218
+ elif spec is not None and not isinstance(spec, Composite):
219
+ if len(self.out_keys) - return_log_prob > 1:
220
+ raise RuntimeError(
221
+ f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
222
+ "Consider using a Composite object or no spec at all."
223
+ )
224
+ spec = Composite({self.out_keys[0]: spec})
225
+ elif spec is not None and isinstance(spec, Composite):
226
+ if "_" in spec.keys():
227
+ warnings.warn('got a spec with key "_": it will be ignored')
228
+ elif spec is None:
229
+ spec = Composite()
230
+ spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
231
+ out_keys = set(unravel_key_list(self._out_keys))
232
+ if spec_keys != out_keys:
233
+ # then assume that all the non indicated specs are None
234
+ for key in out_keys:
235
+ if key not in spec_keys:
236
+ spec[key] = None
237
+ spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
238
+
239
+ if spec_keys != out_keys:
240
+ raise RuntimeError(
241
+ f"spec keys and out_keys do not match, got: {spec_keys} and {out_keys} respectively"
242
+ )
243
+
244
+ self._spec = spec
245
+ self.safe = safe
246
+ if safe:
247
+ if spec is None or (
248
+ isinstance(spec, Composite)
249
+ and all(_spec is None for _spec in spec.values())
250
+ ):
251
+ raise RuntimeError(
252
+ "`SafeProbabilisticModule(spec=None, safe=True)` is not a valid configuration as the tensor "
253
+ "specs are not specified"
254
+ )
255
+ self.register_forward_hook(_forward_hook_safe_action)
256
+
257
+ @property
258
+ def spec(self) -> Composite:
259
+ return self._spec
260
+
261
+ @spec.setter
262
+ def spec(self, spec: Composite) -> None:
263
+ if not isinstance(spec, Composite):
264
+ raise RuntimeError(
265
+ f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance."
266
+ )
267
+ self._spec = spec
268
+
269
+ def random(self, tensordict: TensorDictBase) -> TensorDictBase:
270
+ """Samples a random element in the target space, irrespective of any input.
271
+
272
+ If multiple output keys are present, only the first will be written in the input :obj:`tensordict`.
273
+
274
+ Args:
275
+ tensordict (TensorDictBase): tensordict where the output value should be written.
276
+
277
+ Returns:
278
+ the original tensordict with a new/updated value for the output key.
279
+
280
+ """
281
+ key0 = self.out_keys[0]
282
+ tensordict.set(key0, self.spec.rand(tensordict.batch_size))
283
+ return tensordict
284
+
285
+ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase:
286
+ """See :obj:`SafeModule.random(...)`."""
287
+ return self.random(tensordict)
288
+
289
+
290
+ class SafeProbabilisticTensorDictSequential(
291
+ ProbabilisticTensorDictSequential, SafeSequential
292
+ ):
293
+ """:class:`tensordict.nn.ProbabilisticTensorDictSequential` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as argument to control the output domain.
294
+
295
+ Similarly to :obj:`TensorDictSequential`, but enforces that the final module in the
296
+ sequence is an :obj:`ProbabilisticTensorDictModule` and also exposes ``get_dist``
297
+ method to recover the distribution object from the ``ProbabilisticTensorDictModule``
298
+
299
+ Args:
300
+ modules (iterable of TensorDictModules): ordered sequence of TensorDictModule
301
+ instances, terminating in ProbabilisticTensorDictModule, to be run
302
+ sequentially.
303
+ partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some
304
+ of the input keys. If so, the only modules that will be executed are those
305
+ which can be executed given the keys that are present. Also, if the input
306
+ tensordict is a lazy stack of tensordicts AND if partial_tolerant is
307
+ ``True`` AND if the stack does not have the required keys, then
308
+ TensorDictSequential will scan through the sub-tensordicts looking for those
309
+ that have the required keys, if any.
310
+
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ *modules: TensorDictModule | ProbabilisticTensorDictModule,
316
+ partial_tolerant: bool = False,
317
+ ) -> None:
318
+ super().__init__(*modules, partial_tolerant=partial_tolerant)
319
+ super(ProbabilisticTensorDictSequential, self).__init__(
320
+ *modules, partial_tolerant=partial_tolerant
321
+ )