torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
# Expert Iteration: Learning from Top-K Responses
|
|
2
|
+
|
|
3
|
+
This is an implementation of Expert Iteration for language models, built on top of TorchRL.
|
|
4
|
+
Expert Iteration is a reinforcement learning-like method that learns from the best performing responses in a batch, rather than using all responses equally.
|
|
5
|
+
|
|
6
|
+
The idea of these scripts is extremely simple:
|
|
7
|
+
- Collect some trajectories with a pre-trained version of the model;
|
|
8
|
+
- Select the top-K best trajectories of the batch (based on their reward);
|
|
9
|
+
- Train the model using SFT of these trajectories;
|
|
10
|
+
- Update the inference model.
|
|
11
|
+
|
|
12
|
+
## Overview
|
|
13
|
+
|
|
14
|
+
The version of Expert Iteration presented here has the following features:
|
|
15
|
+
|
|
16
|
+
- **Top-K Selection**: Only the best performing responses are used for training, improving sample efficiency
|
|
17
|
+
- **KL Regularization**: Maintains model quality by penalizing divergence from a reference model
|
|
18
|
+
- **Multi-GPU support** with efficient device management
|
|
19
|
+
- **Mixed precision training** for memory efficiency
|
|
20
|
+
- **Gradient accumulation** for larger effective batch sizes
|
|
21
|
+
- **Automatic checkpointing** and comprehensive logging with Weights & Biases
|
|
22
|
+
- **Hydra configuration system** for easy experimentation
|
|
23
|
+
- **Asynchronous training support** with Ray for improved throughput
|
|
24
|
+
- **Prioritized sampling** such that samples with higher rewards have more chances of being sampled
|
|
25
|
+
|
|
26
|
+
## Key Differences from GRPO and other RL algorithms
|
|
27
|
+
|
|
28
|
+
### 1. Top-K Reward Selection
|
|
29
|
+
|
|
30
|
+
Unlike other RL post-training recipes (e.g. GRPO) which uses all responses,
|
|
31
|
+
Expert Iteration employs a `TopKRewardSelector` transform that:
|
|
32
|
+
|
|
33
|
+
- Collects multiple responses for each prompt (controlled by `env.repeats`)
|
|
34
|
+
- Selects only the top-k responses based on reward (controlled by `train.topk_size`)
|
|
35
|
+
- Writes only the best responses to the replay buffer, improving training efficiency
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
# Example: For each prompt, generate 32 responses but only keep the best 4
|
|
39
|
+
env.repeats = 32 # Generate 32 responses per prompt
|
|
40
|
+
train.topk_size = 4 # Keep only the top 4 responses
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
### 2. KL Divergence Handling
|
|
44
|
+
|
|
45
|
+
Expert Iteration uses a different approach to KL regularization:
|
|
46
|
+
- **No KL in reward**: Unlike GRPO's `KLRewardTransform`, Expert Iteration doesn't add KL penalties to the reward signal
|
|
47
|
+
- **KL in loss function**: KL divergence is computed directly in the loss function using `SFTLoss` with `kl_to_ref_coeff`
|
|
48
|
+
- **Reference log probabilities**: The `RetrieveLogProb` transform extracts reference model log probabilities for KL computation
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
# KL is handled in the loss function, not in the reward
|
|
52
|
+
loss_fn = SFTLoss(
|
|
53
|
+
actor_network=policy_training,
|
|
54
|
+
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff, # KL penalty coefficient
|
|
55
|
+
tokenizer=train_tokenizer,
|
|
56
|
+
tokenizer_kwargs={"chat_template_name": "qwen"},
|
|
57
|
+
device=train_device,
|
|
58
|
+
)
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
### 3. Reduced Weight Updates
|
|
62
|
+
|
|
63
|
+
Expert Iteration can afford fewer policy weight updates due to its selective training approach. One can freely choose longer intervals for the `update_weight_frequency` (e.g., every 100 or more optimization steps).
|
|
64
|
+
|
|
65
|
+
## Installation
|
|
66
|
+
|
|
67
|
+
Install dependencies:
|
|
68
|
+
```bash
|
|
69
|
+
# GSM8K deps
|
|
70
|
+
pip install -r sota-implementations/expert-iteration/requirements_gsm8k.txt
|
|
71
|
+
# IFEval deps
|
|
72
|
+
pip install -r sota-implementations/expert-iteration/requirements_ifeval.txt
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Hardware Requirements
|
|
76
|
+
|
|
77
|
+
- At least 3 CUDA-capable GPUs:
|
|
78
|
+
- Training device(s)
|
|
79
|
+
- vLLM inference device
|
|
80
|
+
- Reference model device
|
|
81
|
+
|
|
82
|
+
### Device Management
|
|
83
|
+
|
|
84
|
+
The number of devices for each model component is specified using `num_devices`:
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
This approach:
|
|
91
|
+
|
|
92
|
+
- Automatically handles device allocation;
|
|
93
|
+
- Works correctly in both sync and async modes;
|
|
94
|
+
- Prevents device conflicts between model components;
|
|
95
|
+
- Is more portable across different machine configurations.
|
|
96
|
+
|
|
97
|
+
## Configuration
|
|
98
|
+
|
|
99
|
+
The training configuration is managed through Hydra. There are two main configuration files:
|
|
100
|
+
- `config/ei_gsm8k.yaml`: Default configuration for GSM8K tasks (default)
|
|
101
|
+
- `config/ei_ifeval.yaml`: Configuration optimized for IFEval tasks
|
|
102
|
+
|
|
103
|
+
## Usage
|
|
104
|
+
|
|
105
|
+
### Basic Training
|
|
106
|
+
|
|
107
|
+
There are two training modes available:
|
|
108
|
+
|
|
109
|
+
#### Synchronous Mode (Default)
|
|
110
|
+
```bash
|
|
111
|
+
python sota-implementations/expert-iteration/expert-iteration-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
#### Asynchronous Mode (Recommended)
|
|
115
|
+
```bash
|
|
116
|
+
python sota-implementations/expert-iteration/expert-iteration-async.py mode=async train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
The key difference between sync and async modes is how data collection and optimization are handled:
|
|
120
|
+
|
|
121
|
+
**Synchronous Mode (expert-iteration-sync.py)**:
|
|
122
|
+
```python
|
|
123
|
+
# Three nested loops:
|
|
124
|
+
for data in collector: # Data collection loop
|
|
125
|
+
for epoch in range(epochs): # Epoch loop
|
|
126
|
+
for batch in replay_buffer: # Buffer consumption loop
|
|
127
|
+
# Optimize on batch (only top-k responses)
|
|
128
|
+
loss = loss_fn(batch)
|
|
129
|
+
loss.backward()
|
|
130
|
+
optimizer.step()
|
|
131
|
+
# Weight update
|
|
132
|
+
weight_updater.push_weights(policy_training)
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
**Asynchronous Mode (expert-iteration-async.py)**:
|
|
136
|
+
```python
|
|
137
|
+
# Start data collection in background
|
|
138
|
+
collector.start()
|
|
139
|
+
|
|
140
|
+
# Single optimization loop
|
|
141
|
+
for step in range(total_steps):
|
|
142
|
+
# Sample and optimize (only top-k responses)
|
|
143
|
+
batch = replay_buffer.sample()
|
|
144
|
+
loss = loss_fn(batch)
|
|
145
|
+
loss.backward()
|
|
146
|
+
optimizer.step()
|
|
147
|
+
# Update weights once in a while
|
|
148
|
+
if step % weight_update_frequency == 0:
|
|
149
|
+
weight_updater.push_weights(policy_training)
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
Key differences:
|
|
153
|
+
1. **Data Collection**:
|
|
154
|
+
- Sync: Data collection and optimization happen sequentially (unless `train.sync_iter=false`)
|
|
155
|
+
- Async: Data collection runs in background while optimization happens
|
|
156
|
+
|
|
157
|
+
2. **Buffer Size**:
|
|
158
|
+
- Sync: Buffer size must equal the batch size returned by collector
|
|
159
|
+
- Async: Buffer can be larger than the batch size, allowing for more diverse sampling
|
|
160
|
+
|
|
161
|
+
3. **Data Processing**:
|
|
162
|
+
- Sync: Processes the same data multiple times (epochs)
|
|
163
|
+
- Async: Each piece of data is processed a non-deterministic number of times
|
|
164
|
+
|
|
165
|
+
4. **Weight updates**:
|
|
166
|
+
- Sync: Weights are updated before every collection of data
|
|
167
|
+
- Async: Weights are updated at a given interval (in gradient steps)
|
|
168
|
+
|
|
169
|
+
The async mode offers better performance by:
|
|
170
|
+
|
|
171
|
+
- Running data collection and optimization concurrently
|
|
172
|
+
- More efficient GPU utilization
|
|
173
|
+
- Reduced memory overhead
|
|
174
|
+
- Better throughput
|
|
175
|
+
- More flexible buffer management
|
|
176
|
+
|
|
177
|
+
### Top-K Configuration
|
|
178
|
+
|
|
179
|
+
The key parameters for top-k selection are:
|
|
180
|
+
|
|
181
|
+
```yaml
|
|
182
|
+
env:
|
|
183
|
+
repeats: 32 # Number of responses to generate per prompt
|
|
184
|
+
train:
|
|
185
|
+
topk_size: 4 # Number of best responses to keep for training
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
**Recommendations**:
|
|
189
|
+
|
|
190
|
+
- Higher `repeats` values provide more diversity but increase computation
|
|
191
|
+
- `topk_size` should be 10-20% of `repeats` for good selection pressure
|
|
192
|
+
- Typical values: `repeats=32, topk_size=4` or `repeats=64, topk_size=8`
|
|
193
|
+
|
|
194
|
+
It is critical to have a reward function that is granular enough for `top-k` to be of any use: a binary reward will have a median value
|
|
195
|
+
will not provide much insight into what outputs outrank others.
|
|
196
|
+
|
|
197
|
+
### KL Regularization
|
|
198
|
+
|
|
199
|
+
KL divergence is controlled via the `kl_to_ref_coeff` parameter:
|
|
200
|
+
|
|
201
|
+
```yaml
|
|
202
|
+
train:
|
|
203
|
+
kl_to_ref_coeff: 1.0 # KL penalty coefficient
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
**Recommendations**:
|
|
207
|
+
|
|
208
|
+
- Start with `kl_to_ref_coeff=1.0` and adjust based on model quality.
|
|
209
|
+
- Higher values keep the model closer to the reference.
|
|
210
|
+
- Lower values allow more exploration but risk quality degradation.
|
|
211
|
+
**Note**: Expert iteration is a rather simple algorithm with little convergence guarantees. Using high KL regularization coefficient and setting it to lower values progressively is advisable.
|
|
212
|
+
|
|
213
|
+
### Run with IFEval Config
|
|
214
|
+
|
|
215
|
+
```bash
|
|
216
|
+
python expert-iteration-sync.py mode=sync --config-name ei_ifeval
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
### Override Config Values
|
|
220
|
+
|
|
221
|
+
```bash
|
|
222
|
+
# Change dataset
|
|
223
|
+
python expert-iteration-sync.py mode=sync env.dataset=ifeval
|
|
224
|
+
|
|
225
|
+
# Modify top-k parameters
|
|
226
|
+
python expert-iteration-sync.py mode=sync env.repeats=64 train.topk_size=8
|
|
227
|
+
|
|
228
|
+
# Adjust KL regularization
|
|
229
|
+
python expert-iteration-sync.py mode=sync train.kl_to_ref_coeff=0.5
|
|
230
|
+
|
|
231
|
+
# Change model
|
|
232
|
+
python expert-iteration-sync.py mode=sync model.name=meta-llama/Llama-2-7b-hf
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
### Hyperparameter Sweeps
|
|
236
|
+
|
|
237
|
+
```bash
|
|
238
|
+
# Top-k size sweep
|
|
239
|
+
python expert-iteration-sync.py mode=sync --multirun train.topk_size=2,4,8
|
|
240
|
+
|
|
241
|
+
# KL coefficient sweep
|
|
242
|
+
python expert-iteration-sync.py mode=sync --multirun train.kl_to_ref_coeff=0.5,1.0,2.0
|
|
243
|
+
|
|
244
|
+
# Multiple parameters
|
|
245
|
+
python expert-iteration-sync.py mode=sync --multirun \
|
|
246
|
+
train.topk_size=4,8 \
|
|
247
|
+
train.kl_to_ref_coeff=0.5,1.0
|
|
248
|
+
```
|
|
249
|
+
|
|
250
|
+
Don't forget to set the number of value of `train.total_dialog_turns` to a reasonable value!
|
|
251
|
+
|
|
252
|
+
## Monitoring
|
|
253
|
+
|
|
254
|
+
Training progress is logged to Weights & Biases with the following metrics:
|
|
255
|
+
|
|
256
|
+
- **Reward**: Average reward of responses in the buffer
|
|
257
|
+
- **Sequence length**: Average length of generated responses
|
|
258
|
+
- **KL divergence**: KL divergence from reference model
|
|
259
|
+
- **Loss metrics**: SFT loss, KL loss, and total loss
|
|
260
|
+
- **Gradient norm**: Gradient clipping statistics
|
|
261
|
+
- **Throughput metrics**: Steps per second, gradient steps per write
|
|
262
|
+
- **Buffer statistics**: Write count, policy version tracking
|
|
263
|
+
|
|
264
|
+
### Collector Logging
|
|
265
|
+
|
|
266
|
+
The collector is given a `RemoteDataLogger` postproc hook that passes the data to a Ray queue, consumed by the training node for logging.
|
|
267
|
+
|
|
268
|
+
This approach ensures:
|
|
269
|
+
- Single wandb run with all metrics (training + collector)
|
|
270
|
+
- No conflicts between multiple wandb loggers
|
|
271
|
+
- Centralized logging through the main process
|
|
272
|
+
|
|
273
|
+
The collector logs the following metrics:
|
|
274
|
+
- **Collector rewards**: Mean, std, min, max of rewards from collected data
|
|
275
|
+
- **Response lengths**: Mean, std, min, max of response lengths
|
|
276
|
+
- **Policy versions**: Mean, min, max of policy versions (for async mode)
|
|
277
|
+
- **Time elapsed**: Time between collection batches
|
|
278
|
+
|
|
279
|
+
To add new collector metrics, modify the `log_data` method in `RemoteDataLogger` in `ei_utils.py`.
|
|
280
|
+
|
|
281
|
+
## Checkpointing
|
|
282
|
+
|
|
283
|
+
Checkpoints are saved every `train.checkpoint_frequency` steps and contain:
|
|
284
|
+
- Model state
|
|
285
|
+
- Optimizer state
|
|
286
|
+
- Gradient scaler state (for mixed precision)
|
|
287
|
+
- Full configuration
|
|
288
|
+
|
|
289
|
+
## Debugging Out-of-memory issues
|
|
290
|
+
|
|
291
|
+
- **vLLM**: Reduce `inference_model.gpu_memory_utilization=FRACTION` or number of environments run in parallel (`env.num_envs=N`)
|
|
292
|
+
- **Reference model**: If the reference model computation is memory-intensive, reduce the number of environments (`env.num_envs=N`) run in parallel
|
|
293
|
+
- **Training**: Reduce batch size (`train.optim_batch_size`)
|
|
294
|
+
- **Top-k**: Reduce `env.repeats` to generate fewer responses per prompt
|
|
295
|
+
|
|
296
|
+
## Directory Structure
|
|
297
|
+
|
|
298
|
+
```
|
|
299
|
+
sota-implementations/expert-iteration/
|
|
300
|
+
├── config/
|
|
301
|
+
│ ├── ei_gsm8k.yaml # Main configuration file
|
|
302
|
+
│ ├── ei_ifeval.yaml # Configuration for IFEval task
|
|
303
|
+
│ └── mode/
|
|
304
|
+
│ ├── async.yaml # Async mode settings
|
|
305
|
+
│ └── sync.yaml # Sync mode settings
|
|
306
|
+
├── expert-iteration-sync.py # Synchronous training script
|
|
307
|
+
├── expert-iteration-async.py # Asynchronous training script
|
|
308
|
+
├── ei_utils.py # Utility functions
|
|
309
|
+
└── README.md # This file
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
## Output Structure
|
|
313
|
+
|
|
314
|
+
Each run creates a timestamped directory under `outputs/`:
|
|
315
|
+
```
|
|
316
|
+
outputs/
|
|
317
|
+
└── YYYY-MM-DD/
|
|
318
|
+
└── HH-MM-SS/
|
|
319
|
+
├── checkpoints/
|
|
320
|
+
│ └── checkpoint_*.pt
|
|
321
|
+
└── .hydra/
|
|
322
|
+
└── config.yaml
|
|
323
|
+
```
|
|
324
|
+
|
|
325
|
+
For hyperparameter sweeps, outputs are stored under `multirun/`.
|
|
326
|
+
|
|
327
|
+
## Theoretical Background
|
|
328
|
+
|
|
329
|
+
Expert Iteration is based on the principle of learning from the best examples rather than all examples. The key insights are:
|
|
330
|
+
|
|
331
|
+
1. **Selective Learning**: By only training on high-quality responses, the model learns more efficiently
|
|
332
|
+
2. **Quality over Quantity**: A smaller dataset of high-quality examples can be more effective than a larger dataset of mixed quality
|
|
333
|
+
3. **Iterative Improvement**: Each iteration produces better responses, which become the training data for the next iteration
|
|
334
|
+
|
|
335
|
+
This approach is particularly effective for language model training where:
|
|
336
|
+
|
|
337
|
+
- Response quality varies significantly
|
|
338
|
+
- High-quality responses are rare but valuable
|
|
339
|
+
- The model can learn to imitate good responses more effectively than avoid bad ones
|
|
340
|
+
|
|
341
|
+
In theory, one could use Exp. It. with samples gathered from other LLMs or expert datasets, although convergence will be harder to control due to
|
|
342
|
+
the inability to use the KL regularization factor.
|
|
343
|
+
|
|
344
|
+
## Comparison with Other Methods
|
|
345
|
+
|
|
346
|
+
| Method | Training Data | KL Handling | Update Frequency |
|
|
347
|
+
|--------|---------------|-------------|------------------|
|
|
348
|
+
| **Expert Iteration** | Top-k responses | In loss function | Reduced (can be less frequent) |
|
|
349
|
+
| **GRPO** | All responses | In reward / loss | Standard |
|
|
350
|
+
| **DPO** | Preference pairs | Implicit in loss | Standard |
|
|
351
|
+
|
|
352
|
+
Expert Iteration's key advantage is its sample efficiency - by focusing on the best responses, it can achieve better performance with fewer training examples and less frequent policy updates.
|