torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,132 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from tensordict.nn import TensorDictModule, TensorDictSequential
9
+ from torch import nn
10
+
11
+ from torchrl.data.tensor_specs import Composite
12
+ from torchrl.modules.tensordict_module.common import SafeModule
13
+
14
+
15
+ class SafeSequential(TensorDictSequential, SafeModule):
16
+ """A safe sequence of TensorDictModules.
17
+
18
+ Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor
19
+ each, this module will read and write over a tensordict by querying each of the input modules.
20
+ When calling a :obj:`TensorDictSequential` instance with a functional module, it is expected that the parameter lists (and
21
+ buffers) will be concatenated in a single list.
22
+
23
+ Args:
24
+ modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
25
+ partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some of the input keys.
26
+ If so, the only modules that will be executed are those which can be executed given the keys that
27
+ are present.
28
+ Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the
29
+ stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts
30
+ looking for those that have the required keys, if any.
31
+ inplace (bool or str, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
32
+ :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
33
+ output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
34
+
35
+ TensorDictSequence supports functional, modular and vmap coding:
36
+ Examples:
37
+ >>> import torch
38
+ >>> from tensordict import TensorDict
39
+ >>> from torchrl.data import Composite, Unbounded
40
+ >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor
41
+ >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule
42
+ >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
43
+ >>> spec1 = Composite(hidden=Unbounded(4), loc=None, scale=None)
44
+ >>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
45
+ >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"])
46
+ >>> td_module1 = SafeProbabilisticModule(
47
+ ... module=module1,
48
+ ... spec=spec1,
49
+ ... in_keys=["loc", "scale"],
50
+ ... out_keys=["hidden"],
51
+ ... distribution_class=TanhNormal,
52
+ ... return_log_prob=True,
53
+ ... )
54
+ >>> spec2 = Unbounded(8)
55
+ >>> module2 = torch.nn.Linear(4, 8)
56
+ >>> td_module2 = TensorDictModule(
57
+ ... module=module2,
58
+ ... spec=spec2,
59
+ ... in_keys=["hidden"],
60
+ ... out_keys=["output"],
61
+ ... )
62
+ >>> td_module = SafeSequential(td_module1, td_module2)
63
+ >>> params = TensorDict.from_module(td_module)
64
+ >>> with params.to_module(td_module):
65
+ ... td_module(td)
66
+ >>> print(td)
67
+ TensorDict(
68
+ fields={
69
+ hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),
70
+ input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
71
+ loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
72
+ output: Tensor(torch.Size([3, 8]), dtype=torch.float32),
73
+ sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),
74
+ scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
75
+ batch_size=torch.Size([3]),
76
+ device=None,
77
+ is_shared=False)
78
+ >>> # The module spec aggregates all the input specs:
79
+ >>> print(td_module.spec)
80
+ Composite(
81
+ hidden: UnboundedContinuous(
82
+ shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous),
83
+ loc: None,
84
+ scale: None,
85
+ output: UnboundedContinuous(
86
+ shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous))
87
+
88
+ In the vmap case:
89
+ >>> from torch import vmap
90
+ >>> params = params.expand(4, *params.shape)
91
+ >>> td_vmap = vmap(td_module, (None, 0))(td, params)
92
+ >>> print(td_vmap)
93
+ TensorDict(
94
+ fields={
95
+ hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
96
+ input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
97
+ loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
98
+ output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32),
99
+ sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32),
100
+ scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)},
101
+ batch_size=torch.Size([4, 3]),
102
+ device=None,
103
+ is_shared=False)
104
+
105
+ """
106
+
107
+ module: nn.ModuleList
108
+
109
+ def __init__(
110
+ self,
111
+ *modules: TensorDictModule,
112
+ partial_tolerant: bool = False,
113
+ inplace: bool | str | None = None,
114
+ ):
115
+ self.partial_tolerant = partial_tolerant
116
+
117
+ in_keys, out_keys = self._compute_in_and_out_keys(modules)
118
+
119
+ spec = Composite()
120
+ for module in modules:
121
+ try:
122
+ spec.update(module.spec)
123
+ except AttributeError:
124
+ spec.update(Composite({key: None for key in module.out_keys}))
125
+
126
+ super(TensorDictSequential, self).__init__(
127
+ spec=spec,
128
+ module=nn.ModuleList(list(modules)),
129
+ in_keys=in_keys,
130
+ out_keys=out_keys,
131
+ inplace=inplace,
132
+ )
@@ -0,0 +1,34 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from tensordict.nn import TensorDictModule, TensorDictSequential
8
+
9
+
10
+ class WorldModelWrapper(TensorDictSequential):
11
+ """World model wrapper.
12
+
13
+ This module wraps together a transition model and a reward model.
14
+ The transition model is used to predict an imaginary world state.
15
+ The reward model is used to predict the reward of the imagined transition.
16
+
17
+ Args:
18
+ transition_model (TensorDictModule): a transition model that generates a new world states.
19
+ reward_model (TensorDictModule): a reward model, that reads the world state and returns a reward.
20
+
21
+ """
22
+
23
+ def __init__(
24
+ self, transition_model: TensorDictModule, reward_model: TensorDictModule
25
+ ):
26
+ super().__init__(transition_model, reward_model)
27
+
28
+ def get_transition_model_operator(self) -> TensorDictModule:
29
+ """Returns a transition operator that maps either an observation to a world state or a world state to the next world state."""
30
+ return self.module[0]
31
+
32
+ def get_reward_operator(self) -> TensorDictModule:
33
+ """Returns a reward operator that maps a world state to a reward."""
34
+ return self.module[1]
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ from packaging import version
10
+
11
+
12
+ if version.parse(torch.__version__) >= version.parse("1.12.0"):
13
+ from torch.nn.parameter import _ParameterMeta
14
+ else:
15
+ pass
16
+
17
+ # Metaclass to combine _TensorMeta and the instance check override for Parameter.
18
+ class _ParameterMeta(torch._C._TensorMeta):
19
+ # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
20
+ def __instancecheck__(self, instance):
21
+ return super().__instancecheck__(instance) or (
22
+ isinstance(instance, torch.Tensor)
23
+ and getattr(instance, "_is_param", False)
24
+ )
25
+
26
+
27
+ from .mappings import biased_softplus, inv_softplus, mappings
28
+ from .utils import get_primers_from_module
29
+
30
+ __all__ = [
31
+ "OrderedDict",
32
+ "torch",
33
+ "version",
34
+ "biased_softplus",
35
+ "inv_softplus",
36
+ "mappings",
37
+ "get_primers_from_module",
38
+ ]
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from tensordict.nn.utils import biased_softplus, expln, inv_softplus, mappings
8
+
9
+ __all__ = ["biased_softplus", "expln", "inv_softplus", "mappings"]
@@ -0,0 +1,89 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+
9
+ import torch
10
+ from tensordict.utils import expand_as_right
11
+
12
+
13
+ def get_primers_from_module(module):
14
+ """Get all tensordict primers from all submodules of a module.
15
+
16
+ This method is useful for retrieving primers from modules that are contained within a
17
+ parent module.
18
+
19
+ Args:
20
+ module (torch.nn.Module): The parent module.
21
+
22
+ Returns:
23
+ TensorDictPrimer: A TensorDictPrimer Transform.
24
+
25
+ Example:
26
+ >>> from torchrl.modules.utils import get_primers_from_module
27
+ >>> from torchrl.modules import GRUModule, MLP
28
+ >>> from tensordict.nn import TensorDictModule, TensorDictSequential
29
+ >>> # Define a GRU module
30
+ >>> gru_module = GRUModule(
31
+ ... input_size=10,
32
+ ... hidden_size=10,
33
+ ... num_layers=1,
34
+ ... in_keys=["input", "recurrent_state", "is_init"],
35
+ ... out_keys=["features", ("next", "recurrent_state")],
36
+ ... )
37
+ >>> # Define a head module
38
+ >>> head = TensorDictModule(
39
+ ... MLP(
40
+ ... in_features=10,
41
+ ... out_features=10,
42
+ ... num_cells=[],
43
+ ... ),
44
+ ... in_keys=["features"],
45
+ ... out_keys=["output"],
46
+ ... )
47
+ >>> # Create a sequential model
48
+ >>> model = TensorDictSequential(gru_module, head)
49
+ >>> # Retrieve primers from the model
50
+ >>> primers = get_primers_from_module(model)
51
+ >>> print(primers)
52
+
53
+ TensorDictPrimer(primers=Composite(
54
+ recurrent_state: UnboundedContinuous(
55
+ shape=torch.Size([1, 10]),
56
+ space=None,
57
+ device=cpu,
58
+ dtype=torch.float32,
59
+ domain=continuous), device=None, shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
60
+
61
+ """
62
+ primers = []
63
+
64
+ def make_primers(submodule):
65
+ if hasattr(submodule, "make_tensordict_primer"):
66
+ primers.append(submodule.make_tensordict_primer())
67
+
68
+ module.apply(make_primers)
69
+ if not primers:
70
+ warnings.warn("No primers found in the module.")
71
+ return
72
+ elif len(primers) == 1:
73
+ return primers[0]
74
+ else:
75
+ from torchrl.envs.transforms import Compose
76
+
77
+ return Compose(*primers)
78
+
79
+
80
+ def _unpad_tensors(tensors, mask, as_nested: bool = True) -> torch.Tensor:
81
+ shape = tensors.shape[2:]
82
+ mask = expand_as_right(mask.bool(), tensors)
83
+ nelts = mask.sum(-1)
84
+ while nelts.dim() > 1:
85
+ nelts = nelts.sum(-1)
86
+ vals = [t.view(-1, *shape) for t in tensors[mask].split(nelts.tolist(), dim=0)]
87
+ if as_nested:
88
+ return torch.nested.as_nested_tensor(vals)
89
+ return vals
@@ -0,0 +1,78 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torchrl.objectives.a2c import A2CLoss
7
+ from torchrl.objectives.common import add_random_module, LossModule
8
+ from torchrl.objectives.cql import CQLLoss, DiscreteCQLLoss
9
+ from torchrl.objectives.crossq import CrossQLoss
10
+ from torchrl.objectives.ddpg import DDPGLoss
11
+ from torchrl.objectives.decision_transformer import DTLoss, OnlineDTLoss
12
+ from torchrl.objectives.dqn import DistributionalDQNLoss, DQNLoss
13
+ from torchrl.objectives.dreamer import (
14
+ DreamerActorLoss,
15
+ DreamerModelLoss,
16
+ DreamerValueLoss,
17
+ )
18
+ from torchrl.objectives.gail import GAILLoss
19
+ from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss
20
+ from torchrl.objectives.multiagent import QMixerLoss
21
+ from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss
22
+ from torchrl.objectives.redq import REDQLoss
23
+ from torchrl.objectives.reinforce import ReinforceLoss
24
+ from torchrl.objectives.sac import DiscreteSACLoss, SACLoss
25
+ from torchrl.objectives.td3 import TD3Loss
26
+ from torchrl.objectives.td3_bc import TD3BCLoss
27
+ from torchrl.objectives.utils import (
28
+ default_value_kwargs,
29
+ distance_loss,
30
+ group_optimizers,
31
+ HardUpdate,
32
+ hold_out_net,
33
+ hold_out_params,
34
+ next_state_value,
35
+ SoftUpdate,
36
+ TargetNetUpdater,
37
+ ValueEstimators,
38
+ )
39
+
40
+ __all__ = [
41
+ "A2CLoss",
42
+ "CQLLoss",
43
+ "ClipPPOLoss",
44
+ "CrossQLoss",
45
+ "DDPGLoss",
46
+ "DQNLoss",
47
+ "DTLoss",
48
+ "DiscreteCQLLoss",
49
+ "DiscreteIQLLoss",
50
+ "DiscreteSACLoss",
51
+ "DistributionalDQNLoss",
52
+ "DreamerActorLoss",
53
+ "DreamerModelLoss",
54
+ "DreamerValueLoss",
55
+ "GAILLoss",
56
+ "HardUpdate",
57
+ "IQLLoss",
58
+ "KLPENPPOLoss",
59
+ "LossModule",
60
+ "OnlineDTLoss",
61
+ "PPOLoss",
62
+ "QMixerLoss",
63
+ "REDQLoss",
64
+ "ReinforceLoss",
65
+ "SACLoss",
66
+ "SoftUpdate",
67
+ "TD3BCLoss",
68
+ "TD3Loss",
69
+ "TargetNetUpdater",
70
+ "ValueEstimators",
71
+ "add_random_module",
72
+ "default_value_kwargs",
73
+ "distance_loss",
74
+ "group_optimizers",
75
+ "hold_out_net",
76
+ "hold_out_params",
77
+ "next_state_value",
78
+ ]