torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,581 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import queue
5
+ from collections.abc import Callable
6
+ from functools import partial
7
+ from multiprocessing import connection, queues
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import torch
12
+ from tensordict import TensorDict, TensorDictBase
13
+
14
+ from torchrl import logger as torchrl_logger
15
+ from torchrl._utils import timeit, VERBOSE
16
+ from torchrl.collectors._base import BaseCollector, ProfileConfig
17
+ from torchrl.collectors._constants import (
18
+ _MAX_IDLE_COUNT,
19
+ _MIN_TIMEOUT,
20
+ _TIMEOUT,
21
+ DEFAULT_EXPLORATION_TYPE,
22
+ )
23
+ from torchrl.collectors._single import Collector
24
+
25
+ from torchrl.collectors.utils import (
26
+ _cast,
27
+ _make_policy_factory,
28
+ _map_to_cpu_if_needed,
29
+ _TrajectoryPool,
30
+ )
31
+ from torchrl.data import ReplayBuffer
32
+ from torchrl.envs import EnvBase, EnvCreator
33
+ from torchrl.envs.utils import ExplorationType
34
+ from torchrl.weight_update import WeightSyncScheme
35
+
36
+
37
+ class _WorkerProfiler:
38
+ """Helper class for profiling worker rollouts.
39
+
40
+ Manages the PyTorch profiler lifecycle for a worker process,
41
+ handling warmup, active profiling, and trace export.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ profile_config: ProfileConfig,
47
+ worker_idx: int,
48
+ ):
49
+ self.config = profile_config
50
+ self.worker_idx = worker_idx
51
+ self.rollout_count = 0
52
+ self._profiler = None
53
+ self._stopped = False
54
+ self._active = False
55
+
56
+ # Check if this worker should be profiled
57
+ if not self.config.should_profile_worker(worker_idx):
58
+ return
59
+
60
+ # Set up profiler schedule
61
+ # - skip_first: warmup rollouts (profiler runs but data discarded)
62
+ # - wait: 0 (no wait between cycles)
63
+ # - warmup: 0 (we handle warmup via skip_first)
64
+ # - active: num_rollouts - warmup_rollouts
65
+ # - repeat: 1 (single profiling cycle)
66
+ active_rollouts = self.config.num_rollouts - self.config.warmup_rollouts
67
+ profiler_schedule = torch.profiler.schedule(
68
+ skip_first=self.config.warmup_rollouts,
69
+ wait=0,
70
+ warmup=0,
71
+ active=active_rollouts,
72
+ repeat=1,
73
+ )
74
+
75
+ # Get activities
76
+ activities = self.config.get_activities()
77
+ if not activities:
78
+ torchrl_logger.warning(
79
+ f"Worker {worker_idx}: No profiler activities available. Profiling disabled."
80
+ )
81
+ return
82
+
83
+ # Determine trace handler
84
+ if self.config.on_trace_ready is not None:
85
+ on_trace_ready = self.config.on_trace_ready
86
+ else:
87
+ save_path = self.config.get_save_path(worker_idx)
88
+ save_path.parent.mkdir(parents=True, exist_ok=True)
89
+
90
+ def on_trace_ready(prof, save_path=save_path):
91
+ prof.export_chrome_trace(str(save_path))
92
+ torchrl_logger.info(
93
+ f"Worker {worker_idx}: Profiling trace saved to {save_path}"
94
+ )
95
+
96
+ self._profiler = torch.profiler.profile(
97
+ activities=activities,
98
+ schedule=profiler_schedule,
99
+ on_trace_ready=on_trace_ready,
100
+ record_shapes=self.config.record_shapes,
101
+ profile_memory=self.config.profile_memory,
102
+ with_stack=self.config.with_stack,
103
+ with_flops=self.config.with_flops,
104
+ )
105
+ self._active = True
106
+
107
+ def start(self) -> None:
108
+ """Start the profiler."""
109
+ if self._profiler is not None and not self._stopped:
110
+ self._profiler.start()
111
+ torchrl_logger.info(
112
+ f"Worker {self.worker_idx}: Profiling started. "
113
+ f"Will profile rollouts {self.config.warmup_rollouts} to {self.config.num_rollouts - 1}."
114
+ )
115
+
116
+ def step(self) -> bool:
117
+ """Step the profiler after a rollout.
118
+
119
+ Returns:
120
+ True if profiling is complete.
121
+ """
122
+ if self._profiler is None or self._stopped:
123
+ return False
124
+
125
+ self.rollout_count += 1
126
+ self._profiler.step()
127
+
128
+ # Check if profiling is complete
129
+ if self.rollout_count >= self.config.num_rollouts:
130
+ self.stop()
131
+ return True
132
+
133
+ return False
134
+
135
+ def stop(self) -> None:
136
+ """Stop the profiler and export trace."""
137
+ if self._profiler is not None and not self._stopped:
138
+ self._profiler.stop()
139
+ self._stopped = True
140
+ torchrl_logger.info(
141
+ f"Worker {self.worker_idx}: Profiling complete after {self.rollout_count} rollouts."
142
+ )
143
+
144
+ @property
145
+ def is_active(self) -> bool:
146
+ """Check if profiling is active."""
147
+ return self._active and not self._stopped
148
+
149
+ @contextlib.contextmanager
150
+ def profile_rollout(self):
151
+ """Context manager for profiling a single rollout."""
152
+ if self._profiler is not None and not self._stopped:
153
+ with torch.profiler.record_function(f"worker_{self.worker_idx}_rollout"):
154
+ yield
155
+ else:
156
+ yield
157
+
158
+
159
+ def _main_async_collector(
160
+ pipe_child: connection.Connection,
161
+ queue_out: queues.Queue,
162
+ create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821
163
+ create_env_kwargs: dict[str, Any],
164
+ policy: Callable[[TensorDictBase], TensorDictBase],
165
+ max_frames_per_traj: int,
166
+ frames_per_batch: int,
167
+ reset_at_each_iter: bool,
168
+ storing_device: torch.device | str | int | None,
169
+ env_device: torch.device | str | int | None,
170
+ policy_device: torch.device | str | int | None,
171
+ idx: int = 0,
172
+ exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
173
+ reset_when_done: bool = True,
174
+ verbose: bool = VERBOSE,
175
+ interruptor=None,
176
+ set_truncated: bool = False,
177
+ use_buffers: bool | None = None,
178
+ replay_buffer: ReplayBuffer | None = None,
179
+ extend_buffer: bool = True,
180
+ traj_pool: _TrajectoryPool = None,
181
+ trust_policy: bool = False,
182
+ compile_policy: bool = False,
183
+ cudagraph_policy: bool = False,
184
+ no_cuda_sync: bool = False,
185
+ policy_factory: Callable | None = None,
186
+ collector_class: type | Callable[[], BaseCollector] | None = None,
187
+ postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
188
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
189
+ worker_idx: int | None = None,
190
+ init_random_frames: int | None = None,
191
+ profile_config: ProfileConfig | None = None,
192
+ ) -> None:
193
+ if collector_class is None:
194
+ collector_class = Collector
195
+ # init variables that will be cleared when closing
196
+ collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None
197
+
198
+ # Make a policy-factory out of the policy
199
+ policy_factory = partial(
200
+ _make_policy_factory,
201
+ policy=policy,
202
+ policy_factory=policy_factory,
203
+ weight_sync_scheme=weight_sync_schemes.get("policy")
204
+ if weight_sync_schemes
205
+ else None,
206
+ worker_idx=worker_idx,
207
+ pipe=pipe_child,
208
+ )
209
+ policy = None
210
+ # Store the original init_random_frames for run_free mode logic
211
+ original_init_random_frames = (
212
+ init_random_frames if init_random_frames is not None else 0
213
+ )
214
+ try:
215
+ collector_class._ignore_rb = extend_buffer
216
+ inner_collector = collector_class(
217
+ create_env_fn,
218
+ create_env_kwargs=create_env_kwargs,
219
+ policy=policy,
220
+ policy_factory=policy_factory,
221
+ total_frames=-1,
222
+ max_frames_per_traj=max_frames_per_traj,
223
+ frames_per_batch=frames_per_batch,
224
+ reset_at_each_iter=reset_at_each_iter,
225
+ postproc=postproc,
226
+ split_trajs=False,
227
+ storing_device=storing_device,
228
+ policy_device=policy_device,
229
+ env_device=env_device,
230
+ exploration_type=exploration_type,
231
+ reset_when_done=reset_when_done,
232
+ return_same_td=replay_buffer is None,
233
+ interruptor=interruptor,
234
+ set_truncated=set_truncated,
235
+ use_buffers=use_buffers,
236
+ replay_buffer=replay_buffer,
237
+ extend_buffer=extend_buffer,
238
+ traj_pool=traj_pool,
239
+ trust_policy=trust_policy,
240
+ compile_policy=compile_policy,
241
+ cudagraph_policy=cudagraph_policy,
242
+ no_cuda_sync=no_cuda_sync,
243
+ # We don't pass the weight sync scheme as only the sender has the weight sync scheme within.
244
+ # weight_sync_schemes=weight_sync_schemes,
245
+ worker_idx=worker_idx,
246
+ # init_random_frames is passed; inner collector will use _should_use_random_frames()
247
+ # which checks replay_buffer.write_count when replay_buffer is provided
248
+ init_random_frames=init_random_frames,
249
+ )
250
+ # Set up weight receivers for worker process using the standard register_scheme_receiver API.
251
+ # This properly initializes the schemes on the receiver side and stores them in _receiver_schemes.
252
+ if weight_sync_schemes:
253
+ inner_collector.register_scheme_receiver(weight_sync_schemes)
254
+
255
+ use_buffers = inner_collector._use_buffers
256
+ if verbose:
257
+ torchrl_logger.debug("Sync data collector created")
258
+
259
+ # Set up profiler for this worker if configured
260
+ worker_profiler = None
261
+ if profile_config is not None:
262
+ worker_profiler = _WorkerProfiler(profile_config, worker_idx)
263
+ if worker_profiler.is_active:
264
+ worker_profiler.start()
265
+
266
+ dc_iter = iter(inner_collector)
267
+ j = 0
268
+ pipe_child.send("instantiated")
269
+ except Exception as e:
270
+ # Send error information to main process
271
+ # We send a dict with the exception info so we can recreate it in the main process
272
+ import traceback
273
+
274
+ error_info = {
275
+ "error": True,
276
+ "exception_type": type(e).__name__,
277
+ "exception_module": type(e).__module__,
278
+ "exception_msg": str(e),
279
+ "traceback": traceback.format_exc(),
280
+ }
281
+ try:
282
+ pipe_child.send(error_info)
283
+ except Exception:
284
+ # If pipe is broken, nothing we can do
285
+ pass
286
+ return
287
+
288
+ has_timed_out = False
289
+ counter = 0
290
+ run_free = False
291
+ while True:
292
+ _timeout = _TIMEOUT if not has_timed_out else 1e-3
293
+ if not run_free and pipe_child.poll(_timeout):
294
+ counter = 0
295
+ try:
296
+ data_in, msg = pipe_child.recv()
297
+ if verbose:
298
+ torchrl_logger.debug(f"mp worker {idx} received {msg}")
299
+ except EOFError:
300
+ raise
301
+ elif not run_free:
302
+ if verbose:
303
+ torchrl_logger.debug(f"poll failed, j={j}, worker={idx}")
304
+ # default is "continue" (after first iteration)
305
+ # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe
306
+ # in that case, the main process probably expects the worker to continue collect data
307
+ if has_timed_out:
308
+ counter = 0
309
+ # has_timed_out is True if the process failed to send data, which will
310
+ # typically occur if main has taken another batch (i.e. the queue is Full).
311
+ # In this case, msg is the previous msg sent by main, which will typically be "continue"
312
+ # If it's not the case, it is not expected that has_timed_out is True.
313
+ if msg not in ("continue", "continue_random"):
314
+ raise RuntimeError(f"Unexpected message after time out: msg={msg}")
315
+ else:
316
+ # if has_timed_out is False, then the time out does not come from the fact that the queue is Full.
317
+ # this means that our process has been waiting for a command from main in vain, while main was not
318
+ # receiving data.
319
+ # This will occur if main is busy doing something else (e.g. computing loss etc).
320
+
321
+ counter += _timeout
322
+ if verbose:
323
+ torchrl_logger.debug(f"mp worker {idx} has counter {counter}")
324
+ if counter >= (_MAX_IDLE_COUNT * _TIMEOUT):
325
+ raise RuntimeError(
326
+ f"This process waited for {counter} seconds "
327
+ f"without receiving a command from main. Consider increasing the maximum idle count "
328
+ f"if this is expected via the environment variable MAX_IDLE_COUNT "
329
+ f"(current value is {_MAX_IDLE_COUNT})."
330
+ f"\nIf this occurs at the end of a function or program, it means that your collector has not been "
331
+ f"collected, consider calling `collector.shutdown()` before ending the program."
332
+ )
333
+ continue
334
+ else:
335
+ # placeholder, will be checked after
336
+ msg = "continue"
337
+ if msg == "run_free":
338
+ run_free = True
339
+ msg = "continue"
340
+ if run_free:
341
+ # Capture shutdown / update / seed signal, but continue should not be expected
342
+ if pipe_child.poll(1e-4):
343
+ data_in, msg = pipe_child.recv()
344
+ if msg == "continue":
345
+ # Switch back to run_free = False
346
+ run_free = False
347
+ if msg == "pause":
348
+ queue_out.put((idx, "paused"), timeout=_TIMEOUT)
349
+ while not pipe_child.poll(1e-2):
350
+ continue
351
+ data_in, msg = pipe_child.recv()
352
+ if msg != "restart":
353
+ raise RuntimeError(f"Expected msg='restart', got {msg=}")
354
+ msg = "continue"
355
+ else:
356
+ data_in = None
357
+ # In run_free mode, determine msg based on replay_buffer.write_count for random frames
358
+ if (
359
+ replay_buffer is not None
360
+ and original_init_random_frames > 0
361
+ and replay_buffer.write_count < original_init_random_frames
362
+ ):
363
+ msg = "continue_random"
364
+ else:
365
+ msg = "continue"
366
+ # Note: Weight updates are handled by background threads in weight sync schemes.
367
+ # The scheme's background receiver thread listens for "receive" instructions.
368
+
369
+ if msg == "enable_profile":
370
+ # Handle profile configuration sent after worker startup
371
+ if worker_profiler is None or not worker_profiler.is_active:
372
+ worker_profiler = _WorkerProfiler(data_in, worker_idx)
373
+ if worker_profiler.is_active:
374
+ worker_profiler.start()
375
+ pipe_child.send((j, "profile_enabled"))
376
+ has_timed_out = False
377
+ continue
378
+
379
+ if msg == "update":
380
+ # Legacy - weight updater
381
+ with timeit(f"worker/{idx}/update") as update_timer:
382
+ torchrl_logger.debug(
383
+ f"mp worker {idx}: Received weight update request..."
384
+ )
385
+ inner_collector.update_policy_weights_(policy_weights=data_in)
386
+ torchrl_logger.debug(
387
+ f"mp worker {idx}: Weight update completed in {update_timer.elapsed():.3f}s"
388
+ )
389
+ pipe_child.send((j, "updated"))
390
+ has_timed_out = False
391
+ continue
392
+
393
+ # Note: Weight updates are now handled by background threads in the weight sync schemes.
394
+ # The scheme's background receiver thread listens for "receive" instructions and
395
+ # applies weights automatically. No explicit message handling needed here.
396
+
397
+ if msg in ("continue", "continue_random"):
398
+ # When in run_free mode with a replay_buffer, the inner collector uses
399
+ # _should_use_random_frames() which checks replay_buffer.write_count.
400
+ # So we don't override init_random_frames. Otherwise, we use the message
401
+ # to control whether random frames are used.
402
+ if not run_free or replay_buffer is None:
403
+ if msg == "continue_random":
404
+ inner_collector.init_random_frames = float("inf")
405
+ else:
406
+ inner_collector.init_random_frames = -1
407
+
408
+ # Debug logging for rollout timing
409
+ # Use profiler context if profiling is active
410
+ profile_ctx = (
411
+ worker_profiler.profile_rollout()
412
+ if worker_profiler is not None and worker_profiler.is_active
413
+ else contextlib.nullcontext()
414
+ )
415
+ with profile_ctx:
416
+ with timeit(f"worker/{idx}/rollout") as rollout_timer:
417
+ torchrl_logger.debug(
418
+ f"mp worker {idx}: Starting rollout (j={j})..."
419
+ )
420
+ next_data = next(dc_iter)
421
+ torchrl_logger.debug(
422
+ f"mp worker {idx}: Rollout completed in {rollout_timer.elapsed():.3f}s, "
423
+ f"frames={next_data.numel() if hasattr(next_data, 'numel') else 'N/A'}"
424
+ )
425
+
426
+ # Step the profiler after each rollout
427
+ if worker_profiler is not None and worker_profiler.is_active:
428
+ worker_profiler.step()
429
+ if pipe_child.poll(_MIN_TIMEOUT):
430
+ # in this case, main send a message to the worker while it was busy collecting trajectories.
431
+ # In that case, we skip the collected trajectory and get the message from main. This is faster than
432
+ # sending the trajectory in the queue until timeout when it's never going to be received.
433
+ continue
434
+
435
+ if replay_buffer is not None:
436
+ if extend_buffer:
437
+ next_data.names = None
438
+ replay_buffer.extend(next_data)
439
+
440
+ if run_free:
441
+ continue
442
+
443
+ try:
444
+ queue_out.put((idx, j), timeout=_TIMEOUT)
445
+ if verbose:
446
+ torchrl_logger.debug(f"mp worker {idx} successfully sent data")
447
+ j += 1
448
+ has_timed_out = False
449
+ continue
450
+ except queue.Full:
451
+ has_timed_out = True
452
+ continue
453
+
454
+ if j == 0 or not use_buffers:
455
+ collected_tensordict = next_data
456
+ if (
457
+ storing_device is not None
458
+ and collected_tensordict.device != storing_device
459
+ ):
460
+ raise RuntimeError(
461
+ f"expected device to be {storing_device} but got {collected_tensordict.device}"
462
+ )
463
+ if use_buffers:
464
+ # If policy and env are on cpu, we put in shared mem,
465
+ # if policy is on cuda and env on cuda, we are fine with this
466
+ # If policy is on cuda and env on cpu (or opposite) we put tensors that
467
+ # are on cpu in shared mem.
468
+ MPS_ERROR = (
469
+ "tensors on mps device cannot be put in shared memory. Make sure "
470
+ "the shared device (aka storing_device) is set to CPU."
471
+ )
472
+ if collected_tensordict.device is not None:
473
+ # placeholder in case we need different behaviors
474
+ if collected_tensordict.device.type in ("cpu",):
475
+ collected_tensordict.share_memory_()
476
+ elif collected_tensordict.device.type in ("mps",):
477
+ raise RuntimeError(MPS_ERROR)
478
+ elif collected_tensordict.device.type == "cuda":
479
+ collected_tensordict.share_memory_()
480
+ else:
481
+ raise NotImplementedError(
482
+ f"Device {collected_tensordict.device} is not supported in multi-collectors yet."
483
+ )
484
+ else:
485
+ # make sure each cpu tensor is shared - assuming non-cpu devices are shared
486
+ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
487
+ if x.device.type in ("cpu",):
488
+ x.share_memory_()
489
+ if x.device.type in ("mps",):
490
+ RuntimeError(MPS_ERROR)
491
+
492
+ collected_tensordict.apply(cast_tensor, filter_empty=True)
493
+ data = (collected_tensordict, idx)
494
+ else:
495
+ if next_data is not collected_tensordict:
496
+ raise RuntimeError(
497
+ "Collector should return the same tensordict modified in-place."
498
+ )
499
+ data = idx # flag the worker that has sent its data
500
+ try:
501
+ queue_out.put((data, j), timeout=_TIMEOUT)
502
+ if verbose:
503
+ torchrl_logger.debug(f"mp worker {idx} successfully sent data")
504
+ j += 1
505
+ has_timed_out = False
506
+ continue
507
+ except queue.Full:
508
+ if verbose:
509
+ torchrl_logger.debug(f"mp worker {idx} has timed out")
510
+ has_timed_out = True
511
+ continue
512
+
513
+ if msg == "seed":
514
+ data_in, static_seed = data_in
515
+ new_seed = inner_collector.set_seed(data_in, static_seed=static_seed)
516
+ torch.manual_seed(data_in)
517
+ np.random.seed(data_in)
518
+ pipe_child.send((new_seed, "seeded"))
519
+ has_timed_out = False
520
+ continue
521
+
522
+ elif msg == "reset":
523
+ inner_collector.reset()
524
+ pipe_child.send((j, "reset"))
525
+ continue
526
+
527
+ elif msg == "state_dict":
528
+ from torch.utils._pytree import tree_map
529
+
530
+ state_dict = inner_collector.state_dict()
531
+ # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility
532
+ # CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth)
533
+ state_dict = tree_map(_map_to_cpu_if_needed, state_dict)
534
+ state_dict = TensorDict(state_dict)
535
+ state_dict = state_dict.clone().apply(_cast, state_dict).to_dict()
536
+ pipe_child.send((state_dict, "state_dict"))
537
+ has_timed_out = False
538
+ continue
539
+
540
+ elif msg == "load_state_dict":
541
+ state_dict = data_in
542
+ inner_collector.load_state_dict(state_dict)
543
+ del state_dict
544
+ pipe_child.send((j, "loaded"))
545
+ has_timed_out = False
546
+ continue
547
+
548
+ elif msg == "getattr_policy":
549
+ attr_name = data_in
550
+ try:
551
+ result = getattr(inner_collector.policy, attr_name)
552
+ pipe_child.send((result, "getattr_policy"))
553
+ except AttributeError as e:
554
+ pipe_child.send((e, "getattr_policy"))
555
+ has_timed_out = False
556
+ continue
557
+
558
+ elif msg == "getattr_env":
559
+ attr_name = data_in
560
+ try:
561
+ result = getattr(inner_collector.env, attr_name)
562
+ pipe_child.send((result, "getattr_env"))
563
+ except AttributeError as e:
564
+ pipe_child.send((e, "getattr_env"))
565
+ has_timed_out = False
566
+ continue
567
+
568
+ elif msg == "close":
569
+ # Stop profiler if active
570
+ if worker_profiler is not None and worker_profiler.is_active:
571
+ worker_profiler.stop()
572
+ del collected_tensordict, data, next_data, data_in
573
+ inner_collector.shutdown()
574
+ del inner_collector, dc_iter
575
+ pipe_child.send("closed")
576
+ if verbose:
577
+ torchrl_logger.debug(f"collector {idx} closed")
578
+ break
579
+
580
+ else:
581
+ raise Exception(f"Unrecognized message {msg}")