torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# GRPO: Generalized Reward-Conditioned Policy Optimization
|
|
2
|
+
|
|
3
|
+
This is an implementation of GRPO for language models, built on top of TorchRL.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
GRPO is a method for training language models using reinforcement learning, with the following key features:
|
|
8
|
+
- Multi-GPU support with efficient device management
|
|
9
|
+
- Mixed precision training
|
|
10
|
+
- Gradient accumulation
|
|
11
|
+
- Automatic checkpointing
|
|
12
|
+
- Comprehensive logging with Weights & Biases
|
|
13
|
+
- Hydra configuration system
|
|
14
|
+
- Asynchronous training support with Ray
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
Install dependencies:
|
|
19
|
+
```bash
|
|
20
|
+
# GSM8K deps
|
|
21
|
+
pip install -r sota-implementations/grpo/requirements_gsm8k.txt
|
|
22
|
+
# IFEval deps
|
|
23
|
+
pip install -r sota-implementations/grpo/requirements_ifeval.txt
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## Hardware Requirements
|
|
27
|
+
|
|
28
|
+
- At least 3 CUDA-capable GPUs:
|
|
29
|
+
- Training device(s)
|
|
30
|
+
- vLLM inference device
|
|
31
|
+
- Reference model device
|
|
32
|
+
|
|
33
|
+
### Device Management
|
|
34
|
+
|
|
35
|
+
The number of devices for each model component is specified using `num_devices`:
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
This approach:
|
|
42
|
+
- Automatically handles device allocation
|
|
43
|
+
- Works correctly in both sync and async modes
|
|
44
|
+
- Prevents device conflicts between model components
|
|
45
|
+
- Is more portable across different machine configurations
|
|
46
|
+
|
|
47
|
+
## Configuration
|
|
48
|
+
|
|
49
|
+
The training configuration is managed through Hydra. There are two main configuration files:
|
|
50
|
+
- `config/grpo_gsm8k.yaml`: Default configuration for GSM8K tasks (default)
|
|
51
|
+
- `config/grpo_ifeval.yaml`: Configuration optimized for IFEval tasks
|
|
52
|
+
|
|
53
|
+
## Usage
|
|
54
|
+
|
|
55
|
+
### Basic Training
|
|
56
|
+
|
|
57
|
+
There are two training modes available:
|
|
58
|
+
|
|
59
|
+
#### Synchronous Mode (Default)
|
|
60
|
+
```bash
|
|
61
|
+
python sota-implementations/grpo/grpo-sync.py mode=sync train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
#### Asynchronous Mode (Recommended)
|
|
65
|
+
```bash
|
|
66
|
+
python sota-implementations/grpo/grpo-async.py mode=async train_model.num_devices=2 ref_model.num_devices=2 inference_model.num_devices=2
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
The key difference between sync and async modes is how data collection and optimization are handled:
|
|
70
|
+
|
|
71
|
+
**Synchronous Mode (grpo-sync.py)**:
|
|
72
|
+
```python
|
|
73
|
+
# Three nested loops:
|
|
74
|
+
for data in collector: # Data collection loop
|
|
75
|
+
for epoch in range(epochs): # Epoch loop
|
|
76
|
+
for batch in replay_buffer: # Buffer consumption loop
|
|
77
|
+
# Optimize on batch
|
|
78
|
+
loss = loss_fn(batch)
|
|
79
|
+
loss.backward()
|
|
80
|
+
optimizer.step()
|
|
81
|
+
# Weight update
|
|
82
|
+
weight_updater.push_weights(policy_training)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
**Asynchronous Mode (grpo-async.py)**:
|
|
86
|
+
```python
|
|
87
|
+
# Start data collection in background
|
|
88
|
+
collector.start()
|
|
89
|
+
|
|
90
|
+
# Single optimization loop
|
|
91
|
+
for step in range(total_steps):
|
|
92
|
+
# Sample and optimize
|
|
93
|
+
batch = replay_buffer.sample()
|
|
94
|
+
loss = loss_fn(batch)
|
|
95
|
+
loss.backward()
|
|
96
|
+
optimizer.step()
|
|
97
|
+
# Update weights once in a while
|
|
98
|
+
if cond():
|
|
99
|
+
weight_updater.push_weights(policy_training)
|
|
100
|
+
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
Key differences:
|
|
104
|
+
1. **Data Collection**:
|
|
105
|
+
- Sync: Data collection and optimization happen sequentially.
|
|
106
|
+
|
|
107
|
+
*Note*: The `train.sync_iter=False` argument can be used to collect data whilst optimizing. In this context, the
|
|
108
|
+
maximum policy age will be 1. If `train.sync_iter=True` (default), the maximum policy age is `0`.
|
|
109
|
+
|
|
110
|
+
- Async: Data collection runs in background while optimization happens
|
|
111
|
+
|
|
112
|
+
2. **Buffer Size**:
|
|
113
|
+
- Sync: Buffer size must equal the batch size returned by collector (`buffer_size = dialog_turns_per_batch`)
|
|
114
|
+
- Async: Buffer can be larger than the batch size, allowing for more diverse sampling
|
|
115
|
+
|
|
116
|
+
3. **Data Processing**:
|
|
117
|
+
- Sync: Processes the same data multiple times (epochs)
|
|
118
|
+
- Async: Each piece of data is processed a non-deterministic number of times.
|
|
119
|
+
|
|
120
|
+
4. **Weight updates**:
|
|
121
|
+
- Sync: Weights are updated befor every collection of data.
|
|
122
|
+
- Async: Weights are updated at a given interval (in gradient steps). This will require a synchronization between the training
|
|
123
|
+
and inference processes, and frequent updates will cause both workers to often wait for each other.
|
|
124
|
+
|
|
125
|
+
The async mode offers better performance by:
|
|
126
|
+
- Running data collection and optimization concurrently
|
|
127
|
+
- More efficient GPU utilization
|
|
128
|
+
- Reduced memory overhead
|
|
129
|
+
- Better throughput
|
|
130
|
+
- More flexible buffer management
|
|
131
|
+
|
|
132
|
+
### Running GRPO on More Than One Node with SLURM
|
|
133
|
+
|
|
134
|
+
GRPO can be run across more than one node using SLURM, enabling distributed training for moderately scaled workloads.
|
|
135
|
+
|
|
136
|
+
Two scripts are provided for launching multi-node runs:
|
|
137
|
+
|
|
138
|
+
- `grpo-sync-multi-node.sbatch`: SLURM job script that launches sync GRPO across multiple nodes using Ray.
|
|
139
|
+
- `grpo-async-multi-node.sbatch`: SLURM job script that launches async GRPO across multiple nodes using Ray.
|
|
140
|
+
|
|
141
|
+
Example Usage:
|
|
142
|
+
|
|
143
|
+
```bash
|
|
144
|
+
sbatch sota-implementations/grpo/grpo-sync-multi-node.sbatch
|
|
145
|
+
|
|
146
|
+
### KL Divergences in PPO: Reference vs Inference
|
|
147
|
+
|
|
148
|
+
KL divergence is a key regularization term in policy optimization algorithms like PPO and in LLM post-training. It measures how much the updated policy diverges from a baseline or reference policy, helping to prevent the new policy from drifting too far and ensuring stable learning.
|
|
149
|
+
|
|
150
|
+
There are two main types of KL divergences commonly used:
|
|
151
|
+
|
|
152
|
+
#### 1. KL to Reference Policy (KL[ref || policy])
|
|
153
|
+
- **Definition:** Measures how much the new (learned) policy diverges from a fixed reference policy (often the original, pre-trained model).
|
|
154
|
+
- **Implementation:** In GRPO, this is computed as `(ref_log_prob - cur_log_prob).expm1() - (ref_log_prob - cur_log_prob)`, which is a numerically stable way to compute KL for log probabilities.
|
|
155
|
+
- **Usage:**
|
|
156
|
+
- **LLM Post-Training:** This is the canonical choice in LLM post-training (e.g., RLHF, DPO, GRPO). The reference is usually the original language model before any RL fine-tuning. Penalizing KL[ref || policy] ensures the fine-tuned model stays close to the original, preserving language quality and preventing over-optimization.
|
|
157
|
+
- **Effect:** Encourages the new policy to not deviate too much from the reference, maintaining fluency and generalization.
|
|
158
|
+
|
|
159
|
+
#### 2. KL to Inference Policy (KL[policy || inference])
|
|
160
|
+
- **Definition:** Measures how much the current policy diverges from the policy used to generate the data (the inference policy, sometimes called the behavior policy).
|
|
161
|
+
- **Implementation:** In GRPO, this is approximated as `prev_log_prob - cur_log_prob`, where `prev_log_prob` is from the inference policy that generated the data.
|
|
162
|
+
- **Usage:**
|
|
163
|
+
- **Canonical PPO:** In standard PPO (especially in RL for control), this is the canonical KL: KL[policy || inference]. The inference policy is the one that generated the trajectories in the replay buffer. Penalizing this KL ensures that the updated policy does not move too far from the data distribution, stabilizing importance sampling and learning.
|
|
164
|
+
- **Effect:** Prevents the policy from making large, unstable updates relative to the data it was trained on.
|
|
165
|
+
|
|
166
|
+
#### Summary Table
|
|
167
|
+
| Setting | Canonical KL Term | Purpose |
|
|
168
|
+
|--------------------|--------------------------|---------------------------------------------|
|
|
169
|
+
| PPO (RL control) | KL[policy || inference] | Stabilize updates, match data distribution |
|
|
170
|
+
| LLM Post-Training | KL[ref || policy] | Stay close to pre-trained model |
|
|
171
|
+
|
|
172
|
+
In GRPO, both types of KL can be used and controlled via configuration. Typically, for LLM post-training, the KL to reference is the most important for preserving model quality, while the KL to inference is more about stabilizing the optimization process.
|
|
173
|
+
|
|
174
|
+
The KL contributions to the loss can be controlled via the `train.kl_to_ref_coeff` and `train.kl_to_inference_coeff`, respectively.
|
|
175
|
+
|
|
176
|
+
Additionally, the KL to ref loss contribution can be either added to the reward during the grading of the LLM response, or added directly to the loss given by the `train.kl_coef_in_loss` config option.
|
|
177
|
+
|
|
178
|
+
In the original GRPO paper, the KL to reference (KL[ref || policy]) is added **directly to the loss function**, not to the reward. This means that the KL penalty acts as a regularizer during optimization, discouraging the policy from drifting too far from the reference model at every update step. This is in contrast to some RLHF-style approaches, where the KL penalty is added to the reward signal during data collection (i.e., the environment's reward is modified).
|
|
179
|
+
|
|
180
|
+
**Why does this matter?**
|
|
181
|
+
- **KL in the loss (as in GRPO):** The optimization explicitly balances the policy objective and the KL penalty at each gradient step, making the trade-off more direct and stable. This is the canonical approach in GRPO and is controlled by setting `train.kl_coef_in_loss=True` in the config.
|
|
182
|
+
- **KL in the reward:** The KL penalty is treated as part of the environment's reward, so the policy is trained to maximize this modified reward. This can sometimes make the effect of the KL less direct, as it is mixed with the task reward during data collection.
|
|
183
|
+
|
|
184
|
+
In summary, GRPO's approach of adding the KL to reference directly to the loss provides more explicit and stable regularization, and is the recommended setting for most LLM post-training scenarios.
|
|
185
|
+
|
|
186
|
+
### Run with IFEval Config
|
|
187
|
+
|
|
188
|
+
```bash
|
|
189
|
+
python grpo-sync.py mode=sync --config-name grpo_ifeval
|
|
190
|
+
```
|
|
191
|
+
|
|
192
|
+
### Override Config Values
|
|
193
|
+
|
|
194
|
+
```bash
|
|
195
|
+
# Change dataset
|
|
196
|
+
python grpo-sync.py mode=sync env.dataset=ifeval
|
|
197
|
+
|
|
198
|
+
# Modify training parameters
|
|
199
|
+
python grpo-sync.py mode=sync optimizer.lr=2e-5 optimizer.weight_decay=0.01
|
|
200
|
+
|
|
201
|
+
# Change model
|
|
202
|
+
python grpo-sync.py mode=sync model.name=meta-llama/Llama-2-7b-hf
|
|
203
|
+
```
|
|
204
|
+
|
|
205
|
+
### Hyperparameter Sweeps
|
|
206
|
+
|
|
207
|
+
```bash
|
|
208
|
+
# Learning rate sweep
|
|
209
|
+
python grpo-sync.py mode=sync --multirun optimizer.lr=1e-4,1e-5,1e-6
|
|
210
|
+
|
|
211
|
+
# Multiple parameters
|
|
212
|
+
python grpo-sync.py mode=sync --multirun \
|
|
213
|
+
optimizer.lr=1e-4,1e-5 \
|
|
214
|
+
policy.kl_coef=0.01,0.1
|
|
215
|
+
```
|
|
216
|
+
|
|
217
|
+
Don't forget to set the number of value of `train.total_dialog_turns` to a reasonable value!
|
|
218
|
+
|
|
219
|
+
## Monitoring
|
|
220
|
+
|
|
221
|
+
Training progress is logged to Weights & Biases with the following metrics:
|
|
222
|
+
- Reward
|
|
223
|
+
- Advantage
|
|
224
|
+
- KL penalty
|
|
225
|
+
- Sequence length
|
|
226
|
+
- ESS (Effective Sample Size)
|
|
227
|
+
- Loss metrics (objective, clip fraction, etc.)
|
|
228
|
+
- Gradient norm
|
|
229
|
+
- Throughput metrics (in async mode)
|
|
230
|
+
|
|
231
|
+
## Checkpointing
|
|
232
|
+
|
|
233
|
+
Checkpoints are saved every `train.checkpoint_frequency` steps and contain:
|
|
234
|
+
- Model state
|
|
235
|
+
- Optimizer state
|
|
236
|
+
- Gradient scaler state (for mixed precision)
|
|
237
|
+
- Full configuration
|
|
238
|
+
|
|
239
|
+
## Debugging Out-of-memory issues
|
|
240
|
+
|
|
241
|
+
- vLLM: Reduce `inference_model.gpu_memory_utilization=FRACTION` or number of environments run
|
|
242
|
+
in parallel (`env.num_envs=N`).
|
|
243
|
+
- KL scoring: If the KL scoring is achieved on the batch of data,
|
|
244
|
+
reduce the number of environments (`env.num_envs=N`) run in parallel.
|
|
245
|
+
- Training: Reduce batch size (`train.optim_batch_size`)
|
|
246
|
+
|
|
247
|
+
## Directory Structure
|
|
248
|
+
|
|
249
|
+
```
|
|
250
|
+
sota-implementations/grpo/
|
|
251
|
+
├── config/
|
|
252
|
+
│ └── grpo_gsm8k.yaml # Main configuration file
|
|
253
|
+
│ └── grpo_ifeval.yaml # config file for IFEval task
|
|
254
|
+
├── grpo-sync.py # Synchronous training script
|
|
255
|
+
├── grpo-async.py # Asynchronous training script
|
|
256
|
+
├── grpo_utils.py # Utility functions
|
|
257
|
+
└── README.md # This file
|
|
258
|
+
```
|
|
259
|
+
|
|
260
|
+
## Output Structure
|
|
261
|
+
|
|
262
|
+
Each run creates a timestamped directory under `outputs/`:
|
|
263
|
+
```
|
|
264
|
+
outputs/
|
|
265
|
+
└── YYYY-MM-DD/
|
|
266
|
+
└── HH-MM-SS/
|
|
267
|
+
├── checkpoints/
|
|
268
|
+
│ └── checkpoint_*.pt
|
|
269
|
+
└── .hydra/
|
|
270
|
+
└── config.yaml
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
For hyperparameter sweeps, outputs are stored under `multirun/`.
|