torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,891 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.distributed
8
+
9
+ from tensordict import TensorDict, TensorDictBase
10
+
11
+ from torch import multiprocessing as mp, nn
12
+
13
+ from torchrl._utils import logger as torchrl_logger, WEIGHT_SYNC_TIMEOUT
14
+
15
+ from torchrl.weight_update.utils import _resolve_model
16
+ from torchrl.weight_update.weight_sync_schemes import (
17
+ TransportBackend,
18
+ WeightStrategy,
19
+ WeightSyncScheme,
20
+ )
21
+
22
+
23
+ def _close_mp_queue(queue: mp.Queue) -> None:
24
+ """Close a multiprocessing Queue and wait for its feeder thread to exit."""
25
+ queue.close()
26
+ queue.join_thread()
27
+
28
+
29
+ class SharedMemTransport:
30
+ """Shared memory transport for in-place weight updates.
31
+
32
+ This transport uses queue-based buffer distribution for initialization, then
33
+ updates shared memory tensors directly for subsequent weight updates.
34
+ Workers automatically see weight updates without explicit communication.
35
+
36
+ Initialization flow:
37
+ - Shared memory buffers are created and sent to workers via per-worker queues
38
+ - Workers receive the buffer reference and apply weights to their models
39
+ - Subsequent updates are pure in-place shared memory (zero-copy)
40
+
41
+ Both CPU and CUDA tensors maintain shared references when sent through mp.Queue.
42
+
43
+ """
44
+
45
+ def __init__(self):
46
+ self._params_map = None # a dict[worker_idx, TensorDictBase] map
47
+ self._weight_queues = (
48
+ None # Dict of per-worker queues for distributing shared weights
49
+ )
50
+ self._unique_weights = None
51
+
52
+ @property
53
+ def unique_weights(self) -> list[TensorDictBase]:
54
+ """Get the unique weights.
55
+
56
+ Returns:
57
+ The unique weights.
58
+ """
59
+ if self._unique_weights is None:
60
+ raise RuntimeError("Unique weights not set. Call register_weights() first.")
61
+ return self._unique_weights
62
+
63
+ def register_weights(
64
+ self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue]
65
+ ) -> None:
66
+ """Initialize per-worker queues for shared memory buffer distribution."""
67
+ from torchrl.collectors.utils import _cast
68
+
69
+ self._weight_queues = init_queues
70
+ self._params_map = params_map
71
+ # Create set of the unique weights
72
+ self._unique_weights = []
73
+ for weights in params_map.values():
74
+ if id(weights) in [id(w) for w in self._unique_weights]:
75
+ continue
76
+ weights = weights.data.apply(_cast, weights)
77
+ self._unique_weights.append(weights)
78
+
79
+ def setup_connection_and_weights_on_sender(self) -> None:
80
+ """Send shared memory buffer reference to workers via their per-worker queues.
81
+
82
+ Both CPU and CUDA tensors maintain shared references through queues.
83
+ Each worker reads from its own dedicated queue, to avoid race conditions.
84
+
85
+ """
86
+ if self._weight_queues is None:
87
+ raise RuntimeError("Queues not created yet. Call init_on_sender() first.")
88
+
89
+ for worker_idx, queue in self._weight_queues.items():
90
+ weights = self._params_map[worker_idx]
91
+ queue.put(weights)
92
+
93
+ def setup_connection_and_weights_on_receiver(
94
+ self,
95
+ *,
96
+ worker_idx: int | None = None,
97
+ weights: Any = None,
98
+ model: Any = None,
99
+ strategy: Any = None,
100
+ timeout: float = WEIGHT_SYNC_TIMEOUT,
101
+ ) -> TensorDictBase:
102
+ """Receive shared memory buffer reference from sender via their per-worker queues.
103
+
104
+ Each worker reads from its own dedicated queue, to avoid race conditions.
105
+
106
+ Args:
107
+ worker_idx: The worker index.
108
+ weights: Ignored (weights come from queue).
109
+ model: Ignored.
110
+ strategy: Ignored.
111
+ timeout: Timeout for reading from queue.
112
+
113
+ Returns:
114
+ The shared memory weights TensorDict.
115
+ """
116
+ if self._weight_queues is None:
117
+ raise RuntimeError("Queues not created yet. Call init_on_sender() first.")
118
+
119
+ if worker_idx not in self._weight_queues:
120
+ raise RuntimeError(f"Worker {worker_idx} not registered in queues.")
121
+
122
+ # Read from dedicated queue for this worker
123
+ worker_queue = self._weight_queues[worker_idx]
124
+ received_weights = worker_queue.get(timeout=timeout)
125
+ return received_weights
126
+
127
+ def send_weights(self, weights: Any) -> None:
128
+ """Update weights in-place in shared memory.
129
+
130
+ Args:
131
+ weights: New weights to send. Can be a TensorDictBase or dict.
132
+
133
+ Raises:
134
+ ValueError: If weights type is unsupported.
135
+ """
136
+ # Update shared memory in-place (workers see this automatically)
137
+ if isinstance(weights, dict):
138
+ weights = TensorDict(weights)
139
+ if not isinstance(weights, TensorDictBase):
140
+ raise ValueError(f"Unsupported weights type: {type(weights)=}")
141
+ # Unflatten if needed to match shared buffer structure
142
+ weights_to_update = weights
143
+ if any("." in key for key in weights.keys()):
144
+ weights_to_update = weights.unflatten_keys(".")
145
+
146
+ # Detach weights to allow in-place updates (gradients are not needed for weight sync)
147
+ weights_to_update = weights_to_update.detach()
148
+
149
+ if self._unique_weights is None:
150
+ raise RuntimeError("Unique weights not set. Call register_weights() first.")
151
+ for buffer in self._unique_weights:
152
+ if buffer.requires_grad:
153
+ raise RuntimeError(
154
+ "Gradients should not be required for shared memory buffers."
155
+ )
156
+ if weights_to_update.requires_grad:
157
+ raise RuntimeError("Gradients should not be required for weights.")
158
+ buffer.update_(weights_to_update, non_blocking=True)
159
+ if torch.cuda.is_available():
160
+ torch.cuda.synchronize()
161
+
162
+ def receive_weights(
163
+ self,
164
+ timeout: float | None = None,
165
+ *,
166
+ weights: Any = None,
167
+ model: Any = None,
168
+ strategy: Any = None,
169
+ ) -> Any | None:
170
+ """Apply shared memory weights to the model.
171
+
172
+ For shared memory, weights are already available (passed via the weights arg).
173
+ This method applies them to the model, matching the pattern of other transports.
174
+
175
+ Args:
176
+ timeout: Ignored (shared memory access is instant).
177
+ weights: The shared memory buffer containing current weights.
178
+ model: The model to apply weights to.
179
+ strategy: Strategy for applying weights.
180
+
181
+ Returns:
182
+ The applied weights, or None if not applied.
183
+ """
184
+ # Apply weights to model if provided (same pattern as other transports)
185
+ if model is not None and strategy is not None and weights is not None:
186
+ torchrl_logger.debug("Applying shared memory weights to model.")
187
+ strategy.apply_weights(model, weights)
188
+ return weights
189
+ return None
190
+
191
+ def send_ack(self, message: str = "updated") -> None:
192
+ """No-op for shared memory - no acknowledgment needed."""
193
+
194
+
195
+ class SharedMemWeightSyncScheme(WeightSyncScheme):
196
+ """Weight synchronization using shared memory.
197
+
198
+ This scheme uses shared memory for in-place weight updates. Workers
199
+ automatically see weight updates without explicit message passing.
200
+
201
+ A background thread on the receiver side listens for "receive" instructions
202
+ from the sender. When an instruction arrives, the thread applies the current
203
+ shared memory weights to the model and sends an acknowledgment.
204
+
205
+ Args:
206
+ strategy: The weight transmission strategy (default: "tensordict").
207
+ sync: If True (default), send() blocks until receiver acknowledges.
208
+ If False, send() returns immediately (use send_async/wait_async).
209
+
210
+ Example:
211
+ >>> # Basic usage
212
+ >>> scheme = SharedMemWeightSyncScheme()
213
+ >>> # Weights are initialized via init_on_sender()
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ strategy: str = "tensordict",
219
+ sync: bool = True,
220
+ ):
221
+ super().__init__(strategy)
222
+ self.sync = sync
223
+ # Create a single shared transport for all workers
224
+ self.shared_transport = SharedMemTransport()
225
+
226
+ # Create per-worker queues to avoid race conditions
227
+ # Each worker gets its own queue for weight initialization
228
+ self._weight_init_queues = {} # worker_idx -> Queue
229
+
230
+ # Instruction queues: sender puts "receive" instruction, receiver's background thread reads
231
+ self._instruction_queues: dict[int, mp.Queue] = {} # worker_idx -> Queue
232
+
233
+ # Acknowledgment queues: receiver puts "updated" ack, sender reads for sync mode
234
+ self._ack_queues: dict[int, mp.Queue] = {} # worker_idx -> Queue
235
+
236
+ # Receiver's instruction queue reference (set during init_on_receiver)
237
+ self._receiver_instruction_queue: mp.Queue | None = None
238
+ self._receiver_ack_queue: mp.Queue | None = None
239
+
240
+ def _init_on_sender_impl(
241
+ self,
242
+ *,
243
+ model_id: str | None = None,
244
+ context: Any = None,
245
+ weights: TensorDictBase | None = None,
246
+ model: nn.Module | None = None,
247
+ params_map: dict[int, TensorDictBase] | None = None,
248
+ devices: list[torch.device] | None = None,
249
+ device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None,
250
+ num_workers: int | None = None,
251
+ ctx: Any = None,
252
+ ) -> None:
253
+ """Initialize on the main process (sender side).
254
+
255
+ We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers
256
+ share the same device, the entry in the dict will be the same.
257
+ To do this, we need to know the number of workers, their assigned device, and have access to the parameters.
258
+ If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided
259
+ explicitly.
260
+
261
+ In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the
262
+ devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that
263
+ will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices.
264
+
265
+ Args:
266
+ model_id: Identifier for the model being synchronized
267
+ context: Optional context object providing device_to_workers mapping and model access
268
+ weights: Pre-extracted weights as TensorDict (for policy factory usage)
269
+ model: Model to extract weights from
270
+ params_map: Direct mapping of worker_idx to weights on device (most explicit)
271
+ devices: List of devices for each worker
272
+ device_map_fn: Custom function to map worker_idx and weights to device-specific weights
273
+ num_workers: Number of workers (required with device_map_fn)
274
+ ctx: Multiprocessing context. Defaults to `mp.get_context()`.
275
+
276
+ Examples:
277
+ Simple usage with collector context (stateful policy):
278
+
279
+ >>> policy = make_stateful_policy()
280
+ >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict")
281
+ >>> collector = MultiSyncCollector(
282
+ ... create_env_fn=[lambda: GymEnv("CartPole-v1")],
283
+ ... policy=policy,
284
+ ... frames_per_batch=100,
285
+ ... total_frames=1000,
286
+ ... weight_sync_schemes={"policy": scheme},
287
+ ... )
288
+ >>> # scheme.init_on_sender() is called automatically by collector
289
+
290
+ Pre-initialized usage (policy factory):
291
+
292
+ >>> policy_on_main = make_stateful_policy()
293
+ >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict")
294
+ >>> # Must initialize before collector creation when using policy_factory
295
+ >>> scheme.init_on_sender(
296
+ ... model_id="policy",
297
+ ... weights=TensorDict.from_module(policy_on_main),
298
+ ... devices=[torch.device("cuda:0"), torch.device("cuda:1")],
299
+ ... num_workers=2,
300
+ ... )
301
+ >>> collector = MultiSyncCollector(
302
+ ... create_env_fn=[lambda: GymEnv("CartPole-v1")],
303
+ ... policy_factory=[make_stateful_policy],
304
+ ... frames_per_batch=100,
305
+ ... total_frames=1000,
306
+ ... weight_sync_schemes={"policy": scheme},
307
+ ... )
308
+
309
+ Direct params_map usage (advanced):
310
+
311
+ >>> weights_cpu = TensorDict.from_module(policy).share_memory_()
312
+ >>> weights_cuda = weights_cpu.to("cuda").share_memory_()
313
+ >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict")
314
+ >>> scheme.init_on_sender(
315
+ ... model_id="policy",
316
+ ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda},
317
+ ... )
318
+ """
319
+ # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init
320
+ # the weights on the workers.
321
+ # Scenarios:
322
+ # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating
323
+ # the transport and registering the workers etc.
324
+ # - The user provides a model or its params and a device map. We need to create the map from the params
325
+ # explicitly.
326
+ # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need
327
+ # to collect the model from the context.
328
+ params_map = self._get_params_map(
329
+ context=context,
330
+ model_id=model_id,
331
+ weights=weights,
332
+ model=model,
333
+ params_map=params_map,
334
+ devices=devices,
335
+ device_map_fn=device_map_fn,
336
+ num_workers=num_workers,
337
+ )
338
+
339
+ # Create per-worker queues if not already created
340
+ # Collect all unique worker indices
341
+ all_workers = list(params_map.keys())
342
+
343
+ if ctx is None:
344
+ ctx = mp.get_context()
345
+ for worker_idx in all_workers:
346
+ if worker_idx not in self._weight_init_queues:
347
+ self._weight_init_queues[worker_idx] = ctx.Queue()
348
+ # Create instruction queues for background receiver
349
+ if worker_idx not in self._instruction_queues:
350
+ self._instruction_queues[worker_idx] = ctx.Queue()
351
+ # Create ack queues for synchronous mode
352
+ if worker_idx not in self._ack_queues:
353
+ self._ack_queues[worker_idx] = ctx.Queue()
354
+
355
+ # Set worker info in transport
356
+ self.shared_transport.register_weights(params_map, self._weight_init_queues)
357
+
358
+ # Store model_id and context on scheme
359
+ self.model_id = model_id
360
+ if context is not None:
361
+ self.context = context
362
+
363
+ def _get_params_map(
364
+ self,
365
+ context: Any = None,
366
+ model_id: str | None = None,
367
+ weights: TensorDictBase | None = None,
368
+ model: nn.Module | None = None,
369
+ params_map: dict[int, TensorDictBase] | None = None,
370
+ devices: list[torch.device] | None = None,
371
+ device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None,
372
+ num_workers: int | None = None,
373
+ ):
374
+ """Get the params_map for init_on_sender()."""
375
+ # Import _cast locally to avoid circular imports
376
+ from torchrl.collectors.utils import _cast
377
+
378
+ if params_map is not None:
379
+ # Sanity check: params_map must be a dict[int, TensorDictBase]
380
+ # All other args must be None
381
+ if (
382
+ not isinstance(params_map, dict)
383
+ or not all(isinstance(v, int) for v in params_map.keys())
384
+ or not all(isinstance(v, TensorDictBase) for v in params_map.values())
385
+ ):
386
+ raise ValueError("params_map must be a dict[int, TensorDictBase]")
387
+ if model_id is not None or weights is not None or model is not None:
388
+ raise ValueError(
389
+ "model_id, weights, and model cannot be provided if params_map is provided"
390
+ )
391
+ if context is not None:
392
+ raise ValueError("context cannot be provided if params_map is provided")
393
+ if devices is not None:
394
+ raise ValueError("devices cannot be provided if params_map is provided")
395
+ if device_map_fn is not None:
396
+ raise ValueError(
397
+ "device_map_fn cannot be provided if params_map is provided"
398
+ )
399
+ if num_workers is not None:
400
+ raise ValueError(
401
+ "num_workers cannot be provided if params_map is provided"
402
+ )
403
+ return params_map
404
+ elif context is not None:
405
+ if devices is not None:
406
+ raise ValueError("devices cannot be provided if context is provided")
407
+ # Sanity check: model_id must be provided if context is provided
408
+ # All other args must be None
409
+ if model_id is None:
410
+ raise ValueError("model_id must be provided if context is provided")
411
+ if model is not None:
412
+ raise ValueError("model cannot be provided if context is provided")
413
+ if weights is not None:
414
+ raise ValueError("weights cannot be provided if context is provided")
415
+ if device_map_fn is not None:
416
+ raise ValueError(
417
+ "device_map_fn cannot be provided if context is provided"
418
+ )
419
+ # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later
420
+ devices = context.policy_device
421
+ if num_workers is not None and num_workers != len(devices):
422
+ raise ValueError(
423
+ "num_workers cannot be provided if context is provided"
424
+ )
425
+ # Get the weights
426
+ model = _resolve_model(context, model_id)
427
+ if model is None:
428
+ if model_id == "policy":
429
+ # we need to get a copy of the weights from the factory
430
+ model = context.policy_factory[0]()
431
+ weights = TensorDict.from_module(model)
432
+ elif model is not None:
433
+ if weights is not None:
434
+ raise ValueError("weights cannot be provided if model is provided")
435
+ weights = TensorDict.from_module(model)
436
+ if weights is not None:
437
+ weights = weights.data.apply(_cast, weights)
438
+ # To make the map, we need the list of devices, or the map fn
439
+ if devices is not None:
440
+ # Get the unique devices
441
+ devices_set = set(devices)
442
+ weights_devices = (
443
+ {p.device for p in weights.values(True, True)}
444
+ if weights is not None
445
+ else set()
446
+ )
447
+ if len(weights_devices) == 1:
448
+ weights_device = weights_devices.pop()
449
+ else:
450
+ weights_device = None
451
+
452
+ # Create device map with proper Parameter handling using _cast
453
+ # _cast ensures Parameters stay as Parameters (with requires_grad=False)
454
+ device_map = {}
455
+ for d in devices_set:
456
+ if d != weights_device:
457
+ # Move to device and apply _cast to preserve Parameter/Buffer types
458
+ weights_on_device = weights.to(d)
459
+ weights_on_device = weights_on_device.apply(_cast, weights)
460
+ device_map[d] = weights_on_device
461
+ else:
462
+ # Already on correct device, just apply _cast
463
+ device_map[d] = weights.apply(_cast, weights)
464
+
465
+ # Create the map
466
+ params_map = {
467
+ worker_idx: device_map[device]
468
+ for worker_idx, device in enumerate(devices)
469
+ }
470
+ return params_map
471
+ if device_map_fn is not None:
472
+ return {
473
+ worker_idx: device_map_fn(worker_idx, weights)
474
+ for worker_idx in range(num_workers)
475
+ }
476
+ raise ValueError(
477
+ "Either params_map, model_id + context or model/weights + devices must be provided."
478
+ )
479
+
480
+ def _init_on_receiver_impl(
481
+ self,
482
+ *,
483
+ model_id: str | None = None,
484
+ context: Any = None,
485
+ model: Any = None,
486
+ worker_idx: int | None = None,
487
+ **kwargs,
488
+ ) -> None:
489
+ """Initialize on worker process (receiver side).
490
+
491
+ Reads from the worker's dedicated queue to receive shared weights,
492
+ then registers them in the transport. The receiver then applies these weights
493
+ to the model.
494
+
495
+ Args:
496
+ model_id: Identifier for the model being synchronized
497
+ context: Optional context object providing model and worker_idx
498
+ model: Model being synchronized
499
+ worker_idx: Worker index
500
+ **kwargs: Alternative to context (model, worker_idx, timeout, etc.)
501
+ """
502
+ # Extract parameters from context or kwargs
503
+ if context is not None:
504
+ if model_id is None:
505
+ raise ValueError("model_id is required when context is provided")
506
+ if hasattr(context, "get_model"):
507
+ model = context.get_model(model_id)
508
+ elif model is None:
509
+ model = _resolve_model(context, model_id)
510
+ worker_idx = getattr(context, "worker_idx", worker_idx)
511
+
512
+ # Store on scheme directly
513
+ self.model_id = model_id
514
+ if context is not None:
515
+ self.context = context
516
+
517
+ # Register the model
518
+ if model is not None:
519
+ self.model = model
520
+
521
+ # Store worker_idx for synchronize_weights
522
+ self.worker_idx = worker_idx
523
+
524
+ # Store references to instruction and ack queues for this worker
525
+ # These are created by init_on_sender and passed via pickle
526
+ if worker_idx is not None:
527
+ if worker_idx in self._instruction_queues:
528
+ self._receiver_instruction_queue = self._instruction_queues[worker_idx]
529
+ if worker_idx in self._ack_queues:
530
+ self._receiver_ack_queue = self._ack_queues[worker_idx]
531
+
532
+ self.create_transport()
533
+
534
+ def _wait_for_instruction(self, timeout: float | None = None) -> str | None:
535
+ """Block until an instruction arrives from the sender.
536
+
537
+ Args:
538
+ timeout: Maximum time to wait for instruction (seconds).
539
+ None means block indefinitely.
540
+
541
+ Returns:
542
+ The instruction string (e.g., "receive", "stop"), or None if
543
+ stop event is set or timeout expires.
544
+ """
545
+ if self._receiver_instruction_queue is None:
546
+ raise RuntimeError(
547
+ "Instruction queue not set. init_on_receiver() must be called first."
548
+ )
549
+
550
+ try:
551
+ # Check stop event periodically while waiting
552
+ while True:
553
+ if self._stop_event is not None and self._stop_event.is_set():
554
+ return None
555
+ try:
556
+ # Use short timeout to allow checking stop event
557
+ instruction = self._receiver_instruction_queue.get(timeout=0.1)
558
+ return instruction
559
+ except Exception:
560
+ # Queue.Empty - continue waiting
561
+ if timeout is not None:
562
+ timeout -= 0.1
563
+ if timeout <= 0:
564
+ return None
565
+ except Exception as e:
566
+ torchrl_logger.warning(f"Error waiting for instruction: {e}")
567
+ return None
568
+
569
+ def _send_instruction(
570
+ self,
571
+ instruction: str = "receive",
572
+ worker_ids: int | list[int] | None = None,
573
+ ) -> None:
574
+ """Send instruction to receiver(s) to trigger weight reception.
575
+
576
+ Args:
577
+ instruction: The instruction to send (default: "receive").
578
+ worker_ids: Which workers to send to (None = all workers).
579
+ """
580
+ if not self._instruction_queues:
581
+ raise RuntimeError(
582
+ "Instruction queues not created. init_on_sender() must be called first."
583
+ )
584
+
585
+ if worker_ids is None:
586
+ target_workers = list(self._instruction_queues.keys())
587
+ elif isinstance(worker_ids, int):
588
+ target_workers = [worker_ids]
589
+ else:
590
+ target_workers = list(worker_ids)
591
+
592
+ for worker_idx in target_workers:
593
+ if worker_idx not in self._instruction_queues:
594
+ raise ValueError(f"Worker {worker_idx} not registered")
595
+ self._instruction_queues[worker_idx].put(instruction)
596
+
597
+ def _send_ack(self, message: str = "updated") -> None:
598
+ """Send acknowledgment back to sender after receiving weights.
599
+
600
+ Args:
601
+ message: The acknowledgment message (default: "updated").
602
+ """
603
+ if self._receiver_ack_queue is not None:
604
+ self._receiver_ack_queue.put(message)
605
+
606
+ def _wait_for_ack(
607
+ self,
608
+ worker_ids: int | list[int] | None = None,
609
+ timeout: float | None = None,
610
+ ) -> None:
611
+ """Wait for acknowledgment from receiver(s).
612
+
613
+ Args:
614
+ worker_ids: Which workers to wait for (None = all workers).
615
+ timeout: Maximum time to wait (seconds). None means block indefinitely.
616
+ """
617
+ if not self._ack_queues:
618
+ return # No ack queues, nothing to wait for
619
+
620
+ if worker_ids is None:
621
+ target_workers = list(self._ack_queues.keys())
622
+ elif isinstance(worker_ids, int):
623
+ target_workers = [worker_ids]
624
+ else:
625
+ target_workers = list(worker_ids)
626
+
627
+ for worker_idx in target_workers:
628
+ if worker_idx not in self._ack_queues:
629
+ raise ValueError(f"Worker {worker_idx} not registered")
630
+ try:
631
+ ack = self._ack_queues[worker_idx].get(timeout=timeout)
632
+ if ack != "updated":
633
+ torchrl_logger.warning(
634
+ f"Unexpected ack from worker {worker_idx}: {ack}"
635
+ )
636
+ except Exception as e:
637
+ torchrl_logger.warning(
638
+ f"Timeout waiting for ack from worker {worker_idx}: {e}"
639
+ )
640
+
641
+ def create_transport(self, **kwargs) -> TransportBackend:
642
+ """Create shared memory transport.
643
+
644
+ Returns the shared transport instance that all workers will use.
645
+ Since this is shared memory, there's only one transport shared by all workers.
646
+
647
+ Note:
648
+ This is used internally by init_on_sender/init_on_receiver.
649
+ """
650
+ return self.shared_transport
651
+
652
+ def prepare_weights(
653
+ self,
654
+ weights: Any,
655
+ model_id: str,
656
+ strategy: WeightStrategy,
657
+ context: Any = None,
658
+ ) -> Any:
659
+ """Prepare weights for SharedMemWeightSyncScheme.
660
+
661
+ When weights=None, we extract fresh weights from the model and update
662
+ the shared memory buffer in-place so workers see the change.
663
+
664
+ Args:
665
+ weights: Raw weights input
666
+ model_id: The model identifier
667
+ strategy: WeightStrategy for extracting/converting weights
668
+ context: Optional context (e.g., collector) for cache lookup
669
+
670
+ Returns:
671
+ Shared memory weights ready to send
672
+ """
673
+ # If weights are explicitly provided, use them directly
674
+ if weights is not None:
675
+ fresh_weights = super().prepare_weights(
676
+ weights, model_id, strategy, context
677
+ )
678
+ else:
679
+ # Extract fresh weights from the model (base class handles this)
680
+ fresh_weights = super().prepare_weights(None, model_id, strategy, context)
681
+
682
+ if fresh_weights is None:
683
+ return None
684
+
685
+ # Update the shared memory buffer in-place so workers see the change
686
+ if self._shared_transport is not None and self.shared_transport.unique_weights:
687
+ shared_weights = self.shared_transport.unique_weights[0]
688
+ # In-place update of shared memory buffer with fresh weights
689
+ shared_weights.data.update_(fresh_weights.data)
690
+ return shared_weights
691
+
692
+ # If no shared transport, just return the fresh weights
693
+ return fresh_weights
694
+
695
+ def send(
696
+ self,
697
+ weights: Any = None,
698
+ worker_ids: int | list[int] | None = None,
699
+ ) -> None:
700
+ """Send weights via shared memory (in-place update).
701
+
702
+ For SharedMemWeightSyncScheme:
703
+ 1. prepare_weights() updates the shared memory buffer in-place
704
+ 2. _send_instruction() tells workers to apply the new weights
705
+ 3. If sync=True, waits for acknowledgments from all workers
706
+
707
+ Args:
708
+ weights: Weights to send (can be None to extract from model).
709
+ worker_ids: Which workers to notify (None = all workers).
710
+ """
711
+ if not self.initialized_on_sender:
712
+ raise RuntimeError("Must be initialized on sender before sending weights")
713
+ if not self.synchronized_on_sender:
714
+ raise RuntimeError("Must be synchronized on sender before sending weights")
715
+
716
+ # prepare_weights updates the shared buffer in-place
717
+ self.prepare_weights(
718
+ weights=weights,
719
+ model_id=self._model_id,
720
+ strategy=self._strategy,
721
+ context=self.context,
722
+ )
723
+
724
+ # Send instruction to workers' background threads to apply the weights
725
+ self._send_instruction(instruction="receive", worker_ids=worker_ids)
726
+
727
+ # Wait for acknowledgments if in synchronous mode
728
+ if self.sync:
729
+ self._wait_for_ack(worker_ids=worker_ids)
730
+
731
+ @property
732
+ def weights(self) -> Any | None:
733
+ """Get the current weights from shared memory.
734
+
735
+ For SharedMemWeightSyncScheme:
736
+ - On sender side: weights are in transport's _unique_weights
737
+ - On receiver side: weights are in _receiver_shared_weights (stored during connect())
738
+
739
+ Returns:
740
+ The weights TensorDict if available, None otherwise.
741
+ """
742
+ # On receiver side, use the stored shared buffer reference
743
+ if (
744
+ hasattr(self, "_receiver_shared_weights")
745
+ and self._receiver_shared_weights is not None
746
+ ):
747
+ return self._receiver_shared_weights
748
+
749
+ # On sender side, get from the shared transport
750
+ if self._shared_transport is not None and self.shared_transport.unique_weights:
751
+ return self.shared_transport.unique_weights[0]
752
+
753
+ # Fall back to parent implementation
754
+ return super().weights
755
+
756
+ def _setup_connection_and_weights_on_receiver_impl(
757
+ self, *, worker_idx: int | None = None
758
+ ) -> None:
759
+ """Synchronize weights on receiver side for shared memory.
760
+
761
+ Reads the shared memory buffer from the queue and applies it to the model.
762
+ Then starts a background thread that listens for "receive" instructions
763
+ from the sender and applies weights when instructed.
764
+
765
+ If a receiver_transport is set (e.g., for MultiProcessWeightSyncScheme),
766
+ defers to the base class implementation.
767
+ """
768
+ # If receiver_transport is set (e.g., MultiProcess subclass), use base behavior
769
+ if self._receiver_transport is not None:
770
+ return super()._setup_connection_and_weights_on_receiver_impl(
771
+ worker_idx=worker_idx
772
+ )
773
+
774
+ # SharedMem-specific: use shared_transport
775
+ if self._shared_transport is None:
776
+ raise RuntimeError(
777
+ "SharedMemWeightSyncScheme requires shared_transport to be set."
778
+ )
779
+
780
+ # Use stored worker_idx if not provided
781
+ if worker_idx is None:
782
+ worker_idx = self.worker_idx
783
+
784
+ if worker_idx is None:
785
+ raise RuntimeError(
786
+ "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl."
787
+ )
788
+
789
+ # Read shared memory buffer from queue
790
+ weights = self._shared_transport.setup_connection_and_weights_on_receiver(
791
+ worker_idx=worker_idx
792
+ )
793
+
794
+ # Store the shared buffer reference for later receive() calls
795
+ # This is the actual shared memory buffer that the sender updates
796
+ self._receiver_shared_weights = weights
797
+
798
+ # Apply weights to model
799
+ if weights is not None and self.model is not None:
800
+ self._strategy.apply_weights(self.model, weights, inplace=False)
801
+
802
+ # Start background receiver thread that listens for instructions
803
+ self._start_background_receiver()
804
+
805
+ def _background_receive_loop(self):
806
+ """Background thread loop that waits for instructions and applies weights.
807
+
808
+ This loop:
809
+ 1. Waits for a "receive" instruction from the sender
810
+ 2. Applies the current shared memory weights to the model
811
+ 3. Sends an acknowledgment back to the sender
812
+ 4. Repeats until stop event is set or "stop" instruction received
813
+ """
814
+ while not self._stop_event.is_set():
815
+ try:
816
+ instruction = self._wait_for_instruction()
817
+ if instruction is None:
818
+ # Stop event was set or timeout
819
+ continue
820
+ if instruction == "receive":
821
+ # Apply the current shared memory weights to the model
822
+ # The weights are already updated in shared memory by the sender
823
+ if (
824
+ self._receiver_shared_weights is not None
825
+ and self.model is not None
826
+ ):
827
+ self._strategy.apply_weights(
828
+ self.model, self._receiver_shared_weights, inplace=True
829
+ )
830
+
831
+ # Cascade weight update to sub-collectors if context supports it
832
+ model_id = self._model_id or "policy"
833
+ if self.context is not None and hasattr(
834
+ self.context, "update_policy_weights_"
835
+ ):
836
+ self.context.update_policy_weights_(
837
+ model_id=model_id,
838
+ policy_or_weights=self._receiver_shared_weights,
839
+ )
840
+
841
+ # Send acknowledgment
842
+ self._send_ack("updated")
843
+ elif instruction == "stop":
844
+ break
845
+ else:
846
+ torchrl_logger.warning(
847
+ f"SharedMemWeightSyncScheme: Unknown instruction: {instruction}"
848
+ )
849
+ except Exception as e:
850
+ if not self._stop_event.is_set():
851
+ torchrl_logger.warning(
852
+ f"SharedMemWeightSyncScheme: Background receiver error: {e}"
853
+ )
854
+
855
+ def __getstate__(self):
856
+ """Prepare the scheme for pickling."""
857
+ state = super().__getstate__()
858
+ # mp.Queue objects can be pickled and shared across processes
859
+ # Keep them in state so workers have access
860
+ return state
861
+
862
+ def shutdown(self) -> None:
863
+ """Stop the background receiver thread and clean up."""
864
+ # Check if already shutdown
865
+ if getattr(self, "_is_shutdown", False):
866
+ return
867
+ self._is_shutdown = True
868
+
869
+ # Signal all workers to stop
870
+ instruction_queues = getattr(self, "_instruction_queues", None)
871
+ if instruction_queues:
872
+ for _, queue in instruction_queues.items():
873
+ queue.put("stop")
874
+
875
+ # Let base class handle background thread cleanup
876
+ super().shutdown()
877
+
878
+ # Close all multiprocessing queues created by the scheme.
879
+ queues_to_close = []
880
+ for name in ("_weight_init_queues", "_instruction_queues", "_ack_queues"):
881
+ mapping = getattr(self, name, None)
882
+ if not mapping:
883
+ continue
884
+ queues_to_close.extend(mapping.values())
885
+ setattr(self, name, {})
886
+
887
+ unique = {}
888
+ for q in queues_to_close:
889
+ unique[id(q)] = q
890
+ for q in unique.values():
891
+ _close_mp_queue(q)