torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from tensordict import NestedKey, TensorDictBase
|
|
5
|
+
from tensordict.nn import TensorDictModuleBase
|
|
6
|
+
from torch import nn, vmap
|
|
7
|
+
|
|
8
|
+
from torchrl._utils import logger, RL_WARNINGS
|
|
9
|
+
from torchrl.modules import MLP
|
|
10
|
+
from torchrl.objectives.value.advantages import _vmap_func
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"BiasModule",
|
|
14
|
+
"LSTMNet",
|
|
15
|
+
"NonSerializableBiasModule",
|
|
16
|
+
"call_value_nets",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BiasModule(nn.Module):
|
|
21
|
+
"""Simple bias module to check weight synchronization correctness."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, value: float = 0.0):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.bias = nn.Parameter(torch.tensor(value, dtype=torch.float))
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
return x + self.bias
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class NonSerializableBiasModule(BiasModule):
|
|
32
|
+
"""Bias module that intentionally fails to serialize.
|
|
33
|
+
|
|
34
|
+
This is used in tests to simulate a policy that cannot be pickled.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __getstate__(self):
|
|
38
|
+
# Simulate a non-serializable policy by raising on pickling
|
|
39
|
+
raise RuntimeError("NonSerializableBiasModule cannot be pickled")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LSTMNet(nn.Module):
|
|
43
|
+
"""An embedder for an LSTM preceded by an MLP.
|
|
44
|
+
|
|
45
|
+
The forward method returns the hidden states of the current state
|
|
46
|
+
(input hidden states) and the output, as
|
|
47
|
+
the environment returns the 'observation' and 'next_observation'.
|
|
48
|
+
|
|
49
|
+
Because the LSTM kernel only returns the last hidden state, hidden states
|
|
50
|
+
are padded with zeros such that they have the right size to be stored in a
|
|
51
|
+
TensorDict of size [batch x time_steps].
|
|
52
|
+
|
|
53
|
+
If a 2D tensor is provided as input, it is assumed that it is a batch of data
|
|
54
|
+
with only one time step. This means that we explicitly assume that users will
|
|
55
|
+
unsqueeze inputs of a single batch with multiple time steps.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
out_features (int): number of output features.
|
|
59
|
+
lstm_kwargs (dict): the keyword arguments for the
|
|
60
|
+
:class:`~torch.nn.LSTM` layer.
|
|
61
|
+
mlp_kwargs (dict): the keyword arguments for the
|
|
62
|
+
:class:`~torchrl.modules.MLP` layer.
|
|
63
|
+
device (torch.device, optional): the device where the module should
|
|
64
|
+
be instantiated.
|
|
65
|
+
|
|
66
|
+
Keyword Args:
|
|
67
|
+
lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that
|
|
68
|
+
indicates where the LSTM class is to be retrieved. The ``"torchrl"``
|
|
69
|
+
backend (:class:`~torchrl.modules.LSTM`) is slower but works with
|
|
70
|
+
:func:`~torch.vmap` and should work with :func:`~torch.compile`.
|
|
71
|
+
Defaults to ``"torch"``.
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
>>> batch = 7
|
|
75
|
+
>>> time_steps = 6
|
|
76
|
+
>>> in_features = 4
|
|
77
|
+
>>> out_features = 10
|
|
78
|
+
>>> hidden_size = 5
|
|
79
|
+
>>> net = LSTMNet(
|
|
80
|
+
... out_features,
|
|
81
|
+
... {"input_size": hidden_size, "hidden_size": hidden_size},
|
|
82
|
+
... {"out_features": hidden_size},
|
|
83
|
+
... )
|
|
84
|
+
>>> # test single step vs multi-step
|
|
85
|
+
>>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step
|
|
86
|
+
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
|
|
87
|
+
>>> x = torch.randn(batch, in_features) # 2 dims = single step
|
|
88
|
+
>>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x)
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
out_features: int,
|
|
95
|
+
lstm_kwargs,
|
|
96
|
+
mlp_kwargs,
|
|
97
|
+
device=None,
|
|
98
|
+
*,
|
|
99
|
+
lstm_backend: str | None = None,
|
|
100
|
+
) -> None:
|
|
101
|
+
super().__init__()
|
|
102
|
+
lstm_kwargs.update({"batch_first": True})
|
|
103
|
+
self.mlp = MLP(device=device, **mlp_kwargs)
|
|
104
|
+
if lstm_backend is None:
|
|
105
|
+
lstm_backend = "torch"
|
|
106
|
+
self.lstm_backend = lstm_backend
|
|
107
|
+
if self.lstm_backend == "torch":
|
|
108
|
+
LSTM = nn.LSTM
|
|
109
|
+
else:
|
|
110
|
+
from torchrl.modules.tensordict_module.rnn import LSTM
|
|
111
|
+
self.lstm = LSTM(device=device, **lstm_kwargs)
|
|
112
|
+
self.linear = nn.LazyLinear(out_features, device=device)
|
|
113
|
+
|
|
114
|
+
def _lstm(
|
|
115
|
+
self,
|
|
116
|
+
input: torch.Tensor,
|
|
117
|
+
hidden0_in: torch.Tensor | None = None,
|
|
118
|
+
hidden1_in: torch.Tensor | None = None,
|
|
119
|
+
):
|
|
120
|
+
squeeze0 = False
|
|
121
|
+
squeeze1 = False
|
|
122
|
+
if input.ndimension() == 1:
|
|
123
|
+
squeeze0 = True
|
|
124
|
+
input = input.unsqueeze(0).contiguous()
|
|
125
|
+
|
|
126
|
+
if input.ndimension() == 2:
|
|
127
|
+
squeeze1 = True
|
|
128
|
+
input = input.unsqueeze(1).contiguous()
|
|
129
|
+
batch, steps = input.shape[:2]
|
|
130
|
+
|
|
131
|
+
if hidden1_in is None and hidden0_in is None:
|
|
132
|
+
shape = (batch, steps) if not squeeze1 else (batch,)
|
|
133
|
+
hidden0_in, hidden1_in = (
|
|
134
|
+
torch.zeros(
|
|
135
|
+
*shape,
|
|
136
|
+
self.lstm.num_layers,
|
|
137
|
+
self.lstm.hidden_size,
|
|
138
|
+
device=input.device,
|
|
139
|
+
dtype=input.dtype,
|
|
140
|
+
)
|
|
141
|
+
for _ in range(2)
|
|
142
|
+
)
|
|
143
|
+
elif hidden1_in is None or hidden0_in is None:
|
|
144
|
+
raise RuntimeError(
|
|
145
|
+
f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}"
|
|
146
|
+
)
|
|
147
|
+
elif squeeze0:
|
|
148
|
+
hidden0_in = hidden0_in.unsqueeze(0)
|
|
149
|
+
hidden1_in = hidden1_in.unsqueeze(0)
|
|
150
|
+
|
|
151
|
+
# we only need the first hidden state
|
|
152
|
+
if not squeeze1:
|
|
153
|
+
_hidden0_in = hidden0_in[:, 0]
|
|
154
|
+
_hidden1_in = hidden1_in[:, 0]
|
|
155
|
+
else:
|
|
156
|
+
_hidden0_in = hidden0_in
|
|
157
|
+
_hidden1_in = hidden1_in
|
|
158
|
+
hidden = (
|
|
159
|
+
_hidden0_in.transpose(-3, -2).contiguous(),
|
|
160
|
+
_hidden1_in.transpose(-3, -2).contiguous(),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
y0, hidden = self.lstm(input, hidden)
|
|
164
|
+
# dim 0 in hidden is num_layers, but that will conflict with tensordict
|
|
165
|
+
hidden = tuple(_h.transpose(0, 1) for _h in hidden)
|
|
166
|
+
y = self.linear(y0)
|
|
167
|
+
|
|
168
|
+
out = [y, hidden0_in, hidden1_in, *hidden]
|
|
169
|
+
if squeeze1:
|
|
170
|
+
# squeezes time
|
|
171
|
+
out[0] = out[0].squeeze(1)
|
|
172
|
+
if not squeeze1:
|
|
173
|
+
# we pad the hidden states with zero to make tensordict happy
|
|
174
|
+
for i in range(3, 5):
|
|
175
|
+
out[i] = torch.stack(
|
|
176
|
+
[torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)]
|
|
177
|
+
+ [out[i]],
|
|
178
|
+
1,
|
|
179
|
+
)
|
|
180
|
+
if squeeze0:
|
|
181
|
+
out = [_out.squeeze(0) for _out in out]
|
|
182
|
+
return tuple(out)
|
|
183
|
+
|
|
184
|
+
def forward(
|
|
185
|
+
self,
|
|
186
|
+
input: torch.Tensor,
|
|
187
|
+
hidden0_in: torch.Tensor | None = None,
|
|
188
|
+
hidden1_in: torch.Tensor | None = None,
|
|
189
|
+
):
|
|
190
|
+
input = self.mlp(input)
|
|
191
|
+
return self._lstm(input, hidden0_in, hidden1_in)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def call_value_nets(
|
|
195
|
+
value_net: TensorDictModuleBase,
|
|
196
|
+
data: TensorDictBase,
|
|
197
|
+
params: TensorDictBase,
|
|
198
|
+
next_params: TensorDictBase,
|
|
199
|
+
single_call: bool,
|
|
200
|
+
value_key: NestedKey,
|
|
201
|
+
detach_next: bool,
|
|
202
|
+
vmap_randomness: str = "error",
|
|
203
|
+
):
|
|
204
|
+
"""Call value networks to compute values at t and t+1.
|
|
205
|
+
|
|
206
|
+
This is a testing utility for computing value estimates in advantage
|
|
207
|
+
calculations.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
value_net: The value network module.
|
|
211
|
+
data: Input tensordict with observations.
|
|
212
|
+
params: Parameters for the value network at time t.
|
|
213
|
+
next_params: Parameters for the value network at time t+1.
|
|
214
|
+
single_call: Whether to use a single forward pass for both t and t+1.
|
|
215
|
+
value_key: The key where values are stored.
|
|
216
|
+
detach_next: Whether to detach the next value from the computation graph.
|
|
217
|
+
vmap_randomness: Randomness mode for vmap.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tuple of (value, value_next).
|
|
221
|
+
"""
|
|
222
|
+
in_keys = value_net.in_keys
|
|
223
|
+
if single_call:
|
|
224
|
+
for i, name in enumerate(data.names):
|
|
225
|
+
if name == "time":
|
|
226
|
+
ndim = i + 1
|
|
227
|
+
break
|
|
228
|
+
else:
|
|
229
|
+
ndim = None
|
|
230
|
+
if ndim is not None:
|
|
231
|
+
# get data at t and last of t+1
|
|
232
|
+
idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),)
|
|
233
|
+
idx = (slice(None),) * (ndim - 1) + (slice(None, -1),)
|
|
234
|
+
idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),)
|
|
235
|
+
data_in = torch.cat(
|
|
236
|
+
[
|
|
237
|
+
data.select(*in_keys, value_key, strict=False),
|
|
238
|
+
data.get("next").select(*in_keys, value_key, strict=False)[idx0],
|
|
239
|
+
],
|
|
240
|
+
ndim - 1,
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
if RL_WARNINGS:
|
|
244
|
+
logger.warning(
|
|
245
|
+
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
|
|
246
|
+
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
|
|
247
|
+
)
|
|
248
|
+
ndim = data.ndim
|
|
249
|
+
idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),)
|
|
250
|
+
idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),)
|
|
251
|
+
data_in = torch.cat(
|
|
252
|
+
[
|
|
253
|
+
data.select(*in_keys, value_key, strict=False),
|
|
254
|
+
data.get("next").select(*in_keys, value_key, strict=False),
|
|
255
|
+
],
|
|
256
|
+
ndim - 1,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# next_params should be None or be identical to params
|
|
260
|
+
if next_params is not None and next_params is not params:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
|
|
263
|
+
)
|
|
264
|
+
if params is not None:
|
|
265
|
+
with params.to_module(value_net):
|
|
266
|
+
value_est = value_net(data_in).get(value_key)
|
|
267
|
+
else:
|
|
268
|
+
value_est = value_net(data_in).get(value_key)
|
|
269
|
+
value, value_ = value_est[idx], value_est[idx_]
|
|
270
|
+
else:
|
|
271
|
+
data_in = torch.stack(
|
|
272
|
+
[
|
|
273
|
+
data.select(*in_keys, value_key, strict=False),
|
|
274
|
+
data.get("next").select(*in_keys, value_key, strict=False),
|
|
275
|
+
],
|
|
276
|
+
0,
|
|
277
|
+
)
|
|
278
|
+
if (params is not None) ^ (next_params is not None):
|
|
279
|
+
raise ValueError(
|
|
280
|
+
"params and next_params must be either both provided or not."
|
|
281
|
+
)
|
|
282
|
+
elif params is not None:
|
|
283
|
+
params_stack = torch.stack([params, next_params], 0).contiguous()
|
|
284
|
+
data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(
|
|
285
|
+
data_in, params_stack
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
|
|
289
|
+
value_est = data_out.get(value_key)
|
|
290
|
+
value, value_ = value_est[0], value_est[1]
|
|
291
|
+
data.set(value_key, value)
|
|
292
|
+
data.set(("next", value_key), value_)
|
|
293
|
+
if detach_next:
|
|
294
|
+
value_ = value_.detach()
|
|
295
|
+
return value, value_
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from torchrl.data.utils import CloudpickleWrapper
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def decorate_thread_sub_func(func, num_threads):
|
|
9
|
+
"""Decorate a function to assert that the number of threads is correct."""
|
|
10
|
+
|
|
11
|
+
def new_func(*args, **kwargs):
|
|
12
|
+
assert torch.get_num_threads() == num_threads
|
|
13
|
+
return func(*args, **kwargs)
|
|
14
|
+
|
|
15
|
+
return CloudpickleWrapper(new_func)
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Helper classes for Ray-based weight synchronization tests.
|
|
7
|
+
|
|
8
|
+
This module contains Ray actor classes that need to be importable by Ray workers.
|
|
9
|
+
These classes are used in tests but must be defined at module level in a proper
|
|
10
|
+
Python package (not in test files) so Ray can serialize and import them on remote workers.
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from torchrl._utils import logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WorkerVLLMNCCL:
|
|
19
|
+
"""Ray actor for vLLM inference worker (receiver) using NCCL collective communication."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
scheme_config: dict,
|
|
24
|
+
model_name: str = "Qwen/Qwen2.5-0.5B",
|
|
25
|
+
trainer_actor_name: str = "Trainer",
|
|
26
|
+
):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
# Store config for deferred initialization
|
|
30
|
+
self.scheme_config = scheme_config
|
|
31
|
+
self.model_name = model_name
|
|
32
|
+
self.trainer_actor_name = trainer_actor_name
|
|
33
|
+
self.wrapper = None
|
|
34
|
+
self.engine = None
|
|
35
|
+
self.receiver = None
|
|
36
|
+
self.scheme = None
|
|
37
|
+
self.trainer = None
|
|
38
|
+
self.model_metadata = None
|
|
39
|
+
|
|
40
|
+
def setup(self):
|
|
41
|
+
"""Set up vLLM engine (deferred from __init__ to avoid blocking)."""
|
|
42
|
+
from torchrl.modules.llm.backends import AsyncVLLM
|
|
43
|
+
from torchrl.modules.llm.policies import vLLMWrapper
|
|
44
|
+
|
|
45
|
+
# Create vLLM wrapper
|
|
46
|
+
async_engine = AsyncVLLM.from_pretrained(
|
|
47
|
+
self.model_name,
|
|
48
|
+
num_replicas=2, # Number of engine replicas
|
|
49
|
+
)
|
|
50
|
+
self.wrapper = vLLMWrapper(async_engine, input_mode="history")
|
|
51
|
+
self.engine = self.wrapper.model
|
|
52
|
+
|
|
53
|
+
# Create scheme from config
|
|
54
|
+
from torchrl.weight_update.llm.vllm_nccl import VLLMWeightSyncScheme
|
|
55
|
+
|
|
56
|
+
self.scheme = VLLMWeightSyncScheme(**self.scheme_config)
|
|
57
|
+
|
|
58
|
+
# Create receiver (engine handles rank assignment automatically)
|
|
59
|
+
self.receiver = self.scheme.create_receiver(self.engine)
|
|
60
|
+
return "setup_complete"
|
|
61
|
+
|
|
62
|
+
def init_metadata(self):
|
|
63
|
+
"""Initialize the receiver by fetching metadata from trainer."""
|
|
64
|
+
import ray
|
|
65
|
+
|
|
66
|
+
if self.receiver is None:
|
|
67
|
+
raise RuntimeError("Must call setup() before init()")
|
|
68
|
+
|
|
69
|
+
# Get trainer actor by name
|
|
70
|
+
logger.info(f"Getting trainer actor by name {self.trainer_actor_name}")
|
|
71
|
+
self.trainer = ray.get_actor(self.trainer_actor_name)
|
|
72
|
+
|
|
73
|
+
# Fetch model metadata from trainer
|
|
74
|
+
logger.info("Fetching model metadata from trainer (requires max_concurrency>1)")
|
|
75
|
+
self.model_metadata = ray.get(self.trainer.get_model_metadata.remote())
|
|
76
|
+
|
|
77
|
+
def init(self):
|
|
78
|
+
if self.model_metadata is None:
|
|
79
|
+
raise RuntimeError("Must call init_metadata() before init()")
|
|
80
|
+
|
|
81
|
+
# Initialize receiver with metadata
|
|
82
|
+
logger.info("Initializing receiver...")
|
|
83
|
+
self.receiver.init_all_workers_group(self.model_metadata)
|
|
84
|
+
self.initialized = True
|
|
85
|
+
logger.info("Receiver initialized")
|
|
86
|
+
return "initialized"
|
|
87
|
+
|
|
88
|
+
def get_engine(self):
|
|
89
|
+
"""Get the vLLM engine reference for RPC coordination."""
|
|
90
|
+
if self.engine is None:
|
|
91
|
+
raise RuntimeError("Must call setup() first")
|
|
92
|
+
return self.engine
|
|
93
|
+
|
|
94
|
+
def get_sample_output(self):
|
|
95
|
+
"""Get a sample output to verify model works."""
|
|
96
|
+
# Simple inference test
|
|
97
|
+
return "vllm_ready"
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def as_remote(cls, *args, **kwargs):
|
|
101
|
+
import ray
|
|
102
|
+
|
|
103
|
+
# No GPUs needed for the actor itself - vLLM workers manage their own placement group (2 GPUs)
|
|
104
|
+
# AsyncVLLM service doesn't act as NCCL rank 0 when used with external trainer
|
|
105
|
+
return ray.remote(num_cpus=4, num_gpus=0, max_concurrency=4)(cls)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class WorkerTransformerNCCL:
|
|
109
|
+
"""Ray actor for transformer trainer (sender) using NCCL collective communication."""
|
|
110
|
+
|
|
111
|
+
def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
|
|
112
|
+
from torchrl.weight_update.llm.vllm_nccl import (
|
|
113
|
+
get_model_metadata,
|
|
114
|
+
VLLMWeightSyncScheme,
|
|
115
|
+
)
|
|
116
|
+
from transformers import AutoModelForCausalLM
|
|
117
|
+
|
|
118
|
+
# Create transformer model
|
|
119
|
+
transformer = AutoModelForCausalLM.from_pretrained(
|
|
120
|
+
model_name,
|
|
121
|
+
dtype=torch.float16,
|
|
122
|
+
)
|
|
123
|
+
self.transformer = transformer.cuda()
|
|
124
|
+
|
|
125
|
+
# Create scheme from config
|
|
126
|
+
self.scheme = VLLMWeightSyncScheme(**scheme_config)
|
|
127
|
+
|
|
128
|
+
# Create sender
|
|
129
|
+
self.sender = self.scheme.create_sender()
|
|
130
|
+
self.sender.register_model(self.transformer)
|
|
131
|
+
|
|
132
|
+
# Extract and store model metadata
|
|
133
|
+
self.model_metadata = get_model_metadata(self.transformer)
|
|
134
|
+
|
|
135
|
+
def init(self, vllm_engine=None):
|
|
136
|
+
"""Initialize sender with optional vLLM engine for RPC coordination.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
vllm_engine: Optional vLLM engine reference for calling collective_rpc
|
|
140
|
+
"""
|
|
141
|
+
if self.model_metadata is None:
|
|
142
|
+
raise RuntimeError("Must call init_metadata() before init()")
|
|
143
|
+
|
|
144
|
+
self.sender.init_all_workers_group(self.model_metadata, vllm_engine=vllm_engine)
|
|
145
|
+
self.initialized = True
|
|
146
|
+
logger.info("Trainer initialized")
|
|
147
|
+
return "initialized"
|
|
148
|
+
|
|
149
|
+
def get_model_metadata(self):
|
|
150
|
+
"""Get model metadata to share with receiver."""
|
|
151
|
+
return self.model_metadata
|
|
152
|
+
|
|
153
|
+
def update_weights(self, modify_weights: bool = False):
|
|
154
|
+
"""Trigger a weight update broadcast.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
modify_weights: If True, modifies weights before broadcasting
|
|
158
|
+
for verification purposes.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
str: "updated" status message
|
|
162
|
+
"""
|
|
163
|
+
# Optionally modify weights for testing
|
|
164
|
+
if modify_weights:
|
|
165
|
+
with torch.no_grad():
|
|
166
|
+
first_param = next(self.transformer.parameters())
|
|
167
|
+
first_param.add_(0.01)
|
|
168
|
+
|
|
169
|
+
# Broadcast weights to all vLLM workers
|
|
170
|
+
self.sender.update_weights()
|
|
171
|
+
return "updated"
|
|
172
|
+
|
|
173
|
+
def get_first_param_sum(self):
|
|
174
|
+
"""Get sum of first parameter for verification."""
|
|
175
|
+
return next(self.transformer.parameters()).sum().item()
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def as_remote(cls, *args, **kwargs):
|
|
179
|
+
import ray
|
|
180
|
+
|
|
181
|
+
return ray.remote(num_cpus=4, num_gpus=1, max_concurrency=4)(cls)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class WorkerVLLMDoubleBuffer:
|
|
185
|
+
"""Ray actor for vLLM inference worker (receiver) using double-buffered storage."""
|
|
186
|
+
|
|
187
|
+
def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
|
|
188
|
+
# Store config for deferred initialization
|
|
189
|
+
self.scheme_config = scheme_config
|
|
190
|
+
self.model_name = model_name
|
|
191
|
+
self.wrapper = None
|
|
192
|
+
self.engine = None
|
|
193
|
+
self.receiver = None
|
|
194
|
+
self.scheme = None
|
|
195
|
+
|
|
196
|
+
def setup(self):
|
|
197
|
+
"""Set up vLLM engine and receiver."""
|
|
198
|
+
from torchrl.modules.llm.backends import AsyncVLLM
|
|
199
|
+
from torchrl.modules.llm.policies import vLLMWrapper
|
|
200
|
+
|
|
201
|
+
# Create vLLM wrapper
|
|
202
|
+
async_engine = AsyncVLLM.from_pretrained(
|
|
203
|
+
self.model_name,
|
|
204
|
+
num_replicas=1, # Single replica for simplicity
|
|
205
|
+
)
|
|
206
|
+
self.wrapper = vLLMWrapper(async_engine, input_mode="history")
|
|
207
|
+
self.engine = self.wrapper.model
|
|
208
|
+
|
|
209
|
+
# Create scheme from config
|
|
210
|
+
from torchrl.weight_update.llm.vllm_double_buffer import (
|
|
211
|
+
VLLMDoubleBufferSyncScheme,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
self.scheme = VLLMDoubleBufferSyncScheme(**self.scheme_config)
|
|
215
|
+
|
|
216
|
+
# Create receiver
|
|
217
|
+
self.receiver = self.scheme.create_receiver(self.engine)
|
|
218
|
+
logger.info("Receiver setup complete")
|
|
219
|
+
return "setup_complete"
|
|
220
|
+
|
|
221
|
+
def poll_and_apply_weights(self):
|
|
222
|
+
"""Poll for new weights and apply them to the engine."""
|
|
223
|
+
if self.receiver is None:
|
|
224
|
+
raise RuntimeError("Must call setup() first")
|
|
225
|
+
|
|
226
|
+
success = self.receiver.poll_and_apply()
|
|
227
|
+
return success
|
|
228
|
+
|
|
229
|
+
def get_sample_output(self):
|
|
230
|
+
"""Get a sample output to verify model works."""
|
|
231
|
+
return "vllm_ready"
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def as_remote(cls, *args, **kwargs):
|
|
235
|
+
import ray
|
|
236
|
+
|
|
237
|
+
# vLLM worker needs 1 GPU
|
|
238
|
+
return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class WorkerTransformerDoubleBuffer:
|
|
242
|
+
"""Ray actor for transformer trainer (sender) using double-buffered storage."""
|
|
243
|
+
|
|
244
|
+
def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"):
|
|
245
|
+
from torchrl.weight_update.llm.vllm_double_buffer import (
|
|
246
|
+
VLLMDoubleBufferSyncScheme,
|
|
247
|
+
)
|
|
248
|
+
from transformers import AutoModelForCausalLM
|
|
249
|
+
|
|
250
|
+
# Create transformer model
|
|
251
|
+
transformer = AutoModelForCausalLM.from_pretrained(
|
|
252
|
+
model_name,
|
|
253
|
+
dtype=torch.float16,
|
|
254
|
+
)
|
|
255
|
+
self.transformer = transformer.cuda()
|
|
256
|
+
|
|
257
|
+
# Create scheme from config
|
|
258
|
+
self.scheme = VLLMDoubleBufferSyncScheme(**scheme_config)
|
|
259
|
+
|
|
260
|
+
# Create sender
|
|
261
|
+
self.sender = self.scheme.create_sender()
|
|
262
|
+
self.sender.register_model(self.transformer)
|
|
263
|
+
logger.info("Trainer setup complete")
|
|
264
|
+
|
|
265
|
+
def update_weights(self, modify_weights: bool = False):
|
|
266
|
+
"""Trigger a weight update by writing to shared storage.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
modify_weights: If True, modifies weights before writing
|
|
270
|
+
for verification purposes.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
str: "updated" status message
|
|
274
|
+
"""
|
|
275
|
+
# Optionally modify weights for testing
|
|
276
|
+
if modify_weights:
|
|
277
|
+
with torch.no_grad():
|
|
278
|
+
first_param = next(self.transformer.parameters())
|
|
279
|
+
first_param.add_(0.01)
|
|
280
|
+
|
|
281
|
+
# Write weights to shared storage
|
|
282
|
+
self.sender.update_weights()
|
|
283
|
+
return "updated"
|
|
284
|
+
|
|
285
|
+
def get_first_param_sum(self):
|
|
286
|
+
"""Get sum of first parameter for verification."""
|
|
287
|
+
return next(self.transformer.parameters()).sum().item()
|
|
288
|
+
|
|
289
|
+
@classmethod
|
|
290
|
+
def as_remote(cls, *args, **kwargs):
|
|
291
|
+
import ray
|
|
292
|
+
|
|
293
|
+
return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls)
|