torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,160 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import subprocess
9
+ import time
10
+
11
+ from torchrl._utils import logger as torchrl_logger, VERBOSE
12
+ from torchrl.collectors.distributed.default_configs import (
13
+ DEFAULT_SLURM_CONF,
14
+ DEFAULT_SLURM_CONF_MAIN,
15
+ TCP_PORT,
16
+ )
17
+ from torchrl.collectors.distributed.generic import _distributed_init_delayed
18
+ from torchrl.collectors.distributed.rpc import _rpc_init_collection_node
19
+
20
+ try:
21
+ import submitit
22
+
23
+ _has_submitit = True
24
+ except ModuleNotFoundError as err:
25
+ _has_submitit = False
26
+ SUBMITIT_ERR = err
27
+
28
+
29
+ class submitit_delayed_launcher:
30
+ """Delayed launcher for submitit.
31
+
32
+ In some cases, launched jobs cannot spawn other jobs on their own and this
33
+ can only be done at the jump-host level.
34
+
35
+ In these cases, the :func:`submitit_delayed_launcher` can be used to
36
+ pre-launch collector nodes that will wait for the main worker to provide
37
+ the launching instruction.
38
+
39
+ Args:
40
+ num_jobs (int): the number of collection jobs to be launched.
41
+ framework (str, optional): the framework to use. Can be either ``"distributed"``
42
+ or ``"rpc"``. ``"distributed"`` requires a :class:`~.DistributedDataCollector`
43
+ collector whereas ``"rpc"`` requires a :class:`RPCDataCollector`.
44
+ Defaults to ``"distributed"``.
45
+ backend (str, optional): torch.distributed backend in case ``framework``
46
+ points to ``"distributed"``. This value must match the one passed to
47
+ the collector, otherwise main and satellite nodes will fail to
48
+ reach the rendezvous and hang forever (ie no exception will be raised!)
49
+ Defaults to ``'gloo'``.
50
+ tcpport (int or str, optional): the TCP port to use.
51
+ Defaults to :obj:`torchrl.collectors.distributed.default_configs.TCP_PORT`
52
+ submitit_main_conf (dict, optional): the main node configuration to be passed to submitit.
53
+ Defaults to :obj:`torchrl.collectors.distributed.default_configs.DEFAULT_SLURM_CONF_MAIN`
54
+ submitit_collection_conf (dict, optional): the configuration to be passed to submitit.
55
+ Defaults to :obj:`torchrl.collectors.distributed.default_configs.DEFAULT_SLURM_CONF`
56
+
57
+ Examples:
58
+ >>> num_jobs=2
59
+ >>> @submitit_delayed_launcher(num_jobs=num_jobs)
60
+ ... def main():
61
+ ... from torchrl.modules.utils.utils import RandomPolicyfrom torchrl.envs.libs.gym import GymEnv
62
+ ... from torchrl.data import BoundedContinuous
63
+ ... collector = DistributedDataCollector(
64
+ ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs,
65
+ ... policy=RandomPolicy(BoundedContinuous(-1, 1, shape=(1,))),
66
+ ... launcher="submitit_delayed",
67
+ ... )
68
+ ... for data in collector:
69
+ ... print(data)
70
+ ...
71
+ >>> if __name__ == "__main__":
72
+ ... main()
73
+ ...
74
+ """
75
+
76
+ _VERBOSE = VERBOSE # for debugging
77
+
78
+ def __init__(
79
+ self,
80
+ num_jobs,
81
+ framework="distributed",
82
+ backend="gloo",
83
+ tcpport=TCP_PORT,
84
+ submitit_main_conf: dict = DEFAULT_SLURM_CONF_MAIN,
85
+ submitit_collection_conf: dict = DEFAULT_SLURM_CONF,
86
+ ):
87
+ self.num_jobs = num_jobs
88
+ self.backend = backend
89
+ self.framework = framework
90
+ self.submitit_collection_conf = submitit_collection_conf
91
+ self.submitit_main_conf = submitit_main_conf
92
+ self.tcpport = tcpport
93
+
94
+ def __call__(self, main_func):
95
+ def exec_fun():
96
+ if not _has_submitit:
97
+ raise ModuleNotFoundError(
98
+ "Failed to import submitit. Check installation of the library."
99
+ ) from SUBMITIT_ERR
100
+ # submit main
101
+ executor = submitit.AutoExecutor(folder="log_test")
102
+ executor.update_parameters(**self.submitit_main_conf)
103
+ main_job = executor.submit(main_func)
104
+ # listen to output file looking for IP address
105
+ torchrl_logger.debug(f"job id: {main_job.job_id}")
106
+ time.sleep(2.0)
107
+ node = None
108
+ while not node:
109
+ cmd = f"squeue -j {main_job.job_id} -o %N | tail -1"
110
+ node = subprocess.check_output(cmd, shell=True, text=True).strip()
111
+ try:
112
+ node = int(node)
113
+ except ValueError:
114
+ time.sleep(0.5)
115
+ continue
116
+ torchrl_logger.debug(f"node: {node}")
117
+ # by default, sinfo will truncate the node name at char 20, we increase this to 200
118
+ cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1"
119
+ rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip()
120
+ torchrl_logger.debug(f"IP: {rank0_ip}")
121
+ world_size = self.num_jobs + 1
122
+
123
+ # submit jobs
124
+ executor = submitit.AutoExecutor(folder="log_test")
125
+ executor.update_parameters(**self.submitit_collection_conf)
126
+ jobs = []
127
+ if self.framework == "rpc":
128
+ from .rpc import DEFAULT_TENSORPIPE_OPTIONS
129
+
130
+ tensorpipe_options = DEFAULT_TENSORPIPE_OPTIONS
131
+ for i in range(self.num_jobs):
132
+ rank = i + 1
133
+ if self.framework == "distributed":
134
+ job = executor.submit(
135
+ _distributed_init_delayed,
136
+ rank,
137
+ self.backend,
138
+ rank0_ip,
139
+ self.tcpport,
140
+ world_size,
141
+ self._VERBOSE,
142
+ )
143
+ elif self.framework == "rpc":
144
+ job = executor.submit(
145
+ _rpc_init_collection_node,
146
+ rank,
147
+ rank0_ip,
148
+ self.tcpport,
149
+ world_size,
150
+ None,
151
+ tensorpipe_options,
152
+ )
153
+ else:
154
+ raise NotImplementedError(f"Unknown framework {self.framework}.")
155
+ jobs.append(job)
156
+ for job in jobs:
157
+ job.result()
158
+ main_job.result()
159
+
160
+ return exec_fun
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .base import LLMCollector
7
+ from .ray_collector import RayLLMCollector
8
+ from .weight_update import vLLMUpdater, vLLMUpdaterV2
9
+
10
+ __all__ = ["vLLMUpdater", "vLLMUpdaterV2", "LLMCollector", "RayLLMCollector"]
@@ -0,0 +1,494 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections import deque
8
+ from collections.abc import Callable
9
+ from typing import Any
10
+
11
+ import torch
12
+
13
+ from tensordict import lazy_stack, TensorDictBase
14
+
15
+ from torchrl._utils import as_remote, logger as torchrl_logger
16
+
17
+ from torchrl.collectors._single import Collector
18
+ from torchrl.collectors.llm.utils import _QueueAsRB
19
+ from torchrl.collectors.weight_update import WeightUpdaterBase
20
+ from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
21
+ from torchrl.envs import AsyncEnvPool
22
+ from torchrl.envs.common import EnvBase
23
+ from torchrl.envs.llm.transforms.policy_version import PolicyVersion
24
+
25
+
26
+ class LLMCollector(Collector):
27
+ """A simplified version of Collector for LLM inference.
28
+
29
+ Args:
30
+ env (EnvBase or EnvBase constructor): the environment to be used for data collection.
31
+
32
+ Keyword Args:
33
+ policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection.
34
+ policy_factory (Callable[[], Callable], optional): a callable that returns
35
+ a policy instance. This is exclusive with the `policy` argument.
36
+
37
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
38
+
39
+ dialog_turns_per_batch (int, optional): A keyword-only argument representing the total
40
+ number of elements in a batch. It is always required except when `yield_completed_trajectories=True`.
41
+ total_dialog_turns (int): A keyword-only argument representing the total
42
+ number of steps returned by the collector during its lifespan. -1 is never ending (until shutdown).
43
+ Defaults to -1.
44
+ yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
45
+ (`yield_completed_trajectories=False`, default) or single, completed trajectories
46
+ (`yield_completed_trajectories=True`).
47
+ Defaults to `False` unless `yield_only_last_steps=True`, where it cannot be `False`.
48
+
49
+ .. warning:: If the `done` state of the environment is not properly set, this may lead to a collector
50
+ that never leads any data.
51
+
52
+ yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the
53
+ last (done) steps.
54
+ If `True`, a single trajectory is yielded (or written in the buffer) at a time.
55
+
56
+ .. warning:: If the `done` state of the environment is not properly set, this may lead to a collector
57
+ that never leads any data.
58
+
59
+ postproc (Callable, optional): A post-processing transform, such as
60
+ a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
61
+ instance.
62
+ Defaults to ``None``.
63
+ async_envs (bool, optional): if ``True``, the environment will be run asynchronously. Defaults to `True` if the
64
+ environment is a :class:`~torchrl.envs.AsyncEnvPool` instance.
65
+ replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
66
+ but populate the buffer instead. Defaults to ``None``.
67
+ reset_at_each_iter (bool, optional): if ``True``, the environment will be reset at each iteration.
68
+ flatten_data (bool, optional): if ``True``, the collector will flatten the collected data
69
+ before returning it. In practice, this means that if an environment of batch-size `(B,)` is used
70
+ and run for `T` steps, `flatten_data=True` will present data of shape `(B*T,)`, whereas
71
+ `flatten_data=False` will not present data of shape `(B, T)`.
72
+ Defaults to `True` when `replay_buffer` is provided, `False` otherwise.
73
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
74
+ or its subclass, responsible for updating the policy weights on remote inference workers.
75
+ This is typically not used in :class:`~torchrl.collectors.Collector` as it operates in a single-process environment.
76
+ Consider using a constructor if the updater needs to be serialized.
77
+ track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
78
+ This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
79
+ Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
80
+ the policy version.
81
+ Defaults to `False`.
82
+ verbose (bool, optional): if ``True``, the collector will print progress information.
83
+ Defaults to `False`.
84
+
85
+ Examples:
86
+ >>> import vllm
87
+ >>> from torchrl.modules import vLLMWrapper
88
+ >>> from torchrl.testing.mocking_classes import DummyStrDataLoader
89
+ >>> from torchrl.envs import LLMEnv
90
+ >>> llm_model = vllm.LLM("gpt2")
91
+ >>> tokenizer = llm_model.get_tokenizer()
92
+ >>> tokenizer.pad_token = tokenizer.eos_token
93
+ >>> policy = vLLMWrapper(llm_model)
94
+ >>> dataloader = DummyStrDataLoader(1)
95
+ >>> env = LLMEnv.from_dataloader(
96
+ ... dataloader=dataloader,
97
+ ... tokenizer=tokenizer,
98
+ ... from_text=True,
99
+ ... batch_size=1,
100
+ ... group_repeats=True,
101
+ ... )
102
+ >>> collector = LLMCollector(
103
+ ... env=env,
104
+ ... policy_factory=lambda: policy,
105
+ ... dialog_turns_per_batch=env.batch_size[0],
106
+ ... total_dialog_turns=3,
107
+ ... )
108
+ >>> for i, data in enumerate(collector):
109
+ ... if i == 2:
110
+ ... print(data)
111
+ ... break
112
+ LazyStackedTensorDict(
113
+ fields={
114
+ attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False),
115
+ collector: LazyStackedTensorDict(
116
+ fields={
117
+ traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)},
118
+ exclusive_fields={
119
+ },
120
+ batch_size=torch.Size([1, 1]),
121
+ device=None,
122
+ is_shared=False,
123
+ stack_dim=1),
124
+ done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
125
+ terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
126
+ text: NonTensorStack(
127
+ [['plsgqejeyd']],
128
+ batch_size=torch.Size([1, 1]),
129
+ device=None),
130
+ text_response: NonTensorStack(
131
+ [['ec.n.n.n.tjbjz3perwhz']],
132
+ batch_size=torch.Size([1, 1]),
133
+ device=None),
134
+ tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False),
135
+ tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)},
136
+ exclusive_fields={
137
+ },
138
+ batch_size=torch.Size([1, 1]),
139
+ device=None,
140
+ is_shared=False,
141
+ stack_dim=1)
142
+ >>> del collector
143
+
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ env: EnvBase | Callable[[], EnvBase],
149
+ *,
150
+ policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
151
+ policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
152
+ | None = None,
153
+ dialog_turns_per_batch: int | None = None,
154
+ yield_only_last_steps: bool | None = None,
155
+ yield_completed_trajectories: bool | None = None,
156
+ postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
157
+ total_dialog_turns: int = -1,
158
+ async_envs: bool | None = None,
159
+ replay_buffer: ReplayBuffer | None = None,
160
+ reset_at_each_iter: bool = False,
161
+ flatten_data: bool | None = None,
162
+ weight_updater: WeightUpdaterBase
163
+ | Callable[[], WeightUpdaterBase]
164
+ | None = None,
165
+ queue: Any | None = None,
166
+ track_policy_version: bool | PolicyVersion = False,
167
+ verbose: bool = False,
168
+ ):
169
+ if queue is not None and replay_buffer is not None:
170
+ raise RuntimeError(
171
+ "Handling both a buffer and a queue is not possible at the moment."
172
+ )
173
+ elif queue is not None:
174
+ # disguise the queue as a replay buffer
175
+ replay_buffer = _QueueAsRB(queue)
176
+ if dialog_turns_per_batch is None and yield_completed_trajectories:
177
+ dialog_turns_per_batch = 1
178
+ super().__init__(
179
+ create_env_fn=env,
180
+ policy=policy,
181
+ policy_factory=policy_factory,
182
+ frames_per_batch=dialog_turns_per_batch,
183
+ replay_buffer=replay_buffer,
184
+ total_frames=total_dialog_turns,
185
+ weight_updater=weight_updater,
186
+ reset_at_each_iter=reset_at_each_iter,
187
+ trust_policy=True,
188
+ use_buffers=False,
189
+ no_cuda_sync=True,
190
+ extend_buffer=True,
191
+ postproc=postproc,
192
+ )
193
+ if hasattr(self.policy, "register_collector"):
194
+ self.policy.register_collector(self)
195
+
196
+ if yield_only_last_steps is None:
197
+ yield_only_last_steps = False
198
+
199
+ if yield_completed_trajectories is None:
200
+ yield_completed_trajectories = yield_only_last_steps
201
+ elif yield_only_last_steps and not yield_completed_trajectories:
202
+ raise TypeError(
203
+ "yield_only_last_steps=True requires yield_completed_trajectories=True (or None)"
204
+ )
205
+
206
+ if yield_only_last_steps:
207
+ if flatten_data is not None:
208
+ raise TypeError(
209
+ "`yield_only_last_steps` cannot be `True` when `flatten_data` is passed."
210
+ )
211
+ if self.reset_at_each_iter:
212
+ raise TypeError(
213
+ "`yield_only_last_steps` cannot be `True` when `reset_at_each_iter=True`."
214
+ )
215
+ if flatten_data is None:
216
+ flatten_data = replay_buffer is not None
217
+ self.flatten_data = flatten_data
218
+ self.yield_completed_trajectories = yield_completed_trajectories
219
+ self.yield_only_last_steps = yield_only_last_steps
220
+ self.verbose = verbose
221
+ self._shuttle = None # Initialize shuttle for rollout
222
+ if self.yield_completed_trajectories:
223
+ # For async envs, we route by env_id so we only care about batch_size[0].
224
+ # For non-async envs, we need exactly one batch dimension.
225
+ if not isinstance(self.env, AsyncEnvPool) and len(self.env.batch_size) != 1:
226
+ raise ValueError(
227
+ "`yield_completed_trajectories` only works with envs that have a single batch dimension. Got "
228
+ f"env.batch_size={self.env.batch_size}."
229
+ )
230
+ self._yield_queues = [deque() for _ in range(self.env.batch_size[0])]
231
+ self._trajectory_queue = deque()
232
+ self.async_envs = bool(async_envs) | isinstance(self.env, AsyncEnvPool)
233
+ if self.async_envs and not isinstance(self.env, AsyncEnvPool):
234
+ # This basically means that `async_envs` is automatically set and passing is it useless as of today,
235
+ # except for the following error.
236
+ raise RuntimeError(
237
+ "async_envs requires the environment to be an AsyncEnvPool instance."
238
+ )
239
+ self.policy_version_tracker = track_policy_version
240
+ if isinstance(track_policy_version, bool) and track_policy_version:
241
+ if isinstance(self.env, AsyncEnvPool):
242
+ raise RuntimeError(
243
+ "AsyncEnvPool is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
244
+ "and pass that transform to the collector."
245
+ )
246
+ self.policy_version_tracker = PolicyVersion()
247
+ self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
248
+ elif isinstance(track_policy_version, PolicyVersion):
249
+ self.policy_version_tracker = track_policy_version
250
+ self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
251
+ else:
252
+ self.policy_version_tracker = None
253
+
254
+ def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
255
+ if self.postproc is not None:
256
+ raise RuntimeError("Postproc already set")
257
+ self.postproc = postproc
258
+
259
+ def increment_version(self):
260
+ """Increment the policy version."""
261
+ if self.policy_version_tracker is not None:
262
+ if not isinstance(self.policy_version_tracker, PolicyVersion):
263
+ raise RuntimeError(
264
+ "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
265
+ )
266
+ self.policy_version_tracker.increment_version()
267
+
268
+ @property
269
+ def policy_version(self) -> str | int | None:
270
+ """The current policy version."""
271
+ if not isinstance(self.policy_version_tracker, PolicyVersion):
272
+ return None
273
+ return self.policy_version_tracker.version
274
+
275
+ def get_policy_version(self) -> str | int | None:
276
+ """Get the current policy version.
277
+
278
+ This method exists to support remote calls in Ray actors, since properties
279
+ cannot be accessed directly through Ray's RPC mechanism.
280
+
281
+ Returns:
282
+ The current version number (int) or UUID (str), or None if version tracking is disabled.
283
+ """
284
+ return self.policy_version
285
+
286
+ @property
287
+ def total_dialog_turns(self):
288
+ return self.total_frames
289
+
290
+ @property
291
+ def dialog_turns_per_batch(self) -> int:
292
+ """Alias to `frames_per_batch`."""
293
+ return self.requested_frames_per_batch
294
+
295
+ @property
296
+ def rollout(self) -> Callable[[], TensorDictBase]:
297
+ if self.yield_completed_trajectories:
298
+ if self.async_envs:
299
+ return self._rollout_yield_trajs_async
300
+ else:
301
+ return self._rollout_yield_trajs
302
+ else:
303
+ return self._rollout_all
304
+
305
+ def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout
306
+ if self.reset_at_each_iter or self._shuttle is None:
307
+ self._shuttle = self.env.reset()
308
+
309
+ trajectory = []
310
+ collected_steps = 0
311
+ policy_input = self._shuttle
312
+ while collected_steps < self.dialog_turns_per_batch:
313
+ if self.verbose:
314
+ torchrl_logger.debug(
315
+ f"LLMCollector: Collected {collected_steps} steps over {self.dialog_turns_per_batch} requested."
316
+ )
317
+ env_input = self.policy(policy_input)
318
+ env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
319
+
320
+ # carry over collector data without messing up devices
321
+ collector_data = env_output.get("collector", default=None)
322
+ if collector_data is not None:
323
+ env_next_output.set("collector", collector_data.copy())
324
+ self._update_traj_ids(env_output)
325
+ trajectory.append(env_output.clone())
326
+ collected_steps += env_output.numel()
327
+ policy_input = self._shuttle = env_next_output
328
+ trajectory = lazy_stack(trajectory, -1)
329
+ if self.flatten_data:
330
+ return trajectory.view(-1)
331
+ return trajectory
332
+
333
+ _result_numel = 0
334
+
335
+ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rollout
336
+ if self._shuttle is None:
337
+ self._shuttle = self.env.reset()
338
+ next_output = self._shuttle
339
+
340
+ collected_steps = 0
341
+ dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
342
+ while True:
343
+ if self._result_numel >= self.dialog_turns_per_batch:
344
+ break
345
+ elif self.verbose:
346
+ torchrl_logger.debug(
347
+ f"LLMCollector: Collected {collected_steps} steps with {self._result_numel} elements in the resulting batch, over {self.dialog_turns_per_batch} requested."
348
+ )
349
+ env_input = self.policy(next_output)
350
+ cur_output, next_output = self.env.step_and_maybe_reset(env_input)
351
+ # for i in range(cur_output.numel()):
352
+ # print(len(cur_output[i]["text"]) < len(cur_output[i]["next", "text"]))
353
+
354
+ # carry over collector data without messing up devices
355
+ collector_data = cur_output.get("collector", default=None)
356
+ if collector_data is not None:
357
+ self._update_traj_ids(cur_output)
358
+ next_output.set("collector", collector_data.copy())
359
+
360
+ # if the loop is interrupted
361
+ self._shuttle = next_output
362
+ collected_steps += next_output.numel()
363
+ for i, (_data, queue) in enumerate(
364
+ zip(cur_output.unbind(0), self._yield_queues)
365
+ ):
366
+ queue.append(_data)
367
+ dones[i] = _data["next", "done"].any()
368
+ if dones.any():
369
+ for idx in dones.nonzero(as_tuple=True)[0].tolist():
370
+ if not self.yield_only_last_steps:
371
+ _result = lazy_stack(self._yield_queues[idx], -1)
372
+ self._trajectory_queue.append(_result)
373
+ else:
374
+ # FIXME: We need to increment the step count here because iterator() won't
375
+ # see the extra steps
376
+ # We use lazy-stack because unsqueeze doesn't nest the strings in lists
377
+ _result = lazy_stack([self._yield_queues[idx][-1]])
378
+ self._trajectory_queue.append(_result)
379
+ self._result_numel += _result.numel()
380
+ self._yield_queues[idx].clear()
381
+ result = [self._trajectory_queue.popleft()]
382
+ elt = result[0].numel()
383
+ self._result_numel -= result[0].numel()
384
+ while elt < self.dialog_turns_per_batch:
385
+ result.append(self._trajectory_queue.popleft())
386
+ elt += result[-1].numel()
387
+ self._result_numel -= result[-1].numel()
388
+ result = torch.cat(result, -1)
389
+ if self.verbose:
390
+ torchrl_logger.debug(
391
+ f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
392
+ )
393
+ return result
394
+
395
+ started = False
396
+
397
+ def _rollout_yield_trajs_async(
398
+ self,
399
+ ) -> TensorDictBase: # A simplified version of rollout
400
+ if not self.started:
401
+ if self._shuttle is None:
402
+ self._shuttle = self.env.reset()
403
+ next_output = self._shuttle
404
+ env_input = self.policy(next_output)
405
+ self.env.async_step_and_maybe_reset_send(env_input)
406
+ self.started = True
407
+
408
+ collected_steps = 0
409
+ # Use only the first dimension (num_envs) for done tracking, since we route by env_id
410
+ dones = torch.zeros(self.env.batch_size[0], dtype=torch.bool)
411
+ while True:
412
+ if self._trajectory_queue:
413
+ break
414
+
415
+ cur_output, next_output = self.env.async_step_and_maybe_reset_recv()
416
+
417
+ # Get the env ids - flatten to handle multi-dimensional batch sizes
418
+ # (e.g., AsyncEnvPool with batch_size=[4, 1] gives [[0], [1], [2], [3]])
419
+ env_ids_raw = cur_output.get(self.env._env_idx_key).tolist()
420
+ # Flatten nested lists to get scalar env indices
421
+ env_ids = []
422
+ for eid in env_ids_raw:
423
+ while isinstance(eid, list) and len(eid) == 1:
424
+ eid = eid[0]
425
+ env_ids.append(eid)
426
+
427
+ # carry over collector data without messing up devices
428
+ collector_data = cur_output.get("collector", default=None)
429
+ if collector_data is not None:
430
+ self._update_traj_ids(cur_output)
431
+ next_output.set("collector", collector_data.copy())
432
+
433
+ collected_steps += next_output.numel()
434
+ dones.fill_(False)
435
+ for i, _data in zip(env_ids, cur_output.unbind(0)):
436
+ queue = self._yield_queues[i]
437
+ queue.append(_data)
438
+ dones[i] = _data["next", "done"].any()
439
+ if dones.any():
440
+ for idx in dones.nonzero(as_tuple=True)[0].tolist():
441
+ if not self.yield_only_last_steps:
442
+ self._trajectory_queue.append(
443
+ lazy_stack(self._yield_queues[idx], -1)
444
+ )
445
+ else:
446
+ # FIXME: We need to increment the step count here because iterator() won't
447
+ # see the extra steps
448
+ # We use lazy-stack because unsqueeze doesn't nest the strings in lists
449
+ self._trajectory_queue.append(
450
+ lazy_stack([self._yield_queues[idx][-1]])
451
+ )
452
+ self._yield_queues[idx].clear()
453
+
454
+ # Launch the next batch:
455
+ # FIXME: Add a condition RE number of frames here
456
+ if True:
457
+ env_input = self.policy(next_output)
458
+ self.env.async_step_and_maybe_reset_send(env_input)
459
+
460
+ result = self._trajectory_queue.popleft()
461
+ # Flatten the result - AsyncEnvPool child envs with batch_size=(1,) produce
462
+ # trajectories with shape [1, T] but we want [T] for consistency
463
+ result = result.view(-1)
464
+ if self.verbose:
465
+ torchrl_logger.debug(
466
+ f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
467
+ )
468
+ return result
469
+
470
+ as_remote = as_remote
471
+
472
+ def get_policy_model(self):
473
+ """Get the policy model.
474
+
475
+ This method is used by RayLLMCollector to get the remote LLM instance
476
+ for weight updates.
477
+
478
+ Returns:
479
+ The policy model instance
480
+ """
481
+ return self.policy.model
482
+
483
+ def is_initialized(self) -> bool:
484
+ """Check if the collector is initialized and ready.
485
+
486
+ Returns:
487
+ bool: True if the collector is initialized and ready to collect data.
488
+ """
489
+ # The collector is initialized if it has a valid environment and policy
490
+ return hasattr(self, "_env") and hasattr(self, "_policy")
491
+
492
+ def set_weight_updater(self, weight_updater: WeightUpdaterBase):
493
+ self.weight_updater = weight_updater
494
+ return True