torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,663 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from abc import ABC, abstractmethod
8
+ from collections.abc import Callable
9
+ from functools import wraps
10
+ from typing import overload, TypeVar
11
+
12
+ import torch
13
+ from tensordict import is_tensor_collection
14
+ from tensordict.base import TensorDictBase
15
+
16
+ from torchrl.data.tensor_specs import DEVICE_TYPING, TensorSpec
17
+ from torchrl.envs.common import EnvBase
18
+ from torchrl.envs.transforms.transforms import Transform
19
+
20
+ T = TypeVar("T")
21
+
22
+
23
+ @overload
24
+ def _maybe_to_device(r: tuple, device: DEVICE_TYPING) -> tuple:
25
+ ...
26
+
27
+
28
+ @overload
29
+ def _maybe_to_device(r: list, device: DEVICE_TYPING) -> list:
30
+ ...
31
+
32
+
33
+ @overload
34
+ def _maybe_to_device(r: dict, device: DEVICE_TYPING) -> dict:
35
+ ...
36
+
37
+
38
+ @overload
39
+ def _maybe_to_device(r: TensorDictBase, device: DEVICE_TYPING) -> TensorDictBase:
40
+ ...
41
+
42
+
43
+ @overload
44
+ def _maybe_to_device(r: T, device: DEVICE_TYPING) -> T:
45
+ ...
46
+
47
+
48
+ def _maybe_to_device(r, device):
49
+ if isinstance(r, tuple):
50
+ return tuple(_maybe_to_device(r_i, device) for r_i in r)
51
+ if isinstance(r, list):
52
+ return [_maybe_to_device(r_i, device) for r_i in r]
53
+ if isinstance(r, dict):
54
+ return {k: _maybe_to_device(v, device) for k, v in r.items()}
55
+ if hasattr(r, "to"):
56
+ return r.to(device)
57
+ return r
58
+
59
+
60
+ @overload
61
+ def _maybe_clear_device(r: tuple) -> tuple:
62
+ ...
63
+
64
+
65
+ @overload
66
+ def _maybe_clear_device(r: list) -> list:
67
+ ...
68
+
69
+
70
+ @overload
71
+ def _maybe_clear_device(r: dict) -> dict:
72
+ ...
73
+
74
+
75
+ @overload
76
+ def _maybe_clear_device(r: TensorDictBase) -> TensorDictBase:
77
+ ...
78
+
79
+
80
+ @overload
81
+ def _maybe_clear_device(r: T) -> T:
82
+ ...
83
+
84
+
85
+ def _maybe_clear_device(r):
86
+ if isinstance(r, tuple):
87
+ return tuple(_maybe_clear_device(r_i) for r_i in r)
88
+ if isinstance(r, list):
89
+ return [_maybe_clear_device(r_i) for r_i in r]
90
+ if isinstance(r, dict):
91
+ return {k: _maybe_clear_device(v) for k, v in r.items()}
92
+ if is_tensor_collection(r) or isinstance(r, TensorSpec):
93
+ r = r.clone()
94
+ r = r.cpu().clear_device_()
95
+ return r
96
+
97
+
98
+ def _map_input_output_device(func: Callable):
99
+ """Decorator that maps inputs to CPU and outputs to the local device.
100
+
101
+ This decorator ensures that:
102
+ 1. All inputs are moved to CPU before being sent to the remote Ray actor
103
+ 2. All outputs are moved to the local device (if set) after receiving from the Ray actor
104
+
105
+ Args:
106
+ func: The method to decorate
107
+
108
+ Returns:
109
+ The decorated method
110
+ """
111
+
112
+ @wraps(func)
113
+ def wrapper(self, *args, **kwargs):
114
+ args = _maybe_clear_device(args)
115
+ kwargs = _maybe_clear_device(kwargs)
116
+ r = func(self, *args, **kwargs)
117
+ if hasattr(self, "_device"):
118
+ if self._device is not None:
119
+ r = _maybe_to_device(r, self._device)
120
+ else:
121
+ r = _maybe_clear_device(r)
122
+ return r
123
+
124
+ return wrapper
125
+
126
+
127
+ class RayTransform(Transform, ABC):
128
+ """Base class for transforms that delegate operations to Ray remote actors.
129
+
130
+ This class provides a framework for creating transforms that offload their operations
131
+ to Ray remote actors, enabling:
132
+ - Resource isolation and dedicated CPU/GPU allocation
133
+ - Shared state across multiple environment instances
134
+ - Distributed computation for expensive operations
135
+
136
+ The class automatically handles:
137
+ - Ray actor lifecycle management (creation, reuse, cleanup)
138
+ - Device mapping between local client and remote actor contexts
139
+ - Transparent method delegation with proper error handling
140
+ - Local management of parent/container relationships
141
+
142
+ Subclasses only need to implement `_create_actor()` to specify how their
143
+ specific Ray actor should be created and configured.
144
+
145
+ Args:
146
+ num_cpus: CPU cores to allocate to the Ray actor
147
+ num_gpus: GPU devices to allocate to the Ray actor
148
+ device: Local device for tensor operations (client-side)
149
+ actor_name: Optional name for actor reuse across instances
150
+ **kwargs: Additional arguments passed to Transform base class
151
+
152
+ Example:
153
+ ```python
154
+ class MyRayTransform(RayTransform):
155
+ def _create_actor(self, **kwargs):
156
+ RemoteClass = self._ray.remote(num_cpus=self._num_cpus)(MyClass)
157
+ return RemoteClass.remote(**kwargs)
158
+ ```
159
+ """
160
+
161
+ @property
162
+ def _ray(self):
163
+ ray = self.__dict__.get("_ray_val", None)
164
+ if ray is not None:
165
+ return ray
166
+ # Import ray here to avoid requiring it as a dependency
167
+ try:
168
+ import ray
169
+ except ImportError:
170
+ raise ImportError(
171
+ "Ray is required for RayTransform. Install with: pip install ray"
172
+ )
173
+ self.__dict__["_ray_val"] = ray
174
+ return ray
175
+
176
+ @_ray.setter
177
+ def _ray(self, value):
178
+ self.__dict__["_ray_val"] = value
179
+
180
+ def __getstate__(self):
181
+ state = super().__getstate__()
182
+ state.pop("_ray_val", None)
183
+ return state
184
+
185
+ def __init__(
186
+ self,
187
+ *,
188
+ num_cpus: int | None = None,
189
+ num_gpus: int | None = None,
190
+ device: DEVICE_TYPING | None = None,
191
+ actor_name: str | None = None,
192
+ **kwargs,
193
+ ):
194
+ """Initialize the RayTransform.
195
+
196
+ Args:
197
+ num_cpus: Number of CPUs to allocate to the Ray actor
198
+ num_gpus: Number of GPUs to allocate to the Ray actor
199
+ device: Local device for tensor operations
200
+ actor_name: Name of the Ray actor (for reuse)
201
+ **kwargs: Additional arguments passed to Transform
202
+ """
203
+ super().__init__(in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys"))
204
+
205
+ self._num_cpus = num_cpus
206
+ self._num_gpus = num_gpus
207
+ self._device = device
208
+ self._actor_name = actor_name
209
+ self._actor = None
210
+
211
+ # Initialize the Ray actor
212
+ self._initialize_actor(**kwargs)
213
+
214
+ def _initialize_actor(self, **kwargs):
215
+ """Initialize the Ray actor, either by reusing existing or creating new."""
216
+ # First attempt to get the actor if it already exists
217
+ if self._actor_name is not None:
218
+ try:
219
+ existing_actor = self._ray.get_actor(self._actor_name)
220
+ self._actor = existing_actor
221
+ return
222
+ except ValueError:
223
+ pass
224
+
225
+ # Create new actor
226
+ self._actor = self._create_actor(**kwargs)
227
+
228
+ @abstractmethod
229
+ def _create_actor(self, **kwargs):
230
+ """Create and return a Ray actor.
231
+
232
+ This method should be implemented by subclasses to create the specific
233
+ Ray actor needed for their operations.
234
+
235
+ Args:
236
+ **kwargs: Additional arguments for actor creation
237
+
238
+ Returns:
239
+ The created Ray actor
240
+ """
241
+
242
+ # Container management - handled locally, not delegated to remote actor
243
+ def set_container(self, container: Transform | EnvBase) -> None:
244
+ """Set the container for this transform. This is handled locally."""
245
+ result = super().set_container(container)
246
+
247
+ # After setting the container locally, provide batch size information to the remote actor
248
+ # This ensures the remote actor has the right batch size for proper shape handling
249
+ if self.parent is not None:
250
+ parent_batch_size = self.parent.batch_size
251
+
252
+ # Set the batch size directly on the remote actor to override its initialization
253
+ self._ray.get(self._actor._set_attr.remote("batch_size", parent_batch_size))
254
+
255
+ # Also disable validation on the remote actor since we'll handle consistency locally
256
+ self._ray.get(self._actor._set_attr.remote("_validated", True))
257
+
258
+ return result
259
+
260
+ def reset_parent(self) -> None:
261
+ """Reset the parent. This is handled locally."""
262
+ return super().reset_parent()
263
+
264
+ def clone(self):
265
+ """Clone the transform."""
266
+ # Use the parent's clone method to properly copy all Transform attributes
267
+ new_instance = super().clone()
268
+ # Then copy our specific Ray attributes to share the same actor
269
+ new_instance._actor = self._actor
270
+ new_instance._ray = self._ray
271
+ new_instance._device = getattr(self, "_device", None)
272
+ new_instance._num_cpus = self._num_cpus
273
+ new_instance._num_gpus = self._num_gpus
274
+ new_instance._actor_name = self._actor_name
275
+ return new_instance
276
+
277
+ def empty_cache(self):
278
+ """Empty cache."""
279
+ super().empty_cache()
280
+ return self._ray.get(self._actor.empty_cache.remote())
281
+
282
+ @property
283
+ def container(self) -> EnvBase | None:
284
+ """Returns the env containing the transform. This is handled locally."""
285
+ return super().container
286
+
287
+ @property
288
+ def parent(self) -> EnvBase | None:
289
+ """Returns the parent env of the transform. This is handled locally."""
290
+ return super().parent
291
+
292
+ @property
293
+ def base_env(self):
294
+ """Returns the base environment. This traverses the parent chain locally."""
295
+ return (
296
+ getattr(self.parent, "base_env", None) if self.parent is not None else None
297
+ )
298
+
299
+ def __repr__(self):
300
+ """String representation."""
301
+ try:
302
+ if hasattr(self, "_actor") and self._actor is not None:
303
+ return self._ray.get(self._actor.__repr__.remote())
304
+ else:
305
+ return f"{self.__class__.__name__}(actor=None)"
306
+ except Exception:
307
+ return f"{self.__class__.__name__}(actor={getattr(self, '_actor', 'None')})"
308
+
309
+ # Properties - access via generic attribute getter since Ray doesn't support direct property access
310
+ @property
311
+ def device(self):
312
+ """Get device property."""
313
+ return getattr(self, "_device", None)
314
+
315
+ @device.setter
316
+ def device(self, value):
317
+ """Set device property."""
318
+ raise NotImplementedError(
319
+ f"device setter is not implemented for {self.__class__.__name__}. Use transform.to() instead."
320
+ )
321
+
322
+ # TensorDictPrimer methods
323
+ def init(self, tensordict: TensorDictBase | None):
324
+ """Initialize."""
325
+ return self._ray.get(self._actor.init.remote(tensordict))
326
+
327
+ @_map_input_output_device
328
+ def _reset_func(
329
+ self, tensordict: TensorDictBase | None, tensordict_reset: TensorDictBase | None
330
+ ) -> TensorDictBase | None:
331
+ """Reset function."""
332
+ result = self._ray.get(
333
+ self._actor._reset_func.remote(tensordict, tensordict_reset)
334
+ )
335
+ return result
336
+
337
+ @_map_input_output_device
338
+ def _reset(
339
+ self, tensordict: TensorDictBase | None, tensordict_reset: TensorDictBase | None
340
+ ) -> TensorDictBase | None:
341
+ """Reset method for TensorDictPrimer."""
342
+ return self._ray.get(self._actor._reset.remote(tensordict, tensordict_reset))
343
+
344
+ @_map_input_output_device
345
+ def _reset_env_preprocess(
346
+ self, tensordict: TensorDictBase | None
347
+ ) -> TensorDictBase | None:
348
+ """Reset environment preprocess - crucial for call_before_env_reset=True."""
349
+ return self._ray.get(self._actor._reset_env_preprocess.remote(tensordict))
350
+
351
+ def close(self):
352
+ """Close the transform."""
353
+ return self._ray.get(self._actor.close.remote())
354
+
355
+ @_map_input_output_device
356
+ def _apply_transform(self, obs: torch.Tensor | None) -> torch.Tensor | None:
357
+ """Apply transform."""
358
+ return self._ray.get(self._actor._apply_transform.remote(obs))
359
+
360
+ @_map_input_output_device
361
+ def _call(self, next_tensordict: TensorDictBase | None) -> TensorDictBase | None:
362
+ """Call method."""
363
+ return self._ray.get(self._actor._call.remote(next_tensordict))
364
+
365
+ @_map_input_output_device
366
+ def forward(self, tensordict: TensorDictBase | None) -> TensorDictBase | None:
367
+ """Forward pass."""
368
+ return self._ray.get(self._actor.forward.remote(tensordict))
369
+
370
+ @_map_input_output_device
371
+ def _inv_apply_transform(
372
+ self, state: TensorDictBase | None
373
+ ) -> TensorDictBase | None:
374
+ """Inverse apply transform."""
375
+ return self._ray.get(self._actor._inv_apply_transform.remote(state))
376
+
377
+ @_map_input_output_device
378
+ def _inv_call(self, tensordict: TensorDictBase | None) -> TensorDictBase | None:
379
+ """Inverse call."""
380
+ return self._ray.get(self._actor._inv_call.remote(tensordict))
381
+
382
+ @_map_input_output_device
383
+ def inv(self, tensordict: TensorDictBase | None) -> TensorDictBase | None:
384
+ """Inverse."""
385
+ return self._ray.get(self._actor.inv.remote(tensordict))
386
+
387
+ @_map_input_output_device
388
+ def _step(
389
+ self, tensordict: TensorDictBase | None, next_tensordict: TensorDictBase | None
390
+ ) -> TensorDictBase | None:
391
+ """Step method."""
392
+ return self._ray.get(self._actor._step.remote(tensordict, next_tensordict))
393
+
394
+ def transform_env_device(self, device):
395
+ """Transform environment device."""
396
+ return self._ray.get(self._actor.transform_env_device.remote(device))
397
+
398
+ def transform_env_batch_size(self, batch_size):
399
+ """Transform environment batch size."""
400
+ return self._ray.get(self._actor.transform_env_batch_size.remote(batch_size))
401
+
402
+ @_map_input_output_device
403
+ def transform_output_spec(self, output_spec):
404
+ """Transform output spec."""
405
+ return self._ray.get(self._actor.transform_output_spec.remote(output_spec))
406
+
407
+ @_map_input_output_device
408
+ def transform_input_spec(self, input_spec):
409
+ """Transform input spec."""
410
+ return self._ray.get(self._actor.transform_input_spec.remote(input_spec))
411
+
412
+ @_map_input_output_device
413
+ def transform_observation_spec(self, observation_spec):
414
+ """Transform observation spec."""
415
+ return self._ray.get(
416
+ self._actor.transform_observation_spec.remote(observation_spec)
417
+ )
418
+
419
+ @_map_input_output_device
420
+ def transform_reward_spec(self, reward_spec):
421
+ """Transform reward spec."""
422
+ return self._ray.get(self._actor.transform_reward_spec.remote(reward_spec))
423
+
424
+ @_map_input_output_device
425
+ def transform_done_spec(self, done_spec):
426
+ """Transform done spec."""
427
+ return self._ray.get(self._actor.transform_done_spec.remote(done_spec))
428
+
429
+ @_map_input_output_device
430
+ def transform_action_spec(self, action_spec):
431
+ """Transform action spec."""
432
+ return self._ray.get(self._actor.transform_action_spec.remote(action_spec))
433
+
434
+ @_map_input_output_device
435
+ def transform_state_spec(self, state_spec):
436
+ """Transform state spec."""
437
+ return self._ray.get(self._actor.transform_state_spec.remote(state_spec))
438
+
439
+ def dump(self, **kwargs):
440
+ """Dump method."""
441
+ return self._ray.get(self._actor.dump.remote(**kwargs))
442
+
443
+ def set_missing_tolerance(self, mode=False):
444
+ """Set missing tolerance."""
445
+ return self._ray.get(self._actor.set_missing_tolerance.remote(mode))
446
+
447
+ @property
448
+ def missing_tolerance(self):
449
+ """Get missing tolerance."""
450
+ return self._ray.get(self._actor.missing_tolerance.remote())
451
+
452
+ @property
453
+ def primers(self):
454
+ """Get primers."""
455
+ return self._ray.get(self._actor.__getattribute__.remote("primers"))
456
+
457
+ @primers.setter
458
+ def primers(self, value):
459
+ """Set primers."""
460
+ self.__dict__["_primers"] = value
461
+ if hasattr(self, "_actor"):
462
+ self._ray.get(self._actor._set_attr.remote("primers", value))
463
+
464
+ def to(self, *args, **kwargs):
465
+ """Move to device."""
466
+ # Parse the device from args/kwargs like torch does
467
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
468
+ *args, **kwargs
469
+ )
470
+ if device is not None:
471
+ self._device = device
472
+ # Don't delegate to remote actor - just register device locally
473
+ return super().to(*args, **kwargs)
474
+
475
+ # Properties that should be accessed from the remote actor
476
+ @property
477
+ def in_keys(self):
478
+ """Get in_keys property."""
479
+ return self._ray.get(self._actor.__getattribute__.remote("in_keys"))
480
+
481
+ @in_keys.setter
482
+ def in_keys(self, value):
483
+ """Set in_keys property."""
484
+ self.__dict__["_in_keys"] = value
485
+ if hasattr(self, "_actor"):
486
+ self._ray.get(self._actor._set_attr.remote("in_keys", value))
487
+
488
+ @property
489
+ def out_keys(self):
490
+ """Get out_keys property."""
491
+ return self._ray.get(self._actor.__getattribute__.remote("out_keys"))
492
+
493
+ @out_keys.setter
494
+ def out_keys(self, value):
495
+ """Set out_keys property."""
496
+ self.__dict__["_out_keys"] = value
497
+ if hasattr(self, "_actor"):
498
+ self._ray.get(self._actor._set_attr.remote("out_keys", value))
499
+
500
+ @property
501
+ def in_keys_inv(self):
502
+ """Get in_keys_inv property."""
503
+ return self._ray.get(self._actor.__getattribute__.remote("in_keys_inv"))
504
+
505
+ @in_keys_inv.setter
506
+ def in_keys_inv(self, value):
507
+ """Set in_keys_inv property."""
508
+ self.__dict__["_in_keys_inv"] = value
509
+ if hasattr(self, "_actor"):
510
+ self._ray.get(self._actor._set_attr.remote("in_keys_inv", value))
511
+
512
+ @property
513
+ def out_keys_inv(self):
514
+ """Get out_keys_inv property."""
515
+ return self._ray.get(self._actor.__getattribute__.remote("out_keys_inv"))
516
+
517
+ @out_keys_inv.setter
518
+ def out_keys_inv(self, value):
519
+ """Set out_keys_inv property."""
520
+ self.__dict__["_out_keys_inv"] = value
521
+ if hasattr(self, "_actor"):
522
+ self._ray.get(self._actor._set_attr.remote("out_keys_inv", value))
523
+
524
+ # Generic attribute access for any remaining attributes
525
+ def __getattr__(self, name):
526
+ """Get attribute from the remote actor.
527
+
528
+ This method should only be called for attributes that don't exist locally
529
+ and should be delegated to the remote actor (inward-facing).
530
+
531
+ Outward-facing attributes (parent, container, base_env, etc.) should be handled
532
+ by the Transform base class and never reach this method.
533
+ """
534
+ # Upward-facing attributes that should never be delegated to remote actor
535
+ upward_attrs = {"parent", "container", "base_env", "_parent", "_container"}
536
+
537
+ if name in upward_attrs:
538
+ # These should be handled by the local Transform implementation
539
+ raise AttributeError(
540
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
541
+ )
542
+
543
+ # Only delegate to remote actor if we're sure this is an inward-facing attribute
544
+ # and the actor is properly initialized
545
+ actor = self.__dict__.get("_actor", None)
546
+ if actor is None:
547
+ raise AttributeError(
548
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
549
+ )
550
+
551
+ # Only delegate specific DataLoadingPrimer methods/attributes to the remote actor
552
+ # This is a whitelist approach to be more conservative
553
+ delegated_methods = {
554
+ # DataLoadingPrimer methods that should be called on the remote actor
555
+ "_call",
556
+ "_reset",
557
+ "_inv_call",
558
+ "forward",
559
+ "inv",
560
+ "_apply_transform",
561
+ "_inv_apply_transform",
562
+ "_reset_func",
563
+ "init", # TensorDictPrimer specific methods
564
+ "primers",
565
+ "dataloader", # Properties
566
+ # Add other specific methods that should be delegated as needed
567
+ }
568
+
569
+ if name in delegated_methods:
570
+ try:
571
+ result = self._ray.get(getattr(actor, name).remote())
572
+ # If it's a method, wrap it to make remote calls
573
+ if callable(result):
574
+ return lambda *args, **kwargs: self._ray.get(
575
+ getattr(actor, name).remote(*args, **kwargs)
576
+ )
577
+ return result
578
+ except (AttributeError, TypeError):
579
+ # If that fails, it might be a callable method
580
+ try:
581
+ remote_method = getattr(actor, name)
582
+ return lambda *args, **kwargs: self._ray.get(
583
+ remote_method.remote(*args, **kwargs)
584
+ )
585
+ except AttributeError:
586
+ pass
587
+
588
+ # If not in our whitelist, don't delegate to remote actor
589
+ raise AttributeError(
590
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
591
+ )
592
+
593
+ def __setattr__(self, name, value):
594
+ """Set attribute on the remote actor or locally."""
595
+ # Local attributes that should never be delegated to remote actor
596
+ local_attrs = {
597
+ "_actor",
598
+ "_ray",
599
+ "_parent",
600
+ "_container",
601
+ "_missing_tolerance",
602
+ "_in_keys",
603
+ "_out_keys",
604
+ "_in_keys_inv",
605
+ "_out_keys_inv",
606
+ "in_keys",
607
+ "out_keys",
608
+ "in_keys_inv",
609
+ "out_keys_inv",
610
+ "_modules",
611
+ "_parameters",
612
+ "_buffers",
613
+ "_device",
614
+ }
615
+
616
+ if name in local_attrs:
617
+ super().__setattr__(name, value)
618
+ else:
619
+ # Try to set on remote actor for other attributes
620
+ try:
621
+ if hasattr(self, "_actor") and self._actor is not None:
622
+ self._ray.get(self._actor._set_attr.remote(name, value))
623
+ else:
624
+ super().__setattr__(name, value)
625
+ except Exception:
626
+ # Fall back to local setting for attributes that can't be set remotely
627
+ super().__setattr__(name, value)
628
+
629
+
630
+ class _RayServiceMetaClass(type):
631
+ """Metaclass that enables dynamic class selection based on use_ray_service parameter.
632
+
633
+ This metaclass allows a class to dynamically return either itself or a Ray-based
634
+ alternative class when instantiated with use_ray_service=True.
635
+
636
+ Usage:
637
+ >>> class MyRayClass():
638
+ ... def __init__(self, **kwargs):
639
+ ... ...
640
+ ...
641
+ >>> class MyClass(metaclass=_RayServiceMetaClass):
642
+ ... _RayServiceClass = MyRayClass
643
+ ...
644
+ ... def __init__(self, use_ray_service=False, **kwargs):
645
+ ... # Regular implementation
646
+ ... pass
647
+ ...
648
+ >>> # Returns MyClass instance
649
+ >>> obj1 = MyClass(use_ray_service=False)
650
+ >>>
651
+ >>> # Returns MyRayClass instance
652
+ >>> obj2 = MyClass(use_ray_service=True)
653
+ """
654
+
655
+ def __call__(cls, *args, use_ray_service=False, **kwargs):
656
+ if use_ray_service:
657
+ if not hasattr(cls, "_RayServiceClass"):
658
+ raise ValueError(
659
+ f"Class {cls.__name__} does not have a _RayServiceClass attribute"
660
+ )
661
+ return cls._RayServiceClass(*args, **kwargs)
662
+ else:
663
+ return super().__call__(*args, **kwargs)