torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .vllm_double_buffer import (
|
|
7
|
+
VLLMDoubleBufferSyncScheme,
|
|
8
|
+
VLLMDoubleBufferTransport,
|
|
9
|
+
VLLMDoubleBufferWeightReceiver,
|
|
10
|
+
VLLMDoubleBufferWeightSender,
|
|
11
|
+
)
|
|
12
|
+
from .vllm_nccl import (
|
|
13
|
+
get_model_metadata,
|
|
14
|
+
VLLMCollectiveTransport,
|
|
15
|
+
VLLMWeightReceiver,
|
|
16
|
+
VLLMWeightSender,
|
|
17
|
+
VLLMWeightSyncScheme,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
# vLLM NCCL-based weight sync
|
|
22
|
+
"VLLMWeightSyncScheme",
|
|
23
|
+
"VLLMWeightSender",
|
|
24
|
+
"VLLMWeightReceiver",
|
|
25
|
+
"VLLMCollectiveTransport",
|
|
26
|
+
"get_model_metadata",
|
|
27
|
+
# vLLM double-buffer weight sync
|
|
28
|
+
"VLLMDoubleBufferSyncScheme",
|
|
29
|
+
"VLLMDoubleBufferWeightSender",
|
|
30
|
+
"VLLMDoubleBufferWeightReceiver",
|
|
31
|
+
"VLLMDoubleBufferTransport",
|
|
32
|
+
]
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""vLLM weight synchronization using double-buffered shared memory.
|
|
7
|
+
|
|
8
|
+
This module provides weight synchronization for vLLM engines using a double-buffer
|
|
9
|
+
approach with memory-mapped TensorDict storage.
|
|
10
|
+
|
|
11
|
+
**Architecture Overview**
|
|
12
|
+
|
|
13
|
+
The double-buffer synchronization uses a simpler architecture compared to NCCL:
|
|
14
|
+
|
|
15
|
+
1. **Sender (Trainer)**
|
|
16
|
+
- Extracts weights from the training model
|
|
17
|
+
- Writes weights to shared directory using TensorDict.memmap
|
|
18
|
+
- No coordination needed - receiver pulls when ready
|
|
19
|
+
|
|
20
|
+
2. **Receiver (vLLM Worker)**
|
|
21
|
+
- Uses RPC to tell all vLLM workers to load from shared directory
|
|
22
|
+
- Each worker reads weights and calls model.load_weights()
|
|
23
|
+
- Can trigger at any time (pull-based)
|
|
24
|
+
|
|
25
|
+
**Key Differences from NCCL**
|
|
26
|
+
|
|
27
|
+
- **Async vs Sync**: Double-buffer is asynchronous (no coordination required)
|
|
28
|
+
- **Push vs Pull**: Sender writes, receiver pulls when ready via RPC
|
|
29
|
+
- **Simplicity**: No NCCL collectives, uses file I/O
|
|
30
|
+
- **Storage**: Uses shared filesystem instead of GPU-GPU transfer
|
|
31
|
+
|
|
32
|
+
**RPC Pattern**
|
|
33
|
+
|
|
34
|
+
Like the NCCL implementation, this uses RPC to coordinate workers:
|
|
35
|
+
- RPC tells workers: "load weights from this directory"
|
|
36
|
+
- Workers read from shared storage independently
|
|
37
|
+
- Each worker calls `model_runner.model.load_weights()`
|
|
38
|
+
|
|
39
|
+
**Usage Example**
|
|
40
|
+
|
|
41
|
+
.. code-block:: python
|
|
42
|
+
|
|
43
|
+
# Create scheme with shared directory
|
|
44
|
+
scheme = VLLMDoubleBufferSyncScheme(
|
|
45
|
+
remote_addr="/shared/weights",
|
|
46
|
+
num_threads=4
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Sender side (trainer)
|
|
50
|
+
sender = scheme.create_sender()
|
|
51
|
+
sender.register_model(policy_model)
|
|
52
|
+
sender.update_weights() # Writes to /shared/weights
|
|
53
|
+
|
|
54
|
+
# Receiver side (vLLM worker - AsyncVLLM)
|
|
55
|
+
receiver = scheme.create_receiver(vllm_engine)
|
|
56
|
+
receiver.poll_and_apply() # RPC to workers -> load from /shared/weights
|
|
57
|
+
|
|
58
|
+
**Node-to-Node Transfer**
|
|
59
|
+
|
|
60
|
+
For distributed setups, you can use different addresses:
|
|
61
|
+
- Sender writes to local path
|
|
62
|
+
- Use NFS, rsync, or other file sync mechanisms
|
|
63
|
+
- Receiver reads from its local mount point
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
from __future__ import annotations
|
|
67
|
+
|
|
68
|
+
from typing import Any, Literal
|
|
69
|
+
|
|
70
|
+
from tensordict import TensorDict, TensorDictBase
|
|
71
|
+
from torchrl._utils import logger
|
|
72
|
+
from torchrl.weight_update.weight_sync_schemes import WeightStrategy, WeightSyncScheme
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class VLLMDoubleBufferTransport:
|
|
76
|
+
"""Transport for vLLM using double-buffered memory-mapped storage.
|
|
77
|
+
|
|
78
|
+
This transport writes weights to a shared directory and reads them back
|
|
79
|
+
using TensorDict's memory-mapping capabilities.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
remote_addr: Directory path where sender writes weights.
|
|
83
|
+
local_addr: Directory path where receiver reads weights.
|
|
84
|
+
If None, uses same path as remote_addr (for local testing).
|
|
85
|
+
num_threads: Number of threads for memmap operations.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self, remote_addr: str, local_addr: str | None = None, num_threads: int = 1
|
|
90
|
+
):
|
|
91
|
+
if local_addr is None:
|
|
92
|
+
local_addr = remote_addr
|
|
93
|
+
self.remote_addr = remote_addr
|
|
94
|
+
self.local_addr = local_addr
|
|
95
|
+
self.num_threads = num_threads
|
|
96
|
+
|
|
97
|
+
def send_weights(self, model_id: str, weights: Any) -> None:
|
|
98
|
+
"""Writes the weights to a shared directory.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model_id: Identifier for the model (used for logging).
|
|
102
|
+
weights: TensorDict or dict of weights to write.
|
|
103
|
+
"""
|
|
104
|
+
if isinstance(weights, dict):
|
|
105
|
+
weights = TensorDict(weights, batch_size=[])
|
|
106
|
+
elif isinstance(weights, TensorDictBase):
|
|
107
|
+
# Ensure it has a batch_size
|
|
108
|
+
if weights.batch_size == ():
|
|
109
|
+
weights = weights.clone()
|
|
110
|
+
|
|
111
|
+
logger.info(f"Writing weights for model '{model_id}' to {self.remote_addr}")
|
|
112
|
+
weights.memmap(self.remote_addr, num_threads=self.num_threads)
|
|
113
|
+
logger.info(f"Weights written successfully to {self.remote_addr}")
|
|
114
|
+
|
|
115
|
+
def receive_weights(
|
|
116
|
+
self,
|
|
117
|
+
timeout: float | None = None,
|
|
118
|
+
*,
|
|
119
|
+
weights: Any = None,
|
|
120
|
+
model: Any = None,
|
|
121
|
+
strategy: Any = None,
|
|
122
|
+
) -> Any | None:
|
|
123
|
+
"""Reads the weights from the shared directory.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
timeout: Ignored (file-based transport is instant).
|
|
127
|
+
weights: Ignored.
|
|
128
|
+
model: Ignored.
|
|
129
|
+
strategy: Ignored.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
TensorDict with flattened keys containing the weights.
|
|
133
|
+
"""
|
|
134
|
+
# Timeout is ignored since file-based transport doesn't involve waiting
|
|
135
|
+
logger.info(f"Reading weights from {self.local_addr}")
|
|
136
|
+
received_weights = TensorDict.load_memmap(self.local_addr)
|
|
137
|
+
received_weights = received_weights.flatten_keys(".")
|
|
138
|
+
logger.info(f"Weights read successfully from {self.local_addr}")
|
|
139
|
+
return received_weights
|
|
140
|
+
|
|
141
|
+
def check_connection(self) -> bool:
|
|
142
|
+
"""Check if the transport is ready.
|
|
143
|
+
|
|
144
|
+
For file-based transport, always returns True.
|
|
145
|
+
"""
|
|
146
|
+
return True
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class VLLMDoubleBufferSyncScheme(WeightSyncScheme):
|
|
150
|
+
"""Weight synchronization scheme for vLLM using double-buffered storage.
|
|
151
|
+
|
|
152
|
+
This scheme uses memory-mapped TensorDict storage to transfer weights from
|
|
153
|
+
a trainer to vLLM inference workers. It's simpler than NCCL-based approaches
|
|
154
|
+
and doesn't require process group coordination.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
remote_addr: Directory path where sender writes weights.
|
|
158
|
+
local_addr: Directory path where receiver reads weights.
|
|
159
|
+
If None, uses same path as remote_addr (for local testing).
|
|
160
|
+
num_threads: Number of threads for memmap operations. Defaults to 1.
|
|
161
|
+
strategy: Weight extraction strategy ("tensordict" or "state_dict").
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
>>> # Local testing (same machine)
|
|
165
|
+
>>> scheme = VLLMDoubleBufferSyncScheme(
|
|
166
|
+
... remote_addr="/tmp/weights",
|
|
167
|
+
... strategy="tensordict"
|
|
168
|
+
... )
|
|
169
|
+
>>>
|
|
170
|
+
>>> # Distributed setup (different machines)
|
|
171
|
+
>>> # On trainer node:
|
|
172
|
+
>>> scheme = VLLMDoubleBufferSyncScheme(
|
|
173
|
+
... remote_addr="/mnt/shared/weights", # NFS mount
|
|
174
|
+
... num_threads=4
|
|
175
|
+
... )
|
|
176
|
+
>>>
|
|
177
|
+
>>> # On vLLM worker node:
|
|
178
|
+
>>> scheme = VLLMDoubleBufferSyncScheme(
|
|
179
|
+
... remote_addr="/mnt/shared/weights", # Same NFS mount
|
|
180
|
+
... num_threads=4
|
|
181
|
+
... )
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
remote_addr: str,
|
|
187
|
+
local_addr: str | None = None,
|
|
188
|
+
num_threads: int = 1,
|
|
189
|
+
strategy: Literal["tensordict", "state_dict"] = "tensordict",
|
|
190
|
+
):
|
|
191
|
+
self.remote_addr = remote_addr
|
|
192
|
+
self.local_addr = local_addr if local_addr is not None else remote_addr
|
|
193
|
+
self.num_threads = num_threads
|
|
194
|
+
self.strategy_name = strategy
|
|
195
|
+
|
|
196
|
+
def create_transport(self, **kwargs) -> VLLMDoubleBufferTransport:
|
|
197
|
+
"""Create transport for double-buffered storage.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
**kwargs: Not used for file-based transport (kept for API compatibility).
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
A VLLMDoubleBufferTransport instance.
|
|
204
|
+
"""
|
|
205
|
+
return VLLMDoubleBufferTransport(
|
|
206
|
+
remote_addr=self.remote_addr,
|
|
207
|
+
local_addr=self.local_addr,
|
|
208
|
+
num_threads=self.num_threads,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def create_sender(self) -> VLLMDoubleBufferWeightSender:
|
|
212
|
+
"""Create a weight sender for the trainer process."""
|
|
213
|
+
return VLLMDoubleBufferWeightSender(self)
|
|
214
|
+
|
|
215
|
+
def create_receiver(self, vllm_engine) -> VLLMDoubleBufferWeightReceiver:
|
|
216
|
+
"""Create a weight receiver for a vLLM worker process.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
vllm_engine: The vLLM engine instance (must have .llm_engine.model_executor attribute).
|
|
220
|
+
"""
|
|
221
|
+
return VLLMDoubleBufferWeightReceiver(self, vllm_engine)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class VLLMDoubleBufferWeightSender:
|
|
225
|
+
"""Sends weights to vLLM workers using double-buffered storage.
|
|
226
|
+
|
|
227
|
+
This sender extracts weights from a training model and writes them to
|
|
228
|
+
a shared directory using TensorDict.memmap.
|
|
229
|
+
|
|
230
|
+
Example:
|
|
231
|
+
>>> sender = scheme.create_sender()
|
|
232
|
+
>>> sender.register_model(policy_model)
|
|
233
|
+
>>>
|
|
234
|
+
>>> # During training loop
|
|
235
|
+
>>> sender.update_weights() # Writes current weights to shared storage
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self, scheme: VLLMDoubleBufferSyncScheme):
|
|
239
|
+
self._scheme = scheme
|
|
240
|
+
self._strategy = WeightStrategy(extract_as=scheme.strategy_name)
|
|
241
|
+
self._model_ref = None
|
|
242
|
+
self._transport = None
|
|
243
|
+
|
|
244
|
+
def register_model(self, model: Any) -> None:
|
|
245
|
+
"""Register the model to extract weights from.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
model: The model to extract weights from (e.g., TransformersWrapper).
|
|
249
|
+
"""
|
|
250
|
+
import weakref
|
|
251
|
+
|
|
252
|
+
self._model_ref = weakref.ref(model)
|
|
253
|
+
|
|
254
|
+
# Create transport on registration
|
|
255
|
+
self._transport = self._scheme.create_transport()
|
|
256
|
+
logger.info(
|
|
257
|
+
f"Registered model for double-buffer weight sync to {self._scheme.remote_addr}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def update_weights(self, weights: Any | None = None) -> None:
|
|
261
|
+
"""Extract and write weights to shared storage.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
weights: Optional weights to send. If None, extracts from registered model.
|
|
265
|
+
"""
|
|
266
|
+
if self._transport is None:
|
|
267
|
+
raise RuntimeError("Transport not initialized. Call register_model first.")
|
|
268
|
+
|
|
269
|
+
# Extract weights if not provided
|
|
270
|
+
if weights is None:
|
|
271
|
+
model = self._model_ref()
|
|
272
|
+
if model is None:
|
|
273
|
+
raise RuntimeError("Model reference is dead")
|
|
274
|
+
weights = self._strategy.extract_weights(model)
|
|
275
|
+
else:
|
|
276
|
+
# Ensure weights are in the right format
|
|
277
|
+
if hasattr(weights, "state_dict"):
|
|
278
|
+
# It's a module, extract
|
|
279
|
+
weights = self._strategy.extract_weights(weights)
|
|
280
|
+
|
|
281
|
+
# Send via transport
|
|
282
|
+
self._transport.send_weights("vllm_model", weights)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class VLLMDoubleBufferWeightReceiver:
|
|
286
|
+
"""Receives weights in a vLLM worker using double-buffered storage.
|
|
287
|
+
|
|
288
|
+
This receiver reads weights from a shared directory and loads them into
|
|
289
|
+
the vLLM engine using the engine's load_weights interface.
|
|
290
|
+
|
|
291
|
+
Example:
|
|
292
|
+
>>> receiver = scheme.create_receiver(vllm_engine)
|
|
293
|
+
>>>
|
|
294
|
+
>>> # Poll for new weights
|
|
295
|
+
>>> if receiver.poll_and_apply():
|
|
296
|
+
... print("Weights updated!")
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine):
|
|
300
|
+
self._scheme = scheme
|
|
301
|
+
self._strategy = WeightStrategy(extract_as=scheme.strategy_name)
|
|
302
|
+
self._vllm_engine = vllm_engine
|
|
303
|
+
self._transport = scheme.create_transport()
|
|
304
|
+
logger.info(
|
|
305
|
+
f"Initialized double-buffer receiver reading from {self._scheme.local_addr}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def apply_weights(self, weights: TensorDict, inplace: bool = True) -> None:
|
|
309
|
+
"""Apply weights to vLLM engine using RPC.
|
|
310
|
+
|
|
311
|
+
This method uses RPC to tell all vLLM workers to load weights from
|
|
312
|
+
the shared storage directory. Similar to how AsyncVLLM._update_weights_with_nccl_broadcast_simple
|
|
313
|
+
uses collective_rpc to coordinate workers.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
weights: TensorDict with flattened keys containing weights.
|
|
317
|
+
inplace: Whether to apply weights in place. Default is `True`.
|
|
318
|
+
"""
|
|
319
|
+
if not inplace:
|
|
320
|
+
raise ValueError("Cannot apply weights out of place for vLLM double-buffer")
|
|
321
|
+
logger.info("Applying weights to vLLM engine via RPC")
|
|
322
|
+
|
|
323
|
+
# Convert TensorDict to list of (name, tensor) tuples
|
|
324
|
+
weights_list = list(weights.items())
|
|
325
|
+
|
|
326
|
+
# Check if this is an AsyncVLLM instance (uses RPC to coordinate workers)
|
|
327
|
+
if hasattr(self._vllm_engine, "collective_rpc"):
|
|
328
|
+
# AsyncVLLM path: use RPC to tell all workers to load weights
|
|
329
|
+
logger.info(
|
|
330
|
+
f"Using RPC to load {len(weights_list)} weights across all replicas"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Call collective_rpc to tell workers to load from shared storage
|
|
334
|
+
# The method 'load_weights_from_storage' will be called on each worker
|
|
335
|
+
futures = self._vllm_engine.collective_rpc(
|
|
336
|
+
method="load_weights_from_storage",
|
|
337
|
+
args=(str(self._scheme.local_addr), self._transport.num_threads),
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Wait for all workers to complete
|
|
341
|
+
import ray
|
|
342
|
+
|
|
343
|
+
ray.get(futures)
|
|
344
|
+
logger.info("Weights loaded successfully via RPC")
|
|
345
|
+
else:
|
|
346
|
+
# Direct path for local LLM (non-AsyncVLLM)
|
|
347
|
+
logger.info("Using direct load for local LLM")
|
|
348
|
+
engine = (
|
|
349
|
+
self._vllm_engine.llm_engine
|
|
350
|
+
if hasattr(self._vllm_engine, "llm_engine")
|
|
351
|
+
else self._vllm_engine
|
|
352
|
+
)
|
|
353
|
+
worker = engine.model_executor.driver_worker
|
|
354
|
+
model = worker.model_runner.model
|
|
355
|
+
model.load_weights(weights_list)
|
|
356
|
+
logger.info("Weights loaded successfully")
|
|
357
|
+
|
|
358
|
+
def poll_and_apply(self, timeout: float = 180.0) -> bool:
|
|
359
|
+
"""Poll for and apply weights from shared storage.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
timeout: Not used for file-based transport (kept for API compatibility).
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
True if weights were successfully read and applied, False otherwise.
|
|
366
|
+
"""
|
|
367
|
+
# timeout is not used by file-based transport but kept for API compatibility
|
|
368
|
+
weights = self._transport.receive_weights()
|
|
369
|
+
self.apply_weights(weights)
|
|
370
|
+
return True
|