torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/services/base.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ServiceBase(ABC):
|
|
12
|
+
"""Base class for distributed service registries.
|
|
13
|
+
|
|
14
|
+
A service registry manages distributed actors/services that can be accessed
|
|
15
|
+
across multiple workers. Common use cases include:
|
|
16
|
+
|
|
17
|
+
- Tokenizers shared across inference workers
|
|
18
|
+
- Replay buffers for distributed training
|
|
19
|
+
- Model registries for centralized model storage
|
|
20
|
+
- Metrics aggregators
|
|
21
|
+
|
|
22
|
+
The registry provides a dict-like interface for registering and accessing
|
|
23
|
+
services by name.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def register(self, name: str, service_factory: type, *args, **kwargs) -> Any:
|
|
28
|
+
"""Register a service factory and create the service actor.
|
|
29
|
+
|
|
30
|
+
This method registers a service with the given name and immediately
|
|
31
|
+
creates the corresponding actor. The service becomes globally visible
|
|
32
|
+
to all workers in the cluster.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
name: Unique identifier for the service. This name is used to
|
|
36
|
+
retrieve the service later.
|
|
37
|
+
service_factory: Class to instantiate as a remote actor.
|
|
38
|
+
*args: Positional arguments to pass to the service constructor.
|
|
39
|
+
**kwargs: Keyword arguments for both actor configuration and
|
|
40
|
+
service constructor. Actor configuration options are backend-specific
|
|
41
|
+
(e.g., num_cpus, num_gpus for Ray).
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The remote actor handle.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: If a service with this name already exists.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def get(self, name: str) -> Any:
|
|
52
|
+
"""Get a service by name.
|
|
53
|
+
|
|
54
|
+
Retrieves a previously registered service. If the service was registered
|
|
55
|
+
by another worker, this method will find it in the distributed registry.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
name: Service identifier.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The remote actor handle for the service.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
KeyError: If the service is not found.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def __contains__(self, name: str) -> bool:
|
|
69
|
+
"""Check if a service is registered.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
name: Service identifier.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if the service exists, False otherwise.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def list(self) -> list[str]:
|
|
80
|
+
"""List all registered service names.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
List of service names currently registered in the cluster.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def reset(self) -> None:
|
|
88
|
+
"""Reset the service registry.
|
|
89
|
+
|
|
90
|
+
This removes all registered services and cleans up associated resources.
|
|
91
|
+
After calling reset(), the registry will be empty and all service actors
|
|
92
|
+
will be terminated.
|
|
93
|
+
|
|
94
|
+
Warning:
|
|
95
|
+
This is a destructive operation. All services will be terminated and
|
|
96
|
+
any ongoing work will be interrupted.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __getitem__(self, name: str) -> Any:
|
|
100
|
+
"""Dict-like access: services["tokenizer"]."""
|
|
101
|
+
return self.get(name)
|
|
102
|
+
|
|
103
|
+
def __setitem__(self, name: str, service_factory: type) -> None:
|
|
104
|
+
"""Dict-like registration: services["tokenizer"] = TokenizerClass.
|
|
105
|
+
|
|
106
|
+
Note: This only supports service_factory without additional arguments.
|
|
107
|
+
For full control, use register() method instead.
|
|
108
|
+
"""
|
|
109
|
+
self.register(name, service_factory)
|
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from torchrl._utils import logger
|
|
10
|
+
from torchrl.services.base import ServiceBase
|
|
11
|
+
|
|
12
|
+
RAY_ERR = None
|
|
13
|
+
try:
|
|
14
|
+
import ray
|
|
15
|
+
|
|
16
|
+
_has_ray = True
|
|
17
|
+
except ImportError as err:
|
|
18
|
+
_has_ray = False
|
|
19
|
+
RAY_ERR = err
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _ServiceRegistryActor:
|
|
23
|
+
"""Internal actor that maintains the list of registered services.
|
|
24
|
+
|
|
25
|
+
This is a lightweight actor (1 CPU) that tracks which services have been
|
|
26
|
+
registered in a namespace. This ensures we only list our own services,
|
|
27
|
+
not other named actors in Ray.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self._services: set[str] = set()
|
|
32
|
+
|
|
33
|
+
def add(self, name: str) -> None:
|
|
34
|
+
"""Add a service to the registry."""
|
|
35
|
+
self._services.add(name)
|
|
36
|
+
|
|
37
|
+
def remove(self, name: str) -> None:
|
|
38
|
+
"""Remove a service from the registry."""
|
|
39
|
+
self._services.discard(name)
|
|
40
|
+
|
|
41
|
+
def list(self) -> list[str]:
|
|
42
|
+
"""List all registered services."""
|
|
43
|
+
return sorted(self._services)
|
|
44
|
+
|
|
45
|
+
def clear(self) -> None:
|
|
46
|
+
"""Clear all registered services."""
|
|
47
|
+
self._services.clear()
|
|
48
|
+
|
|
49
|
+
def contains(self, name: str) -> bool:
|
|
50
|
+
"""Check if a service is registered."""
|
|
51
|
+
return name in self._services
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RayService(ServiceBase):
|
|
55
|
+
"""Ray-based distributed service registry.
|
|
56
|
+
|
|
57
|
+
This class uses Ray's named actors feature to provide truly distributed
|
|
58
|
+
service discovery. When a service is registered by any worker, it becomes
|
|
59
|
+
immediately accessible to all other workers in the Ray cluster.
|
|
60
|
+
|
|
61
|
+
Services are registered as Ray actors with globally unique names. This
|
|
62
|
+
ensures that:
|
|
63
|
+
1. Services persist independently of the registering worker
|
|
64
|
+
2. All workers see the same services instantly
|
|
65
|
+
3. No custom synchronization is needed
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
ray_init_config (dict, optional): Configuration for ray.init(). Only
|
|
69
|
+
used if Ray is not already initialized. Common options:
|
|
70
|
+
- address (str): Ray cluster address, or "auto" to auto-detect
|
|
71
|
+
- num_cpus (int): Number of CPUs to use
|
|
72
|
+
- num_gpus (int): Number of GPUs to use
|
|
73
|
+
namespace (str, optional): Ray namespace for service isolation. Services
|
|
74
|
+
in different namespaces are isolated from each other. Defaults to
|
|
75
|
+
"torchrl_services".
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
>>> # Basic usage
|
|
79
|
+
>>> services = RayService()
|
|
80
|
+
>>> services.register("tokenizer", TokenizerClass, num_cpus=1)
|
|
81
|
+
>>> tokenizer = services["tokenizer"]
|
|
82
|
+
>>>
|
|
83
|
+
>>> # With Ray options for dynamic configuration
|
|
84
|
+
>>> actor = services.register(
|
|
85
|
+
... "model",
|
|
86
|
+
... ModelClass,
|
|
87
|
+
... num_cpus=2,
|
|
88
|
+
... num_gpus=1,
|
|
89
|
+
... memory=10 * 1024**3,
|
|
90
|
+
... max_concurrency=4
|
|
91
|
+
... )
|
|
92
|
+
>>>
|
|
93
|
+
>>> # Check and retrieve
|
|
94
|
+
>>> if "tokenizer" in services:
|
|
95
|
+
... tok = services["tokenizer"]
|
|
96
|
+
>>>
|
|
97
|
+
>>> # List all services
|
|
98
|
+
>>> print(services.list())
|
|
99
|
+
['tokenizer', 'model']
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
ray_init_config: dict | None = None,
|
|
105
|
+
namespace: str = "torchrl_services",
|
|
106
|
+
):
|
|
107
|
+
if not _has_ray:
|
|
108
|
+
raise ImportError(
|
|
109
|
+
"Ray is required for RayService. Install with: pip install ray"
|
|
110
|
+
) from RAY_ERR
|
|
111
|
+
|
|
112
|
+
self._namespace = namespace
|
|
113
|
+
self._ensure_ray_initialized(ray_init_config)
|
|
114
|
+
self._registry_actor = self._get_or_create_registry_actor()
|
|
115
|
+
|
|
116
|
+
def _ensure_ray_initialized(self, ray_init_config: dict | None = None):
|
|
117
|
+
"""Initialize Ray if not already initialized."""
|
|
118
|
+
if not ray.is_initialized():
|
|
119
|
+
config = ray_init_config or {}
|
|
120
|
+
# Ensure namespace is set
|
|
121
|
+
if "namespace" not in config:
|
|
122
|
+
config["namespace"] = self._namespace
|
|
123
|
+
|
|
124
|
+
logger.info(f"Initializing Ray with namespace '{self._namespace}'")
|
|
125
|
+
ray.init(**config)
|
|
126
|
+
else:
|
|
127
|
+
# Ray already initialized - check if namespace matches
|
|
128
|
+
context = ray.get_runtime_context()
|
|
129
|
+
current_namespace = context.namespace
|
|
130
|
+
if current_namespace != self._namespace:
|
|
131
|
+
logger.warning(
|
|
132
|
+
f"Ray already initialized with namespace '{current_namespace}', "
|
|
133
|
+
f"but RayService is using namespace '{self._namespace}'. "
|
|
134
|
+
f"Services may not be visible across namespaces."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _make_service_name(self, name: str) -> str:
|
|
138
|
+
"""Create the full actor name with namespace prefix."""
|
|
139
|
+
return f"{self._namespace}::service::{name}"
|
|
140
|
+
|
|
141
|
+
def _get_registry_actor_name(self) -> str:
|
|
142
|
+
"""Get the name of the registry actor for this namespace."""
|
|
143
|
+
return f"{self._namespace}::_registry"
|
|
144
|
+
|
|
145
|
+
def _get_or_create_registry_actor(self):
|
|
146
|
+
"""Get or create the registry actor for this namespace."""
|
|
147
|
+
registry_name = self._get_registry_actor_name()
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
# Try to get existing registry
|
|
151
|
+
registry = ray.get_actor(registry_name, namespace=self._namespace)
|
|
152
|
+
return registry
|
|
153
|
+
except ValueError:
|
|
154
|
+
# Registry doesn't exist, create it
|
|
155
|
+
RemoteRegistry = ray.remote(_ServiceRegistryActor)
|
|
156
|
+
registry = RemoteRegistry.options(
|
|
157
|
+
name=registry_name,
|
|
158
|
+
namespace=self._namespace,
|
|
159
|
+
lifetime="detached",
|
|
160
|
+
num_cpus=1,
|
|
161
|
+
).remote()
|
|
162
|
+
logger.info(
|
|
163
|
+
f"Created service registry actor for namespace '{self._namespace}'"
|
|
164
|
+
)
|
|
165
|
+
return registry
|
|
166
|
+
|
|
167
|
+
def register(self, name: str, service_factory: type, *args, **kwargs) -> Any:
|
|
168
|
+
"""Register a service and create a named Ray actor.
|
|
169
|
+
|
|
170
|
+
This method creates a Ray actor with a globally unique name. The actor
|
|
171
|
+
becomes immediately visible to all workers in the cluster.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
name: Service identifier. Must be unique within the namespace.
|
|
175
|
+
service_factory: Class to instantiate as a Ray actor.
|
|
176
|
+
*args: Positional arguments for the service constructor.
|
|
177
|
+
**kwargs: Both Ray actor options (num_cpus, num_gpus, memory, etc.)
|
|
178
|
+
and service constructor arguments. Ray will filter out the actor
|
|
179
|
+
options it recognizes.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The Ray actor handle.
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
ValueError: If a service with this name already exists.
|
|
186
|
+
|
|
187
|
+
Examples:
|
|
188
|
+
>>> services = RayService()
|
|
189
|
+
>>>
|
|
190
|
+
>>> # Basic registration
|
|
191
|
+
>>> tokenizer = services.register("tokenizer", TokenizerClass)
|
|
192
|
+
>>>
|
|
193
|
+
>>> # With Ray resource specification
|
|
194
|
+
>>> buffer = services.register(
|
|
195
|
+
... "buffer",
|
|
196
|
+
... ReplayBuffer,
|
|
197
|
+
... num_cpus=2,
|
|
198
|
+
... num_gpus=0,
|
|
199
|
+
... size=1000000
|
|
200
|
+
... )
|
|
201
|
+
>>>
|
|
202
|
+
>>> # With advanced Ray options
|
|
203
|
+
>>> model = services.register(
|
|
204
|
+
... "model",
|
|
205
|
+
... ModelClass,
|
|
206
|
+
... num_cpus=4,
|
|
207
|
+
... num_gpus=1,
|
|
208
|
+
... memory=20 * 1024**3,
|
|
209
|
+
... max_concurrency=10,
|
|
210
|
+
... max_restarts=3,
|
|
211
|
+
... )
|
|
212
|
+
"""
|
|
213
|
+
full_name = self._make_service_name(name)
|
|
214
|
+
|
|
215
|
+
# Check if service already exists in our registry
|
|
216
|
+
if ray.get(self._registry_actor.contains.remote(name)):
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Service '{name}' already exists in namespace '{self._namespace}'. "
|
|
219
|
+
f"Use a different name or retrieve the existing service with get()."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Create the Ray remote class
|
|
223
|
+
# First, make it a remote class
|
|
224
|
+
remote_cls = ray.remote(service_factory)
|
|
225
|
+
|
|
226
|
+
# Then apply options including the name
|
|
227
|
+
options = {
|
|
228
|
+
"name": full_name,
|
|
229
|
+
"namespace": self._namespace,
|
|
230
|
+
"lifetime": "detached",
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
# Extract Ray-specific options from kwargs
|
|
234
|
+
ray_options = [
|
|
235
|
+
"num_cpus",
|
|
236
|
+
"num_gpus",
|
|
237
|
+
"memory",
|
|
238
|
+
"object_store_memory",
|
|
239
|
+
"resources",
|
|
240
|
+
"accelerator_type",
|
|
241
|
+
"max_concurrency",
|
|
242
|
+
"max_restarts",
|
|
243
|
+
"max_task_retries",
|
|
244
|
+
"max_pending_calls",
|
|
245
|
+
"scheduling_strategy",
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
for opt in ray_options:
|
|
249
|
+
if opt in kwargs:
|
|
250
|
+
options[opt] = kwargs.pop(opt)
|
|
251
|
+
|
|
252
|
+
# Apply options and create the actor
|
|
253
|
+
remote_actor = remote_cls.options(**options).remote(*args, **kwargs)
|
|
254
|
+
|
|
255
|
+
# Add to registry
|
|
256
|
+
ray.get(self._registry_actor.add.remote(name))
|
|
257
|
+
|
|
258
|
+
logger.info(
|
|
259
|
+
f"Registered service '{name}' as Ray actor '{full_name}' "
|
|
260
|
+
f"with options: {options}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return remote_actor
|
|
264
|
+
|
|
265
|
+
def get(self, name: str) -> Any:
|
|
266
|
+
"""Get a service by name.
|
|
267
|
+
|
|
268
|
+
Retrieves a service actor by name. The service can have been registered
|
|
269
|
+
by any worker in the cluster.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
name: Service identifier.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
The Ray actor handle.
|
|
276
|
+
|
|
277
|
+
Raises:
|
|
278
|
+
KeyError: If the service is not found.
|
|
279
|
+
|
|
280
|
+
Examples:
|
|
281
|
+
>>> services = RayService()
|
|
282
|
+
>>> tokenizer = services.get("tokenizer")
|
|
283
|
+
>>> # Use the actor
|
|
284
|
+
>>> result = ray.get(tokenizer.encode.remote("Hello world"))
|
|
285
|
+
"""
|
|
286
|
+
# Check registry first
|
|
287
|
+
if not ray.get(self._registry_actor.contains.remote(name)):
|
|
288
|
+
raise KeyError(
|
|
289
|
+
f"Service '{name}' not found in namespace '{self._namespace}'. "
|
|
290
|
+
f"Available services: {self.list()}"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
full_name = self._make_service_name(name)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
actor = ray.get_actor(full_name, namespace=self._namespace)
|
|
297
|
+
return actor
|
|
298
|
+
except ValueError as e:
|
|
299
|
+
# Service in registry but actor missing - inconsistency
|
|
300
|
+
logger.warning(
|
|
301
|
+
f"Service '{name}' in registry but actor not found. "
|
|
302
|
+
f"Removing from registry."
|
|
303
|
+
)
|
|
304
|
+
ray.get(self._registry_actor.remove.remote(name))
|
|
305
|
+
raise KeyError(
|
|
306
|
+
f"Service '{name}' actor not found (removed from registry). "
|
|
307
|
+
f"Available services: {self.list()}"
|
|
308
|
+
) from e
|
|
309
|
+
|
|
310
|
+
def __contains__(self, name: str) -> bool:
|
|
311
|
+
"""Check if a service is registered.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
name: Service identifier.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
True if the service exists, False otherwise.
|
|
318
|
+
|
|
319
|
+
Examples:
|
|
320
|
+
>>> services = RayService()
|
|
321
|
+
>>> if "tokenizer" in services:
|
|
322
|
+
... tokenizer = services["tokenizer"]
|
|
323
|
+
... else:
|
|
324
|
+
... services.register("tokenizer", TokenizerClass)
|
|
325
|
+
"""
|
|
326
|
+
return ray.get(self._registry_actor.contains.remote(name))
|
|
327
|
+
|
|
328
|
+
def list(self) -> list[str]:
|
|
329
|
+
"""List all registered service names.
|
|
330
|
+
|
|
331
|
+
Returns a list of all services in the current namespace. This includes
|
|
332
|
+
services registered by any worker.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
List of service names (without namespace prefix).
|
|
336
|
+
|
|
337
|
+
Examples:
|
|
338
|
+
>>> services = RayService()
|
|
339
|
+
>>> services.register("tokenizer", TokenizerClass)
|
|
340
|
+
>>> services.register("buffer", ReplayBuffer)
|
|
341
|
+
>>> print(services.list())
|
|
342
|
+
['buffer', 'tokenizer']
|
|
343
|
+
"""
|
|
344
|
+
return ray.get(self._registry_actor.list.remote())
|
|
345
|
+
|
|
346
|
+
def reset(self) -> None:
|
|
347
|
+
"""Reset the service registry by terminating all actors.
|
|
348
|
+
|
|
349
|
+
This method:
|
|
350
|
+
1. Terminates all service actors in the current namespace
|
|
351
|
+
2. Clears the registry actor's internal state
|
|
352
|
+
|
|
353
|
+
After calling reset(), all services will be removed and their actors
|
|
354
|
+
will be killed. Any ongoing work will be interrupted.
|
|
355
|
+
|
|
356
|
+
Warning:
|
|
357
|
+
This is a destructive operation that affects all workers in the
|
|
358
|
+
namespace. Use with caution.
|
|
359
|
+
|
|
360
|
+
Examples:
|
|
361
|
+
>>> services = RayService(namespace="experiment")
|
|
362
|
+
>>> services.register("tokenizer", TokenizerClass)
|
|
363
|
+
>>> print(services.list())
|
|
364
|
+
['tokenizer']
|
|
365
|
+
>>> services.reset()
|
|
366
|
+
>>> print(services.list())
|
|
367
|
+
[]
|
|
368
|
+
"""
|
|
369
|
+
service_names = self.list()
|
|
370
|
+
|
|
371
|
+
for name in service_names:
|
|
372
|
+
full_name = self._make_service_name(name)
|
|
373
|
+
try:
|
|
374
|
+
actor = ray.get_actor(full_name, namespace=self._namespace)
|
|
375
|
+
ray.kill(actor)
|
|
376
|
+
logger.info(f"Terminated service '{name}' (actor: {full_name})")
|
|
377
|
+
except ValueError:
|
|
378
|
+
# Actor already gone or doesn't exist
|
|
379
|
+
logger.warning(f"Service '{name}' not found during reset")
|
|
380
|
+
except Exception as e:
|
|
381
|
+
logger.warning(f"Failed to terminate service '{name}': {e}")
|
|
382
|
+
|
|
383
|
+
# Clear the registry
|
|
384
|
+
ray.get(self._registry_actor.clear.remote())
|
|
385
|
+
|
|
386
|
+
logger.info(
|
|
387
|
+
f"Reset complete for namespace '{self._namespace}'. Terminated {len(service_names)} services."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
def shutdown(self, raise_on_error: bool = True) -> None:
|
|
391
|
+
"""Shutdown the RayService by shutting down the Ray cluster."""
|
|
392
|
+
try:
|
|
393
|
+
self.reset()
|
|
394
|
+
# kill the registry actor
|
|
395
|
+
registry_actor = ray.get_actor(
|
|
396
|
+
self._get_registry_actor_name(), namespace=self._namespace
|
|
397
|
+
)
|
|
398
|
+
ray.kill(registry_actor, no_restart=True)
|
|
399
|
+
except Exception as e:
|
|
400
|
+
if raise_on_error:
|
|
401
|
+
raise e
|
|
402
|
+
else:
|
|
403
|
+
logger.warning(f"Error shutting down RayService: {e}")
|
|
404
|
+
|
|
405
|
+
def register_with_options(
|
|
406
|
+
self,
|
|
407
|
+
name: str,
|
|
408
|
+
service_factory: type,
|
|
409
|
+
actor_options: dict[str, Any],
|
|
410
|
+
**constructor_kwargs,
|
|
411
|
+
) -> Any:
|
|
412
|
+
"""Register a service with explicit separation of Ray options and constructor args.
|
|
413
|
+
|
|
414
|
+
This is a convenience method that makes it explicit which arguments are for
|
|
415
|
+
Ray actor configuration vs. the service constructor. It's functionally
|
|
416
|
+
equivalent to `register()` but more readable for complex configurations.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
name: Service identifier.
|
|
420
|
+
service_factory: Class to instantiate as a Ray actor.
|
|
421
|
+
actor_options: Dictionary of Ray actor options (num_cpus, num_gpus, etc.).
|
|
422
|
+
**constructor_kwargs: Arguments to pass to the service constructor.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
The Ray actor handle.
|
|
426
|
+
|
|
427
|
+
Examples:
|
|
428
|
+
>>> services = RayService()
|
|
429
|
+
>>>
|
|
430
|
+
>>> # Explicit separation of concerns
|
|
431
|
+
>>> model = services.register_with_options(
|
|
432
|
+
... "model",
|
|
433
|
+
... ModelClass,
|
|
434
|
+
... actor_options={
|
|
435
|
+
... "num_cpus": 4,
|
|
436
|
+
... "num_gpus": 1,
|
|
437
|
+
... "memory": 20 * 1024**3,
|
|
438
|
+
... "max_concurrency": 10
|
|
439
|
+
... },
|
|
440
|
+
... model_path="/path/to/checkpoint",
|
|
441
|
+
... batch_size=32
|
|
442
|
+
... )
|
|
443
|
+
>>>
|
|
444
|
+
>>> # Equivalent to:
|
|
445
|
+
>>> # services.register(
|
|
446
|
+
>>> # "model", ModelClass,
|
|
447
|
+
>>> # num_cpus=4, num_gpus=1, memory=20*1024**3, max_concurrency=10,
|
|
448
|
+
>>> # model_path="/path/to/checkpoint", batch_size=32
|
|
449
|
+
>>> # )
|
|
450
|
+
"""
|
|
451
|
+
# Merge actor_options into kwargs for register()
|
|
452
|
+
merged_kwargs = {**actor_options, **constructor_kwargs}
|
|
453
|
+
return self.register(name, service_factory, **merged_kwargs)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Testing utilities for TorchRL.
|
|
7
|
+
|
|
8
|
+
This module provides helper classes and utilities for testing TorchRL functionality,
|
|
9
|
+
particularly for distributed and Ray-based tests that require importable classes.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from torchrl.testing.assertions import (
|
|
13
|
+
check_rollout_consistency_multikey_env,
|
|
14
|
+
rand_reset,
|
|
15
|
+
rollout_consistency_assertion,
|
|
16
|
+
)
|
|
17
|
+
from torchrl.testing.env_creators import (
|
|
18
|
+
get_transform_out,
|
|
19
|
+
make_envs,
|
|
20
|
+
make_multithreaded_env,
|
|
21
|
+
)
|
|
22
|
+
from torchrl.testing.gym_helpers import (
|
|
23
|
+
BREAKOUT_VERSIONED,
|
|
24
|
+
CARTPOLE_VERSIONED,
|
|
25
|
+
CLIFFWALKING_VERSIONED,
|
|
26
|
+
HALFCHEETAH_VERSIONED,
|
|
27
|
+
PENDULUM_VERSIONED,
|
|
28
|
+
PONG_VERSIONED,
|
|
29
|
+
)
|
|
30
|
+
from torchrl.testing.llm_mocks import (
|
|
31
|
+
DummyStrDataLoader,
|
|
32
|
+
DummyTensorDataLoader,
|
|
33
|
+
MockTransformerConfig,
|
|
34
|
+
MockTransformerModel,
|
|
35
|
+
MockTransformerOutput,
|
|
36
|
+
)
|
|
37
|
+
from torchrl.testing.modules import (
|
|
38
|
+
BiasModule,
|
|
39
|
+
call_value_nets,
|
|
40
|
+
LSTMNet,
|
|
41
|
+
NonSerializableBiasModule,
|
|
42
|
+
)
|
|
43
|
+
from torchrl.testing.ray_helpers import (
|
|
44
|
+
WorkerTransformerDoubleBuffer,
|
|
45
|
+
WorkerTransformerNCCL,
|
|
46
|
+
WorkerVLLMDoubleBuffer,
|
|
47
|
+
WorkerVLLMNCCL,
|
|
48
|
+
)
|
|
49
|
+
from torchrl.testing.utils import (
|
|
50
|
+
capture_log_records,
|
|
51
|
+
dtype_fixture,
|
|
52
|
+
generate_seeds,
|
|
53
|
+
get_available_devices,
|
|
54
|
+
get_default_devices,
|
|
55
|
+
IS_WIN,
|
|
56
|
+
make_tc,
|
|
57
|
+
mp_ctx,
|
|
58
|
+
PYTHON_3_9,
|
|
59
|
+
retry,
|
|
60
|
+
set_global_var,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
__all__ = [
|
|
64
|
+
# Assertions
|
|
65
|
+
"check_rollout_consistency_multikey_env",
|
|
66
|
+
"rand_reset",
|
|
67
|
+
"rollout_consistency_assertion",
|
|
68
|
+
# Environment creators
|
|
69
|
+
"get_transform_out",
|
|
70
|
+
"make_envs",
|
|
71
|
+
"make_multithreaded_env",
|
|
72
|
+
# Gym helpers
|
|
73
|
+
"BREAKOUT_VERSIONED",
|
|
74
|
+
"CARTPOLE_VERSIONED",
|
|
75
|
+
"CLIFFWALKING_VERSIONED",
|
|
76
|
+
"HALFCHEETAH_VERSIONED",
|
|
77
|
+
"PENDULUM_VERSIONED",
|
|
78
|
+
"PONG_VERSIONED",
|
|
79
|
+
# LLM mocks
|
|
80
|
+
"DummyStrDataLoader",
|
|
81
|
+
"DummyTensorDataLoader",
|
|
82
|
+
"MockTransformerConfig",
|
|
83
|
+
"MockTransformerModel",
|
|
84
|
+
"MockTransformerOutput",
|
|
85
|
+
# Modules
|
|
86
|
+
"BiasModule",
|
|
87
|
+
"call_value_nets",
|
|
88
|
+
"LSTMNet",
|
|
89
|
+
"NonSerializableBiasModule",
|
|
90
|
+
# Ray helpers
|
|
91
|
+
"WorkerTransformerDoubleBuffer",
|
|
92
|
+
"WorkerTransformerNCCL",
|
|
93
|
+
"WorkerVLLMDoubleBuffer",
|
|
94
|
+
"WorkerVLLMNCCL",
|
|
95
|
+
# Utils
|
|
96
|
+
"capture_log_records",
|
|
97
|
+
"dtype_fixture",
|
|
98
|
+
"generate_seeds",
|
|
99
|
+
"get_available_devices",
|
|
100
|
+
"get_default_devices",
|
|
101
|
+
"IS_WIN",
|
|
102
|
+
"make_tc",
|
|
103
|
+
"mp_ctx",
|
|
104
|
+
"PYTHON_3_9",
|
|
105
|
+
"retry",
|
|
106
|
+
"set_global_var",
|
|
107
|
+
]
|