torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1032 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import weakref
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from typing import Any, Literal
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import TensorDict
|
|
14
|
+
from tensordict.base import TensorDictBase
|
|
15
|
+
|
|
16
|
+
from torchrl._utils import logger as torchrl_logger
|
|
17
|
+
from torchrl.weight_update.utils import _resolve_model
|
|
18
|
+
from torchrl.weight_update.weight_sync_schemes import (
|
|
19
|
+
TransportBackend,
|
|
20
|
+
WeightStrategy,
|
|
21
|
+
WeightSyncScheme,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Default timeout for torch.distributed operations
|
|
25
|
+
_DIST_TIMEOUT = timedelta(seconds=60)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ConnectionInfo:
|
|
30
|
+
"""Connection info for Ray distributed computing.
|
|
31
|
+
|
|
32
|
+
Uses dataclass instead of UserDict to avoid Ray signature introspection
|
|
33
|
+
issues with UserDict's __class_getitem__ in Python 3.11+
|
|
34
|
+
(ValueError: no signature found for builtin type GenericAlias).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
master_addr: str
|
|
38
|
+
master_port: int
|
|
39
|
+
world_size: int
|
|
40
|
+
stateful_model: bool
|
|
41
|
+
|
|
42
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
43
|
+
"""Get a connection info value by key name.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
key (str): The attribute name to retrieve.
|
|
47
|
+
default: The default value if the attribute does not exist.
|
|
48
|
+
Defaults to None.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The value of the attribute, or the default if not found.
|
|
52
|
+
"""
|
|
53
|
+
return getattr(self, key, default)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class RayTransport:
|
|
57
|
+
"""Ray transport for communicating with a single Ray actor.
|
|
58
|
+
|
|
59
|
+
This transport handles weight updates for ONE specific remote actor
|
|
60
|
+
using torch.distributed for efficient weight transfer. Ray is used for
|
|
61
|
+
signaling/coordination, while the actual weight data is transferred via
|
|
62
|
+
torch.distributed send/recv operations.
|
|
63
|
+
|
|
64
|
+
Multiple transports are created for multiple actors, following the
|
|
65
|
+
same pattern as multiprocess collectors.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
remote_actor: The Ray actor handle for the remote collector/transform.
|
|
69
|
+
worker_idx (int, optional): The worker index for this remote actor.
|
|
70
|
+
Defaults to 0.
|
|
71
|
+
backend (str): The torch.distributed backend to use ("gloo" or "nccl").
|
|
72
|
+
Defaults to "gloo".
|
|
73
|
+
connection_info_name (str): Name of the Ray actor storing connection info.
|
|
74
|
+
Defaults to "connection_info".
|
|
75
|
+
model_id (str, optional): The model identifier for weight synchronization.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
*,
|
|
81
|
+
remote_actor=None,
|
|
82
|
+
worker_idx: int | None = None,
|
|
83
|
+
backend: str = "gloo",
|
|
84
|
+
connection_info_name: str = "connection_info",
|
|
85
|
+
model_id: str | None = None,
|
|
86
|
+
):
|
|
87
|
+
"""Initialize the RayTransport.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
remote_actor: The Ray actor handle for the remote collector/transform.
|
|
91
|
+
worker_idx (int, optional): The worker index for this remote actor.
|
|
92
|
+
Defaults to 0.
|
|
93
|
+
backend (str): The torch.distributed backend to use ("gloo" or "nccl").
|
|
94
|
+
Defaults to "gloo".
|
|
95
|
+
connection_info_name (str): Name of the Ray actor storing connection info.
|
|
96
|
+
Defaults to "connection_info".
|
|
97
|
+
model_id (str, optional): The model identifier for weight synchronization.
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
import ray
|
|
101
|
+
|
|
102
|
+
self.ray = ray
|
|
103
|
+
except ImportError:
|
|
104
|
+
raise ImportError("Ray is required for RayTransport")
|
|
105
|
+
self._remote_actor = remote_actor
|
|
106
|
+
self._worker_idx = worker_idx if worker_idx is not None else 0
|
|
107
|
+
self._backend = backend
|
|
108
|
+
self._connection_info_name = connection_info_name
|
|
109
|
+
self._model_id = model_id
|
|
110
|
+
|
|
111
|
+
# Distributed state
|
|
112
|
+
self._dist_initialized = False
|
|
113
|
+
self._weights_buffer: TensorDictBase | None = None
|
|
114
|
+
self._stateful_model: bool = True
|
|
115
|
+
|
|
116
|
+
# Async operation state
|
|
117
|
+
self._pending_future = None
|
|
118
|
+
self._pending_isend = None
|
|
119
|
+
|
|
120
|
+
# Model reference (set by scheme on receiver side)
|
|
121
|
+
self._model = None
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def _rank(self) -> int:
|
|
125
|
+
"""Get the torch.distributed rank for this worker.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
int: The rank (worker_idx + 1, since sender is rank 0).
|
|
129
|
+
"""
|
|
130
|
+
return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed
|
|
131
|
+
|
|
132
|
+
def set_model(self, model: Any) -> None:
|
|
133
|
+
"""Set the model for receiving weights.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
model: The model to receive weights into.
|
|
137
|
+
"""
|
|
138
|
+
self._model = model
|
|
139
|
+
|
|
140
|
+
# ========================================================================
|
|
141
|
+
# Sending Weights (Sender Side)
|
|
142
|
+
# ========================================================================
|
|
143
|
+
|
|
144
|
+
def send_weights(self, weights: Any) -> None:
|
|
145
|
+
"""Send weights to the remote actor via torch.distributed.
|
|
146
|
+
|
|
147
|
+
This method:
|
|
148
|
+
1. Signals the remote actor to start receiving via Ray remote call
|
|
149
|
+
2. Sends weights via torch.distributed.isend
|
|
150
|
+
3. Waits for both to complete
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
weights: The weights to send (typically a TensorDict).
|
|
154
|
+
"""
|
|
155
|
+
if self._remote_actor is None:
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
# Step 1: Signal the remote actor via Ray to start receiving (async)
|
|
159
|
+
future = self._remote_actor._receive_weights_scheme.remote()
|
|
160
|
+
|
|
161
|
+
# Step 2: Send weights via torch.distributed (async)
|
|
162
|
+
weights.isend(dst=self._rank)
|
|
163
|
+
|
|
164
|
+
# Step 3: Wait for the Ray call to complete (receiver has applied weights)
|
|
165
|
+
self.ray.get(future)
|
|
166
|
+
|
|
167
|
+
def send_weights_async(self, weights: Any) -> None:
|
|
168
|
+
"""Send weights to Ray actor without waiting for completion.
|
|
169
|
+
|
|
170
|
+
Use :meth:`wait_ack` to wait for completion after sending to all actors.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
weights: The weights to send (typically a TensorDict).
|
|
174
|
+
"""
|
|
175
|
+
if self._remote_actor is None:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# Step 1: Signal the actor via Ray to start receiving (async)
|
|
179
|
+
self._pending_future = self._remote_actor._receive_weights_scheme.remote()
|
|
180
|
+
|
|
181
|
+
# Step 2: Send weights via torch.distributed (async)
|
|
182
|
+
self._pending_isend = weights.isend(dst=self._rank, return_early=True)
|
|
183
|
+
|
|
184
|
+
def wait_ack(self) -> None:
|
|
185
|
+
"""Wait for Ray actor to finish applying weights.
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
RuntimeError: If no pending future exists (i.e., :meth:`send_weights_async`
|
|
189
|
+
was not called before this method).
|
|
190
|
+
"""
|
|
191
|
+
if self._pending_future is not None:
|
|
192
|
+
self.ray.get(self._pending_future)
|
|
193
|
+
if self._pending_isend is not None:
|
|
194
|
+
for fut in self._pending_isend:
|
|
195
|
+
fut.wait()
|
|
196
|
+
self._pending_future = None
|
|
197
|
+
self._pending_isend = None
|
|
198
|
+
else:
|
|
199
|
+
raise RuntimeError("No pending future. Did you call send_weights_async?")
|
|
200
|
+
|
|
201
|
+
# ========================================================================
|
|
202
|
+
# Receiving Weights (Receiver Side)
|
|
203
|
+
# ========================================================================
|
|
204
|
+
|
|
205
|
+
def receive_weights(
|
|
206
|
+
self,
|
|
207
|
+
timeout: float | None = None,
|
|
208
|
+
*,
|
|
209
|
+
weights: Any = None,
|
|
210
|
+
model: Any = None,
|
|
211
|
+
strategy: WeightStrategy | None = None,
|
|
212
|
+
) -> Any | None:
|
|
213
|
+
"""Receive weights from sender via torch.distributed.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
timeout: Maximum time to wait for weights (seconds). If None,
|
|
217
|
+
blocks until weights are received.
|
|
218
|
+
weights: Pre-allocated weight buffer to receive into.
|
|
219
|
+
model: The model to apply weights to.
|
|
220
|
+
strategy: Strategy for applying weights to the model.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
The received weights, or None if timeout expires.
|
|
224
|
+
"""
|
|
225
|
+
from torchrl.collectors.utils import _cast
|
|
226
|
+
|
|
227
|
+
# Use provided weights buffer or fallback to stored one
|
|
228
|
+
weights_buffer = weights if weights is not None else self._weights_buffer
|
|
229
|
+
if weights_buffer is None:
|
|
230
|
+
if model is None:
|
|
231
|
+
raise RuntimeError("No model available to receive weights")
|
|
232
|
+
if isinstance(model, torch.nn.Module):
|
|
233
|
+
weights_buffer = TensorDict.from_module(model)
|
|
234
|
+
weights_buffer = weights_buffer.data.apply(_cast, weights_buffer)
|
|
235
|
+
else:
|
|
236
|
+
weights_buffer = TensorDict(lock=True)
|
|
237
|
+
|
|
238
|
+
# Cache the weights buffer for future use
|
|
239
|
+
if self._weights_buffer is None:
|
|
240
|
+
self._weights_buffer = weights_buffer
|
|
241
|
+
|
|
242
|
+
# Receive weights from rank 0
|
|
243
|
+
if timeout is None:
|
|
244
|
+
# Blocking receive
|
|
245
|
+
weights_buffer.irecv(src=0)
|
|
246
|
+
else:
|
|
247
|
+
# Non-blocking receive with timeout support
|
|
248
|
+
futures = weights_buffer.irecv(src=0, return_premature=True)
|
|
249
|
+
if futures:
|
|
250
|
+
start_time = time.monotonic()
|
|
251
|
+
while True:
|
|
252
|
+
# Check if all futures are complete
|
|
253
|
+
all_complete = all(f.is_completed() for f in futures)
|
|
254
|
+
if all_complete:
|
|
255
|
+
break
|
|
256
|
+
# Check timeout
|
|
257
|
+
elapsed = time.monotonic() - start_time
|
|
258
|
+
if elapsed >= timeout:
|
|
259
|
+
# Timeout expired before receiving all weights
|
|
260
|
+
return None
|
|
261
|
+
# Small sleep to avoid busy-waiting
|
|
262
|
+
time.sleep(0.001)
|
|
263
|
+
|
|
264
|
+
# Apply weights to model
|
|
265
|
+
if not isinstance(model, torch.nn.Module):
|
|
266
|
+
if not weights_buffer.is_empty():
|
|
267
|
+
raise RuntimeError(
|
|
268
|
+
f"Cannot cast weights to model type: {type(model)} with weights: {weights_buffer}."
|
|
269
|
+
)
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
if strategy is not None:
|
|
273
|
+
strategy.apply_weights(model, weights_buffer)
|
|
274
|
+
else:
|
|
275
|
+
weights_buffer.to_module(model)
|
|
276
|
+
|
|
277
|
+
return weights_buffer
|
|
278
|
+
|
|
279
|
+
# ========================================================================
|
|
280
|
+
# Connection Setup
|
|
281
|
+
# ========================================================================
|
|
282
|
+
|
|
283
|
+
def setup_connection_and_weights_on_sender(self) -> None:
|
|
284
|
+
"""Initialize torch.distributed on sender side for this worker's rank.
|
|
285
|
+
|
|
286
|
+
This is called by the scheme after it has created the connection info
|
|
287
|
+
Ray actor. The actual ``init_process_group`` happens in the scheme since
|
|
288
|
+
it's a collective operation that needs to happen for rank 0.
|
|
289
|
+
|
|
290
|
+
Note:
|
|
291
|
+
This method exists for interface compatibility but the real work
|
|
292
|
+
happens in the scheme's :meth:`_setup_distributed_connection_sender`.
|
|
293
|
+
"""
|
|
294
|
+
# The scheme handles the collective init_process_group for rank 0.
|
|
295
|
+
# This method exists for interface compatibility but the real work
|
|
296
|
+
# happens in the scheme's _setup_distributed_connection_sender.
|
|
297
|
+
|
|
298
|
+
def setup_connection_and_weights_on_receiver(
|
|
299
|
+
self,
|
|
300
|
+
*,
|
|
301
|
+
worker_idx: int,
|
|
302
|
+
strategy: WeightStrategy | None = None,
|
|
303
|
+
model: Any | None = None,
|
|
304
|
+
weights: Any | None = None,
|
|
305
|
+
) -> Any:
|
|
306
|
+
"""Join torch.distributed process group and receive initial weights.
|
|
307
|
+
|
|
308
|
+
This method:
|
|
309
|
+
1. Retrieves connection info from the shared Ray actor
|
|
310
|
+
2. Initializes torch.distributed process group with rank=worker_idx+1
|
|
311
|
+
3. Receives weights if model is stateful
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
worker_idx (int): The worker index for this transport.
|
|
315
|
+
strategy (WeightStrategy, optional): The weight transmission strategy.
|
|
316
|
+
model (nn.Module or compatible, optional): The model to receive weights for.
|
|
317
|
+
weights (TensorDict, optional): Pre-allocated buffer for receiving weights.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
The received weights (TensorDict) if model is stateful, None otherwise.
|
|
321
|
+
"""
|
|
322
|
+
if self._dist_initialized:
|
|
323
|
+
# Already initialized, just receive weights if stateful
|
|
324
|
+
if self._stateful_model:
|
|
325
|
+
result = self.receive_weights(
|
|
326
|
+
weights=weights, model=model, strategy=strategy
|
|
327
|
+
)
|
|
328
|
+
return result[1] if result else None
|
|
329
|
+
return None
|
|
330
|
+
|
|
331
|
+
self._worker_idx = worker_idx
|
|
332
|
+
rank = self._rank
|
|
333
|
+
|
|
334
|
+
# Wait for connection info actor to be available
|
|
335
|
+
i = 0
|
|
336
|
+
while True:
|
|
337
|
+
try:
|
|
338
|
+
remote_connection_info = self.ray.get_actor(self._connection_info_name)
|
|
339
|
+
except ValueError:
|
|
340
|
+
i += 1
|
|
341
|
+
time.sleep(0.1)
|
|
342
|
+
continue
|
|
343
|
+
break
|
|
344
|
+
|
|
345
|
+
master_addr = self.ray.get(remote_connection_info.get.remote("master_addr"))
|
|
346
|
+
master_port = self.ray.get(remote_connection_info.get.remote("master_port"))
|
|
347
|
+
world_size = self.ray.get(remote_connection_info.get.remote("world_size"))
|
|
348
|
+
stateful_model = self.ray.get(
|
|
349
|
+
remote_connection_info.get.remote("stateful_model")
|
|
350
|
+
)
|
|
351
|
+
self._stateful_model = stateful_model
|
|
352
|
+
|
|
353
|
+
# Set environment variables for torch.distributed
|
|
354
|
+
os.environ["MASTER_ADDR"] = master_addr
|
|
355
|
+
os.environ["MASTER_PORT"] = str(master_port)
|
|
356
|
+
|
|
357
|
+
# Initialize process group on receiver
|
|
358
|
+
torch.distributed.init_process_group(
|
|
359
|
+
backend=self._backend,
|
|
360
|
+
rank=rank,
|
|
361
|
+
world_size=world_size,
|
|
362
|
+
)
|
|
363
|
+
self._dist_initialized = True
|
|
364
|
+
|
|
365
|
+
# Receive initial weights if model is stateful
|
|
366
|
+
if self._stateful_model:
|
|
367
|
+
return self.receive_weights(model=model, weights=weights, strategy=strategy)
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class RayWeightSyncScheme(WeightSyncScheme):
|
|
372
|
+
"""Weight synchronization for Ray distributed computing.
|
|
373
|
+
|
|
374
|
+
This scheme uses torch.distributed to synchronize weights across distributed
|
|
375
|
+
workers (Ray actors). The process group is initialized during the first
|
|
376
|
+
``synchronize_weights()`` call, with the sender as rank 0 and workers as
|
|
377
|
+
rank ``worker_idx + 1``.
|
|
378
|
+
|
|
379
|
+
Each remote collector gets its own transport, following the same pattern
|
|
380
|
+
as multiprocess collectors.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
|
|
384
|
+
Defaults to "tensordict".
|
|
385
|
+
backend (str): The torch.distributed backend to use ("gloo" or "nccl").
|
|
386
|
+
Defaults to "gloo".
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def connection_info_name(self) -> str:
|
|
391
|
+
"""Get the name of the Ray actor storing connection info.
|
|
392
|
+
|
|
393
|
+
Returns a unique name based on model_id to avoid collisions when
|
|
394
|
+
multiple schemes are used with different models.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
The connection info actor name.
|
|
398
|
+
"""
|
|
399
|
+
if self._model_id is not None:
|
|
400
|
+
return f"connection_info_{self._model_id}"
|
|
401
|
+
return "connection_info"
|
|
402
|
+
|
|
403
|
+
def __init__(
|
|
404
|
+
self,
|
|
405
|
+
strategy: Literal["tensordict", "state_dict"] = "tensordict",
|
|
406
|
+
backend: str = "gloo",
|
|
407
|
+
):
|
|
408
|
+
"""Initialize the RayWeightSyncScheme.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
|
|
412
|
+
Defaults to "tensordict".
|
|
413
|
+
backend (str): The torch.distributed backend to use ("gloo" or "nccl").
|
|
414
|
+
Defaults to "gloo".
|
|
415
|
+
"""
|
|
416
|
+
super().__init__(strategy)
|
|
417
|
+
self._backend = backend
|
|
418
|
+
self._dist_initialized = False
|
|
419
|
+
self._remote_collectors: list | None = None
|
|
420
|
+
self._num_workers: int = 0
|
|
421
|
+
|
|
422
|
+
@property
|
|
423
|
+
def model(self) -> Any | None:
|
|
424
|
+
"""Get the model associated with this scheme.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
The model if set, None otherwise.
|
|
428
|
+
"""
|
|
429
|
+
if self._model_ref is not None:
|
|
430
|
+
return self._model_ref()
|
|
431
|
+
if self._model_id is not None:
|
|
432
|
+
model = _resolve_model(self.context, self._model_id)
|
|
433
|
+
if model is None:
|
|
434
|
+
if self._model_id == "policy":
|
|
435
|
+
torchrl_logger.debug("Creating policy from factory.")
|
|
436
|
+
model = self.context.policy_factory[0]()
|
|
437
|
+
self.context.policy = model
|
|
438
|
+
else:
|
|
439
|
+
raise AttributeError(
|
|
440
|
+
f"Model {self._model_id} was `None` in context {self.context}"
|
|
441
|
+
)
|
|
442
|
+
self._model_ref = weakref.ref(model)
|
|
443
|
+
return model
|
|
444
|
+
|
|
445
|
+
@model.setter
|
|
446
|
+
def model(self, value: Any):
|
|
447
|
+
"""Set the model for this scheme.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
value: The model to set. If None, the setter is a no-op.
|
|
451
|
+
"""
|
|
452
|
+
if value is None:
|
|
453
|
+
return
|
|
454
|
+
self._model_ref = weakref.ref(value)
|
|
455
|
+
|
|
456
|
+
def create_transport(
|
|
457
|
+
self,
|
|
458
|
+
*,
|
|
459
|
+
remote_actor=None,
|
|
460
|
+
worker_idx: int | None = None,
|
|
461
|
+
# Legacy parameter name for backwards compatibility
|
|
462
|
+
remote_collector=None,
|
|
463
|
+
**kwargs,
|
|
464
|
+
) -> TransportBackend:
|
|
465
|
+
"""Create Ray-based transport for a specific remote actor.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
remote_actor: The Ray actor handle for the remote collector/transform.
|
|
469
|
+
worker_idx: The worker index for this remote actor.
|
|
470
|
+
remote_collector: Legacy alias for remote_actor.
|
|
471
|
+
**kwargs: Additional transport configuration.
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
RayTransport configured for this specific remote actor.
|
|
475
|
+
"""
|
|
476
|
+
# Support legacy parameter name
|
|
477
|
+
if remote_actor is None:
|
|
478
|
+
remote_actor = remote_collector
|
|
479
|
+
|
|
480
|
+
return RayTransport(
|
|
481
|
+
remote_actor=remote_actor,
|
|
482
|
+
worker_idx=worker_idx,
|
|
483
|
+
backend=self._backend,
|
|
484
|
+
connection_info_name=self.connection_info_name,
|
|
485
|
+
model_id=self._model_id,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
def _init_on_sender_impl(
|
|
489
|
+
self,
|
|
490
|
+
model_id: str,
|
|
491
|
+
context: Any = None,
|
|
492
|
+
**kwargs,
|
|
493
|
+
) -> None:
|
|
494
|
+
"""Initialize on the main process (sender side).
|
|
495
|
+
|
|
496
|
+
This method sets up the torch.distributed connection info and shares it
|
|
497
|
+
with all remote collectors so they can join the process group.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
model_id: Identifier for the model being synchronized
|
|
501
|
+
context: Optional context object providing remote_collectors
|
|
502
|
+
**kwargs: Alternative to context (remote_collectors, source_model, etc.)
|
|
503
|
+
"""
|
|
504
|
+
try:
|
|
505
|
+
import ray
|
|
506
|
+
|
|
507
|
+
self.ray = ray
|
|
508
|
+
except ImportError:
|
|
509
|
+
raise ImportError("Ray is required for RayWeightSyncScheme")
|
|
510
|
+
|
|
511
|
+
# Extract parameters from context or kwargs
|
|
512
|
+
if context is not None:
|
|
513
|
+
remote_collectors = getattr(context, "remote_collectors", None)
|
|
514
|
+
num_workers = getattr(context, "num_workers", None) or getattr(
|
|
515
|
+
context, "num_collectors", None
|
|
516
|
+
)
|
|
517
|
+
else:
|
|
518
|
+
remote_collectors = kwargs.get("remote_collectors")
|
|
519
|
+
num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors")
|
|
520
|
+
|
|
521
|
+
if remote_collectors is None:
|
|
522
|
+
raise ValueError("remote_collectors must be provided via context or kwargs")
|
|
523
|
+
if num_workers is None:
|
|
524
|
+
num_workers = len(remote_collectors) if remote_collectors else 0
|
|
525
|
+
|
|
526
|
+
# Store model_id and context on scheme
|
|
527
|
+
self.model_id = model_id
|
|
528
|
+
|
|
529
|
+
# Store remote collectors and num_workers for synchronize_weights
|
|
530
|
+
self._remote_collectors = list(remote_collectors)
|
|
531
|
+
self._num_workers = int(num_workers)
|
|
532
|
+
|
|
533
|
+
# Register each Ray actor with explicit transport kwargs
|
|
534
|
+
for worker_idx, remote_collector in enumerate(remote_collectors):
|
|
535
|
+
transport = self.create_transport(
|
|
536
|
+
remote_actor=remote_collector,
|
|
537
|
+
worker_idx=worker_idx,
|
|
538
|
+
)
|
|
539
|
+
self._register_worker_sender(
|
|
540
|
+
worker_idx=worker_idx,
|
|
541
|
+
transport=transport,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Set context with weak reference to avoid circular refs
|
|
545
|
+
if context is not None:
|
|
546
|
+
self.context = context
|
|
547
|
+
|
|
548
|
+
# Store source model reference if provided for automatic weight extraction
|
|
549
|
+
model = kwargs.get("model")
|
|
550
|
+
if model is not None:
|
|
551
|
+
self.model = model
|
|
552
|
+
|
|
553
|
+
# Note: Distributed connection setup is deferred to synchronize_weights
|
|
554
|
+
# because _receiver_schemes on workers won't exist until register_scheme_receiver is called
|
|
555
|
+
|
|
556
|
+
def _init_on_receiver_impl(
|
|
557
|
+
self,
|
|
558
|
+
model_id: str,
|
|
559
|
+
context: Any = None,
|
|
560
|
+
**kwargs,
|
|
561
|
+
) -> None:
|
|
562
|
+
"""Initialize on worker process (receiver side).
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
model_id: Identifier for the model being synchronized
|
|
566
|
+
context: Optional context object (typically the remote collector)
|
|
567
|
+
**kwargs: Optional parameters (worker_idx, model, etc.)
|
|
568
|
+
"""
|
|
569
|
+
try:
|
|
570
|
+
import ray
|
|
571
|
+
|
|
572
|
+
self.ray = ray
|
|
573
|
+
except ImportError:
|
|
574
|
+
raise ImportError("Ray is required for RayWeightSyncScheme")
|
|
575
|
+
|
|
576
|
+
# Store model_id and context on scheme
|
|
577
|
+
self.model_id = model_id
|
|
578
|
+
self.context = context
|
|
579
|
+
|
|
580
|
+
# Extract worker_idx from context or kwargs
|
|
581
|
+
if context is not None:
|
|
582
|
+
worker_idx = getattr(context, "worker_idx", None)
|
|
583
|
+
else:
|
|
584
|
+
worker_idx = kwargs.get("worker_idx")
|
|
585
|
+
|
|
586
|
+
self._worker_idx = worker_idx
|
|
587
|
+
|
|
588
|
+
# Resolve the target model on this worker
|
|
589
|
+
model = kwargs.get("model")
|
|
590
|
+
if model is not None:
|
|
591
|
+
self.model = model
|
|
592
|
+
# get the weights to possibly instantiate a copy of the model (policy factory with multi-collector)
|
|
593
|
+
self.weights # noqa
|
|
594
|
+
|
|
595
|
+
# Create and register transport for receiver side
|
|
596
|
+
# Note: create_transport returns TransportBackend but we know it's RayTransport
|
|
597
|
+
transport = self.create_transport(
|
|
598
|
+
remote_actor=None, # Receiver doesn't need actor handle
|
|
599
|
+
worker_idx=worker_idx,
|
|
600
|
+
)
|
|
601
|
+
if isinstance(transport, RayTransport):
|
|
602
|
+
transport.set_model(model)
|
|
603
|
+
self._register_transport_receiver(transport=transport)
|
|
604
|
+
|
|
605
|
+
def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None:
|
|
606
|
+
"""Set up torch.distributed connection info and share with remote collectors.
|
|
607
|
+
|
|
608
|
+
This method:
|
|
609
|
+
1. Gets master address and finds an available port
|
|
610
|
+
2. Stores connection info in Ray's object store as a named actor
|
|
611
|
+
3. Initializes torch.distributed process group with rank=0
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
timeout: Maximum time in seconds to wait for workers to be ready.
|
|
615
|
+
Default is 300 seconds (5 minutes).
|
|
616
|
+
"""
|
|
617
|
+
if self._dist_initialized:
|
|
618
|
+
return
|
|
619
|
+
|
|
620
|
+
if self._remote_collectors is None or self._num_workers == 0:
|
|
621
|
+
raise RuntimeError(
|
|
622
|
+
"_setup_distributed_connection() requires remote_collectors to be set"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Get master address (hostname/IP)
|
|
626
|
+
hostname = socket.gethostname()
|
|
627
|
+
try:
|
|
628
|
+
master_addr = socket.gethostbyname(hostname)
|
|
629
|
+
except socket.gaierror:
|
|
630
|
+
master_addr = "127.0.0.1"
|
|
631
|
+
|
|
632
|
+
# Find an available port
|
|
633
|
+
master_port = self._find_free_port()
|
|
634
|
+
world_size = self._num_workers + 1 # +1 for the sender (rank 0)
|
|
635
|
+
|
|
636
|
+
try:
|
|
637
|
+
self.weights
|
|
638
|
+
stateful_model = True
|
|
639
|
+
except (AttributeError, RuntimeError, ValueError):
|
|
640
|
+
stateful_model = False
|
|
641
|
+
self._stateful_model = stateful_model
|
|
642
|
+
|
|
643
|
+
# Connection info to share with workers via named Ray actor
|
|
644
|
+
RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options(
|
|
645
|
+
name=self.connection_info_name
|
|
646
|
+
)
|
|
647
|
+
self._connection_info_actor = RemoteConnectionInfo.remote(
|
|
648
|
+
master_addr=master_addr,
|
|
649
|
+
master_port=master_port,
|
|
650
|
+
world_size=world_size,
|
|
651
|
+
stateful_model=stateful_model,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Set environment variables for torch.distributed
|
|
655
|
+
os.environ["MASTER_ADDR"] = master_addr
|
|
656
|
+
os.environ["MASTER_PORT"] = str(master_port)
|
|
657
|
+
|
|
658
|
+
# Initialize process group on sender (rank 0)
|
|
659
|
+
# Note: Workers will call init_process_group in their transport's
|
|
660
|
+
# setup_connection_and_weights_on_receiver. The init_process_group is
|
|
661
|
+
# a collective operation, so all ranks must call it together.
|
|
662
|
+
torch.distributed.init_process_group(
|
|
663
|
+
backend=self._backend,
|
|
664
|
+
rank=0,
|
|
665
|
+
world_size=world_size,
|
|
666
|
+
timeout=_DIST_TIMEOUT,
|
|
667
|
+
)
|
|
668
|
+
self._dist_initialized = True
|
|
669
|
+
|
|
670
|
+
def _setup_connection_and_weights_on_sender_impl(
|
|
671
|
+
self,
|
|
672
|
+
*,
|
|
673
|
+
worker_idx: int | None = None,
|
|
674
|
+
weights: Any | None = None,
|
|
675
|
+
) -> None:
|
|
676
|
+
"""Set up distributed connection and send initial weights to all workers.
|
|
677
|
+
|
|
678
|
+
This method:
|
|
679
|
+
1. Sets up torch.distributed process group (waits for workers if needed)
|
|
680
|
+
2. Sends initial weights to all workers via their transports
|
|
681
|
+
|
|
682
|
+
The distributed setup is done here (not in ``init_on_sender``) because
|
|
683
|
+
workers need to have ``register_scheme_receiver`` called first.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
worker_idx (int, optional): Not used in this implementation.
|
|
687
|
+
weights (optional): Not used in this implementation (weights are
|
|
688
|
+
extracted from the model).
|
|
689
|
+
"""
|
|
690
|
+
# Set up distributed connection (with wait for workers to be ready)
|
|
691
|
+
if not self._dist_initialized:
|
|
692
|
+
self._setup_distributed_connection_sender()
|
|
693
|
+
|
|
694
|
+
# Send the initial weights
|
|
695
|
+
if self._stateful_model:
|
|
696
|
+
self._send_weights_distributed()
|
|
697
|
+
|
|
698
|
+
def _send_weights_distributed(self) -> None:
|
|
699
|
+
"""Send weights to all workers via torch.distributed.
|
|
700
|
+
|
|
701
|
+
Raises:
|
|
702
|
+
RuntimeError: If no weights are available to send.
|
|
703
|
+
"""
|
|
704
|
+
# Extract weights from model
|
|
705
|
+
weights = self.weights
|
|
706
|
+
if weights is None:
|
|
707
|
+
raise RuntimeError("No weights available to send")
|
|
708
|
+
|
|
709
|
+
# Send weights to each worker (ranks 1 to num_workers)
|
|
710
|
+
futures = []
|
|
711
|
+
for worker_idx in range(self._num_workers):
|
|
712
|
+
rank = worker_idx + 1
|
|
713
|
+
futures.extend(weights.isend(dst=rank, return_early=True))
|
|
714
|
+
# Wait for all sends to complete
|
|
715
|
+
for future in futures:
|
|
716
|
+
future.wait()
|
|
717
|
+
|
|
718
|
+
def _setup_connection_and_weights_on_receiver_impl(
|
|
719
|
+
self, *, worker_idx: int | None = None
|
|
720
|
+
) -> None:
|
|
721
|
+
"""Join torch.distributed process group and receive initial weights.
|
|
722
|
+
|
|
723
|
+
Delegates to the transport's :meth:`~RayTransport.setup_connection_and_weights_on_receiver`.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
worker_idx (int, optional): The worker index. If None, uses the stored
|
|
727
|
+
``_worker_idx`` or defaults to 0.
|
|
728
|
+
"""
|
|
729
|
+
if worker_idx is None:
|
|
730
|
+
worker_idx = self._worker_idx
|
|
731
|
+
if worker_idx is None:
|
|
732
|
+
worker_idx = 0 # Default to worker 0
|
|
733
|
+
|
|
734
|
+
transport = self.receiver_transport
|
|
735
|
+
if transport is not None:
|
|
736
|
+
# Transport handles joining process group and receiving weights
|
|
737
|
+
transport.setup_connection_and_weights_on_receiver(
|
|
738
|
+
worker_idx=worker_idx,
|
|
739
|
+
model=self.model,
|
|
740
|
+
weights=self.weights,
|
|
741
|
+
strategy=self._strategy,
|
|
742
|
+
)
|
|
743
|
+
self._dist_initialized = True
|
|
744
|
+
|
|
745
|
+
@staticmethod
|
|
746
|
+
def _find_free_port() -> int:
|
|
747
|
+
"""Find a free port on the local machine.
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
int: An available port number.
|
|
751
|
+
"""
|
|
752
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
753
|
+
s.bind(("", 0))
|
|
754
|
+
s.listen(1)
|
|
755
|
+
port = s.getsockname()[1]
|
|
756
|
+
return port
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
class RayModuleTransformScheme(RayWeightSyncScheme):
|
|
760
|
+
"""Weight synchronization for RayModuleTransform.
|
|
761
|
+
|
|
762
|
+
This scheme uses torch.distributed to synchronize weights between
|
|
763
|
+
a trainer/collector and a RayModuleTransform actor. The sender is rank 0,
|
|
764
|
+
the transform's actor is rank 1.
|
|
765
|
+
|
|
766
|
+
This enables updating the weights of a module running inside a RayModuleTransform
|
|
767
|
+
from a parent collector or training loop.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
|
|
771
|
+
Default is "tensordict".
|
|
772
|
+
backend (str): The torch.distributed backend to use ("gloo" or "nccl").
|
|
773
|
+
Default is "gloo".
|
|
774
|
+
|
|
775
|
+
Example:
|
|
776
|
+
>>> # Create scheme and transform
|
|
777
|
+
>>> scheme = RayModuleTransformScheme()
|
|
778
|
+
>>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme)
|
|
779
|
+
>>>
|
|
780
|
+
>>> # Create env with transform
|
|
781
|
+
>>> env = TransformedEnv(base_env, transform)
|
|
782
|
+
>>>
|
|
783
|
+
>>> # Pass scheme to parent collector
|
|
784
|
+
>>> collector = SomeCollector(
|
|
785
|
+
... env, policy,
|
|
786
|
+
... weight_sync_schemes={"transform_module": scheme}
|
|
787
|
+
... )
|
|
788
|
+
>>>
|
|
789
|
+
>>> # Update weights
|
|
790
|
+
>>> collector.update_policy_weights_(model_id="transform_module")
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
def __init__(
|
|
794
|
+
self,
|
|
795
|
+
strategy: Literal["tensordict", "state_dict"] = "tensordict",
|
|
796
|
+
backend: str = "gloo",
|
|
797
|
+
):
|
|
798
|
+
"""Initialize the RayModuleTransformScheme.
|
|
799
|
+
|
|
800
|
+
Args:
|
|
801
|
+
strategy (str): The weight transmission strategy ("state_dict" or "tensordict").
|
|
802
|
+
Defaults to "tensordict".
|
|
803
|
+
backend (str): The torch.distributed backend to use ("gloo" or "nccl").
|
|
804
|
+
Defaults to "gloo".
|
|
805
|
+
"""
|
|
806
|
+
super().__init__(strategy, backend)
|
|
807
|
+
self._ray_transform = None
|
|
808
|
+
|
|
809
|
+
def _set_transform(self, ray_transform) -> None:
|
|
810
|
+
"""Store reference to the RayModuleTransform.
|
|
811
|
+
|
|
812
|
+
Called by RayModuleTransform when the scheme is passed to it.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
ray_transform: The RayModuleTransform instance.
|
|
816
|
+
"""
|
|
817
|
+
self._ray_transform = ray_transform
|
|
818
|
+
|
|
819
|
+
def _init_on_sender_impl(
|
|
820
|
+
self,
|
|
821
|
+
model_id: str | None = None,
|
|
822
|
+
context: Any = None,
|
|
823
|
+
**kwargs,
|
|
824
|
+
) -> None:
|
|
825
|
+
"""Initialize on the main process (sender side).
|
|
826
|
+
|
|
827
|
+
Uses the stored transform reference (set via _set_transform) to
|
|
828
|
+
create transport for the transform's actor.
|
|
829
|
+
|
|
830
|
+
Args:
|
|
831
|
+
model_id: Identifier for the model being synchronized
|
|
832
|
+
context: Optional context object (typically the collector)
|
|
833
|
+
**kwargs: Optional parameters (ray_transform, model, etc.)
|
|
834
|
+
"""
|
|
835
|
+
try:
|
|
836
|
+
import ray
|
|
837
|
+
|
|
838
|
+
self.ray = ray
|
|
839
|
+
except ImportError:
|
|
840
|
+
raise ImportError("Ray is required for RayModuleTransformScheme")
|
|
841
|
+
|
|
842
|
+
# Get transform reference - either stored via _set_transform or from kwargs
|
|
843
|
+
ray_transform = self._ray_transform
|
|
844
|
+
if ray_transform is None:
|
|
845
|
+
ray_transform = kwargs.get("ray_transform")
|
|
846
|
+
if ray_transform is None:
|
|
847
|
+
raise ValueError(
|
|
848
|
+
"ray_transform must be set via _set_transform() or provided in kwargs. "
|
|
849
|
+
"Pass the scheme to RayModuleTransform constructor to set it automatically."
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
# Store model_id
|
|
853
|
+
self.model_id = model_id
|
|
854
|
+
|
|
855
|
+
# Single worker (the transform's actor)
|
|
856
|
+
self._num_workers = 1
|
|
857
|
+
|
|
858
|
+
# Create transport for the transform's actor
|
|
859
|
+
# The actor handle is ray_transform._actor
|
|
860
|
+
transport = self.create_transport(
|
|
861
|
+
remote_actor=ray_transform._actor,
|
|
862
|
+
worker_idx=0,
|
|
863
|
+
)
|
|
864
|
+
self._register_worker_sender(
|
|
865
|
+
worker_idx=0,
|
|
866
|
+
transport=transport,
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# Set context if provided
|
|
870
|
+
if context is not None:
|
|
871
|
+
self.context = context
|
|
872
|
+
|
|
873
|
+
# Store source model reference if provided for automatic weight extraction
|
|
874
|
+
model = kwargs.get("model")
|
|
875
|
+
if model is not None:
|
|
876
|
+
self.model = model
|
|
877
|
+
|
|
878
|
+
def _init_on_receiver_impl(
|
|
879
|
+
self,
|
|
880
|
+
model_id: str,
|
|
881
|
+
context: Any = None,
|
|
882
|
+
**kwargs,
|
|
883
|
+
) -> None:
|
|
884
|
+
"""Initialize on the transform's actor (receiver side).
|
|
885
|
+
|
|
886
|
+
Args:
|
|
887
|
+
model_id: Identifier for the model being synchronized
|
|
888
|
+
context: The ModuleTransform instance (the actor's underlying class)
|
|
889
|
+
**kwargs: Optional parameters (worker_idx, model, etc.)
|
|
890
|
+
"""
|
|
891
|
+
try:
|
|
892
|
+
import ray
|
|
893
|
+
|
|
894
|
+
self.ray = ray
|
|
895
|
+
except ImportError:
|
|
896
|
+
raise ImportError("Ray is required for RayModuleTransformScheme")
|
|
897
|
+
|
|
898
|
+
# Store model_id and context
|
|
899
|
+
self.model_id = model_id
|
|
900
|
+
self.context = context
|
|
901
|
+
|
|
902
|
+
# Single transform actor is always worker_idx=0
|
|
903
|
+
self._worker_idx = kwargs.get("worker_idx", 0)
|
|
904
|
+
|
|
905
|
+
# Resolve the target model from context (ModuleTransform has a .module attribute)
|
|
906
|
+
model = kwargs.get("model")
|
|
907
|
+
if model is None and context is not None:
|
|
908
|
+
model = getattr(context, "module", None)
|
|
909
|
+
if model is not None:
|
|
910
|
+
self.model = model
|
|
911
|
+
|
|
912
|
+
# Create and register transport for receiver side
|
|
913
|
+
# Note: create_transport returns TransportBackend but we know it's RayTransport
|
|
914
|
+
transport = self.create_transport(
|
|
915
|
+
remote_actor=None,
|
|
916
|
+
worker_idx=self._worker_idx,
|
|
917
|
+
)
|
|
918
|
+
if isinstance(transport, RayTransport):
|
|
919
|
+
transport.set_model(model)
|
|
920
|
+
self._register_transport_receiver(transport=transport)
|
|
921
|
+
|
|
922
|
+
def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None:
|
|
923
|
+
"""Set up torch.distributed for the single transform actor.
|
|
924
|
+
|
|
925
|
+
Overrides parent to work with a single RayModuleTransform instead of
|
|
926
|
+
multiple remote collectors.
|
|
927
|
+
|
|
928
|
+
Args:
|
|
929
|
+
timeout (float): Maximum time in seconds to wait for connection setup.
|
|
930
|
+
Defaults to 300.0 (5 minutes).
|
|
931
|
+
|
|
932
|
+
Raises:
|
|
933
|
+
RuntimeError: If ``ray_transform`` is not set.
|
|
934
|
+
"""
|
|
935
|
+
if self._dist_initialized:
|
|
936
|
+
return
|
|
937
|
+
|
|
938
|
+
if self._ray_transform is None:
|
|
939
|
+
raise RuntimeError(
|
|
940
|
+
"_setup_distributed_connection() requires ray_transform to be set. "
|
|
941
|
+
"Did you pass the scheme to RayModuleTransform?"
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
# Get master address (hostname/IP)
|
|
945
|
+
hostname = socket.gethostname()
|
|
946
|
+
try:
|
|
947
|
+
master_addr = socket.gethostbyname(hostname)
|
|
948
|
+
except socket.gaierror:
|
|
949
|
+
master_addr = "127.0.0.1"
|
|
950
|
+
|
|
951
|
+
# Find an available port
|
|
952
|
+
master_port = self._find_free_port()
|
|
953
|
+
world_size = 2 # Sender (rank 0) + Transform (rank 1)
|
|
954
|
+
|
|
955
|
+
# Check if model has weights
|
|
956
|
+
try:
|
|
957
|
+
w = self.weights
|
|
958
|
+
stateful_model = w is not None
|
|
959
|
+
except (AttributeError, RuntimeError, ValueError):
|
|
960
|
+
stateful_model = False
|
|
961
|
+
self._stateful_model = stateful_model
|
|
962
|
+
|
|
963
|
+
# Connection info to share with the transform's actor
|
|
964
|
+
RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options(
|
|
965
|
+
name=self.connection_info_name
|
|
966
|
+
)
|
|
967
|
+
self._connection_info_actor = RemoteConnectionInfo.remote(
|
|
968
|
+
master_addr=master_addr,
|
|
969
|
+
master_port=master_port,
|
|
970
|
+
world_size=world_size,
|
|
971
|
+
stateful_model=stateful_model,
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
# Set environment variables for torch.distributed
|
|
975
|
+
os.environ["MASTER_ADDR"] = master_addr
|
|
976
|
+
os.environ["MASTER_PORT"] = str(master_port)
|
|
977
|
+
|
|
978
|
+
# Now initialize process group on sender (rank 0)
|
|
979
|
+
# The receiver is concurrently joining via the Ray call above
|
|
980
|
+
torch.distributed.init_process_group(
|
|
981
|
+
backend=self._backend,
|
|
982
|
+
rank=0,
|
|
983
|
+
world_size=world_size,
|
|
984
|
+
timeout=_DIST_TIMEOUT,
|
|
985
|
+
)
|
|
986
|
+
self._dist_initialized = True
|
|
987
|
+
|
|
988
|
+
def _setup_connection_and_weights_on_sender_impl(
|
|
989
|
+
self,
|
|
990
|
+
*,
|
|
991
|
+
worker_idx: int | None = None,
|
|
992
|
+
weights: Any | None = None,
|
|
993
|
+
) -> None:
|
|
994
|
+
"""Set up distributed connection and send initial weights.
|
|
995
|
+
|
|
996
|
+
Args:
|
|
997
|
+
worker_idx (int, optional): The worker index. Not used for
|
|
998
|
+
RayModuleTransformScheme as there is only one transform actor.
|
|
999
|
+
weights (optional): Pre-extracted weights to send. If None, weights
|
|
1000
|
+
are extracted from the model.
|
|
1001
|
+
"""
|
|
1002
|
+
receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote(
|
|
1003
|
+
scheme=self, model_id=self.model_id
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
if not self._dist_initialized:
|
|
1007
|
+
self._setup_distributed_connection_sender()
|
|
1008
|
+
|
|
1009
|
+
if self._stateful_model:
|
|
1010
|
+
self._send_weights_distributed(weights=weights)
|
|
1011
|
+
|
|
1012
|
+
self.ray.get(receiver_future)
|
|
1013
|
+
|
|
1014
|
+
def _send_weights_distributed(self, weights: Any | None = None) -> None:
|
|
1015
|
+
"""Send weights to the transform actor via torch.distributed.
|
|
1016
|
+
|
|
1017
|
+
Args:
|
|
1018
|
+
weights (optional): Pre-extracted weights to send. If None, weights
|
|
1019
|
+
are extracted from the model via :attr:`weights`.
|
|
1020
|
+
|
|
1021
|
+
Raises:
|
|
1022
|
+
RuntimeError: If no weights are available to send.
|
|
1023
|
+
"""
|
|
1024
|
+
if weights is None:
|
|
1025
|
+
weights = self.weights
|
|
1026
|
+
if weights is None:
|
|
1027
|
+
raise RuntimeError("No weights available to send")
|
|
1028
|
+
|
|
1029
|
+
# Send weights to the transform (rank 1)
|
|
1030
|
+
futures = weights.isend(dst=1, return_early=True)
|
|
1031
|
+
for future in futures:
|
|
1032
|
+
future.wait()
|