torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
from collections.abc import Callable, Iterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import TensorDictBase
|
|
15
|
+
from tensordict.nn import TensorDictModuleBase
|
|
16
|
+
from torchrl.collectors.llm import LLMCollector
|
|
17
|
+
from torchrl.collectors.weight_update import WeightUpdaterBase
|
|
18
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
19
|
+
from torchrl.envs import EnvBase
|
|
20
|
+
from torchrl.envs.llm.transforms.policy_version import PolicyVersion
|
|
21
|
+
|
|
22
|
+
RAY_ERR = None
|
|
23
|
+
try:
|
|
24
|
+
import ray
|
|
25
|
+
|
|
26
|
+
_has_ray = True
|
|
27
|
+
except ImportError as err:
|
|
28
|
+
_has_ray = False
|
|
29
|
+
RAY_ERR = err
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RayLLMCollector(LLMCollector):
|
|
33
|
+
"""A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
env (EnvBase or EnvBase constructor): the environment to be used for data collection.
|
|
37
|
+
|
|
38
|
+
Keyword Args:
|
|
39
|
+
policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection.
|
|
40
|
+
policy_factory (Callable[[], Callable], optional): a callable that returns
|
|
41
|
+
a policy instance. This is exclusive with the `policy` argument.
|
|
42
|
+
dialog_turns_per_batch (int): A keyword-only argument representing the total
|
|
43
|
+
number of elements in a batch.
|
|
44
|
+
total_dialog_turns (int): A keyword-only argument representing the total
|
|
45
|
+
number of dialog turns returned by the collector during its lifespan.
|
|
46
|
+
yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the
|
|
47
|
+
last (done) steps.
|
|
48
|
+
yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
|
|
49
|
+
or single, completed trajectories.
|
|
50
|
+
postproc (Callable, optional): A post-processing transform.
|
|
51
|
+
async_envs (bool, optional): if True, the environment will be run asynchronously.
|
|
52
|
+
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
|
|
53
|
+
but populate the buffer instead.
|
|
54
|
+
reset_at_each_iter (bool, optional): if True, the environment will be reset at each iteration.
|
|
55
|
+
flatten_data (bool, optional): if True, the collector will flatten the collected data
|
|
56
|
+
before returning it.
|
|
57
|
+
weight_updater (WeightUpdaterBase or constructor, optional): An instance of WeightUpdaterBase
|
|
58
|
+
or its subclass, responsible for updating the policy weights on remote inference workers.
|
|
59
|
+
ray_init_config (dict[str, Any], optional): keyword arguments to pass to ray.init().
|
|
60
|
+
remote_config (dict[str, Any], optional): keyword arguments to pass to cls.as_remote().
|
|
61
|
+
num_cpus (int, optional): Number of CPUs of the actor. Defaults to `None` (taken from remote_config).
|
|
62
|
+
num_gpus (int, optional): Number of GPUs of the actor. Defaults to `None` (taken from remote_config).
|
|
63
|
+
sync_iter (bool, optional): if `True`, items yeilded by the collector will be synced to the local process.
|
|
64
|
+
If `False`, the collector will collect the next batch of data in between yielding.
|
|
65
|
+
This has no effect when data is collected through the :meth:`start` method.
|
|
66
|
+
For example:
|
|
67
|
+
|
|
68
|
+
>>> collector = RayLLMCollector(..., sync_iter=True)
|
|
69
|
+
>>> for data in collector: # blocking
|
|
70
|
+
... # expensive operation - collector is idle
|
|
71
|
+
>>> collector = RayLLMCollector(..., sync_iter=False)
|
|
72
|
+
>>> for data in collector: # non-blocking
|
|
73
|
+
... # expensive operation - collector is collecting data
|
|
74
|
+
|
|
75
|
+
This is somehwat equivalent to using :class:`~torchrl.collectors.MultiSyncCollector` (`sync_iter=True`) or
|
|
76
|
+
:class:`~torchrl.collectors.MultiAsyncDataCollector` (`sync_iter=False`).
|
|
77
|
+
Defaults to `True`.
|
|
78
|
+
verbose (bool, optional): if ``True``, the collector will print progress information.
|
|
79
|
+
Defaults to `False`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
env: EnvBase | Callable[[], EnvBase],
|
|
85
|
+
*,
|
|
86
|
+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
87
|
+
policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
|
|
88
|
+
| None = None,
|
|
89
|
+
dialog_turns_per_batch: int,
|
|
90
|
+
total_dialog_turns: int = -1,
|
|
91
|
+
yield_only_last_steps: bool | None = None,
|
|
92
|
+
yield_completed_trajectories: bool | None = None,
|
|
93
|
+
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
94
|
+
async_envs: bool | None = None,
|
|
95
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
96
|
+
reset_at_each_iter: bool = False,
|
|
97
|
+
flatten_data: bool | None = None,
|
|
98
|
+
weight_updater: WeightUpdaterBase
|
|
99
|
+
| Callable[[], WeightUpdaterBase]
|
|
100
|
+
| None = None,
|
|
101
|
+
ray_init_config: dict[str, Any] | None = None,
|
|
102
|
+
remote_config: dict[str, Any] | None = None,
|
|
103
|
+
track_policy_version: bool | PolicyVersion = False,
|
|
104
|
+
sync_iter: bool = True,
|
|
105
|
+
verbose: bool = False,
|
|
106
|
+
num_cpus: int | None = None,
|
|
107
|
+
num_gpus: int | None = None,
|
|
108
|
+
) -> None:
|
|
109
|
+
if not _has_ray:
|
|
110
|
+
raise RuntimeError(
|
|
111
|
+
"ray library not found, unable to create a RayLLMCollector. "
|
|
112
|
+
) from RAY_ERR
|
|
113
|
+
if not ray.is_initialized():
|
|
114
|
+
if ray_init_config is None:
|
|
115
|
+
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
|
|
116
|
+
|
|
117
|
+
ray_init_config = DEFAULT_RAY_INIT_CONFIG
|
|
118
|
+
ray.init(**ray_init_config)
|
|
119
|
+
if not sync_iter:
|
|
120
|
+
remote_config = copy.copy(remote_config)
|
|
121
|
+
remote_config.setdefault("max_concurrency", 2)
|
|
122
|
+
if num_cpus is not None:
|
|
123
|
+
remote_config.setdefault("num_cpus", num_cpus)
|
|
124
|
+
if num_gpus is not None:
|
|
125
|
+
remote_config.setdefault("num_gpus", num_gpus)
|
|
126
|
+
remote_cls = LLMCollector.as_remote(remote_config).remote
|
|
127
|
+
self.sync_iter = sync_iter
|
|
128
|
+
self._collector = remote_cls(
|
|
129
|
+
env=env,
|
|
130
|
+
policy=policy,
|
|
131
|
+
policy_factory=policy_factory,
|
|
132
|
+
dialog_turns_per_batch=dialog_turns_per_batch,
|
|
133
|
+
total_dialog_turns=total_dialog_turns,
|
|
134
|
+
yield_only_last_steps=yield_only_last_steps,
|
|
135
|
+
yield_completed_trajectories=yield_completed_trajectories,
|
|
136
|
+
postproc=postproc,
|
|
137
|
+
async_envs=async_envs,
|
|
138
|
+
replay_buffer=replay_buffer,
|
|
139
|
+
reset_at_each_iter=reset_at_each_iter,
|
|
140
|
+
flatten_data=flatten_data,
|
|
141
|
+
weight_updater=weight_updater,
|
|
142
|
+
track_policy_version=track_policy_version,
|
|
143
|
+
verbose=verbose,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
|
|
147
|
+
return ray.get(self._collector.set_postproc.remote(postproc))
|
|
148
|
+
|
|
149
|
+
def _next_remote(self) -> None:
|
|
150
|
+
return self._collector.next.remote()
|
|
151
|
+
|
|
152
|
+
def next(self) -> None:
|
|
153
|
+
"""Get the next batch of data from the collector.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
None as the data is written directly to the replay buffer.
|
|
157
|
+
"""
|
|
158
|
+
return ray.get(self._next_remote())
|
|
159
|
+
|
|
160
|
+
def __iter__(self) -> Iterator[None]:
|
|
161
|
+
"""Returns an iterator that yields None as the collector writes directly to the replay buffer."""
|
|
162
|
+
if not self.sync_iter:
|
|
163
|
+
future = self._next_remote()
|
|
164
|
+
else:
|
|
165
|
+
future = None
|
|
166
|
+
while True:
|
|
167
|
+
try:
|
|
168
|
+
if self.sync_iter:
|
|
169
|
+
yield self.next()
|
|
170
|
+
else:
|
|
171
|
+
result = ray.get(future)
|
|
172
|
+
future = self._next_remote()
|
|
173
|
+
yield result
|
|
174
|
+
except StopIteration:
|
|
175
|
+
break
|
|
176
|
+
|
|
177
|
+
def start(self):
|
|
178
|
+
"""Starts the collector in a background thread."""
|
|
179
|
+
pending_task = self._collector.start.remote()
|
|
180
|
+
return ray.get(pending_task)
|
|
181
|
+
|
|
182
|
+
def is_running(self):
|
|
183
|
+
return ray.get(self._collector.is_running.remote())
|
|
184
|
+
|
|
185
|
+
def shutdown(self):
|
|
186
|
+
"""Shuts down the collector."""
|
|
187
|
+
pending_task = self._collector.shutdown.remote()
|
|
188
|
+
return ray.get(pending_task)
|
|
189
|
+
|
|
190
|
+
def async_shutdown(self, timeout=None):
|
|
191
|
+
"""Shuts down the collector asynchronously."""
|
|
192
|
+
pending_task = self._collector.async_shutdown.remote(timeout=timeout)
|
|
193
|
+
return ray.get(pending_task)
|
|
194
|
+
|
|
195
|
+
def update_policy_weights_(
|
|
196
|
+
self,
|
|
197
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
198
|
+
*,
|
|
199
|
+
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
|
|
200
|
+
**kwargs,
|
|
201
|
+
):
|
|
202
|
+
"""Updates the policy weights on remote workers.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
policy_or_weights: The weights to update with. Can be:
|
|
206
|
+
- TensorDictModuleBase: A policy module whose weights will be extracted
|
|
207
|
+
- TensorDictBase: A TensorDict containing weights
|
|
208
|
+
- dict: A regular dict containing weights
|
|
209
|
+
- None: Will try to get weights from server using _get_server_weights()
|
|
210
|
+
worker_ids: The workers to update. If None, updates all workers.
|
|
211
|
+
"""
|
|
212
|
+
if "policy_weights" in kwargs:
|
|
213
|
+
warnings.warn(
|
|
214
|
+
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
|
|
215
|
+
DeprecationWarning,
|
|
216
|
+
)
|
|
217
|
+
policy_or_weights = kwargs.pop("policy_weights")
|
|
218
|
+
|
|
219
|
+
pending_task = self._collector.update_policy_weights_.remote(
|
|
220
|
+
policy_or_weights=policy_or_weights, worker_ids=worker_ids
|
|
221
|
+
)
|
|
222
|
+
return ray.get(pending_task)
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def total_dialog_turns(self):
|
|
226
|
+
"""Total number of dialog turns to collect."""
|
|
227
|
+
return ray.get(self._collector.total_dialog_turns.remote)
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def dialog_turns_per_batch(self) -> int:
|
|
231
|
+
"""Number of dialog turns per batch."""
|
|
232
|
+
return ray.get(self._collector.dialog_turns_per_batch.remote)
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def rollout(self) -> Callable[[], TensorDictBase]:
|
|
236
|
+
"""Returns the rollout function."""
|
|
237
|
+
return ray.get(self._collector.rollout.remote())
|
|
238
|
+
|
|
239
|
+
def init_updater(self, *args, **kwargs):
|
|
240
|
+
"""Initialize the weight updater with custom arguments.
|
|
241
|
+
|
|
242
|
+
This method calls init_updater on the remote collector.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
*args: Positional arguments for weight updater initialization
|
|
246
|
+
**kwargs: Keyword arguments for weight updater initialization
|
|
247
|
+
"""
|
|
248
|
+
ray.get(self._collector.init_updater.remote(*args, **kwargs))
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def policy_version(self) -> str | int | None:
|
|
252
|
+
"""The current version of the policy.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
The current version number (int) or UUID (str), or None if version tracking is disabled.
|
|
256
|
+
"""
|
|
257
|
+
return ray.get(self._collector.get_policy_version.remote())
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def weight_updater(self) -> WeightUpdaterBase:
|
|
261
|
+
"""The weight updater instance.
|
|
262
|
+
|
|
263
|
+
We can pass the weight updater because it's stateless, hence serializable.
|
|
264
|
+
"""
|
|
265
|
+
return ray.get(self._collector.weight_updater.remote)
|
|
266
|
+
|
|
267
|
+
@weight_updater.setter
|
|
268
|
+
def weight_updater(self, weight_updater: WeightUpdaterBase):
|
|
269
|
+
"""Set the weight updater instance."""
|
|
270
|
+
ray.get(self._collector.set_weight_updater.remote(weight_updater))
|
|
271
|
+
weight_updater.register_collector(self)
|
|
272
|
+
|
|
273
|
+
def increment_version(self):
|
|
274
|
+
"""Increment the policy version."""
|
|
275
|
+
return ray.get(self._collector.increment_version.remote())
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
from queue import Full as QueueFull, Queue
|
|
10
|
+
|
|
11
|
+
from tensordict import TensorDictBase
|
|
12
|
+
|
|
13
|
+
from torchrl._utils import logger as torchrl_logger
|
|
14
|
+
|
|
15
|
+
_has_ray = importlib.util.find_spec("ray") is not None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _QueueAsRB:
|
|
19
|
+
def __init__(self, queue: Queue | ray.util.queue.Queue): # noqa
|
|
20
|
+
if not _has_ray:
|
|
21
|
+
raise ImportError("Ray not installed.")
|
|
22
|
+
self.queue = queue
|
|
23
|
+
|
|
24
|
+
def extend(self, data: TensorDictBase):
|
|
25
|
+
from ray.util.queue import Full as RayQueueFull
|
|
26
|
+
|
|
27
|
+
# unbind the data and put in the queue
|
|
28
|
+
for item in data.unbind(0):
|
|
29
|
+
while True:
|
|
30
|
+
try:
|
|
31
|
+
self.queue.put_nowait(item)
|
|
32
|
+
break
|
|
33
|
+
except (QueueFull, RayQueueFull):
|
|
34
|
+
self.queue.get() # Remove the oldest item to make space
|
|
35
|
+
torchrl_logger.warn("rollout queue full. Discarding data.")
|
|
36
|
+
return
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from .vllm import vLLMUpdater
|
|
8
|
+
from .vllm_v2 import vLLMUpdaterV2
|
|
9
|
+
|
|
10
|
+
__all__ = ["vLLMUpdater", "vLLMUpdaterV2"]
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
|
|
9
|
+
import importlib.util
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.cuda
|
|
13
|
+
import torch.distributed
|
|
14
|
+
|
|
15
|
+
from tensordict import TensorDict, TensorDictBase
|
|
16
|
+
from tensordict.nn import TensorDictModuleBase
|
|
17
|
+
|
|
18
|
+
from torchrl._utils import logger as torchrl_logger
|
|
19
|
+
|
|
20
|
+
from torchrl.collectors.weight_update import WeightUpdaterBase
|
|
21
|
+
from torchrl.modules.llm.backends import stateless_init_process_group
|
|
22
|
+
|
|
23
|
+
_has_vllm = importlib.util.find_spec("vllm") is not None
|
|
24
|
+
|
|
25
|
+
_has_ray = importlib.util.find_spec("ray") is not None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class vLLMUpdaterMeta(abc.ABCMeta):
|
|
29
|
+
"""Metaclass for vLLMUpdater that allows switching between V1 and V2 implementations.
|
|
30
|
+
|
|
31
|
+
When instantiating vLLMUpdater with v2=True, returns a vLLMUpdaterV2 instance instead.
|
|
32
|
+
This provides a unified entry point for both updater versions while maintaining
|
|
33
|
+
backward compatibility.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __call__(cls, *args, v2=False, **kwargs):
|
|
37
|
+
if v2:
|
|
38
|
+
# Import V2 here to avoid circular imports
|
|
39
|
+
from .vllm_v2 import vLLMUpdaterV2
|
|
40
|
+
|
|
41
|
+
# V2 has a different signature - it expects a vllm_engine parameter
|
|
42
|
+
# If the user is providing the old signature, we need to handle this gracefully
|
|
43
|
+
if args or any(
|
|
44
|
+
k in kwargs
|
|
45
|
+
for k in [
|
|
46
|
+
"master_address",
|
|
47
|
+
"master_port",
|
|
48
|
+
"model_metadata",
|
|
49
|
+
"vllm_tp_size",
|
|
50
|
+
]
|
|
51
|
+
):
|
|
52
|
+
# Old signature detected - we can't auto-convert, user needs to update their code
|
|
53
|
+
raise TypeError(
|
|
54
|
+
"When using v2=True, you must provide a vllm_engine parameter instead of "
|
|
55
|
+
"the v1 parameters (master_address, master_port, model_metadata, vllm_tp_size). "
|
|
56
|
+
"See vLLMUpdaterV2 documentation for details."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Forward to V2 constructor
|
|
60
|
+
return vLLMUpdaterV2(*args, **kwargs)
|
|
61
|
+
else:
|
|
62
|
+
# Use original V1 constructor
|
|
63
|
+
return super().__call__(*args, **kwargs)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class vLLMUpdater(WeightUpdaterBase, metaclass=vLLMUpdaterMeta):
|
|
67
|
+
"""A class that sends weights to vLLM workers.
|
|
68
|
+
|
|
69
|
+
This class handles synchronizing weights between a training policy and vLLM inference workers.
|
|
70
|
+
It supports both local vLLM instances and remote Ray actors.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
master_address (str, optional): The master address for distributed training. Defaults to localhost.
|
|
74
|
+
master_port (int, optional): The master port for distributed training. If None, will auto-assign.
|
|
75
|
+
model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional): Model metadata mapping
|
|
76
|
+
parameter names to their dtype and shape. If not provided, will be extracted from policy.
|
|
77
|
+
vllm_tp_size (int, optional): vLLM tensor parallel size. Defaults to 1.
|
|
78
|
+
v2 (bool, optional): If True, returns a vLLMUpdaterV2 instance instead. This is an experimental
|
|
79
|
+
feature that provides better integration with AsyncVLLM engines. When using v2=True, you must
|
|
80
|
+
provide a vllm_engine parameter instead of the above parameters. Defaults to False.
|
|
81
|
+
|
|
82
|
+
Methods:
|
|
83
|
+
init: Initialize the updater with model metadata and initialize the group.
|
|
84
|
+
_sync_weights_with_worker: Synchronize weights with a vLLM worker.
|
|
85
|
+
_get_server_weights: Not used - weights must be passed directly.
|
|
86
|
+
_maybe_map_weights: No mapping needed.
|
|
87
|
+
all_worker_ids: Returns [0] since we only have one worker.
|
|
88
|
+
|
|
89
|
+
.. note::
|
|
90
|
+
This class assumes the policy is a transformers model that can be loaded by vLLM.
|
|
91
|
+
The policy must have a state_dict() method that returns the model weights.
|
|
92
|
+
|
|
93
|
+
.. warning::
|
|
94
|
+
The v2=True option is experimental and may have backward-compatibility breaking changes
|
|
95
|
+
in future releases. However, it is generally considered a better option for working with
|
|
96
|
+
AsyncVLLM engines and provides improved performance and reliability.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
master_address: str | None = None,
|
|
102
|
+
master_port: int | None = None,
|
|
103
|
+
model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None,
|
|
104
|
+
vllm_tp_size: int | None = None,
|
|
105
|
+
):
|
|
106
|
+
torchrl_logger.debug(f"=> in {type(self).__name__}.__init__")
|
|
107
|
+
self.master_address = master_address
|
|
108
|
+
self.master_port = master_port
|
|
109
|
+
self.model_metadata = model_metadata
|
|
110
|
+
self.initialized_group = None
|
|
111
|
+
if vllm_tp_size is None:
|
|
112
|
+
vllm_tp_size = 1
|
|
113
|
+
self.vllm_tp_size = vllm_tp_size
|
|
114
|
+
self._model_ref = None
|
|
115
|
+
|
|
116
|
+
def init(self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) -> None:
|
|
117
|
+
"""Initialize the updater with model metadata and initialize the group.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
model_metadata (dict[str, tuple[torch.dtype, torch.Size]]): The model metadata mapping
|
|
121
|
+
parameter names to their dtype and shape.
|
|
122
|
+
"""
|
|
123
|
+
self.model_metadata = model_metadata
|
|
124
|
+
self.maybe_init_group()
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def master_address(self):
|
|
128
|
+
if self._master_address is None:
|
|
129
|
+
self._master_address = "localhost" # get_ip()
|
|
130
|
+
return self._master_address
|
|
131
|
+
|
|
132
|
+
@master_address.setter
|
|
133
|
+
def master_address(self, value):
|
|
134
|
+
self._master_address = value
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def master_port(self):
|
|
138
|
+
if self._master_port is None:
|
|
139
|
+
try:
|
|
140
|
+
from vllm.utils import get_open_port
|
|
141
|
+
|
|
142
|
+
self._master_port = get_open_port()
|
|
143
|
+
except ImportError:
|
|
144
|
+
# Fallback if vLLM get_open_port not available
|
|
145
|
+
import socket
|
|
146
|
+
|
|
147
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
148
|
+
s.bind(("", 0))
|
|
149
|
+
self._master_port = s.getsockname()[1]
|
|
150
|
+
return self._master_port
|
|
151
|
+
|
|
152
|
+
@master_port.setter
|
|
153
|
+
def master_port(self, value):
|
|
154
|
+
self._master_port = value
|
|
155
|
+
|
|
156
|
+
def _get_model_ref(self):
|
|
157
|
+
"""Get a reference to the model actor, either directly or through Ray.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
The model reference that can be used for weight updates
|
|
161
|
+
"""
|
|
162
|
+
if self._model_ref is not None:
|
|
163
|
+
return self._model_ref
|
|
164
|
+
|
|
165
|
+
if hasattr(self.collector, "_collector"):
|
|
166
|
+
# We're dealing with a RayLLMCollector
|
|
167
|
+
import ray
|
|
168
|
+
|
|
169
|
+
# Get direct reference to the model actor
|
|
170
|
+
self._model_ref = ray.get(
|
|
171
|
+
self.collector._collector.get_policy_model.remote()
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
# We're dealing with a local collector
|
|
175
|
+
self._model_ref = self.collector.policy.model
|
|
176
|
+
|
|
177
|
+
return self._model_ref
|
|
178
|
+
|
|
179
|
+
def _init_group(self):
|
|
180
|
+
import ray
|
|
181
|
+
|
|
182
|
+
torchrl_logger.debug(f"=> in {type(self).__name__}._init_group")
|
|
183
|
+
weight_sync_world_size = self.vllm_tp_size + 1
|
|
184
|
+
torchrl_logger.debug(f"initializing group with {weight_sync_world_size=}...")
|
|
185
|
+
torchrl_logger.debug(f"vllm_tp_size={self.vllm_tp_size}")
|
|
186
|
+
|
|
187
|
+
model_ref = self._get_model_ref()
|
|
188
|
+
|
|
189
|
+
torchrl_logger.debug(f"model_ref: {model_ref}")
|
|
190
|
+
# Initialize the weight update group
|
|
191
|
+
torchrl_logger.debug("Calling init_weight_update_group...")
|
|
192
|
+
init_weight_update_group_getter = model_ref.collective_rpc.remote(
|
|
193
|
+
"init_weight_update_group",
|
|
194
|
+
args=(self.master_address, self.master_port, 1, weight_sync_world_size),
|
|
195
|
+
)
|
|
196
|
+
torchrl_logger.debug("init_weight_update_group remote call succeeded")
|
|
197
|
+
|
|
198
|
+
torchrl_logger.debug("Calling stateless_init_process_group within updater...")
|
|
199
|
+
self.vllm_comm_group = stateless_init_process_group(
|
|
200
|
+
self.master_address,
|
|
201
|
+
self.master_port,
|
|
202
|
+
0,
|
|
203
|
+
weight_sync_world_size,
|
|
204
|
+
torch.device("cuda:0"),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
ray.get(init_weight_update_group_getter)
|
|
208
|
+
torchrl_logger.debug("init_weight_update_group getter succeeded")
|
|
209
|
+
|
|
210
|
+
torchrl_logger.debug("group initialized")
|
|
211
|
+
self.initialized_group = True
|
|
212
|
+
|
|
213
|
+
def maybe_init_group(self):
|
|
214
|
+
if self.initialized_group is None and self.collector is not None:
|
|
215
|
+
self._init_group()
|
|
216
|
+
|
|
217
|
+
def _sync_weights_with_worker(
|
|
218
|
+
self,
|
|
219
|
+
*,
|
|
220
|
+
worker_id: int | torch.device | None = None,
|
|
221
|
+
server_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
222
|
+
) -> None:
|
|
223
|
+
"""Synchronize weights with a vLLM worker.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
worker_id: Not used - we only have one worker.
|
|
227
|
+
server_weights: The weights to sync. Can be:
|
|
228
|
+
- TensorDictModuleBase: A policy module whose weights will be extracted
|
|
229
|
+
- TensorDictBase: A TensorDict containing weights
|
|
230
|
+
- dict: A regular dict containing weights
|
|
231
|
+
- None: Will try to get weights from server using _get_server_weights()
|
|
232
|
+
"""
|
|
233
|
+
if server_weights is None:
|
|
234
|
+
raise ValueError("server_weights cannot be None for vLLM updater")
|
|
235
|
+
|
|
236
|
+
if self.initialized_group is None:
|
|
237
|
+
raise RuntimeError(
|
|
238
|
+
"Failed to update weights because sender is not initialized."
|
|
239
|
+
)
|
|
240
|
+
if self.model_metadata is None:
|
|
241
|
+
raise RuntimeError(
|
|
242
|
+
"Failed to update weights because model metadata is not set. "
|
|
243
|
+
"In async mode, you must set the model metadata in the training actor "
|
|
244
|
+
"before any weight updates."
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
model_ref = self._get_model_ref()
|
|
248
|
+
|
|
249
|
+
# First broadcast metadata
|
|
250
|
+
torchrl_logger.debug("broadcasting with update_weight_broadcast")
|
|
251
|
+
remotes = []
|
|
252
|
+
for k, (dtype, shape) in self.model_metadata.items():
|
|
253
|
+
remotes.append(
|
|
254
|
+
model_ref.collective_rpc.remote(
|
|
255
|
+
"update_weight_broadcast", args=(k, dtype, shape)
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# # Then update weights
|
|
260
|
+
# remotes = []
|
|
261
|
+
# pbar = tqdm.tqdm(server_weights.items(), desc="Updating weights", total=len(server_weights))
|
|
262
|
+
# for k, val in pbar:
|
|
263
|
+
# pbar.set_description(f"Updating {k}")
|
|
264
|
+
# remotes.append(model_ref.collective_rpc.remote("update_weight", args=(k, val)))
|
|
265
|
+
# # ray.get(remotes)
|
|
266
|
+
|
|
267
|
+
# if self.vllm_comm_group is not True:
|
|
268
|
+
torchrl_logger.debug("broadcasting...")
|
|
269
|
+
for k in self.model_metadata:
|
|
270
|
+
val = server_weights[k].to(torch.device("cuda:0"))
|
|
271
|
+
self.vllm_comm_group.broadcast(
|
|
272
|
+
val,
|
|
273
|
+
src=0,
|
|
274
|
+
stream=torch.cuda.current_stream(),
|
|
275
|
+
)
|
|
276
|
+
del val
|
|
277
|
+
import ray
|
|
278
|
+
|
|
279
|
+
ray.get(remotes)
|
|
280
|
+
torchrl_logger.debug("done broadcasting")
|
|
281
|
+
torch.cuda.synchronize()
|
|
282
|
+
|
|
283
|
+
def _get_server_weights(self) -> TensorDictBase | None:
|
|
284
|
+
"""Not used - weights must be passed directly via policy."""
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
def _maybe_map_weights(
|
|
288
|
+
self, server_weights: TensorDictBase | TensorDictModuleBase | dict
|
|
289
|
+
) -> TensorDictBase:
|
|
290
|
+
"""Map weights from any format to the format expected by vLLM.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
server_weights: The weights to map. Can be:
|
|
294
|
+
- TensorDictModuleBase: A policy module whose weights will be extracted
|
|
295
|
+
- TensorDictBase: A TensorDict containing weights
|
|
296
|
+
- dict: A regular dict containing weights
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
TensorDictBase: The mapped weights in TensorDict format
|
|
300
|
+
"""
|
|
301
|
+
if isinstance(server_weights, TensorDictModuleBase):
|
|
302
|
+
# Extract weights from policy module using merge_and_unload for LLMs
|
|
303
|
+
if not hasattr(server_weights, "model"):
|
|
304
|
+
raise ValueError("TensorDictModuleBase must have a 'model' attribute")
|
|
305
|
+
# Check if it's a LoRA model
|
|
306
|
+
if hasattr(server_weights.model, "merge_and_unload"):
|
|
307
|
+
state_dict = server_weights.model.merge_and_unload().state_dict()
|
|
308
|
+
else:
|
|
309
|
+
state_dict = server_weights.model.state_dict()
|
|
310
|
+
return TensorDict(state_dict, [])
|
|
311
|
+
elif isinstance(server_weights, TensorDictBase):
|
|
312
|
+
return server_weights
|
|
313
|
+
elif isinstance(server_weights, dict):
|
|
314
|
+
return TensorDict(server_weights, [])
|
|
315
|
+
else:
|
|
316
|
+
raise TypeError(
|
|
317
|
+
f"server_weights must be TensorDictModuleBase, TensorDictBase or dict, got {type(server_weights)}"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
@classmethod
|
|
321
|
+
def get_model_metadata(
|
|
322
|
+
cls, model: TensorDictModuleBase
|
|
323
|
+
) -> dict[str, tuple[torch.dtype, torch.Size]]:
|
|
324
|
+
"""Get the model metadata from a model.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
model (TensorDictModuleBase): The model to get the metadata from.
|
|
328
|
+
Must be a TransformersWrapper or equivalent.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
dict[str, tuple[torch.dtype, torch.Size]]: The model metadata.
|
|
332
|
+
"""
|
|
333
|
+
# Check if the model has a LoRA adapter
|
|
334
|
+
if hasattr(model.model, "merge_and_unload"):
|
|
335
|
+
sd = model.model.merge_and_unload().state_dict()
|
|
336
|
+
else:
|
|
337
|
+
sd = model.model.state_dict()
|
|
338
|
+
model_metadata = {k: (v.dtype, v.shape) for k, v in sd.items()}
|
|
339
|
+
return model_metadata
|
|
340
|
+
|
|
341
|
+
def all_worker_ids(self) -> list[int]:
|
|
342
|
+
"""Returns [0] since we only have one worker."""
|
|
343
|
+
return [0]
|
|
344
|
+
|
|
345
|
+
def register_collector(self, collector: BaseCollector): # noqa: F821
|
|
346
|
+
result = super().register_collector(collector)
|
|
347
|
+
self.register_post_hook(collector.increment_version)
|
|
348
|
+
return result
|