torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
import math
|
|
10
|
+
from functools import partial
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import hydra
|
|
14
|
+
|
|
15
|
+
from torchrl import merge_ray_runtime_env, torchrl_logger
|
|
16
|
+
from torchrl.data.llm.history import History
|
|
17
|
+
from torchrl.record.loggers.wandb import WandbLogger
|
|
18
|
+
from torchrl.weight_update.llm import get_model_metadata
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import ray
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Ray is required for sync training. Please install ray with `pip install ray`."
|
|
25
|
+
)
|
|
26
|
+
import time
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
import tqdm
|
|
30
|
+
|
|
31
|
+
from ei_utils import (
|
|
32
|
+
compute_device_allocation,
|
|
33
|
+
create_cosine_scheduler_with_warmup,
|
|
34
|
+
get_inference_model,
|
|
35
|
+
get_train_model,
|
|
36
|
+
log_training_metrics,
|
|
37
|
+
make_env,
|
|
38
|
+
make_weight_sync_scheme,
|
|
39
|
+
RemoteDataLogger,
|
|
40
|
+
)
|
|
41
|
+
from omegaconf import DictConfig
|
|
42
|
+
from ray.util.queue import Queue
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from tensordict import set_list_to_stack
|
|
46
|
+
except ImportError:
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"TensorDict is required. Please install it with `pip install tensordict`."
|
|
49
|
+
)
|
|
50
|
+
from torch.amp.autocast_mode import autocast
|
|
51
|
+
from torch.amp.grad_scaler import GradScaler
|
|
52
|
+
from torchrl._utils import timeit
|
|
53
|
+
from torchrl.collectors.llm import RayLLMCollector
|
|
54
|
+
from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
|
|
55
|
+
from torchrl.data.llm.topk import TopKRewardSelector
|
|
56
|
+
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
|
|
57
|
+
from torchrl.objectives.llm.sft import SFTLoss
|
|
58
|
+
|
|
59
|
+
DEFAULT_DIALOG_TURNS_PER_BATCH = 256
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def setup_environment() -> None:
|
|
63
|
+
"""Setup required environment variables and configurations."""
|
|
64
|
+
|
|
65
|
+
if not torch.cuda.is_available():
|
|
66
|
+
raise RuntimeError("CUDA is required for training")
|
|
67
|
+
|
|
68
|
+
# Set default dtype to float32 for mixed precision training
|
|
69
|
+
torch.set_default_dtype(torch.float32)
|
|
70
|
+
torch.set_default_device("cuda:0")
|
|
71
|
+
set_list_to_stack(True).set()
|
|
72
|
+
|
|
73
|
+
# Ensure CUDA is using the correct dtype
|
|
74
|
+
if torch.cuda.is_available():
|
|
75
|
+
torch.cuda.set_device("cuda:0")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def train(
|
|
79
|
+
replay_buffer: ReplayBuffer,
|
|
80
|
+
cfg: DictConfig,
|
|
81
|
+
collector: RayLLMCollector,
|
|
82
|
+
devices: list[int] | None = None,
|
|
83
|
+
):
|
|
84
|
+
"""Main training loop for EI sync.
|
|
85
|
+
|
|
86
|
+
This function implements synchronous training where data collection and optimization
|
|
87
|
+
happen in separate, consecutive steps. The total number of steps is determined by the number of epochs,
|
|
88
|
+
samples per epoch, and batches collected.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
replay_buffer: The replay buffer to store experiences. The sampler will typically be a `SamplerWithoutReplacement`.
|
|
92
|
+
cfg: The configuration object containing training parameters
|
|
93
|
+
collector: The collector object.
|
|
94
|
+
devices: The devices to use for the training model.
|
|
95
|
+
"""
|
|
96
|
+
# Setup training model and tokenizer
|
|
97
|
+
policy_training, train_tokenizer = get_train_model(
|
|
98
|
+
cfg, devices=devices, chat_template_name="qwen"
|
|
99
|
+
)
|
|
100
|
+
train_device = devices[0] # Use first device for batch processing
|
|
101
|
+
|
|
102
|
+
# Setup loss function
|
|
103
|
+
loss_fn = SFTLoss(
|
|
104
|
+
actor_network=policy_training,
|
|
105
|
+
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff,
|
|
106
|
+
tokenizer=train_tokenizer,
|
|
107
|
+
tokenizer_kwargs={"chat_template_name": "qwen"},
|
|
108
|
+
device=torch.device(f"cuda:{train_device}")
|
|
109
|
+
if train_device is not None
|
|
110
|
+
else None,
|
|
111
|
+
loss_function=cfg.train.loss_function,
|
|
112
|
+
beta=cfg.train.minor_sft_beta,
|
|
113
|
+
)
|
|
114
|
+
if cfg.model.compile:
|
|
115
|
+
loss_fn = torch.compile(loss_fn)
|
|
116
|
+
|
|
117
|
+
# Create weight sync scheme
|
|
118
|
+
weight_sync_scheme = make_weight_sync_scheme(
|
|
119
|
+
master_address="localhost", # Since we're running locally
|
|
120
|
+
master_port=None, # Will auto-assign an open port
|
|
121
|
+
vllm_tp_size=cfg.inference_model.num_devices
|
|
122
|
+
if cfg.inference_model.num_devices is not None
|
|
123
|
+
else len(cfg.inference_model.get("devices", [1])),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Set up weight sender
|
|
127
|
+
torchrl_logger.info("Setting up weight synchronization scheme...")
|
|
128
|
+
sender = weight_sync_scheme.create_sender()
|
|
129
|
+
sender.register_model(policy_training)
|
|
130
|
+
|
|
131
|
+
# Get vLLM engine reference from collector's policy
|
|
132
|
+
vllm_engine = collector.policy.model if hasattr(collector, "policy") else None
|
|
133
|
+
if vllm_engine is None:
|
|
134
|
+
raise RuntimeError("Could not get vLLM engine from collector policy")
|
|
135
|
+
|
|
136
|
+
# Initialize collective group
|
|
137
|
+
torchrl_logger.info("Initializing collective group...")
|
|
138
|
+
metadata = get_model_metadata(policy_training)
|
|
139
|
+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
|
|
140
|
+
|
|
141
|
+
# First weight update
|
|
142
|
+
with timeit("update_policy_weights"):
|
|
143
|
+
sender.update_weights()
|
|
144
|
+
timeit.print(prefix="First update_policy_weights_ time")
|
|
145
|
+
timeit.reset()
|
|
146
|
+
|
|
147
|
+
# Make optimizer
|
|
148
|
+
torchrl_logger.info("Starting optimizer.")
|
|
149
|
+
optimizer = torch.optim.Adam(
|
|
150
|
+
policy_training.parameters(),
|
|
151
|
+
lr=cfg.optimizer.lr,
|
|
152
|
+
weight_decay=cfg.optimizer.weight_decay,
|
|
153
|
+
fused=False,
|
|
154
|
+
)
|
|
155
|
+
scaler = GradScaler(enabled=cfg.train.mixed_precision)
|
|
156
|
+
|
|
157
|
+
# Calculate total optimization steps for scheduler
|
|
158
|
+
# The training loop structure: for each collector iteration, we do cfg.train.epochs epochs
|
|
159
|
+
# Each epoch processes the entire replay buffer, and optimization happens every gradient_accumulation_steps
|
|
160
|
+
# We need to estimate the total number of optimization steps
|
|
161
|
+
# For now, we'll use a conservative estimate based on the total dialog turns
|
|
162
|
+
# This can be refined based on the actual training dynamics
|
|
163
|
+
total_optim_steps = (
|
|
164
|
+
cfg.train.total_dialog_turns
|
|
165
|
+
* cfg.train.epochs
|
|
166
|
+
// cfg.train.gradient_accumulation_steps
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Create scheduler if enabled
|
|
170
|
+
scheduler = None
|
|
171
|
+
if cfg.optimizer.scheduler.enabled:
|
|
172
|
+
warmup_steps = cfg.optimizer.scheduler.warmup_steps
|
|
173
|
+
num_cycles = cfg.optimizer.scheduler.num_cycles
|
|
174
|
+
torchrl_logger.info(
|
|
175
|
+
f"Creating {cfg.optimizer.scheduler.type} scheduler with {warmup_steps} warmup steps out of {total_optim_steps} total steps"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
scheduler = create_cosine_scheduler_with_warmup(
|
|
179
|
+
optimizer,
|
|
180
|
+
num_warmup_steps=warmup_steps,
|
|
181
|
+
num_training_steps=total_optim_steps,
|
|
182
|
+
num_cycles=num_cycles,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Make checkpoint dir
|
|
186
|
+
checkpoint_dir = Path(cfg.logging.checkpoint_dir)
|
|
187
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
188
|
+
|
|
189
|
+
# Make wandb logger
|
|
190
|
+
torchrl_logger.info("Starting wandb logger.")
|
|
191
|
+
experiment_name = cfg.logging.experiment_name
|
|
192
|
+
if experiment_name is not None:
|
|
193
|
+
experiment_name = [experiment_name]
|
|
194
|
+
else:
|
|
195
|
+
experiment_name = []
|
|
196
|
+
|
|
197
|
+
experiment_name.append(cfg.env.dataset)
|
|
198
|
+
experiment_name.append(cfg.model.name)
|
|
199
|
+
|
|
200
|
+
# Create local wandb logger for training metrics
|
|
201
|
+
wandb_config = {
|
|
202
|
+
"project": "ei-sync",
|
|
203
|
+
"exp_name": "-".join(["ei-sync"] + experiment_name),
|
|
204
|
+
}
|
|
205
|
+
wandb_logger = WandbLogger(**wandb_config)
|
|
206
|
+
|
|
207
|
+
# Pass the logging actor reference to the collector
|
|
208
|
+
log_queue = Queue(maxsize=1000)
|
|
209
|
+
collector.set_postproc(RemoteDataLogger(log_queue=log_queue))
|
|
210
|
+
|
|
211
|
+
# Training loop
|
|
212
|
+
torchrl_logger.info("Starting training loop.")
|
|
213
|
+
pbar = tqdm.tqdm(total=cfg.train.total_dialog_turns)
|
|
214
|
+
grad_norm = 0.0 # Initialize grad_norm
|
|
215
|
+
data_read_count = 0
|
|
216
|
+
|
|
217
|
+
global_step = 0
|
|
218
|
+
optim_step = 0 # Track optimization steps separately for scheduler
|
|
219
|
+
start_time = time.time()
|
|
220
|
+
write_count = replay_buffer.write_count
|
|
221
|
+
for data in collector:
|
|
222
|
+
new_write_count = replay_buffer.write_count
|
|
223
|
+
if new_write_count == write_count:
|
|
224
|
+
torchrl_logger.warning("No new writes to replay buffer")
|
|
225
|
+
continue
|
|
226
|
+
pbar.update(new_write_count - write_count)
|
|
227
|
+
write_count = new_write_count
|
|
228
|
+
|
|
229
|
+
# data is None as the collector directly writes to the replay buffer
|
|
230
|
+
if data is not None:
|
|
231
|
+
raise ValueError("Data is not None")
|
|
232
|
+
|
|
233
|
+
for _ in range(cfg.train.epochs):
|
|
234
|
+
# Iterate over the replay buffer
|
|
235
|
+
for batch in replay_buffer:
|
|
236
|
+
batch = batch.to(train_device)
|
|
237
|
+
global_step += 1
|
|
238
|
+
pbar.set_description(
|
|
239
|
+
f"Gradient step {global_step}, writes: {replay_buffer.write_count}, batch size: {batch.shape}"
|
|
240
|
+
)
|
|
241
|
+
# For logging purposes, we get the last element of the history
|
|
242
|
+
# and convert it to a string
|
|
243
|
+
history: History = batch.view(-1)[0]["next", "history", "prompt"]
|
|
244
|
+
history_str: list[str] | str = history.apply_chat_template(
|
|
245
|
+
tokenizer=train_tokenizer
|
|
246
|
+
)
|
|
247
|
+
while not isinstance(history_str, str):
|
|
248
|
+
history_str = "\n".join(history_str)
|
|
249
|
+
|
|
250
|
+
data_read_count += batch.numel()
|
|
251
|
+
|
|
252
|
+
with timeit("forward_pass"):
|
|
253
|
+
# Forward pass with mixed precision
|
|
254
|
+
with autocast("cuda", enabled=cfg.train.mixed_precision):
|
|
255
|
+
loss = loss_fn(batch)
|
|
256
|
+
if loss.loss_kl_to_ref is not None:
|
|
257
|
+
loss_val = loss.loss_sft + loss.loss_kl_to_ref
|
|
258
|
+
else:
|
|
259
|
+
loss_val = loss.loss_sft
|
|
260
|
+
loss_val = loss_val / cfg.train.gradient_accumulation_steps
|
|
261
|
+
|
|
262
|
+
with timeit("backward_pass"):
|
|
263
|
+
# Backward pass
|
|
264
|
+
if (
|
|
265
|
+
cfg.train.mixed_precision
|
|
266
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
267
|
+
):
|
|
268
|
+
scaler = GradScaler(enabled=True)
|
|
269
|
+
scaler.scale(loss_val).backward()
|
|
270
|
+
else:
|
|
271
|
+
loss_val.backward()
|
|
272
|
+
|
|
273
|
+
# Optimization step
|
|
274
|
+
if ((global_step + 1) % cfg.train.gradient_accumulation_steps) == 0:
|
|
275
|
+
with timeit("optim_step"):
|
|
276
|
+
if (
|
|
277
|
+
cfg.train.mixed_precision
|
|
278
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
279
|
+
):
|
|
280
|
+
scaler.unscale_(optimizer)
|
|
281
|
+
|
|
282
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
283
|
+
policy_training.parameters(),
|
|
284
|
+
cfg.optimizer.clip_grad_norm,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if (
|
|
288
|
+
cfg.train.mixed_precision
|
|
289
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
290
|
+
):
|
|
291
|
+
scaler.step(optimizer)
|
|
292
|
+
scaler.update()
|
|
293
|
+
else:
|
|
294
|
+
optimizer.step()
|
|
295
|
+
optimizer.zero_grad(set_to_none=True)
|
|
296
|
+
|
|
297
|
+
# Step the scheduler
|
|
298
|
+
if scheduler is not None:
|
|
299
|
+
scheduler.step()
|
|
300
|
+
|
|
301
|
+
# Increment optimization step counter
|
|
302
|
+
optim_step += 1
|
|
303
|
+
|
|
304
|
+
# Clear memory
|
|
305
|
+
del loss_val
|
|
306
|
+
torch.cuda.empty_cache()
|
|
307
|
+
gc.collect()
|
|
308
|
+
|
|
309
|
+
# Update metrics
|
|
310
|
+
if (global_step % cfg.train.logging_frequency) == 0:
|
|
311
|
+
log_training_metrics(
|
|
312
|
+
wandb_logger=wandb_logger,
|
|
313
|
+
replay_buffer=replay_buffer,
|
|
314
|
+
batch=batch,
|
|
315
|
+
loss=loss,
|
|
316
|
+
grad_norm=grad_norm,
|
|
317
|
+
global_step=global_step,
|
|
318
|
+
data_read_count=data_read_count,
|
|
319
|
+
collector=collector,
|
|
320
|
+
start_time=start_time,
|
|
321
|
+
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
|
|
322
|
+
history_str=history_str,
|
|
323
|
+
)
|
|
324
|
+
# Log additional metrics
|
|
325
|
+
wandb_logger.log_scalar(
|
|
326
|
+
"learning_rate",
|
|
327
|
+
float(optimizer.param_groups[0]["lr"]),
|
|
328
|
+
step=global_step,
|
|
329
|
+
)
|
|
330
|
+
wandb_logger.log_scalar("optim_step", optim_step, step=global_step)
|
|
331
|
+
while not log_queue.empty():
|
|
332
|
+
logs = log_queue.get()
|
|
333
|
+
for k, v in logs.items():
|
|
334
|
+
wandb_logger.log_scalar(k, v, step=global_step)
|
|
335
|
+
|
|
336
|
+
# Update policy weights
|
|
337
|
+
if (
|
|
338
|
+
cfg.train.weight_update_frequency is not None
|
|
339
|
+
and (global_step + 1) % cfg.train.weight_update_frequency == 0
|
|
340
|
+
):
|
|
341
|
+
with timeit("update_policy_weights"):
|
|
342
|
+
torchrl_logger.info("Updating policy weights...")
|
|
343
|
+
sender.update_weights()
|
|
344
|
+
torch.cuda.empty_cache()
|
|
345
|
+
gc.collect()
|
|
346
|
+
# Checkpointing disabled to prevent disk space issues
|
|
347
|
+
# if (global_step + 1) % cfg.train.checkpoint_frequency == 0:
|
|
348
|
+
# with timeit("save_checkpoint"):
|
|
349
|
+
# torchrl_logger.info(
|
|
350
|
+
# f"Saving checkpoint {(global_step+1) // cfg.train.checkpoint_frequency}..."
|
|
351
|
+
# )
|
|
352
|
+
# checkpoint = {
|
|
353
|
+
# "step": global_step,
|
|
354
|
+
# "model_state_dict": policy_training.model.state_dict(),
|
|
355
|
+
# "optimizer_state_dict": optimizer.state_dict(),
|
|
356
|
+
# "scaler_state_dict": scaler.state_dict(),
|
|
357
|
+
# "config": dict(cfg),
|
|
358
|
+
# }
|
|
359
|
+
# torch.save(checkpoint, checkpoint_dir / f"checkpoint_{global_step:04d}.pt")
|
|
360
|
+
|
|
361
|
+
# Update policy weights
|
|
362
|
+
if cfg.train.weight_update_frequency is None:
|
|
363
|
+
# If weight_update_frequency is not set, we update the weights after each batch
|
|
364
|
+
with timeit("update_policy_weights"):
|
|
365
|
+
torchrl_logger.info("Updating policy weights...")
|
|
366
|
+
sender.update_weights()
|
|
367
|
+
torch.cuda.empty_cache()
|
|
368
|
+
gc.collect()
|
|
369
|
+
|
|
370
|
+
timeit.print(prefix="timeit")
|
|
371
|
+
for key, val in timeit.todict().items():
|
|
372
|
+
wandb_logger.log_scalar(f"timeit/{key}", val)
|
|
373
|
+
timeit.reset()
|
|
374
|
+
|
|
375
|
+
if cfg.train.empty_replay_buffer:
|
|
376
|
+
replay_buffer.empty(empty_write_count=False)
|
|
377
|
+
|
|
378
|
+
pbar.close()
|
|
379
|
+
collector.shutdown()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@hydra.main(version_base=None, config_path="config", config_name="ei_gsm8k")
|
|
383
|
+
def main(cfg):
|
|
384
|
+
# Force sync mode
|
|
385
|
+
if not cfg.train.sync:
|
|
386
|
+
raise ValueError(
|
|
387
|
+
"expert-iteration-sync.py must run in sync mode (`python expert-iteration-sync.py mode=sync`). Please use expert-iteration-async.py for async mode (`python expert-iteration-async.py mode=async`)."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Compute device allocation
|
|
391
|
+
device_config = compute_device_allocation(cfg)
|
|
392
|
+
|
|
393
|
+
if not ray.is_initialized():
|
|
394
|
+
# Convert OmegaConf to regular dict and filter out unsupported parameters
|
|
395
|
+
ray_init_config = {
|
|
396
|
+
k: dict(v) if isinstance(v, DictConfig) else v
|
|
397
|
+
for k, v in dict(cfg.ray.init_config).items()
|
|
398
|
+
if not k.startswith("_")
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
# Add computed GPU configuration and merge with default runtime_env
|
|
402
|
+
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
|
|
403
|
+
ray_init_config = merge_ray_runtime_env(ray_init_config)
|
|
404
|
+
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
|
|
405
|
+
ray.init(**ray_init_config)
|
|
406
|
+
|
|
407
|
+
# Check if num_devices is set
|
|
408
|
+
if cfg.inference_model.num_devices is None:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Inference model num_devices must be set via inference_model.num_devices"
|
|
411
|
+
)
|
|
412
|
+
if cfg.ref_model.num_devices is None:
|
|
413
|
+
raise ValueError("Ref model num_devices must be set via ref_model.num_devices")
|
|
414
|
+
if cfg.train_model.num_devices is None:
|
|
415
|
+
raise ValueError(
|
|
416
|
+
"Train model num_devices must be set via train_model.num_devices"
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Convert OmegaConf to regular dict for Ray configs
|
|
420
|
+
replay_buffer_config = dict(cfg.ray.replay_buffer_config)
|
|
421
|
+
collector_config = dict(cfg.ray.collector_config)
|
|
422
|
+
train_handler_config = dict(cfg.ray.train_handler_config)
|
|
423
|
+
|
|
424
|
+
inference_policy = get_inference_model(
|
|
425
|
+
cfg, devices=device_config["inference_model_devices"]
|
|
426
|
+
)
|
|
427
|
+
torchrl_logger.info(f"Inference policy: {inference_policy}")
|
|
428
|
+
|
|
429
|
+
torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
|
|
430
|
+
rb_size = cfg.train.buffer_size
|
|
431
|
+
if rb_size is None:
|
|
432
|
+
if cfg.train.empty_replay_buffer:
|
|
433
|
+
# we can just set a big number, the buffer will be emptied anyway
|
|
434
|
+
rb_size = 1000000
|
|
435
|
+
else:
|
|
436
|
+
dialog_turns_per_batch = cfg.train.dialog_turns_per_batch
|
|
437
|
+
if dialog_turns_per_batch is None:
|
|
438
|
+
dialog_turns_per_batch = DEFAULT_DIALOG_TURNS_PER_BATCH
|
|
439
|
+
rb_size = int(
|
|
440
|
+
math.ceil(
|
|
441
|
+
dialog_turns_per_batch * cfg.train.topk_size / cfg.env.repeats
|
|
442
|
+
)
|
|
443
|
+
)
|
|
444
|
+
rb = RayReplayBuffer(
|
|
445
|
+
storage=partial(
|
|
446
|
+
LazyStackStorage,
|
|
447
|
+
rb_size,
|
|
448
|
+
device="cpu",
|
|
449
|
+
),
|
|
450
|
+
sampler=SamplerWithoutReplacement,
|
|
451
|
+
transform_factory=partial(
|
|
452
|
+
TopKRewardSelector,
|
|
453
|
+
total_dialog_turns=cfg.env.repeats,
|
|
454
|
+
topk_size=cfg.train.topk_size,
|
|
455
|
+
),
|
|
456
|
+
batch_size=cfg.train.optim_batch_size,
|
|
457
|
+
remote_config=replay_buffer_config,
|
|
458
|
+
)
|
|
459
|
+
torchrl_logger.info(f"Replay buffer: {rb}")
|
|
460
|
+
|
|
461
|
+
# Create remote collector using RayLLMCollector
|
|
462
|
+
collector_config["num_gpus"] = (
|
|
463
|
+
# The ref model will be instantiated within the collector, so we only need to allocate the number of devices for the inference model
|
|
464
|
+
cfg.ref_model.num_devices
|
|
465
|
+
)
|
|
466
|
+
torchrl_logger.info(f"Starting collector with {collector_config=}")
|
|
467
|
+
|
|
468
|
+
dialog_turns_per_batch = cfg.train.dialog_turns_per_batch
|
|
469
|
+
if dialog_turns_per_batch is None:
|
|
470
|
+
# Hardcoded for now
|
|
471
|
+
dialog_turns_per_batch = DEFAULT_DIALOG_TURNS_PER_BATCH
|
|
472
|
+
|
|
473
|
+
collector = RayLLMCollector(
|
|
474
|
+
env=partial(make_env, cfg, devices=device_config["ref_model_devices"]),
|
|
475
|
+
policy=inference_policy,
|
|
476
|
+
dialog_turns_per_batch=dialog_turns_per_batch,
|
|
477
|
+
total_dialog_turns=cfg.train.total_dialog_turns,
|
|
478
|
+
replay_buffer=rb,
|
|
479
|
+
ray_init_config=None, # Ray is already initialized
|
|
480
|
+
weight_updater=None, # We'll create this after getting the remote LLM
|
|
481
|
+
track_policy_version=True,
|
|
482
|
+
remote_config=collector_config,
|
|
483
|
+
sync_iter=cfg.train.sync_iter,
|
|
484
|
+
verbose=True,
|
|
485
|
+
)
|
|
486
|
+
# Ensure collector is initialized by calling a method that will block until ready
|
|
487
|
+
ray.get(collector._collector.is_initialized.remote())
|
|
488
|
+
torchrl_logger.info(f"Collector: {collector}")
|
|
489
|
+
|
|
490
|
+
train_handler_config = {
|
|
491
|
+
"num_cpus": train_handler_config.get("num_cpus", 1),
|
|
492
|
+
"num_gpus": cfg.train_model.num_devices,
|
|
493
|
+
}
|
|
494
|
+
torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
|
|
495
|
+
train_handler = ray.remote(
|
|
496
|
+
**train_handler_config,
|
|
497
|
+
)(train)
|
|
498
|
+
|
|
499
|
+
# launch training
|
|
500
|
+
ray.get(
|
|
501
|
+
train_handler.remote(rb, cfg, collector, device_config["train_model_devices"])
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
if __name__ == "__main__":
|
|
506
|
+
# Setup environment
|
|
507
|
+
setup_environment()
|
|
508
|
+
main()
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
torch==2.7.0
|
|
2
|
+
transformers==4.52.4
|
|
3
|
+
peft==0.15.2
|
|
4
|
+
bitsandbytes==0.46.0
|
|
5
|
+
datasets==3.6.0
|
|
6
|
+
wandb==0.19.11
|
|
7
|
+
hydra-core==1.3.2
|
|
8
|
+
ray==2.52.1
|
|
9
|
+
tqdm==4.67.1
|
|
10
|
+
tensordict==0.9.0
|
|
11
|
+
vllm==0.9.0.1
|
|
12
|
+
accelerate==1.7.0
|
|
13
|
+
xformers==0.0.30
|
|
14
|
+
nltk==3.9.1
|
|
15
|
+
langdetect==1.0.9
|
|
16
|
+
immutabledict==4.2.1
|