torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,624 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ import torch
7
+ from tensordict import TensorDictBase
8
+ from torch import multiprocessing as mp, nn
9
+ from torchrl.weight_update._shared import SharedMemWeightSyncScheme
10
+ from torchrl.weight_update.utils import _resolve_model
11
+
12
+ from torchrl.weight_update.weight_sync_schemes import TransportBackend
13
+
14
+
15
+ class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme):
16
+ """Weight synchronization for multiprocess operations using queues.
17
+
18
+ This scheme creates transports that communicate via multiprocessing queues.
19
+ Unlike the parent SharedMemWeightSyncScheme which uses shared memory for in-place
20
+ updates, this scheme sends actual weight copies through queues to workers.
21
+
22
+ A background thread on the receiver side listens for "receive" instructions
23
+ from the sender. When an instruction arrives, the thread receives the weights
24
+ from the weight queue and applies them to the model.
25
+
26
+ It follows the same two-phase pattern as SharedMemWeightSyncScheme:
27
+
28
+ 1. **init_on_sender()**: Stores the recipe for creating device-specific weights
29
+ (model reference, devices, mapping functions) without creating actual copies
30
+ 2. **synchronize_weights()**: Creates device-specific weight copies on-demand,
31
+ sends them sequentially to workers via queues, allowing garbage collection
32
+ between workers to minimize memory usage
33
+
34
+ This approach avoids holding multiple weight copies in memory simultaneously,
35
+ which is especially beneficial for large models with many workers.
36
+
37
+ Synchronization flow:
38
+ - **init_on_sender()**: Store configuration and register worker queues
39
+ - **synchronize_weights()**: Create and send initial weights on-demand
40
+ - **init_on_receiver()**: Create receiver that reads from queue
41
+ - **send()**: Extract and send weight updates, wait for acknowledgments
42
+
43
+ Args:
44
+ strategy: The weight transmission strategy (default: "tensordict").
45
+ Can be "tensordict" or "state_dict".
46
+ sync: If True (default), send() blocks until receiver acknowledges.
47
+ If False, send() returns immediately (use send_async/wait_async).
48
+
49
+ Example:
50
+ >>> # Basic usage with collector
51
+ >>> scheme = MultiProcessWeightSyncScheme()
52
+ >>> collector = MultiSyncCollector(
53
+ ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3,
54
+ ... policy=policy,
55
+ ... frames_per_batch=100,
56
+ ... total_frames=1000,
57
+ ... weight_sync_schemes={"policy": scheme},
58
+ ... )
59
+ >>> # scheme.collect() is called automatically by collector
60
+ >>> # Weights are created on-demand and sent to workers efficiently
61
+
62
+ Note:
63
+ The on-demand weight creation means that synchronize_weights() will be
64
+ slower than if weights were pre-computed, but memory usage is significantly
65
+ reduced, especially when workers use different devices or when the model
66
+ is large.
67
+ """
68
+
69
+ def __init__(self, strategy: str = "tensordict", sync: bool = True):
70
+ """Initialize the MultiProcessWeightSyncScheme.
71
+
72
+ Args:
73
+ strategy: The weight transmission strategy (default: "tensordict").
74
+ sync: If True (default), send() blocks until receiver acknowledges.
75
+ """
76
+ super().__init__(strategy, sync=sync)
77
+ # Override parent's shared transport - we don't use shared memory
78
+ self._shared_transport = None
79
+
80
+ def _init_on_sender_impl(
81
+ self,
82
+ *,
83
+ model_id: str | None = None,
84
+ context: Any = None,
85
+ weights: TensorDictBase | None = None,
86
+ model: nn.Module | None = None,
87
+ params_map: dict[int, TensorDictBase] | None = None,
88
+ devices: list[torch.device] | None = None,
89
+ device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None,
90
+ num_workers: int | None = None,
91
+ ctx: Any = None,
92
+ **kwargs,
93
+ ) -> None:
94
+ """Initialize on the main process (sender side).
95
+
96
+ This method stores the configuration needed to create device-specific weight
97
+ copies during synchronization. Weight copies are created on-demand during
98
+ `synchronize_weights()` to reduce memory usage.
99
+
100
+ Similar to `SharedMemWeightSyncScheme`, this follows a two-phase pattern:
101
+ 1. `init_on_sender()`: Store the recipe for creating weights
102
+ 2. `synchronize_weights()`: Create and send weights on-demand
103
+
104
+ Args:
105
+ model_id: Identifier for the model being synchronized (e.g., "policy").
106
+ Required when using context.
107
+ context: Optional context object (e.g., collector) providing:
108
+ - num_workers: Number of worker processes
109
+ - policy_device: List of devices for each worker
110
+ When provided, model_id is used to resolve the model from context.
111
+ weights: Pre-extracted weights as TensorDict. Mutually exclusive with
112
+ model and context. Used when weights are already available.
113
+ model: Model to extract weights from. Mutually exclusive with weights
114
+ and context.
115
+ params_map: Pre-computed mapping of worker_idx to device-specific weights.
116
+ Most explicit option. When provided, all other parameters must be None.
117
+ devices: List of devices for each worker. Used with weights or model to
118
+ automatically create device-specific copies. Length must equal num_workers.
119
+ device_map_fn: Custom function (worker_idx, weights) -> device_weights.
120
+ Allows full control over device mapping. Requires num_workers.
121
+ num_workers: Number of workers. Required with device_map_fn, inferred
122
+ from devices length otherwise.
123
+ ctx: The multiprocessing context to use. Defaults to `multiprocessing.get_context()`.
124
+ **kwargs: Reserved for future use.
125
+
126
+ Examples:
127
+ Simple usage with collector context (most common):
128
+
129
+ >>> scheme = MultiProcessWeightSyncScheme()
130
+ >>> collector = MultiSyncCollector(
131
+ ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3,
132
+ ... policy=policy,
133
+ ... frames_per_batch=100,
134
+ ... weight_sync_schemes={"policy": scheme},
135
+ ... )
136
+ >>> # scheme.init_on_sender() is called automatically by collector
137
+
138
+ Direct initialization with explicit devices:
139
+
140
+ >>> scheme = MultiProcessWeightSyncScheme()
141
+ >>> weights = TensorDict.from_module(policy)
142
+ >>> scheme.init_on_sender(
143
+ ... weights=weights,
144
+ ... devices=[torch.device("cpu"), torch.device("cuda:0")],
145
+ ... num_workers=2,
146
+ ... )
147
+
148
+ Advanced: Pre-computed params_map:
149
+
150
+ >>> weights_cpu = TensorDict.from_module(policy)
151
+ >>> weights_cuda = weights_cpu.to("cuda")
152
+ >>> scheme.init_on_sender(
153
+ ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda},
154
+ ... num_workers=3,
155
+ ... )
156
+ """
157
+ # Get params_map from parent class logic
158
+ params_map_result = self._get_params_map(
159
+ context=context,
160
+ model_id=model_id,
161
+ weights=weights,
162
+ model=model,
163
+ params_map=params_map,
164
+ devices=devices,
165
+ device_map_fn=device_map_fn,
166
+ num_workers=num_workers,
167
+ )
168
+
169
+ # Store the mapping recipe for later use in synchronize_weights
170
+ # Don't store params_map directly to save memory - we'll recompute on demand
171
+ # Note: We don't store context directly to avoid pickle issues -
172
+ # it's available via _context_ref
173
+ self._device_mapping_info = {
174
+ "model_id": model_id,
175
+ "weights": weights,
176
+ "model": model,
177
+ "params_map": params_map,
178
+ "devices": devices,
179
+ "device_map_fn": device_map_fn,
180
+ "num_workers": num_workers
181
+ if num_workers is not None
182
+ else len(params_map_result),
183
+ }
184
+
185
+ # Create per-worker queues for weight distribution
186
+ # Each worker gets its own queue for receiving weights
187
+ all_workers = list(params_map_result.keys())
188
+ if not hasattr(self, "_weight_init_queues"):
189
+ self._weight_init_queues = {}
190
+
191
+ if ctx is None:
192
+ ctx = mp.get_context()
193
+ for worker_idx in all_workers:
194
+ if worker_idx not in self._weight_init_queues:
195
+ self._weight_init_queues[worker_idx] = ctx.Queue()
196
+ # Create instruction queues for background receiver
197
+ if worker_idx not in self._instruction_queues:
198
+ self._instruction_queues[worker_idx] = ctx.Queue()
199
+ # Create ack queues for synchronous mode
200
+ if worker_idx not in self._ack_queues:
201
+ self._ack_queues[worker_idx] = ctx.Queue()
202
+
203
+ # Store model_id and context on scheme
204
+ self.model_id = model_id
205
+ if context is not None:
206
+ self.context = context
207
+
208
+ # Register workers with their queues
209
+ for worker_idx in all_workers:
210
+ queue = self._weight_init_queues[worker_idx]
211
+ ack_queue = self._ack_queues[worker_idx]
212
+ # Create MPTransport for this worker with ack queue
213
+ transport = MPTransport(weight_queue=queue, ack_queue=ack_queue)
214
+ self._register_worker_sender(worker_idx=worker_idx, transport=transport)
215
+
216
+ def _init_on_receiver_impl(
217
+ self,
218
+ *,
219
+ model_id: str,
220
+ context: Any = None,
221
+ **kwargs,
222
+ ) -> None:
223
+ """Initialize on worker process (receiver side).
224
+
225
+ Args:
226
+ model_id: Identifier for the model being synchronized
227
+ context: Optional context object providing worker_idx and model
228
+ **kwargs: Alternative to context (worker_idx, model, etc.)
229
+ """
230
+ # Extract parameters from context or kwargs
231
+ if context is not None:
232
+ worker_idx = getattr(context, "worker_idx", None)
233
+ if hasattr(context, "get_model"):
234
+ model = context.get_model(model_id)
235
+ else:
236
+ model = _resolve_model(context, model_id)
237
+ else:
238
+ worker_idx = kwargs.get("worker_idx")
239
+ model = kwargs.get("model")
240
+
241
+ if worker_idx is None:
242
+ raise ValueError("worker_idx must be provided via context or kwargs")
243
+
244
+ # Get the queue for this worker
245
+ if worker_idx not in self._weight_init_queues:
246
+ raise ValueError(
247
+ f"Worker {worker_idx} not registered. init_on_sender() must be called first."
248
+ )
249
+
250
+ queue = self._weight_init_queues[worker_idx]
251
+ ack_queue = self._ack_queues.get(worker_idx)
252
+
253
+ # Store on scheme directly
254
+ self.model_id = model_id
255
+ if context is not None:
256
+ self.context = context
257
+
258
+ # Store instruction and ack queue references for this worker
259
+ if worker_idx in self._instruction_queues:
260
+ self._receiver_instruction_queue = self._instruction_queues[worker_idx]
261
+ if worker_idx in self._ack_queues:
262
+ self._receiver_ack_queue = self._ack_queues[worker_idx]
263
+
264
+ # Create transport with the worker's queue and ack queue
265
+ transport = MPTransport(weight_queue=queue, ack_queue=ack_queue)
266
+ self._register_transport_receiver(transport=transport)
267
+
268
+ if model is not None:
269
+ self.model = model
270
+
271
+ # Store worker_idx for synchronize_weights
272
+ self.worker_idx = worker_idx
273
+
274
+ def send(
275
+ self,
276
+ weights: Any = None,
277
+ worker_ids: int | list[int] | None = None,
278
+ ) -> None:
279
+ """Send weights synchronously to workers.
280
+
281
+ This method:
282
+ 1. Prepares weights (extracts from model if weights=None)
283
+ 2. Sends weights to the weight queue
284
+ 3. Sends "receive" instruction to workers' background threads
285
+ 4. If sync=True, waits for acknowledgments from those workers
286
+
287
+ Args:
288
+ weights: Weights to send. Can be:
289
+ - None: Extract from model via context.get_model(model_id)
290
+ - nn.Module: Extract weights from module
291
+ - TensorDict: Use directly
292
+ - dict: Convert to TensorDict
293
+ worker_ids: Which workers to send to:
294
+ - None: Send to all workers (default)
295
+ - int: Send to single worker
296
+ - list[int]: Send to specific workers
297
+
298
+ Note: If sync=True (default), this is a blocking call that ensures
299
+ specified workers are updated before returning.
300
+ """
301
+ if not self.initialized_on_sender:
302
+ raise RuntimeError("Must be initialized on sender before sending weights")
303
+ if not self.synchronized_on_sender:
304
+ raise RuntimeError("Must be synchronized on sender before sending weights")
305
+
306
+ model_id = self.model_id
307
+ context = self.context
308
+
309
+ # Let the scheme prepare the weights
310
+ prepared_weights = self.prepare_weights(
311
+ weights=weights,
312
+ model_id=model_id,
313
+ strategy=self._strategy,
314
+ context=context,
315
+ )
316
+
317
+ transports = list(self._iterate_transports(worker_ids))
318
+
319
+ # Send weights to all workers first via queue (non-blocking)
320
+ for transport in transports:
321
+ if hasattr(transport, "send_weights_async"):
322
+ # For MPTransport, pass model_id; other transports don't need it
323
+ transport.send_weights_async(prepared_weights, model_id=model_id)
324
+ else:
325
+ # Fallback for transports that don't support async send
326
+ transport.send_weights(prepared_weights)
327
+
328
+ # Send instruction to workers' background threads to receive the weights
329
+ self._send_instruction(instruction="receive", worker_ids=worker_ids)
330
+
331
+ # Wait for all acknowledgments if in synchronous mode
332
+ if self.sync:
333
+ self._wait_for_ack(worker_ids=worker_ids)
334
+
335
+ def _setup_connection_and_weights_on_sender_impl(
336
+ self,
337
+ *,
338
+ worker_idx: int | None = None,
339
+ weights: Any | None = None,
340
+ ) -> None:
341
+ """Synchronize weights with workers before collection starts.
342
+
343
+ Computes device-specific weight copies on-demand and sends them to workers
344
+ sequentially via queues. This is called once after workers are initialized
345
+ but before they start collecting data.
346
+
347
+ Unlike send(), this does not wait for acknowledgments since workers are still
348
+ in their initialization phase.
349
+
350
+ This approach creates weight copies on-demand and sends them sequentially,
351
+ allowing garbage collection between workers to reduce memory usage.
352
+
353
+ Raises:
354
+ RuntimeError: If init_on_sender() was not called first.
355
+ """
356
+ # Get the device mapping info stored during init_on_sender
357
+ if not hasattr(self, "_device_mapping_info"):
358
+ raise RuntimeError(
359
+ "synchronize_weights() requires init_on_sender() to be called first"
360
+ )
361
+
362
+ mapping_info = self._device_mapping_info
363
+
364
+ # Get context from weakref
365
+ context = self.context
366
+
367
+ # Compute params_map on-demand
368
+ # Extract with explicit type casting for type checker
369
+ model_id = mapping_info["model_id"]
370
+ weights = mapping_info["weights"]
371
+ model = mapping_info["model"]
372
+ params_map_arg = mapping_info["params_map"]
373
+ devices = mapping_info["devices"]
374
+ device_map_fn = mapping_info["device_map_fn"]
375
+ num_workers = mapping_info["num_workers"]
376
+
377
+ params_map = self._get_params_map(
378
+ context=context,
379
+ model_id=model_id,
380
+ weights=weights,
381
+ model=model,
382
+ params_map=params_map_arg,
383
+ devices=devices,
384
+ device_map_fn=device_map_fn,
385
+ num_workers=num_workers,
386
+ )
387
+
388
+ # Send to workers sequentially via queues (no ACK - workers are still initializing)
389
+ # This allows GC to clean up each worker's weights before creating the next
390
+ for i, transport in enumerate(self._iterate_transports()):
391
+ if worker_idx is not None and i != worker_idx:
392
+ continue
393
+ worker_weights = params_map[i]
394
+ if hasattr(transport, "send_weights_async"):
395
+ transport.send_weights_async(worker_weights, model_id=self._model_id)
396
+ else:
397
+ raise RuntimeError(
398
+ f"Transport {type(transport)} does not support async send for synchronization"
399
+ )
400
+
401
+ # Clean up the mapping info after synchronization
402
+ delattr(self, "_device_mapping_info")
403
+
404
+ def _setup_connection_and_weights_on_receiver_impl(
405
+ self, *, worker_idx: int | None = None
406
+ ) -> None:
407
+ """Receive initial weights and start background receiver thread.
408
+
409
+ This method:
410
+ 1. Receives initial weights from the sender via queue
411
+ 2. Applies them to the model
412
+ 3. Starts a background thread that listens for "receive" instructions
413
+
414
+ Args:
415
+ worker_idx: The worker index.
416
+ """
417
+ # Use stored worker_idx if not provided
418
+ if worker_idx is None:
419
+ worker_idx = self._worker_idx
420
+
421
+ if worker_idx is None:
422
+ raise RuntimeError(
423
+ "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl."
424
+ )
425
+
426
+ # Receive initial weights from queue via transport
427
+ if self._receiver_transport is None:
428
+ raise RuntimeError("Receiver transport not set.")
429
+
430
+ weights = self._receiver_transport.setup_connection_and_weights_on_receiver(
431
+ worker_idx=worker_idx,
432
+ weights=self.weights,
433
+ model=self.model,
434
+ strategy=self._strategy,
435
+ )
436
+
437
+ # Store received weights for later use
438
+ if weights is not None:
439
+ self._receiver_weights = weights
440
+
441
+ # Apply weights to model
442
+ if weights is not None and self.model is not None:
443
+ self._strategy.apply_weights(self.model, weights, inplace=False)
444
+
445
+ # Start background receiver thread
446
+ self._start_background_receiver()
447
+
448
+ def _background_receive_loop(self):
449
+ """Background thread loop that waits for instructions and receives weights.
450
+
451
+ This loop:
452
+ 1. Waits for a "receive" instruction from the sender
453
+ 2. Receives weights from the weight queue
454
+ 3. Applies them to the model
455
+ 4. Sends an acknowledgment back to the sender
456
+ 5. Repeats until stop event is set or "stop" instruction received
457
+ """
458
+ from torchrl._utils import logger as torchrl_logger
459
+
460
+ while not self._stop_event.is_set():
461
+ try:
462
+ instruction = self._wait_for_instruction()
463
+ if instruction is None:
464
+ # Stop event was set or timeout
465
+ continue
466
+ if instruction == "receive":
467
+ # Receive weights from transport (blocking)
468
+ if self._receiver_transport is not None:
469
+ weights = self._receiver_transport.receive_weights(
470
+ model=self.model,
471
+ strategy=self._strategy,
472
+ )
473
+
474
+ if weights is not None:
475
+ # Cascade weight update to sub-collectors if context supports it
476
+ model_id = self._model_id or "policy"
477
+ if self.context is not None and hasattr(
478
+ self.context, "update_policy_weights_"
479
+ ):
480
+ self.context.update_policy_weights_(
481
+ model_id=model_id, policy_or_weights=weights
482
+ )
483
+
484
+ # Send acknowledgment
485
+ self._send_ack("updated")
486
+
487
+ elif instruction == "stop":
488
+ break
489
+ else:
490
+ torchrl_logger.warning(
491
+ f"MultiProcessWeightSyncScheme: Unknown instruction: {instruction}"
492
+ )
493
+ except Exception as e:
494
+ if not self._stop_event.is_set():
495
+ torchrl_logger.warning(
496
+ f"MultiProcessWeightSyncScheme: Background receiver error: {e}"
497
+ )
498
+
499
+ def create_transport(self, **kwargs) -> TransportBackend:
500
+ """Create an MPTransport using the provided queue.
501
+
502
+ Note:
503
+ This is used internally by init_on_sender/init_on_receiver.
504
+ """
505
+ queue = kwargs.get("queue")
506
+ return MPTransport(weight_queue=queue, ack_queue=None)
507
+
508
+
509
+ class MPTransport:
510
+ """Multiprocessing transport using queues.
511
+
512
+ This transport uses queues for weight distribution and synchronization.
513
+ Similar to SharedMemTransport's queue-based approach, MPTransport uses
514
+ queues to send initial weights to workers during synchronization.
515
+
516
+ Initialization flow:
517
+ - synchronize_weights() extracts weights and sends to all workers via queues
518
+ - Workers receive the initial weights via setup_connection_and_weights_on_receiver()
519
+ - Subsequent updates use send_weights_async() followed by acknowledgments
520
+
521
+ Args:
522
+ weight_queue (mp.Queue): The queue to use for sending weights.
523
+ ack_queue (mp.Queue): The queue to use for receiving acknowledgments.
524
+ timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds.
525
+ """
526
+
527
+ def __init__(self, weight_queue, ack_queue=None, timeout: float = 10.0):
528
+ self.timeout = timeout
529
+ self.weight_queue = weight_queue
530
+ self.ack_queue = ack_queue
531
+
532
+ def send_weights_async(self, weights: Any, model_id: str = "policy") -> None:
533
+ """Send weights through the queue without waiting for acknowledgment.
534
+
535
+ Use wait_ack() to wait for acknowledgment after sending to all workers.
536
+ """
537
+ # Send in format expected by worker loop: ((model_id, weights), "update_weights")
538
+ self.weight_queue.put(((model_id, weights), "update_weights"))
539
+
540
+ def receive_weights(
541
+ self,
542
+ timeout: float | None = None,
543
+ *,
544
+ weights: Any = None,
545
+ model: Any = None,
546
+ strategy: Any = None,
547
+ ) -> Any | None:
548
+ """Receive weights from the queue (used in worker process).
549
+
550
+ This method only handles weight update messages. Other messages
551
+ (like "close", "continue", etc.) are ignored and should be handled
552
+ by the main worker loop.
553
+
554
+ Args:
555
+ timeout: Maximum time to wait for weights (seconds).
556
+ None means use the transport's default timeout.
557
+ weights: Ignored (weights come from queue).
558
+ model: The model to apply weights to.
559
+ strategy: Strategy for applying weights to the model.
560
+
561
+ Returns:
562
+ The received weights, or None if no data available.
563
+ """
564
+ # Use transport's default timeout if not specified
565
+ if timeout is None:
566
+ timeout = self.timeout
567
+ data_in, msg = self.weight_queue.get(timeout=timeout)
568
+ if msg == "update_weights":
569
+ # data_in is (model_id, weights) - we ignore model_id, scheme knows it
570
+ _model_id, received_weights = data_in
571
+
572
+ # Apply weights to model if provided
573
+ if model is not None and strategy is not None:
574
+ strategy.apply_weights(model, received_weights)
575
+
576
+ return received_weights
577
+ else:
578
+ raise ValueError(f"Expected 'update_weights' but got {msg}")
579
+
580
+ def setup_connection_and_weights_on_sender(self) -> None:
581
+ """No-op for MPTransport - weights are sent via scheme's synchronize_weights().
582
+
583
+ The actual sending happens in MultiProcessWeightSyncScheme._setup_connection_and_weights_on_sender_impl(), which:
584
+ 1. Extracts weights from the context (e.g., collector.policy)
585
+ 2. Calls send_weights_async() on all worker transports
586
+ 3. Sends initial weights through queues to all workers
587
+
588
+ This is similar to SharedMemTransport.setup_connection_and_weights_on_sender() which
589
+ sends shared memory buffer references via queues.
590
+ """
591
+
592
+ def setup_connection_and_weights_on_receiver(
593
+ self,
594
+ *,
595
+ worker_idx: int,
596
+ weights: Any = None,
597
+ model: Any = None,
598
+ strategy: Any = None,
599
+ ) -> Any:
600
+ """Receive initial weights from sender during worker initialization.
601
+
602
+ This method blocks waiting for the initial weights to be sent from the main process
603
+ via queue. Similar to SharedMemTransport.setup_connection_and_weights_on_receiver() which receives
604
+ shared memory buffer references via queues, this receives the actual weights via queues.
605
+
606
+ The received weights are then applied to the worker's model by the scheme's synchronize_weights().
607
+
608
+ Args:
609
+ worker_idx: The worker index (used for logging/debugging).
610
+ weights: Ignored (weights come from queue).
611
+ model: Ignored.
612
+ strategy: Ignored.
613
+
614
+ Returns:
615
+ The received weights if available, None otherwise (weights will come later via receive()).
616
+ """
617
+ # Wait for initial weights (blocking)
618
+ data_in, msg = self.weight_queue.get(timeout=self.timeout)
619
+ if msg == "update_weights":
620
+ # data_in is (model_id, weights), extract just the weights
621
+ _, received_weights = data_in
622
+ return received_weights
623
+ else:
624
+ raise ValueError(f"Expected 'update_weights' but got {msg}")