torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class AdamConfig(ConfigBase):
|
|
15
|
+
"""Configuration for Adam optimizer."""
|
|
16
|
+
|
|
17
|
+
lr: float = 1e-3
|
|
18
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
19
|
+
eps: float = 1e-4
|
|
20
|
+
weight_decay: float = 0.0
|
|
21
|
+
amsgrad: bool = False
|
|
22
|
+
_target_: str = "torch.optim.Adam"
|
|
23
|
+
_partial_: bool = True
|
|
24
|
+
|
|
25
|
+
def __post_init__(self) -> None:
|
|
26
|
+
"""Post-initialization hook for Adam optimizer configurations."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class AdamWConfig(ConfigBase):
|
|
31
|
+
"""Configuration for AdamW optimizer."""
|
|
32
|
+
|
|
33
|
+
lr: float = 1e-3
|
|
34
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
35
|
+
eps: float = 1e-8
|
|
36
|
+
weight_decay: float = 1e-2
|
|
37
|
+
amsgrad: bool = False
|
|
38
|
+
maximize: bool = False
|
|
39
|
+
foreach: bool | None = None
|
|
40
|
+
capturable: bool = False
|
|
41
|
+
differentiable: bool = False
|
|
42
|
+
fused: bool | None = None
|
|
43
|
+
_target_: str = "torch.optim.AdamW"
|
|
44
|
+
_partial_: bool = True
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
"""Post-initialization hook for AdamW optimizer configurations."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class AdamaxConfig(ConfigBase):
|
|
52
|
+
"""Configuration for Adamax optimizer."""
|
|
53
|
+
|
|
54
|
+
lr: float = 2e-3
|
|
55
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
56
|
+
eps: float = 1e-8
|
|
57
|
+
weight_decay: float = 0.0
|
|
58
|
+
_target_: str = "torch.optim.Adamax"
|
|
59
|
+
_partial_: bool = True
|
|
60
|
+
|
|
61
|
+
def __post_init__(self) -> None:
|
|
62
|
+
"""Post-initialization hook for Adamax optimizer configurations."""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class SGDConfig(ConfigBase):
|
|
67
|
+
"""Configuration for SGD optimizer."""
|
|
68
|
+
|
|
69
|
+
lr: float = 1e-3
|
|
70
|
+
momentum: float = 0.0
|
|
71
|
+
dampening: float = 0.0
|
|
72
|
+
weight_decay: float = 0.0
|
|
73
|
+
nesterov: bool = False
|
|
74
|
+
maximize: bool = False
|
|
75
|
+
foreach: bool | None = None
|
|
76
|
+
differentiable: bool = False
|
|
77
|
+
_target_: str = "torch.optim.SGD"
|
|
78
|
+
_partial_: bool = True
|
|
79
|
+
|
|
80
|
+
def __post_init__(self) -> None:
|
|
81
|
+
"""Post-initialization hook for SGD optimizer configurations."""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class RMSpropConfig(ConfigBase):
|
|
86
|
+
"""Configuration for RMSprop optimizer."""
|
|
87
|
+
|
|
88
|
+
lr: float = 1e-2
|
|
89
|
+
alpha: float = 0.99
|
|
90
|
+
eps: float = 1e-8
|
|
91
|
+
weight_decay: float = 0.0
|
|
92
|
+
momentum: float = 0.0
|
|
93
|
+
centered: bool = False
|
|
94
|
+
maximize: bool = False
|
|
95
|
+
foreach: bool | None = None
|
|
96
|
+
differentiable: bool = False
|
|
97
|
+
_target_: str = "torch.optim.RMSprop"
|
|
98
|
+
_partial_: bool = True
|
|
99
|
+
|
|
100
|
+
def __post_init__(self) -> None:
|
|
101
|
+
"""Post-initialization hook for RMSprop optimizer configurations."""
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class AdagradConfig(ConfigBase):
|
|
106
|
+
"""Configuration for Adagrad optimizer."""
|
|
107
|
+
|
|
108
|
+
lr: float = 1e-2
|
|
109
|
+
lr_decay: float = 0.0
|
|
110
|
+
weight_decay: float = 0.0
|
|
111
|
+
initial_accumulator_value: float = 0.0
|
|
112
|
+
eps: float = 1e-10
|
|
113
|
+
maximize: bool = False
|
|
114
|
+
foreach: bool | None = None
|
|
115
|
+
differentiable: bool = False
|
|
116
|
+
_target_: str = "torch.optim.Adagrad"
|
|
117
|
+
_partial_: bool = True
|
|
118
|
+
|
|
119
|
+
def __post_init__(self) -> None:
|
|
120
|
+
"""Post-initialization hook for Adagrad optimizer configurations."""
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass
|
|
124
|
+
class AdadeltaConfig(ConfigBase):
|
|
125
|
+
"""Configuration for Adadelta optimizer."""
|
|
126
|
+
|
|
127
|
+
lr: float = 1.0
|
|
128
|
+
rho: float = 0.9
|
|
129
|
+
eps: float = 1e-6
|
|
130
|
+
weight_decay: float = 0.0
|
|
131
|
+
foreach: bool | None = None
|
|
132
|
+
maximize: bool = False
|
|
133
|
+
differentiable: bool = False
|
|
134
|
+
_target_: str = "torch.optim.Adadelta"
|
|
135
|
+
_partial_: bool = True
|
|
136
|
+
|
|
137
|
+
def __post_init__(self) -> None:
|
|
138
|
+
"""Post-initialization hook for Adadelta optimizer configurations."""
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class RpropConfig(ConfigBase):
|
|
143
|
+
"""Configuration for Rprop optimizer."""
|
|
144
|
+
|
|
145
|
+
lr: float = 1e-2
|
|
146
|
+
etas: tuple[float, float] = (0.5, 1.2)
|
|
147
|
+
step_sizes: tuple[float, float] = (1e-6, 50.0)
|
|
148
|
+
foreach: bool | None = None
|
|
149
|
+
maximize: bool = False
|
|
150
|
+
differentiable: bool = False
|
|
151
|
+
_target_: str = "torch.optim.Rprop"
|
|
152
|
+
_partial_: bool = True
|
|
153
|
+
|
|
154
|
+
def __post_init__(self) -> None:
|
|
155
|
+
"""Post-initialization hook for Rprop optimizer configurations."""
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclass
|
|
159
|
+
class ASGDConfig(ConfigBase):
|
|
160
|
+
"""Configuration for ASGD optimizer."""
|
|
161
|
+
|
|
162
|
+
lr: float = 1e-2
|
|
163
|
+
lambd: float = 1e-4
|
|
164
|
+
alpha: float = 0.75
|
|
165
|
+
t0: float = 1e6
|
|
166
|
+
weight_decay: float = 0.0
|
|
167
|
+
foreach: bool | None = None
|
|
168
|
+
maximize: bool = False
|
|
169
|
+
differentiable: bool = False
|
|
170
|
+
_target_: str = "torch.optim.ASGD"
|
|
171
|
+
_partial_: bool = True
|
|
172
|
+
|
|
173
|
+
def __post_init__(self) -> None:
|
|
174
|
+
"""Post-initialization hook for ASGD optimizer configurations."""
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclass
|
|
178
|
+
class LBFGSConfig(ConfigBase):
|
|
179
|
+
"""Configuration for LBFGS optimizer."""
|
|
180
|
+
|
|
181
|
+
lr: float = 1.0
|
|
182
|
+
max_iter: int = 20
|
|
183
|
+
max_eval: int | None = None
|
|
184
|
+
tolerance_grad: float = 1e-7
|
|
185
|
+
tolerance_change: float = 1e-9
|
|
186
|
+
history_size: int = 100
|
|
187
|
+
line_search_fn: str | None = None
|
|
188
|
+
_target_: str = "torch.optim.LBFGS"
|
|
189
|
+
_partial_: bool = True
|
|
190
|
+
|
|
191
|
+
def __post_init__(self) -> None:
|
|
192
|
+
"""Post-initialization hook for LBFGS optimizer configurations."""
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass
|
|
196
|
+
class RAdamConfig(ConfigBase):
|
|
197
|
+
"""Configuration for RAdam optimizer."""
|
|
198
|
+
|
|
199
|
+
lr: float = 1e-3
|
|
200
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
201
|
+
eps: float = 1e-8
|
|
202
|
+
weight_decay: float = 0.0
|
|
203
|
+
_target_: str = "torch.optim.RAdam"
|
|
204
|
+
_partial_: bool = True
|
|
205
|
+
|
|
206
|
+
def __post_init__(self) -> None:
|
|
207
|
+
"""Post-initialization hook for RAdam optimizer configurations."""
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@dataclass
|
|
211
|
+
class NAdamConfig(ConfigBase):
|
|
212
|
+
"""Configuration for NAdam optimizer."""
|
|
213
|
+
|
|
214
|
+
lr: float = 2e-3
|
|
215
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
216
|
+
eps: float = 1e-8
|
|
217
|
+
weight_decay: float = 0.0
|
|
218
|
+
momentum_decay: float = 4e-3
|
|
219
|
+
foreach: bool | None = None
|
|
220
|
+
_target_: str = "torch.optim.NAdam"
|
|
221
|
+
_partial_: bool = True
|
|
222
|
+
|
|
223
|
+
def __post_init__(self) -> None:
|
|
224
|
+
"""Post-initialization hook for NAdam optimizer configurations."""
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@dataclass
|
|
228
|
+
class SparseAdamConfig(ConfigBase):
|
|
229
|
+
"""Configuration for SparseAdam optimizer."""
|
|
230
|
+
|
|
231
|
+
lr: float = 1e-3
|
|
232
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
233
|
+
eps: float = 1e-8
|
|
234
|
+
_target_: str = "torch.optim.SparseAdam"
|
|
235
|
+
_partial_: bool = True
|
|
236
|
+
|
|
237
|
+
def __post_init__(self) -> None:
|
|
238
|
+
"""Post-initialization hook for SparseAdam optimizer configurations."""
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@dataclass
|
|
242
|
+
class LionConfig(ConfigBase):
|
|
243
|
+
"""Configuration for Lion optimizer."""
|
|
244
|
+
|
|
245
|
+
lr: float = 1e-4
|
|
246
|
+
betas: tuple[float, float] = (0.9, 0.99)
|
|
247
|
+
weight_decay: float = 0.0
|
|
248
|
+
_target_: str = "torch.optim.Lion"
|
|
249
|
+
_partial_: bool = True
|
|
250
|
+
|
|
251
|
+
def __post_init__(self) -> None:
|
|
252
|
+
"""Post-initialization hook for Lion optimizer configurations."""
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class WeightSyncSchemeConfig(ConfigBase):
|
|
16
|
+
"""Base configuration for weight synchronization schemes."""
|
|
17
|
+
|
|
18
|
+
_target_: str = "torchrl.weight_update.WeightSyncScheme"
|
|
19
|
+
_partial_: bool = False
|
|
20
|
+
|
|
21
|
+
# Common argument for all schemes
|
|
22
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
23
|
+
|
|
24
|
+
def __post_init__(self) -> None:
|
|
25
|
+
"""Post-initialization hook for weight sync scheme configurations."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class MultiProcessWeightSyncSchemeConfig(ConfigBase):
|
|
30
|
+
"""Configuration for MultiProcessWeightSyncScheme.
|
|
31
|
+
|
|
32
|
+
Weight synchronization for multiprocess operations using pipes.
|
|
33
|
+
This scheme creates transports that communicate via multiprocessing pipes.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
_target_: str = "torchrl.weight_update.MultiProcessWeightSyncScheme"
|
|
37
|
+
_partial_: bool = False
|
|
38
|
+
|
|
39
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
40
|
+
|
|
41
|
+
def __post_init__(self) -> None:
|
|
42
|
+
"""Post-initialization hook for multiprocess weight sync scheme configurations."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class SharedMemWeightSyncSchemeConfig(ConfigBase):
|
|
47
|
+
"""Configuration for SharedMemWeightSyncScheme.
|
|
48
|
+
|
|
49
|
+
Weight synchronization using shared memory for in-place weight updates.
|
|
50
|
+
Workers automatically see weight updates without explicit message passing.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
_target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme"
|
|
54
|
+
_partial_: bool = False
|
|
55
|
+
|
|
56
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
57
|
+
|
|
58
|
+
def __post_init__(self) -> None:
|
|
59
|
+
"""Post-initialization hook for shared memory weight sync scheme configurations."""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class NoWeightSyncSchemeConfig(ConfigBase):
|
|
64
|
+
"""Configuration for NoWeightSyncScheme.
|
|
65
|
+
|
|
66
|
+
No-op weight synchronization scheme that disables weight synchronization entirely.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
_target_: str = "torchrl.weight_update.NoWeightSyncScheme"
|
|
70
|
+
_partial_: bool = False
|
|
71
|
+
|
|
72
|
+
strategy: str = "tensordict" # Not really used, but kept for consistency
|
|
73
|
+
|
|
74
|
+
def __post_init__(self) -> None:
|
|
75
|
+
"""Post-initialization hook for no weight sync scheme configurations."""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class RayWeightSyncSchemeConfig(ConfigBase):
|
|
80
|
+
"""Configuration for RayWeightSyncScheme.
|
|
81
|
+
|
|
82
|
+
Weight synchronization for Ray distributed computing. Uses Ray's object store
|
|
83
|
+
and remote calls to synchronize weights across distributed workers (Ray actors).
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
_target_: str = "torchrl.weight_update.RayWeightSyncScheme"
|
|
87
|
+
_partial_: bool = False
|
|
88
|
+
|
|
89
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
90
|
+
|
|
91
|
+
def __post_init__(self) -> None:
|
|
92
|
+
"""Post-initialization hook for Ray weight sync scheme configurations."""
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class RayModuleTransformSchemeConfig(ConfigBase):
|
|
97
|
+
"""Configuration for RayModuleTransformScheme.
|
|
98
|
+
|
|
99
|
+
Weight synchronization for RayModuleTransform actors. This scheme is designed
|
|
100
|
+
specifically for updating models hosted within Ray actors.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
_target_: str = "torchrl.weight_update.RayModuleTransformScheme"
|
|
104
|
+
_partial_: bool = False
|
|
105
|
+
|
|
106
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
107
|
+
|
|
108
|
+
def __post_init__(self) -> None:
|
|
109
|
+
"""Post-initialization hook for Ray module transform scheme configurations."""
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class RPCWeightSyncSchemeConfig(ConfigBase):
|
|
114
|
+
"""Configuration for RPCWeightSyncScheme.
|
|
115
|
+
|
|
116
|
+
Weight synchronization for torch.distributed.rpc. Uses RPC calls to synchronize
|
|
117
|
+
weights across distributed workers.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
_target_: str = "torchrl.weight_update.RPCWeightSyncScheme"
|
|
121
|
+
_partial_: bool = False
|
|
122
|
+
|
|
123
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
124
|
+
|
|
125
|
+
def __post_init__(self) -> None:
|
|
126
|
+
"""Post-initialization hook for RPC weight sync scheme configurations."""
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class DistributedWeightSyncSchemeConfig(ConfigBase):
|
|
131
|
+
"""Configuration for DistributedWeightSyncScheme.
|
|
132
|
+
|
|
133
|
+
Weight synchronization for torch.distributed. Uses torch.distributed primitives
|
|
134
|
+
(send/recv) to synchronize weights across distributed workers.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
_target_: str = "torchrl.weight_update.DistributedWeightSyncScheme"
|
|
138
|
+
_partial_: bool = False
|
|
139
|
+
|
|
140
|
+
backend: str = "gloo" # "gloo", "nccl", etc.
|
|
141
|
+
sync: bool = True
|
|
142
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
143
|
+
|
|
144
|
+
def __post_init__(self) -> None:
|
|
145
|
+
"""Post-initialization hook for distributed weight sync scheme configurations."""
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class VLLMWeightSyncSchemeConfig(ConfigBase):
|
|
150
|
+
"""Configuration for VLLMWeightSyncScheme.
|
|
151
|
+
|
|
152
|
+
Weight synchronization scheme for vLLM engines using collective communication (NCCL).
|
|
153
|
+
Broadcasts weights from a trainer to vLLM inference workers with parallelism support.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
_target_: str = "torchrl.weight_update.llm.VLLMWeightSyncScheme"
|
|
157
|
+
_partial_: bool = False
|
|
158
|
+
|
|
159
|
+
master_address: str | None = None # Defaults to "localhost"
|
|
160
|
+
master_port: int | None = None # Auto-assigned if None
|
|
161
|
+
gpus_per_replica: int = 1 # tp_size × dp_size × pp_size
|
|
162
|
+
num_replicas: int = 1
|
|
163
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
164
|
+
device: Any = 0 # torch.device | str | int
|
|
165
|
+
|
|
166
|
+
def __post_init__(self) -> None:
|
|
167
|
+
"""Post-initialization hook for vLLM weight sync scheme configurations."""
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclass
|
|
171
|
+
class VLLMDoubleBufferSyncSchemeConfig(ConfigBase):
|
|
172
|
+
"""Configuration for VLLMDoubleBufferSyncScheme.
|
|
173
|
+
|
|
174
|
+
Weight synchronization scheme for vLLM using double-buffered memory-mapped storage.
|
|
175
|
+
Uses TensorDict's memory-mapping capabilities to transfer weights via filesystem.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
_target_: str = "torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme"
|
|
179
|
+
_partial_: bool = False
|
|
180
|
+
|
|
181
|
+
remote_addr: str | None = None # Directory path where sender writes weights
|
|
182
|
+
local_addr: str | None = None # Directory path where receiver reads weights
|
|
183
|
+
num_threads: int = 1 # Number of threads for memmap operations
|
|
184
|
+
strategy: str = "tensordict" # "tensordict" or "state_dict"
|
|
185
|
+
|
|
186
|
+
def __post_init__(self) -> None:
|
|
187
|
+
"""Post-initialization hook for vLLM double buffer sync scheme configurations."""
|
|
188
|
+
if self.remote_addr is None:
|
|
189
|
+
raise ValueError("remote_addr is required for VLLMDoubleBufferSyncScheme")
|
|
190
|
+
if self.local_addr is None:
|
|
191
|
+
self.local_addr = self.remote_addr
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class WeightUpdaterConfig(ConfigBase):
|
|
16
|
+
"""Base configuration for weight updaters."""
|
|
17
|
+
|
|
18
|
+
_target_: str = "torchrl.collectors.WeightUpdaterBase"
|
|
19
|
+
_partial_: bool = True
|
|
20
|
+
|
|
21
|
+
def __post_init__(self) -> None:
|
|
22
|
+
"""Post-initialization hook for weight updater configurations."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class VanillaWeightUpdaterConfig(ConfigBase):
|
|
27
|
+
"""Configuration for VanillaWeightUpdater.
|
|
28
|
+
|
|
29
|
+
A simple implementation for updating local policy weights by directly
|
|
30
|
+
fetching them from a specified source.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
_target_: str = "torchrl.collectors.VanillaWeightUpdater"
|
|
34
|
+
_partial_: bool = True
|
|
35
|
+
|
|
36
|
+
# Constructor arguments
|
|
37
|
+
weight_getter: Any = None # Callable[[], TensorDictBase] | None
|
|
38
|
+
policy_weights: Any = None # TensorDictBase
|
|
39
|
+
|
|
40
|
+
def __post_init__(self) -> None:
|
|
41
|
+
"""Post-initialization hook for vanilla weight updater configurations."""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class MultiProcessedWeightUpdaterConfig(ConfigBase):
|
|
46
|
+
"""Configuration for MultiProcessedWeightUpdater.
|
|
47
|
+
|
|
48
|
+
A remote weight updater for synchronizing policy weights across multiple
|
|
49
|
+
processes or devices in a multiprocessed environment.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
_target_: str = "torchrl.collectors.MultiProcessedWeightUpdater"
|
|
53
|
+
_partial_: bool = True
|
|
54
|
+
|
|
55
|
+
# Constructor arguments
|
|
56
|
+
get_server_weights: Any = None # Callable[[], TensorDictBase] | None
|
|
57
|
+
policy_weights: Any = None # dict[torch.device, TensorDictBase]
|
|
58
|
+
|
|
59
|
+
def __post_init__(self) -> None:
|
|
60
|
+
"""Post-initialization hook for multiprocessed weight updater configurations."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class RayWeightUpdaterConfig(ConfigBase):
|
|
65
|
+
"""Configuration for RayWeightUpdater.
|
|
66
|
+
|
|
67
|
+
A remote weight updater for synchronizing policy weights across remote
|
|
68
|
+
workers using Ray's distributed computing capabilities.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
_target_: str = "torchrl.collectors.RayWeightUpdater"
|
|
72
|
+
_partial_: bool = True
|
|
73
|
+
|
|
74
|
+
# Constructor arguments
|
|
75
|
+
policy_weights: Any = None # TensorDictBase
|
|
76
|
+
remote_collectors: Any = None # list
|
|
77
|
+
max_interval: int = 0 # int
|
|
78
|
+
|
|
79
|
+
def __post_init__(self) -> None:
|
|
80
|
+
"""Post-initialization hook for Ray weight updater configurations."""
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class RPCWeightUpdaterConfig(ConfigBase):
|
|
85
|
+
"""Configuration for RPCWeightUpdater.
|
|
86
|
+
|
|
87
|
+
A remote weight updater for synchronizing policy weights across remote
|
|
88
|
+
workers using RPC communication.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
_target_: str = "torchrl.collectors.distributed.RPCWeightUpdater"
|
|
92
|
+
_partial_: bool = True
|
|
93
|
+
|
|
94
|
+
# Constructor arguments
|
|
95
|
+
collector_infos: Any = None
|
|
96
|
+
collector_class: Any = None
|
|
97
|
+
collector_rrefs: Any = None
|
|
98
|
+
policy_weights: Any = None # TensorDictBase
|
|
99
|
+
num_workers: int = 0
|
|
100
|
+
|
|
101
|
+
def __post_init__(self) -> None:
|
|
102
|
+
"""Post-initialization hook for RPC weight updater configurations."""
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class DistributedWeightUpdaterConfig(ConfigBase):
|
|
107
|
+
"""Configuration for DistributedWeightUpdater.
|
|
108
|
+
|
|
109
|
+
A remote weight updater for synchronizing policy weights across distributed
|
|
110
|
+
workers using a dictionary-like store for communication.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
_target_: str = "torchrl.collectors.distributed.DistributedWeightUpdater"
|
|
114
|
+
_partial_: bool = True
|
|
115
|
+
|
|
116
|
+
# Constructor arguments
|
|
117
|
+
store: Any = None # dict[str, str]
|
|
118
|
+
policy_weights: Any = None # TensorDictBase
|
|
119
|
+
num_workers: int = 0
|
|
120
|
+
sync: bool = True
|
|
121
|
+
|
|
122
|
+
def __post_init__(self) -> None:
|
|
123
|
+
"""Post-initialization hook for distributed weight updater configurations."""
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class RemoteModuleWeightUpdaterConfig(ConfigBase):
|
|
128
|
+
"""Configuration for RemoteModuleWeightUpdater.
|
|
129
|
+
|
|
130
|
+
A weight updater for remote nn.Modules that requires explicit weight passing.
|
|
131
|
+
Used when the master collector doesn't have direct access to worker weights.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
_target_: str = "torchrl.collectors.RemoteModuleWeightUpdater"
|
|
135
|
+
_partial_: bool = True
|
|
136
|
+
|
|
137
|
+
def __post_init__(self) -> None:
|
|
138
|
+
"""Post-initialization hook for remote module weight updater configurations."""
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class vLLMUpdaterConfig(ConfigBase):
|
|
143
|
+
"""Configuration for vLLMUpdater.
|
|
144
|
+
|
|
145
|
+
A weight updater that sends weights to vLLM workers, supporting both local
|
|
146
|
+
vLLM instances and remote Ray actors for LLM inference.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
_target_: str = "torchrl.collectors.llm.vLLMUpdater"
|
|
150
|
+
_partial_: bool = True
|
|
151
|
+
|
|
152
|
+
# Constructor arguments
|
|
153
|
+
master_address: str | None = None
|
|
154
|
+
master_port: int | None = None
|
|
155
|
+
model_metadata: Any = None # dict[str, tuple[torch.dtype, torch.Size]] | None
|
|
156
|
+
vllm_tp_size: int | None = None
|
|
157
|
+
|
|
158
|
+
def __post_init__(self) -> None:
|
|
159
|
+
"""Post-initialization hook for vLLM updater configurations."""
|