torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from collections.abc import Callable, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from tensordict import TensorDictBase
|
|
8
|
+
from tensordict.nn import TensorDictModule
|
|
9
|
+
|
|
10
|
+
from torchrl._utils import accept_remote_rref_udf_invocation
|
|
11
|
+
from torchrl.collectors._base import _make_legacy_metaclass
|
|
12
|
+
from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType
|
|
13
|
+
from torchrl.collectors._multi_async import MultiAsyncCollector
|
|
14
|
+
from torchrl.collectors._multi_base import _MultiCollectorMeta
|
|
15
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
16
|
+
from torchrl.envs import EnvBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@accept_remote_rref_udf_invocation
|
|
20
|
+
class AsyncCollector(MultiAsyncCollector):
|
|
21
|
+
"""Runs a single DataCollector on a separate process.
|
|
22
|
+
|
|
23
|
+
This is mostly useful for offline RL paradigms where the policy being
|
|
24
|
+
trained can differ from the policy used to collect data. In online
|
|
25
|
+
settings, a regular DataCollector should be preferred. This class is
|
|
26
|
+
merely a wrapper around a MultiAsyncCollector where a single process
|
|
27
|
+
is being created.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
create_env_fn (Callabled): Callable returning an instance of EnvBase
|
|
31
|
+
policy (Callable): Policy to be executed in the environment.
|
|
32
|
+
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
|
|
33
|
+
If ``None`` is provided, the policy used will be a
|
|
34
|
+
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
|
|
35
|
+
``action_spec``.
|
|
36
|
+
Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
|
|
37
|
+
This is the recommended usage of the collector.
|
|
38
|
+
Other callables are accepted too:
|
|
39
|
+
If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
|
|
40
|
+
instances) it will be wrapped in a `nn.Module` first.
|
|
41
|
+
Then, the collector will try to assess if these
|
|
42
|
+
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
|
|
43
|
+
|
|
44
|
+
- If the policy forward signature matches any of ``forward(self, tensordict)``,
|
|
45
|
+
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
|
|
46
|
+
any typing with a single argument typed as a subclass of ``TensorDictBase``)
|
|
47
|
+
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
|
|
48
|
+
|
|
49
|
+
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
50
|
+
|
|
51
|
+
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
52
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
53
|
+
|
|
54
|
+
Keyword Args:
|
|
55
|
+
policy_factory (Callable[[], Callable], optional): a callable that returns
|
|
56
|
+
a policy instance. This is exclusive with the `policy` argument.
|
|
57
|
+
|
|
58
|
+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
|
|
59
|
+
|
|
60
|
+
frames_per_batch (int): A keyword-only argument representing the
|
|
61
|
+
total number of elements in a batch.
|
|
62
|
+
total_frames (int, optional): A keyword-only argument representing the
|
|
63
|
+
total number of frames returned by the collector
|
|
64
|
+
during its lifespan. If the ``total_frames`` is not divisible by
|
|
65
|
+
``frames_per_batch``, an exception is raised.
|
|
66
|
+
Endless collectors can be created by passing ``total_frames=-1``.
|
|
67
|
+
Defaults to ``-1`` (never ending collector).
|
|
68
|
+
device (int, str or torch.device, optional): The generic device of the
|
|
69
|
+
collector. The ``device`` args fills any non-specified device: if
|
|
70
|
+
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
|
|
71
|
+
``env_device`` is not specified, its value will be set to ``device``.
|
|
72
|
+
Defaults to ``None`` (No default device).
|
|
73
|
+
Supports a list of devices if one wishes to indicate a different device
|
|
74
|
+
for each worker. The list must be as long as the number of workers.
|
|
75
|
+
storing_device (int, str or torch.device, optional): The device on which
|
|
76
|
+
the output :class:`~tensordict.TensorDict` will be stored.
|
|
77
|
+
If ``device`` is passed and ``storing_device`` is ``None``, it will
|
|
78
|
+
default to the value indicated by ``device``.
|
|
79
|
+
For long trajectories, it may be necessary to store the data on a different
|
|
80
|
+
device than the one where the policy and env are executed.
|
|
81
|
+
Defaults to ``None`` (the output tensordict isn't on a specific device,
|
|
82
|
+
leaf tensors sit on the device where they were created).
|
|
83
|
+
Supports a list of devices if one wishes to indicate a different device
|
|
84
|
+
for each worker. The list must be as long as the number of workers.
|
|
85
|
+
env_device (int, str or torch.device, optional): The device on which
|
|
86
|
+
the environment should be cast (or executed if that functionality is
|
|
87
|
+
supported). If not specified and the env has a non-``None`` device,
|
|
88
|
+
``env_device`` will default to that value. If ``device`` is passed
|
|
89
|
+
and ``env_device=None``, it will default to ``device``. If the value
|
|
90
|
+
as such specified of ``env_device`` differs from ``policy_device``
|
|
91
|
+
and one of them is not ``None``, the data will be cast to ``env_device``
|
|
92
|
+
before being passed to the env (i.e., passing different devices to
|
|
93
|
+
policy and env is supported). Defaults to ``None``.
|
|
94
|
+
Supports a list of devices if one wishes to indicate a different device
|
|
95
|
+
for each worker. The list must be as long as the number of workers.
|
|
96
|
+
policy_device (int, str or torch.device, optional): The device on which
|
|
97
|
+
the policy should be cast.
|
|
98
|
+
If ``device`` is passed and ``policy_device=None``, it will default
|
|
99
|
+
to ``device``. If the value as such specified of ``policy_device``
|
|
100
|
+
differs from ``env_device`` and one of them is not ``None``,
|
|
101
|
+
the data will be cast to ``policy_device`` before being passed to
|
|
102
|
+
the policy (i.e., passing different devices to policy and env is
|
|
103
|
+
supported). Defaults to ``None``.
|
|
104
|
+
Supports a list of devices if one wishes to indicate a different device
|
|
105
|
+
for each worker. The list must be as long as the number of workers.
|
|
106
|
+
create_env_kwargs (dict, optional): A dictionary with the
|
|
107
|
+
keyword arguments used to create an environment. If a list is
|
|
108
|
+
provided, each of its elements will be assigned to a sub-collector.
|
|
109
|
+
max_frames_per_traj (int, optional): Maximum steps per trajectory.
|
|
110
|
+
Note that a trajectory can span across multiple batches (unless
|
|
111
|
+
``reset_at_each_iter`` is set to ``True``, see below).
|
|
112
|
+
Once a trajectory reaches ``n_steps``, the environment is reset.
|
|
113
|
+
If the environment wraps multiple environments together, the number
|
|
114
|
+
of steps is tracked for each environment independently. Negative
|
|
115
|
+
values are allowed, in which case this argument is ignored.
|
|
116
|
+
Defaults to ``None`` (i.e. no maximum number of steps).
|
|
117
|
+
init_random_frames (int, optional): Number of frames for which the
|
|
118
|
+
policy is ignored before it is called. This feature is mainly
|
|
119
|
+
intended to be used in offline/model-based settings, where a
|
|
120
|
+
batch of random trajectories can be used to initialize training.
|
|
121
|
+
If provided, it will be rounded up to the closest multiple of frames_per_batch.
|
|
122
|
+
Defaults to ``None`` (i.e. no random frames).
|
|
123
|
+
reset_at_each_iter (bool, optional): Whether environments should be reset
|
|
124
|
+
at the beginning of a batch collection.
|
|
125
|
+
Defaults to ``False``.
|
|
126
|
+
postproc (Callable, optional): A post-processing transform, such as
|
|
127
|
+
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
|
|
128
|
+
instance.
|
|
129
|
+
Defaults to ``None``.
|
|
130
|
+
split_trajs (bool, optional): Boolean indicating whether the resulting
|
|
131
|
+
TensorDict should be split according to the trajectories.
|
|
132
|
+
See :func:`~torchrl.collectors.utils.split_trajectories` for more
|
|
133
|
+
information.
|
|
134
|
+
Defaults to ``False``.
|
|
135
|
+
exploration_type (ExplorationType, optional): interaction mode to be used when
|
|
136
|
+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
|
|
137
|
+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
|
|
138
|
+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
|
|
139
|
+
reset_when_done (bool, optional): if ``True`` (default), an environment
|
|
140
|
+
that return a ``True`` value in its ``"done"`` or ``"truncated"``
|
|
141
|
+
entry will be reset at the corresponding indices.
|
|
142
|
+
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
|
|
143
|
+
will be called before (sync) or after (async) each data collection.
|
|
144
|
+
Defaults to ``False``.
|
|
145
|
+
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
|
|
146
|
+
that will be allowed to finished collecting their rollout before the rest are forced to end early.
|
|
147
|
+
num_threads (int, optional): number of threads for this process.
|
|
148
|
+
Defaults to the number of workers.
|
|
149
|
+
num_sub_threads (int, optional): number of threads of the subprocesses.
|
|
150
|
+
Should be equal to one plus the number of processes launched within
|
|
151
|
+
each subprocess (or one if a single process is launched).
|
|
152
|
+
Defaults to 1 for safety: if none is indicated, launching multiple
|
|
153
|
+
workers may charge the cpu load too much and harm performance.
|
|
154
|
+
set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
|
|
155
|
+
``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
|
|
156
|
+
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
|
|
157
|
+
Truncated keys can be set through ``env.add_truncated_keys``.
|
|
158
|
+
Defaults to ``False``.
|
|
159
|
+
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
|
|
160
|
+
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
|
|
161
|
+
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
|
|
162
|
+
the policy version.
|
|
163
|
+
Defaults to `False`.
|
|
164
|
+
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
create_env_fn: Callable[[], EnvBase],
|
|
170
|
+
policy: None
|
|
171
|
+
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
|
|
172
|
+
*,
|
|
173
|
+
policy_factory: Callable[[], Callable] | None = None,
|
|
174
|
+
frames_per_batch: int,
|
|
175
|
+
total_frames: int | None = -1,
|
|
176
|
+
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
|
|
177
|
+
storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
|
|
178
|
+
env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
|
|
179
|
+
policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
|
|
180
|
+
create_env_kwargs: Sequence[dict[str, Any]] | None = None,
|
|
181
|
+
max_frames_per_traj: int | None = None,
|
|
182
|
+
init_random_frames: int | None = None,
|
|
183
|
+
reset_at_each_iter: bool = False,
|
|
184
|
+
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
185
|
+
split_trajs: bool | None = None,
|
|
186
|
+
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
|
|
187
|
+
reset_when_done: bool = True,
|
|
188
|
+
update_at_each_batch: bool = False,
|
|
189
|
+
preemptive_threshold: float | None = None,
|
|
190
|
+
num_threads: int | None = None,
|
|
191
|
+
num_sub_threads: int = 1,
|
|
192
|
+
set_truncated: bool = False,
|
|
193
|
+
track_policy_version: bool = False,
|
|
194
|
+
**kwargs,
|
|
195
|
+
):
|
|
196
|
+
super().__init__(
|
|
197
|
+
create_env_fn=[create_env_fn],
|
|
198
|
+
policy=policy,
|
|
199
|
+
policy_factory=policy_factory,
|
|
200
|
+
total_frames=total_frames,
|
|
201
|
+
create_env_kwargs=[create_env_kwargs]
|
|
202
|
+
if create_env_kwargs
|
|
203
|
+
else create_env_kwargs,
|
|
204
|
+
max_frames_per_traj=max_frames_per_traj,
|
|
205
|
+
frames_per_batch=frames_per_batch,
|
|
206
|
+
reset_at_each_iter=reset_at_each_iter,
|
|
207
|
+
init_random_frames=init_random_frames,
|
|
208
|
+
postproc=postproc,
|
|
209
|
+
split_trajs=split_trajs,
|
|
210
|
+
device=device,
|
|
211
|
+
policy_device=policy_device,
|
|
212
|
+
env_device=env_device,
|
|
213
|
+
storing_device=storing_device,
|
|
214
|
+
exploration_type=exploration_type,
|
|
215
|
+
reset_when_done=reset_when_done,
|
|
216
|
+
update_at_each_batch=update_at_each_batch,
|
|
217
|
+
preemptive_threshold=preemptive_threshold,
|
|
218
|
+
num_threads=num_threads,
|
|
219
|
+
num_sub_threads=num_sub_threads,
|
|
220
|
+
set_truncated=set_truncated,
|
|
221
|
+
track_policy_version=track_policy_version,
|
|
222
|
+
**kwargs,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# for RPC
|
|
226
|
+
def next(self):
|
|
227
|
+
return super().next()
|
|
228
|
+
|
|
229
|
+
# for RPC
|
|
230
|
+
def shutdown(
|
|
231
|
+
self,
|
|
232
|
+
timeout: float | None = None,
|
|
233
|
+
close_env: bool = True,
|
|
234
|
+
raise_on_error: bool = True,
|
|
235
|
+
) -> None:
|
|
236
|
+
return super().shutdown(
|
|
237
|
+
timeout=timeout, close_env=close_env, raise_on_error=raise_on_error
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# for RPC
|
|
241
|
+
def set_seed(self, seed: int, static_seed: bool = False) -> int:
|
|
242
|
+
return super().set_seed(seed, static_seed)
|
|
243
|
+
|
|
244
|
+
# for RPC
|
|
245
|
+
def state_dict(self) -> OrderedDict:
|
|
246
|
+
return super().state_dict()
|
|
247
|
+
|
|
248
|
+
# for RPC
|
|
249
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
250
|
+
return super().load_state_dict(state_dict)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
_LegacyAsyncCollectorMeta = _make_legacy_metaclass(_MultiCollectorMeta)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class aSyncDataCollector(AsyncCollector, metaclass=_LegacyAsyncCollectorMeta):
|
|
257
|
+
"""Deprecated version of :class:`~torchrl.collectors.AsyncCollector`."""
|
|
258
|
+
|
|
259
|
+
...
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""Re-exports of collector classes for backward compatibility."""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from torchrl.collectors._base import BaseCollector, DataCollectorBase
|
|
9
|
+
|
|
10
|
+
# Re-export constants for backward compatibility
|
|
11
|
+
from torchrl.collectors._constants import (
|
|
12
|
+
_Interruptor,
|
|
13
|
+
_InterruptorManager,
|
|
14
|
+
_is_osx,
|
|
15
|
+
_MAX_IDLE_COUNT,
|
|
16
|
+
_MIN_TIMEOUT,
|
|
17
|
+
_TIMEOUT,
|
|
18
|
+
cudagraph_mark_step_begin,
|
|
19
|
+
DEFAULT_EXPLORATION_TYPE,
|
|
20
|
+
INSTANTIATE_TIMEOUT,
|
|
21
|
+
WEIGHT_SYNC_TIMEOUT,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from torchrl.collectors._multi_async import MultiAsyncCollector, MultiaSyncDataCollector
|
|
25
|
+
from torchrl.collectors._multi_base import (
|
|
26
|
+
MultiCollector,
|
|
27
|
+
MultiCollector as _MultiDataCollector,
|
|
28
|
+
)
|
|
29
|
+
from torchrl.collectors._multi_sync import MultiSyncCollector, MultiSyncDataCollector
|
|
30
|
+
from torchrl.collectors._runner import _main_async_collector
|
|
31
|
+
from torchrl.collectors._single import Collector, SyncDataCollector
|
|
32
|
+
from torchrl.collectors._single_async import AsyncCollector, aSyncDataCollector
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
# New canonical names (preferred)
|
|
36
|
+
"BaseCollector",
|
|
37
|
+
"Collector",
|
|
38
|
+
"AsyncCollector",
|
|
39
|
+
"MultiCollector",
|
|
40
|
+
"MultiSyncCollector",
|
|
41
|
+
"MultiAsyncCollector",
|
|
42
|
+
# Legacy names (backward-compatible aliases)
|
|
43
|
+
"DataCollectorBase",
|
|
44
|
+
"SyncDataCollector",
|
|
45
|
+
"aSyncDataCollector",
|
|
46
|
+
"_MultiDataCollector",
|
|
47
|
+
"MultiSyncDataCollector",
|
|
48
|
+
"MultiaSyncDataCollector",
|
|
49
|
+
# Other exports
|
|
50
|
+
"_main_async_collector",
|
|
51
|
+
# Constants
|
|
52
|
+
"_TIMEOUT",
|
|
53
|
+
"INSTANTIATE_TIMEOUT",
|
|
54
|
+
"WEIGHT_SYNC_TIMEOUT",
|
|
55
|
+
"_MIN_TIMEOUT",
|
|
56
|
+
"_MAX_IDLE_COUNT",
|
|
57
|
+
"DEFAULT_EXPLORATION_TYPE",
|
|
58
|
+
"_is_osx",
|
|
59
|
+
"_Interruptor",
|
|
60
|
+
"_InterruptorManager",
|
|
61
|
+
"cudagraph_mark_step_begin",
|
|
62
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .generic import (
|
|
7
|
+
DEFAULT_SLURM_CONF,
|
|
8
|
+
DistributedCollector,
|
|
9
|
+
DistributedDataCollector,
|
|
10
|
+
DistributedWeightUpdater,
|
|
11
|
+
)
|
|
12
|
+
from .ray import RayCollector
|
|
13
|
+
from .rpc import RPCCollector, RPCDataCollector, RPCWeightUpdater
|
|
14
|
+
from .sync import DistributedSyncCollector, DistributedSyncDataCollector
|
|
15
|
+
from .utils import submitit_delayed_launcher
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"DEFAULT_SLURM_CONF",
|
|
19
|
+
# New canonical names (preferred)
|
|
20
|
+
"DistributedCollector",
|
|
21
|
+
"DistributedSyncCollector",
|
|
22
|
+
"RPCCollector",
|
|
23
|
+
# Legacy names (backward-compatible aliases)
|
|
24
|
+
"DistributedDataCollector",
|
|
25
|
+
"DistributedSyncDataCollector",
|
|
26
|
+
"RPCDataCollector",
|
|
27
|
+
# Other exports
|
|
28
|
+
"DistributedWeightUpdater",
|
|
29
|
+
"RPCWeightUpdater",
|
|
30
|
+
"RayCollector",
|
|
31
|
+
"submitit_delayed_launcher",
|
|
32
|
+
]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import socket
|
|
10
|
+
from datetime import timedelta
|
|
11
|
+
|
|
12
|
+
import torch.distributed
|
|
13
|
+
|
|
14
|
+
from torchrl._utils import logger as torchrl_logger
|
|
15
|
+
|
|
16
|
+
TCP_PORT = os.environ.get("TCP_PORT", "10003")
|
|
17
|
+
IDLE_TIMEOUT = os.environ.get("RCP_IDLE_TIMEOUT", 10)
|
|
18
|
+
|
|
19
|
+
MAX_TIME_TO_CONNECT = 1000
|
|
20
|
+
|
|
21
|
+
SLEEP_INTERVAL = 1e-6
|
|
22
|
+
|
|
23
|
+
DEFAULT_SLURM_CONF = {
|
|
24
|
+
"timeout_min": 10,
|
|
25
|
+
"slurm_partition": "train",
|
|
26
|
+
"slurm_cpus_per_task": 32,
|
|
27
|
+
"slurm_gpus_per_node": 0,
|
|
28
|
+
} #: Default value of the SLURM jobs
|
|
29
|
+
|
|
30
|
+
DEFAULT_SLURM_CONF_MAIN = {
|
|
31
|
+
"timeout_min": 10,
|
|
32
|
+
"slurm_partition": "train",
|
|
33
|
+
"slurm_cpus_per_task": 32,
|
|
34
|
+
"slurm_gpus_per_node": 1,
|
|
35
|
+
} #: Default value of the SLURM main job
|
|
36
|
+
|
|
37
|
+
DEFAULT_TENSORPIPE_OPTIONS = {
|
|
38
|
+
"num_worker_threads": 16,
|
|
39
|
+
"rpc_timeout": 10_000,
|
|
40
|
+
"_transports": ["uv"],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _find_free_port() -> int:
|
|
45
|
+
"""Find a free port by binding to port 0 and letting the OS choose."""
|
|
46
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
47
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
48
|
+
s.bind(("", 0))
|
|
49
|
+
return s.getsockname()[1]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _create_tcpstore_with_retry(
|
|
53
|
+
host_name: str,
|
|
54
|
+
port: int | None,
|
|
55
|
+
world_size: int,
|
|
56
|
+
is_master: bool,
|
|
57
|
+
timeout: float = 10.0,
|
|
58
|
+
max_retries: int = 10,
|
|
59
|
+
wait_for_workers: bool = True,
|
|
60
|
+
) -> tuple[torch.distributed.TCPStore, int]:
|
|
61
|
+
"""Create a TCPStore with retry logic for handling port conflicts.
|
|
62
|
+
|
|
63
|
+
This function attempts to create a TCPStore, and if the port is already in use,
|
|
64
|
+
it will retry with different random ports up to max_retries times.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
host_name: The hostname for the TCPStore.
|
|
68
|
+
port: The initial port to try. If None, a random port will be chosen.
|
|
69
|
+
world_size: The world size for the TCPStore.
|
|
70
|
+
is_master: Whether this is the master (server) process.
|
|
71
|
+
timeout: Timeout in seconds for the TCPStore.
|
|
72
|
+
max_retries: Maximum number of retry attempts.
|
|
73
|
+
wait_for_workers: Whether the master should wait for workers.
|
|
74
|
+
Only used when is_master=True.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
A tuple of (TCPStore, actual_port) where actual_port is the port
|
|
78
|
+
that was successfully bound.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
RuntimeError: If unable to create a TCPStore after max_retries attempts.
|
|
82
|
+
"""
|
|
83
|
+
last_error = None
|
|
84
|
+
|
|
85
|
+
for attempt in range(max_retries):
|
|
86
|
+
if port is None or attempt > 0:
|
|
87
|
+
# For the first attempt use provided port, for retries find a new free port
|
|
88
|
+
current_port = _find_free_port()
|
|
89
|
+
else:
|
|
90
|
+
current_port = int(port)
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
if is_master:
|
|
94
|
+
store = torch.distributed.TCPStore(
|
|
95
|
+
host_name=host_name,
|
|
96
|
+
port=current_port,
|
|
97
|
+
world_size=world_size,
|
|
98
|
+
is_master=True,
|
|
99
|
+
timeout=timedelta(seconds=timeout),
|
|
100
|
+
wait_for_workers=wait_for_workers,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
store = torch.distributed.TCPStore(
|
|
104
|
+
host_name=host_name,
|
|
105
|
+
port=current_port,
|
|
106
|
+
is_master=False,
|
|
107
|
+
timeout=timedelta(seconds=timeout),
|
|
108
|
+
)
|
|
109
|
+
torchrl_logger.debug(
|
|
110
|
+
f"TCPStore created successfully on {host_name}:{current_port} "
|
|
111
|
+
f"(attempt {attempt + 1}/{max_retries})"
|
|
112
|
+
)
|
|
113
|
+
return store, current_port
|
|
114
|
+
|
|
115
|
+
except (RuntimeError, OSError) as e:
|
|
116
|
+
error_msg = str(e).lower()
|
|
117
|
+
if "address already in use" in error_msg or "eaddrinuse" in error_msg:
|
|
118
|
+
torchrl_logger.debug(
|
|
119
|
+
f"Port {current_port} already in use, "
|
|
120
|
+
f"retrying ({attempt + 1}/{max_retries})..."
|
|
121
|
+
)
|
|
122
|
+
last_error = e
|
|
123
|
+
# Add small random delay to reduce collision probability
|
|
124
|
+
import time
|
|
125
|
+
|
|
126
|
+
time.sleep(random.uniform(0.01, 0.1))
|
|
127
|
+
continue
|
|
128
|
+
# For other errors, re-raise immediately
|
|
129
|
+
raise
|
|
130
|
+
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
f"Failed to create TCPStore after {max_retries} attempts. Last error: {last_error}"
|
|
133
|
+
)
|