torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import weakref
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from torchrl._utils import logger as torchrl_logger
|
|
8
|
+
|
|
9
|
+
from torchrl.weight_update.utils import _resolve_model
|
|
10
|
+
from torchrl.weight_update.weight_sync_schemes import (
|
|
11
|
+
TransportBackend,
|
|
12
|
+
WeightStrategy,
|
|
13
|
+
WeightSyncScheme,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RPCWeightSyncScheme(WeightSyncScheme):
|
|
18
|
+
"""Weight synchronization for torch.distributed.rpc.
|
|
19
|
+
|
|
20
|
+
This scheme uses RPC calls to synchronize weights across distributed
|
|
21
|
+
workers. Each remote collector gets its own transport, following the
|
|
22
|
+
same pattern as multiprocess collectors.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def _init_on_sender_impl(
|
|
26
|
+
self,
|
|
27
|
+
*,
|
|
28
|
+
model_id: str,
|
|
29
|
+
context: Any = None,
|
|
30
|
+
num_workers: int,
|
|
31
|
+
) -> None:
|
|
32
|
+
# Store model_id and context on scheme
|
|
33
|
+
self.model_id = model_id
|
|
34
|
+
if context is not None:
|
|
35
|
+
self.context = context
|
|
36
|
+
else:
|
|
37
|
+
raise RuntimeError(f"Expected a context for {type(self).__name__}.")
|
|
38
|
+
collector_infos = getattr(self.context, "collector_infos", None)
|
|
39
|
+
collector_rrefs = getattr(self.context, "collector_rrefs", None)
|
|
40
|
+
collector_class = getattr(self.context, "collector_class", None)
|
|
41
|
+
if (
|
|
42
|
+
collector_infos is None
|
|
43
|
+
or collector_rrefs is None
|
|
44
|
+
or collector_class is None
|
|
45
|
+
):
|
|
46
|
+
raise RuntimeError(
|
|
47
|
+
"RPCWeightSyncScheme requires a context with the following attributes: "
|
|
48
|
+
"(context.collector_infos, context.collector_rrefs, context.collector_class)"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Create transports for each remote collector
|
|
52
|
+
# worker_rank is i+1 because rank 0 is the main/trainer process
|
|
53
|
+
for i in range(num_workers):
|
|
54
|
+
worker_rank = i + 1
|
|
55
|
+
transport = self.create_transport(
|
|
56
|
+
collector_info=collector_infos[i],
|
|
57
|
+
collector_rref=collector_rrefs[i],
|
|
58
|
+
collector_class=collector_class,
|
|
59
|
+
worker_rank=worker_rank,
|
|
60
|
+
)
|
|
61
|
+
self._register_worker_sender(worker_idx=i, transport=transport)
|
|
62
|
+
|
|
63
|
+
def _init_on_receiver_impl(
|
|
64
|
+
self, *, model_id: str, context: Any = None, worker_idx: int | None = None
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Initialize scheme on the worker (receiver) side.
|
|
67
|
+
|
|
68
|
+
Expected kwargs (as provided by collectors):
|
|
69
|
+
- model_id: str # e.g. "policy"
|
|
70
|
+
- context: Any # collector / inner collector
|
|
71
|
+
- worker_idx: int | None # worker index (optional)
|
|
72
|
+
"""
|
|
73
|
+
if context is None:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"RPCWeightSyncScheme.init_on_receiver requires a 'context' "
|
|
76
|
+
"providing access to the model to be synchronized."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Store model_id and context on scheme
|
|
80
|
+
self.model_id = model_id
|
|
81
|
+
self.worker_idx = worker_idx
|
|
82
|
+
self.context = context
|
|
83
|
+
# Access weights to set up missing elements
|
|
84
|
+
self.weights # noqa
|
|
85
|
+
|
|
86
|
+
self._receiver_transport = RPCTransport(worker_rank=worker_idx)
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def model(self) -> Any | None:
|
|
90
|
+
if self._model_ref is not None:
|
|
91
|
+
return self._model_ref()
|
|
92
|
+
if self._model_id is not None:
|
|
93
|
+
model = _resolve_model(self.context, self._model_id)
|
|
94
|
+
if model is None:
|
|
95
|
+
if self._model_id == "policy":
|
|
96
|
+
torchrl_logger.debug("Creating policy from factory.")
|
|
97
|
+
model = self.context.policy_factory[0]()
|
|
98
|
+
self.context.policy = model
|
|
99
|
+
else:
|
|
100
|
+
raise AttributeError(
|
|
101
|
+
f"Model {self._model_id} was `None` in context {self.context}"
|
|
102
|
+
)
|
|
103
|
+
self._model_ref = weakref.ref(model)
|
|
104
|
+
return model
|
|
105
|
+
|
|
106
|
+
@model.setter
|
|
107
|
+
def model(self, value: Any):
|
|
108
|
+
if value is None:
|
|
109
|
+
return
|
|
110
|
+
self._model_ref = weakref.ref(value)
|
|
111
|
+
|
|
112
|
+
def create_transport(
|
|
113
|
+
self,
|
|
114
|
+
*,
|
|
115
|
+
collector_info=None,
|
|
116
|
+
collector_rref=None,
|
|
117
|
+
collector_class=None,
|
|
118
|
+
worker_rank=None,
|
|
119
|
+
**kwargs,
|
|
120
|
+
) -> TransportBackend:
|
|
121
|
+
"""Create RPC-based transport for a specific remote collector.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
collector_info: RPC worker info for the remote collector.
|
|
125
|
+
collector_rref: RPC remote reference to the collector.
|
|
126
|
+
collector_class: Class of the remote collector.
|
|
127
|
+
worker_rank: The torch.distributed rank of the remote worker.
|
|
128
|
+
**kwargs: Additional transport configuration.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
RPCTransport configured for this specific remote collector.
|
|
132
|
+
"""
|
|
133
|
+
return RPCTransport(
|
|
134
|
+
collector_info=collector_info,
|
|
135
|
+
collector_rref=collector_rref,
|
|
136
|
+
collector_class=collector_class,
|
|
137
|
+
worker_rank=worker_rank,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class RPCTransport:
|
|
142
|
+
"""RPC transport for communicating with a single RPC remote collector.
|
|
143
|
+
|
|
144
|
+
This transport handles weight updates for ONE specific remote collector via
|
|
145
|
+
torch.distributed primitives (send/recv) with RPC used for signaling.
|
|
146
|
+
Multiple transports are created for multiple collectors, following the same
|
|
147
|
+
pattern as the DistributedDataCollector.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
collector_info=None,
|
|
153
|
+
collector_rref=None,
|
|
154
|
+
collector_class=None,
|
|
155
|
+
worker_rank=None,
|
|
156
|
+
):
|
|
157
|
+
self._collector_info = collector_info
|
|
158
|
+
self._collector_rref = collector_rref
|
|
159
|
+
self._collector_class = collector_class
|
|
160
|
+
self._worker_rank = worker_rank # The torch.distributed rank of this worker
|
|
161
|
+
self._pending_future = None
|
|
162
|
+
self._pending_send = None
|
|
163
|
+
|
|
164
|
+
def send_weights(self, weights: Any) -> None:
|
|
165
|
+
"""Send weights to the remote collector using torch.distributed.
|
|
166
|
+
|
|
167
|
+
Uses torch.distributed.send() for the actual weight transfer and RPC
|
|
168
|
+
for signaling the remote collector to receive.
|
|
169
|
+
|
|
170
|
+
Order is critical to avoid deadlock:
|
|
171
|
+
1. Signal receiver via RPC to start recv() (non-blocking)
|
|
172
|
+
2. Send weights via torch.distributed (blocking until recv completes)
|
|
173
|
+
"""
|
|
174
|
+
if self._collector_info is None or self._collector_rref is None:
|
|
175
|
+
return
|
|
176
|
+
if self._worker_rank is None:
|
|
177
|
+
raise RuntimeError("worker_rank must be set for RPC transport")
|
|
178
|
+
|
|
179
|
+
# Step 1: Signal the remote collector via RPC to start receiving (async)
|
|
180
|
+
# Use rref.rpc_async() to properly call the instance method on the remote object
|
|
181
|
+
future = self._collector_rref.rpc_async()._receive_weights_scheme()
|
|
182
|
+
|
|
183
|
+
# Step 2: Send weights via torch.distributed (blocks until receiver calls recv())
|
|
184
|
+
weights.send(self._worker_rank)
|
|
185
|
+
|
|
186
|
+
# Step 3: Wait for RPC to complete (receiver has applied weights)
|
|
187
|
+
future.wait()
|
|
188
|
+
|
|
189
|
+
def send_weights_async(self, weights: Any) -> None:
|
|
190
|
+
"""Send weights to remote collector asynchronously.
|
|
191
|
+
|
|
192
|
+
Uses torch.distributed.isend() for the actual weight transfer and RPC
|
|
193
|
+
for signaling. Use wait_ack() to wait for completion.
|
|
194
|
+
|
|
195
|
+
Order is critical to avoid deadlock:
|
|
196
|
+
1. Signal receiver via RPC to start recv() (non-blocking)
|
|
197
|
+
2. Send weights via torch.distributed.isend() (non-blocking)
|
|
198
|
+
3. wait_ack() waits for both to complete
|
|
199
|
+
"""
|
|
200
|
+
if self._collector_info is None or self._collector_rref is None:
|
|
201
|
+
return
|
|
202
|
+
if self._worker_rank is None:
|
|
203
|
+
raise RuntimeError("worker_rank must be set for RPC transport")
|
|
204
|
+
|
|
205
|
+
# Step 1: Signal the remote collector via RPC to start receiving (async)
|
|
206
|
+
# Use rref.rpc_async() to properly call the instance method on the remote object
|
|
207
|
+
self._pending_future = (
|
|
208
|
+
self._collector_rref.rpc_async()._receive_weights_scheme()
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Step 2: Send weights asynchronously via torch.distributed
|
|
212
|
+
# Store the Work handle for wait_ack()
|
|
213
|
+
weights.isend(self._worker_rank)
|
|
214
|
+
|
|
215
|
+
def wait_ack(self) -> None:
|
|
216
|
+
"""Wait for both the RPC call and the distributed send to complete."""
|
|
217
|
+
# Wait for the RPC call to complete
|
|
218
|
+
if hasattr(self, "_pending_future") and self._pending_future is not None:
|
|
219
|
+
self._pending_future.wait()
|
|
220
|
+
del self._pending_future
|
|
221
|
+
|
|
222
|
+
def receive_weights(
|
|
223
|
+
self,
|
|
224
|
+
timeout: float | None = None,
|
|
225
|
+
*,
|
|
226
|
+
weights: Any = None,
|
|
227
|
+
model: Any = None,
|
|
228
|
+
strategy: WeightStrategy | None = None,
|
|
229
|
+
) -> Any | None:
|
|
230
|
+
"""Receive weights from sender using torch.distributed.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
timeout: Maximum time to wait for weights (seconds). If None,
|
|
234
|
+
blocks until weights are received.
|
|
235
|
+
weights: Pre-allocated weight buffer to receive into.
|
|
236
|
+
model: The model to apply weights to.
|
|
237
|
+
strategy: Strategy for applying weights to the model.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
The received weights, or None if timeout expires.
|
|
241
|
+
"""
|
|
242
|
+
if weights is None:
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
if timeout is None:
|
|
246
|
+
# Blocking receive
|
|
247
|
+
weights.recv(0)
|
|
248
|
+
else:
|
|
249
|
+
# Non-blocking receive with timeout support
|
|
250
|
+
futures = weights.irecv(src=0, return_premature=True)
|
|
251
|
+
if futures:
|
|
252
|
+
start_time = time.monotonic()
|
|
253
|
+
while True:
|
|
254
|
+
# Check if all futures are complete
|
|
255
|
+
all_complete = all(f.is_completed() for f in futures)
|
|
256
|
+
if all_complete:
|
|
257
|
+
break
|
|
258
|
+
# Check timeout
|
|
259
|
+
elapsed = time.monotonic() - start_time
|
|
260
|
+
if elapsed >= timeout:
|
|
261
|
+
# Timeout expired before receiving all weights
|
|
262
|
+
return None
|
|
263
|
+
# Small sleep to avoid busy-waiting
|
|
264
|
+
time.sleep(0.001)
|
|
265
|
+
|
|
266
|
+
# Apply the received weights to the model
|
|
267
|
+
if model is not None and strategy is not None:
|
|
268
|
+
strategy.apply_weights(model, weights)
|
|
269
|
+
|
|
270
|
+
return weights
|
|
271
|
+
|
|
272
|
+
def setup_connection_and_weights_on_sender(self) -> None:
|
|
273
|
+
"""No-op for RPCTransport - weights are sent via send_weights()."""
|
|
274
|
+
|
|
275
|
+
def setup_connection_and_weights_on_receiver(
|
|
276
|
+
self,
|
|
277
|
+
*,
|
|
278
|
+
worker_idx: int,
|
|
279
|
+
weights: Any = None,
|
|
280
|
+
model: Any = None,
|
|
281
|
+
strategy: WeightStrategy | None = None,
|
|
282
|
+
) -> Any:
|
|
283
|
+
"""No-op for RPCTransport - weights are received via receive()."""
|
|
284
|
+
return None
|