torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1032 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import socket
5
+
6
+ import time
7
+ import weakref
8
+ from dataclasses import dataclass
9
+ from datetime import timedelta
10
+ from typing import Any, Literal
11
+
12
+ import torch
13
+ from tensordict import TensorDict
14
+ from tensordict.base import TensorDictBase
15
+
16
+ from torchrl._utils import logger as torchrl_logger
17
+ from torchrl.weight_update.utils import _resolve_model
18
+ from torchrl.weight_update.weight_sync_schemes import (
19
+ TransportBackend,
20
+ WeightStrategy,
21
+ WeightSyncScheme,
22
+ )
23
+
24
+ # Default timeout for torch.distributed operations
25
+ _DIST_TIMEOUT = timedelta(seconds=60)
26
+
27
+
28
+ @dataclass
29
+ class ConnectionInfo:
30
+ """Connection info for Ray distributed computing.
31
+
32
+ Uses dataclass instead of UserDict to avoid Ray signature introspection
33
+ issues with UserDict's __class_getitem__ in Python 3.11+
34
+ (ValueError: no signature found for builtin type GenericAlias).
35
+ """
36
+
37
+ master_addr: str
38
+ master_port: int
39
+ world_size: int
40
+ stateful_model: bool
41
+
42
+ def get(self, key: str, default: Any = None) -> Any:
43
+ """Get a connection info value by key name.
44
+
45
+ Args:
46
+ key (str): The attribute name to retrieve.
47
+ default: The default value if the attribute does not exist.
48
+ Defaults to None.
49
+
50
+ Returns:
51
+ The value of the attribute, or the default if not found.
52
+ """
53
+ return getattr(self, key, default)
54
+
55
+
56
+ class RayTransport:
57
+ """Ray transport for communicating with a single Ray actor.
58
+
59
+ This transport handles weight updates for ONE specific remote actor
60
+ using torch.distributed for efficient weight transfer. Ray is used for
61
+ signaling/coordination, while the actual weight data is transferred via
62
+ torch.distributed send/recv operations.
63
+
64
+ Multiple transports are created for multiple actors, following the
65
+ same pattern as multiprocess collectors.
66
+
67
+ Args:
68
+ remote_actor: The Ray actor handle for the remote collector/transform.
69
+ worker_idx (int, optional): The worker index for this remote actor.
70
+ Defaults to 0.
71
+ backend (str): The torch.distributed backend to use ("gloo" or "nccl").
72
+ Defaults to "gloo".
73
+ connection_info_name (str): Name of the Ray actor storing connection info.
74
+ Defaults to "connection_info".
75
+ model_id (str, optional): The model identifier for weight synchronization.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ *,
81
+ remote_actor=None,
82
+ worker_idx: int | None = None,
83
+ backend: str = "gloo",
84
+ connection_info_name: str = "connection_info",
85
+ model_id: str | None = None,
86
+ ):
87
+ """Initialize the RayTransport.
88
+
89
+ Args:
90
+ remote_actor: The Ray actor handle for the remote collector/transform.
91
+ worker_idx (int, optional): The worker index for this remote actor.
92
+ Defaults to 0.
93
+ backend (str): The torch.distributed backend to use ("gloo" or "nccl").
94
+ Defaults to "gloo".
95
+ connection_info_name (str): Name of the Ray actor storing connection info.
96
+ Defaults to "connection_info".
97
+ model_id (str, optional): The model identifier for weight synchronization.
98
+ """
99
+ try:
100
+ import ray
101
+
102
+ self.ray = ray
103
+ except ImportError:
104
+ raise ImportError("Ray is required for RayTransport")
105
+ self._remote_actor = remote_actor
106
+ self._worker_idx = worker_idx if worker_idx is not None else 0
107
+ self._backend = backend
108
+ self._connection_info_name = connection_info_name
109
+ self._model_id = model_id
110
+
111
+ # Distributed state
112
+ self._dist_initialized = False
113
+ self._weights_buffer: TensorDictBase | None = None
114
+ self._stateful_model: bool = True
115
+
116
+ # Async operation state
117
+ self._pending_future = None
118
+ self._pending_isend = None
119
+
120
+ # Model reference (set by scheme on receiver side)
121
+ self._model = None
122
+
123
+ @property
124
+ def _rank(self) -> int:
125
+ """Get the torch.distributed rank for this worker.
126
+
127
+ Returns:
128
+ int: The rank (worker_idx + 1, since sender is rank 0).
129
+ """
130
+ return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed
131
+
132
+ def set_model(self, model: Any) -> None:
133
+ """Set the model for receiving weights.
134
+
135
+ Args:
136
+ model: The model to receive weights into.
137
+ """
138
+ self._model = model
139
+
140
+ # ========================================================================
141
+ # Sending Weights (Sender Side)
142
+ # ========================================================================
143
+
144
+ def send_weights(self, weights: Any) -> None:
145
+ """Send weights to the remote actor via torch.distributed.
146
+
147
+ This method:
148
+ 1. Signals the remote actor to start receiving via Ray remote call
149
+ 2. Sends weights via torch.distributed.isend
150
+ 3. Waits for both to complete
151
+
152
+ Args:
153
+ weights: The weights to send (typically a TensorDict).
154
+ """
155
+ if self._remote_actor is None:
156
+ return
157
+
158
+ # Step 1: Signal the remote actor via Ray to start receiving (async)
159
+ future = self._remote_actor._receive_weights_scheme.remote()
160
+
161
+ # Step 2: Send weights via torch.distributed (async)
162
+ weights.isend(dst=self._rank)
163
+
164
+ # Step 3: Wait for the Ray call to complete (receiver has applied weights)
165
+ self.ray.get(future)
166
+
167
+ def send_weights_async(self, weights: Any) -> None:
168
+ """Send weights to Ray actor without waiting for completion.
169
+
170
+ Use :meth:`wait_ack` to wait for completion after sending to all actors.
171
+
172
+ Args:
173
+ weights: The weights to send (typically a TensorDict).
174
+ """
175
+ if self._remote_actor is None:
176
+ return
177
+
178
+ # Step 1: Signal the actor via Ray to start receiving (async)
179
+ self._pending_future = self._remote_actor._receive_weights_scheme.remote()
180
+
181
+ # Step 2: Send weights via torch.distributed (async)
182
+ self._pending_isend = weights.isend(dst=self._rank, return_early=True)
183
+
184
+ def wait_ack(self) -> None:
185
+ """Wait for Ray actor to finish applying weights.
186
+
187
+ Raises:
188
+ RuntimeError: If no pending future exists (i.e., :meth:`send_weights_async`
189
+ was not called before this method).
190
+ """
191
+ if self._pending_future is not None:
192
+ self.ray.get(self._pending_future)
193
+ if self._pending_isend is not None:
194
+ for fut in self._pending_isend:
195
+ fut.wait()
196
+ self._pending_future = None
197
+ self._pending_isend = None
198
+ else:
199
+ raise RuntimeError("No pending future. Did you call send_weights_async?")
200
+
201
+ # ========================================================================
202
+ # Receiving Weights (Receiver Side)
203
+ # ========================================================================
204
+
205
+ def receive_weights(
206
+ self,
207
+ timeout: float | None = None,
208
+ *,
209
+ weights: Any = None,
210
+ model: Any = None,
211
+ strategy: WeightStrategy | None = None,
212
+ ) -> Any | None:
213
+ """Receive weights from sender via torch.distributed.
214
+
215
+ Args:
216
+ timeout: Maximum time to wait for weights (seconds). If None,
217
+ blocks until weights are received.
218
+ weights: Pre-allocated weight buffer to receive into.
219
+ model: The model to apply weights to.
220
+ strategy: Strategy for applying weights to the model.
221
+
222
+ Returns:
223
+ The received weights, or None if timeout expires.
224
+ """
225
+ from torchrl.collectors.utils import _cast
226
+
227
+ # Use provided weights buffer or fallback to stored one
228
+ weights_buffer = weights if weights is not None else self._weights_buffer
229
+ if weights_buffer is None:
230
+ if model is None:
231
+ raise RuntimeError("No model available to receive weights")
232
+ if isinstance(model, torch.nn.Module):
233
+ weights_buffer = TensorDict.from_module(model)
234
+ weights_buffer = weights_buffer.data.apply(_cast, weights_buffer)
235
+ else:
236
+ weights_buffer = TensorDict(lock=True)
237
+
238
+ # Cache the weights buffer for future use
239
+ if self._weights_buffer is None:
240
+ self._weights_buffer = weights_buffer
241
+
242
+ # Receive weights from rank 0
243
+ if timeout is None:
244
+ # Blocking receive
245
+ weights_buffer.irecv(src=0)
246
+ else:
247
+ # Non-blocking receive with timeout support
248
+ futures = weights_buffer.irecv(src=0, return_premature=True)
249
+ if futures:
250
+ start_time = time.monotonic()
251
+ while True:
252
+ # Check if all futures are complete
253
+ all_complete = all(f.is_completed() for f in futures)
254
+ if all_complete:
255
+ break
256
+ # Check timeout
257
+ elapsed = time.monotonic() - start_time
258
+ if elapsed >= timeout:
259
+ # Timeout expired before receiving all weights
260
+ return None
261
+ # Small sleep to avoid busy-waiting
262
+ time.sleep(0.001)
263
+
264
+ # Apply weights to model
265
+ if not isinstance(model, torch.nn.Module):
266
+ if not weights_buffer.is_empty():
267
+ raise RuntimeError(
268
+ f"Cannot cast weights to model type: {type(model)} with weights: {weights_buffer}."
269
+ )
270
+ return None
271
+
272
+ if strategy is not None:
273
+ strategy.apply_weights(model, weights_buffer)
274
+ else:
275
+ weights_buffer.to_module(model)
276
+
277
+ return weights_buffer
278
+
279
+ # ========================================================================
280
+ # Connection Setup
281
+ # ========================================================================
282
+
283
+ def setup_connection_and_weights_on_sender(self) -> None:
284
+ """Initialize torch.distributed on sender side for this worker's rank.
285
+
286
+ This is called by the scheme after it has created the connection info
287
+ Ray actor. The actual ``init_process_group`` happens in the scheme since
288
+ it's a collective operation that needs to happen for rank 0.
289
+
290
+ Note:
291
+ This method exists for interface compatibility but the real work
292
+ happens in the scheme's :meth:`_setup_distributed_connection_sender`.
293
+ """
294
+ # The scheme handles the collective init_process_group for rank 0.
295
+ # This method exists for interface compatibility but the real work
296
+ # happens in the scheme's _setup_distributed_connection_sender.
297
+
298
+ def setup_connection_and_weights_on_receiver(
299
+ self,
300
+ *,
301
+ worker_idx: int,
302
+ strategy: WeightStrategy | None = None,
303
+ model: Any | None = None,
304
+ weights: Any | None = None,
305
+ ) -> Any:
306
+ """Join torch.distributed process group and receive initial weights.
307
+
308
+ This method:
309
+ 1. Retrieves connection info from the shared Ray actor
310
+ 2. Initializes torch.distributed process group with rank=worker_idx+1
311
+ 3. Receives weights if model is stateful
312
+
313
+ Args:
314
+ worker_idx (int): The worker index for this transport.
315
+ strategy (WeightStrategy, optional): The weight transmission strategy.
316
+ model (nn.Module or compatible, optional): The model to receive weights for.
317
+ weights (TensorDict, optional): Pre-allocated buffer for receiving weights.
318
+
319
+ Returns:
320
+ The received weights (TensorDict) if model is stateful, None otherwise.
321
+ """
322
+ if self._dist_initialized:
323
+ # Already initialized, just receive weights if stateful
324
+ if self._stateful_model:
325
+ result = self.receive_weights(
326
+ weights=weights, model=model, strategy=strategy
327
+ )
328
+ return result[1] if result else None
329
+ return None
330
+
331
+ self._worker_idx = worker_idx
332
+ rank = self._rank
333
+
334
+ # Wait for connection info actor to be available
335
+ i = 0
336
+ while True:
337
+ try:
338
+ remote_connection_info = self.ray.get_actor(self._connection_info_name)
339
+ except ValueError:
340
+ i += 1
341
+ time.sleep(0.1)
342
+ continue
343
+ break
344
+
345
+ master_addr = self.ray.get(remote_connection_info.get.remote("master_addr"))
346
+ master_port = self.ray.get(remote_connection_info.get.remote("master_port"))
347
+ world_size = self.ray.get(remote_connection_info.get.remote("world_size"))
348
+ stateful_model = self.ray.get(
349
+ remote_connection_info.get.remote("stateful_model")
350
+ )
351
+ self._stateful_model = stateful_model
352
+
353
+ # Set environment variables for torch.distributed
354
+ os.environ["MASTER_ADDR"] = master_addr
355
+ os.environ["MASTER_PORT"] = str(master_port)
356
+
357
+ # Initialize process group on receiver
358
+ torch.distributed.init_process_group(
359
+ backend=self._backend,
360
+ rank=rank,
361
+ world_size=world_size,
362
+ )
363
+ self._dist_initialized = True
364
+
365
+ # Receive initial weights if model is stateful
366
+ if self._stateful_model:
367
+ return self.receive_weights(model=model, weights=weights, strategy=strategy)
368
+ return None
369
+
370
+
371
+ class RayWeightSyncScheme(WeightSyncScheme):
372
+ """Weight synchronization for Ray distributed computing.
373
+
374
+ This scheme uses torch.distributed to synchronize weights across distributed
375
+ workers (Ray actors). The process group is initialized during the first
376
+ ``synchronize_weights()`` call, with the sender as rank 0 and workers as
377
+ rank ``worker_idx + 1``.
378
+
379
+ Each remote collector gets its own transport, following the same pattern
380
+ as multiprocess collectors.
381
+
382
+ Args:
383
+ strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
384
+ Defaults to "tensordict".
385
+ backend (str): The torch.distributed backend to use ("gloo" or "nccl").
386
+ Defaults to "gloo".
387
+ """
388
+
389
+ @property
390
+ def connection_info_name(self) -> str:
391
+ """Get the name of the Ray actor storing connection info.
392
+
393
+ Returns a unique name based on model_id to avoid collisions when
394
+ multiple schemes are used with different models.
395
+
396
+ Returns:
397
+ The connection info actor name.
398
+ """
399
+ if self._model_id is not None:
400
+ return f"connection_info_{self._model_id}"
401
+ return "connection_info"
402
+
403
+ def __init__(
404
+ self,
405
+ strategy: Literal["tensordict", "state_dict"] = "tensordict",
406
+ backend: str = "gloo",
407
+ ):
408
+ """Initialize the RayWeightSyncScheme.
409
+
410
+ Args:
411
+ strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
412
+ Defaults to "tensordict".
413
+ backend (str): The torch.distributed backend to use ("gloo" or "nccl").
414
+ Defaults to "gloo".
415
+ """
416
+ super().__init__(strategy)
417
+ self._backend = backend
418
+ self._dist_initialized = False
419
+ self._remote_collectors: list | None = None
420
+ self._num_workers: int = 0
421
+
422
+ @property
423
+ def model(self) -> Any | None:
424
+ """Get the model associated with this scheme.
425
+
426
+ Returns:
427
+ The model if set, None otherwise.
428
+ """
429
+ if self._model_ref is not None:
430
+ return self._model_ref()
431
+ if self._model_id is not None:
432
+ model = _resolve_model(self.context, self._model_id)
433
+ if model is None:
434
+ if self._model_id == "policy":
435
+ torchrl_logger.debug("Creating policy from factory.")
436
+ model = self.context.policy_factory[0]()
437
+ self.context.policy = model
438
+ else:
439
+ raise AttributeError(
440
+ f"Model {self._model_id} was `None` in context {self.context}"
441
+ )
442
+ self._model_ref = weakref.ref(model)
443
+ return model
444
+
445
+ @model.setter
446
+ def model(self, value: Any):
447
+ """Set the model for this scheme.
448
+
449
+ Args:
450
+ value: The model to set. If None, the setter is a no-op.
451
+ """
452
+ if value is None:
453
+ return
454
+ self._model_ref = weakref.ref(value)
455
+
456
+ def create_transport(
457
+ self,
458
+ *,
459
+ remote_actor=None,
460
+ worker_idx: int | None = None,
461
+ # Legacy parameter name for backwards compatibility
462
+ remote_collector=None,
463
+ **kwargs,
464
+ ) -> TransportBackend:
465
+ """Create Ray-based transport for a specific remote actor.
466
+
467
+ Args:
468
+ remote_actor: The Ray actor handle for the remote collector/transform.
469
+ worker_idx: The worker index for this remote actor.
470
+ remote_collector: Legacy alias for remote_actor.
471
+ **kwargs: Additional transport configuration.
472
+
473
+ Returns:
474
+ RayTransport configured for this specific remote actor.
475
+ """
476
+ # Support legacy parameter name
477
+ if remote_actor is None:
478
+ remote_actor = remote_collector
479
+
480
+ return RayTransport(
481
+ remote_actor=remote_actor,
482
+ worker_idx=worker_idx,
483
+ backend=self._backend,
484
+ connection_info_name=self.connection_info_name,
485
+ model_id=self._model_id,
486
+ )
487
+
488
+ def _init_on_sender_impl(
489
+ self,
490
+ model_id: str,
491
+ context: Any = None,
492
+ **kwargs,
493
+ ) -> None:
494
+ """Initialize on the main process (sender side).
495
+
496
+ This method sets up the torch.distributed connection info and shares it
497
+ with all remote collectors so they can join the process group.
498
+
499
+ Args:
500
+ model_id: Identifier for the model being synchronized
501
+ context: Optional context object providing remote_collectors
502
+ **kwargs: Alternative to context (remote_collectors, source_model, etc.)
503
+ """
504
+ try:
505
+ import ray
506
+
507
+ self.ray = ray
508
+ except ImportError:
509
+ raise ImportError("Ray is required for RayWeightSyncScheme")
510
+
511
+ # Extract parameters from context or kwargs
512
+ if context is not None:
513
+ remote_collectors = getattr(context, "remote_collectors", None)
514
+ num_workers = getattr(context, "num_workers", None) or getattr(
515
+ context, "num_collectors", None
516
+ )
517
+ else:
518
+ remote_collectors = kwargs.get("remote_collectors")
519
+ num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors")
520
+
521
+ if remote_collectors is None:
522
+ raise ValueError("remote_collectors must be provided via context or kwargs")
523
+ if num_workers is None:
524
+ num_workers = len(remote_collectors) if remote_collectors else 0
525
+
526
+ # Store model_id and context on scheme
527
+ self.model_id = model_id
528
+
529
+ # Store remote collectors and num_workers for synchronize_weights
530
+ self._remote_collectors = list(remote_collectors)
531
+ self._num_workers = int(num_workers)
532
+
533
+ # Register each Ray actor with explicit transport kwargs
534
+ for worker_idx, remote_collector in enumerate(remote_collectors):
535
+ transport = self.create_transport(
536
+ remote_actor=remote_collector,
537
+ worker_idx=worker_idx,
538
+ )
539
+ self._register_worker_sender(
540
+ worker_idx=worker_idx,
541
+ transport=transport,
542
+ )
543
+
544
+ # Set context with weak reference to avoid circular refs
545
+ if context is not None:
546
+ self.context = context
547
+
548
+ # Store source model reference if provided for automatic weight extraction
549
+ model = kwargs.get("model")
550
+ if model is not None:
551
+ self.model = model
552
+
553
+ # Note: Distributed connection setup is deferred to synchronize_weights
554
+ # because _receiver_schemes on workers won't exist until register_scheme_receiver is called
555
+
556
+ def _init_on_receiver_impl(
557
+ self,
558
+ model_id: str,
559
+ context: Any = None,
560
+ **kwargs,
561
+ ) -> None:
562
+ """Initialize on worker process (receiver side).
563
+
564
+ Args:
565
+ model_id: Identifier for the model being synchronized
566
+ context: Optional context object (typically the remote collector)
567
+ **kwargs: Optional parameters (worker_idx, model, etc.)
568
+ """
569
+ try:
570
+ import ray
571
+
572
+ self.ray = ray
573
+ except ImportError:
574
+ raise ImportError("Ray is required for RayWeightSyncScheme")
575
+
576
+ # Store model_id and context on scheme
577
+ self.model_id = model_id
578
+ self.context = context
579
+
580
+ # Extract worker_idx from context or kwargs
581
+ if context is not None:
582
+ worker_idx = getattr(context, "worker_idx", None)
583
+ else:
584
+ worker_idx = kwargs.get("worker_idx")
585
+
586
+ self._worker_idx = worker_idx
587
+
588
+ # Resolve the target model on this worker
589
+ model = kwargs.get("model")
590
+ if model is not None:
591
+ self.model = model
592
+ # get the weights to possibly instantiate a copy of the model (policy factory with multi-collector)
593
+ self.weights # noqa
594
+
595
+ # Create and register transport for receiver side
596
+ # Note: create_transport returns TransportBackend but we know it's RayTransport
597
+ transport = self.create_transport(
598
+ remote_actor=None, # Receiver doesn't need actor handle
599
+ worker_idx=worker_idx,
600
+ )
601
+ if isinstance(transport, RayTransport):
602
+ transport.set_model(model)
603
+ self._register_transport_receiver(transport=transport)
604
+
605
+ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None:
606
+ """Set up torch.distributed connection info and share with remote collectors.
607
+
608
+ This method:
609
+ 1. Gets master address and finds an available port
610
+ 2. Stores connection info in Ray's object store as a named actor
611
+ 3. Initializes torch.distributed process group with rank=0
612
+
613
+ Args:
614
+ timeout: Maximum time in seconds to wait for workers to be ready.
615
+ Default is 300 seconds (5 minutes).
616
+ """
617
+ if self._dist_initialized:
618
+ return
619
+
620
+ if self._remote_collectors is None or self._num_workers == 0:
621
+ raise RuntimeError(
622
+ "_setup_distributed_connection() requires remote_collectors to be set"
623
+ )
624
+
625
+ # Get master address (hostname/IP)
626
+ hostname = socket.gethostname()
627
+ try:
628
+ master_addr = socket.gethostbyname(hostname)
629
+ except socket.gaierror:
630
+ master_addr = "127.0.0.1"
631
+
632
+ # Find an available port
633
+ master_port = self._find_free_port()
634
+ world_size = self._num_workers + 1 # +1 for the sender (rank 0)
635
+
636
+ try:
637
+ self.weights
638
+ stateful_model = True
639
+ except (AttributeError, RuntimeError, ValueError):
640
+ stateful_model = False
641
+ self._stateful_model = stateful_model
642
+
643
+ # Connection info to share with workers via named Ray actor
644
+ RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options(
645
+ name=self.connection_info_name
646
+ )
647
+ self._connection_info_actor = RemoteConnectionInfo.remote(
648
+ master_addr=master_addr,
649
+ master_port=master_port,
650
+ world_size=world_size,
651
+ stateful_model=stateful_model,
652
+ )
653
+
654
+ # Set environment variables for torch.distributed
655
+ os.environ["MASTER_ADDR"] = master_addr
656
+ os.environ["MASTER_PORT"] = str(master_port)
657
+
658
+ # Initialize process group on sender (rank 0)
659
+ # Note: Workers will call init_process_group in their transport's
660
+ # setup_connection_and_weights_on_receiver. The init_process_group is
661
+ # a collective operation, so all ranks must call it together.
662
+ torch.distributed.init_process_group(
663
+ backend=self._backend,
664
+ rank=0,
665
+ world_size=world_size,
666
+ timeout=_DIST_TIMEOUT,
667
+ )
668
+ self._dist_initialized = True
669
+
670
+ def _setup_connection_and_weights_on_sender_impl(
671
+ self,
672
+ *,
673
+ worker_idx: int | None = None,
674
+ weights: Any | None = None,
675
+ ) -> None:
676
+ """Set up distributed connection and send initial weights to all workers.
677
+
678
+ This method:
679
+ 1. Sets up torch.distributed process group (waits for workers if needed)
680
+ 2. Sends initial weights to all workers via their transports
681
+
682
+ The distributed setup is done here (not in ``init_on_sender``) because
683
+ workers need to have ``register_scheme_receiver`` called first.
684
+
685
+ Args:
686
+ worker_idx (int, optional): Not used in this implementation.
687
+ weights (optional): Not used in this implementation (weights are
688
+ extracted from the model).
689
+ """
690
+ # Set up distributed connection (with wait for workers to be ready)
691
+ if not self._dist_initialized:
692
+ self._setup_distributed_connection_sender()
693
+
694
+ # Send the initial weights
695
+ if self._stateful_model:
696
+ self._send_weights_distributed()
697
+
698
+ def _send_weights_distributed(self) -> None:
699
+ """Send weights to all workers via torch.distributed.
700
+
701
+ Raises:
702
+ RuntimeError: If no weights are available to send.
703
+ """
704
+ # Extract weights from model
705
+ weights = self.weights
706
+ if weights is None:
707
+ raise RuntimeError("No weights available to send")
708
+
709
+ # Send weights to each worker (ranks 1 to num_workers)
710
+ futures = []
711
+ for worker_idx in range(self._num_workers):
712
+ rank = worker_idx + 1
713
+ futures.extend(weights.isend(dst=rank, return_early=True))
714
+ # Wait for all sends to complete
715
+ for future in futures:
716
+ future.wait()
717
+
718
+ def _setup_connection_and_weights_on_receiver_impl(
719
+ self, *, worker_idx: int | None = None
720
+ ) -> None:
721
+ """Join torch.distributed process group and receive initial weights.
722
+
723
+ Delegates to the transport's :meth:`~RayTransport.setup_connection_and_weights_on_receiver`.
724
+
725
+ Args:
726
+ worker_idx (int, optional): The worker index. If None, uses the stored
727
+ ``_worker_idx`` or defaults to 0.
728
+ """
729
+ if worker_idx is None:
730
+ worker_idx = self._worker_idx
731
+ if worker_idx is None:
732
+ worker_idx = 0 # Default to worker 0
733
+
734
+ transport = self.receiver_transport
735
+ if transport is not None:
736
+ # Transport handles joining process group and receiving weights
737
+ transport.setup_connection_and_weights_on_receiver(
738
+ worker_idx=worker_idx,
739
+ model=self.model,
740
+ weights=self.weights,
741
+ strategy=self._strategy,
742
+ )
743
+ self._dist_initialized = True
744
+
745
+ @staticmethod
746
+ def _find_free_port() -> int:
747
+ """Find a free port on the local machine.
748
+
749
+ Returns:
750
+ int: An available port number.
751
+ """
752
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
753
+ s.bind(("", 0))
754
+ s.listen(1)
755
+ port = s.getsockname()[1]
756
+ return port
757
+
758
+
759
+ class RayModuleTransformScheme(RayWeightSyncScheme):
760
+ """Weight synchronization for RayModuleTransform.
761
+
762
+ This scheme uses torch.distributed to synchronize weights between
763
+ a trainer/collector and a RayModuleTransform actor. The sender is rank 0,
764
+ the transform's actor is rank 1.
765
+
766
+ This enables updating the weights of a module running inside a RayModuleTransform
767
+ from a parent collector or training loop.
768
+
769
+ Args:
770
+ strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
771
+ Default is "tensordict".
772
+ backend (str): The torch.distributed backend to use ("gloo" or "nccl").
773
+ Default is "gloo".
774
+
775
+ Example:
776
+ >>> # Create scheme and transform
777
+ >>> scheme = RayModuleTransformScheme()
778
+ >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme)
779
+ >>>
780
+ >>> # Create env with transform
781
+ >>> env = TransformedEnv(base_env, transform)
782
+ >>>
783
+ >>> # Pass scheme to parent collector
784
+ >>> collector = SomeCollector(
785
+ ... env, policy,
786
+ ... weight_sync_schemes={"transform_module": scheme}
787
+ ... )
788
+ >>>
789
+ >>> # Update weights
790
+ >>> collector.update_policy_weights_(model_id="transform_module")
791
+ """
792
+
793
+ def __init__(
794
+ self,
795
+ strategy: Literal["tensordict", "state_dict"] = "tensordict",
796
+ backend: str = "gloo",
797
+ ):
798
+ """Initialize the RayModuleTransformScheme.
799
+
800
+ Args:
801
+ strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
802
+ Defaults to "tensordict".
803
+ backend (str): The torch.distributed backend to use ("gloo" or "nccl").
804
+ Defaults to "gloo".
805
+ """
806
+ super().__init__(strategy, backend)
807
+ self._ray_transform = None
808
+
809
+ def _set_transform(self, ray_transform) -> None:
810
+ """Store reference to the RayModuleTransform.
811
+
812
+ Called by RayModuleTransform when the scheme is passed to it.
813
+
814
+ Args:
815
+ ray_transform: The RayModuleTransform instance.
816
+ """
817
+ self._ray_transform = ray_transform
818
+
819
+ def _init_on_sender_impl(
820
+ self,
821
+ model_id: str | None = None,
822
+ context: Any = None,
823
+ **kwargs,
824
+ ) -> None:
825
+ """Initialize on the main process (sender side).
826
+
827
+ Uses the stored transform reference (set via _set_transform) to
828
+ create transport for the transform's actor.
829
+
830
+ Args:
831
+ model_id: Identifier for the model being synchronized
832
+ context: Optional context object (typically the collector)
833
+ **kwargs: Optional parameters (ray_transform, model, etc.)
834
+ """
835
+ try:
836
+ import ray
837
+
838
+ self.ray = ray
839
+ except ImportError:
840
+ raise ImportError("Ray is required for RayModuleTransformScheme")
841
+
842
+ # Get transform reference - either stored via _set_transform or from kwargs
843
+ ray_transform = self._ray_transform
844
+ if ray_transform is None:
845
+ ray_transform = kwargs.get("ray_transform")
846
+ if ray_transform is None:
847
+ raise ValueError(
848
+ "ray_transform must be set via _set_transform() or provided in kwargs. "
849
+ "Pass the scheme to RayModuleTransform constructor to set it automatically."
850
+ )
851
+
852
+ # Store model_id
853
+ self.model_id = model_id
854
+
855
+ # Single worker (the transform's actor)
856
+ self._num_workers = 1
857
+
858
+ # Create transport for the transform's actor
859
+ # The actor handle is ray_transform._actor
860
+ transport = self.create_transport(
861
+ remote_actor=ray_transform._actor,
862
+ worker_idx=0,
863
+ )
864
+ self._register_worker_sender(
865
+ worker_idx=0,
866
+ transport=transport,
867
+ )
868
+
869
+ # Set context if provided
870
+ if context is not None:
871
+ self.context = context
872
+
873
+ # Store source model reference if provided for automatic weight extraction
874
+ model = kwargs.get("model")
875
+ if model is not None:
876
+ self.model = model
877
+
878
+ def _init_on_receiver_impl(
879
+ self,
880
+ model_id: str,
881
+ context: Any = None,
882
+ **kwargs,
883
+ ) -> None:
884
+ """Initialize on the transform's actor (receiver side).
885
+
886
+ Args:
887
+ model_id: Identifier for the model being synchronized
888
+ context: The ModuleTransform instance (the actor's underlying class)
889
+ **kwargs: Optional parameters (worker_idx, model, etc.)
890
+ """
891
+ try:
892
+ import ray
893
+
894
+ self.ray = ray
895
+ except ImportError:
896
+ raise ImportError("Ray is required for RayModuleTransformScheme")
897
+
898
+ # Store model_id and context
899
+ self.model_id = model_id
900
+ self.context = context
901
+
902
+ # Single transform actor is always worker_idx=0
903
+ self._worker_idx = kwargs.get("worker_idx", 0)
904
+
905
+ # Resolve the target model from context (ModuleTransform has a .module attribute)
906
+ model = kwargs.get("model")
907
+ if model is None and context is not None:
908
+ model = getattr(context, "module", None)
909
+ if model is not None:
910
+ self.model = model
911
+
912
+ # Create and register transport for receiver side
913
+ # Note: create_transport returns TransportBackend but we know it's RayTransport
914
+ transport = self.create_transport(
915
+ remote_actor=None,
916
+ worker_idx=self._worker_idx,
917
+ )
918
+ if isinstance(transport, RayTransport):
919
+ transport.set_model(model)
920
+ self._register_transport_receiver(transport=transport)
921
+
922
+ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None:
923
+ """Set up torch.distributed for the single transform actor.
924
+
925
+ Overrides parent to work with a single RayModuleTransform instead of
926
+ multiple remote collectors.
927
+
928
+ Args:
929
+ timeout (float): Maximum time in seconds to wait for connection setup.
930
+ Defaults to 300.0 (5 minutes).
931
+
932
+ Raises:
933
+ RuntimeError: If ``ray_transform`` is not set.
934
+ """
935
+ if self._dist_initialized:
936
+ return
937
+
938
+ if self._ray_transform is None:
939
+ raise RuntimeError(
940
+ "_setup_distributed_connection() requires ray_transform to be set. "
941
+ "Did you pass the scheme to RayModuleTransform?"
942
+ )
943
+
944
+ # Get master address (hostname/IP)
945
+ hostname = socket.gethostname()
946
+ try:
947
+ master_addr = socket.gethostbyname(hostname)
948
+ except socket.gaierror:
949
+ master_addr = "127.0.0.1"
950
+
951
+ # Find an available port
952
+ master_port = self._find_free_port()
953
+ world_size = 2 # Sender (rank 0) + Transform (rank 1)
954
+
955
+ # Check if model has weights
956
+ try:
957
+ w = self.weights
958
+ stateful_model = w is not None
959
+ except (AttributeError, RuntimeError, ValueError):
960
+ stateful_model = False
961
+ self._stateful_model = stateful_model
962
+
963
+ # Connection info to share with the transform's actor
964
+ RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options(
965
+ name=self.connection_info_name
966
+ )
967
+ self._connection_info_actor = RemoteConnectionInfo.remote(
968
+ master_addr=master_addr,
969
+ master_port=master_port,
970
+ world_size=world_size,
971
+ stateful_model=stateful_model,
972
+ )
973
+
974
+ # Set environment variables for torch.distributed
975
+ os.environ["MASTER_ADDR"] = master_addr
976
+ os.environ["MASTER_PORT"] = str(master_port)
977
+
978
+ # Now initialize process group on sender (rank 0)
979
+ # The receiver is concurrently joining via the Ray call above
980
+ torch.distributed.init_process_group(
981
+ backend=self._backend,
982
+ rank=0,
983
+ world_size=world_size,
984
+ timeout=_DIST_TIMEOUT,
985
+ )
986
+ self._dist_initialized = True
987
+
988
+ def _setup_connection_and_weights_on_sender_impl(
989
+ self,
990
+ *,
991
+ worker_idx: int | None = None,
992
+ weights: Any | None = None,
993
+ ) -> None:
994
+ """Set up distributed connection and send initial weights.
995
+
996
+ Args:
997
+ worker_idx (int, optional): The worker index. Not used for
998
+ RayModuleTransformScheme as there is only one transform actor.
999
+ weights (optional): Pre-extracted weights to send. If None, weights
1000
+ are extracted from the model.
1001
+ """
1002
+ receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote(
1003
+ scheme=self, model_id=self.model_id
1004
+ )
1005
+
1006
+ if not self._dist_initialized:
1007
+ self._setup_distributed_connection_sender()
1008
+
1009
+ if self._stateful_model:
1010
+ self._send_weights_distributed(weights=weights)
1011
+
1012
+ self.ray.get(receiver_future)
1013
+
1014
+ def _send_weights_distributed(self, weights: Any | None = None) -> None:
1015
+ """Send weights to the transform actor via torch.distributed.
1016
+
1017
+ Args:
1018
+ weights (optional): Pre-extracted weights to send. If None, weights
1019
+ are extracted from the model via :attr:`weights`.
1020
+
1021
+ Raises:
1022
+ RuntimeError: If no weights are available to send.
1023
+ """
1024
+ if weights is None:
1025
+ weights = self.weights
1026
+ if weights is None:
1027
+ raise RuntimeError("No weights available to send")
1028
+
1029
+ # Send weights to the transform (rank 1)
1030
+ futures = weights.isend(dst=1, return_early=True)
1031
+ for future in futures:
1032
+ future.wait()