torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,554 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import math
9
+ from collections.abc import Callable, Sequence
10
+ from copy import copy
11
+
12
+ import numpy as np
13
+ import torch
14
+ from tensordict import NonTensorData, TensorDictBase
15
+ from tensordict.utils import NestedKey
16
+ from torchrl._utils import _can_be_pickled
17
+ from torchrl.data.tensor_specs import NonTensor, TensorSpec, Unbounded
18
+ from torchrl.data.utils import CloudpickleWrapper
19
+ from torchrl.envs import EnvBase
20
+ from torchrl.envs.transforms import ObservationTransform, Transform
21
+ from torchrl.record.loggers import Logger
22
+
23
+ _has_tv = importlib.util.find_spec("torchvision", None) is not None
24
+
25
+
26
+ class VideoRecorder(ObservationTransform):
27
+ """Video Recorder transform.
28
+
29
+ Will record a series of observations from an environment and write them
30
+ to a Logger object when needed.
31
+
32
+ Args:
33
+ logger (Logger): a Logger instance where the video
34
+ should be written. To save the video under a memmap tensor or an mp4 file, use
35
+ the :class:`~torchrl.record.loggers.CSVLogger` class.
36
+ tag (str): the video tag in the logger.
37
+ in_keys (Sequence of NestedKey, optional): keys to be read to produce the video.
38
+ Default is :obj:`"pixels"`.
39
+ skip (int): frame interval in the output video.
40
+ Default is ``2`` if the transform has a parent environment, and ``1`` if not.
41
+ center_crop (int, optional): value of square center crop.
42
+ make_grid (bool, optional): if ``True``, a grid is created assuming that a
43
+ tensor of shape [B x W x H x 3] is provided, with B being the batch
44
+ size. Default is ``True`` if the transform has a parent environment, and ``False``
45
+ if not.
46
+ out_keys (sequence of NestedKey, optional): destination keys. Defaults
47
+ to ``in_keys`` if not provided.
48
+ fps (int, optional): Frames per second of the output video. Defaults to the logger predefined ``fps``,
49
+ and overrides it if provided.
50
+ **kwargs (Dict[str, Any], optional): additional keyword arguments for
51
+ :meth:`~torchrl.record.loggers.Logger.log_video`.
52
+
53
+ Examples:
54
+ The following example shows how to save a rollout under a video. First a few imports:
55
+
56
+ >>> from torchrl.record import VideoRecorder
57
+ >>> from torchrl.record.loggers.csv import CSVLogger
58
+ >>> from torchrl.envs import TransformedEnv, DMControlEnv
59
+
60
+ The video format is chosen in the logger. Wandb and tensorboard will take care of that
61
+ on their own, CSV accepts various video formats.
62
+
63
+ >>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")
64
+
65
+ Some envs (eg, Atari games) natively return images, some require the user to ask for them.
66
+ Check :class:`~torchrl.envs.GymEnv` or :class:`~torchrl.envs.DMControlEnv` to see how to render images
67
+ in these contexts.
68
+
69
+ >>> base_env = DMControlEnv("cheetah", "run", from_pixels=True)
70
+ >>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video"))
71
+ >>> env.rollout(100)
72
+
73
+ All transforms have a dump function, mostly a no-op except for ``VideoRecorder``, and :class:`~torchrl.envs.transforms.Compose`
74
+ which will dispatch the `dumps` to all its members.
75
+
76
+ >>> env.transform.dump()
77
+
78
+ The transform can also be used within a dataset to save the video collected. Unlike in the environment case,
79
+ images will come in a batch. The ``skip`` argument will enable to save the images only at specific intervals.
80
+
81
+ >>> from torchrl.data.datasets import OpenXExperienceReplay
82
+ >>> from torchrl.envs import Compose
83
+ >>> from torchrl.record import VideoRecorder, CSVLogger
84
+ >>> # Create a logger that saves videos as mp4 using 24 frames per sec
85
+ >>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24)
86
+ >>> # We use the VideoRecorder transform to save register the images coming from the batch.
87
+ >>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged.
88
+ >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12)
89
+ >>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
90
+ >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
91
+ ... download=True, strict_length=False,
92
+ ... transform=t)
93
+ >>> # Get a batch of data and visualize it
94
+ >>> for data in dataset:
95
+ ... t.dump()
96
+ ... break
97
+
98
+
99
+ Our video is available under ``./cheetah_videos/cheetah/videos/run_video_0.mp4``!
100
+
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ logger: Logger,
106
+ tag: str,
107
+ in_keys: Sequence[NestedKey] | None = None,
108
+ skip: int | None = None,
109
+ center_crop: int | None = None,
110
+ make_grid: bool | None = None,
111
+ out_keys: Sequence[NestedKey] | None = None,
112
+ fps: int | None = None,
113
+ **kwargs,
114
+ ) -> None:
115
+ if in_keys is None:
116
+ in_keys = ["pixels"]
117
+ if out_keys is None:
118
+ out_keys = copy(in_keys)
119
+ super().__init__(in_keys=in_keys, out_keys=out_keys)
120
+ video_kwargs = {}
121
+ video_kwargs.update(kwargs)
122
+ if fps is not None:
123
+ video_kwargs["fps"] = fps
124
+ self.video_kwargs = video_kwargs
125
+ self.iter = 0
126
+ self.skip = skip
127
+ self.logger = logger
128
+ self.tag = tag
129
+ self.count = 0
130
+ self.center_crop = center_crop
131
+ self.make_grid = make_grid
132
+ if center_crop and not _has_tv:
133
+ raise ImportError(
134
+ "Could not load center_crop from torchvision. Make sure torchvision is installed."
135
+ )
136
+ self.obs = []
137
+
138
+ @property
139
+ def make_grid(self):
140
+ make_grid = self._make_grid
141
+ if make_grid is None:
142
+ if self.parent is not None:
143
+ self._make_grid = True
144
+ return True
145
+ self._make_grid = False
146
+ return False
147
+ return make_grid
148
+
149
+ @make_grid.setter
150
+ def make_grid(self, value):
151
+ self._make_grid = value
152
+
153
+ @property
154
+ def skip(self):
155
+ skip = self._skip
156
+ if skip is None:
157
+ if self.parent is not None:
158
+ self._skip = 2
159
+ return 2
160
+ self._skip = 1
161
+ return 1
162
+ return skip
163
+
164
+ @skip.setter
165
+ def skip(self, value):
166
+ self._skip = value
167
+
168
+ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
169
+ if isinstance(observation, NonTensorData):
170
+ observation_trsf = torch.tensor(observation.data)
171
+ else:
172
+ observation_trsf = observation
173
+ self.count += 1
174
+ if self.count % self.skip == 0:
175
+ if (
176
+ observation_trsf.ndim >= 3
177
+ and observation_trsf.shape[-3] == 3
178
+ and observation_trsf.shape[-2] > 3
179
+ and observation_trsf.shape[-1] > 3
180
+ ):
181
+ # permute the channels to the last dim
182
+ observation_trsf = observation_trsf.permute(
183
+ *range(observation_trsf.ndim - 3), -2, -1, -3
184
+ )
185
+ if not (
186
+ observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2
187
+ ):
188
+ raise RuntimeError(
189
+ f"Invalid observation shape, got: {observation.shape}"
190
+ )
191
+ observation_trsf = observation_trsf.clone()
192
+
193
+ if observation.ndimension() == 2:
194
+ observation_trsf = observation.unsqueeze(-3)
195
+ else:
196
+ if observation_trsf.shape[-1] != 3:
197
+ raise RuntimeError(
198
+ "observation_trsf is expected to have 3 dimensions, "
199
+ f"got {observation_trsf.ndimension()} instead"
200
+ )
201
+ trailing_dim = range(observation_trsf.ndimension() - 3)
202
+ observation_trsf = observation_trsf.permute(*trailing_dim, -1, -3, -2)
203
+ if self.center_crop:
204
+ if not _has_tv:
205
+ raise ImportError(
206
+ "Could not import torchvision, `center_crop` not available. "
207
+ "Make sure torchvision is installed in your environment."
208
+ )
209
+ from torchvision.transforms.functional import (
210
+ center_crop as center_crop_fn,
211
+ )
212
+
213
+ observation_trsf = center_crop_fn(
214
+ observation_trsf, [self.center_crop, self.center_crop]
215
+ )
216
+ if self.make_grid and observation_trsf.ndimension() >= 4:
217
+ if not _has_tv:
218
+ raise ImportError(
219
+ "Could not import torchvision, `make_grid` not available. "
220
+ "Make sure torchvision is installed in your environment."
221
+ )
222
+ from torchvision.utils import make_grid
223
+
224
+ obs_flat = observation_trsf.flatten(0, -4)
225
+ observation_trsf = make_grid(
226
+ obs_flat, nrow=int(math.ceil(math.sqrt(obs_flat.shape[0])))
227
+ )
228
+ self.obs.append(observation_trsf.to("cpu", torch.uint8))
229
+ elif observation_trsf.ndimension() >= 4:
230
+ self.obs.extend(observation_trsf.to("cpu", torch.uint8).flatten(0, -4))
231
+ else:
232
+ self.obs.append(observation_trsf.to("cpu", torch.uint8))
233
+ return observation
234
+
235
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
236
+ return self._call(tensordict)
237
+
238
+ def dump(self, suffix: str | None = None, step: int | None = None) -> None:
239
+ """Writes the video to the ``self.logger`` attribute.
240
+
241
+ Calling ``dump`` when no image has been stored in a no-op.
242
+
243
+ Args:
244
+ suffix (str, optional): a suffix for the video to be recorded.
245
+ step (int, optional): the step to log the video at. If not provided,
246
+ uses an internal counter that increments with each dump call.
247
+ """
248
+ if self.obs:
249
+ obs = torch.stack(self.obs, 0).unsqueeze(0).cpu()
250
+ else:
251
+ obs = None
252
+ self.obs = []
253
+ if obs is not None:
254
+ if suffix is None:
255
+ tag = self.tag
256
+ else:
257
+ tag = "_".join([self.tag, suffix])
258
+ if self.logger is not None:
259
+ self.logger.log_video(
260
+ name=tag,
261
+ video=obs,
262
+ step=step if step is not None else self.iter,
263
+ **self.video_kwargs,
264
+ )
265
+ self.iter += 1
266
+ self.count = 0
267
+ self.obs = []
268
+
269
+ def _reset(
270
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
271
+ ) -> TensorDictBase:
272
+ self._call(tensordict_reset)
273
+ return tensordict_reset
274
+
275
+
276
+ class TensorDictRecorder(Transform):
277
+ """TensorDict recorder.
278
+
279
+ When the 'dump' method is called, this class will save a stack of the tensordict resulting from :obj:`env.step(td)` in a
280
+ file with a prefix defined by the out_file_base argument.
281
+
282
+ Args:
283
+ out_file_base (str): a string defining the prefix of the file where the tensordict will be written.
284
+ skip_reset (bool): if ``True``, the first TensorDict of the list will be discarded (usually the tensordict
285
+ resulting from the call to :obj:`env.reset()`)
286
+ default: True
287
+ skip (int): frame interval for the saved tensordict.
288
+ default: 4
289
+
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ out_file_base: str,
295
+ skip_reset: bool = True,
296
+ skip: int = 4,
297
+ in_keys: Sequence[str] | None = None,
298
+ ) -> None:
299
+ if in_keys is None:
300
+ in_keys = []
301
+
302
+ super().__init__(in_keys=in_keys)
303
+ self.iter = 0
304
+ self.out_file_base = out_file_base
305
+ self.td = []
306
+ self.skip_reset = skip_reset
307
+ self.skip = skip
308
+ self.count = 0
309
+
310
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
311
+ self.count += 1
312
+ if self.count % self.skip == 0:
313
+ _td = next_tensordict
314
+ if self.in_keys:
315
+ _td = next_tensordict.select(*self.in_keys).to_tensordict()
316
+ self.td.append(_td)
317
+ return next_tensordict
318
+
319
+ def dump(self, suffix: str | None = None) -> None:
320
+ if suffix is None:
321
+ tag = self.tag
322
+ else:
323
+ tag = "_".join([self.tag, suffix])
324
+
325
+ td = self.td
326
+ if self.skip_reset:
327
+ td = td[1:]
328
+ torch.save(
329
+ torch.stack(td, 0).contiguous(),
330
+ f"{tag}_tensordict.t",
331
+ )
332
+ self.iter += 1
333
+ self.count = 0
334
+ del self.td
335
+ self.td = []
336
+
337
+ def _reset(
338
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
339
+ ) -> TensorDictBase:
340
+ self._call(tensordict_reset)
341
+ return tensordict_reset
342
+
343
+
344
+ class PixelRenderTransform(Transform):
345
+ """A transform to call render on the parent environment and register the pixel observation in the tensordict.
346
+
347
+ This transform offers an alternative to the ``from_pixels`` syntactic sugar when instantiating an environment
348
+ that offers rendering is expensive, or when ``from_pixels`` is not implemented.
349
+ It can be used within a single environment or over batched environments alike.
350
+
351
+ Args:
352
+ out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations.
353
+ preproc (Callable, optional): a preproc function. Can be used to reshape the observation, or apply
354
+ any other transformation that makes it possible to register it in the output data.
355
+ as_non_tensor (bool, optional): if ``True``, the data will be written as a :class:`~tensordict.NonTensorData`
356
+ thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the
357
+ input data type and shape.
358
+ render_method (str, optional): the name of the render method. Defaults to ``"render"``.
359
+ pass_tensordict (bool, optional): if ``True``, the input tensordict will be passed to the
360
+ render method. This enables rendering for stateless environments. Defaults to ``False``.
361
+ **kwargs: additional keyword arguments to pass to the render function (e.g. ``mode="rgb_array"``).
362
+
363
+ Examples:
364
+ >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
365
+ >>> from torchrl.record.loggers import CSVLogger
366
+ >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
367
+ >>>
368
+ >>> def make_env():
369
+ >>> env = GymEnv("CartPole-v1", render_mode="rgb_array")
370
+ >>> env = env.append_transform(PixelRenderTransform())
371
+ >>> return env
372
+ >>>
373
+ >>> if __name__ == "__main__":
374
+ ... logger = CSVLogger("dummy", video_format="mp4")
375
+ ...
376
+ ... env = ParallelEnv(4, EnvCreator(make_env))
377
+ ...
378
+ ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
379
+ ... env.rollout(3)
380
+ ...
381
+ ... check_env_specs(env)
382
+ ...
383
+ ... r = env.rollout(30)
384
+ ... print(env)
385
+ ... env.transform.dump()
386
+ ... env.close()
387
+
388
+ This transform can also be used whenever a batched environment ``render()`` returns a single image:
389
+
390
+ Examples:
391
+ >>> from torchrl.envs import check_env_specs
392
+ >>> from torchrl.envs.libs.vmas import VmasEnv
393
+ >>> from torchrl.record.loggers import CSVLogger
394
+ >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
395
+ >>>
396
+ >>> env = VmasEnv(
397
+ ... scenario="flocking",
398
+ ... num_envs=32,
399
+ ... continuous_actions=True,
400
+ ... max_steps=200,
401
+ ... device="cpu",
402
+ ... seed=None,
403
+ ... # Scenario kwargs
404
+ ... n_agents=5,
405
+ ... )
406
+ >>>
407
+ >>> logger = CSVLogger("dummy", video_format="mp4")
408
+ >>>
409
+ >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy()))
410
+ >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
411
+ >>>
412
+ >>> check_env_specs(env)
413
+ >>>
414
+ >>> r = env.rollout(30)
415
+ >>> env.transform[-1].dump()
416
+
417
+ The transform can be disabled using the :meth:`~torchrl.record.PixelRenderTransform.switch` method, which will
418
+ turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behavior).
419
+ Since transforms are :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control
420
+ this behavior:
421
+
422
+ >>> def switch(module):
423
+ ... if isinstance(module, PixelRenderTransform):
424
+ ... module.switch()
425
+ >>> env.apply(switch)
426
+
427
+ """
428
+
429
+ def __init__(
430
+ self,
431
+ out_keys: list[NestedKey] = None,
432
+ preproc: Callable[
433
+ [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor
434
+ ] = None,
435
+ as_non_tensor: bool | None = None,
436
+ render_method: str = "render",
437
+ pass_tensordict: bool = False,
438
+ **kwargs,
439
+ ) -> None:
440
+ if out_keys is None:
441
+ out_keys = ["pixels"]
442
+ elif isinstance(out_keys, (str, tuple)):
443
+ out_keys = [out_keys]
444
+ if len(out_keys) != 1:
445
+ raise RuntimeError(
446
+ f"Expected one and only one out_key, got out_keys={out_keys}"
447
+ )
448
+ if preproc is not None and not _can_be_pickled(preproc):
449
+ preproc = CloudpickleWrapper(preproc)
450
+ self.preproc = preproc
451
+ self.as_non_tensor = as_non_tensor
452
+ self.kwargs = kwargs
453
+ self.render_method = render_method
454
+ self._enabled = True
455
+ self.pass_tensordict = pass_tensordict
456
+ super().__init__(in_keys=[], out_keys=out_keys)
457
+
458
+ def _reset(
459
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
460
+ ) -> TensorDictBase:
461
+ return self._call(tensordict_reset)
462
+
463
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
464
+ if not self._enabled:
465
+ return next_tensordict
466
+
467
+ method = getattr(self.parent, self.render_method)
468
+ if not self.pass_tensordict:
469
+ array = method(**self.kwargs)
470
+ else:
471
+ array = method(next_tensordict, **self.kwargs)
472
+
473
+ if self.preproc:
474
+ array = self.preproc(array)
475
+ if self.as_non_tensor is None:
476
+ if isinstance(array, list):
477
+ if isinstance(array[0], np.ndarray):
478
+ array = np.asarray(array)
479
+ else:
480
+ array = torch.as_tensor(array)
481
+ if (
482
+ array.ndim == 3
483
+ and array.shape[-1] == 3
484
+ and self.parent.batch_size != ()
485
+ ):
486
+ self.as_non_tensor = True
487
+ else:
488
+ self.as_non_tensor = False
489
+ if not self.as_non_tensor:
490
+ try:
491
+ next_tensordict.set(self.out_keys[0], array)
492
+ except Exception:
493
+ raise RuntimeError(
494
+ f"An exception was raised while writing the rendered array "
495
+ f"(shape={getattr(array, 'shape', None)}, dtype={getattr(array, 'dtype', None)}) in the tensordict with shape {next_tensordict.shape}. "
496
+ f"Consider adapting your preproc function in {type(self).__name__}. You can also "
497
+ f"pass keyword arguments to the render function of the parent environment, or save "
498
+ f"this observation as a non-tensor data with as_non_tensor=True."
499
+ )
500
+ else:
501
+ next_tensordict.set_non_tensor(self.out_keys[0], array)
502
+ return next_tensordict
503
+
504
+ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
505
+ # Adds the pixel observation spec by calling render on the parent env
506
+ switch = False
507
+ if not self.enabled:
508
+ switch = True
509
+ self.switch()
510
+ parent = self.parent
511
+ td_in = parent.reset()
512
+ self._call(td_in)
513
+ obs = td_in.get(self.out_keys[0])
514
+ if isinstance(obs, NonTensorData):
515
+ spec = NonTensor(device=obs.device, dtype=obs.dtype, shape=obs.shape)
516
+ else:
517
+ spec = Unbounded(device=obs.device, dtype=obs.dtype, shape=obs.shape)
518
+ observation_spec[self.out_keys[0]] = spec
519
+ if switch:
520
+ self.switch()
521
+ return observation_spec
522
+
523
+ def switch(self, mode: str | bool = None):
524
+ """Sets the transform on or off.
525
+
526
+ Args:
527
+ mode (str or bool, optional): if provided, sets the switch to the desired mode.
528
+ ``"on"``, ``"off"``, ``True`` and ``False`` are accepted values.
529
+ By default, ``switch`` sets the mode to the opposite of the current one.
530
+
531
+ """
532
+ if mode is None:
533
+ mode = not self._enabled
534
+ if not isinstance(mode, bool):
535
+ if mode not in ("on", "off"):
536
+ raise ValueError("mode must be either 'on' or 'off', or a boolean.")
537
+ mode = mode == "on"
538
+ self._enabled = mode
539
+
540
+ @property
541
+ def enabled(self) -> bool:
542
+ """Whether the recorder is enabled."""
543
+ return self._enabled
544
+
545
+ def set_container(self, container: Transform | EnvBase) -> None:
546
+ out = super().set_container(container)
547
+ if isinstance(self.parent, EnvBase):
548
+ # Start the env if needed
549
+ method = getattr(self.parent, self.render_method, None)
550
+ if method is None or not callable(method):
551
+ raise ValueError(
552
+ f"The render method must exist and be a callable. Got render={method}."
553
+ )
554
+ return out
@@ -0,0 +1,79 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Distributed service registry for TorchRL.
7
+
8
+ This module provides a service registry for managing distributed actors
9
+ (tokenizers, replay buffers, etc.) that can be accessed across workers.
10
+
11
+ Example:
12
+ >>> from torchrl.services import get_services
13
+ >>>
14
+ >>> # Worker 1: Register a tokenizer service
15
+ >>> services = get_services()
16
+ >>> services.register("tokenizer", TokenizerClass, num_cpus=1, num_gpus=0.1)
17
+ >>>
18
+ >>> # Worker 2: Access the same tokenizer
19
+ >>> services = get_services()
20
+ >>> tokenizer = services["tokenizer"]
21
+ >>> result = tokenizer.encode.remote(text)
22
+ """
23
+ from __future__ import annotations
24
+
25
+ from torchrl.services.base import ServiceBase
26
+ from torchrl.services.ray_service import RayService
27
+
28
+ __all__ = ["ServiceBase", "RayService", "get_services"]
29
+
30
+
31
+ def get_services(backend: str = "ray", **init_kwargs) -> ServiceBase:
32
+ """Get a distributed service registry.
33
+
34
+ This function creates or retrieves a service registry for managing distributed
35
+ actors across workers. Services registered by one worker are immediately visible
36
+ to all other workers in the cluster.
37
+
38
+ Args:
39
+ backend: Service backend to use. Currently only "ray" is supported.
40
+ **init_kwargs: Backend-specific initialization arguments.
41
+ For Ray:
42
+
43
+ - ray_init_config (dict, optional): Arguments to pass to ray.init()
44
+ - namespace (str, optional): Ray namespace for service isolation.
45
+ Defaults to "torchrl_services".
46
+
47
+ Returns:
48
+ ServiceBase: A service registry instance.
49
+
50
+ Raises:
51
+ ValueError: If an unsupported backend is specified.
52
+ ImportError: If the required backend library is not installed.
53
+
54
+ Examples:
55
+ >>> # Basic usage - register and access services
56
+ >>> services = get_services()
57
+ >>> services.register("tokenizer", TokenizerClass, num_cpus=1)
58
+ >>> tokenizer = services["tokenizer"]
59
+ >>>
60
+ >>> # With custom Ray initialization
61
+ >>> services = get_services(
62
+ ... backend="ray",
63
+ ... ray_init_config={"address": "auto"},
64
+ ... namespace="my_experiment"
65
+ ... )
66
+ >>>
67
+ >>> # Check if service exists
68
+ >>> if "tokenizer" in services:
69
+ ... tokenizer = services["tokenizer"]
70
+ >>>
71
+ >>> # List all registered services
72
+ >>> service_names = services.list()
73
+ """
74
+ if backend != "ray":
75
+ raise ValueError(
76
+ f"Unsupported backend: {backend}. Currently only 'ray' is supported."
77
+ )
78
+
79
+ return RayService(**init_kwargs)