torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,749 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ import socket
5
+ import time
6
+ import weakref
7
+ from datetime import timedelta
8
+ from typing import Any
9
+
10
+ import torch
11
+ from tensordict import TensorDictBase
12
+ from torchrl._utils import logger as torchrl_logger
13
+
14
+ from torchrl.weight_update.utils import _resolve_model
15
+
16
+ from torchrl.weight_update.weight_sync_schemes import (
17
+ TransportBackend,
18
+ WeightStrategy,
19
+ WeightSyncScheme,
20
+ )
21
+
22
+
23
+ class DistributedWeightSyncScheme(WeightSyncScheme):
24
+ """Weight synchronization for torch.distributed.
25
+
26
+ This scheme uses torch.distributed primitives (send/recv) to synchronize
27
+ weights across distributed workers. Each worker gets its own transport,
28
+ following the same pattern as multiprocess collectors.
29
+
30
+ The scheme can create its own TCPStore for coordination if one is not provided.
31
+ Use `get_store_info()` after `init_on_sender()` to get connection details for workers.
32
+
33
+ Args:
34
+ backend (str): The distributed backend ("gloo", "nccl", etc.)
35
+ sync (bool): If True, weight updates are synchronous (blocking receive).
36
+ If False, a background thread monitors the store and applies weight
37
+ updates automatically. Defaults to True.
38
+ timeout (float): Timeout in seconds for TCPStore operations.
39
+ Defaults to 3600.0 (1 hour).
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ backend: str = "gloo",
45
+ sync: bool = True,
46
+ timeout: float = 3600.0,
47
+ ):
48
+ super().__init__()
49
+ self.backend = backend
50
+ self.sync = sync
51
+ self._timeout = timeout
52
+ self._store = None
53
+ self._store_info = None
54
+ self._num_workers = None
55
+
56
+ def __getstate__(self):
57
+ """Custom serialization - exclude non-picklable objects."""
58
+ state = super().__getstate__()
59
+ # TCPStore cannot be pickled - remove it but keep _store_info
60
+ state["_store"] = None
61
+
62
+ # Thread and Event cannot be pickled
63
+ state["_background_thread"] = None
64
+ state["_stop_event"] = None
65
+
66
+ # Transports contain references to store/groups - exclude them
67
+ # The receiver will create its own transport in init_on_receiver
68
+ state["_sender_transports"] = {}
69
+ state["_receiver_transport"] = None
70
+ return state
71
+
72
+ def __setstate__(self, state):
73
+ """Custom deserialization."""
74
+ super().__setstate__(state)
75
+
76
+ def _init_on_sender_impl(
77
+ self,
78
+ *,
79
+ model_id: str,
80
+ context: Any = None,
81
+ num_workers: int,
82
+ model: Any = None,
83
+ weights: Any = None,
84
+ **kwargs,
85
+ ) -> None:
86
+ if kwargs:
87
+ raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}")
88
+ self.model_id = model_id
89
+ self._num_workers = num_workers
90
+
91
+ # Attach context so we can resolve the model and prepare
92
+ # weights on demand via scheme.prepare_weights().
93
+ weights_buffer = None
94
+ if context is not None:
95
+ self.context = context
96
+ if weights is not None:
97
+ self.weights = weights
98
+ weights_buffer = weights
99
+ if model is not None:
100
+ self.model = model
101
+ else:
102
+ # resolve model
103
+ try:
104
+ model = self.model
105
+ except (AttributeError, ValueError):
106
+ pass
107
+
108
+ if weights_buffer is None and model is not None:
109
+ weights_buffer = self._get_weights_buffer_from_model(model)
110
+
111
+ # Get base tcp_port from context if available to avoid port conflicts.
112
+ # The DistributedDataCollector uses tcp_port for init and tcp_port+1 for its store,
113
+ # so we use tcp_port+2 for the weight sync scheme's store.
114
+ base_tcp_port = (
115
+ getattr(context, "tcp_port", None) if context is not None else None
116
+ )
117
+ self._store = self._make_store(
118
+ is_master=True, num_workers=num_workers, base_tcp_port=base_tcp_port
119
+ )
120
+
121
+ for i in range(num_workers):
122
+ rank = i + 1 # Workers are 1-indexed in distributed
123
+ transport = self.create_transport(
124
+ store=self._store,
125
+ rank=rank,
126
+ weights_buffer=weights_buffer,
127
+ sync=self.sync,
128
+ )
129
+ self._register_worker_sender(worker_idx=i, transport=transport)
130
+
131
+ def _make_store(
132
+ self,
133
+ is_master: bool,
134
+ num_workers: int | None = None,
135
+ store_info: dict | None = None,
136
+ base_tcp_port: int | str | None = None,
137
+ max_retries: int = 10,
138
+ ) -> torch.distributed.TCPStore:
139
+ """Create a TCPStore for weight synchronization.
140
+
141
+ Args:
142
+ is_master: If True, creates the store as master (server).
143
+ If False, connects as client.
144
+ num_workers: Number of workers (required for master).
145
+ store_info: Dictionary with 'host' and 'port' keys (required for client).
146
+ base_tcp_port: Base TCP port from the collector. If provided, the store
147
+ will use base_tcp_port + 2 to avoid conflicts with the collector's
148
+ stores (which use base_tcp_port and base_tcp_port + 1).
149
+ max_retries: Maximum number of retry attempts for handling port conflicts.
150
+
151
+ Returns:
152
+ The created TCPStore.
153
+ """
154
+ if is_master:
155
+ # Create as master (server)
156
+ if num_workers is None:
157
+ raise ValueError(
158
+ "num_workers is required when creating store as master"
159
+ )
160
+
161
+ hostname = socket.gethostname()
162
+ host = socket.gethostbyname(hostname)
163
+
164
+ # Use base_tcp_port + 2 if available (to avoid conflicts with collector's
165
+ # tcp_port and tcp_port + 1), otherwise find a free port dynamically.
166
+ initial_port = int(base_tcp_port) + 2 if base_tcp_port is not None else None
167
+
168
+ last_error = None
169
+ for attempt in range(max_retries):
170
+ if initial_port is None or attempt > 0:
171
+ # Find a free port dynamically
172
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
173
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
174
+ s.bind(("", 0))
175
+ self._store_port = s.getsockname()[1]
176
+ else:
177
+ self._store_port = initial_port
178
+
179
+ try:
180
+ store = torch.distributed.TCPStore(
181
+ host_name=host,
182
+ port=self._store_port,
183
+ is_master=True,
184
+ timeout=timedelta(seconds=self._timeout),
185
+ wait_for_workers=False, # Don't block - workers may not be started yet
186
+ )
187
+ self._store_info = {"host": host, "port": self._store_port}
188
+ return store
189
+ except (RuntimeError, OSError) as e:
190
+ error_msg = str(e).lower()
191
+ if (
192
+ "address already in use" in error_msg
193
+ or "eaddrinuse" in error_msg
194
+ ):
195
+ last_error = e
196
+ # Add small random delay to reduce collision probability
197
+ time.sleep(random.uniform(0.01, 0.1))
198
+ continue
199
+ # For other errors, re-raise immediately
200
+ raise
201
+
202
+ raise RuntimeError(
203
+ f"DistributedWeightSyncScheme: Failed to create TCPStore after {max_retries} attempts. "
204
+ f"Last error: {last_error}"
205
+ )
206
+ else:
207
+ # Connect as client
208
+ if store_info is None:
209
+ raise ValueError("store_info is required when connecting as client")
210
+ store = torch.distributed.TCPStore(
211
+ host_name=store_info["host"],
212
+ port=store_info["port"],
213
+ is_master=False,
214
+ timeout=timedelta(seconds=self._timeout),
215
+ )
216
+ return store
217
+
218
+ def _init_on_receiver_impl(
219
+ self,
220
+ *,
221
+ model_id: str,
222
+ context: Any = None,
223
+ store_info: dict | None = None,
224
+ worker_idx: int | None = None,
225
+ **kwargs,
226
+ ) -> None:
227
+ """Initialize scheme on the worker (receiver) side.
228
+
229
+ Expected kwargs (as provided by collectors):
230
+ - model_id: str # e.g. "policy"
231
+ - context: Any # collector / inner collector
232
+ - store: TCPStore | None # distributed TCP store
233
+ - store_info: dict | None # {"host": ..., "port": ...} to create store
234
+ - rank: int | None # worker rank (1-indexed)
235
+ """
236
+ if context is None:
237
+ raise ValueError(
238
+ "DistributedWeightSyncScheme.init_on_receiver requires a 'context' "
239
+ "providing access to the model to be synchronized."
240
+ )
241
+ if worker_idx is None:
242
+ raise RuntimeError("rank was not provided.")
243
+ if kwargs:
244
+ raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}")
245
+
246
+ # Store model_id and context on scheme
247
+ self.model_id = model_id
248
+ self.context = context
249
+
250
+ # Get or create store
251
+ # Priority: provided store > provided store_info > self._store_info (from serialization)
252
+ # Connect to master's TCPStore as client
253
+ info = self._store_info
254
+ if info is None:
255
+ raise RuntimeError(
256
+ "TCPStore info not available. init_on_sender() must be called first on the sender side, before passing the scheme to the receiver."
257
+ )
258
+ self._store = self._make_store(is_master=False, store_info=info)
259
+
260
+ if (model := getattr(self, "model", None)) is not None:
261
+ self.model = model
262
+ weights_buffer = self._get_weights_buffer_from_model(model)
263
+ else:
264
+ raise RuntimeError("Couldn't find weights")
265
+ self._receiver_transport = self.create_transport(
266
+ store=self._store,
267
+ rank=worker_idx,
268
+ weights_buffer=weights_buffer,
269
+ sync=self.sync,
270
+ )
271
+
272
+ # Store worker_idx for synchronize_weights
273
+ self._worker_idx = worker_idx
274
+ # Note: Background thread for async mode is started in connect() after init_process_group
275
+
276
+ def _wait_for_instruction(self, timeout: float | None = None) -> str | None:
277
+ """Block until an instruction arrives via TCPStore.
278
+
279
+ Args:
280
+ timeout: Maximum time to wait for instruction (seconds).
281
+ None means block indefinitely.
282
+
283
+ Returns:
284
+ The instruction string (e.g., "receive", "stop"), or None if
285
+ stop event is set or timeout expires.
286
+ """
287
+ key = f"NODE_{self._worker_idx}_in"
288
+ start_time = time.monotonic()
289
+
290
+ while True:
291
+ if self._stop_event is not None and self._stop_event.is_set():
292
+ return None
293
+
294
+ try:
295
+ instruction = self._store.get(key)
296
+ self._store.delete_key(key)
297
+ # Decode bytes to string
298
+ return (
299
+ instruction.decode()
300
+ if isinstance(instruction, bytes)
301
+ else instruction
302
+ )
303
+ except RuntimeError:
304
+ # Key doesn't exist yet, continue polling
305
+ pass
306
+
307
+ # Check timeout
308
+ if timeout is not None:
309
+ elapsed = time.monotonic() - start_time
310
+ if elapsed >= timeout:
311
+ return None
312
+
313
+ time.sleep(0.01)
314
+
315
+ def _send_instruction(
316
+ self,
317
+ instruction: str = "receive",
318
+ worker_ids: int | list[int] | None = None,
319
+ ) -> None:
320
+ """Send instruction to receiver(s) via TCPStore.
321
+
322
+ Args:
323
+ instruction: The instruction to send (default: "receive").
324
+ worker_ids: Which workers to send to (None = all workers).
325
+ """
326
+ if self._store is None:
327
+ raise RuntimeError(
328
+ "Store not initialized. init_on_sender() must be called first."
329
+ )
330
+
331
+ if worker_ids is None:
332
+ target_workers = list(range(self._num_workers)) if self._num_workers else []
333
+ elif isinstance(worker_ids, int):
334
+ target_workers = [worker_ids]
335
+ else:
336
+ target_workers = list(worker_ids)
337
+
338
+ # Map instruction to TCPStore format
339
+ store_instruction = (
340
+ b"update_weights" if instruction == "receive" else instruction.encode()
341
+ )
342
+
343
+ for worker_idx in target_workers:
344
+ rank = worker_idx + 1 # Workers are 1-indexed in distributed
345
+ self._store.set(f"NODE_{rank}_in", store_instruction)
346
+
347
+ def _send_ack(self, message: str = "updated") -> None:
348
+ """Send acknowledgment back to sender via TCPStore.
349
+
350
+ Args:
351
+ message: The acknowledgment message (default: "updated").
352
+ """
353
+ if self._store is None or self._worker_idx is None:
354
+ return
355
+ self._store.set(f"NODE_{self._worker_idx}_out", message.encode())
356
+
357
+ def _wait_for_ack(
358
+ self,
359
+ worker_ids: int | list[int] | None = None,
360
+ timeout: float | None = None,
361
+ ) -> None:
362
+ """Wait for acknowledgment from receiver(s) via TCPStore.
363
+
364
+ Args:
365
+ worker_ids: Which workers to wait for (None = all workers).
366
+ timeout: Maximum time to wait (seconds). None means block indefinitely.
367
+ """
368
+ if self._store is None:
369
+ return
370
+
371
+ if worker_ids is None:
372
+ target_workers = list(range(self._num_workers)) if self._num_workers else []
373
+ elif isinstance(worker_ids, int):
374
+ target_workers = [worker_ids]
375
+ else:
376
+ target_workers = list(worker_ids)
377
+
378
+ for worker_idx in target_workers:
379
+ rank = worker_idx + 1
380
+ key = f"NODE_{rank}_out"
381
+ try:
382
+ status = self._store.get(key)
383
+ if status != b"updated":
384
+ torchrl_logger.warning(
385
+ f"Unexpected ack from worker {worker_idx}: {status}"
386
+ )
387
+ self._store.delete_key(key)
388
+ except Exception as e:
389
+ torchrl_logger.warning(
390
+ f"Timeout waiting for ack from worker {worker_idx}: {e}"
391
+ )
392
+
393
+ def _background_receive_loop(self):
394
+ """Background thread loop that waits for instructions and receives weights.
395
+
396
+ This loop:
397
+ 1. Waits for an instruction via TCPStore
398
+ 2. Receives weights via torch.distributed
399
+ 3. Sends an acknowledgment back
400
+ 4. Repeats until stop event is set
401
+ """
402
+ while not self._stop_event.is_set():
403
+ try:
404
+ instruction = self._wait_for_instruction()
405
+ if instruction is None:
406
+ continue
407
+ if instruction in ("receive", "update_weights"):
408
+ # Receive weights via torch.distributed
409
+ weights = self._receiver_transport.receive_weights(
410
+ model=self.model,
411
+ strategy=self._strategy,
412
+ )
413
+
414
+ if weights is not None:
415
+ # Cascade weight update to sub-collectors if context supports it
416
+ model_id = self._model_id or "policy"
417
+ if self.context is not None and hasattr(
418
+ self.context, "update_policy_weights_"
419
+ ):
420
+ self.context.update_policy_weights_(
421
+ model_id=model_id, policy_or_weights=weights
422
+ )
423
+
424
+ # Send acknowledgment
425
+ self._send_ack("updated")
426
+
427
+ elif instruction == "stop":
428
+ break
429
+ else:
430
+ torchrl_logger.warning(
431
+ f"DistributedWeightSyncScheme: Unknown instruction: {instruction}"
432
+ )
433
+
434
+ except Exception as e:
435
+ if not self._stop_event.is_set():
436
+ torchrl_logger.warning(
437
+ f"DistributedWeightSyncScheme: Background receiver error: {e}"
438
+ )
439
+
440
+ def _setup_connection_and_weights_on_sender_impl(
441
+ self, *, worker_idx: int | None = None, weights: Any | None = None
442
+ ) -> None:
443
+ """Send initial weights to all workers during connect().
444
+
445
+ If the sender has a stateful model (weights available), send them
446
+ to all workers so they start with the correct weights.
447
+
448
+ Note: This uses direct torch.distributed send/recv without TCPStore
449
+ signaling to avoid interfering with the main collection loop.
450
+ """
451
+ # Initialize torch.distributed process group if not already done
452
+ # This is a collective operation - all workers must call it
453
+ if not torch.distributed.is_initialized():
454
+ torch.distributed.init_process_group(
455
+ backend=self.backend,
456
+ rank=0, # Sender is always rank 0
457
+ world_size=self._num_workers + 1,
458
+ timeout=timedelta(seconds=self._timeout),
459
+ )
460
+
461
+ # Check if we have weights to send
462
+ if weights is None and getattr(self, "model", None) is None:
463
+ self._store.set("STATELESS_MODEL", b"1")
464
+ return
465
+
466
+ self._store.set("STATELESS_MODEL", b"0")
467
+ # Prepare weights from model
468
+ weights = self._get_weights_buffer_from_model(self.model)
469
+ if weights is None or weights.is_empty():
470
+ return
471
+
472
+ # Send to all workers using direct torch.distributed (no TCPStore signaling)
473
+ for i, transport in enumerate(self._iterate_transports()):
474
+ if worker_idx is not None and i != worker_idx:
475
+ continue
476
+ transport.send_initial_weights(weights)
477
+
478
+ def _setup_connection_and_weights_on_receiver_impl(
479
+ self, *, worker_idx: int | None = None
480
+ ) -> None:
481
+ """Receive initial weights from sender during connect().
482
+
483
+ The receiver always has a model that needs weights, so we block
484
+ waiting for the initial weights from the sender.
485
+ """
486
+ # Use stored worker_idx if not provided
487
+ if worker_idx is None:
488
+ worker_idx = self._worker_idx
489
+
490
+ # Initialize torch.distributed process group if not already done
491
+ # This is a collective operation - sender and all workers must call it
492
+ if not torch.distributed.is_initialized():
493
+ torch.distributed.init_process_group(
494
+ backend=self.backend,
495
+ rank=worker_idx,
496
+ world_size=self._num_workers + 1,
497
+ timeout=timedelta(seconds=self._timeout),
498
+ )
499
+
500
+ if self._receiver_transport is None:
501
+ torchrl_logger.warning(
502
+ "DistributedWeightSyncScheme: No receiver transport, skipping initial weight sync"
503
+ )
504
+ return
505
+
506
+ stateless_model = self.receiver_transport._store.get("STATELESS_MODEL")
507
+ if stateless_model not in (b"0", b"1"):
508
+ raise RuntimeError(f"Invalid STATELESS_MODEL value: {stateless_model}")
509
+ if stateless_model != b"1":
510
+ # Receive initial weights (blocking, no TCPStore coordination)
511
+ weights = self._receiver_transport.receive_initial_weights()
512
+ if weights is not None and self.model is not None:
513
+ self._strategy.apply_weights(self.model, weights, inplace=False)
514
+
515
+ # Start background receiver thread AFTER initial weight sync is complete
516
+ # This prevents the background thread from consuming the initial sync messages
517
+ if self._background_thread is None:
518
+ self._start_background_receiver()
519
+
520
+ def shutdown(self) -> None:
521
+ """Stop background receiver thread and clean up."""
522
+ # Check if already shutdown
523
+ if getattr(self, "_is_shutdown", False):
524
+ return
525
+ self._is_shutdown = True
526
+
527
+ # Let base class handle background thread cleanup
528
+ super().shutdown()
529
+
530
+ @property
531
+ def model(self) -> Any | None:
532
+ """Get the model associated with this scheme.
533
+
534
+ Returns:
535
+ The model if set, None otherwise.
536
+ """
537
+ if self._model_ref is not None:
538
+ return self._model_ref()
539
+ if self._model_id is not None:
540
+ model = _resolve_model(self.context, self._model_id)
541
+ if model is None:
542
+ if self._model_id == "policy":
543
+ torchrl_logger.debug("Creating policy from factory.")
544
+ model = self.context.policy_factory[0]()
545
+ self.context.policy = model
546
+ else:
547
+ raise AttributeError(
548
+ f"Model {self._model_id} was `None` in context {self.context}"
549
+ )
550
+ self._model_ref = weakref.ref(model)
551
+ return model
552
+
553
+ @model.setter
554
+ def model(self, value: Any):
555
+ """Set the model for this scheme.
556
+
557
+ Args:
558
+ value: The model to set. If None, the setter is a no-op.
559
+ """
560
+ if value is None:
561
+ return
562
+ self._model_ref = weakref.ref(value)
563
+
564
+ def create_transport(self, **kwargs) -> TransportBackend:
565
+ """Create distributed transport for a specific worker."""
566
+ return DistributedTransport(**kwargs)
567
+
568
+
569
+ class DistributedTransport:
570
+ """torch.distributed transport for communicating with a single distributed worker.
571
+
572
+ This transport handles weight updates for ONE specific distributed worker via
573
+ torch.distributed send/recv. Multiple transports are created for multiple workers,
574
+ following the same pattern as multiprocess collectors.
575
+ """
576
+
577
+ def __init__(
578
+ self,
579
+ *,
580
+ weights_buffer: TensorDictBase,
581
+ store: torch.distributed.Store = None,
582
+ rank: int | None = None,
583
+ sync: bool = True,
584
+ ):
585
+ """Initialize the DistributedTransport.
586
+
587
+ Args:
588
+ weights_buffer (TensorDictBase): a tensor buffer of weights.
589
+ store (torch.distributed.Store): A (TCP)Store for communication.
590
+ rank (int): Worker rank (1-indexed).
591
+ sync (bool): Whether to use synchronous weight updates.
592
+ """
593
+ self._store = store
594
+ self._rank = rank
595
+ self._sync = sync
596
+ self._weights_buffer = weights_buffer
597
+
598
+ def send_weights(self, weights: Any) -> None:
599
+ """Send weights to the distributed worker."""
600
+ if self._store is None or self._rank is None:
601
+ return
602
+
603
+ # Instruct worker to expect weight update
604
+ self._store.set(f"NODE_{self._rank}_in", b"update_weights")
605
+
606
+ # Send weights via torch.distributed
607
+ if self._sync:
608
+ weights.send(self._rank)
609
+ else:
610
+ weights.isend(self._rank)
611
+
612
+ # Wait for acknowledgment
613
+ status = self._store.get(f"NODE_{self._rank}_out")
614
+ if status != b"updated":
615
+ raise RuntimeError(f"Expected 'updated' but got status {status}.")
616
+ self._store.delete_key(f"NODE_{self._rank}_out")
617
+
618
+ def send_weights_async(self, weights: Any) -> None:
619
+ """Send weights to distributed worker without waiting for acknowledgment.
620
+
621
+ Use wait_ack() to wait for acknowledgment after sending to all workers.
622
+ """
623
+ if self._store is None or self._rank is None:
624
+ return
625
+
626
+ # Instruct worker to expect weight update
627
+ self._store.set(f"NODE_{self._rank}_in", b"update_weights")
628
+
629
+ # Send weights via torch.distributed
630
+ if self._sync:
631
+ weights.send(self._rank)
632
+ else:
633
+ weights.isend(self._rank)
634
+
635
+ def wait_ack(self) -> None:
636
+ """Wait for acknowledgment from distributed worker."""
637
+ if self._store is None or self._rank is None:
638
+ return
639
+
640
+ status = self._store.get(f"NODE_{self._rank}_out")
641
+ if status != b"updated":
642
+ raise RuntimeError(f"Expected 'updated' but got status {status}.")
643
+ self._store.delete_key(f"NODE_{self._rank}_out")
644
+
645
+ def receive_weights(
646
+ self,
647
+ timeout: float | None = None,
648
+ *,
649
+ weights: Any = None,
650
+ model: Any = None,
651
+ strategy: WeightStrategy | None = None,
652
+ ) -> Any | None:
653
+ r"""Receive weights via torch.distributed and apply them to the model.
654
+
655
+ The surrounding collector loop is responsible for checking the TCPStore
656
+ for the \"update_weights\" instruction. When this method is called we
657
+ assume that a weight update has been requested and the sender has
658
+ already performed the corresponding ``send()``.
659
+
660
+ Args:
661
+ timeout: Maximum time to wait for weights (seconds). If None,
662
+ blocks until weights are received.
663
+ weights: Pre-allocated weight buffer to receive into.
664
+ model: The model to apply weights to.
665
+ strategy: Strategy for applying weights to the model.
666
+
667
+ Returns:
668
+ The received weights, or None if timeout expires.
669
+ """
670
+ if self._store is None or self._rank is None:
671
+ return None
672
+
673
+ # Use provided weights buffer or fallback to stored one
674
+ weights_buffer = weights if weights is not None else self._weights_buffer
675
+
676
+ # Receive weights via torch.distributed into the buffer
677
+ if self._sync or timeout is None:
678
+ # Blocking receive - no timeout support
679
+ if self._sync:
680
+ weights_buffer.recv(src=0)
681
+ else:
682
+ weights_buffer.irecv(src=0)
683
+ else:
684
+ # Non-blocking receive with timeout support
685
+ futures = weights_buffer.irecv(src=0, return_premature=True)
686
+ if futures:
687
+ start_time = time.monotonic()
688
+ while True:
689
+ # Check if all futures are complete
690
+ all_complete = all(f.is_completed() for f in futures)
691
+ if all_complete:
692
+ break
693
+ # Check timeout
694
+ elapsed = time.monotonic() - start_time
695
+ if elapsed >= timeout:
696
+ # Timeout expired before receiving all weights
697
+ return None
698
+ # Small sleep to avoid busy-waiting
699
+ time.sleep(0.001)
700
+
701
+ # Apply weights if model and strategy provided
702
+ if model is not None and strategy is not None:
703
+ strategy.apply_weights(model, weights_buffer)
704
+
705
+ return weights_buffer
706
+
707
+ def send_initial_weights(self, weights: Any) -> None:
708
+ """Send initial weights during connect() without TCPStore signaling.
709
+
710
+ This is used for the initial weight sync during connect() to avoid
711
+ interfering with the main collection loop's TCPStore-based coordination.
712
+ """
713
+ if self._rank is None:
714
+ return
715
+
716
+ # Note: No TCPStore signaling for initial sync - just direct send/recv
717
+ if self._sync:
718
+ weights.send(self._rank)
719
+ else:
720
+ weights.isend(self._rank)
721
+
722
+ def receive_initial_weights(self) -> Any:
723
+ """Receive initial weights during connect() without TCPStore signaling.
724
+
725
+ This is used for the initial weight sync during connect() to avoid
726
+ interfering with the main collection loop's TCPStore-based coordination.
727
+
728
+ Returns:
729
+ The received weights TensorDict.
730
+ """
731
+ if self._sync:
732
+ self._weights_buffer.recv(src=0)
733
+ else:
734
+ self._weights_buffer.irecv(src=0)
735
+ return self._weights_buffer
736
+
737
+ def setup_connection_and_weights_on_sender(self) -> None:
738
+ """No-op for DistributedTransport - handled by scheme."""
739
+
740
+ def setup_connection_and_weights_on_receiver(
741
+ self,
742
+ *,
743
+ worker_idx: int,
744
+ weights: Any = None,
745
+ model: Any = None,
746
+ strategy: WeightStrategy | None = None,
747
+ ) -> Any:
748
+ """No-op for DistributedTransport - handled by scheme."""
749
+ return None