torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
import time
|
|
10
|
+
from functools import partial
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import hydra
|
|
14
|
+
|
|
15
|
+
from torchrl import merge_ray_runtime_env, torchrl_logger
|
|
16
|
+
from torchrl.data.llm.history import History
|
|
17
|
+
from torchrl.record.loggers.wandb import WandbLogger
|
|
18
|
+
from torchrl.weight_update.llm import get_model_metadata
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import ray
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"Ray is required for async training. Please install ray with `pip install ray`."
|
|
25
|
+
)
|
|
26
|
+
import torch
|
|
27
|
+
import tqdm
|
|
28
|
+
|
|
29
|
+
from ei_utils import (
|
|
30
|
+
compute_device_allocation,
|
|
31
|
+
create_cosine_scheduler_with_warmup,
|
|
32
|
+
get_inference_model,
|
|
33
|
+
get_train_model,
|
|
34
|
+
log_training_metrics,
|
|
35
|
+
make_env,
|
|
36
|
+
make_weight_sync_scheme,
|
|
37
|
+
RemoteDataLogger,
|
|
38
|
+
)
|
|
39
|
+
from omegaconf import DictConfig
|
|
40
|
+
from ray.util.queue import Queue
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
from tensordict import set_list_to_stack
|
|
44
|
+
except ImportError:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"TensorDict is required. Please install it with `pip install tensordict`."
|
|
47
|
+
)
|
|
48
|
+
from torch.amp.autocast_mode import autocast
|
|
49
|
+
from torch.amp.grad_scaler import GradScaler
|
|
50
|
+
from torchrl._utils import timeit
|
|
51
|
+
from torchrl.collectors.llm import RayLLMCollector
|
|
52
|
+
from torchrl.data import (
|
|
53
|
+
LazyStackStorage,
|
|
54
|
+
PrioritizedSampler,
|
|
55
|
+
ReplayBuffer,
|
|
56
|
+
TensorDictReplayBuffer,
|
|
57
|
+
)
|
|
58
|
+
from torchrl.data.llm.topk import TopKRewardSelector
|
|
59
|
+
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
|
|
60
|
+
from torchrl.objectives.llm.sft import SFTLoss
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def setup_environment() -> None:
|
|
64
|
+
"""Setup required environment variables and configurations."""
|
|
65
|
+
|
|
66
|
+
if not torch.cuda.is_available():
|
|
67
|
+
raise RuntimeError("CUDA is required for training")
|
|
68
|
+
|
|
69
|
+
# Set default dtype to float32 for mixed precision training
|
|
70
|
+
torch.set_default_dtype(torch.float32)
|
|
71
|
+
torch.set_default_device("cuda:0")
|
|
72
|
+
set_list_to_stack(True).set()
|
|
73
|
+
|
|
74
|
+
# Ensure CUDA is using the correct dtype
|
|
75
|
+
if torch.cuda.is_available():
|
|
76
|
+
torch.cuda.set_device("cuda:0")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def train(
|
|
80
|
+
replay_buffer: ReplayBuffer,
|
|
81
|
+
cfg: DictConfig,
|
|
82
|
+
collector: RayLLMCollector,
|
|
83
|
+
devices: list[int] | None = None,
|
|
84
|
+
):
|
|
85
|
+
"""Main training loop for EI async.
|
|
86
|
+
|
|
87
|
+
This function implements asynchronous training where data collection and optimization
|
|
88
|
+
happen concurrently. The total number of steps is determined by the number of epochs,
|
|
89
|
+
samples per epoch, and batches collected.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
replay_buffer: The replay buffer to store experiences
|
|
93
|
+
cfg: The configuration object containing training parameters
|
|
94
|
+
collector: The collector object.
|
|
95
|
+
devices: The devices to use for the training model.
|
|
96
|
+
"""
|
|
97
|
+
# Setup training model and tokenizer
|
|
98
|
+
policy_training, train_tokenizer = get_train_model(
|
|
99
|
+
cfg, devices=devices, chat_template_name="qwen"
|
|
100
|
+
)
|
|
101
|
+
train_device = devices[0] # Use first device for batch processing
|
|
102
|
+
|
|
103
|
+
# Setup loss function
|
|
104
|
+
loss_fn = SFTLoss(
|
|
105
|
+
actor_network=policy_training,
|
|
106
|
+
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff,
|
|
107
|
+
tokenizer=train_tokenizer,
|
|
108
|
+
tokenizer_kwargs={"chat_template_name": "qwen"},
|
|
109
|
+
device=torch.device(f"cuda:{train_device}")
|
|
110
|
+
if train_device is not None
|
|
111
|
+
else None,
|
|
112
|
+
loss_function=cfg.train.loss_function,
|
|
113
|
+
beta=cfg.train.minor_sft_beta,
|
|
114
|
+
)
|
|
115
|
+
if cfg.model.compile:
|
|
116
|
+
loss_fn = torch.compile(loss_fn)
|
|
117
|
+
|
|
118
|
+
# Get vLLM engine from the inference policy
|
|
119
|
+
# Note: In expert iteration, the inference policy is typically created in get_inference_model
|
|
120
|
+
# We need to get the vLLM engine from the collector's policy or create it
|
|
121
|
+
# For now, we'll use the approach similar to GRPO with explicit scheme creation
|
|
122
|
+
|
|
123
|
+
# Create weight sync scheme
|
|
124
|
+
weight_sync_scheme = make_weight_sync_scheme(
|
|
125
|
+
master_address="localhost", # Since we're running locally
|
|
126
|
+
master_port=None, # Will auto-assign an open port
|
|
127
|
+
vllm_tp_size=cfg.inference_model.num_devices
|
|
128
|
+
if cfg.inference_model.num_devices is not None
|
|
129
|
+
else len(cfg.inference_model.get("devices", [1])),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Set up weight sender
|
|
133
|
+
torchrl_logger.info("Setting up weight synchronization scheme...")
|
|
134
|
+
sender = weight_sync_scheme.create_sender()
|
|
135
|
+
sender.register_model(policy_training)
|
|
136
|
+
|
|
137
|
+
# Get vLLM engine reference from collector's policy
|
|
138
|
+
# The collector has the policy which wraps the vLLM engine
|
|
139
|
+
vllm_engine = collector.policy.model if hasattr(collector, "policy") else None
|
|
140
|
+
if vllm_engine is None:
|
|
141
|
+
raise RuntimeError("Could not get vLLM engine from collector policy")
|
|
142
|
+
|
|
143
|
+
# Initialize collective group
|
|
144
|
+
torchrl_logger.info("Initializing collective group...")
|
|
145
|
+
metadata = get_model_metadata(policy_training)
|
|
146
|
+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
|
|
147
|
+
|
|
148
|
+
# First weight update
|
|
149
|
+
with timeit("update_policy_weights"):
|
|
150
|
+
sender.update_weights()
|
|
151
|
+
timeit.print(prefix="First update_policy_weights_ time")
|
|
152
|
+
timeit.reset()
|
|
153
|
+
|
|
154
|
+
# Make optimizer
|
|
155
|
+
torchrl_logger.info("Starting optimizer.")
|
|
156
|
+
optimizer = torch.optim.Adam(
|
|
157
|
+
policy_training.parameters(),
|
|
158
|
+
lr=cfg.optimizer.lr,
|
|
159
|
+
weight_decay=cfg.optimizer.weight_decay,
|
|
160
|
+
fused=False,
|
|
161
|
+
)
|
|
162
|
+
scaler = GradScaler(enabled=cfg.train.mixed_precision)
|
|
163
|
+
|
|
164
|
+
# Calculate total optimization steps for scheduler
|
|
165
|
+
# The training loop structure: for each collector iteration, we do cfg.train.epochs epochs
|
|
166
|
+
# Each epoch processes the entire replay buffer, and optimization happens every gradient_accumulation_steps
|
|
167
|
+
# We need to estimate the total number of optimization steps
|
|
168
|
+
# For now, we'll use a conservative estimate based on the total dialog turns
|
|
169
|
+
# This can be refined based on the actual training dynamics
|
|
170
|
+
total_optim_steps = (
|
|
171
|
+
cfg.train.total_dialog_turns
|
|
172
|
+
* cfg.train.epochs
|
|
173
|
+
// cfg.train.gradient_accumulation_steps
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Create scheduler if enabled
|
|
177
|
+
scheduler = None
|
|
178
|
+
if cfg.optimizer.scheduler.enabled:
|
|
179
|
+
warmup_steps = cfg.optimizer.scheduler.warmup_steps
|
|
180
|
+
num_cycles = cfg.optimizer.scheduler.num_cycles
|
|
181
|
+
torchrl_logger.info(
|
|
182
|
+
f"Creating {cfg.optimizer.scheduler.type} scheduler with {warmup_steps} warmup steps out of {total_optim_steps} total steps"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
scheduler = create_cosine_scheduler_with_warmup(
|
|
186
|
+
optimizer,
|
|
187
|
+
num_warmup_steps=warmup_steps,
|
|
188
|
+
num_training_steps=total_optim_steps,
|
|
189
|
+
num_cycles=num_cycles,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Make checkpoint dir
|
|
193
|
+
checkpoint_dir = Path(cfg.logging.checkpoint_dir)
|
|
194
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
# Make wandb logger
|
|
197
|
+
torchrl_logger.info("Starting wandb logger.")
|
|
198
|
+
experiment_name = cfg.logging.experiment_name
|
|
199
|
+
if experiment_name is not None:
|
|
200
|
+
experiment_name = [experiment_name]
|
|
201
|
+
else:
|
|
202
|
+
experiment_name = []
|
|
203
|
+
|
|
204
|
+
experiment_name.append(cfg.env.dataset)
|
|
205
|
+
experiment_name.append(cfg.model.name)
|
|
206
|
+
|
|
207
|
+
# Create local wandb logger for training metrics
|
|
208
|
+
wandb_config = {
|
|
209
|
+
"project": "ei-async",
|
|
210
|
+
"exp_name": "-".join(["ei-async"] + experiment_name),
|
|
211
|
+
}
|
|
212
|
+
wandb_logger = WandbLogger(**wandb_config)
|
|
213
|
+
|
|
214
|
+
# Pass the logging actor reference to the collector
|
|
215
|
+
log_queue = Queue(maxsize=1000)
|
|
216
|
+
collector.set_postproc(RemoteDataLogger(log_queue=log_queue))
|
|
217
|
+
|
|
218
|
+
# Start collector
|
|
219
|
+
collector.start()
|
|
220
|
+
|
|
221
|
+
# Wait for initial data
|
|
222
|
+
while not replay_buffer.write_count:
|
|
223
|
+
time.sleep(1)
|
|
224
|
+
|
|
225
|
+
# Training loop
|
|
226
|
+
total_steps = (
|
|
227
|
+
-(cfg.train.total_dialog_turns // -cfg.train.optim_batch_size)
|
|
228
|
+
* cfg.train.epochs
|
|
229
|
+
)
|
|
230
|
+
torchrl_logger.info(f"Total steps: {total_steps}")
|
|
231
|
+
|
|
232
|
+
pbar = tqdm.tqdm(total=total_steps)
|
|
233
|
+
grad_norm = 0.0 # Initialize grad_norm
|
|
234
|
+
data_read_count = 0
|
|
235
|
+
optim_step = 0
|
|
236
|
+
start_time = time.time()
|
|
237
|
+
|
|
238
|
+
for step in range(total_steps):
|
|
239
|
+
pbar.update(1)
|
|
240
|
+
pbar.set_description(f"Step {step}, writes: {replay_buffer.write_count}")
|
|
241
|
+
|
|
242
|
+
with timeit("sampling"):
|
|
243
|
+
# Sample batch and move to device
|
|
244
|
+
batch = replay_buffer.sample(cfg.train.optim_batch_size).to(train_device)
|
|
245
|
+
|
|
246
|
+
max_policy_age = (
|
|
247
|
+
batch.view(-1)[0]["next", "policy_version"] - collector.policy_version
|
|
248
|
+
).max()
|
|
249
|
+
if (
|
|
250
|
+
cfg.train.max_policy_age is not None
|
|
251
|
+
and max_policy_age > cfg.train.max_policy_age
|
|
252
|
+
):
|
|
253
|
+
# Skip this batch, as it's too old
|
|
254
|
+
torchrl_logger.info(f"Skipping batch with policy age {max_policy_age}")
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
# For logging purposes, we get the last element of the history
|
|
258
|
+
# and convert it to a string
|
|
259
|
+
history: History = batch.view(-1)[0]["next", "history", "prompt"]
|
|
260
|
+
history_str: list[str] | str = history.apply_chat_template(
|
|
261
|
+
tokenizer=train_tokenizer
|
|
262
|
+
)
|
|
263
|
+
while not isinstance(history_str, str):
|
|
264
|
+
history_str = "\n".join(history_str)
|
|
265
|
+
|
|
266
|
+
data_read_count += batch.numel()
|
|
267
|
+
|
|
268
|
+
with timeit("forward_pass"):
|
|
269
|
+
# Forward pass with mixed precision
|
|
270
|
+
with autocast("cuda", enabled=cfg.train.mixed_precision):
|
|
271
|
+
loss = loss_fn(batch)
|
|
272
|
+
if loss.loss_kl_to_ref is not None:
|
|
273
|
+
loss_val = loss.loss_sft + loss.loss_kl_to_ref
|
|
274
|
+
else:
|
|
275
|
+
loss_val = loss.loss_sft
|
|
276
|
+
loss_val = loss_val / cfg.train.gradient_accumulation_steps
|
|
277
|
+
with timeit("backward_pass"):
|
|
278
|
+
# Backward pass
|
|
279
|
+
if cfg.train.mixed_precision and cfg.train_model.torch_dtype == "float16":
|
|
280
|
+
scaler = GradScaler(enabled=True)
|
|
281
|
+
scaler.scale(loss_val).backward()
|
|
282
|
+
else:
|
|
283
|
+
loss_val.backward()
|
|
284
|
+
|
|
285
|
+
# Optimization step
|
|
286
|
+
if ((step + 1) % cfg.train.gradient_accumulation_steps) == 0:
|
|
287
|
+
with timeit("optim_step"):
|
|
288
|
+
if (
|
|
289
|
+
cfg.train.mixed_precision
|
|
290
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
291
|
+
):
|
|
292
|
+
scaler.unscale_(optimizer)
|
|
293
|
+
|
|
294
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
295
|
+
policy_training.parameters(),
|
|
296
|
+
cfg.optimizer.clip_grad_norm,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if (
|
|
300
|
+
cfg.train.mixed_precision
|
|
301
|
+
and cfg.train_model.torch_dtype == "float16"
|
|
302
|
+
):
|
|
303
|
+
scaler.step(optimizer)
|
|
304
|
+
scaler.update()
|
|
305
|
+
else:
|
|
306
|
+
optimizer.step()
|
|
307
|
+
optimizer.zero_grad(set_to_none=True)
|
|
308
|
+
|
|
309
|
+
# Step the scheduler
|
|
310
|
+
if scheduler is not None:
|
|
311
|
+
scheduler.step()
|
|
312
|
+
|
|
313
|
+
# Increment optimization step counter
|
|
314
|
+
optim_step += 1
|
|
315
|
+
|
|
316
|
+
# Update metrics
|
|
317
|
+
if (step % cfg.train.logging_frequency) == 0:
|
|
318
|
+
log_training_metrics(
|
|
319
|
+
wandb_logger=wandb_logger,
|
|
320
|
+
replay_buffer=replay_buffer,
|
|
321
|
+
batch=batch,
|
|
322
|
+
loss=loss,
|
|
323
|
+
grad_norm=grad_norm,
|
|
324
|
+
global_step=step,
|
|
325
|
+
data_read_count=data_read_count,
|
|
326
|
+
collector=collector,
|
|
327
|
+
start_time=start_time,
|
|
328
|
+
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
|
|
329
|
+
history_str=history_str,
|
|
330
|
+
)
|
|
331
|
+
# Log additional metrics
|
|
332
|
+
wandb_logger.log_scalar(
|
|
333
|
+
"learning_rate", float(optimizer.param_groups[0]["lr"]), step=step
|
|
334
|
+
)
|
|
335
|
+
wandb_logger.log_scalar("optim_step", optim_step, step=step)
|
|
336
|
+
while not log_queue.empty():
|
|
337
|
+
logs = log_queue.get()
|
|
338
|
+
for k, v in logs.items():
|
|
339
|
+
wandb_logger.log_scalar(k, v)
|
|
340
|
+
|
|
341
|
+
# Update policy weights
|
|
342
|
+
if step % cfg.train.weight_update_frequency == 0:
|
|
343
|
+
with timeit("update_policy_weights"):
|
|
344
|
+
torchrl_logger.info("Updating policy weights...")
|
|
345
|
+
sender.update_weights()
|
|
346
|
+
# TODO: do we need this? Does it interfere with other processes?
|
|
347
|
+
# torch.cuda.empty_cache()
|
|
348
|
+
gc.collect()
|
|
349
|
+
|
|
350
|
+
# Checkpointing disabled to prevent disk space issues
|
|
351
|
+
# if (step + 1) % cfg.train.checkpoint_frequency == 0:
|
|
352
|
+
# with timeit("save_checkpoint"):
|
|
353
|
+
# torchrl_logger.info(
|
|
354
|
+
# f"Saving checkpoint {(step+1) // cfg.train.checkpoint_frequency}..."
|
|
355
|
+
# )
|
|
356
|
+
# checkpoint = {
|
|
357
|
+
# "step": step,
|
|
358
|
+
# "model_state_dict": policy_training.model.state_dict(),
|
|
359
|
+
# "optimizer_state_dict": optimizer.state_dict(),
|
|
360
|
+
# "scaler_state_dict": scaler.state_dict(),
|
|
361
|
+
# "config": dict(cfg),
|
|
362
|
+
# }
|
|
363
|
+
# torch.save(checkpoint, checkpoint_dir / f"checkpoint_{step:04d}.pt")
|
|
364
|
+
|
|
365
|
+
if step % cfg.train.weight_update_frequency == 0:
|
|
366
|
+
timeit.print(prefix="timeit")
|
|
367
|
+
for key, val in timeit.todict().items():
|
|
368
|
+
wandb_logger.log_scalar(f"timeit/{key}", val)
|
|
369
|
+
timeit.reset()
|
|
370
|
+
|
|
371
|
+
# Clear memory
|
|
372
|
+
del loss_val
|
|
373
|
+
# TODO: do we need this? Does it interfere with other processes?
|
|
374
|
+
# torch.cuda.empty_cache()
|
|
375
|
+
gc.collect()
|
|
376
|
+
|
|
377
|
+
pbar.close()
|
|
378
|
+
collector.shutdown()
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@hydra.main(version_base=None, config_path="config", config_name="ei_gsm8k")
|
|
382
|
+
def main(cfg):
|
|
383
|
+
# Force async mode
|
|
384
|
+
if cfg.train.sync:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
"expert-iteration-async.py must run in async mode (`python expert-iteration-async.py mode=async`). Please use expert-iteration-sync.py for sync mode (`python expert-iteration-sync.py mode=sync`)."
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Compute device allocation
|
|
390
|
+
device_config = compute_device_allocation(cfg)
|
|
391
|
+
|
|
392
|
+
if not ray.is_initialized():
|
|
393
|
+
# Convert OmegaConf to regular dict and filter out unsupported parameters
|
|
394
|
+
ray_init_config = {
|
|
395
|
+
k: dict(v) if isinstance(v, DictConfig) else v
|
|
396
|
+
for k, v in dict(cfg.ray.init_config).items()
|
|
397
|
+
if not k.startswith("_")
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
# Add computed GPU configuration and merge with default runtime_env
|
|
401
|
+
ray_init_config["num_gpus"] = device_config["ray_num_gpus"]
|
|
402
|
+
ray_init_config = merge_ray_runtime_env(ray_init_config)
|
|
403
|
+
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
|
|
404
|
+
ray.init(**ray_init_config)
|
|
405
|
+
|
|
406
|
+
# Check if num_devices is set
|
|
407
|
+
if cfg.inference_model.num_devices is None:
|
|
408
|
+
raise ValueError(
|
|
409
|
+
"Inference model num_devices must be set via inference_model.num_devices"
|
|
410
|
+
)
|
|
411
|
+
if cfg.ref_model.num_devices is None:
|
|
412
|
+
raise ValueError("Ref model num_devices must be set via ref_model.num_devices")
|
|
413
|
+
if cfg.train_model.num_devices is None:
|
|
414
|
+
raise ValueError(
|
|
415
|
+
"Train model num_devices must be set via train_model.num_devices"
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Convert OmegaConf to regular dict for Ray configs
|
|
419
|
+
replay_buffer_config = dict(cfg.ray.replay_buffer_config)
|
|
420
|
+
collector_config = dict(cfg.ray.collector_config)
|
|
421
|
+
train_handler_config = dict(cfg.ray.train_handler_config)
|
|
422
|
+
|
|
423
|
+
inference_policy = get_inference_model(
|
|
424
|
+
cfg, devices=device_config["inference_model_devices"]
|
|
425
|
+
)
|
|
426
|
+
torchrl_logger.info(f"Inference policy: {inference_policy}")
|
|
427
|
+
|
|
428
|
+
torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
|
|
429
|
+
rb_size = cfg.train.buffer_size
|
|
430
|
+
if rb_size is None:
|
|
431
|
+
# Hardcoded for now
|
|
432
|
+
rb_size = 256
|
|
433
|
+
if cfg.train.prioritized_sampling:
|
|
434
|
+
rb_cls = TensorDictReplayBuffer
|
|
435
|
+
rb_sampler_cls = partial(
|
|
436
|
+
PrioritizedSampler,
|
|
437
|
+
max_capacity=rb_size,
|
|
438
|
+
alpha=cfg.train.prioritized_sampling_alpha,
|
|
439
|
+
beta=cfg.train.prioritized_sampling_beta,
|
|
440
|
+
eps=cfg.train.prioritized_sampling_epsilon,
|
|
441
|
+
)
|
|
442
|
+
kwargs = {"priority_key": ("next", "reward")}
|
|
443
|
+
else:
|
|
444
|
+
rb_cls = ReplayBuffer
|
|
445
|
+
rb_sampler_cls = None
|
|
446
|
+
kwargs = {}
|
|
447
|
+
rb = RayReplayBuffer(
|
|
448
|
+
storage=partial(
|
|
449
|
+
LazyStackStorage,
|
|
450
|
+
rb_size,
|
|
451
|
+
device="cpu",
|
|
452
|
+
),
|
|
453
|
+
transform_factory=partial(
|
|
454
|
+
TopKRewardSelector,
|
|
455
|
+
total_dialog_turns=cfg.env.repeats,
|
|
456
|
+
topk_size=cfg.train.topk_size,
|
|
457
|
+
),
|
|
458
|
+
batch_size=cfg.train.optim_batch_size,
|
|
459
|
+
remote_config=replay_buffer_config,
|
|
460
|
+
replay_buffer_cls=rb_cls,
|
|
461
|
+
sampler=rb_sampler_cls,
|
|
462
|
+
**kwargs,
|
|
463
|
+
)
|
|
464
|
+
torchrl_logger.info(f"Replay buffer: {rb}")
|
|
465
|
+
|
|
466
|
+
# Create remote collector using RayLLMCollector
|
|
467
|
+
collector_config["num_gpus"] = (
|
|
468
|
+
# The ref model will be instantiated within the collector, so we only need to allocate the number of devices for the inference model
|
|
469
|
+
cfg.ref_model.num_devices
|
|
470
|
+
)
|
|
471
|
+
torchrl_logger.info(f"Starting collector with {collector_config=}")
|
|
472
|
+
|
|
473
|
+
dialog_turns_per_batch = cfg.train.dialog_turns_per_batch
|
|
474
|
+
if dialog_turns_per_batch is None:
|
|
475
|
+
# Hardcoded for now
|
|
476
|
+
dialog_turns_per_batch = cfg.env.repeats
|
|
477
|
+
|
|
478
|
+
collector = RayLLMCollector(
|
|
479
|
+
env=partial(make_env, cfg, devices=device_config["ref_model_devices"]),
|
|
480
|
+
policy=inference_policy,
|
|
481
|
+
dialog_turns_per_batch=dialog_turns_per_batch,
|
|
482
|
+
total_dialog_turns=cfg.train.total_dialog_turns,
|
|
483
|
+
replay_buffer=rb,
|
|
484
|
+
ray_init_config=None, # Ray is already initialized
|
|
485
|
+
weight_updater=None, # We'll create this after getting the remote LLM
|
|
486
|
+
track_policy_version=True,
|
|
487
|
+
remote_config=collector_config,
|
|
488
|
+
verbose=True,
|
|
489
|
+
)
|
|
490
|
+
# Ensure collector is initialized by calling a method that will block until ready
|
|
491
|
+
ray.get(collector._collector.is_initialized.remote())
|
|
492
|
+
torchrl_logger.info(f"Collector: {collector}")
|
|
493
|
+
|
|
494
|
+
train_handler_config = {
|
|
495
|
+
"num_cpus": train_handler_config.get("num_cpus", 1),
|
|
496
|
+
"num_gpus": cfg.train_model.num_devices,
|
|
497
|
+
}
|
|
498
|
+
torchrl_logger.info(f"Starting training handler with {train_handler_config=}")
|
|
499
|
+
train_handler = ray.remote(
|
|
500
|
+
**train_handler_config,
|
|
501
|
+
)(train)
|
|
502
|
+
|
|
503
|
+
# launch training
|
|
504
|
+
ray.get(
|
|
505
|
+
train_handler.remote(rb, cfg, collector, device_config["train_model_devices"])
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
if __name__ == "__main__":
|
|
510
|
+
# Setup environment
|
|
511
|
+
setup_environment()
|
|
512
|
+
main()
|