torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,749 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import socket
|
|
5
|
+
import time
|
|
6
|
+
import weakref
|
|
7
|
+
from datetime import timedelta
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDictBase
|
|
12
|
+
from torchrl._utils import logger as torchrl_logger
|
|
13
|
+
|
|
14
|
+
from torchrl.weight_update.utils import _resolve_model
|
|
15
|
+
|
|
16
|
+
from torchrl.weight_update.weight_sync_schemes import (
|
|
17
|
+
TransportBackend,
|
|
18
|
+
WeightStrategy,
|
|
19
|
+
WeightSyncScheme,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DistributedWeightSyncScheme(WeightSyncScheme):
|
|
24
|
+
"""Weight synchronization for torch.distributed.
|
|
25
|
+
|
|
26
|
+
This scheme uses torch.distributed primitives (send/recv) to synchronize
|
|
27
|
+
weights across distributed workers. Each worker gets its own transport,
|
|
28
|
+
following the same pattern as multiprocess collectors.
|
|
29
|
+
|
|
30
|
+
The scheme can create its own TCPStore for coordination if one is not provided.
|
|
31
|
+
Use `get_store_info()` after `init_on_sender()` to get connection details for workers.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
backend (str): The distributed backend ("gloo", "nccl", etc.)
|
|
35
|
+
sync (bool): If True, weight updates are synchronous (blocking receive).
|
|
36
|
+
If False, a background thread monitors the store and applies weight
|
|
37
|
+
updates automatically. Defaults to True.
|
|
38
|
+
timeout (float): Timeout in seconds for TCPStore operations.
|
|
39
|
+
Defaults to 3600.0 (1 hour).
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
backend: str = "gloo",
|
|
45
|
+
sync: bool = True,
|
|
46
|
+
timeout: float = 3600.0,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.backend = backend
|
|
50
|
+
self.sync = sync
|
|
51
|
+
self._timeout = timeout
|
|
52
|
+
self._store = None
|
|
53
|
+
self._store_info = None
|
|
54
|
+
self._num_workers = None
|
|
55
|
+
|
|
56
|
+
def __getstate__(self):
|
|
57
|
+
"""Custom serialization - exclude non-picklable objects."""
|
|
58
|
+
state = super().__getstate__()
|
|
59
|
+
# TCPStore cannot be pickled - remove it but keep _store_info
|
|
60
|
+
state["_store"] = None
|
|
61
|
+
|
|
62
|
+
# Thread and Event cannot be pickled
|
|
63
|
+
state["_background_thread"] = None
|
|
64
|
+
state["_stop_event"] = None
|
|
65
|
+
|
|
66
|
+
# Transports contain references to store/groups - exclude them
|
|
67
|
+
# The receiver will create its own transport in init_on_receiver
|
|
68
|
+
state["_sender_transports"] = {}
|
|
69
|
+
state["_receiver_transport"] = None
|
|
70
|
+
return state
|
|
71
|
+
|
|
72
|
+
def __setstate__(self, state):
|
|
73
|
+
"""Custom deserialization."""
|
|
74
|
+
super().__setstate__(state)
|
|
75
|
+
|
|
76
|
+
def _init_on_sender_impl(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
model_id: str,
|
|
80
|
+
context: Any = None,
|
|
81
|
+
num_workers: int,
|
|
82
|
+
model: Any = None,
|
|
83
|
+
weights: Any = None,
|
|
84
|
+
**kwargs,
|
|
85
|
+
) -> None:
|
|
86
|
+
if kwargs:
|
|
87
|
+
raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}")
|
|
88
|
+
self.model_id = model_id
|
|
89
|
+
self._num_workers = num_workers
|
|
90
|
+
|
|
91
|
+
# Attach context so we can resolve the model and prepare
|
|
92
|
+
# weights on demand via scheme.prepare_weights().
|
|
93
|
+
weights_buffer = None
|
|
94
|
+
if context is not None:
|
|
95
|
+
self.context = context
|
|
96
|
+
if weights is not None:
|
|
97
|
+
self.weights = weights
|
|
98
|
+
weights_buffer = weights
|
|
99
|
+
if model is not None:
|
|
100
|
+
self.model = model
|
|
101
|
+
else:
|
|
102
|
+
# resolve model
|
|
103
|
+
try:
|
|
104
|
+
model = self.model
|
|
105
|
+
except (AttributeError, ValueError):
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
if weights_buffer is None and model is not None:
|
|
109
|
+
weights_buffer = self._get_weights_buffer_from_model(model)
|
|
110
|
+
|
|
111
|
+
# Get base tcp_port from context if available to avoid port conflicts.
|
|
112
|
+
# The DistributedDataCollector uses tcp_port for init and tcp_port+1 for its store,
|
|
113
|
+
# so we use tcp_port+2 for the weight sync scheme's store.
|
|
114
|
+
base_tcp_port = (
|
|
115
|
+
getattr(context, "tcp_port", None) if context is not None else None
|
|
116
|
+
)
|
|
117
|
+
self._store = self._make_store(
|
|
118
|
+
is_master=True, num_workers=num_workers, base_tcp_port=base_tcp_port
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for i in range(num_workers):
|
|
122
|
+
rank = i + 1 # Workers are 1-indexed in distributed
|
|
123
|
+
transport = self.create_transport(
|
|
124
|
+
store=self._store,
|
|
125
|
+
rank=rank,
|
|
126
|
+
weights_buffer=weights_buffer,
|
|
127
|
+
sync=self.sync,
|
|
128
|
+
)
|
|
129
|
+
self._register_worker_sender(worker_idx=i, transport=transport)
|
|
130
|
+
|
|
131
|
+
def _make_store(
|
|
132
|
+
self,
|
|
133
|
+
is_master: bool,
|
|
134
|
+
num_workers: int | None = None,
|
|
135
|
+
store_info: dict | None = None,
|
|
136
|
+
base_tcp_port: int | str | None = None,
|
|
137
|
+
max_retries: int = 10,
|
|
138
|
+
) -> torch.distributed.TCPStore:
|
|
139
|
+
"""Create a TCPStore for weight synchronization.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
is_master: If True, creates the store as master (server).
|
|
143
|
+
If False, connects as client.
|
|
144
|
+
num_workers: Number of workers (required for master).
|
|
145
|
+
store_info: Dictionary with 'host' and 'port' keys (required for client).
|
|
146
|
+
base_tcp_port: Base TCP port from the collector. If provided, the store
|
|
147
|
+
will use base_tcp_port + 2 to avoid conflicts with the collector's
|
|
148
|
+
stores (which use base_tcp_port and base_tcp_port + 1).
|
|
149
|
+
max_retries: Maximum number of retry attempts for handling port conflicts.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
The created TCPStore.
|
|
153
|
+
"""
|
|
154
|
+
if is_master:
|
|
155
|
+
# Create as master (server)
|
|
156
|
+
if num_workers is None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"num_workers is required when creating store as master"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
hostname = socket.gethostname()
|
|
162
|
+
host = socket.gethostbyname(hostname)
|
|
163
|
+
|
|
164
|
+
# Use base_tcp_port + 2 if available (to avoid conflicts with collector's
|
|
165
|
+
# tcp_port and tcp_port + 1), otherwise find a free port dynamically.
|
|
166
|
+
initial_port = int(base_tcp_port) + 2 if base_tcp_port is not None else None
|
|
167
|
+
|
|
168
|
+
last_error = None
|
|
169
|
+
for attempt in range(max_retries):
|
|
170
|
+
if initial_port is None or attempt > 0:
|
|
171
|
+
# Find a free port dynamically
|
|
172
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
173
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
174
|
+
s.bind(("", 0))
|
|
175
|
+
self._store_port = s.getsockname()[1]
|
|
176
|
+
else:
|
|
177
|
+
self._store_port = initial_port
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
store = torch.distributed.TCPStore(
|
|
181
|
+
host_name=host,
|
|
182
|
+
port=self._store_port,
|
|
183
|
+
is_master=True,
|
|
184
|
+
timeout=timedelta(seconds=self._timeout),
|
|
185
|
+
wait_for_workers=False, # Don't block - workers may not be started yet
|
|
186
|
+
)
|
|
187
|
+
self._store_info = {"host": host, "port": self._store_port}
|
|
188
|
+
return store
|
|
189
|
+
except (RuntimeError, OSError) as e:
|
|
190
|
+
error_msg = str(e).lower()
|
|
191
|
+
if (
|
|
192
|
+
"address already in use" in error_msg
|
|
193
|
+
or "eaddrinuse" in error_msg
|
|
194
|
+
):
|
|
195
|
+
last_error = e
|
|
196
|
+
# Add small random delay to reduce collision probability
|
|
197
|
+
time.sleep(random.uniform(0.01, 0.1))
|
|
198
|
+
continue
|
|
199
|
+
# For other errors, re-raise immediately
|
|
200
|
+
raise
|
|
201
|
+
|
|
202
|
+
raise RuntimeError(
|
|
203
|
+
f"DistributedWeightSyncScheme: Failed to create TCPStore after {max_retries} attempts. "
|
|
204
|
+
f"Last error: {last_error}"
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
# Connect as client
|
|
208
|
+
if store_info is None:
|
|
209
|
+
raise ValueError("store_info is required when connecting as client")
|
|
210
|
+
store = torch.distributed.TCPStore(
|
|
211
|
+
host_name=store_info["host"],
|
|
212
|
+
port=store_info["port"],
|
|
213
|
+
is_master=False,
|
|
214
|
+
timeout=timedelta(seconds=self._timeout),
|
|
215
|
+
)
|
|
216
|
+
return store
|
|
217
|
+
|
|
218
|
+
def _init_on_receiver_impl(
|
|
219
|
+
self,
|
|
220
|
+
*,
|
|
221
|
+
model_id: str,
|
|
222
|
+
context: Any = None,
|
|
223
|
+
store_info: dict | None = None,
|
|
224
|
+
worker_idx: int | None = None,
|
|
225
|
+
**kwargs,
|
|
226
|
+
) -> None:
|
|
227
|
+
"""Initialize scheme on the worker (receiver) side.
|
|
228
|
+
|
|
229
|
+
Expected kwargs (as provided by collectors):
|
|
230
|
+
- model_id: str # e.g. "policy"
|
|
231
|
+
- context: Any # collector / inner collector
|
|
232
|
+
- store: TCPStore | None # distributed TCP store
|
|
233
|
+
- store_info: dict | None # {"host": ..., "port": ...} to create store
|
|
234
|
+
- rank: int | None # worker rank (1-indexed)
|
|
235
|
+
"""
|
|
236
|
+
if context is None:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"DistributedWeightSyncScheme.init_on_receiver requires a 'context' "
|
|
239
|
+
"providing access to the model to be synchronized."
|
|
240
|
+
)
|
|
241
|
+
if worker_idx is None:
|
|
242
|
+
raise RuntimeError("rank was not provided.")
|
|
243
|
+
if kwargs:
|
|
244
|
+
raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}")
|
|
245
|
+
|
|
246
|
+
# Store model_id and context on scheme
|
|
247
|
+
self.model_id = model_id
|
|
248
|
+
self.context = context
|
|
249
|
+
|
|
250
|
+
# Get or create store
|
|
251
|
+
# Priority: provided store > provided store_info > self._store_info (from serialization)
|
|
252
|
+
# Connect to master's TCPStore as client
|
|
253
|
+
info = self._store_info
|
|
254
|
+
if info is None:
|
|
255
|
+
raise RuntimeError(
|
|
256
|
+
"TCPStore info not available. init_on_sender() must be called first on the sender side, before passing the scheme to the receiver."
|
|
257
|
+
)
|
|
258
|
+
self._store = self._make_store(is_master=False, store_info=info)
|
|
259
|
+
|
|
260
|
+
if (model := getattr(self, "model", None)) is not None:
|
|
261
|
+
self.model = model
|
|
262
|
+
weights_buffer = self._get_weights_buffer_from_model(model)
|
|
263
|
+
else:
|
|
264
|
+
raise RuntimeError("Couldn't find weights")
|
|
265
|
+
self._receiver_transport = self.create_transport(
|
|
266
|
+
store=self._store,
|
|
267
|
+
rank=worker_idx,
|
|
268
|
+
weights_buffer=weights_buffer,
|
|
269
|
+
sync=self.sync,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Store worker_idx for synchronize_weights
|
|
273
|
+
self._worker_idx = worker_idx
|
|
274
|
+
# Note: Background thread for async mode is started in connect() after init_process_group
|
|
275
|
+
|
|
276
|
+
def _wait_for_instruction(self, timeout: float | None = None) -> str | None:
|
|
277
|
+
"""Block until an instruction arrives via TCPStore.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
timeout: Maximum time to wait for instruction (seconds).
|
|
281
|
+
None means block indefinitely.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
The instruction string (e.g., "receive", "stop"), or None if
|
|
285
|
+
stop event is set or timeout expires.
|
|
286
|
+
"""
|
|
287
|
+
key = f"NODE_{self._worker_idx}_in"
|
|
288
|
+
start_time = time.monotonic()
|
|
289
|
+
|
|
290
|
+
while True:
|
|
291
|
+
if self._stop_event is not None and self._stop_event.is_set():
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
instruction = self._store.get(key)
|
|
296
|
+
self._store.delete_key(key)
|
|
297
|
+
# Decode bytes to string
|
|
298
|
+
return (
|
|
299
|
+
instruction.decode()
|
|
300
|
+
if isinstance(instruction, bytes)
|
|
301
|
+
else instruction
|
|
302
|
+
)
|
|
303
|
+
except RuntimeError:
|
|
304
|
+
# Key doesn't exist yet, continue polling
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
# Check timeout
|
|
308
|
+
if timeout is not None:
|
|
309
|
+
elapsed = time.monotonic() - start_time
|
|
310
|
+
if elapsed >= timeout:
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
time.sleep(0.01)
|
|
314
|
+
|
|
315
|
+
def _send_instruction(
|
|
316
|
+
self,
|
|
317
|
+
instruction: str = "receive",
|
|
318
|
+
worker_ids: int | list[int] | None = None,
|
|
319
|
+
) -> None:
|
|
320
|
+
"""Send instruction to receiver(s) via TCPStore.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
instruction: The instruction to send (default: "receive").
|
|
324
|
+
worker_ids: Which workers to send to (None = all workers).
|
|
325
|
+
"""
|
|
326
|
+
if self._store is None:
|
|
327
|
+
raise RuntimeError(
|
|
328
|
+
"Store not initialized. init_on_sender() must be called first."
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
if worker_ids is None:
|
|
332
|
+
target_workers = list(range(self._num_workers)) if self._num_workers else []
|
|
333
|
+
elif isinstance(worker_ids, int):
|
|
334
|
+
target_workers = [worker_ids]
|
|
335
|
+
else:
|
|
336
|
+
target_workers = list(worker_ids)
|
|
337
|
+
|
|
338
|
+
# Map instruction to TCPStore format
|
|
339
|
+
store_instruction = (
|
|
340
|
+
b"update_weights" if instruction == "receive" else instruction.encode()
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
for worker_idx in target_workers:
|
|
344
|
+
rank = worker_idx + 1 # Workers are 1-indexed in distributed
|
|
345
|
+
self._store.set(f"NODE_{rank}_in", store_instruction)
|
|
346
|
+
|
|
347
|
+
def _send_ack(self, message: str = "updated") -> None:
|
|
348
|
+
"""Send acknowledgment back to sender via TCPStore.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
message: The acknowledgment message (default: "updated").
|
|
352
|
+
"""
|
|
353
|
+
if self._store is None or self._worker_idx is None:
|
|
354
|
+
return
|
|
355
|
+
self._store.set(f"NODE_{self._worker_idx}_out", message.encode())
|
|
356
|
+
|
|
357
|
+
def _wait_for_ack(
|
|
358
|
+
self,
|
|
359
|
+
worker_ids: int | list[int] | None = None,
|
|
360
|
+
timeout: float | None = None,
|
|
361
|
+
) -> None:
|
|
362
|
+
"""Wait for acknowledgment from receiver(s) via TCPStore.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
worker_ids: Which workers to wait for (None = all workers).
|
|
366
|
+
timeout: Maximum time to wait (seconds). None means block indefinitely.
|
|
367
|
+
"""
|
|
368
|
+
if self._store is None:
|
|
369
|
+
return
|
|
370
|
+
|
|
371
|
+
if worker_ids is None:
|
|
372
|
+
target_workers = list(range(self._num_workers)) if self._num_workers else []
|
|
373
|
+
elif isinstance(worker_ids, int):
|
|
374
|
+
target_workers = [worker_ids]
|
|
375
|
+
else:
|
|
376
|
+
target_workers = list(worker_ids)
|
|
377
|
+
|
|
378
|
+
for worker_idx in target_workers:
|
|
379
|
+
rank = worker_idx + 1
|
|
380
|
+
key = f"NODE_{rank}_out"
|
|
381
|
+
try:
|
|
382
|
+
status = self._store.get(key)
|
|
383
|
+
if status != b"updated":
|
|
384
|
+
torchrl_logger.warning(
|
|
385
|
+
f"Unexpected ack from worker {worker_idx}: {status}"
|
|
386
|
+
)
|
|
387
|
+
self._store.delete_key(key)
|
|
388
|
+
except Exception as e:
|
|
389
|
+
torchrl_logger.warning(
|
|
390
|
+
f"Timeout waiting for ack from worker {worker_idx}: {e}"
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
def _background_receive_loop(self):
|
|
394
|
+
"""Background thread loop that waits for instructions and receives weights.
|
|
395
|
+
|
|
396
|
+
This loop:
|
|
397
|
+
1. Waits for an instruction via TCPStore
|
|
398
|
+
2. Receives weights via torch.distributed
|
|
399
|
+
3. Sends an acknowledgment back
|
|
400
|
+
4. Repeats until stop event is set
|
|
401
|
+
"""
|
|
402
|
+
while not self._stop_event.is_set():
|
|
403
|
+
try:
|
|
404
|
+
instruction = self._wait_for_instruction()
|
|
405
|
+
if instruction is None:
|
|
406
|
+
continue
|
|
407
|
+
if instruction in ("receive", "update_weights"):
|
|
408
|
+
# Receive weights via torch.distributed
|
|
409
|
+
weights = self._receiver_transport.receive_weights(
|
|
410
|
+
model=self.model,
|
|
411
|
+
strategy=self._strategy,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
if weights is not None:
|
|
415
|
+
# Cascade weight update to sub-collectors if context supports it
|
|
416
|
+
model_id = self._model_id or "policy"
|
|
417
|
+
if self.context is not None and hasattr(
|
|
418
|
+
self.context, "update_policy_weights_"
|
|
419
|
+
):
|
|
420
|
+
self.context.update_policy_weights_(
|
|
421
|
+
model_id=model_id, policy_or_weights=weights
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Send acknowledgment
|
|
425
|
+
self._send_ack("updated")
|
|
426
|
+
|
|
427
|
+
elif instruction == "stop":
|
|
428
|
+
break
|
|
429
|
+
else:
|
|
430
|
+
torchrl_logger.warning(
|
|
431
|
+
f"DistributedWeightSyncScheme: Unknown instruction: {instruction}"
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
except Exception as e:
|
|
435
|
+
if not self._stop_event.is_set():
|
|
436
|
+
torchrl_logger.warning(
|
|
437
|
+
f"DistributedWeightSyncScheme: Background receiver error: {e}"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _setup_connection_and_weights_on_sender_impl(
|
|
441
|
+
self, *, worker_idx: int | None = None, weights: Any | None = None
|
|
442
|
+
) -> None:
|
|
443
|
+
"""Send initial weights to all workers during connect().
|
|
444
|
+
|
|
445
|
+
If the sender has a stateful model (weights available), send them
|
|
446
|
+
to all workers so they start with the correct weights.
|
|
447
|
+
|
|
448
|
+
Note: This uses direct torch.distributed send/recv without TCPStore
|
|
449
|
+
signaling to avoid interfering with the main collection loop.
|
|
450
|
+
"""
|
|
451
|
+
# Initialize torch.distributed process group if not already done
|
|
452
|
+
# This is a collective operation - all workers must call it
|
|
453
|
+
if not torch.distributed.is_initialized():
|
|
454
|
+
torch.distributed.init_process_group(
|
|
455
|
+
backend=self.backend,
|
|
456
|
+
rank=0, # Sender is always rank 0
|
|
457
|
+
world_size=self._num_workers + 1,
|
|
458
|
+
timeout=timedelta(seconds=self._timeout),
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Check if we have weights to send
|
|
462
|
+
if weights is None and getattr(self, "model", None) is None:
|
|
463
|
+
self._store.set("STATELESS_MODEL", b"1")
|
|
464
|
+
return
|
|
465
|
+
|
|
466
|
+
self._store.set("STATELESS_MODEL", b"0")
|
|
467
|
+
# Prepare weights from model
|
|
468
|
+
weights = self._get_weights_buffer_from_model(self.model)
|
|
469
|
+
if weights is None or weights.is_empty():
|
|
470
|
+
return
|
|
471
|
+
|
|
472
|
+
# Send to all workers using direct torch.distributed (no TCPStore signaling)
|
|
473
|
+
for i, transport in enumerate(self._iterate_transports()):
|
|
474
|
+
if worker_idx is not None and i != worker_idx:
|
|
475
|
+
continue
|
|
476
|
+
transport.send_initial_weights(weights)
|
|
477
|
+
|
|
478
|
+
def _setup_connection_and_weights_on_receiver_impl(
|
|
479
|
+
self, *, worker_idx: int | None = None
|
|
480
|
+
) -> None:
|
|
481
|
+
"""Receive initial weights from sender during connect().
|
|
482
|
+
|
|
483
|
+
The receiver always has a model that needs weights, so we block
|
|
484
|
+
waiting for the initial weights from the sender.
|
|
485
|
+
"""
|
|
486
|
+
# Use stored worker_idx if not provided
|
|
487
|
+
if worker_idx is None:
|
|
488
|
+
worker_idx = self._worker_idx
|
|
489
|
+
|
|
490
|
+
# Initialize torch.distributed process group if not already done
|
|
491
|
+
# This is a collective operation - sender and all workers must call it
|
|
492
|
+
if not torch.distributed.is_initialized():
|
|
493
|
+
torch.distributed.init_process_group(
|
|
494
|
+
backend=self.backend,
|
|
495
|
+
rank=worker_idx,
|
|
496
|
+
world_size=self._num_workers + 1,
|
|
497
|
+
timeout=timedelta(seconds=self._timeout),
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
if self._receiver_transport is None:
|
|
501
|
+
torchrl_logger.warning(
|
|
502
|
+
"DistributedWeightSyncScheme: No receiver transport, skipping initial weight sync"
|
|
503
|
+
)
|
|
504
|
+
return
|
|
505
|
+
|
|
506
|
+
stateless_model = self.receiver_transport._store.get("STATELESS_MODEL")
|
|
507
|
+
if stateless_model not in (b"0", b"1"):
|
|
508
|
+
raise RuntimeError(f"Invalid STATELESS_MODEL value: {stateless_model}")
|
|
509
|
+
if stateless_model != b"1":
|
|
510
|
+
# Receive initial weights (blocking, no TCPStore coordination)
|
|
511
|
+
weights = self._receiver_transport.receive_initial_weights()
|
|
512
|
+
if weights is not None and self.model is not None:
|
|
513
|
+
self._strategy.apply_weights(self.model, weights, inplace=False)
|
|
514
|
+
|
|
515
|
+
# Start background receiver thread AFTER initial weight sync is complete
|
|
516
|
+
# This prevents the background thread from consuming the initial sync messages
|
|
517
|
+
if self._background_thread is None:
|
|
518
|
+
self._start_background_receiver()
|
|
519
|
+
|
|
520
|
+
def shutdown(self) -> None:
|
|
521
|
+
"""Stop background receiver thread and clean up."""
|
|
522
|
+
# Check if already shutdown
|
|
523
|
+
if getattr(self, "_is_shutdown", False):
|
|
524
|
+
return
|
|
525
|
+
self._is_shutdown = True
|
|
526
|
+
|
|
527
|
+
# Let base class handle background thread cleanup
|
|
528
|
+
super().shutdown()
|
|
529
|
+
|
|
530
|
+
@property
|
|
531
|
+
def model(self) -> Any | None:
|
|
532
|
+
"""Get the model associated with this scheme.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
The model if set, None otherwise.
|
|
536
|
+
"""
|
|
537
|
+
if self._model_ref is not None:
|
|
538
|
+
return self._model_ref()
|
|
539
|
+
if self._model_id is not None:
|
|
540
|
+
model = _resolve_model(self.context, self._model_id)
|
|
541
|
+
if model is None:
|
|
542
|
+
if self._model_id == "policy":
|
|
543
|
+
torchrl_logger.debug("Creating policy from factory.")
|
|
544
|
+
model = self.context.policy_factory[0]()
|
|
545
|
+
self.context.policy = model
|
|
546
|
+
else:
|
|
547
|
+
raise AttributeError(
|
|
548
|
+
f"Model {self._model_id} was `None` in context {self.context}"
|
|
549
|
+
)
|
|
550
|
+
self._model_ref = weakref.ref(model)
|
|
551
|
+
return model
|
|
552
|
+
|
|
553
|
+
@model.setter
|
|
554
|
+
def model(self, value: Any):
|
|
555
|
+
"""Set the model for this scheme.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
value: The model to set. If None, the setter is a no-op.
|
|
559
|
+
"""
|
|
560
|
+
if value is None:
|
|
561
|
+
return
|
|
562
|
+
self._model_ref = weakref.ref(value)
|
|
563
|
+
|
|
564
|
+
def create_transport(self, **kwargs) -> TransportBackend:
|
|
565
|
+
"""Create distributed transport for a specific worker."""
|
|
566
|
+
return DistributedTransport(**kwargs)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
class DistributedTransport:
|
|
570
|
+
"""torch.distributed transport for communicating with a single distributed worker.
|
|
571
|
+
|
|
572
|
+
This transport handles weight updates for ONE specific distributed worker via
|
|
573
|
+
torch.distributed send/recv. Multiple transports are created for multiple workers,
|
|
574
|
+
following the same pattern as multiprocess collectors.
|
|
575
|
+
"""
|
|
576
|
+
|
|
577
|
+
def __init__(
|
|
578
|
+
self,
|
|
579
|
+
*,
|
|
580
|
+
weights_buffer: TensorDictBase,
|
|
581
|
+
store: torch.distributed.Store = None,
|
|
582
|
+
rank: int | None = None,
|
|
583
|
+
sync: bool = True,
|
|
584
|
+
):
|
|
585
|
+
"""Initialize the DistributedTransport.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
weights_buffer (TensorDictBase): a tensor buffer of weights.
|
|
589
|
+
store (torch.distributed.Store): A (TCP)Store for communication.
|
|
590
|
+
rank (int): Worker rank (1-indexed).
|
|
591
|
+
sync (bool): Whether to use synchronous weight updates.
|
|
592
|
+
"""
|
|
593
|
+
self._store = store
|
|
594
|
+
self._rank = rank
|
|
595
|
+
self._sync = sync
|
|
596
|
+
self._weights_buffer = weights_buffer
|
|
597
|
+
|
|
598
|
+
def send_weights(self, weights: Any) -> None:
|
|
599
|
+
"""Send weights to the distributed worker."""
|
|
600
|
+
if self._store is None or self._rank is None:
|
|
601
|
+
return
|
|
602
|
+
|
|
603
|
+
# Instruct worker to expect weight update
|
|
604
|
+
self._store.set(f"NODE_{self._rank}_in", b"update_weights")
|
|
605
|
+
|
|
606
|
+
# Send weights via torch.distributed
|
|
607
|
+
if self._sync:
|
|
608
|
+
weights.send(self._rank)
|
|
609
|
+
else:
|
|
610
|
+
weights.isend(self._rank)
|
|
611
|
+
|
|
612
|
+
# Wait for acknowledgment
|
|
613
|
+
status = self._store.get(f"NODE_{self._rank}_out")
|
|
614
|
+
if status != b"updated":
|
|
615
|
+
raise RuntimeError(f"Expected 'updated' but got status {status}.")
|
|
616
|
+
self._store.delete_key(f"NODE_{self._rank}_out")
|
|
617
|
+
|
|
618
|
+
def send_weights_async(self, weights: Any) -> None:
|
|
619
|
+
"""Send weights to distributed worker without waiting for acknowledgment.
|
|
620
|
+
|
|
621
|
+
Use wait_ack() to wait for acknowledgment after sending to all workers.
|
|
622
|
+
"""
|
|
623
|
+
if self._store is None or self._rank is None:
|
|
624
|
+
return
|
|
625
|
+
|
|
626
|
+
# Instruct worker to expect weight update
|
|
627
|
+
self._store.set(f"NODE_{self._rank}_in", b"update_weights")
|
|
628
|
+
|
|
629
|
+
# Send weights via torch.distributed
|
|
630
|
+
if self._sync:
|
|
631
|
+
weights.send(self._rank)
|
|
632
|
+
else:
|
|
633
|
+
weights.isend(self._rank)
|
|
634
|
+
|
|
635
|
+
def wait_ack(self) -> None:
|
|
636
|
+
"""Wait for acknowledgment from distributed worker."""
|
|
637
|
+
if self._store is None or self._rank is None:
|
|
638
|
+
return
|
|
639
|
+
|
|
640
|
+
status = self._store.get(f"NODE_{self._rank}_out")
|
|
641
|
+
if status != b"updated":
|
|
642
|
+
raise RuntimeError(f"Expected 'updated' but got status {status}.")
|
|
643
|
+
self._store.delete_key(f"NODE_{self._rank}_out")
|
|
644
|
+
|
|
645
|
+
def receive_weights(
|
|
646
|
+
self,
|
|
647
|
+
timeout: float | None = None,
|
|
648
|
+
*,
|
|
649
|
+
weights: Any = None,
|
|
650
|
+
model: Any = None,
|
|
651
|
+
strategy: WeightStrategy | None = None,
|
|
652
|
+
) -> Any | None:
|
|
653
|
+
r"""Receive weights via torch.distributed and apply them to the model.
|
|
654
|
+
|
|
655
|
+
The surrounding collector loop is responsible for checking the TCPStore
|
|
656
|
+
for the \"update_weights\" instruction. When this method is called we
|
|
657
|
+
assume that a weight update has been requested and the sender has
|
|
658
|
+
already performed the corresponding ``send()``.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
timeout: Maximum time to wait for weights (seconds). If None,
|
|
662
|
+
blocks until weights are received.
|
|
663
|
+
weights: Pre-allocated weight buffer to receive into.
|
|
664
|
+
model: The model to apply weights to.
|
|
665
|
+
strategy: Strategy for applying weights to the model.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
The received weights, or None if timeout expires.
|
|
669
|
+
"""
|
|
670
|
+
if self._store is None or self._rank is None:
|
|
671
|
+
return None
|
|
672
|
+
|
|
673
|
+
# Use provided weights buffer or fallback to stored one
|
|
674
|
+
weights_buffer = weights if weights is not None else self._weights_buffer
|
|
675
|
+
|
|
676
|
+
# Receive weights via torch.distributed into the buffer
|
|
677
|
+
if self._sync or timeout is None:
|
|
678
|
+
# Blocking receive - no timeout support
|
|
679
|
+
if self._sync:
|
|
680
|
+
weights_buffer.recv(src=0)
|
|
681
|
+
else:
|
|
682
|
+
weights_buffer.irecv(src=0)
|
|
683
|
+
else:
|
|
684
|
+
# Non-blocking receive with timeout support
|
|
685
|
+
futures = weights_buffer.irecv(src=0, return_premature=True)
|
|
686
|
+
if futures:
|
|
687
|
+
start_time = time.monotonic()
|
|
688
|
+
while True:
|
|
689
|
+
# Check if all futures are complete
|
|
690
|
+
all_complete = all(f.is_completed() for f in futures)
|
|
691
|
+
if all_complete:
|
|
692
|
+
break
|
|
693
|
+
# Check timeout
|
|
694
|
+
elapsed = time.monotonic() - start_time
|
|
695
|
+
if elapsed >= timeout:
|
|
696
|
+
# Timeout expired before receiving all weights
|
|
697
|
+
return None
|
|
698
|
+
# Small sleep to avoid busy-waiting
|
|
699
|
+
time.sleep(0.001)
|
|
700
|
+
|
|
701
|
+
# Apply weights if model and strategy provided
|
|
702
|
+
if model is not None and strategy is not None:
|
|
703
|
+
strategy.apply_weights(model, weights_buffer)
|
|
704
|
+
|
|
705
|
+
return weights_buffer
|
|
706
|
+
|
|
707
|
+
def send_initial_weights(self, weights: Any) -> None:
|
|
708
|
+
"""Send initial weights during connect() without TCPStore signaling.
|
|
709
|
+
|
|
710
|
+
This is used for the initial weight sync during connect() to avoid
|
|
711
|
+
interfering with the main collection loop's TCPStore-based coordination.
|
|
712
|
+
"""
|
|
713
|
+
if self._rank is None:
|
|
714
|
+
return
|
|
715
|
+
|
|
716
|
+
# Note: No TCPStore signaling for initial sync - just direct send/recv
|
|
717
|
+
if self._sync:
|
|
718
|
+
weights.send(self._rank)
|
|
719
|
+
else:
|
|
720
|
+
weights.isend(self._rank)
|
|
721
|
+
|
|
722
|
+
def receive_initial_weights(self) -> Any:
|
|
723
|
+
"""Receive initial weights during connect() without TCPStore signaling.
|
|
724
|
+
|
|
725
|
+
This is used for the initial weight sync during connect() to avoid
|
|
726
|
+
interfering with the main collection loop's TCPStore-based coordination.
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
The received weights TensorDict.
|
|
730
|
+
"""
|
|
731
|
+
if self._sync:
|
|
732
|
+
self._weights_buffer.recv(src=0)
|
|
733
|
+
else:
|
|
734
|
+
self._weights_buffer.irecv(src=0)
|
|
735
|
+
return self._weights_buffer
|
|
736
|
+
|
|
737
|
+
def setup_connection_and_weights_on_sender(self) -> None:
|
|
738
|
+
"""No-op for DistributedTransport - handled by scheme."""
|
|
739
|
+
|
|
740
|
+
def setup_connection_and_weights_on_receiver(
|
|
741
|
+
self,
|
|
742
|
+
*,
|
|
743
|
+
worker_idx: int,
|
|
744
|
+
weights: Any = None,
|
|
745
|
+
model: Any = None,
|
|
746
|
+
strategy: WeightStrategy | None = None,
|
|
747
|
+
) -> Any:
|
|
748
|
+
"""No-op for DistributedTransport - handled by scheme."""
|
|
749
|
+
return None
|