torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1244 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import threading
9
+ import warnings
10
+ import weakref
11
+ from collections import defaultdict
12
+ from collections.abc import Callable, Iterator
13
+ from typing import Any, Literal, overload, Protocol
14
+
15
+ import torch
16
+
17
+ from tensordict import TensorDict, TensorDictBase
18
+ from torch import nn
19
+ from torchrl._utils import logger as torchrl_logger
20
+
21
+ __all__ = [
22
+ "TransportBackend",
23
+ "WeightStrategy",
24
+ "WeightSyncScheme",
25
+ ]
26
+
27
+ from torchrl.weight_update.utils import _resolve_model
28
+
29
+
30
+ # ============================================================================
31
+ # Transport Layer Abstraction
32
+ # ============================================================================
33
+
34
+
35
+ class TransportBackend(Protocol):
36
+ """Abstract interface for different communication mechanisms."""
37
+
38
+ def send_weights(self, weights: Any) -> None:
39
+ """Send weights to the receiver."""
40
+ ...
41
+
42
+ def receive_weights(
43
+ self,
44
+ timeout: float | None = None,
45
+ *,
46
+ weights: Any = None,
47
+ model: Any = None,
48
+ strategy: WeightStrategy | None = None,
49
+ ) -> Any | None:
50
+ """Receive weights from the sender and apply them to the model.
51
+
52
+ Args:
53
+ timeout: Maximum time to wait for weights (seconds).
54
+ None means no timeout (blocking). Some transports may not
55
+ support timeout and will raise ValueError if specified.
56
+ weights: Pre-allocated weight buffer to receive into.
57
+ model: The model to apply weights to.
58
+ strategy: Strategy for applying weights to the model.
59
+
60
+ Returns:
61
+ The received/applied weights, or None if timeout/no weights available.
62
+ """
63
+ ...
64
+
65
+ def setup_connection_and_weights_on_sender(self) -> None:
66
+ """Synchronize weights on sender side before collection starts.
67
+
68
+ This is called once after workers are initialized to send the initial
69
+ weights. This can be a no-op (weights are sent via
70
+ send_weights).
71
+ """
72
+ ...
73
+
74
+ def setup_connection_and_weights_on_receiver(
75
+ self,
76
+ *,
77
+ worker_idx: int,
78
+ weights: Any = None,
79
+ model: Any = None,
80
+ strategy: WeightStrategy | None = None,
81
+ ) -> Any:
82
+ """Synchronize weights on worker side before collection starts.
83
+
84
+ This is called once in each worker after initialization to receive
85
+ the initial weights. This is a no-op (weights are received via
86
+ receive_weights).
87
+
88
+ Args:
89
+ worker_idx: The worker index.
90
+ weights: Pre-allocated weight buffer to receive into.
91
+ model: The model to apply weights to.
92
+ strategy: Strategy for applying weights to the model.
93
+
94
+ Returns:
95
+ The received weights (for SharedMemTransport) or None.
96
+ """
97
+ ...
98
+
99
+
100
+ # ============================================================================
101
+ # Weight Strategies
102
+ # ============================================================================
103
+
104
+
105
+ class WeightStrategy:
106
+ """Unified strategy for weight transmission.
107
+
108
+ This strategy handles both extraction and application of weights, supporting
109
+ both TensorDict and state_dict formats.
110
+
111
+ Args:
112
+ extract_as (str): Format for extracting weights. Can be:
113
+ - "tensordict" (default): Extract weights as TensorDict
114
+ - "state_dict": Extract weights as PyTorch state_dict
115
+
116
+ The application format is automatically detected based on the type of weights
117
+ received (dict -> state_dict, TensorDict -> tensordict).
118
+ """
119
+
120
+ def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"):
121
+ if extract_as == "state_dict":
122
+ warnings.warn(
123
+ "state_dict strategy is experimental. Use tensordict strategy for safer weight updates.",
124
+ UserWarning,
125
+ )
126
+ if extract_as not in ("tensordict", "state_dict"):
127
+ raise ValueError(
128
+ f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}"
129
+ )
130
+ self.extract_as = extract_as
131
+
132
+ def extract_weights(self, source: Any) -> TensorDictBase | dict | None:
133
+ """Extract weights from source model in the specified format.
134
+
135
+ Args:
136
+ source: The model to extract weights from. Can be:
137
+ - nn.Module: PyTorch module
138
+ - TensorDictBase: TensorDict
139
+ - dict: State dictionary
140
+
141
+ Returns:
142
+ Weights in the format specified by `extract_as` constructor argument.
143
+ """
144
+ if self.extract_as == "tensordict":
145
+ # Extract as TensorDict
146
+ if isinstance(source, nn.Module):
147
+ return TensorDict.from_module(source)
148
+ elif isinstance(source, TensorDictBase):
149
+ return source
150
+ elif isinstance(source, dict):
151
+ # Convert state_dict to TensorDict
152
+ return TensorDict(source, batch_size=[])
153
+ else:
154
+ torchrl_logger.warning(
155
+ f"Unsupported source type for TensorDict extraction: {type(source)}"
156
+ )
157
+ return TensorDict(lock=True)
158
+ elif self.extract_as == "state_dict": # state_dict
159
+ # Extract as state_dict
160
+ if isinstance(source, nn.Module):
161
+ return source.state_dict()
162
+ elif isinstance(source, dict):
163
+ return source
164
+ elif isinstance(source, TensorDictBase):
165
+ # Convert TensorDict to state_dict
166
+ return source.flatten_keys().to_dict()
167
+ else:
168
+ torchrl_logger.warning(
169
+ f"Unsupported source type for TensorDict extraction: {type(source)}"
170
+ )
171
+ return {}
172
+ else:
173
+ raise ValueError(
174
+ f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'."
175
+ )
176
+
177
+ def apply_weights(
178
+ self, destination: Any, weights: Any, inplace: bool = True
179
+ ) -> None:
180
+ """Apply weights to destination model.
181
+
182
+ The format is automatically detected from the weights type:
183
+ - dict -> state_dict format
184
+ - TensorDictBase -> tensordict format
185
+
186
+ Args:
187
+ destination: The model to apply weights to. Can be:
188
+ - nn.Module: PyTorch module
189
+ - TensorDictBase: TensorDict
190
+ - dict: State dictionary
191
+ weights: The weights to apply (dict or TensorDictBase).
192
+ inplace: Whether to apply weights in place.
193
+ """
194
+ if weights is None:
195
+ return
196
+
197
+ # Auto-detect format from weights type
198
+ if isinstance(weights, dict):
199
+ weights = TensorDict(weights)
200
+ if any("." in key for key in weights.keys()):
201
+ weights = weights.unflatten_keys(".")
202
+ if isinstance(destination, nn.Module):
203
+ # Do not update in-place
204
+ if not inplace:
205
+ weights.to_module(destination)
206
+ return
207
+ else:
208
+ destination = TensorDict.from_module(destination)
209
+ elif isinstance(destination, dict):
210
+ if not inplace:
211
+ raise ValueError("Cannot update state_dict out of place")
212
+ destination = TensorDict(destination)
213
+ if any(isinstance(key, str) and "." in key for key in destination.keys()):
214
+ destination = destination.unflatten_keys(".")
215
+
216
+ if not isinstance(weights, TensorDictBase):
217
+ raise ValueError(
218
+ f"Unsupported weights type: {type(weights)}. Must be dict or TensorDictBase."
219
+ )
220
+ if not isinstance(destination, TensorDictBase):
221
+ if not weights.is_empty():
222
+ raise ValueError(
223
+ "Non-empty weights are associated with a non-dict, non-td, non-Module destination."
224
+ )
225
+ return
226
+
227
+ try:
228
+ if not inplace:
229
+ destination.update(weights)
230
+ else:
231
+ destination.data.update_(weights.data)
232
+ except Exception as e:
233
+ raise KeyError(
234
+ f"Error updating destination. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}"
235
+ ) from e
236
+ return
237
+
238
+
239
+ def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy:
240
+ """Get strategy object from string name.
241
+
242
+ Args:
243
+ strategy: Either "tensordict" or "state_dict".
244
+
245
+ Returns:
246
+ WeightStrategy: Strategy configured with the specified extraction format.
247
+ """
248
+ if strategy not in ("tensordict", "state_dict"):
249
+ raise ValueError(
250
+ f"Unknown strategy: {strategy}. Must be 'tensordict' or 'state_dict'."
251
+ )
252
+ return WeightStrategy(extract_as=strategy)
253
+
254
+
255
+ # ============================================================================
256
+ # Weight Synchronization Schemes
257
+ # ============================================================================
258
+
259
+
260
+ class WeightSyncScheme(metaclass=abc.ABCMeta):
261
+ """Configuration for how to synchronize ONE model across workers.
262
+
263
+ A scheme manages synchronization of ONE model across workers.
264
+ The collector maintains a dict of {model_id: scheme} pairs.
265
+
266
+ This class directly handles both sender and receiver functionality,
267
+ with behavior determined by whether init_on_sender() or init_on_receiver()
268
+ was called.
269
+ """
270
+
271
+ _model_id: str | None = None
272
+
273
+ # Transport management
274
+ _sender_transports: dict[int, TransportBackend] | None
275
+ _receiver_transport: TransportBackend | None
276
+ _shared_transport: TransportBackend | None
277
+
278
+ # Context and model references
279
+ _context_ref: weakref.ReferenceType[Any] | None
280
+ _model_ref: weakref.ReferenceType[Any] | None
281
+
282
+ # Strategy
283
+ _strategy: WeightStrategy
284
+
285
+ # Worker index (for receiver side)
286
+ _worker_idx: int | None
287
+
288
+ # Background thread
289
+ _background_thread = None
290
+ _stop_event = None
291
+
292
+ def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"):
293
+ self.strategy_str = strategy
294
+ self._strategy = _get_strategy(strategy)
295
+ self._initialized_on_sender = False
296
+ self._initialized_on_receiver = False
297
+
298
+ # Transport management
299
+ self._sender_transports = None # worker_idx -> transport
300
+ self._receiver_transport = None
301
+ self._shared_transport = None
302
+
303
+ # Context and model references
304
+ self._context_ref = None
305
+ self._model_ref = None
306
+
307
+ # Worker index
308
+ self._worker_idx = None
309
+
310
+ # ========================================================================
311
+ # Initialization
312
+ # ========================================================================
313
+
314
+ @property
315
+ def strategy(self) -> WeightStrategy:
316
+ return self._strategy
317
+
318
+ @strategy.setter
319
+ def strategy(self, value: WeightStrategy) -> None:
320
+ self._strategy = value
321
+
322
+ @overload
323
+ def init_on_sender(
324
+ self,
325
+ *,
326
+ model_id: str,
327
+ context: Any,
328
+ ) -> None:
329
+ ...
330
+
331
+ @overload
332
+ def init_on_sender(
333
+ self,
334
+ *,
335
+ params_map: dict[int, TensorDictBase],
336
+ model_id: str | None = None,
337
+ ) -> None:
338
+ ...
339
+
340
+ @overload
341
+ def init_on_sender(
342
+ self,
343
+ *,
344
+ params_map: dict[int, TensorDictBase],
345
+ ) -> None:
346
+ ...
347
+
348
+ @overload
349
+ def init_on_sender(
350
+ self,
351
+ *,
352
+ weights: TensorDictBase,
353
+ devices: list[torch.device],
354
+ ) -> None:
355
+ ...
356
+
357
+ @overload
358
+ def init_on_sender(
359
+ self,
360
+ *,
361
+ weights: TensorDictBase,
362
+ devices: list[torch.device],
363
+ model_id: str | None = None,
364
+ ) -> None:
365
+ ...
366
+
367
+ @overload
368
+ def init_on_sender(
369
+ self,
370
+ *,
371
+ model: nn.Module,
372
+ devices: list[torch.device],
373
+ ) -> None:
374
+ ...
375
+
376
+ @overload
377
+ def init_on_sender(
378
+ self,
379
+ *,
380
+ model: nn.Module,
381
+ devices: list[torch.device],
382
+ model_id: str | None = None,
383
+ ) -> None:
384
+ ...
385
+
386
+ @overload
387
+ def init_on_sender(
388
+ self,
389
+ *,
390
+ weights: TensorDictBase,
391
+ device_map_fn: Callable[[int, TensorDictBase], TensorDictBase],
392
+ num_workers: int,
393
+ ) -> None:
394
+ ...
395
+
396
+ @overload
397
+ def init_on_sender(
398
+ self,
399
+ *,
400
+ model: nn.Module,
401
+ device_map_fn: Callable[[int, TensorDictBase], TensorDictBase],
402
+ num_workers: int,
403
+ model_id: str | None = None,
404
+ ) -> None:
405
+ ...
406
+
407
+ @overload
408
+ def init_on_sender(self):
409
+ ...
410
+
411
+ def init_on_sender(
412
+ self,
413
+ *args,
414
+ **kwargs,
415
+ ) -> None:
416
+ """Initialize on the main process (sender side).
417
+
418
+ This method is called once in the collector's _run_processes() method,
419
+ after workers have been started and are ready to receive messages.
420
+ """
421
+ self._initialized_on_sender = True
422
+ try:
423
+ result = self._init_on_sender_impl(*args, **kwargs)
424
+ except Exception:
425
+ self._initialized_on_sender = False
426
+ raise
427
+ return result
428
+
429
+ def _init_on_sender_impl(self, *args, **kwargs):
430
+ raise NotImplementedError
431
+
432
+ @property
433
+ def initialized_on_sender(self):
434
+ return getattr(self, "_initialized_on_sender", False)
435
+
436
+ @property
437
+ def initialized_on_receiver(self):
438
+ return getattr(self, "_initialized_on_receiver", False)
439
+
440
+ @overload
441
+ def init_on_receiver(
442
+ self,
443
+ model_id: str,
444
+ context: Any,
445
+ **kwargs,
446
+ ) -> None:
447
+ ...
448
+
449
+ @overload
450
+ def init_on_receiver(
451
+ self,
452
+ model_id: str,
453
+ context: None = None,
454
+ *,
455
+ worker_idx: int = ...,
456
+ model: Any | None = None,
457
+ **kwargs,
458
+ ) -> None:
459
+ ...
460
+
461
+ def init_on_receiver(
462
+ self,
463
+ *,
464
+ model_id: str,
465
+ context: Any = None,
466
+ **kwargs,
467
+ ) -> None:
468
+ """Initialize on worker process (receiver side).
469
+
470
+ This method is called once in each worker's initialization.
471
+
472
+ Args:
473
+ model_id: Identifier for the model being synchronized
474
+ context: Optional context object (e.g., inner collector)
475
+ **kwargs: Alternative to context (model, etc.)
476
+ """
477
+ if self.initialized_on_sender:
478
+ # emulate pickling to erase the current state
479
+ self.__setstate__(self.__getstate__())
480
+
481
+ self._initialized_on_receiver = True
482
+ try:
483
+ result = self._init_on_receiver_impl(
484
+ model_id=model_id, context=context, **kwargs
485
+ )
486
+ except Exception:
487
+ self._initialized_on_receiver = False
488
+ raise
489
+ return result
490
+
491
+ def _init_on_receiver_impl(
492
+ self,
493
+ model_id: str,
494
+ context: Any = None,
495
+ **kwargs,
496
+ ) -> None:
497
+ raise NotImplementedError
498
+
499
+ # ========================================================================
500
+ # Context and Model Management
501
+ # ========================================================================
502
+
503
+ @property
504
+ def context(self) -> Any | None:
505
+ """Get the context object (e.g., collector), if available.
506
+
507
+ Returns:
508
+ The context object if available, None otherwise.
509
+ """
510
+ if self._context_ref is not None:
511
+ return self._context_ref()
512
+ return None
513
+
514
+ @context.setter
515
+ def context(self, context: Any) -> None:
516
+ """Set the context object for resolving references.
517
+
518
+ Args:
519
+ context: The context object to resolve references from.
520
+ """
521
+ if context is not None:
522
+ self._context_ref = weakref.ref(context)
523
+ else:
524
+ self._context_ref = None
525
+
526
+ @property
527
+ def model_id(self) -> str | None:
528
+ """Get the model ID for this scheme.
529
+
530
+ Returns:
531
+ The model ID if set, None otherwise.
532
+ """
533
+ return self._model_id
534
+
535
+ @model_id.setter
536
+ def model_id(self, model_id: str) -> None:
537
+ """Set the model ID for this scheme.
538
+
539
+ Args:
540
+ model_id: The model ID to set.
541
+ """
542
+ self._model_id = model_id
543
+
544
+ @property
545
+ def worker_idx(self) -> int | None:
546
+ """Get the worker index for this scheme.
547
+
548
+ Returns:
549
+ The worker index if set, None otherwise.
550
+ """
551
+ return self._worker_idx
552
+
553
+ @worker_idx.setter
554
+ def worker_idx(self, worker_idx: int | None) -> None:
555
+ """Set the worker index for this scheme.
556
+
557
+ Args:
558
+ worker_idx: The worker index to set.
559
+ """
560
+ if self.initialized_on_sender and worker_idx is not None:
561
+ raise RuntimeError(
562
+ "Worker index cannot be set after initialization on sender"
563
+ )
564
+ self._worker_idx = worker_idx
565
+
566
+ @property
567
+ def model(self) -> Any | None:
568
+ """Get the model object, if available.
569
+
570
+ Returns:
571
+ The model object if available, None otherwise.
572
+ """
573
+ if self._model_ref is not None:
574
+ return self._model_ref()
575
+ if self._model_id is not None:
576
+ model = _resolve_model(self.context, self._model_id)
577
+ if model is None:
578
+ raise AttributeError(
579
+ f"Model {self._model_id} was `None` in context {self.context}"
580
+ )
581
+ self._model_ref = weakref.ref(model)
582
+ return model
583
+
584
+ @model.setter
585
+ def model(self, model: Any) -> None:
586
+ """Set the model object for applying weights.
587
+
588
+ Args:
589
+ model: The model object to apply weights to.
590
+ """
591
+ if model is not None:
592
+ self._model_ref = weakref.ref(model)
593
+ else:
594
+ self._model_ref = None
595
+
596
+ @property
597
+ def weights(self) -> Any | None:
598
+ """Get the current weights, if available.
599
+
600
+ Returns:
601
+ The weights as TensorDict if available, None otherwise.
602
+ """
603
+ if (weights := getattr(self, "_weights", None)) is not None:
604
+ return weights
605
+ model = self.model
606
+ if model is not None:
607
+ return self._strategy.extract_weights(model)
608
+ return None
609
+
610
+ @weights.setter
611
+ def weights(self, value: Any):
612
+ self._weights = value
613
+
614
+ def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase:
615
+ from torchrl.collectors.utils import _cast
616
+
617
+ if isinstance(model, torch.nn.Module):
618
+ td = TensorDict.from_module(model)
619
+ td = td.data.apply(_cast, td)
620
+ return td
621
+ # Return an empty TD
622
+ return TensorDict()
623
+
624
+ # ========================================================================
625
+ # Transport Management
626
+ # ========================================================================
627
+
628
+ def _register_worker_sender(
629
+ self,
630
+ *,
631
+ worker_idx: int,
632
+ transport: TransportBackend | None = None,
633
+ **transport_kwargs,
634
+ ) -> None:
635
+ """Register a worker's communication.
636
+
637
+ Args:
638
+ worker_idx: The worker index.
639
+ transport: Optional pre-created transport.
640
+ **transport_kwargs: Transport-specific configuration.
641
+ """
642
+ if self._sender_transports is None:
643
+ if self._shared_transport is not None:
644
+ raise RuntimeError(
645
+ "Cannot register transports on sender after shared transport is set"
646
+ )
647
+ self._sender_transports = {}
648
+ if worker_idx not in self._sender_transports:
649
+ if transport is not None:
650
+ self._sender_transports[worker_idx] = transport
651
+ else:
652
+ self._sender_transports[worker_idx] = self.create_transport(
653
+ **transport_kwargs
654
+ )
655
+
656
+ def _register_transport_receiver(
657
+ self, transport: TransportBackend | None = None, **transport_kwargs
658
+ ) -> None:
659
+ """Register a single transport (for receiver side).
660
+
661
+ Args:
662
+ transport: Optional pre-created transport.
663
+ **transport_kwargs: Transport-specific configuration.
664
+ """
665
+ if transport is not None:
666
+ self._receiver_transport = transport
667
+ else:
668
+ self._receiver_transport = self.create_transport(**transport_kwargs)
669
+
670
+ def _iterate_transports(
671
+ self, worker_ids: int | list[int] | None = None
672
+ ) -> Iterator[TransportBackend]:
673
+ """Iterate over transports for specified workers."""
674
+ if worker_ids is None:
675
+ # All workers
676
+ if not self.sender_transports:
677
+ if self.receiver_transport is not None:
678
+ yield self.receiver_transport
679
+ else:
680
+ # Make sure transports are sorted
681
+ for k in sorted(self.sender_transports.keys()):
682
+ yield self.sender_transports[k]
683
+ else:
684
+ # Specific workers
685
+ if isinstance(worker_ids, int):
686
+ worker_ids = [worker_ids]
687
+ for worker_id in worker_ids:
688
+ if worker_id in self.sender_transports:
689
+ yield self.sender_transports[worker_id]
690
+ else:
691
+ raise ValueError(f"Worker {worker_id} not registered")
692
+
693
+ @abc.abstractmethod
694
+ def create_transport(self, **kwargs) -> TransportBackend:
695
+ """Create transport for communication.
696
+
697
+ Args:
698
+ **kwargs: Transport-specific configuration parameters.
699
+
700
+ Returns:
701
+ A transport backend instance.
702
+
703
+ Note:
704
+ This is used internally by init_on_sender/init_on_receiver.
705
+ """
706
+ ...
707
+
708
+ @property
709
+ def sender_transports(self) -> dict[int, TransportBackend]:
710
+ """Get the sender transports.
711
+
712
+ Returns:
713
+ The sender transports.
714
+ """
715
+ if self._shared_transport is not None:
716
+ return defaultdict(lambda: self._shared_transport)
717
+ return self._sender_transports
718
+
719
+ @property
720
+ def receiver_transport(self) -> TransportBackend | None:
721
+ """Get the receiver transport.
722
+
723
+ Returns:
724
+ The receiver transport.
725
+ """
726
+ if self._shared_transport is not None:
727
+ return self._shared_transport
728
+ return self._receiver_transport
729
+
730
+ @property
731
+ def shared_transport(self) -> TransportBackend | None:
732
+ """Get the shared transport.
733
+
734
+ Returns:
735
+ The shared transport.
736
+ """
737
+ if self._receiver_transport is not None:
738
+ raise RuntimeError(
739
+ "Receiver transport and shared transport cannot be used together"
740
+ )
741
+ if self._sender_transports is not None:
742
+ raise RuntimeError(
743
+ "Sender transports and shared transport cannot be used together"
744
+ )
745
+ return self._shared_transport
746
+
747
+ @shared_transport.setter
748
+ def shared_transport(self, shared_transport: TransportBackend | None) -> None:
749
+ """Set the shared transport.
750
+
751
+ Args:
752
+ shared_transport: The shared transport to set.
753
+ """
754
+ self._shared_transport = shared_transport
755
+
756
+ # ========================================================================
757
+ # Sending Weights (Sender Side)
758
+ # ========================================================================
759
+
760
+ def send(
761
+ self,
762
+ weights: Any = None,
763
+ worker_ids: int | list[int] | None = None,
764
+ ) -> None:
765
+ """Send weights synchronously to workers.
766
+
767
+ This method:
768
+ 1. Prepares weights (extracts from model if weights=None)
769
+ 2. Sends to specified workers (or all if worker_ids=None)
770
+ 3. Waits for acknowledgments from those workers
771
+ 4. Returns when workers have applied the weights
772
+
773
+ Args:
774
+ weights: Weights to send. Can be:
775
+ - None: Extract from model via context.get_model(model_id)
776
+ - nn.Module: Extract weights from module
777
+ - TensorDict: Use directly
778
+ - dict: Convert to TensorDict
779
+ worker_ids: Which workers to send to:
780
+ - None: Send to all workers (default)
781
+ - int: Send to single worker
782
+ - list[int]: Send to specific workers
783
+
784
+ Note: This is a blocking call that ensures specified workers are updated
785
+ before returning.
786
+ """
787
+ if not self.initialized_on_sender:
788
+ raise RuntimeError("Must be initialized on sender before sending weights")
789
+ if not self.synchronized_on_sender:
790
+ raise RuntimeError("Must be synchronized on sender before sending weights")
791
+
792
+ context = self.context
793
+
794
+ # Let the scheme prepare the weights
795
+ prepared_weights = self.prepare_weights(
796
+ weights=weights,
797
+ model_id=self._model_id,
798
+ strategy=self._strategy,
799
+ context=context,
800
+ )
801
+
802
+ transports = list(self._iterate_transports(worker_ids))
803
+
804
+ if not transports:
805
+ raise RuntimeError("No transports available.")
806
+
807
+ # Send to all workers first (non-blocking if transport supports it)
808
+ for transport in transports:
809
+ if hasattr(transport, "send_weights_async"):
810
+ transport.send_weights_async(prepared_weights)
811
+ else:
812
+ # Fallback for transports that don't support async send
813
+ transport.send_weights(prepared_weights)
814
+
815
+ # Wait for all acknowledgments
816
+ for transport in transports:
817
+ if hasattr(transport, "wait_ack"):
818
+ transport.wait_ack()
819
+
820
+ def prepare_weights(
821
+ self,
822
+ weights: Any,
823
+ model_id: str,
824
+ strategy: WeightStrategy,
825
+ context: Any = None,
826
+ ) -> Any:
827
+ """Prepare weights for sending.
828
+
829
+ This method handles weight extraction, conversion, and any scheme-specific
830
+ preparation (e.g., cache lookups for SharedMemWeightSyncScheme).
831
+
832
+ Args:
833
+ weights: Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.)
834
+ model_id: The model identifier (e.g., "policy")
835
+ strategy: WeightStrategy for extracting/converting weights
836
+ context: Optional context (e.g., collector) for model resolution
837
+
838
+ Returns:
839
+ Prepared weights ready to send via transport
840
+ """
841
+ # Default implementation: extract from model or pass through
842
+ if weights is None and context is not None:
843
+ # Try to resolve and extract from model in context
844
+ try:
845
+ model = _resolve_model(context, model_id)
846
+ return strategy.extract_weights(model)
847
+ except (AttributeError, KeyError):
848
+ pass
849
+ # Try fallback policy
850
+ if model_id == "policy" and hasattr(context, "_fallback_policy"):
851
+ if context._fallback_policy is not None:
852
+ return strategy.extract_weights(context._fallback_policy)
853
+ return None
854
+
855
+ if isinstance(weights, nn.Module):
856
+ return strategy.extract_weights(weights)
857
+ elif isinstance(weights, str):
858
+ # String reference to model
859
+ if context is not None:
860
+ model = _resolve_model(context, weights)
861
+ return strategy.extract_weights(model)
862
+ raise ValueError(
863
+ f"Cannot resolve string reference '{weights}' without context"
864
+ )
865
+ else:
866
+ # Already extracted weights (TensorDict, dict, etc.)
867
+ return weights
868
+
869
+ # ========================================================================
870
+ # Receiving Weights (Receiver Side)
871
+ # ========================================================================
872
+
873
+ def receive(self, timeout: float | None = None) -> TensorDictBase | None:
874
+ """Check for and apply new weights (non-blocking).
875
+
876
+ This method is called in the worker's main loop to check if
877
+ new weights have been sent. If weights are available, they
878
+ are applied to the registered model immediately, and the update
879
+ is cascaded to any sub-collectors via context.update_policy_weights_().
880
+
881
+ Args:
882
+ timeout: Maximum time to wait for weights (seconds).
883
+ None means no timeout (blocking). Some transports may not
884
+ support timeout and will raise ValueError if specified.
885
+
886
+ Returns:
887
+ The received weights if available, None otherwise.
888
+
889
+ Note: For SharedMemWeightSyncScheme, this always returns None
890
+ since workers automatically see updates via shared memory.
891
+ """
892
+ if not self.initialized_on_receiver:
893
+ raise RuntimeError(
894
+ "Must be initialized on receiver before receiving weights"
895
+ )
896
+ if not self.synchronized_on_receiver:
897
+ raise RuntimeError(
898
+ "Must be synchronized on receiver before receiving weights"
899
+ )
900
+
901
+ # Determine which transport to use
902
+ if self._receiver_transport is not None:
903
+ transport = self._receiver_transport
904
+ elif self._shared_transport is not None:
905
+ # Use shared transport directly (e.g., SharedMemWeightSyncScheme)
906
+ transport = self._shared_transport
907
+ else:
908
+ return None
909
+
910
+ # Try to receive weights - transport handles receiving and applying
911
+ result = transport.receive_weights(
912
+ timeout=timeout,
913
+ weights=self.weights,
914
+ model=self.model,
915
+ strategy=self._strategy,
916
+ )
917
+ if result is None:
918
+ return None
919
+
920
+ weights = result
921
+ model_id = self._model_id or "policy"
922
+
923
+ # Cascade weight update to sub-collectors if context supports it
924
+ if self.context is not None and hasattr(self.context, "update_policy_weights_"):
925
+ self.context.update_policy_weights_(
926
+ model_id=model_id, policy_or_weights=weights
927
+ )
928
+
929
+ # Send acknowledgment if transport supports it
930
+ if hasattr(transport, "send_ack"):
931
+ transport.send_ack("updated")
932
+
933
+ return weights
934
+
935
+ def apply_weights(self, weights: TensorDictBase, inplace: bool = True) -> None:
936
+ """Apply weights to the model.
937
+
938
+ Args:
939
+ weights: The weights to apply.
940
+ inplace: Whether to apply weights in place. Default is `True`.
941
+ """
942
+ if not self.initialized_on_receiver:
943
+ if self.initialized_on_sender:
944
+ raise RuntimeError("apply_weights() called on a sender side.")
945
+ raise RuntimeError(
946
+ "apply_weights() called before init_on_receiver has been called."
947
+ )
948
+
949
+ if self._model_ref is None:
950
+ raise ValueError("No model registered")
951
+
952
+ model = self.model
953
+ self._strategy.apply_weights(model, weights, inplace=inplace)
954
+
955
+ # Send acknowledgment if transport supports it
956
+ if self.receiver_transport is not None and hasattr(
957
+ self.receiver_transport, "send_ack"
958
+ ):
959
+ self.receiver_transport.send_ack("updated")
960
+
961
+ # ========================================================================
962
+ # Synchronization
963
+ # ========================================================================
964
+
965
+ @overload
966
+ def connect(self, *, worker_idx: int | None = None) -> None:
967
+ ...
968
+
969
+ @overload
970
+ def connect(self, *, weights: Any | None = None) -> None:
971
+ ...
972
+
973
+ def connect(
974
+ self, *, worker_idx: int | None = None, weights: Any | None = None
975
+ ) -> None:
976
+ """Method to be called once the workers have started.
977
+
978
+ Triggers a rendez-vous for the workers to receive their copy of the weights.
979
+
980
+ Dispatches to _setup_connection_and_weights_on_sender_impl() or _setup_connection_and_weights_on_receiver_impl()
981
+ based on which initialization was performed.
982
+ """
983
+ if self.synchronized_on_receiver or self.synchronized_on_sender:
984
+ raise RuntimeError("Cannot synchronize weights on sender twice.")
985
+ if self._initialized_on_sender:
986
+ if worker_idx is not None:
987
+ # Safety check, we can consider removing this in the future.
988
+ raise RuntimeError(
989
+ "Cannot specify worker_idx on sender side during synchronization."
990
+ )
991
+ self.synchronized_on_sender = True
992
+ try:
993
+ self._setup_connection_and_weights_on_sender_impl(weights=weights)
994
+ except Exception:
995
+ self.synchronized_on_sender = False
996
+ raise
997
+ elif self._initialized_on_receiver:
998
+ if weights is not None:
999
+ # safety check: weights are passed to sender, not receiver for initial sync
1000
+ raise RuntimeError(
1001
+ "Cannot specify weights on receiver side during synchronization."
1002
+ )
1003
+ self.synchronized_on_receiver = True
1004
+ try:
1005
+ self._setup_connection_and_weights_on_receiver_impl(
1006
+ worker_idx=worker_idx
1007
+ )
1008
+ except Exception:
1009
+ self.synchronized_on_receiver = False
1010
+ raise
1011
+ else:
1012
+ raise RuntimeError(
1013
+ "Neither init_on_sender nor init_on_receiver have been called."
1014
+ )
1015
+
1016
+ def _setup_connection_and_weights_on_sender_impl(
1017
+ self,
1018
+ *,
1019
+ worker_idx: int | None = None,
1020
+ weights: Any | None = None,
1021
+ ) -> None:
1022
+ """Synchronize weights on sender side.
1023
+
1024
+ Default implementation uses transport's setup_connection_and_weights_on_sender().
1025
+ Subclasses may override for custom behavior.
1026
+ """
1027
+ if self._shared_transport is not None:
1028
+ # We only need to synchronize once
1029
+ self.shared_transport.setup_connection_and_weights_on_sender()
1030
+ return
1031
+
1032
+ idx = -1
1033
+ for idx, transport in enumerate(self._iterate_transports()):
1034
+ if worker_idx is not None and idx != worker_idx:
1035
+ continue
1036
+ transport.setup_connection_and_weights_on_sender()
1037
+ if idx == -1:
1038
+ raise RuntimeError("No transports available.")
1039
+
1040
+ def _setup_connection_and_weights_on_receiver_impl(
1041
+ self, *, worker_idx: int | None = None
1042
+ ) -> None:
1043
+ """Synchronize weights on receiver side.
1044
+
1045
+ Default implementation uses transport's setup_connection_and_weights_on_receiver().
1046
+ Subclasses may override for custom behavior.
1047
+ """
1048
+ if self.receiver_transport is None:
1049
+ return
1050
+
1051
+ # Use stored worker_idx if not provided
1052
+ if worker_idx is None:
1053
+ worker_idx = self._worker_idx
1054
+
1055
+ # Call transport's synchronize method with all relevant kwargs
1056
+ weights = self.receiver_transport.setup_connection_and_weights_on_receiver(
1057
+ worker_idx=worker_idx,
1058
+ weights=self.weights,
1059
+ model=self.model,
1060
+ strategy=self._strategy,
1061
+ )
1062
+
1063
+ # Apply weights to model if received (SharedMemTransport case)
1064
+ # For other transports (MPTransport, etc.), weights is None and synchronization
1065
+ # happens later via receive(), so this is a no-op
1066
+ if weights is not None:
1067
+ model = self.model
1068
+ self._strategy.apply_weights(model, weights, inplace=False)
1069
+
1070
+ @property
1071
+ def synchronized_on_sender(self):
1072
+ return getattr(self, "_synchronized_on_sender", False)
1073
+
1074
+ @synchronized_on_sender.setter
1075
+ def synchronized_on_sender(self, value: bool):
1076
+ self._synchronized_on_sender = value
1077
+
1078
+ @property
1079
+ def synchronized_on_receiver(self):
1080
+ return getattr(self, "_synchronized_on_receiver", False)
1081
+
1082
+ @synchronized_on_receiver.setter
1083
+ def synchronized_on_receiver(self, value: bool):
1084
+ self._synchronized_on_receiver = value
1085
+
1086
+ # ========================================================================
1087
+ # Background Receiver
1088
+ # ========================================================================
1089
+
1090
+ def _start_background_receiver(self):
1091
+ """Start daemon thread that monitors for weight update instructions.
1092
+
1093
+ The background thread runs _background_receive_loop() which waits for
1094
+ instructions via _wait_for_instruction() and calls receive() when
1095
+ an instruction arrives.
1096
+ """
1097
+ if not self.initialized_on_receiver:
1098
+ raise RuntimeError(
1099
+ "_start_background_receiver must be called on the receiver side."
1100
+ )
1101
+ self._stop_event = threading.Event()
1102
+ self._background_thread = threading.Thread(
1103
+ target=self._background_receive_loop,
1104
+ daemon=True,
1105
+ name=f"WeightReceiver-{self._worker_idx}",
1106
+ )
1107
+ self._background_thread.start()
1108
+
1109
+ def _background_receive_loop(self):
1110
+ """Background thread loop that waits for instructions and receives weights.
1111
+
1112
+ Default implementation uses _wait_for_instruction() and receive().
1113
+ Subclasses may override for custom behavior.
1114
+ """
1115
+ while not self._stop_event.is_set():
1116
+ try:
1117
+ instruction = self._wait_for_instruction()
1118
+ if instruction is None:
1119
+ # Stop signal received
1120
+ break
1121
+ if instruction == "receive":
1122
+ self.receive()
1123
+ elif instruction == "stop":
1124
+ break
1125
+ else:
1126
+ torchrl_logger.warning(f"Unknown instruction: {instruction}")
1127
+ except Exception as e:
1128
+ if not self._stop_event.is_set():
1129
+ torchrl_logger.warning(f"Background receiver error: {e}")
1130
+
1131
+ def _wait_for_instruction(self, timeout: float | None = None) -> str | None:
1132
+ """Block until an instruction arrives from the sender.
1133
+
1134
+ This method should be overridden by subclasses to implement
1135
+ scheme-specific instruction waiting (e.g., queue.get(), store polling).
1136
+
1137
+ Args:
1138
+ timeout: Maximum time to wait for instruction (seconds).
1139
+ None means block indefinitely.
1140
+
1141
+ Returns:
1142
+ The instruction string (e.g., "receive", "stop"), or None if
1143
+ stop event is set or timeout expires.
1144
+ """
1145
+ raise NotImplementedError(
1146
+ f"{type(self).__name__} must implement _wait_for_instruction()"
1147
+ )
1148
+
1149
+ def _send_instruction(
1150
+ self,
1151
+ instruction: str = "receive",
1152
+ worker_ids: int | list[int] | None = None,
1153
+ ) -> None:
1154
+ """Send instruction to receiver(s) to trigger weight reception.
1155
+
1156
+ This method should be overridden by subclasses to implement
1157
+ scheme-specific instruction sending (e.g., queue.put(), store.set()).
1158
+
1159
+ Args:
1160
+ instruction: The instruction to send (default: "receive").
1161
+ worker_ids: Which workers to send to (None = all workers).
1162
+ """
1163
+ raise NotImplementedError(
1164
+ f"{type(self).__name__} must implement _send_instruction()"
1165
+ )
1166
+
1167
+ def _send_ack(self, message: str = "updated") -> None:
1168
+ """Send acknowledgment back to sender after receiving weights.
1169
+
1170
+ Called by the background receiver after successfully applying weights.
1171
+ Subclasses should override to implement scheme-specific acknowledgment.
1172
+
1173
+ Args:
1174
+ message: The acknowledgment message (default: "updated").
1175
+ """
1176
+ # Default: use transport's send_ack if available
1177
+ transport = self._receiver_transport or self._shared_transport
1178
+ if transport is not None and hasattr(transport, "send_ack"):
1179
+ transport.send_ack(message)
1180
+
1181
+ def _wait_for_ack( # noqa: B027
1182
+ self,
1183
+ worker_ids: int | list[int] | None = None,
1184
+ timeout: float | None = None,
1185
+ ) -> None:
1186
+ """Wait for acknowledgment from receiver(s).
1187
+
1188
+ Called by send() in synchronous mode to block until receivers confirm.
1189
+ Subclasses should override to implement scheme-specific waiting.
1190
+
1191
+ Args:
1192
+ worker_ids: Which workers to wait for (None = all workers).
1193
+ timeout: Maximum time to wait (seconds). None means block indefinitely.
1194
+ """
1195
+ # Default: no-op (subclasses implement scheme-specific waiting)
1196
+
1197
+ def __getstate__(self):
1198
+ """Prepare the scheme for pickling by excluding non-serializable runtime state."""
1199
+ state = self.__dict__.copy()
1200
+ # Remove non-serializable runtime state
1201
+ state["_context_ref"] = None
1202
+ state["_model_ref"] = None
1203
+
1204
+ state["_initialized_on_sender"] = False
1205
+ state["_initialized_on_receiver"] = False
1206
+
1207
+ state["_synchronized_on_sender"] = False
1208
+ state["_synchronized_on_receiver"] = False
1209
+
1210
+ state["_background_thread"] = None
1211
+ state["_stop_event"] = None
1212
+
1213
+ return state
1214
+
1215
+ def __setstate__(self, state):
1216
+ """Restore the scheme from pickling."""
1217
+ self.__dict__.update(state)
1218
+
1219
+ def __del__(self):
1220
+ """Clean up resources when the scheme is garbage collected."""
1221
+ try:
1222
+ self.shutdown()
1223
+ except Exception:
1224
+ # Silently ignore any errors during garbage collection cleanup
1225
+ pass
1226
+
1227
+ def shutdown(self) -> None:
1228
+ """Shutdown the scheme and release resources.
1229
+
1230
+ This method stops any background threads and cleans up connections.
1231
+ It is safe to call multiple times. Subclasses should override this
1232
+ method to add custom cleanup logic, but should call super().shutdown()
1233
+ to ensure base cleanup is performed.
1234
+ """
1235
+ # Stop background receiver thread if running
1236
+ if getattr(self, "_stop_event", None) is not None:
1237
+ self._stop_event.set()
1238
+ if getattr(self, "_background_thread", None) is not None:
1239
+ try:
1240
+ self._background_thread.join(timeout=5.0)
1241
+ except Exception:
1242
+ pass
1243
+ self._background_thread = None
1244
+ self._stop_event = None