torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1058 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import contextlib
5
+ import functools
6
+ import typing
7
+ import warnings
8
+ from collections import OrderedDict
9
+ from collections.abc import Callable, Iterator
10
+ from copy import deepcopy
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import Any, overload
14
+
15
+ import torch
16
+ from tensordict import TensorDict, TensorDictBase
17
+ from tensordict.base import NO_DEFAULT
18
+ from tensordict.nn import TensorDictModule, TensorDictModuleBase
19
+ from torch import nn as nn
20
+ from torch.utils.data import IterableDataset
21
+ from torchrl.collectors.utils import _map_weight
22
+
23
+ from torchrl.collectors.weight_update import WeightUpdaterBase
24
+ from torchrl.weight_update.utils import _resolve_attr
25
+ from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme
26
+
27
+
28
+ @dataclass
29
+ class ProfileConfig:
30
+ """Configuration for profiling collector workers.
31
+
32
+ This class holds all the settings for profiling collector rollouts
33
+ using PyTorch's profiler. It's designed to work across all collector types.
34
+
35
+ Attributes:
36
+ workers: List of worker indices to profile. For single-process collectors
37
+ (like Collector), this is ignored. For multi-process collectors
38
+ (like MultiSyncCollector, MultiAsyncCollector), only the specified
39
+ workers will be profiled. Defaults to [0].
40
+ num_rollouts: Total number of rollouts to profile (including warmup).
41
+ After this many rollouts, profiling stops. Defaults to 3.
42
+ warmup_rollouts: Number of rollouts to skip before starting actual
43
+ profiling. This allows JIT/compile warmup. Defaults to 1.
44
+ save_path: Path to save the profiling trace. If None, traces are saved
45
+ to "./collector_profile_{worker_idx}.json". Supports {worker_idx}
46
+ placeholder for worker-specific files.
47
+ activities: List of profiler activities. Defaults to CPU and CUDA.
48
+ record_shapes: Whether to record tensor shapes. Defaults to True.
49
+ profile_memory: Whether to profile memory usage. Defaults to False.
50
+ with_stack: Whether to record stack traces. Defaults to True.
51
+ with_flops: Whether to compute FLOPS. Defaults to False.
52
+ on_trace_ready: Optional callback when trace is ready. If None,
53
+ traces are exported to Chrome trace format at save_path.
54
+
55
+ Example:
56
+ >>> from torchrl.collectors import MultiSyncCollector, ProfileConfig
57
+ >>> collector = MultiSyncCollector(...)
58
+ >>> collector.enable_profile(
59
+ ... workers=[0],
60
+ ... num_rollouts=5,
61
+ ... warmup_rollouts=2,
62
+ ... save_path="./traces/worker_{worker_idx}.json",
63
+ ... )
64
+ >>> for data in collector:
65
+ ... # First worker will be profiled for rollouts 2-4
66
+ ... process(data)
67
+ """
68
+
69
+ workers: list[int] = field(default_factory=lambda: [0])
70
+ num_rollouts: int = 3
71
+ warmup_rollouts: int = 1
72
+ save_path: str | Path | None = None
73
+ activities: list[str] = field(default_factory=lambda: ["cpu", "cuda"])
74
+ record_shapes: bool = True
75
+ profile_memory: bool = False
76
+ with_stack: bool = True
77
+ with_flops: bool = False
78
+ on_trace_ready: Callable | None = None
79
+
80
+ def __post_init__(self):
81
+ """Validate configuration after initialization."""
82
+ if self.num_rollouts <= self.warmup_rollouts:
83
+ raise ValueError(
84
+ f"num_rollouts ({self.num_rollouts}) must be greater than "
85
+ f"warmup_rollouts ({self.warmup_rollouts})"
86
+ )
87
+ if self.warmup_rollouts < 0:
88
+ raise ValueError(
89
+ f"warmup_rollouts must be >= 0, got {self.warmup_rollouts}"
90
+ )
91
+
92
+ def get_save_path(self, worker_idx: int) -> Path:
93
+ """Get the save path for a specific worker.
94
+
95
+ Args:
96
+ worker_idx: The worker index.
97
+
98
+ Returns:
99
+ Path object for the trace file.
100
+ """
101
+ if self.save_path is None:
102
+ return Path(f"./collector_profile_{worker_idx}.json")
103
+ path_str = str(self.save_path).format(worker_idx=worker_idx)
104
+ return Path(path_str)
105
+
106
+ def get_activities(self) -> list:
107
+ """Get PyTorch profiler activity list.
108
+
109
+ Returns:
110
+ List of torch.profiler.ProfilerActivity values.
111
+ """
112
+ import torch.profiler
113
+
114
+ activity_map = {
115
+ "cpu": torch.profiler.ProfilerActivity.CPU,
116
+ "cuda": torch.profiler.ProfilerActivity.CUDA,
117
+ }
118
+ result = []
119
+ for activity in self.activities:
120
+ activity_lower = activity.lower()
121
+ if activity_lower in activity_map:
122
+ # Only add CUDA if CUDA is available
123
+ if activity_lower == "cuda" and not torch.cuda.is_available():
124
+ continue
125
+ result.append(activity_map[activity_lower])
126
+ return result
127
+
128
+ def should_profile_worker(self, worker_idx: int) -> bool:
129
+ """Check if a specific worker should be profiled.
130
+
131
+ Args:
132
+ worker_idx: The worker index to check.
133
+
134
+ Returns:
135
+ True if this worker should be profiled.
136
+ """
137
+ return worker_idx in self.workers
138
+
139
+
140
+ class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
141
+ """Base class for data collectors."""
142
+
143
+ _task = None
144
+ _iterator = None
145
+ _iteration_started = False
146
+ total_frames: int
147
+ requested_frames_per_batch: int
148
+ frames_per_batch: int
149
+ trust_policy: bool
150
+ compiled_policy: bool
151
+ cudagraphed_policy: bool
152
+ _weight_updater: WeightUpdaterBase | None = None
153
+ _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None
154
+ verbose: bool = False
155
+ _profile_config: ProfileConfig | None = None
156
+
157
+ def enable_profile(
158
+ self,
159
+ *,
160
+ workers: list[int] | None = None,
161
+ num_rollouts: int = 3,
162
+ warmup_rollouts: int = 1,
163
+ save_path: str | Path | None = None,
164
+ activities: list[str] | None = None,
165
+ record_shapes: bool = True,
166
+ profile_memory: bool = False,
167
+ with_stack: bool = True,
168
+ with_flops: bool = False,
169
+ on_trace_ready: Callable | None = None,
170
+ ) -> None:
171
+ """Enable profiling for collector worker rollouts.
172
+
173
+ This method configures the collector to profile rollouts using PyTorch's
174
+ profiler. For multi-process collectors, profiling happens in the worker
175
+ processes. For single-process collectors (Collector), profiling happens
176
+ in the main process.
177
+
178
+ Args:
179
+ workers: List of worker indices to profile. Defaults to [0].
180
+ For single-process collectors, this is ignored.
181
+ num_rollouts: Total number of rollouts to run the profiler for
182
+ (including warmup). Profiling stops after this many rollouts.
183
+ Defaults to 3.
184
+ warmup_rollouts: Number of rollouts to skip before starting actual
185
+ profiling. Useful for JIT/compile warmup. The profiler runs
186
+ but discards data during warmup. Defaults to 1.
187
+ save_path: Path to save the profiling trace. Supports {worker_idx}
188
+ placeholder for worker-specific files. If None, traces are
189
+ saved to "./collector_profile_{worker_idx}.json".
190
+ activities: List of profiler activities ("cpu", "cuda").
191
+ Defaults to ["cpu", "cuda"].
192
+ record_shapes: Whether to record tensor shapes. Defaults to True.
193
+ profile_memory: Whether to profile memory usage. Defaults to False.
194
+ with_stack: Whether to record Python stack traces. Defaults to True.
195
+ with_flops: Whether to compute FLOPS. Defaults to False.
196
+ on_trace_ready: Optional callback when trace is ready. If None,
197
+ traces are exported to Chrome trace format at save_path.
198
+
199
+ Raises:
200
+ RuntimeError: If called after iteration has started.
201
+ ValueError: If num_rollouts <= warmup_rollouts.
202
+
203
+ Example:
204
+ >>> from torchrl.collectors import MultiSyncCollector
205
+ >>> collector = MultiSyncCollector(
206
+ ... create_env_fn=[make_env] * 4,
207
+ ... policy=policy,
208
+ ... frames_per_batch=1000,
209
+ ... total_frames=100000,
210
+ ... )
211
+ >>> collector.enable_profile(
212
+ ... workers=[0],
213
+ ... num_rollouts=5,
214
+ ... warmup_rollouts=2,
215
+ ... save_path="./traces/worker_{worker_idx}.json",
216
+ ... )
217
+ >>> # Worker 0 will be profiled for rollouts 2, 3, 4
218
+ >>> for data in collector:
219
+ ... train(data)
220
+ >>> collector.shutdown()
221
+
222
+ Note:
223
+ - Profiling adds overhead, so only profile specific workers
224
+ - The trace file can be viewed in Chrome's trace viewer
225
+ (chrome://tracing) or with PyTorch's TensorBoard plugin
226
+ - For multi-process collectors, this must be called BEFORE
227
+ iteration starts as it needs to configure workers
228
+ """
229
+ if self._iteration_started:
230
+ raise RuntimeError(
231
+ "Cannot enable profiling after iteration has started. "
232
+ "Call enable_profile() before iterating over the collector."
233
+ )
234
+
235
+ if workers is None:
236
+ workers = [0]
237
+ if activities is None:
238
+ activities = ["cpu", "cuda"]
239
+
240
+ self._profile_config = ProfileConfig(
241
+ workers=workers,
242
+ num_rollouts=num_rollouts,
243
+ warmup_rollouts=warmup_rollouts,
244
+ save_path=save_path,
245
+ activities=activities,
246
+ record_shapes=record_shapes,
247
+ profile_memory=profile_memory,
248
+ with_stack=with_stack,
249
+ with_flops=with_flops,
250
+ on_trace_ready=on_trace_ready,
251
+ )
252
+
253
+ @property
254
+ def profile_config(self) -> ProfileConfig | None:
255
+ """Get the profiling configuration.
256
+
257
+ Returns:
258
+ ProfileConfig if profiling is enabled, None otherwise.
259
+ """
260
+ return self._profile_config
261
+
262
+ @property
263
+ def weight_updater(self) -> WeightUpdaterBase:
264
+ return self._weight_updater
265
+
266
+ @weight_updater.setter
267
+ def weight_updater(self, value: WeightUpdaterBase | None):
268
+ if value is not None:
269
+ if not isinstance(value, WeightUpdaterBase) and callable(
270
+ value
271
+ ): # Fall back to default constructor
272
+ value = value()
273
+ value.register_collector(self)
274
+ if value.collector is not self:
275
+ raise RuntimeError("Failed to register collector.")
276
+ self._weight_updater = value
277
+
278
+ @property
279
+ def worker_idx(self) -> int | None:
280
+ """Get the worker index for this collector.
281
+
282
+ Returns:
283
+ The worker index (0-indexed).
284
+
285
+ Raises:
286
+ RuntimeError: If worker_idx has not been set.
287
+ """
288
+ if not hasattr(self, "_worker_idx"):
289
+ raise RuntimeError(
290
+ "worker_idx has not been set. This collector may not have been "
291
+ "initialized as a worker in a distributed setup."
292
+ )
293
+ return self._worker_idx
294
+
295
+ @worker_idx.setter
296
+ def worker_idx(self, value: int | None) -> None:
297
+ """Set the worker index for this collector.
298
+
299
+ Args:
300
+ value: The worker index (0-indexed) or None.
301
+ """
302
+ self._worker_idx = value
303
+
304
+ def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any:
305
+ """Execute a method on a nested attribute of this collector.
306
+
307
+ This method allows remote callers to invoke methods on nested attributes
308
+ of the collector without needing to know the full structure. It's particularly
309
+ useful for calling methods on weight sync schemes from the sender side.
310
+
311
+ Args:
312
+ attr_path: Full path to the callable, e.g.,
313
+ "_receiver_schemes['model_id']._set_dist_connection_info"
314
+ *args: Positional arguments to pass to the method.
315
+ **kwargs: Keyword arguments to pass to the method.
316
+
317
+ Returns:
318
+ The return value of the method call.
319
+
320
+ Examples:
321
+ >>> collector.cascade_execute(
322
+ ... "_receiver_schemes['policy']._set_dist_connection_info",
323
+ ... connection_info_ref,
324
+ ... worker_idx=0
325
+ ... )
326
+ """
327
+ attr = _resolve_attr(self, attr_path)
328
+ if callable(attr):
329
+ return attr(*args, **kwargs)
330
+ else:
331
+ if args or kwargs:
332
+ raise ValueError(
333
+ f"Arguments and keyword arguments are not supported for non-callable attributes. Got {args} and {kwargs} for {attr_path}"
334
+ )
335
+ return attr
336
+
337
+ def _get_policy_and_device(
338
+ self,
339
+ policy: Callable[[Any], Any] | None = None,
340
+ policy_device: Any = NO_DEFAULT,
341
+ env_maker: Any | None = None,
342
+ env_maker_kwargs: dict[str, Any] | None = None,
343
+ ) -> tuple[TensorDictModule, None | Callable[[], dict]]:
344
+ """Util method to get a policy and its device given the collector __init__ inputs.
345
+
346
+ We want to copy the policy and then move the data there, not call policy.to(device).
347
+
348
+ Args:
349
+ policy (TensorDictModule, optional): a policy to be used
350
+ policy_device (torch.device, optional): the device where the policy should be placed.
351
+ Defaults to self.policy_device
352
+ env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
353
+ env_maker_kwargs (a dict, optional): the env_maker function kwargs.
354
+
355
+ """
356
+ if policy_device is NO_DEFAULT:
357
+ policy_device = self.policy_device
358
+
359
+ if not policy_device:
360
+ return policy, None
361
+
362
+ if isinstance(policy, nn.Module):
363
+ param_and_buf = TensorDict.from_module(policy, as_module=True)
364
+ else:
365
+ # Because we want to reach the warning
366
+ param_and_buf = TensorDict()
367
+
368
+ i = -1
369
+ for p in param_and_buf.values(True, True):
370
+ i += 1
371
+ if p.device != policy_device:
372
+ # Then we need casting
373
+ break
374
+ else:
375
+ if i == -1 and not self.trust_policy:
376
+ # We trust that the policy policy device is adequate
377
+ warnings.warn(
378
+ "A policy device was provided but no parameter/buffer could be found in "
379
+ "the policy. Casting to policy_device is therefore impossible. "
380
+ "The collector will trust that the devices match. To suppress this "
381
+ "warning, set `trust_policy=True` when building the collector."
382
+ )
383
+ return policy, None
384
+
385
+ # Create a stateless policy, then populate this copy with params on device
386
+ def get_original_weights(policy=policy):
387
+ td = TensorDict.from_module(policy)
388
+ return td.data
389
+
390
+ # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
391
+ with param_and_buf.data.to("meta").to_module(policy):
392
+ policy_new_device = deepcopy(policy)
393
+
394
+ param_and_buf_new_device = param_and_buf.apply(
395
+ functools.partial(_map_weight, policy_device=policy_device),
396
+ filter_empty=False,
397
+ )
398
+ param_and_buf_new_device.to_module(policy_new_device)
399
+ # Sanity check
400
+ if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set(
401
+ get_original_weights().keys(True, True)
402
+ ):
403
+ raise RuntimeError("Failed to map weights. The weight sets mismatch.")
404
+ return policy_new_device, get_original_weights
405
+
406
+ def start(self):
407
+ """Starts the collector for asynchronous data collection.
408
+
409
+ This method initiates the background collection of data, allowing for decoupling of data collection and training.
410
+
411
+ The collected data is typically stored in a replay buffer passed during the collector's initialization.
412
+
413
+ .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown`
414
+ when you're done with it to free up resources.
415
+
416
+ .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature.
417
+ Ensure you understand the implications for your specific algorithm before using this mode.
418
+
419
+ Raises:
420
+ NotImplementedError: If not implemented by a subclass.
421
+ """
422
+ raise NotImplementedError(
423
+ f"Collector start() is not implemented for {type(self).__name__}."
424
+ )
425
+
426
+ @contextlib.contextmanager
427
+ def pause(self):
428
+ """Context manager that pauses the collector if it is running free."""
429
+ raise NotImplementedError(
430
+ f"Collector pause() is not implemented for {type(self).__name__}."
431
+ )
432
+
433
+ def async_shutdown(
434
+ self, timeout: float | None = None, close_env: bool = True
435
+ ) -> None:
436
+ """Shuts down the collector when started asynchronously with the `start` method.
437
+
438
+ Args:
439
+ timeout (float, optional): The maximum time to wait for the collector to shutdown.
440
+ close_env (bool, optional): If True, the collector will close the contained environment.
441
+ Defaults to `True`.
442
+
443
+ .. seealso:: :meth:`~.start`
444
+
445
+ """
446
+ return self.shutdown(timeout=timeout, close_env=close_env)
447
+
448
+ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
449
+ """Extract weights from a model if needed.
450
+
451
+ For the new weight sync scheme system, weight preparation is handled
452
+ by the scheme's prepare_weights() method. This method now only handles
453
+ legacy weight updater cases.
454
+
455
+ Args:
456
+ weights: Either already-extracted weights or a model to extract from.
457
+ model_id: The model identifier for resolving string paths.
458
+
459
+ Returns:
460
+ Extracted weights in the appropriate format.
461
+ """
462
+ # New weight sync schemes handle preparation themselves
463
+ if self._weight_sync_schemes:
464
+ # Just pass through - WeightSender will call scheme.prepare_weights()
465
+ return weights
466
+
467
+ # Legacy weight updater path
468
+ return self._legacy_extract_weights(weights, model_id)
469
+
470
+ def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any:
471
+ """Legacy weight extraction for old weight updater system.
472
+
473
+ Args:
474
+ weights: Either already-extracted weights or a model to extract from.
475
+ model_id: The model identifier.
476
+
477
+ Returns:
478
+ Extracted weights.
479
+ """
480
+ if weights is None:
481
+ if model_id == "policy" and hasattr(self, "policy_weights"):
482
+ return self.policy_weights
483
+ elif model_id == "policy" and hasattr(self, "_policy_weights_dict"):
484
+ policy_device = (
485
+ self.policy_device
486
+ if not isinstance(self.policy_device, (list, tuple))
487
+ else self.policy_device[0]
488
+ )
489
+ return self._policy_weights_dict.get(policy_device)
490
+ return None
491
+
492
+ return weights
493
+
494
+ @property
495
+ def _legacy_weight_updater(self) -> bool:
496
+ return self._weight_updater is not None
497
+
498
+ # Overloads for update_policy_weights_ to support multiple calling conventions
499
+ @overload
500
+ def update_policy_weights_(
501
+ self,
502
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
503
+ /,
504
+ ) -> None:
505
+ ...
506
+
507
+ @overload
508
+ def update_policy_weights_(
509
+ self,
510
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
511
+ /,
512
+ *,
513
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
514
+ model_id: str | None = None,
515
+ ) -> None:
516
+ ...
517
+
518
+ @overload
519
+ def update_policy_weights_(
520
+ self,
521
+ *,
522
+ weights: TensorDictBase | dict,
523
+ model_id: str | None = None,
524
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
525
+ ) -> None:
526
+ ...
527
+
528
+ @overload
529
+ def update_policy_weights_(
530
+ self,
531
+ *,
532
+ policy: TensorDictModuleBase | nn.Module,
533
+ model_id: str | None = None,
534
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
535
+ ) -> None:
536
+ ...
537
+
538
+ @overload
539
+ def update_policy_weights_(
540
+ self,
541
+ *,
542
+ weights_dict: dict[
543
+ str, TensorDictBase | TensorDictModuleBase | nn.Module | dict
544
+ ],
545
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
546
+ ) -> None:
547
+ ...
548
+
549
+ def update_policy_weights_(
550
+ self,
551
+ policy_or_weights: TensorDictBase
552
+ | TensorDictModuleBase
553
+ | nn.Module
554
+ | dict
555
+ | None = None,
556
+ *,
557
+ weights: TensorDictBase | dict | None = None,
558
+ policy: TensorDictModuleBase | nn.Module | None = None,
559
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
560
+ model_id: str | None = None,
561
+ weights_dict: dict[str, Any] | None = None,
562
+ **kwargs,
563
+ ) -> None:
564
+ """Update policy weights for the data collector.
565
+
566
+ This method synchronizes the policy weights used by the collector with the latest
567
+ trained weights. It supports both local and remote weight updates, depending on
568
+ the collector configuration.
569
+
570
+ The method accepts weights in multiple forms for convenience:
571
+
572
+ Examples:
573
+ >>> # Pass policy module as positional argument
574
+ >>> collector.update_policy_weights_(policy_module)
575
+ >>>
576
+ >>> # Pass TensorDict weights as positional argument
577
+ >>> collector.update_policy_weights_(weights_tensordict)
578
+ >>>
579
+ >>> # Use keyword arguments for clarity
580
+ >>> collector.update_policy_weights_(weights=weights_td, model_id="actor")
581
+ >>> collector.update_policy_weights_(policy=actor_module, model_id="actor")
582
+ >>>
583
+ >>> # Update multiple models atomically
584
+ >>> collector.update_policy_weights_(weights_dict={
585
+ ... "actor": actor_weights,
586
+ ... "critic": critic_weights,
587
+ ... })
588
+
589
+ Args:
590
+ policy_or_weights: The weights to update with. Can be:
591
+
592
+ - ``nn.Module``: A policy module whose weights will be extracted
593
+ - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
594
+ - ``TensorDictBase``: A TensorDict containing weights
595
+ - ``dict``: A regular dict containing weights
596
+ - ``None``: Will try to get weights from server using ``_get_server_weights()``
597
+
598
+ Keyword Args:
599
+ weights: Alternative to positional argument. A TensorDict or dict containing
600
+ weights to update. Cannot be used together with ``policy_or_weights`` or ``policy``.
601
+ policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase``
602
+ whose weights will be extracted. Cannot be used together with ``policy_or_weights``
603
+ or ``weights``.
604
+ worker_ids: Identifiers for the workers to update. Relevant when the collector
605
+ has multiple workers. Can be int, list of ints, device, or list of devices.
606
+ model_id: The model identifier to update (default: ``"policy"``).
607
+ Cannot be used together with ``weights_dict``.
608
+ weights_dict: Dictionary mapping model_id to weights for updating
609
+ multiple models atomically. Keys should match model_ids registered in
610
+ ``weight_sync_schemes``. Cannot be used together with ``model_id``,
611
+ ``policy_or_weights``, ``weights``, or ``policy``.
612
+
613
+ Raises:
614
+ TypeError: If ``worker_ids`` is provided but no ``weight_updater`` is configured.
615
+ ValueError: If conflicting parameters are provided.
616
+
617
+ .. note:: Users should extend the ``WeightUpdaterBase`` classes to customize
618
+ the weight update logic for specific use cases.
619
+
620
+ .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
621
+ :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
622
+
623
+ """
624
+ # Handle the different keyword argument forms
625
+ if weights is not None:
626
+ if policy_or_weights is not None:
627
+ raise ValueError(
628
+ "Cannot specify both positional 'policy_or_weights' and keyword 'weights'"
629
+ )
630
+ if policy is not None:
631
+ raise ValueError("Cannot specify both 'weights' and 'policy'")
632
+ policy_or_weights = weights
633
+
634
+ if policy is not None:
635
+ if policy_or_weights is not None:
636
+ raise ValueError(
637
+ "Cannot specify both positional 'policy_or_weights' and keyword 'policy'"
638
+ )
639
+ policy_or_weights = policy
640
+ if self._legacy_weight_updater:
641
+ return self._legacy_weight_update_impl(
642
+ policy_or_weights=policy_or_weights,
643
+ worker_ids=worker_ids,
644
+ model_id=model_id,
645
+ weights_dict=weights_dict,
646
+ **kwargs,
647
+ )
648
+ else:
649
+ return self._weight_update_impl(
650
+ policy_or_weights=policy_or_weights,
651
+ worker_ids=worker_ids,
652
+ model_id=model_id,
653
+ weights_dict=weights_dict,
654
+ **kwargs,
655
+ )
656
+
657
+ def _legacy_weight_update_impl(
658
+ self,
659
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
660
+ *,
661
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
662
+ model_id: str | None = None,
663
+ weights_dict: dict[str, Any] | None = None,
664
+ **kwargs,
665
+ ) -> None:
666
+ if weights_dict is not None:
667
+ raise ValueError("weights_dict is not supported with legacy weight updater")
668
+ if model_id is not None:
669
+ raise ValueError("model_id is not supported with legacy weight updater")
670
+ # Fall back to old weight updater system
671
+ self.weight_updater(
672
+ policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
673
+ )
674
+
675
+ def _weight_update_impl(
676
+ self,
677
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
678
+ *,
679
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
680
+ model_id: str | None = None,
681
+ weights_dict: dict[str, Any] | None = None,
682
+ **kwargs,
683
+ ) -> None:
684
+ if "policy_weights" in kwargs:
685
+ warnings.warn(
686
+ "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
687
+ DeprecationWarning,
688
+ )
689
+ policy_or_weights = kwargs.pop("policy_weights")
690
+
691
+ if weights_dict is not None and model_id is not None:
692
+ raise ValueError("Cannot specify both 'weights_dict' and 'model_id'")
693
+
694
+ if weights_dict is not None and policy_or_weights is not None:
695
+ raise ValueError(
696
+ "Cannot specify both 'weights_dict' and 'policy_or_weights'"
697
+ )
698
+
699
+ if self._weight_sync_schemes:
700
+ if model_id is None:
701
+ model_id = "policy"
702
+ if policy_or_weights is not None and weights_dict is None:
703
+ # Use model_id as the key, not hardcoded "policy"
704
+ weights_dict = {model_id: policy_or_weights}
705
+ elif weights_dict is None:
706
+ weights_dict = {model_id: policy_or_weights}
707
+ for target_model_id, weights in weights_dict.items():
708
+ if target_model_id not in self._weight_sync_schemes:
709
+ raise KeyError(
710
+ f"Model '{target_model_id}' not found in registered weight sync schemes. "
711
+ f"Available models: {list(self._weight_sync_schemes.keys())}"
712
+ )
713
+ processed_weights = self._extract_weights_if_needed(
714
+ weights, target_model_id
715
+ )
716
+ # Use new send() API with worker_ids support
717
+ scheme = self._weight_sync_schemes.get(target_model_id)
718
+ if not isinstance(scheme, WeightSyncScheme):
719
+ raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}")
720
+ self._send_weights_scheme(
721
+ scheme=scheme,
722
+ processed_weights=processed_weights,
723
+ worker_ids=worker_ids,
724
+ model_id=target_model_id,
725
+ )
726
+ elif self._weight_updater is not None:
727
+ # unreachable
728
+ raise RuntimeError
729
+ else:
730
+ # No weight updater configured, try fallback
731
+ self._maybe_fallback_update(policy_or_weights, model_id=model_id)
732
+
733
+ def _maybe_fallback_update(
734
+ self,
735
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
736
+ *,
737
+ model_id: str | None = None,
738
+ ) -> None:
739
+ """Fallback weight update when no scheme is configured.
740
+
741
+ Override in subclasses to provide custom fallback behavior.
742
+ By default, this is a no-op.
743
+ """
744
+
745
+ def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids):
746
+ # method to override if the scheme requires an RPC call to receive the weights
747
+ scheme.send(weights=processed_weights, worker_ids=worker_ids)
748
+
749
+ def _receive_weights_scheme(self):
750
+ """Receive weights for all registered receiver schemes.
751
+
752
+ scheme.receive() handles both applying weights locally and cascading
753
+ to sub-collectors via context.update_policy_weights_().
754
+ """
755
+ if not hasattr(self, "_receiver_schemes"):
756
+ raise RuntimeError("No receiver schemes registered.")
757
+
758
+ for scheme in self._receiver_schemes.values():
759
+ scheme.receive()
760
+
761
+ # Overloads for receive_weights to support multiple calling conventions
762
+ @overload
763
+ def receive_weights(self) -> None:
764
+ ...
765
+
766
+ @overload
767
+ def receive_weights(
768
+ self,
769
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
770
+ /,
771
+ ) -> None:
772
+ ...
773
+
774
+ @overload
775
+ def receive_weights(
776
+ self,
777
+ *,
778
+ weights: TensorDictBase | dict,
779
+ ) -> None:
780
+ ...
781
+
782
+ @overload
783
+ def receive_weights(
784
+ self,
785
+ *,
786
+ policy: TensorDictModuleBase | nn.Module,
787
+ ) -> None:
788
+ ...
789
+
790
+ def receive_weights(
791
+ self,
792
+ policy_or_weights: TensorDictBase
793
+ | TensorDictModuleBase
794
+ | nn.Module
795
+ | dict
796
+ | None = None,
797
+ *,
798
+ weights: TensorDictBase | dict | None = None,
799
+ policy: TensorDictModuleBase | nn.Module | None = None,
800
+ ) -> None:
801
+ """Receive and apply weights to the collector's policy.
802
+
803
+ This method applies weights to the local policy. When receiver schemes are
804
+ registered, it delegates to those schemes. Otherwise, it directly applies
805
+ the provided weights.
806
+
807
+ The method accepts weights in multiple forms for convenience:
808
+
809
+ Examples:
810
+ >>> # Receive from registered schemes (distributed collectors)
811
+ >>> collector.receive_weights()
812
+ >>>
813
+ >>> # Apply weights from a policy module (positional)
814
+ >>> collector.receive_weights(trained_policy)
815
+ >>>
816
+ >>> # Apply weights from a TensorDict (positional)
817
+ >>> collector.receive_weights(weights_tensordict)
818
+ >>>
819
+ >>> # Use keyword arguments for clarity
820
+ >>> collector.receive_weights(weights=weights_td)
821
+ >>> collector.receive_weights(policy=trained_policy)
822
+
823
+ Args:
824
+ policy_or_weights: The weights to apply. Can be:
825
+
826
+ - ``nn.Module``: A policy module whose weights will be extracted and applied
827
+ - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
828
+ - ``TensorDictBase``: A TensorDict containing weights
829
+ - ``dict``: A regular dict containing weights
830
+ - ``None``: Receive from registered schemes or mirror from original policy
831
+
832
+ Keyword Args:
833
+ weights: Alternative to positional argument. A TensorDict or dict containing
834
+ weights to apply. Cannot be used together with ``policy_or_weights`` or ``policy``.
835
+ policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase``
836
+ whose weights will be extracted. Cannot be used together with ``policy_or_weights``
837
+ or ``weights``.
838
+
839
+ Raises:
840
+ ValueError: If conflicting parameters are provided or if arguments are passed
841
+ when receiver schemes are registered.
842
+
843
+ """
844
+ # Handle the different keyword argument forms
845
+ if weights is not None:
846
+ if policy_or_weights is not None:
847
+ raise ValueError(
848
+ "Cannot specify both positional 'policy_or_weights' and keyword 'weights'"
849
+ )
850
+ if policy is not None:
851
+ raise ValueError("Cannot specify both 'weights' and 'policy'")
852
+ policy_or_weights = weights
853
+
854
+ if policy is not None:
855
+ if policy_or_weights is not None:
856
+ raise ValueError(
857
+ "Cannot specify both positional 'policy_or_weights' and keyword 'policy'"
858
+ )
859
+ policy_or_weights = policy
860
+
861
+ if getattr(self, "_receiver_schemes", None) is not None:
862
+ if policy_or_weights is not None:
863
+ raise ValueError(
864
+ "Cannot specify 'policy_or_weights' when using 'receiver_schemes'. Schemes should know how to get the weights."
865
+ )
866
+ self._receive_weights_scheme()
867
+ return
868
+
869
+ # No weight updater configured
870
+ # For single-process collectors, apply weights locally if explicitly provided
871
+ if policy_or_weights is not None:
872
+ from torchrl.weight_update.weight_sync_schemes import WeightStrategy
873
+
874
+ # Use WeightStrategy to apply weights properly
875
+ strategy = WeightStrategy(extract_as="tensordict")
876
+
877
+ # Extract weights if needed
878
+ if isinstance(policy_or_weights, nn.Module):
879
+ weights = strategy.extract_weights(policy_or_weights)
880
+ else:
881
+ weights = policy_or_weights
882
+
883
+ # Apply to local policy
884
+ if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
885
+ strategy.apply_weights(self.policy, weights)
886
+ # Otherwise, no action needed - policy is local and changes are immediately visible
887
+
888
+ def register_scheme_receiver(
889
+ self,
890
+ weight_recv_schemes: dict[str, WeightSyncScheme],
891
+ *,
892
+ synchronize_weights: bool = True,
893
+ ): # noqa: D417
894
+ """Set up receiver schemes for this collector to receive weights from parent collectors.
895
+
896
+ This method initializes receiver schemes and stores them in _receiver_schemes
897
+ for later use by _receive_weights_scheme() and receive_weights().
898
+
899
+ Receiver schemes enable cascading weight updates across collector hierarchies:
900
+ - Parent collector sends weights via its weight_sync_schemes (senders)
901
+ - Child collector receives weights via its weight_recv_schemes (receivers)
902
+ - If child is also a parent (intermediate node), it can propagate to its own children
903
+
904
+ Args:
905
+ weight_recv_schemes (dict[str, WeightSyncScheme]): Dictionary of {model_id: WeightSyncScheme} to set up as receivers.
906
+ These schemes will receive weights from parent collectors.
907
+
908
+ Keyword Args:
909
+ synchronize_weights (bool, optional): If True, synchronize weights immediately after registering the schemes.
910
+ Defaults to `True`.
911
+ """
912
+ # Initialize _receiver_schemes if not already present
913
+ if not hasattr(self, "_receiver_schemes"):
914
+ self._receiver_schemes = {}
915
+
916
+ # Initialize each scheme on the receiver side
917
+ for model_id, scheme in weight_recv_schemes.items():
918
+ if not scheme.initialized_on_receiver:
919
+ if scheme.initialized_on_sender:
920
+ raise RuntimeError(
921
+ "Weight sync scheme cannot be initialized on both sender and receiver."
922
+ )
923
+ scheme.init_on_receiver(
924
+ model_id=model_id,
925
+ context=self,
926
+ worker_idx=self.worker_idx,
927
+ )
928
+
929
+ # Store the scheme for later use in receive_weights()
930
+ self._receiver_schemes[model_id] = scheme
931
+
932
+ # Perform initial synchronization
933
+ if synchronize_weights:
934
+ for scheme in weight_recv_schemes.values():
935
+ if not scheme.synchronized_on_receiver:
936
+ scheme.connect(worker_idx=self.worker_idx)
937
+
938
+ def __iter__(self) -> Iterator[TensorDictBase]:
939
+ # Mark that iteration has started (used by enable_profile check)
940
+ self._iteration_started = True
941
+ try:
942
+ yield from self.iterator()
943
+ except Exception:
944
+ self.shutdown()
945
+ raise
946
+
947
+ def next(self):
948
+ try:
949
+ if self._iterator is None:
950
+ self._iterator = iter(self)
951
+ out = next(self._iterator)
952
+ # if any, we don't want the device ref to be passed in distributed settings
953
+ if out is not None and (out.device != "cpu"):
954
+ out = out.copy().clear_device_()
955
+ return out
956
+ except StopIteration:
957
+ return None
958
+
959
+ @abc.abstractmethod
960
+ def shutdown(
961
+ self,
962
+ timeout: float | None = None,
963
+ close_env: bool = True,
964
+ raise_on_error: bool = True,
965
+ ) -> None:
966
+ raise NotImplementedError
967
+
968
+ @abc.abstractmethod
969
+ def iterator(self) -> Iterator[TensorDictBase]:
970
+ raise NotImplementedError
971
+
972
+ @abc.abstractmethod
973
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
974
+ raise NotImplementedError
975
+
976
+ @abc.abstractmethod
977
+ def state_dict(self) -> OrderedDict:
978
+ raise NotImplementedError
979
+
980
+ @abc.abstractmethod
981
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
982
+ raise NotImplementedError
983
+
984
+ def _read_compile_kwargs(self, compile_policy, cudagraph_policy):
985
+ self.compiled_policy = compile_policy not in (False, None)
986
+ self.cudagraphed_policy = cudagraph_policy not in (False, None)
987
+ self.compiled_policy_kwargs = (
988
+ {} if not isinstance(compile_policy, typing.Mapping) else compile_policy
989
+ )
990
+ self.cudagraphed_policy_kwargs = (
991
+ {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy
992
+ )
993
+
994
+ def __repr__(self) -> str:
995
+ string = f"{self.__class__.__name__}()"
996
+ return string
997
+
998
+ def __class_getitem__(self, index):
999
+ raise NotImplementedError
1000
+
1001
+ def __len__(self) -> int:
1002
+ if self.total_frames > 0:
1003
+ return -(self.total_frames // -self.requested_frames_per_batch)
1004
+ raise RuntimeError("Non-terminating collectors do not have a length")
1005
+
1006
+ def init_updater(self, *args, **kwargs):
1007
+ """Initialize the weight updater with custom arguments.
1008
+
1009
+ This method passes the arguments to the weight updater's init method.
1010
+ If no weight updater is set, this is a no-op.
1011
+
1012
+ Args:
1013
+ *args: Positional arguments for weight updater initialization
1014
+ **kwargs: Keyword arguments for weight updater initialization
1015
+ """
1016
+ if self.weight_updater is not None:
1017
+ self.weight_updater.init(*args, **kwargs)
1018
+
1019
+
1020
+ def _make_legacy_metaclass(parent_metaclass):
1021
+ """Create a legacy metaclass for deprecated collector names.
1022
+
1023
+ This factory creates a metaclass that inherits from the given parent metaclass
1024
+ to avoid metaclass conflicts.
1025
+ """
1026
+
1027
+ class _LegacyMeta(parent_metaclass):
1028
+ """Metaclass for deprecated collector class names.
1029
+
1030
+ Raises a deprecation warning when the old class name is instantiated,
1031
+ and ensures isinstance() checks work for both old and new names.
1032
+ """
1033
+
1034
+ def __call__(cls, *args, **kwargs):
1035
+ warnings.warn(
1036
+ f"{cls.__name__} has been deprecated and will be removed in v0.13. "
1037
+ f"Please use {cls.__bases__[0].__name__} instead.",
1038
+ category=DeprecationWarning,
1039
+ )
1040
+ return super().__call__(*args, **kwargs)
1041
+
1042
+ def __instancecheck__(cls, instance):
1043
+ if super().__instancecheck__(instance):
1044
+ return True
1045
+ parent_cls = cls.__bases__[0]
1046
+ return isinstance(instance, parent_cls)
1047
+
1048
+ return _LegacyMeta
1049
+
1050
+
1051
+ # Default legacy metaclass for classes with abc.ABCMeta
1052
+ _LegacyCollectorMeta = _make_legacy_metaclass(abc.ABCMeta)
1053
+
1054
+
1055
+ class DataCollectorBase(BaseCollector, metaclass=_LegacyCollectorMeta):
1056
+ """Deprecated version of :class:`~torchrl.collectors.BaseCollector`."""
1057
+
1058
+ ...