torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,464 @@
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import time
5
+ import warnings
6
+ from collections import OrderedDict
7
+ from collections.abc import Iterator, Sequence
8
+ from queue import Empty
9
+
10
+ import torch
11
+
12
+ from tensordict import TensorDict, TensorDictBase
13
+ from tensordict.nn import TensorDictModuleBase
14
+ from torchrl import logger as torchrl_logger
15
+ from torchrl._utils import (
16
+ _check_for_faulty_process,
17
+ accept_remote_rref_udf_invocation,
18
+ RL_WARNINGS,
19
+ )
20
+ from torchrl.collectors._base import _make_legacy_metaclass
21
+ from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT
22
+ from torchrl.collectors._multi_base import _MultiCollectorMeta, MultiCollector
23
+ from torchrl.collectors.utils import split_trajectories
24
+
25
+
26
+ @accept_remote_rref_udf_invocation
27
+ class MultiSyncCollector(MultiCollector):
28
+ """Runs a given number of DataCollectors on separate processes synchronously.
29
+
30
+ .. aafig::
31
+
32
+ +----------------------------------------------------------------------+
33
+ | "MultiSyncCollector" | |
34
+ |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| |
35
+ | "Collector 1" | "Collector 2" | "Collector 3" | Main |
36
+ |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~|
37
+ | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | |
38
+ |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~|
39
+ |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | |
40
+ | | | | | | | |
41
+ | "actor" | | | "actor" | |
42
+ | | | | | |
43
+ | "step" | "step" | "actor" | | |
44
+ | | | | | |
45
+ | | | | "step" | "step" | |
46
+ | | | | | | |
47
+ | "actor" | "step" | "step" | "actor" | |
48
+ | | | | | |
49
+ | | "actor" | | |
50
+ | | | | |
51
+ | "yield batch of traj 1"------->"collect, train"|
52
+ | | |
53
+ | "step" | "step" | "step" | "step" | "step" | "step" | |
54
+ | | | | | | | |
55
+ | "actor" | "actor" | | | |
56
+ | | "step" | "step" | "actor" | |
57
+ | | | | | |
58
+ | "step" | "step" | "actor" | "step" | "step" | |
59
+ | | | | | | |
60
+ | "actor" | | "actor" | |
61
+ | "yield batch of traj 2"------->"collect, train"|
62
+ | | |
63
+ +----------------------------------------------------------------------+
64
+
65
+ Envs can be identical or different.
66
+
67
+ The collection starts when the next item of the collector is queried,
68
+ and no environment step is computed in between the reception of a batch of
69
+ trajectory and the start of the next collection.
70
+ This class can be safely used with online RL sota-implementations.
71
+
72
+ .. note::
73
+ Python requires multiprocessed code to be instantiated within a main guard:
74
+
75
+ >>> from torchrl.collectors import MultiSyncCollector
76
+ >>> if __name__ == "__main__":
77
+ ... # Create your collector here
78
+ ... collector = MultiSyncCollector(...)
79
+
80
+ See https://docs.python.org/3/library/multiprocessing.html for more info.
81
+
82
+ Examples:
83
+ >>> from torchrl.envs.libs.gym import GymEnv
84
+ >>> from tensordict.nn import TensorDictModule
85
+ >>> from torch import nn
86
+ >>> from torchrl.collectors import MultiSyncCollector
87
+ >>> if __name__ == "__main__":
88
+ ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
89
+ ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
90
+ ... collector = MultiSyncCollector(
91
+ ... create_env_fn=[env_maker, env_maker],
92
+ ... policy=policy,
93
+ ... total_frames=2000,
94
+ ... max_frames_per_traj=50,
95
+ ... frames_per_batch=200,
96
+ ... init_random_frames=-1,
97
+ ... reset_at_each_iter=False,
98
+ ... device="cpu",
99
+ ... storing_device="cpu",
100
+ ... cat_results="stack",
101
+ ... )
102
+ ... for i, data in enumerate(collector):
103
+ ... if i == 2:
104
+ ... print(data)
105
+ ... break
106
+ ... collector.shutdown()
107
+ ... del collector
108
+ TensorDict(
109
+ fields={
110
+ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
111
+ collector: TensorDict(
112
+ fields={
113
+ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
114
+ batch_size=torch.Size([200]),
115
+ device=cpu,
116
+ is_shared=False),
117
+ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
118
+ next: TensorDict(
119
+ fields={
120
+ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
121
+ observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
122
+ reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
123
+ step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
124
+ truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
125
+ batch_size=torch.Size([200]),
126
+ device=cpu,
127
+ is_shared=False),
128
+ observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
129
+ step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
130
+ truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
131
+ batch_size=torch.Size([200]),
132
+ device=cpu,
133
+ is_shared=False)
134
+
135
+ """
136
+
137
+ __doc__ += MultiCollector.__doc__
138
+
139
+ # for RPC
140
+ def next(self):
141
+ return super().next()
142
+
143
+ # for RPC
144
+ def shutdown(
145
+ self,
146
+ timeout: float | None = None,
147
+ close_env: bool = True,
148
+ raise_on_error: bool = True,
149
+ ) -> None:
150
+ if not close_env:
151
+ raise RuntimeError(
152
+ f"Cannot shutdown {type(self).__name__} collector without environment being closed."
153
+ )
154
+ if hasattr(self, "out_buffer"):
155
+ del self.out_buffer
156
+ if hasattr(self, "buffers"):
157
+ del self.buffers
158
+ try:
159
+ return super().shutdown(timeout=timeout)
160
+ except Exception as e:
161
+ if raise_on_error:
162
+ raise e
163
+ else:
164
+ pass
165
+
166
+ # for RPC
167
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
168
+ return super().set_seed(seed, static_seed)
169
+
170
+ # for RPC
171
+ def state_dict(self) -> OrderedDict:
172
+ return super().state_dict()
173
+
174
+ # for RPC
175
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
176
+ return super().load_state_dict(state_dict)
177
+
178
+ # for RPC
179
+ def update_policy_weights_(
180
+ self,
181
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
182
+ *,
183
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
184
+ **kwargs,
185
+ ) -> None:
186
+ if "policy_weights" in kwargs:
187
+ warnings.warn(
188
+ "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
189
+ DeprecationWarning,
190
+ )
191
+ policy_or_weights = kwargs.pop("policy_weights")
192
+
193
+ super().update_policy_weights_(
194
+ policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
195
+ )
196
+
197
+ def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int:
198
+ if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
199
+ return self._frames_per_batch[worker_idx]
200
+ if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
201
+ warnings.warn(
202
+ f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
203
+ f" this results in more frames_per_batch per iteration that requested."
204
+ "To silence this message, set the environment variable RL_WARNINGS to False."
205
+ )
206
+ frames_per_batch_worker = -(
207
+ -self.requested_frames_per_batch // self.num_workers
208
+ )
209
+ return frames_per_batch_worker
210
+
211
+ @property
212
+ def _queue_len(self) -> int:
213
+ return self.num_workers
214
+
215
+ def iterator(self) -> Iterator[TensorDictBase]:
216
+ cat_results = self.cat_results
217
+ if cat_results is None:
218
+ cat_results = "stack"
219
+
220
+ self.buffers = [None for _ in range(self.num_workers)]
221
+ dones = [False for _ in range(self.num_workers)]
222
+ workers_frames = [0 for _ in range(self.num_workers)]
223
+ same_device = None
224
+ self.out_buffer = None
225
+ preempt = self.interruptor is not None and self.preemptive_threshold < 1.0
226
+
227
+ while not all(dones) and self._frames < self.total_frames:
228
+ _check_for_faulty_process(self.procs)
229
+ if self.update_at_each_batch:
230
+ self.update_policy_weights_()
231
+
232
+ for idx in range(self.num_workers):
233
+ if self._should_use_random_frames():
234
+ msg = "continue_random"
235
+ else:
236
+ msg = "continue"
237
+ self.pipes[idx].send((None, msg))
238
+
239
+ self._iter += 1
240
+
241
+ if preempt:
242
+ self.interruptor.start_collection()
243
+ while self.queue_out.qsize() < int(
244
+ self.num_workers * self.preemptive_threshold
245
+ ):
246
+ continue
247
+ self.interruptor.stop_collection()
248
+ # Now wait for stragglers to return
249
+ while self.queue_out.qsize() < int(self.num_workers):
250
+ continue
251
+
252
+ recv = collections.deque()
253
+ t0 = time.time()
254
+ while len(recv) < self.num_workers and (
255
+ (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT)
256
+ ):
257
+ for _ in range(self.num_workers):
258
+ try:
259
+ new_data, j = self.queue_out.get(timeout=_TIMEOUT)
260
+ recv.append((new_data, j))
261
+ except (TimeoutError, Empty):
262
+ _check_for_faulty_process(self.procs)
263
+ if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT):
264
+ try:
265
+ self.shutdown()
266
+ finally:
267
+ raise RuntimeError(
268
+ f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. "
269
+ f"Increase the MAX_IDLE_COUNT environment variable to bypass this error."
270
+ )
271
+
272
+ for _ in range(self.num_workers):
273
+ new_data, j = recv.popleft()
274
+ use_buffers = self._use_buffers
275
+ if self.replay_buffer is not None:
276
+ idx = new_data
277
+ workers_frames[idx] = workers_frames[
278
+ idx
279
+ ] + self.frames_per_batch_worker(worker_idx=idx)
280
+ continue
281
+ elif j == 0 or not use_buffers:
282
+ try:
283
+ data, idx = new_data
284
+ self.buffers[idx] = data
285
+ if use_buffers is None and j > 0:
286
+ self._use_buffers = False
287
+ except TypeError:
288
+ if use_buffers is None:
289
+ self._use_buffers = True
290
+ idx = new_data
291
+ else:
292
+ raise
293
+ else:
294
+ idx = new_data
295
+
296
+ if preempt:
297
+ # mask buffers if cat, and create a mask if stack
298
+ if cat_results != "stack":
299
+ buffers = [None] * self.num_workers
300
+ for worker_idx, buffer in enumerate(self.buffers):
301
+ # Skip pre-empted envs:
302
+ if buffer is None:
303
+ continue
304
+ valid = buffer.get(("collector", "traj_ids")) != -1
305
+ if valid.ndim > 2:
306
+ valid = valid.flatten(0, -2)
307
+ if valid.ndim == 2:
308
+ valid = valid.any(0)
309
+ buffers[worker_idx] = buffer[..., valid]
310
+ else:
311
+ for buffer in filter(lambda x: x is not None, self.buffers):
312
+ with buffer.unlock_():
313
+ buffer.set(
314
+ ("collector", "mask"),
315
+ buffer.get(("collector", "traj_ids")) != -1,
316
+ )
317
+ buffers = self.buffers
318
+ else:
319
+ buffers = self.buffers
320
+
321
+ # Skip frame counting if this worker didn't send data this iteration
322
+ # (happens when reusing buffers or on first iteration with some workers)
323
+ if self.buffers[idx] is None:
324
+ continue
325
+
326
+ workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()
327
+
328
+ if workers_frames[idx] >= self.total_frames:
329
+ dones[idx] = True
330
+
331
+ if self.replay_buffer is not None:
332
+ yield
333
+ self._frames += sum(
334
+ self.frames_per_batch_worker(worker_idx=worker_idx)
335
+ for worker_idx in range(self.num_workers)
336
+ )
337
+ continue
338
+
339
+ # we have to correct the traj_ids to make sure that they don't overlap
340
+ # We can count the number of frames collected for free in this loop
341
+ n_collected = 0
342
+ for idx in range(self.num_workers):
343
+ buffer = buffers[idx]
344
+ if buffer is None:
345
+ continue
346
+ traj_ids = buffer.get(("collector", "traj_ids"))
347
+ if preempt:
348
+ if cat_results == "stack":
349
+ mask_frames = buffer.get(("collector", "traj_ids")) != -1
350
+ n_collected += mask_frames.sum().cpu()
351
+ else:
352
+ n_collected += traj_ids.numel()
353
+ else:
354
+ n_collected += traj_ids.numel()
355
+
356
+ if same_device is None:
357
+ prev_device = None
358
+ same_device = True
359
+ for item in filter(lambda x: x is not None, self.buffers):
360
+ if prev_device is None:
361
+ prev_device = item.device
362
+ else:
363
+ same_device = same_device and (item.device == prev_device)
364
+
365
+ if self.split_trajs:
366
+ max_traj_id = -1
367
+ for idx in range(self.num_workers):
368
+ if buffers[idx] is not None:
369
+ traj_ids = buffers[idx].get(("collector", "traj_ids"))
370
+ if traj_ids is not None:
371
+ buffers[idx].set_(
372
+ ("collector", "traj_ids"), traj_ids + max_traj_id + 1
373
+ )
374
+ max_traj_id = (
375
+ buffers[idx].get(("collector", "traj_ids")).max()
376
+ )
377
+
378
+ if cat_results == "stack":
379
+ stack = (
380
+ torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
381
+ )
382
+ if same_device:
383
+ self.out_buffer = stack(
384
+ [item for item in buffers if item is not None], 0
385
+ )
386
+ else:
387
+ self.out_buffer = stack(
388
+ [item.cpu() for item in buffers if item is not None], 0
389
+ )
390
+ else:
391
+ if self._use_buffers is None:
392
+ torchrl_logger.warning(
393
+ "use_buffer not specified and not yet inferred from data, assuming `True`."
394
+ )
395
+ elif not self._use_buffers:
396
+ raise RuntimeError(
397
+ "Cannot concatenate results with use_buffers=False"
398
+ )
399
+ try:
400
+ if same_device:
401
+ self.out_buffer = torch.cat(
402
+ [item for item in buffers if item is not None], cat_results
403
+ )
404
+ else:
405
+ self.out_buffer = torch.cat(
406
+ [item.cpu() for item in buffers if item is not None],
407
+ cat_results,
408
+ )
409
+ except RuntimeError as err:
410
+ if (
411
+ preempt
412
+ and cat_results != -1
413
+ and "Sizes of tensors must match" in str(err)
414
+ ):
415
+ raise RuntimeError(
416
+ "The value provided to cat_results isn't compatible with the collectors outputs. "
417
+ "Consider using `cat_results=-1`."
418
+ )
419
+ raise
420
+
421
+ # TODO: why do we need to do cat inplace and clone?
422
+ if self.split_trajs:
423
+ out = split_trajectories(self.out_buffer, prefix="collector")
424
+ else:
425
+ out = self.out_buffer
426
+ if cat_results in (-1, "stack"):
427
+ out.refine_names(*[None] * (out.ndim - 1) + ["time"])
428
+
429
+ self._frames += n_collected
430
+
431
+ if self.postprocs:
432
+ self.postprocs = (
433
+ self.postprocs.to(out.device)
434
+ if hasattr(self.postprocs, "to")
435
+ else self.postprocs
436
+ )
437
+ out = self.postprocs(out)
438
+ if self._exclude_private_keys:
439
+ excluded_keys = [key for key in out.keys() if key.startswith("_")]
440
+ if excluded_keys:
441
+ out = out.exclude(*excluded_keys)
442
+ yield out
443
+
444
+ del self.buffers
445
+ self.out_buffer = None
446
+ # We shall not call shutdown just yet as user may want to retrieve state_dict
447
+ # self._shutdown_main()
448
+
449
+ # for RPC
450
+ def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
451
+ return super().receive_weights(policy_or_weights)
452
+
453
+ # for RPC
454
+ def _receive_weights_scheme(self):
455
+ return super()._receive_weights_scheme()
456
+
457
+
458
+ _LegacyMultiSyncMeta = _make_legacy_metaclass(_MultiCollectorMeta)
459
+
460
+
461
+ class MultiSyncDataCollector(MultiSyncCollector, metaclass=_LegacyMultiSyncMeta):
462
+ """Deprecated version of :class:`~torchrl.collectors.MultiSyncCollector`."""
463
+
464
+ ...