torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,109 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any
9
+
10
+
11
+ class ServiceBase(ABC):
12
+ """Base class for distributed service registries.
13
+
14
+ A service registry manages distributed actors/services that can be accessed
15
+ across multiple workers. Common use cases include:
16
+
17
+ - Tokenizers shared across inference workers
18
+ - Replay buffers for distributed training
19
+ - Model registries for centralized model storage
20
+ - Metrics aggregators
21
+
22
+ The registry provides a dict-like interface for registering and accessing
23
+ services by name.
24
+ """
25
+
26
+ @abstractmethod
27
+ def register(self, name: str, service_factory: type, *args, **kwargs) -> Any:
28
+ """Register a service factory and create the service actor.
29
+
30
+ This method registers a service with the given name and immediately
31
+ creates the corresponding actor. The service becomes globally visible
32
+ to all workers in the cluster.
33
+
34
+ Args:
35
+ name: Unique identifier for the service. This name is used to
36
+ retrieve the service later.
37
+ service_factory: Class to instantiate as a remote actor.
38
+ *args: Positional arguments to pass to the service constructor.
39
+ **kwargs: Keyword arguments for both actor configuration and
40
+ service constructor. Actor configuration options are backend-specific
41
+ (e.g., num_cpus, num_gpus for Ray).
42
+
43
+ Returns:
44
+ The remote actor handle.
45
+
46
+ Raises:
47
+ ValueError: If a service with this name already exists.
48
+ """
49
+
50
+ @abstractmethod
51
+ def get(self, name: str) -> Any:
52
+ """Get a service by name.
53
+
54
+ Retrieves a previously registered service. If the service was registered
55
+ by another worker, this method will find it in the distributed registry.
56
+
57
+ Args:
58
+ name: Service identifier.
59
+
60
+ Returns:
61
+ The remote actor handle for the service.
62
+
63
+ Raises:
64
+ KeyError: If the service is not found.
65
+ """
66
+
67
+ @abstractmethod
68
+ def __contains__(self, name: str) -> bool:
69
+ """Check if a service is registered.
70
+
71
+ Args:
72
+ name: Service identifier.
73
+
74
+ Returns:
75
+ True if the service exists, False otherwise.
76
+ """
77
+
78
+ @abstractmethod
79
+ def list(self) -> list[str]:
80
+ """List all registered service names.
81
+
82
+ Returns:
83
+ List of service names currently registered in the cluster.
84
+ """
85
+
86
+ @abstractmethod
87
+ def reset(self) -> None:
88
+ """Reset the service registry.
89
+
90
+ This removes all registered services and cleans up associated resources.
91
+ After calling reset(), the registry will be empty and all service actors
92
+ will be terminated.
93
+
94
+ Warning:
95
+ This is a destructive operation. All services will be terminated and
96
+ any ongoing work will be interrupted.
97
+ """
98
+
99
+ def __getitem__(self, name: str) -> Any:
100
+ """Dict-like access: services["tokenizer"]."""
101
+ return self.get(name)
102
+
103
+ def __setitem__(self, name: str, service_factory: type) -> None:
104
+ """Dict-like registration: services["tokenizer"] = TokenizerClass.
105
+
106
+ Note: This only supports service_factory without additional arguments.
107
+ For full control, use register() method instead.
108
+ """
109
+ self.register(name, service_factory)
@@ -0,0 +1,453 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from typing import Any
8
+
9
+ from torchrl._utils import logger
10
+ from torchrl.services.base import ServiceBase
11
+
12
+ RAY_ERR = None
13
+ try:
14
+ import ray
15
+
16
+ _has_ray = True
17
+ except ImportError as err:
18
+ _has_ray = False
19
+ RAY_ERR = err
20
+
21
+
22
+ class _ServiceRegistryActor:
23
+ """Internal actor that maintains the list of registered services.
24
+
25
+ This is a lightweight actor (1 CPU) that tracks which services have been
26
+ registered in a namespace. This ensures we only list our own services,
27
+ not other named actors in Ray.
28
+ """
29
+
30
+ def __init__(self):
31
+ self._services: set[str] = set()
32
+
33
+ def add(self, name: str) -> None:
34
+ """Add a service to the registry."""
35
+ self._services.add(name)
36
+
37
+ def remove(self, name: str) -> None:
38
+ """Remove a service from the registry."""
39
+ self._services.discard(name)
40
+
41
+ def list(self) -> list[str]:
42
+ """List all registered services."""
43
+ return sorted(self._services)
44
+
45
+ def clear(self) -> None:
46
+ """Clear all registered services."""
47
+ self._services.clear()
48
+
49
+ def contains(self, name: str) -> bool:
50
+ """Check if a service is registered."""
51
+ return name in self._services
52
+
53
+
54
+ class RayService(ServiceBase):
55
+ """Ray-based distributed service registry.
56
+
57
+ This class uses Ray's named actors feature to provide truly distributed
58
+ service discovery. When a service is registered by any worker, it becomes
59
+ immediately accessible to all other workers in the Ray cluster.
60
+
61
+ Services are registered as Ray actors with globally unique names. This
62
+ ensures that:
63
+ 1. Services persist independently of the registering worker
64
+ 2. All workers see the same services instantly
65
+ 3. No custom synchronization is needed
66
+
67
+ Args:
68
+ ray_init_config (dict, optional): Configuration for ray.init(). Only
69
+ used if Ray is not already initialized. Common options:
70
+ - address (str): Ray cluster address, or "auto" to auto-detect
71
+ - num_cpus (int): Number of CPUs to use
72
+ - num_gpus (int): Number of GPUs to use
73
+ namespace (str, optional): Ray namespace for service isolation. Services
74
+ in different namespaces are isolated from each other. Defaults to
75
+ "torchrl_services".
76
+
77
+ Examples:
78
+ >>> # Basic usage
79
+ >>> services = RayService()
80
+ >>> services.register("tokenizer", TokenizerClass, num_cpus=1)
81
+ >>> tokenizer = services["tokenizer"]
82
+ >>>
83
+ >>> # With Ray options for dynamic configuration
84
+ >>> actor = services.register(
85
+ ... "model",
86
+ ... ModelClass,
87
+ ... num_cpus=2,
88
+ ... num_gpus=1,
89
+ ... memory=10 * 1024**3,
90
+ ... max_concurrency=4
91
+ ... )
92
+ >>>
93
+ >>> # Check and retrieve
94
+ >>> if "tokenizer" in services:
95
+ ... tok = services["tokenizer"]
96
+ >>>
97
+ >>> # List all services
98
+ >>> print(services.list())
99
+ ['tokenizer', 'model']
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ ray_init_config: dict | None = None,
105
+ namespace: str = "torchrl_services",
106
+ ):
107
+ if not _has_ray:
108
+ raise ImportError(
109
+ "Ray is required for RayService. Install with: pip install ray"
110
+ ) from RAY_ERR
111
+
112
+ self._namespace = namespace
113
+ self._ensure_ray_initialized(ray_init_config)
114
+ self._registry_actor = self._get_or_create_registry_actor()
115
+
116
+ def _ensure_ray_initialized(self, ray_init_config: dict | None = None):
117
+ """Initialize Ray if not already initialized."""
118
+ if not ray.is_initialized():
119
+ config = ray_init_config or {}
120
+ # Ensure namespace is set
121
+ if "namespace" not in config:
122
+ config["namespace"] = self._namespace
123
+
124
+ logger.info(f"Initializing Ray with namespace '{self._namespace}'")
125
+ ray.init(**config)
126
+ else:
127
+ # Ray already initialized - check if namespace matches
128
+ context = ray.get_runtime_context()
129
+ current_namespace = context.namespace
130
+ if current_namespace != self._namespace:
131
+ logger.warning(
132
+ f"Ray already initialized with namespace '{current_namespace}', "
133
+ f"but RayService is using namespace '{self._namespace}'. "
134
+ f"Services may not be visible across namespaces."
135
+ )
136
+
137
+ def _make_service_name(self, name: str) -> str:
138
+ """Create the full actor name with namespace prefix."""
139
+ return f"{self._namespace}::service::{name}"
140
+
141
+ def _get_registry_actor_name(self) -> str:
142
+ """Get the name of the registry actor for this namespace."""
143
+ return f"{self._namespace}::_registry"
144
+
145
+ def _get_or_create_registry_actor(self):
146
+ """Get or create the registry actor for this namespace."""
147
+ registry_name = self._get_registry_actor_name()
148
+
149
+ try:
150
+ # Try to get existing registry
151
+ registry = ray.get_actor(registry_name, namespace=self._namespace)
152
+ return registry
153
+ except ValueError:
154
+ # Registry doesn't exist, create it
155
+ RemoteRegistry = ray.remote(_ServiceRegistryActor)
156
+ registry = RemoteRegistry.options(
157
+ name=registry_name,
158
+ namespace=self._namespace,
159
+ lifetime="detached",
160
+ num_cpus=1,
161
+ ).remote()
162
+ logger.info(
163
+ f"Created service registry actor for namespace '{self._namespace}'"
164
+ )
165
+ return registry
166
+
167
+ def register(self, name: str, service_factory: type, *args, **kwargs) -> Any:
168
+ """Register a service and create a named Ray actor.
169
+
170
+ This method creates a Ray actor with a globally unique name. The actor
171
+ becomes immediately visible to all workers in the cluster.
172
+
173
+ Args:
174
+ name: Service identifier. Must be unique within the namespace.
175
+ service_factory: Class to instantiate as a Ray actor.
176
+ *args: Positional arguments for the service constructor.
177
+ **kwargs: Both Ray actor options (num_cpus, num_gpus, memory, etc.)
178
+ and service constructor arguments. Ray will filter out the actor
179
+ options it recognizes.
180
+
181
+ Returns:
182
+ The Ray actor handle.
183
+
184
+ Raises:
185
+ ValueError: If a service with this name already exists.
186
+
187
+ Examples:
188
+ >>> services = RayService()
189
+ >>>
190
+ >>> # Basic registration
191
+ >>> tokenizer = services.register("tokenizer", TokenizerClass)
192
+ >>>
193
+ >>> # With Ray resource specification
194
+ >>> buffer = services.register(
195
+ ... "buffer",
196
+ ... ReplayBuffer,
197
+ ... num_cpus=2,
198
+ ... num_gpus=0,
199
+ ... size=1000000
200
+ ... )
201
+ >>>
202
+ >>> # With advanced Ray options
203
+ >>> model = services.register(
204
+ ... "model",
205
+ ... ModelClass,
206
+ ... num_cpus=4,
207
+ ... num_gpus=1,
208
+ ... memory=20 * 1024**3,
209
+ ... max_concurrency=10,
210
+ ... max_restarts=3,
211
+ ... )
212
+ """
213
+ full_name = self._make_service_name(name)
214
+
215
+ # Check if service already exists in our registry
216
+ if ray.get(self._registry_actor.contains.remote(name)):
217
+ raise ValueError(
218
+ f"Service '{name}' already exists in namespace '{self._namespace}'. "
219
+ f"Use a different name or retrieve the existing service with get()."
220
+ )
221
+
222
+ # Create the Ray remote class
223
+ # First, make it a remote class
224
+ remote_cls = ray.remote(service_factory)
225
+
226
+ # Then apply options including the name
227
+ options = {
228
+ "name": full_name,
229
+ "namespace": self._namespace,
230
+ "lifetime": "detached",
231
+ }
232
+
233
+ # Extract Ray-specific options from kwargs
234
+ ray_options = [
235
+ "num_cpus",
236
+ "num_gpus",
237
+ "memory",
238
+ "object_store_memory",
239
+ "resources",
240
+ "accelerator_type",
241
+ "max_concurrency",
242
+ "max_restarts",
243
+ "max_task_retries",
244
+ "max_pending_calls",
245
+ "scheduling_strategy",
246
+ ]
247
+
248
+ for opt in ray_options:
249
+ if opt in kwargs:
250
+ options[opt] = kwargs.pop(opt)
251
+
252
+ # Apply options and create the actor
253
+ remote_actor = remote_cls.options(**options).remote(*args, **kwargs)
254
+
255
+ # Add to registry
256
+ ray.get(self._registry_actor.add.remote(name))
257
+
258
+ logger.info(
259
+ f"Registered service '{name}' as Ray actor '{full_name}' "
260
+ f"with options: {options}"
261
+ )
262
+
263
+ return remote_actor
264
+
265
+ def get(self, name: str) -> Any:
266
+ """Get a service by name.
267
+
268
+ Retrieves a service actor by name. The service can have been registered
269
+ by any worker in the cluster.
270
+
271
+ Args:
272
+ name: Service identifier.
273
+
274
+ Returns:
275
+ The Ray actor handle.
276
+
277
+ Raises:
278
+ KeyError: If the service is not found.
279
+
280
+ Examples:
281
+ >>> services = RayService()
282
+ >>> tokenizer = services.get("tokenizer")
283
+ >>> # Use the actor
284
+ >>> result = ray.get(tokenizer.encode.remote("Hello world"))
285
+ """
286
+ # Check registry first
287
+ if not ray.get(self._registry_actor.contains.remote(name)):
288
+ raise KeyError(
289
+ f"Service '{name}' not found in namespace '{self._namespace}'. "
290
+ f"Available services: {self.list()}"
291
+ )
292
+
293
+ full_name = self._make_service_name(name)
294
+
295
+ try:
296
+ actor = ray.get_actor(full_name, namespace=self._namespace)
297
+ return actor
298
+ except ValueError as e:
299
+ # Service in registry but actor missing - inconsistency
300
+ logger.warning(
301
+ f"Service '{name}' in registry but actor not found. "
302
+ f"Removing from registry."
303
+ )
304
+ ray.get(self._registry_actor.remove.remote(name))
305
+ raise KeyError(
306
+ f"Service '{name}' actor not found (removed from registry). "
307
+ f"Available services: {self.list()}"
308
+ ) from e
309
+
310
+ def __contains__(self, name: str) -> bool:
311
+ """Check if a service is registered.
312
+
313
+ Args:
314
+ name: Service identifier.
315
+
316
+ Returns:
317
+ True if the service exists, False otherwise.
318
+
319
+ Examples:
320
+ >>> services = RayService()
321
+ >>> if "tokenizer" in services:
322
+ ... tokenizer = services["tokenizer"]
323
+ ... else:
324
+ ... services.register("tokenizer", TokenizerClass)
325
+ """
326
+ return ray.get(self._registry_actor.contains.remote(name))
327
+
328
+ def list(self) -> list[str]:
329
+ """List all registered service names.
330
+
331
+ Returns a list of all services in the current namespace. This includes
332
+ services registered by any worker.
333
+
334
+ Returns:
335
+ List of service names (without namespace prefix).
336
+
337
+ Examples:
338
+ >>> services = RayService()
339
+ >>> services.register("tokenizer", TokenizerClass)
340
+ >>> services.register("buffer", ReplayBuffer)
341
+ >>> print(services.list())
342
+ ['buffer', 'tokenizer']
343
+ """
344
+ return ray.get(self._registry_actor.list.remote())
345
+
346
+ def reset(self) -> None:
347
+ """Reset the service registry by terminating all actors.
348
+
349
+ This method:
350
+ 1. Terminates all service actors in the current namespace
351
+ 2. Clears the registry actor's internal state
352
+
353
+ After calling reset(), all services will be removed and their actors
354
+ will be killed. Any ongoing work will be interrupted.
355
+
356
+ Warning:
357
+ This is a destructive operation that affects all workers in the
358
+ namespace. Use with caution.
359
+
360
+ Examples:
361
+ >>> services = RayService(namespace="experiment")
362
+ >>> services.register("tokenizer", TokenizerClass)
363
+ >>> print(services.list())
364
+ ['tokenizer']
365
+ >>> services.reset()
366
+ >>> print(services.list())
367
+ []
368
+ """
369
+ service_names = self.list()
370
+
371
+ for name in service_names:
372
+ full_name = self._make_service_name(name)
373
+ try:
374
+ actor = ray.get_actor(full_name, namespace=self._namespace)
375
+ ray.kill(actor)
376
+ logger.info(f"Terminated service '{name}' (actor: {full_name})")
377
+ except ValueError:
378
+ # Actor already gone or doesn't exist
379
+ logger.warning(f"Service '{name}' not found during reset")
380
+ except Exception as e:
381
+ logger.warning(f"Failed to terminate service '{name}': {e}")
382
+
383
+ # Clear the registry
384
+ ray.get(self._registry_actor.clear.remote())
385
+
386
+ logger.info(
387
+ f"Reset complete for namespace '{self._namespace}'. Terminated {len(service_names)} services."
388
+ )
389
+
390
+ def shutdown(self, raise_on_error: bool = True) -> None:
391
+ """Shutdown the RayService by shutting down the Ray cluster."""
392
+ try:
393
+ self.reset()
394
+ # kill the registry actor
395
+ registry_actor = ray.get_actor(
396
+ self._get_registry_actor_name(), namespace=self._namespace
397
+ )
398
+ ray.kill(registry_actor, no_restart=True)
399
+ except Exception as e:
400
+ if raise_on_error:
401
+ raise e
402
+ else:
403
+ logger.warning(f"Error shutting down RayService: {e}")
404
+
405
+ def register_with_options(
406
+ self,
407
+ name: str,
408
+ service_factory: type,
409
+ actor_options: dict[str, Any],
410
+ **constructor_kwargs,
411
+ ) -> Any:
412
+ """Register a service with explicit separation of Ray options and constructor args.
413
+
414
+ This is a convenience method that makes it explicit which arguments are for
415
+ Ray actor configuration vs. the service constructor. It's functionally
416
+ equivalent to `register()` but more readable for complex configurations.
417
+
418
+ Args:
419
+ name: Service identifier.
420
+ service_factory: Class to instantiate as a Ray actor.
421
+ actor_options: Dictionary of Ray actor options (num_cpus, num_gpus, etc.).
422
+ **constructor_kwargs: Arguments to pass to the service constructor.
423
+
424
+ Returns:
425
+ The Ray actor handle.
426
+
427
+ Examples:
428
+ >>> services = RayService()
429
+ >>>
430
+ >>> # Explicit separation of concerns
431
+ >>> model = services.register_with_options(
432
+ ... "model",
433
+ ... ModelClass,
434
+ ... actor_options={
435
+ ... "num_cpus": 4,
436
+ ... "num_gpus": 1,
437
+ ... "memory": 20 * 1024**3,
438
+ ... "max_concurrency": 10
439
+ ... },
440
+ ... model_path="/path/to/checkpoint",
441
+ ... batch_size=32
442
+ ... )
443
+ >>>
444
+ >>> # Equivalent to:
445
+ >>> # services.register(
446
+ >>> # "model", ModelClass,
447
+ >>> # num_cpus=4, num_gpus=1, memory=20*1024**3, max_concurrency=10,
448
+ >>> # model_path="/path/to/checkpoint", batch_size=32
449
+ >>> # )
450
+ """
451
+ # Merge actor_options into kwargs for register()
452
+ merged_kwargs = {**actor_options, **constructor_kwargs}
453
+ return self.register(name, service_factory, **merged_kwargs)
@@ -0,0 +1,107 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Testing utilities for TorchRL.
7
+
8
+ This module provides helper classes and utilities for testing TorchRL functionality,
9
+ particularly for distributed and Ray-based tests that require importable classes.
10
+ """
11
+
12
+ from torchrl.testing.assertions import (
13
+ check_rollout_consistency_multikey_env,
14
+ rand_reset,
15
+ rollout_consistency_assertion,
16
+ )
17
+ from torchrl.testing.env_creators import (
18
+ get_transform_out,
19
+ make_envs,
20
+ make_multithreaded_env,
21
+ )
22
+ from torchrl.testing.gym_helpers import (
23
+ BREAKOUT_VERSIONED,
24
+ CARTPOLE_VERSIONED,
25
+ CLIFFWALKING_VERSIONED,
26
+ HALFCHEETAH_VERSIONED,
27
+ PENDULUM_VERSIONED,
28
+ PONG_VERSIONED,
29
+ )
30
+ from torchrl.testing.llm_mocks import (
31
+ DummyStrDataLoader,
32
+ DummyTensorDataLoader,
33
+ MockTransformerConfig,
34
+ MockTransformerModel,
35
+ MockTransformerOutput,
36
+ )
37
+ from torchrl.testing.modules import (
38
+ BiasModule,
39
+ call_value_nets,
40
+ LSTMNet,
41
+ NonSerializableBiasModule,
42
+ )
43
+ from torchrl.testing.ray_helpers import (
44
+ WorkerTransformerDoubleBuffer,
45
+ WorkerTransformerNCCL,
46
+ WorkerVLLMDoubleBuffer,
47
+ WorkerVLLMNCCL,
48
+ )
49
+ from torchrl.testing.utils import (
50
+ capture_log_records,
51
+ dtype_fixture,
52
+ generate_seeds,
53
+ get_available_devices,
54
+ get_default_devices,
55
+ IS_WIN,
56
+ make_tc,
57
+ mp_ctx,
58
+ PYTHON_3_9,
59
+ retry,
60
+ set_global_var,
61
+ )
62
+
63
+ __all__ = [
64
+ # Assertions
65
+ "check_rollout_consistency_multikey_env",
66
+ "rand_reset",
67
+ "rollout_consistency_assertion",
68
+ # Environment creators
69
+ "get_transform_out",
70
+ "make_envs",
71
+ "make_multithreaded_env",
72
+ # Gym helpers
73
+ "BREAKOUT_VERSIONED",
74
+ "CARTPOLE_VERSIONED",
75
+ "CLIFFWALKING_VERSIONED",
76
+ "HALFCHEETAH_VERSIONED",
77
+ "PENDULUM_VERSIONED",
78
+ "PONG_VERSIONED",
79
+ # LLM mocks
80
+ "DummyStrDataLoader",
81
+ "DummyTensorDataLoader",
82
+ "MockTransformerConfig",
83
+ "MockTransformerModel",
84
+ "MockTransformerOutput",
85
+ # Modules
86
+ "BiasModule",
87
+ "call_value_nets",
88
+ "LSTMNet",
89
+ "NonSerializableBiasModule",
90
+ # Ray helpers
91
+ "WorkerTransformerDoubleBuffer",
92
+ "WorkerTransformerNCCL",
93
+ "WorkerVLLMDoubleBuffer",
94
+ "WorkerVLLMNCCL",
95
+ # Utils
96
+ "capture_log_records",
97
+ "dtype_fixture",
98
+ "generate_seeds",
99
+ "get_available_devices",
100
+ "get_default_devices",
101
+ "IS_WIN",
102
+ "make_tc",
103
+ "mp_ctx",
104
+ "PYTHON_3_9",
105
+ "retry",
106
+ "set_global_var",
107
+ ]