torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import TensorDictBase
|
|
14
|
+
from torchrl._utils import logger as torchrl_logger
|
|
15
|
+
from torchrl.collectors.weight_update import WeightUpdaterBase
|
|
16
|
+
from torchrl.modules.llm.backends.vllm import RLvLLMEngine
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
_has_transformers = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
_has_transformers = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class vLLMUpdaterV2(WeightUpdaterBase):
|
|
27
|
+
"""Simplified vLLM weight updater using the RLvLLMEngine interface.
|
|
28
|
+
|
|
29
|
+
This updater works with any vLLM engine that implements the RLvLLMEngine
|
|
30
|
+
interface, automatically extracting configuration and handling weight updates
|
|
31
|
+
through the engine's own methods.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
vllm_engine: A vLLM engine implementing the RLvLLMEngine interface.
|
|
35
|
+
|
|
36
|
+
.. note:: This class can be created through :class:`torchrl.collectors.llm.vLLMUpdater` with `v2=True`.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, vllm_engine: RLvLLMEngine):
|
|
41
|
+
# Check that vllm_engine implements the RLvLLMEngine interface
|
|
42
|
+
if not isinstance(vllm_engine, RLvLLMEngine):
|
|
43
|
+
raise TypeError(
|
|
44
|
+
f"vllm_engine must implement RLvLLMEngine interface, got {type(vllm_engine)}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
torchrl_logger.debug(f"=> in {type(self).__name__}.__init__")
|
|
48
|
+
self.vllm_engine = vllm_engine
|
|
49
|
+
self.initialized_group = None
|
|
50
|
+
|
|
51
|
+
# Extract configuration from engine
|
|
52
|
+
self.vllm_tp_size = vllm_engine.get_tp_size()
|
|
53
|
+
self.master_address = vllm_engine.get_master_address()
|
|
54
|
+
self.master_port = vllm_engine.get_master_port()
|
|
55
|
+
self.model_metadata = vllm_engine.get_model_metadata()
|
|
56
|
+
|
|
57
|
+
torchrl_logger.debug(
|
|
58
|
+
f"Initialized vLLMUpdaterV2 with tp_size={self.vllm_tp_size}"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def get_tp_size(self) -> int:
|
|
62
|
+
"""Get the tensor parallel size."""
|
|
63
|
+
return self.vllm_tp_size
|
|
64
|
+
|
|
65
|
+
def init(
|
|
66
|
+
self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Initialize the weight updater.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
model_metadata: Optional model metadata. If not provided, uses engine's metadata.
|
|
72
|
+
"""
|
|
73
|
+
if model_metadata is not None:
|
|
74
|
+
self.model_metadata = model_metadata
|
|
75
|
+
|
|
76
|
+
# Initialize the engine's weight update group
|
|
77
|
+
self.vllm_engine.init_weight_update_group()
|
|
78
|
+
self.initialized_group = True
|
|
79
|
+
torchrl_logger.debug("Weight update group initialized")
|
|
80
|
+
|
|
81
|
+
def push_weights(
|
|
82
|
+
self, weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase
|
|
83
|
+
):
|
|
84
|
+
"""Push weights to the vLLM engine.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
weights: Either an iterator of (name, tensor) pairs or a TensorDictBase
|
|
88
|
+
"""
|
|
89
|
+
if isinstance(weights, TensorDictBase):
|
|
90
|
+
weights = iter(weights.flatten_keys(".").items())
|
|
91
|
+
|
|
92
|
+
if self.initialized_group is None:
|
|
93
|
+
raise RuntimeError("Weight updater not initialized. Call init() first.")
|
|
94
|
+
|
|
95
|
+
# Delegate to the engine's update_weights method
|
|
96
|
+
self.vllm_engine.update_weights(weights)
|
|
97
|
+
torchrl_logger.debug("Weight update completed")
|
|
98
|
+
|
|
99
|
+
# Call post-hooks to increment policy version
|
|
100
|
+
torchrl_logger.debug("Calling post-hooks...")
|
|
101
|
+
self._call_post_hooks()
|
|
102
|
+
torchrl_logger.debug("Post-hooks completed")
|
|
103
|
+
|
|
104
|
+
def push_weights_from_transformers(self, transformers_model):
|
|
105
|
+
"""Push weights from a transformers model.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
transformers_model: A transformers PreTrainedModel or TorchRL wrapper
|
|
109
|
+
"""
|
|
110
|
+
if not _has_transformers:
|
|
111
|
+
raise ImportError("transformers not available")
|
|
112
|
+
t0 = time.time()
|
|
113
|
+
# Extract state dict from model, handling LoRA models properly
|
|
114
|
+
if hasattr(transformers_model, "model") and hasattr(
|
|
115
|
+
transformers_model.model, "state_dict"
|
|
116
|
+
):
|
|
117
|
+
# TorchRL wrapper (e.g., TransformersWrapper)
|
|
118
|
+
model = transformers_model.model
|
|
119
|
+
# Check if it's a LoRA model
|
|
120
|
+
if hasattr(model, "merge_and_unload"):
|
|
121
|
+
state_dict = model.merge_and_unload().state_dict()
|
|
122
|
+
else:
|
|
123
|
+
state_dict = model.state_dict()
|
|
124
|
+
elif hasattr(transformers_model, "state_dict"):
|
|
125
|
+
# Direct transformers model
|
|
126
|
+
# Check if it's a LoRA model
|
|
127
|
+
if hasattr(transformers_model, "merge_and_unload"):
|
|
128
|
+
state_dict = transformers_model.merge_and_unload().state_dict()
|
|
129
|
+
else:
|
|
130
|
+
state_dict = transformers_model.state_dict()
|
|
131
|
+
else:
|
|
132
|
+
raise TypeError(
|
|
133
|
+
f"Cannot extract state_dict from {type(transformers_model)}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
t1 = time.time()
|
|
137
|
+
torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0}")
|
|
138
|
+
# Convert to iterator for memory efficiency
|
|
139
|
+
weights_iter = iter(state_dict.items())
|
|
140
|
+
self.push_weights(weights_iter)
|
|
141
|
+
torchrl_logger.debug(f"Time to push weights: {time.time() - t1}")
|
|
142
|
+
|
|
143
|
+
def push_weights_from_transformers_optimized(
|
|
144
|
+
self, transformers_model, batch_size=50
|
|
145
|
+
):
|
|
146
|
+
"""Optimized version of push_weights_from_transformers with GPU pre-loading.
|
|
147
|
+
|
|
148
|
+
This method provides several optimizations:
|
|
149
|
+
1. Pre-loads all weights to GPU before transfer
|
|
150
|
+
2. Optionally batches weights for better memory management
|
|
151
|
+
3. Uses non-blocking transfers when possible
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
transformers_model: A transformers PreTrainedModel or TorchRL wrapper
|
|
155
|
+
batch_size: Number of weights to transfer in each batch (0 = no batching)
|
|
156
|
+
"""
|
|
157
|
+
if not _has_transformers:
|
|
158
|
+
raise ImportError("transformers not available")
|
|
159
|
+
|
|
160
|
+
t0 = time.time()
|
|
161
|
+
|
|
162
|
+
# Extract state dict from model, handling LoRA models properly
|
|
163
|
+
if hasattr(transformers_model, "model") and hasattr(
|
|
164
|
+
transformers_model.model, "state_dict"
|
|
165
|
+
):
|
|
166
|
+
# TorchRL wrapper (e.g., TransformersWrapper)
|
|
167
|
+
model = transformers_model.model
|
|
168
|
+
if hasattr(model, "merge_and_unload"):
|
|
169
|
+
state_dict = model.merge_and_unload().state_dict()
|
|
170
|
+
else:
|
|
171
|
+
state_dict = model.state_dict()
|
|
172
|
+
elif hasattr(transformers_model, "state_dict"):
|
|
173
|
+
# Direct transformers model
|
|
174
|
+
if hasattr(transformers_model, "merge_and_unload"):
|
|
175
|
+
state_dict = transformers_model.merge_and_unload().state_dict()
|
|
176
|
+
else:
|
|
177
|
+
state_dict = transformers_model.state_dict()
|
|
178
|
+
else:
|
|
179
|
+
raise TypeError(
|
|
180
|
+
f"Cannot extract state_dict from {type(transformers_model)}"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
t1 = time.time()
|
|
184
|
+
torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0:.3f}s")
|
|
185
|
+
|
|
186
|
+
# Pre-load all weights to GPU for faster transfer
|
|
187
|
+
gpu_weights = {}
|
|
188
|
+
with torch.device("cuda:0"): # Ensure we're using the right GPU
|
|
189
|
+
for name, weight in state_dict.items():
|
|
190
|
+
if not weight.is_cuda:
|
|
191
|
+
gpu_weights[name] = weight.cuda(non_blocking=True)
|
|
192
|
+
else:
|
|
193
|
+
gpu_weights[name] = weight
|
|
194
|
+
|
|
195
|
+
# Synchronize to ensure all transfers are complete
|
|
196
|
+
torch.cuda.synchronize()
|
|
197
|
+
t2 = time.time()
|
|
198
|
+
torchrl_logger.debug(f"Time to move weights to GPU: {t2 - t1:.3f}s")
|
|
199
|
+
|
|
200
|
+
# Transfer weights (optionally in batches)
|
|
201
|
+
if batch_size > 0:
|
|
202
|
+
weight_items = list(gpu_weights.items())
|
|
203
|
+
for i in range(0, len(weight_items), batch_size):
|
|
204
|
+
batch = weight_items[i : i + batch_size]
|
|
205
|
+
self.push_weights(iter(batch))
|
|
206
|
+
torchrl_logger.debug(
|
|
207
|
+
f"Transferred batch {i // batch_size + 1}/{(len(weight_items) + batch_size - 1) // batch_size}"
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
# Transfer all at once
|
|
211
|
+
self.push_weights(iter(gpu_weights.items()))
|
|
212
|
+
|
|
213
|
+
t3 = time.time()
|
|
214
|
+
torchrl_logger.debug(
|
|
215
|
+
f"Time to push weights: {t3 - t2:.3f}s, total time: {t3 - t0:.3f}s"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Required WeightUpdaterBase methods
|
|
219
|
+
def _sync_weights_with_worker(self, *, worker_id=None, server_weights=None):
|
|
220
|
+
"""Sync weights with worker (delegates to push_weights)."""
|
|
221
|
+
if server_weights is None:
|
|
222
|
+
raise ValueError("server_weights cannot be None")
|
|
223
|
+
|
|
224
|
+
if hasattr(server_weights, "items"):
|
|
225
|
+
# Dict-like object
|
|
226
|
+
self.push_weights(iter(server_weights.items()))
|
|
227
|
+
else:
|
|
228
|
+
# Assume it's a model with state_dict
|
|
229
|
+
self.push_weights_from_transformers(server_weights)
|
|
230
|
+
|
|
231
|
+
def _get_server_weights(self):
|
|
232
|
+
"""Not used - weights must be passed directly."""
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
def _maybe_map_weights(self, server_weights):
|
|
236
|
+
"""Map weights to expected format."""
|
|
237
|
+
return server_weights # No mapping needed, handled in push_weights methods
|
|
238
|
+
|
|
239
|
+
def all_worker_ids(self):
|
|
240
|
+
"""Return list of worker IDs."""
|
|
241
|
+
return [0]
|
|
242
|
+
|
|
243
|
+
def register_collector(self, collector): # noqa: F821
|
|
244
|
+
"""Register a collector and set up policy version increment post-hook.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
collector: The collector to register (BaseCollector)
|
|
248
|
+
"""
|
|
249
|
+
result = super().register_collector(collector)
|
|
250
|
+
|
|
251
|
+
# Only register the increment_version post-hook once for the first collector
|
|
252
|
+
# This avoids N^2 complexity where each weight update calls increment_version
|
|
253
|
+
# on all collectors N times (once per registered collector)
|
|
254
|
+
if len(self.post_hooks) == 0:
|
|
255
|
+
torchrl_logger.debug("Registering policy version increment post-hook")
|
|
256
|
+
self.register_post_hook(self._increment_all_collector_versions)
|
|
257
|
+
|
|
258
|
+
return result
|
|
259
|
+
|
|
260
|
+
def _increment_all_collector_versions(self):
|
|
261
|
+
"""Increment version for all registered collectors efficiently."""
|
|
262
|
+
torchrl_logger.debug(
|
|
263
|
+
f"Incrementing policy version for {len(self.collectors)} collectors..."
|
|
264
|
+
)
|
|
265
|
+
for i, collector in enumerate(self.collectors):
|
|
266
|
+
try:
|
|
267
|
+
collector.increment_version()
|
|
268
|
+
torchrl_logger.debug(
|
|
269
|
+
f"Incremented version for collector {i + 1}/{len(self.collectors)}"
|
|
270
|
+
)
|
|
271
|
+
except Exception as e:
|
|
272
|
+
torchrl_logger.warning(
|
|
273
|
+
f"Failed to increment version for collector {i + 1}: {e}"
|
|
274
|
+
)
|
|
275
|
+
torchrl_logger.debug("All collector versions incremented")
|
|
276
|
+
|
|
277
|
+
@classmethod
|
|
278
|
+
def get_model_metadata(cls, model) -> dict[str, tuple[torch.dtype, torch.Size]]:
|
|
279
|
+
"""Get model metadata from a model.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
model: A model with state_dict() method (e.g., TransformersWrapper)
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
dict: Mapping of parameter names to (dtype, shape) tuples
|
|
286
|
+
"""
|
|
287
|
+
if hasattr(model, "model") and hasattr(model.model, "state_dict"):
|
|
288
|
+
# TorchRL wrapper (e.g., TransformersWrapper)
|
|
289
|
+
model_obj = model.model
|
|
290
|
+
# Check if it's a LoRA model
|
|
291
|
+
if hasattr(model_obj, "merge_and_unload"):
|
|
292
|
+
sd = model_obj.merge_and_unload().state_dict()
|
|
293
|
+
else:
|
|
294
|
+
sd = model_obj.state_dict()
|
|
295
|
+
elif hasattr(model, "state_dict"):
|
|
296
|
+
# Direct model
|
|
297
|
+
# Check if it's a LoRA model
|
|
298
|
+
if hasattr(model, "merge_and_unload"):
|
|
299
|
+
sd = model.merge_and_unload().state_dict()
|
|
300
|
+
else:
|
|
301
|
+
sd = model.state_dict()
|
|
302
|
+
else:
|
|
303
|
+
raise TypeError(f"Cannot extract state_dict from {type(model)}")
|
|
304
|
+
|
|
305
|
+
return {k: (v.dtype, v.shape) for k, v in sd.items()}
|
|
306
|
+
|
|
307
|
+
# Remove the weakrefs from the updater for serialization
|
|
308
|
+
def __getstate__(self):
|
|
309
|
+
state = self.__dict__.copy()
|
|
310
|
+
state["_collector_wrs"] = None
|
|
311
|
+
return state
|