torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1244 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import threading
|
|
9
|
+
import warnings
|
|
10
|
+
import weakref
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from collections.abc import Callable, Iterator
|
|
13
|
+
from typing import Any, Literal, overload, Protocol
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from tensordict import TensorDict, TensorDictBase
|
|
18
|
+
from torch import nn
|
|
19
|
+
from torchrl._utils import logger as torchrl_logger
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"TransportBackend",
|
|
23
|
+
"WeightStrategy",
|
|
24
|
+
"WeightSyncScheme",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
from torchrl.weight_update.utils import _resolve_model
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ============================================================================
|
|
31
|
+
# Transport Layer Abstraction
|
|
32
|
+
# ============================================================================
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TransportBackend(Protocol):
|
|
36
|
+
"""Abstract interface for different communication mechanisms."""
|
|
37
|
+
|
|
38
|
+
def send_weights(self, weights: Any) -> None:
|
|
39
|
+
"""Send weights to the receiver."""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
def receive_weights(
|
|
43
|
+
self,
|
|
44
|
+
timeout: float | None = None,
|
|
45
|
+
*,
|
|
46
|
+
weights: Any = None,
|
|
47
|
+
model: Any = None,
|
|
48
|
+
strategy: WeightStrategy | None = None,
|
|
49
|
+
) -> Any | None:
|
|
50
|
+
"""Receive weights from the sender and apply them to the model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
timeout: Maximum time to wait for weights (seconds).
|
|
54
|
+
None means no timeout (blocking). Some transports may not
|
|
55
|
+
support timeout and will raise ValueError if specified.
|
|
56
|
+
weights: Pre-allocated weight buffer to receive into.
|
|
57
|
+
model: The model to apply weights to.
|
|
58
|
+
strategy: Strategy for applying weights to the model.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The received/applied weights, or None if timeout/no weights available.
|
|
62
|
+
"""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
def setup_connection_and_weights_on_sender(self) -> None:
|
|
66
|
+
"""Synchronize weights on sender side before collection starts.
|
|
67
|
+
|
|
68
|
+
This is called once after workers are initialized to send the initial
|
|
69
|
+
weights. This can be a no-op (weights are sent via
|
|
70
|
+
send_weights).
|
|
71
|
+
"""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
def setup_connection_and_weights_on_receiver(
|
|
75
|
+
self,
|
|
76
|
+
*,
|
|
77
|
+
worker_idx: int,
|
|
78
|
+
weights: Any = None,
|
|
79
|
+
model: Any = None,
|
|
80
|
+
strategy: WeightStrategy | None = None,
|
|
81
|
+
) -> Any:
|
|
82
|
+
"""Synchronize weights on worker side before collection starts.
|
|
83
|
+
|
|
84
|
+
This is called once in each worker after initialization to receive
|
|
85
|
+
the initial weights. This is a no-op (weights are received via
|
|
86
|
+
receive_weights).
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
worker_idx: The worker index.
|
|
90
|
+
weights: Pre-allocated weight buffer to receive into.
|
|
91
|
+
model: The model to apply weights to.
|
|
92
|
+
strategy: Strategy for applying weights to the model.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
The received weights (for SharedMemTransport) or None.
|
|
96
|
+
"""
|
|
97
|
+
...
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# ============================================================================
|
|
101
|
+
# Weight Strategies
|
|
102
|
+
# ============================================================================
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class WeightStrategy:
|
|
106
|
+
"""Unified strategy for weight transmission.
|
|
107
|
+
|
|
108
|
+
This strategy handles both extraction and application of weights, supporting
|
|
109
|
+
both TensorDict and state_dict formats.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
extract_as (str): Format for extracting weights. Can be:
|
|
113
|
+
- "tensordict" (default): Extract weights as TensorDict
|
|
114
|
+
- "state_dict": Extract weights as PyTorch state_dict
|
|
115
|
+
|
|
116
|
+
The application format is automatically detected based on the type of weights
|
|
117
|
+
received (dict -> state_dict, TensorDict -> tensordict).
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"):
|
|
121
|
+
if extract_as == "state_dict":
|
|
122
|
+
warnings.warn(
|
|
123
|
+
"state_dict strategy is experimental. Use tensordict strategy for safer weight updates.",
|
|
124
|
+
UserWarning,
|
|
125
|
+
)
|
|
126
|
+
if extract_as not in ("tensordict", "state_dict"):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}"
|
|
129
|
+
)
|
|
130
|
+
self.extract_as = extract_as
|
|
131
|
+
|
|
132
|
+
def extract_weights(self, source: Any) -> TensorDictBase | dict | None:
|
|
133
|
+
"""Extract weights from source model in the specified format.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
source: The model to extract weights from. Can be:
|
|
137
|
+
- nn.Module: PyTorch module
|
|
138
|
+
- TensorDictBase: TensorDict
|
|
139
|
+
- dict: State dictionary
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Weights in the format specified by `extract_as` constructor argument.
|
|
143
|
+
"""
|
|
144
|
+
if self.extract_as == "tensordict":
|
|
145
|
+
# Extract as TensorDict
|
|
146
|
+
if isinstance(source, nn.Module):
|
|
147
|
+
return TensorDict.from_module(source)
|
|
148
|
+
elif isinstance(source, TensorDictBase):
|
|
149
|
+
return source
|
|
150
|
+
elif isinstance(source, dict):
|
|
151
|
+
# Convert state_dict to TensorDict
|
|
152
|
+
return TensorDict(source, batch_size=[])
|
|
153
|
+
else:
|
|
154
|
+
torchrl_logger.warning(
|
|
155
|
+
f"Unsupported source type for TensorDict extraction: {type(source)}"
|
|
156
|
+
)
|
|
157
|
+
return TensorDict(lock=True)
|
|
158
|
+
elif self.extract_as == "state_dict": # state_dict
|
|
159
|
+
# Extract as state_dict
|
|
160
|
+
if isinstance(source, nn.Module):
|
|
161
|
+
return source.state_dict()
|
|
162
|
+
elif isinstance(source, dict):
|
|
163
|
+
return source
|
|
164
|
+
elif isinstance(source, TensorDictBase):
|
|
165
|
+
# Convert TensorDict to state_dict
|
|
166
|
+
return source.flatten_keys().to_dict()
|
|
167
|
+
else:
|
|
168
|
+
torchrl_logger.warning(
|
|
169
|
+
f"Unsupported source type for TensorDict extraction: {type(source)}"
|
|
170
|
+
)
|
|
171
|
+
return {}
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def apply_weights(
|
|
178
|
+
self, destination: Any, weights: Any, inplace: bool = True
|
|
179
|
+
) -> None:
|
|
180
|
+
"""Apply weights to destination model.
|
|
181
|
+
|
|
182
|
+
The format is automatically detected from the weights type:
|
|
183
|
+
- dict -> state_dict format
|
|
184
|
+
- TensorDictBase -> tensordict format
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
destination: The model to apply weights to. Can be:
|
|
188
|
+
- nn.Module: PyTorch module
|
|
189
|
+
- TensorDictBase: TensorDict
|
|
190
|
+
- dict: State dictionary
|
|
191
|
+
weights: The weights to apply (dict or TensorDictBase).
|
|
192
|
+
inplace: Whether to apply weights in place.
|
|
193
|
+
"""
|
|
194
|
+
if weights is None:
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
# Auto-detect format from weights type
|
|
198
|
+
if isinstance(weights, dict):
|
|
199
|
+
weights = TensorDict(weights)
|
|
200
|
+
if any("." in key for key in weights.keys()):
|
|
201
|
+
weights = weights.unflatten_keys(".")
|
|
202
|
+
if isinstance(destination, nn.Module):
|
|
203
|
+
# Do not update in-place
|
|
204
|
+
if not inplace:
|
|
205
|
+
weights.to_module(destination)
|
|
206
|
+
return
|
|
207
|
+
else:
|
|
208
|
+
destination = TensorDict.from_module(destination)
|
|
209
|
+
elif isinstance(destination, dict):
|
|
210
|
+
if not inplace:
|
|
211
|
+
raise ValueError("Cannot update state_dict out of place")
|
|
212
|
+
destination = TensorDict(destination)
|
|
213
|
+
if any(isinstance(key, str) and "." in key for key in destination.keys()):
|
|
214
|
+
destination = destination.unflatten_keys(".")
|
|
215
|
+
|
|
216
|
+
if not isinstance(weights, TensorDictBase):
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Unsupported weights type: {type(weights)}. Must be dict or TensorDictBase."
|
|
219
|
+
)
|
|
220
|
+
if not isinstance(destination, TensorDictBase):
|
|
221
|
+
if not weights.is_empty():
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"Non-empty weights are associated with a non-dict, non-td, non-Module destination."
|
|
224
|
+
)
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
if not inplace:
|
|
229
|
+
destination.update(weights)
|
|
230
|
+
else:
|
|
231
|
+
destination.data.update_(weights.data)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
raise KeyError(
|
|
234
|
+
f"Error updating destination. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}"
|
|
235
|
+
) from e
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy:
|
|
240
|
+
"""Get strategy object from string name.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
strategy: Either "tensordict" or "state_dict".
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
WeightStrategy: Strategy configured with the specified extraction format.
|
|
247
|
+
"""
|
|
248
|
+
if strategy not in ("tensordict", "state_dict"):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"Unknown strategy: {strategy}. Must be 'tensordict' or 'state_dict'."
|
|
251
|
+
)
|
|
252
|
+
return WeightStrategy(extract_as=strategy)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# ============================================================================
|
|
256
|
+
# Weight Synchronization Schemes
|
|
257
|
+
# ============================================================================
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class WeightSyncScheme(metaclass=abc.ABCMeta):
|
|
261
|
+
"""Configuration for how to synchronize ONE model across workers.
|
|
262
|
+
|
|
263
|
+
A scheme manages synchronization of ONE model across workers.
|
|
264
|
+
The collector maintains a dict of {model_id: scheme} pairs.
|
|
265
|
+
|
|
266
|
+
This class directly handles both sender and receiver functionality,
|
|
267
|
+
with behavior determined by whether init_on_sender() or init_on_receiver()
|
|
268
|
+
was called.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
_model_id: str | None = None
|
|
272
|
+
|
|
273
|
+
# Transport management
|
|
274
|
+
_sender_transports: dict[int, TransportBackend] | None
|
|
275
|
+
_receiver_transport: TransportBackend | None
|
|
276
|
+
_shared_transport: TransportBackend | None
|
|
277
|
+
|
|
278
|
+
# Context and model references
|
|
279
|
+
_context_ref: weakref.ReferenceType[Any] | None
|
|
280
|
+
_model_ref: weakref.ReferenceType[Any] | None
|
|
281
|
+
|
|
282
|
+
# Strategy
|
|
283
|
+
_strategy: WeightStrategy
|
|
284
|
+
|
|
285
|
+
# Worker index (for receiver side)
|
|
286
|
+
_worker_idx: int | None
|
|
287
|
+
|
|
288
|
+
# Background thread
|
|
289
|
+
_background_thread = None
|
|
290
|
+
_stop_event = None
|
|
291
|
+
|
|
292
|
+
def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"):
|
|
293
|
+
self.strategy_str = strategy
|
|
294
|
+
self._strategy = _get_strategy(strategy)
|
|
295
|
+
self._initialized_on_sender = False
|
|
296
|
+
self._initialized_on_receiver = False
|
|
297
|
+
|
|
298
|
+
# Transport management
|
|
299
|
+
self._sender_transports = None # worker_idx -> transport
|
|
300
|
+
self._receiver_transport = None
|
|
301
|
+
self._shared_transport = None
|
|
302
|
+
|
|
303
|
+
# Context and model references
|
|
304
|
+
self._context_ref = None
|
|
305
|
+
self._model_ref = None
|
|
306
|
+
|
|
307
|
+
# Worker index
|
|
308
|
+
self._worker_idx = None
|
|
309
|
+
|
|
310
|
+
# ========================================================================
|
|
311
|
+
# Initialization
|
|
312
|
+
# ========================================================================
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def strategy(self) -> WeightStrategy:
|
|
316
|
+
return self._strategy
|
|
317
|
+
|
|
318
|
+
@strategy.setter
|
|
319
|
+
def strategy(self, value: WeightStrategy) -> None:
|
|
320
|
+
self._strategy = value
|
|
321
|
+
|
|
322
|
+
@overload
|
|
323
|
+
def init_on_sender(
|
|
324
|
+
self,
|
|
325
|
+
*,
|
|
326
|
+
model_id: str,
|
|
327
|
+
context: Any,
|
|
328
|
+
) -> None:
|
|
329
|
+
...
|
|
330
|
+
|
|
331
|
+
@overload
|
|
332
|
+
def init_on_sender(
|
|
333
|
+
self,
|
|
334
|
+
*,
|
|
335
|
+
params_map: dict[int, TensorDictBase],
|
|
336
|
+
model_id: str | None = None,
|
|
337
|
+
) -> None:
|
|
338
|
+
...
|
|
339
|
+
|
|
340
|
+
@overload
|
|
341
|
+
def init_on_sender(
|
|
342
|
+
self,
|
|
343
|
+
*,
|
|
344
|
+
params_map: dict[int, TensorDictBase],
|
|
345
|
+
) -> None:
|
|
346
|
+
...
|
|
347
|
+
|
|
348
|
+
@overload
|
|
349
|
+
def init_on_sender(
|
|
350
|
+
self,
|
|
351
|
+
*,
|
|
352
|
+
weights: TensorDictBase,
|
|
353
|
+
devices: list[torch.device],
|
|
354
|
+
) -> None:
|
|
355
|
+
...
|
|
356
|
+
|
|
357
|
+
@overload
|
|
358
|
+
def init_on_sender(
|
|
359
|
+
self,
|
|
360
|
+
*,
|
|
361
|
+
weights: TensorDictBase,
|
|
362
|
+
devices: list[torch.device],
|
|
363
|
+
model_id: str | None = None,
|
|
364
|
+
) -> None:
|
|
365
|
+
...
|
|
366
|
+
|
|
367
|
+
@overload
|
|
368
|
+
def init_on_sender(
|
|
369
|
+
self,
|
|
370
|
+
*,
|
|
371
|
+
model: nn.Module,
|
|
372
|
+
devices: list[torch.device],
|
|
373
|
+
) -> None:
|
|
374
|
+
...
|
|
375
|
+
|
|
376
|
+
@overload
|
|
377
|
+
def init_on_sender(
|
|
378
|
+
self,
|
|
379
|
+
*,
|
|
380
|
+
model: nn.Module,
|
|
381
|
+
devices: list[torch.device],
|
|
382
|
+
model_id: str | None = None,
|
|
383
|
+
) -> None:
|
|
384
|
+
...
|
|
385
|
+
|
|
386
|
+
@overload
|
|
387
|
+
def init_on_sender(
|
|
388
|
+
self,
|
|
389
|
+
*,
|
|
390
|
+
weights: TensorDictBase,
|
|
391
|
+
device_map_fn: Callable[[int, TensorDictBase], TensorDictBase],
|
|
392
|
+
num_workers: int,
|
|
393
|
+
) -> None:
|
|
394
|
+
...
|
|
395
|
+
|
|
396
|
+
@overload
|
|
397
|
+
def init_on_sender(
|
|
398
|
+
self,
|
|
399
|
+
*,
|
|
400
|
+
model: nn.Module,
|
|
401
|
+
device_map_fn: Callable[[int, TensorDictBase], TensorDictBase],
|
|
402
|
+
num_workers: int,
|
|
403
|
+
model_id: str | None = None,
|
|
404
|
+
) -> None:
|
|
405
|
+
...
|
|
406
|
+
|
|
407
|
+
@overload
|
|
408
|
+
def init_on_sender(self):
|
|
409
|
+
...
|
|
410
|
+
|
|
411
|
+
def init_on_sender(
|
|
412
|
+
self,
|
|
413
|
+
*args,
|
|
414
|
+
**kwargs,
|
|
415
|
+
) -> None:
|
|
416
|
+
"""Initialize on the main process (sender side).
|
|
417
|
+
|
|
418
|
+
This method is called once in the collector's _run_processes() method,
|
|
419
|
+
after workers have been started and are ready to receive messages.
|
|
420
|
+
"""
|
|
421
|
+
self._initialized_on_sender = True
|
|
422
|
+
try:
|
|
423
|
+
result = self._init_on_sender_impl(*args, **kwargs)
|
|
424
|
+
except Exception:
|
|
425
|
+
self._initialized_on_sender = False
|
|
426
|
+
raise
|
|
427
|
+
return result
|
|
428
|
+
|
|
429
|
+
def _init_on_sender_impl(self, *args, **kwargs):
|
|
430
|
+
raise NotImplementedError
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def initialized_on_sender(self):
|
|
434
|
+
return getattr(self, "_initialized_on_sender", False)
|
|
435
|
+
|
|
436
|
+
@property
|
|
437
|
+
def initialized_on_receiver(self):
|
|
438
|
+
return getattr(self, "_initialized_on_receiver", False)
|
|
439
|
+
|
|
440
|
+
@overload
|
|
441
|
+
def init_on_receiver(
|
|
442
|
+
self,
|
|
443
|
+
model_id: str,
|
|
444
|
+
context: Any,
|
|
445
|
+
**kwargs,
|
|
446
|
+
) -> None:
|
|
447
|
+
...
|
|
448
|
+
|
|
449
|
+
@overload
|
|
450
|
+
def init_on_receiver(
|
|
451
|
+
self,
|
|
452
|
+
model_id: str,
|
|
453
|
+
context: None = None,
|
|
454
|
+
*,
|
|
455
|
+
worker_idx: int = ...,
|
|
456
|
+
model: Any | None = None,
|
|
457
|
+
**kwargs,
|
|
458
|
+
) -> None:
|
|
459
|
+
...
|
|
460
|
+
|
|
461
|
+
def init_on_receiver(
|
|
462
|
+
self,
|
|
463
|
+
*,
|
|
464
|
+
model_id: str,
|
|
465
|
+
context: Any = None,
|
|
466
|
+
**kwargs,
|
|
467
|
+
) -> None:
|
|
468
|
+
"""Initialize on worker process (receiver side).
|
|
469
|
+
|
|
470
|
+
This method is called once in each worker's initialization.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
model_id: Identifier for the model being synchronized
|
|
474
|
+
context: Optional context object (e.g., inner collector)
|
|
475
|
+
**kwargs: Alternative to context (model, etc.)
|
|
476
|
+
"""
|
|
477
|
+
if self.initialized_on_sender:
|
|
478
|
+
# emulate pickling to erase the current state
|
|
479
|
+
self.__setstate__(self.__getstate__())
|
|
480
|
+
|
|
481
|
+
self._initialized_on_receiver = True
|
|
482
|
+
try:
|
|
483
|
+
result = self._init_on_receiver_impl(
|
|
484
|
+
model_id=model_id, context=context, **kwargs
|
|
485
|
+
)
|
|
486
|
+
except Exception:
|
|
487
|
+
self._initialized_on_receiver = False
|
|
488
|
+
raise
|
|
489
|
+
return result
|
|
490
|
+
|
|
491
|
+
def _init_on_receiver_impl(
|
|
492
|
+
self,
|
|
493
|
+
model_id: str,
|
|
494
|
+
context: Any = None,
|
|
495
|
+
**kwargs,
|
|
496
|
+
) -> None:
|
|
497
|
+
raise NotImplementedError
|
|
498
|
+
|
|
499
|
+
# ========================================================================
|
|
500
|
+
# Context and Model Management
|
|
501
|
+
# ========================================================================
|
|
502
|
+
|
|
503
|
+
@property
|
|
504
|
+
def context(self) -> Any | None:
|
|
505
|
+
"""Get the context object (e.g., collector), if available.
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
The context object if available, None otherwise.
|
|
509
|
+
"""
|
|
510
|
+
if self._context_ref is not None:
|
|
511
|
+
return self._context_ref()
|
|
512
|
+
return None
|
|
513
|
+
|
|
514
|
+
@context.setter
|
|
515
|
+
def context(self, context: Any) -> None:
|
|
516
|
+
"""Set the context object for resolving references.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
context: The context object to resolve references from.
|
|
520
|
+
"""
|
|
521
|
+
if context is not None:
|
|
522
|
+
self._context_ref = weakref.ref(context)
|
|
523
|
+
else:
|
|
524
|
+
self._context_ref = None
|
|
525
|
+
|
|
526
|
+
@property
|
|
527
|
+
def model_id(self) -> str | None:
|
|
528
|
+
"""Get the model ID for this scheme.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
The model ID if set, None otherwise.
|
|
532
|
+
"""
|
|
533
|
+
return self._model_id
|
|
534
|
+
|
|
535
|
+
@model_id.setter
|
|
536
|
+
def model_id(self, model_id: str) -> None:
|
|
537
|
+
"""Set the model ID for this scheme.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
model_id: The model ID to set.
|
|
541
|
+
"""
|
|
542
|
+
self._model_id = model_id
|
|
543
|
+
|
|
544
|
+
@property
|
|
545
|
+
def worker_idx(self) -> int | None:
|
|
546
|
+
"""Get the worker index for this scheme.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
The worker index if set, None otherwise.
|
|
550
|
+
"""
|
|
551
|
+
return self._worker_idx
|
|
552
|
+
|
|
553
|
+
@worker_idx.setter
|
|
554
|
+
def worker_idx(self, worker_idx: int | None) -> None:
|
|
555
|
+
"""Set the worker index for this scheme.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
worker_idx: The worker index to set.
|
|
559
|
+
"""
|
|
560
|
+
if self.initialized_on_sender and worker_idx is not None:
|
|
561
|
+
raise RuntimeError(
|
|
562
|
+
"Worker index cannot be set after initialization on sender"
|
|
563
|
+
)
|
|
564
|
+
self._worker_idx = worker_idx
|
|
565
|
+
|
|
566
|
+
@property
|
|
567
|
+
def model(self) -> Any | None:
|
|
568
|
+
"""Get the model object, if available.
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
The model object if available, None otherwise.
|
|
572
|
+
"""
|
|
573
|
+
if self._model_ref is not None:
|
|
574
|
+
return self._model_ref()
|
|
575
|
+
if self._model_id is not None:
|
|
576
|
+
model = _resolve_model(self.context, self._model_id)
|
|
577
|
+
if model is None:
|
|
578
|
+
raise AttributeError(
|
|
579
|
+
f"Model {self._model_id} was `None` in context {self.context}"
|
|
580
|
+
)
|
|
581
|
+
self._model_ref = weakref.ref(model)
|
|
582
|
+
return model
|
|
583
|
+
|
|
584
|
+
@model.setter
|
|
585
|
+
def model(self, model: Any) -> None:
|
|
586
|
+
"""Set the model object for applying weights.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
model: The model object to apply weights to.
|
|
590
|
+
"""
|
|
591
|
+
if model is not None:
|
|
592
|
+
self._model_ref = weakref.ref(model)
|
|
593
|
+
else:
|
|
594
|
+
self._model_ref = None
|
|
595
|
+
|
|
596
|
+
@property
|
|
597
|
+
def weights(self) -> Any | None:
|
|
598
|
+
"""Get the current weights, if available.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
The weights as TensorDict if available, None otherwise.
|
|
602
|
+
"""
|
|
603
|
+
if (weights := getattr(self, "_weights", None)) is not None:
|
|
604
|
+
return weights
|
|
605
|
+
model = self.model
|
|
606
|
+
if model is not None:
|
|
607
|
+
return self._strategy.extract_weights(model)
|
|
608
|
+
return None
|
|
609
|
+
|
|
610
|
+
@weights.setter
|
|
611
|
+
def weights(self, value: Any):
|
|
612
|
+
self._weights = value
|
|
613
|
+
|
|
614
|
+
def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase:
|
|
615
|
+
from torchrl.collectors.utils import _cast
|
|
616
|
+
|
|
617
|
+
if isinstance(model, torch.nn.Module):
|
|
618
|
+
td = TensorDict.from_module(model)
|
|
619
|
+
td = td.data.apply(_cast, td)
|
|
620
|
+
return td
|
|
621
|
+
# Return an empty TD
|
|
622
|
+
return TensorDict()
|
|
623
|
+
|
|
624
|
+
# ========================================================================
|
|
625
|
+
# Transport Management
|
|
626
|
+
# ========================================================================
|
|
627
|
+
|
|
628
|
+
def _register_worker_sender(
|
|
629
|
+
self,
|
|
630
|
+
*,
|
|
631
|
+
worker_idx: int,
|
|
632
|
+
transport: TransportBackend | None = None,
|
|
633
|
+
**transport_kwargs,
|
|
634
|
+
) -> None:
|
|
635
|
+
"""Register a worker's communication.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
worker_idx: The worker index.
|
|
639
|
+
transport: Optional pre-created transport.
|
|
640
|
+
**transport_kwargs: Transport-specific configuration.
|
|
641
|
+
"""
|
|
642
|
+
if self._sender_transports is None:
|
|
643
|
+
if self._shared_transport is not None:
|
|
644
|
+
raise RuntimeError(
|
|
645
|
+
"Cannot register transports on sender after shared transport is set"
|
|
646
|
+
)
|
|
647
|
+
self._sender_transports = {}
|
|
648
|
+
if worker_idx not in self._sender_transports:
|
|
649
|
+
if transport is not None:
|
|
650
|
+
self._sender_transports[worker_idx] = transport
|
|
651
|
+
else:
|
|
652
|
+
self._sender_transports[worker_idx] = self.create_transport(
|
|
653
|
+
**transport_kwargs
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
def _register_transport_receiver(
|
|
657
|
+
self, transport: TransportBackend | None = None, **transport_kwargs
|
|
658
|
+
) -> None:
|
|
659
|
+
"""Register a single transport (for receiver side).
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
transport: Optional pre-created transport.
|
|
663
|
+
**transport_kwargs: Transport-specific configuration.
|
|
664
|
+
"""
|
|
665
|
+
if transport is not None:
|
|
666
|
+
self._receiver_transport = transport
|
|
667
|
+
else:
|
|
668
|
+
self._receiver_transport = self.create_transport(**transport_kwargs)
|
|
669
|
+
|
|
670
|
+
def _iterate_transports(
|
|
671
|
+
self, worker_ids: int | list[int] | None = None
|
|
672
|
+
) -> Iterator[TransportBackend]:
|
|
673
|
+
"""Iterate over transports for specified workers."""
|
|
674
|
+
if worker_ids is None:
|
|
675
|
+
# All workers
|
|
676
|
+
if not self.sender_transports:
|
|
677
|
+
if self.receiver_transport is not None:
|
|
678
|
+
yield self.receiver_transport
|
|
679
|
+
else:
|
|
680
|
+
# Make sure transports are sorted
|
|
681
|
+
for k in sorted(self.sender_transports.keys()):
|
|
682
|
+
yield self.sender_transports[k]
|
|
683
|
+
else:
|
|
684
|
+
# Specific workers
|
|
685
|
+
if isinstance(worker_ids, int):
|
|
686
|
+
worker_ids = [worker_ids]
|
|
687
|
+
for worker_id in worker_ids:
|
|
688
|
+
if worker_id in self.sender_transports:
|
|
689
|
+
yield self.sender_transports[worker_id]
|
|
690
|
+
else:
|
|
691
|
+
raise ValueError(f"Worker {worker_id} not registered")
|
|
692
|
+
|
|
693
|
+
@abc.abstractmethod
|
|
694
|
+
def create_transport(self, **kwargs) -> TransportBackend:
|
|
695
|
+
"""Create transport for communication.
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
**kwargs: Transport-specific configuration parameters.
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
A transport backend instance.
|
|
702
|
+
|
|
703
|
+
Note:
|
|
704
|
+
This is used internally by init_on_sender/init_on_receiver.
|
|
705
|
+
"""
|
|
706
|
+
...
|
|
707
|
+
|
|
708
|
+
@property
|
|
709
|
+
def sender_transports(self) -> dict[int, TransportBackend]:
|
|
710
|
+
"""Get the sender transports.
|
|
711
|
+
|
|
712
|
+
Returns:
|
|
713
|
+
The sender transports.
|
|
714
|
+
"""
|
|
715
|
+
if self._shared_transport is not None:
|
|
716
|
+
return defaultdict(lambda: self._shared_transport)
|
|
717
|
+
return self._sender_transports
|
|
718
|
+
|
|
719
|
+
@property
|
|
720
|
+
def receiver_transport(self) -> TransportBackend | None:
|
|
721
|
+
"""Get the receiver transport.
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
The receiver transport.
|
|
725
|
+
"""
|
|
726
|
+
if self._shared_transport is not None:
|
|
727
|
+
return self._shared_transport
|
|
728
|
+
return self._receiver_transport
|
|
729
|
+
|
|
730
|
+
@property
|
|
731
|
+
def shared_transport(self) -> TransportBackend | None:
|
|
732
|
+
"""Get the shared transport.
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
The shared transport.
|
|
736
|
+
"""
|
|
737
|
+
if self._receiver_transport is not None:
|
|
738
|
+
raise RuntimeError(
|
|
739
|
+
"Receiver transport and shared transport cannot be used together"
|
|
740
|
+
)
|
|
741
|
+
if self._sender_transports is not None:
|
|
742
|
+
raise RuntimeError(
|
|
743
|
+
"Sender transports and shared transport cannot be used together"
|
|
744
|
+
)
|
|
745
|
+
return self._shared_transport
|
|
746
|
+
|
|
747
|
+
@shared_transport.setter
|
|
748
|
+
def shared_transport(self, shared_transport: TransportBackend | None) -> None:
|
|
749
|
+
"""Set the shared transport.
|
|
750
|
+
|
|
751
|
+
Args:
|
|
752
|
+
shared_transport: The shared transport to set.
|
|
753
|
+
"""
|
|
754
|
+
self._shared_transport = shared_transport
|
|
755
|
+
|
|
756
|
+
# ========================================================================
|
|
757
|
+
# Sending Weights (Sender Side)
|
|
758
|
+
# ========================================================================
|
|
759
|
+
|
|
760
|
+
def send(
|
|
761
|
+
self,
|
|
762
|
+
weights: Any = None,
|
|
763
|
+
worker_ids: int | list[int] | None = None,
|
|
764
|
+
) -> None:
|
|
765
|
+
"""Send weights synchronously to workers.
|
|
766
|
+
|
|
767
|
+
This method:
|
|
768
|
+
1. Prepares weights (extracts from model if weights=None)
|
|
769
|
+
2. Sends to specified workers (or all if worker_ids=None)
|
|
770
|
+
3. Waits for acknowledgments from those workers
|
|
771
|
+
4. Returns when workers have applied the weights
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
weights: Weights to send. Can be:
|
|
775
|
+
- None: Extract from model via context.get_model(model_id)
|
|
776
|
+
- nn.Module: Extract weights from module
|
|
777
|
+
- TensorDict: Use directly
|
|
778
|
+
- dict: Convert to TensorDict
|
|
779
|
+
worker_ids: Which workers to send to:
|
|
780
|
+
- None: Send to all workers (default)
|
|
781
|
+
- int: Send to single worker
|
|
782
|
+
- list[int]: Send to specific workers
|
|
783
|
+
|
|
784
|
+
Note: This is a blocking call that ensures specified workers are updated
|
|
785
|
+
before returning.
|
|
786
|
+
"""
|
|
787
|
+
if not self.initialized_on_sender:
|
|
788
|
+
raise RuntimeError("Must be initialized on sender before sending weights")
|
|
789
|
+
if not self.synchronized_on_sender:
|
|
790
|
+
raise RuntimeError("Must be synchronized on sender before sending weights")
|
|
791
|
+
|
|
792
|
+
context = self.context
|
|
793
|
+
|
|
794
|
+
# Let the scheme prepare the weights
|
|
795
|
+
prepared_weights = self.prepare_weights(
|
|
796
|
+
weights=weights,
|
|
797
|
+
model_id=self._model_id,
|
|
798
|
+
strategy=self._strategy,
|
|
799
|
+
context=context,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
transports = list(self._iterate_transports(worker_ids))
|
|
803
|
+
|
|
804
|
+
if not transports:
|
|
805
|
+
raise RuntimeError("No transports available.")
|
|
806
|
+
|
|
807
|
+
# Send to all workers first (non-blocking if transport supports it)
|
|
808
|
+
for transport in transports:
|
|
809
|
+
if hasattr(transport, "send_weights_async"):
|
|
810
|
+
transport.send_weights_async(prepared_weights)
|
|
811
|
+
else:
|
|
812
|
+
# Fallback for transports that don't support async send
|
|
813
|
+
transport.send_weights(prepared_weights)
|
|
814
|
+
|
|
815
|
+
# Wait for all acknowledgments
|
|
816
|
+
for transport in transports:
|
|
817
|
+
if hasattr(transport, "wait_ack"):
|
|
818
|
+
transport.wait_ack()
|
|
819
|
+
|
|
820
|
+
def prepare_weights(
|
|
821
|
+
self,
|
|
822
|
+
weights: Any,
|
|
823
|
+
model_id: str,
|
|
824
|
+
strategy: WeightStrategy,
|
|
825
|
+
context: Any = None,
|
|
826
|
+
) -> Any:
|
|
827
|
+
"""Prepare weights for sending.
|
|
828
|
+
|
|
829
|
+
This method handles weight extraction, conversion, and any scheme-specific
|
|
830
|
+
preparation (e.g., cache lookups for SharedMemWeightSyncScheme).
|
|
831
|
+
|
|
832
|
+
Args:
|
|
833
|
+
weights: Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.)
|
|
834
|
+
model_id: The model identifier (e.g., "policy")
|
|
835
|
+
strategy: WeightStrategy for extracting/converting weights
|
|
836
|
+
context: Optional context (e.g., collector) for model resolution
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
Prepared weights ready to send via transport
|
|
840
|
+
"""
|
|
841
|
+
# Default implementation: extract from model or pass through
|
|
842
|
+
if weights is None and context is not None:
|
|
843
|
+
# Try to resolve and extract from model in context
|
|
844
|
+
try:
|
|
845
|
+
model = _resolve_model(context, model_id)
|
|
846
|
+
return strategy.extract_weights(model)
|
|
847
|
+
except (AttributeError, KeyError):
|
|
848
|
+
pass
|
|
849
|
+
# Try fallback policy
|
|
850
|
+
if model_id == "policy" and hasattr(context, "_fallback_policy"):
|
|
851
|
+
if context._fallback_policy is not None:
|
|
852
|
+
return strategy.extract_weights(context._fallback_policy)
|
|
853
|
+
return None
|
|
854
|
+
|
|
855
|
+
if isinstance(weights, nn.Module):
|
|
856
|
+
return strategy.extract_weights(weights)
|
|
857
|
+
elif isinstance(weights, str):
|
|
858
|
+
# String reference to model
|
|
859
|
+
if context is not None:
|
|
860
|
+
model = _resolve_model(context, weights)
|
|
861
|
+
return strategy.extract_weights(model)
|
|
862
|
+
raise ValueError(
|
|
863
|
+
f"Cannot resolve string reference '{weights}' without context"
|
|
864
|
+
)
|
|
865
|
+
else:
|
|
866
|
+
# Already extracted weights (TensorDict, dict, etc.)
|
|
867
|
+
return weights
|
|
868
|
+
|
|
869
|
+
# ========================================================================
|
|
870
|
+
# Receiving Weights (Receiver Side)
|
|
871
|
+
# ========================================================================
|
|
872
|
+
|
|
873
|
+
def receive(self, timeout: float | None = None) -> TensorDictBase | None:
|
|
874
|
+
"""Check for and apply new weights (non-blocking).
|
|
875
|
+
|
|
876
|
+
This method is called in the worker's main loop to check if
|
|
877
|
+
new weights have been sent. If weights are available, they
|
|
878
|
+
are applied to the registered model immediately, and the update
|
|
879
|
+
is cascaded to any sub-collectors via context.update_policy_weights_().
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
timeout: Maximum time to wait for weights (seconds).
|
|
883
|
+
None means no timeout (blocking). Some transports may not
|
|
884
|
+
support timeout and will raise ValueError if specified.
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
The received weights if available, None otherwise.
|
|
888
|
+
|
|
889
|
+
Note: For SharedMemWeightSyncScheme, this always returns None
|
|
890
|
+
since workers automatically see updates via shared memory.
|
|
891
|
+
"""
|
|
892
|
+
if not self.initialized_on_receiver:
|
|
893
|
+
raise RuntimeError(
|
|
894
|
+
"Must be initialized on receiver before receiving weights"
|
|
895
|
+
)
|
|
896
|
+
if not self.synchronized_on_receiver:
|
|
897
|
+
raise RuntimeError(
|
|
898
|
+
"Must be synchronized on receiver before receiving weights"
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
# Determine which transport to use
|
|
902
|
+
if self._receiver_transport is not None:
|
|
903
|
+
transport = self._receiver_transport
|
|
904
|
+
elif self._shared_transport is not None:
|
|
905
|
+
# Use shared transport directly (e.g., SharedMemWeightSyncScheme)
|
|
906
|
+
transport = self._shared_transport
|
|
907
|
+
else:
|
|
908
|
+
return None
|
|
909
|
+
|
|
910
|
+
# Try to receive weights - transport handles receiving and applying
|
|
911
|
+
result = transport.receive_weights(
|
|
912
|
+
timeout=timeout,
|
|
913
|
+
weights=self.weights,
|
|
914
|
+
model=self.model,
|
|
915
|
+
strategy=self._strategy,
|
|
916
|
+
)
|
|
917
|
+
if result is None:
|
|
918
|
+
return None
|
|
919
|
+
|
|
920
|
+
weights = result
|
|
921
|
+
model_id = self._model_id or "policy"
|
|
922
|
+
|
|
923
|
+
# Cascade weight update to sub-collectors if context supports it
|
|
924
|
+
if self.context is not None and hasattr(self.context, "update_policy_weights_"):
|
|
925
|
+
self.context.update_policy_weights_(
|
|
926
|
+
model_id=model_id, policy_or_weights=weights
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
# Send acknowledgment if transport supports it
|
|
930
|
+
if hasattr(transport, "send_ack"):
|
|
931
|
+
transport.send_ack("updated")
|
|
932
|
+
|
|
933
|
+
return weights
|
|
934
|
+
|
|
935
|
+
def apply_weights(self, weights: TensorDictBase, inplace: bool = True) -> None:
|
|
936
|
+
"""Apply weights to the model.
|
|
937
|
+
|
|
938
|
+
Args:
|
|
939
|
+
weights: The weights to apply.
|
|
940
|
+
inplace: Whether to apply weights in place. Default is `True`.
|
|
941
|
+
"""
|
|
942
|
+
if not self.initialized_on_receiver:
|
|
943
|
+
if self.initialized_on_sender:
|
|
944
|
+
raise RuntimeError("apply_weights() called on a sender side.")
|
|
945
|
+
raise RuntimeError(
|
|
946
|
+
"apply_weights() called before init_on_receiver has been called."
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
if self._model_ref is None:
|
|
950
|
+
raise ValueError("No model registered")
|
|
951
|
+
|
|
952
|
+
model = self.model
|
|
953
|
+
self._strategy.apply_weights(model, weights, inplace=inplace)
|
|
954
|
+
|
|
955
|
+
# Send acknowledgment if transport supports it
|
|
956
|
+
if self.receiver_transport is not None and hasattr(
|
|
957
|
+
self.receiver_transport, "send_ack"
|
|
958
|
+
):
|
|
959
|
+
self.receiver_transport.send_ack("updated")
|
|
960
|
+
|
|
961
|
+
# ========================================================================
|
|
962
|
+
# Synchronization
|
|
963
|
+
# ========================================================================
|
|
964
|
+
|
|
965
|
+
@overload
|
|
966
|
+
def connect(self, *, worker_idx: int | None = None) -> None:
|
|
967
|
+
...
|
|
968
|
+
|
|
969
|
+
@overload
|
|
970
|
+
def connect(self, *, weights: Any | None = None) -> None:
|
|
971
|
+
...
|
|
972
|
+
|
|
973
|
+
def connect(
|
|
974
|
+
self, *, worker_idx: int | None = None, weights: Any | None = None
|
|
975
|
+
) -> None:
|
|
976
|
+
"""Method to be called once the workers have started.
|
|
977
|
+
|
|
978
|
+
Triggers a rendez-vous for the workers to receive their copy of the weights.
|
|
979
|
+
|
|
980
|
+
Dispatches to _setup_connection_and_weights_on_sender_impl() or _setup_connection_and_weights_on_receiver_impl()
|
|
981
|
+
based on which initialization was performed.
|
|
982
|
+
"""
|
|
983
|
+
if self.synchronized_on_receiver or self.synchronized_on_sender:
|
|
984
|
+
raise RuntimeError("Cannot synchronize weights on sender twice.")
|
|
985
|
+
if self._initialized_on_sender:
|
|
986
|
+
if worker_idx is not None:
|
|
987
|
+
# Safety check, we can consider removing this in the future.
|
|
988
|
+
raise RuntimeError(
|
|
989
|
+
"Cannot specify worker_idx on sender side during synchronization."
|
|
990
|
+
)
|
|
991
|
+
self.synchronized_on_sender = True
|
|
992
|
+
try:
|
|
993
|
+
self._setup_connection_and_weights_on_sender_impl(weights=weights)
|
|
994
|
+
except Exception:
|
|
995
|
+
self.synchronized_on_sender = False
|
|
996
|
+
raise
|
|
997
|
+
elif self._initialized_on_receiver:
|
|
998
|
+
if weights is not None:
|
|
999
|
+
# safety check: weights are passed to sender, not receiver for initial sync
|
|
1000
|
+
raise RuntimeError(
|
|
1001
|
+
"Cannot specify weights on receiver side during synchronization."
|
|
1002
|
+
)
|
|
1003
|
+
self.synchronized_on_receiver = True
|
|
1004
|
+
try:
|
|
1005
|
+
self._setup_connection_and_weights_on_receiver_impl(
|
|
1006
|
+
worker_idx=worker_idx
|
|
1007
|
+
)
|
|
1008
|
+
except Exception:
|
|
1009
|
+
self.synchronized_on_receiver = False
|
|
1010
|
+
raise
|
|
1011
|
+
else:
|
|
1012
|
+
raise RuntimeError(
|
|
1013
|
+
"Neither init_on_sender nor init_on_receiver have been called."
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
def _setup_connection_and_weights_on_sender_impl(
|
|
1017
|
+
self,
|
|
1018
|
+
*,
|
|
1019
|
+
worker_idx: int | None = None,
|
|
1020
|
+
weights: Any | None = None,
|
|
1021
|
+
) -> None:
|
|
1022
|
+
"""Synchronize weights on sender side.
|
|
1023
|
+
|
|
1024
|
+
Default implementation uses transport's setup_connection_and_weights_on_sender().
|
|
1025
|
+
Subclasses may override for custom behavior.
|
|
1026
|
+
"""
|
|
1027
|
+
if self._shared_transport is not None:
|
|
1028
|
+
# We only need to synchronize once
|
|
1029
|
+
self.shared_transport.setup_connection_and_weights_on_sender()
|
|
1030
|
+
return
|
|
1031
|
+
|
|
1032
|
+
idx = -1
|
|
1033
|
+
for idx, transport in enumerate(self._iterate_transports()):
|
|
1034
|
+
if worker_idx is not None and idx != worker_idx:
|
|
1035
|
+
continue
|
|
1036
|
+
transport.setup_connection_and_weights_on_sender()
|
|
1037
|
+
if idx == -1:
|
|
1038
|
+
raise RuntimeError("No transports available.")
|
|
1039
|
+
|
|
1040
|
+
def _setup_connection_and_weights_on_receiver_impl(
|
|
1041
|
+
self, *, worker_idx: int | None = None
|
|
1042
|
+
) -> None:
|
|
1043
|
+
"""Synchronize weights on receiver side.
|
|
1044
|
+
|
|
1045
|
+
Default implementation uses transport's setup_connection_and_weights_on_receiver().
|
|
1046
|
+
Subclasses may override for custom behavior.
|
|
1047
|
+
"""
|
|
1048
|
+
if self.receiver_transport is None:
|
|
1049
|
+
return
|
|
1050
|
+
|
|
1051
|
+
# Use stored worker_idx if not provided
|
|
1052
|
+
if worker_idx is None:
|
|
1053
|
+
worker_idx = self._worker_idx
|
|
1054
|
+
|
|
1055
|
+
# Call transport's synchronize method with all relevant kwargs
|
|
1056
|
+
weights = self.receiver_transport.setup_connection_and_weights_on_receiver(
|
|
1057
|
+
worker_idx=worker_idx,
|
|
1058
|
+
weights=self.weights,
|
|
1059
|
+
model=self.model,
|
|
1060
|
+
strategy=self._strategy,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
# Apply weights to model if received (SharedMemTransport case)
|
|
1064
|
+
# For other transports (MPTransport, etc.), weights is None and synchronization
|
|
1065
|
+
# happens later via receive(), so this is a no-op
|
|
1066
|
+
if weights is not None:
|
|
1067
|
+
model = self.model
|
|
1068
|
+
self._strategy.apply_weights(model, weights, inplace=False)
|
|
1069
|
+
|
|
1070
|
+
@property
|
|
1071
|
+
def synchronized_on_sender(self):
|
|
1072
|
+
return getattr(self, "_synchronized_on_sender", False)
|
|
1073
|
+
|
|
1074
|
+
@synchronized_on_sender.setter
|
|
1075
|
+
def synchronized_on_sender(self, value: bool):
|
|
1076
|
+
self._synchronized_on_sender = value
|
|
1077
|
+
|
|
1078
|
+
@property
|
|
1079
|
+
def synchronized_on_receiver(self):
|
|
1080
|
+
return getattr(self, "_synchronized_on_receiver", False)
|
|
1081
|
+
|
|
1082
|
+
@synchronized_on_receiver.setter
|
|
1083
|
+
def synchronized_on_receiver(self, value: bool):
|
|
1084
|
+
self._synchronized_on_receiver = value
|
|
1085
|
+
|
|
1086
|
+
# ========================================================================
|
|
1087
|
+
# Background Receiver
|
|
1088
|
+
# ========================================================================
|
|
1089
|
+
|
|
1090
|
+
def _start_background_receiver(self):
|
|
1091
|
+
"""Start daemon thread that monitors for weight update instructions.
|
|
1092
|
+
|
|
1093
|
+
The background thread runs _background_receive_loop() which waits for
|
|
1094
|
+
instructions via _wait_for_instruction() and calls receive() when
|
|
1095
|
+
an instruction arrives.
|
|
1096
|
+
"""
|
|
1097
|
+
if not self.initialized_on_receiver:
|
|
1098
|
+
raise RuntimeError(
|
|
1099
|
+
"_start_background_receiver must be called on the receiver side."
|
|
1100
|
+
)
|
|
1101
|
+
self._stop_event = threading.Event()
|
|
1102
|
+
self._background_thread = threading.Thread(
|
|
1103
|
+
target=self._background_receive_loop,
|
|
1104
|
+
daemon=True,
|
|
1105
|
+
name=f"WeightReceiver-{self._worker_idx}",
|
|
1106
|
+
)
|
|
1107
|
+
self._background_thread.start()
|
|
1108
|
+
|
|
1109
|
+
def _background_receive_loop(self):
|
|
1110
|
+
"""Background thread loop that waits for instructions and receives weights.
|
|
1111
|
+
|
|
1112
|
+
Default implementation uses _wait_for_instruction() and receive().
|
|
1113
|
+
Subclasses may override for custom behavior.
|
|
1114
|
+
"""
|
|
1115
|
+
while not self._stop_event.is_set():
|
|
1116
|
+
try:
|
|
1117
|
+
instruction = self._wait_for_instruction()
|
|
1118
|
+
if instruction is None:
|
|
1119
|
+
# Stop signal received
|
|
1120
|
+
break
|
|
1121
|
+
if instruction == "receive":
|
|
1122
|
+
self.receive()
|
|
1123
|
+
elif instruction == "stop":
|
|
1124
|
+
break
|
|
1125
|
+
else:
|
|
1126
|
+
torchrl_logger.warning(f"Unknown instruction: {instruction}")
|
|
1127
|
+
except Exception as e:
|
|
1128
|
+
if not self._stop_event.is_set():
|
|
1129
|
+
torchrl_logger.warning(f"Background receiver error: {e}")
|
|
1130
|
+
|
|
1131
|
+
def _wait_for_instruction(self, timeout: float | None = None) -> str | None:
|
|
1132
|
+
"""Block until an instruction arrives from the sender.
|
|
1133
|
+
|
|
1134
|
+
This method should be overridden by subclasses to implement
|
|
1135
|
+
scheme-specific instruction waiting (e.g., queue.get(), store polling).
|
|
1136
|
+
|
|
1137
|
+
Args:
|
|
1138
|
+
timeout: Maximum time to wait for instruction (seconds).
|
|
1139
|
+
None means block indefinitely.
|
|
1140
|
+
|
|
1141
|
+
Returns:
|
|
1142
|
+
The instruction string (e.g., "receive", "stop"), or None if
|
|
1143
|
+
stop event is set or timeout expires.
|
|
1144
|
+
"""
|
|
1145
|
+
raise NotImplementedError(
|
|
1146
|
+
f"{type(self).__name__} must implement _wait_for_instruction()"
|
|
1147
|
+
)
|
|
1148
|
+
|
|
1149
|
+
def _send_instruction(
|
|
1150
|
+
self,
|
|
1151
|
+
instruction: str = "receive",
|
|
1152
|
+
worker_ids: int | list[int] | None = None,
|
|
1153
|
+
) -> None:
|
|
1154
|
+
"""Send instruction to receiver(s) to trigger weight reception.
|
|
1155
|
+
|
|
1156
|
+
This method should be overridden by subclasses to implement
|
|
1157
|
+
scheme-specific instruction sending (e.g., queue.put(), store.set()).
|
|
1158
|
+
|
|
1159
|
+
Args:
|
|
1160
|
+
instruction: The instruction to send (default: "receive").
|
|
1161
|
+
worker_ids: Which workers to send to (None = all workers).
|
|
1162
|
+
"""
|
|
1163
|
+
raise NotImplementedError(
|
|
1164
|
+
f"{type(self).__name__} must implement _send_instruction()"
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
def _send_ack(self, message: str = "updated") -> None:
|
|
1168
|
+
"""Send acknowledgment back to sender after receiving weights.
|
|
1169
|
+
|
|
1170
|
+
Called by the background receiver after successfully applying weights.
|
|
1171
|
+
Subclasses should override to implement scheme-specific acknowledgment.
|
|
1172
|
+
|
|
1173
|
+
Args:
|
|
1174
|
+
message: The acknowledgment message (default: "updated").
|
|
1175
|
+
"""
|
|
1176
|
+
# Default: use transport's send_ack if available
|
|
1177
|
+
transport = self._receiver_transport or self._shared_transport
|
|
1178
|
+
if transport is not None and hasattr(transport, "send_ack"):
|
|
1179
|
+
transport.send_ack(message)
|
|
1180
|
+
|
|
1181
|
+
def _wait_for_ack( # noqa: B027
|
|
1182
|
+
self,
|
|
1183
|
+
worker_ids: int | list[int] | None = None,
|
|
1184
|
+
timeout: float | None = None,
|
|
1185
|
+
) -> None:
|
|
1186
|
+
"""Wait for acknowledgment from receiver(s).
|
|
1187
|
+
|
|
1188
|
+
Called by send() in synchronous mode to block until receivers confirm.
|
|
1189
|
+
Subclasses should override to implement scheme-specific waiting.
|
|
1190
|
+
|
|
1191
|
+
Args:
|
|
1192
|
+
worker_ids: Which workers to wait for (None = all workers).
|
|
1193
|
+
timeout: Maximum time to wait (seconds). None means block indefinitely.
|
|
1194
|
+
"""
|
|
1195
|
+
# Default: no-op (subclasses implement scheme-specific waiting)
|
|
1196
|
+
|
|
1197
|
+
def __getstate__(self):
|
|
1198
|
+
"""Prepare the scheme for pickling by excluding non-serializable runtime state."""
|
|
1199
|
+
state = self.__dict__.copy()
|
|
1200
|
+
# Remove non-serializable runtime state
|
|
1201
|
+
state["_context_ref"] = None
|
|
1202
|
+
state["_model_ref"] = None
|
|
1203
|
+
|
|
1204
|
+
state["_initialized_on_sender"] = False
|
|
1205
|
+
state["_initialized_on_receiver"] = False
|
|
1206
|
+
|
|
1207
|
+
state["_synchronized_on_sender"] = False
|
|
1208
|
+
state["_synchronized_on_receiver"] = False
|
|
1209
|
+
|
|
1210
|
+
state["_background_thread"] = None
|
|
1211
|
+
state["_stop_event"] = None
|
|
1212
|
+
|
|
1213
|
+
return state
|
|
1214
|
+
|
|
1215
|
+
def __setstate__(self, state):
|
|
1216
|
+
"""Restore the scheme from pickling."""
|
|
1217
|
+
self.__dict__.update(state)
|
|
1218
|
+
|
|
1219
|
+
def __del__(self):
|
|
1220
|
+
"""Clean up resources when the scheme is garbage collected."""
|
|
1221
|
+
try:
|
|
1222
|
+
self.shutdown()
|
|
1223
|
+
except Exception:
|
|
1224
|
+
# Silently ignore any errors during garbage collection cleanup
|
|
1225
|
+
pass
|
|
1226
|
+
|
|
1227
|
+
def shutdown(self) -> None:
|
|
1228
|
+
"""Shutdown the scheme and release resources.
|
|
1229
|
+
|
|
1230
|
+
This method stops any background threads and cleans up connections.
|
|
1231
|
+
It is safe to call multiple times. Subclasses should override this
|
|
1232
|
+
method to add custom cleanup logic, but should call super().shutdown()
|
|
1233
|
+
to ensure base cleanup is performed.
|
|
1234
|
+
"""
|
|
1235
|
+
# Stop background receiver thread if running
|
|
1236
|
+
if getattr(self, "_stop_event", None) is not None:
|
|
1237
|
+
self._stop_event.set()
|
|
1238
|
+
if getattr(self, "_background_thread", None) is not None:
|
|
1239
|
+
try:
|
|
1240
|
+
self._background_thread.join(timeout=5.0)
|
|
1241
|
+
except Exception:
|
|
1242
|
+
pass
|
|
1243
|
+
self._background_thread = None
|
|
1244
|
+
self._stop_event = None
|