torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,543 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import collections
|
|
9
|
+
import importlib
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import TensorDict
|
|
14
|
+
from torch import nn, Tensor
|
|
15
|
+
from torch.nn import functional as F
|
|
16
|
+
|
|
17
|
+
from torchrl.data.llm.prompt import PromptData
|
|
18
|
+
|
|
19
|
+
_has_transformers = importlib.util.find_spec("transformers") is not None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class KLControllerBase(abc.ABC):
|
|
23
|
+
"""Base class for KL controllers.
|
|
24
|
+
|
|
25
|
+
Each controller must implement an update method that takes the current KL value and
|
|
26
|
+
the number of steps and updates the kl_coef attribute of the wrapped model,
|
|
27
|
+
which will multiply the KL during calculation of the reward.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def update(self, kl_values: list[float]) -> float:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ConstantKLController(KLControllerBase):
|
|
36
|
+
"""Constant KL Controller.
|
|
37
|
+
|
|
38
|
+
This controller maintains a fixed coefficient no matter what values it is updated
|
|
39
|
+
with.
|
|
40
|
+
|
|
41
|
+
Keyword Arguments:
|
|
42
|
+
kl_coef (:obj:`float`): The coefficient to multiply KL with when calculating the
|
|
43
|
+
reward.
|
|
44
|
+
model (nn.Module, optional): wrapped model that needs to be controlled.
|
|
45
|
+
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
|
|
46
|
+
be updated in-place.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
kl_coef: float | None = None,
|
|
53
|
+
model: nn.Module | None = None,
|
|
54
|
+
):
|
|
55
|
+
self.model = model
|
|
56
|
+
if model is not None and not hasattr(model, "kl_coef"):
|
|
57
|
+
raise AttributeError(
|
|
58
|
+
"Model input to ConstantKLController doesn't have attribute 'kl_coef'"
|
|
59
|
+
)
|
|
60
|
+
self.coef = kl_coef
|
|
61
|
+
if model is not None:
|
|
62
|
+
self.model.kl_coef = self.coef
|
|
63
|
+
|
|
64
|
+
def update(self, kl_values: list[float] = None) -> float:
|
|
65
|
+
if self.model is not None:
|
|
66
|
+
self.model.kl_coef = self.coef
|
|
67
|
+
return self.coef
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class AdaptiveKLController(KLControllerBase):
|
|
71
|
+
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
|
|
72
|
+
|
|
73
|
+
Keyword Arguments:
|
|
74
|
+
init_kl_coef (:obj:`float`): The starting value of the coefficient.
|
|
75
|
+
target (:obj:`float`): The target KL value. When the observed KL is smaller, the
|
|
76
|
+
coefficient is decreased, thereby relaxing the KL penalty in the training
|
|
77
|
+
objective and allowing the model to stray further from the reference model.
|
|
78
|
+
When the observed KL is greater than the target, the KL coefficient is
|
|
79
|
+
increased, thereby pulling the model back towards the reference model.
|
|
80
|
+
horizon (int): Scaling factor to control how aggressively we update the
|
|
81
|
+
coefficient.
|
|
82
|
+
model (nn.Module, optional): wrapped model that needs to be controlled.
|
|
83
|
+
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
|
|
84
|
+
be updated in-place.
|
|
85
|
+
|
|
86
|
+
Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
|
|
87
|
+
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
*,
|
|
93
|
+
init_kl_coef: float,
|
|
94
|
+
target: float,
|
|
95
|
+
horizon: int,
|
|
96
|
+
model: nn.Module | None = None,
|
|
97
|
+
):
|
|
98
|
+
self.model = model
|
|
99
|
+
self.coef = init_kl_coef
|
|
100
|
+
self.target = target
|
|
101
|
+
self.horizon = horizon
|
|
102
|
+
if model is not None:
|
|
103
|
+
self.model.kl_coef = self.coef
|
|
104
|
+
|
|
105
|
+
def update(self, kl_values: list[float]):
|
|
106
|
+
"""Update ``self.coef`` adaptively.
|
|
107
|
+
|
|
108
|
+
Arguments:
|
|
109
|
+
kl_values (sequence of float): The current KL value between the newest policy and the initial
|
|
110
|
+
policy.
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
if kl_values is None:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"The kl_values were not provided to {type(self)}. "
|
|
116
|
+
f"Make sure these values are provided for the scheduler to be updated "
|
|
117
|
+
f"accordingly. "
|
|
118
|
+
)
|
|
119
|
+
n_steps = len(kl_values)
|
|
120
|
+
# renormalize kls
|
|
121
|
+
kl_value = -torch.as_tensor(kl_values).mean() / self.coef
|
|
122
|
+
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
|
|
123
|
+
mult = 1 + proportional_error * n_steps / self.horizon
|
|
124
|
+
self.coef *= mult # βₜ₊₁
|
|
125
|
+
if self.model is not None:
|
|
126
|
+
self.model.kl_coef = self.coef
|
|
127
|
+
return self.coef
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class RolloutFromModel:
|
|
131
|
+
"""A class for performing rollouts with causal language models.
|
|
132
|
+
|
|
133
|
+
It is assumed that the model this class wraps takes as input tokenized text and
|
|
134
|
+
whose task is to predict the next word in a sentence having read the n previous
|
|
135
|
+
words.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model (transformers.Transformer): the model to be used. Should have a
|
|
139
|
+
:meth:`generate` method.
|
|
140
|
+
ref_model (transformers.Transformer): a frozen version of ``model``
|
|
141
|
+
where params are in their initial configuration. This is used to compute a
|
|
142
|
+
KL penalty for the reward, to stop the model from straying too far from the
|
|
143
|
+
reference model during training.
|
|
144
|
+
reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given
|
|
145
|
+
``input_ids`` and ``attention_mask``, calculates rewards for each token and
|
|
146
|
+
end_scores (the reward for the final token in each sequence).
|
|
147
|
+
kl_coef: (:obj:`float`, optional): initial kl coefficient.
|
|
148
|
+
max_new_tokens (int, optional): the maximum length of the sequence.
|
|
149
|
+
Defaults to 50.
|
|
150
|
+
score_clip (:obj:`float`, optional): Scores from the reward model are clipped to the
|
|
151
|
+
range ``(-score_clip, score_clip)``. Defaults to 10.
|
|
152
|
+
kl_scheduler (KLControllerBase, optional): the KL coefficient scheduler.
|
|
153
|
+
num_steps (int, optional): number of steps between two optimization.
|
|
154
|
+
|
|
155
|
+
Examples:
|
|
156
|
+
>>> from tensordict.nn import TensorDictModule
|
|
157
|
+
>>> from torchrl.modules.models.llm import GPT2RewardModel
|
|
158
|
+
>>> from torchrl.data.llm.utils import RolloutFromModel
|
|
159
|
+
>>> from torchrl.data.llm.dataset import get_dataloader
|
|
160
|
+
>>> from torchrl.data.llm.prompt import PromptData
|
|
161
|
+
>>> from transformers import GPT2LMHeadModel
|
|
162
|
+
>>>
|
|
163
|
+
>>> dl = get_dataloader(
|
|
164
|
+
... batch_size=4,
|
|
165
|
+
... block_size=550,
|
|
166
|
+
... tensorclass_type=PromptData,
|
|
167
|
+
... device="cpu",
|
|
168
|
+
... dataset_name="CarperAI/openai_summarize_tldr",
|
|
169
|
+
... )
|
|
170
|
+
>>> model = GPT2LMHeadModel.from_pretrained("gpt2")
|
|
171
|
+
>>> # we load ref_model with random weights so it differs from model
|
|
172
|
+
>>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class())
|
|
173
|
+
>>> reward_model = GPT2RewardModel(model_path="gpt2")
|
|
174
|
+
>>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model)
|
|
175
|
+
>>>
|
|
176
|
+
>>> batch = next(dl)
|
|
177
|
+
>>> rollout = rollout_from_model.rollout_from_data(batch)
|
|
178
|
+
>>> rollout
|
|
179
|
+
TensorDict(
|
|
180
|
+
fields={
|
|
181
|
+
action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
182
|
+
attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
183
|
+
input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
184
|
+
next: TensorDict(
|
|
185
|
+
fields={
|
|
186
|
+
attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
187
|
+
done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
188
|
+
input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
189
|
+
reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
190
|
+
reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
191
|
+
reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
192
|
+
batch_size=torch.Size([4, 50]),
|
|
193
|
+
device=cpu,
|
|
194
|
+
is_shared=False),
|
|
195
|
+
sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
196
|
+
batch_size=torch.Size([4, 50]),
|
|
197
|
+
device=cpu,
|
|
198
|
+
is_shared=False)
|
|
199
|
+
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
EOS_TOKEN_ID = 50256
|
|
203
|
+
|
|
204
|
+
def __init__(
|
|
205
|
+
self,
|
|
206
|
+
model,
|
|
207
|
+
ref_model,
|
|
208
|
+
reward_model,
|
|
209
|
+
kl_coef=0.1,
|
|
210
|
+
max_new_tokens=50,
|
|
211
|
+
score_clip=10.0,
|
|
212
|
+
kl_scheduler: KLControllerBase | None = None,
|
|
213
|
+
num_steps: int | None = None,
|
|
214
|
+
):
|
|
215
|
+
if not _has_transformers:
|
|
216
|
+
raise ImportError(
|
|
217
|
+
"transformers module couldn't be found. Make sure it is installed in your "
|
|
218
|
+
"environment."
|
|
219
|
+
)
|
|
220
|
+
self.model = model
|
|
221
|
+
self.ref_model = ref_model
|
|
222
|
+
self.reward_model = reward_model
|
|
223
|
+
self.max_new_tokens = max_new_tokens
|
|
224
|
+
self.score_clip = score_clip
|
|
225
|
+
self.kl_coef = kl_coef
|
|
226
|
+
self.kl_scheduler = kl_scheduler
|
|
227
|
+
if num_steps is not None:
|
|
228
|
+
self._kl_queue = collections.deque(maxlen=num_steps)
|
|
229
|
+
else:
|
|
230
|
+
# we create a list. Value appended to it will be detached scalars so very cheap to store,
|
|
231
|
+
# even if the update is not called.
|
|
232
|
+
# The scheduler update will take care of erasing these values.
|
|
233
|
+
self._kl_queue = []
|
|
234
|
+
|
|
235
|
+
@torch.no_grad()
|
|
236
|
+
def rollout_from_data(self, batch):
|
|
237
|
+
generated, log_probs, log_ratio = self.generate(batch)
|
|
238
|
+
return self.create_rollout_td(batch, generated, log_probs, log_ratio)
|
|
239
|
+
|
|
240
|
+
@torch.no_grad()
|
|
241
|
+
def create_rollout_td(self, batch, generated, log_probs, log_ratio):
|
|
242
|
+
"""A TensorDict wrapper for generated data.
|
|
243
|
+
|
|
244
|
+
This function takes a batch plus the generated tokens and replicates the
|
|
245
|
+
tensordict structure that would have been obtained from a rollout with a TorchRL
|
|
246
|
+
env that sampled one token each timestep.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
batch (TensorDict): A batch of data containing the original prompt together with a field
|
|
250
|
+
"rindex" indicating the right index of the prompt.
|
|
251
|
+
generated (torch.Tensor): Tokenized prompt followed by generated tokens. This can be obtained
|
|
252
|
+
by calling the ``generate`` method.
|
|
253
|
+
log_probs (torch.Tensor): The log probabilities of the generated tokens. Can be obtained by
|
|
254
|
+
calling the ``generate`` method.
|
|
255
|
+
log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens
|
|
256
|
+
according to the generative model and the reference model. Can be
|
|
257
|
+
obtained by calling the ``generate`` method.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
A :class:`~tensordict.TensorDict` with the following keys:
|
|
261
|
+
|
|
262
|
+
- ``"action"``: the sequence of actions (generated tokens)
|
|
263
|
+
- ``"input_ids"``: the input_ids passed to the generative model at each time
|
|
264
|
+
step.
|
|
265
|
+
- ``"attention_mask"``: the attention_masks passed to the generative model at
|
|
266
|
+
each time step
|
|
267
|
+
- ``"sample_log_prob"``: the log probability of each token during generation
|
|
268
|
+
- ``("next", "input_ids")``: the sequence of tokens after generation. Makes up
|
|
269
|
+
part of the inputs that will be used for generating the next token.
|
|
270
|
+
- ``("next", "attention_mask")``: updated attention_mask after token has been
|
|
271
|
+
generated. Passed to the generative model on the next time step
|
|
272
|
+
- ``("next", "terminated")``: Boolean array indicating whether we've reached a
|
|
273
|
+
terminal state (either because we generated EOS token or because we
|
|
274
|
+
reached the token limit)
|
|
275
|
+
- ``("next", "done")``: Boolean array indicating whether we've reached a
|
|
276
|
+
final state. Currently a copy of ``"terminated"``.
|
|
277
|
+
- ``("next", "reward")``: The reward received at each time step
|
|
278
|
+
- ``("next", "reward_raw")``: The raw reward from the reward model, without the
|
|
279
|
+
KL term. This is mainly for debugging and logging, it is not used in
|
|
280
|
+
training
|
|
281
|
+
- ``("next", "reward_kl")``: The KL term from the reward. This is mainly for
|
|
282
|
+
debugging and logging, it is not used in training.
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
rollout_generated = self._get_rollout_generated(generated, batch)
|
|
286
|
+
rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool()
|
|
287
|
+
|
|
288
|
+
done, terminated = self._get_done_status(generated, batch)
|
|
289
|
+
action = self._get_action(generated, batch)
|
|
290
|
+
end_scores, end_scores_labels = self._get_end_scores(
|
|
291
|
+
rollout_generated, rollout_attention_mask, batch
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# the reward is zero except for the timestep where we reached a stopping condition
|
|
295
|
+
clipped_scores = torch.clip(
|
|
296
|
+
end_scores - end_scores_labels, -self.score_clip, self.score_clip
|
|
297
|
+
)
|
|
298
|
+
reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1)
|
|
299
|
+
reward_raw = reward_raw * done
|
|
300
|
+
reward_kl = -self.kl_coef * log_ratio.unsqueeze(-1)
|
|
301
|
+
reward = reward_raw + reward_kl
|
|
302
|
+
td = {
|
|
303
|
+
"action": action,
|
|
304
|
+
"input_ids": rollout_generated[:, :-1].clone(),
|
|
305
|
+
"attention_mask": rollout_attention_mask[:, :-1].clone(),
|
|
306
|
+
"sample_log_prob": log_probs,
|
|
307
|
+
"next": {
|
|
308
|
+
"input_ids": rollout_generated[:, 1:].clone(),
|
|
309
|
+
"attention_mask": rollout_attention_mask[:, 1:].clone(),
|
|
310
|
+
"done": done,
|
|
311
|
+
"terminated": terminated,
|
|
312
|
+
"reward": reward,
|
|
313
|
+
"reward_raw": reward_raw,
|
|
314
|
+
"reward_kl": reward_kl,
|
|
315
|
+
},
|
|
316
|
+
}
|
|
317
|
+
self._kl_queue.append(reward_kl.detach().mean())
|
|
318
|
+
return TensorDict(
|
|
319
|
+
td, batch_size=done.shape[:2], device=generated.device
|
|
320
|
+
).refine_names(..., "time")
|
|
321
|
+
|
|
322
|
+
def _get_rollout_generated(self, generated, batch):
|
|
323
|
+
# stack the individual timesteps during generation into a single tensor
|
|
324
|
+
rollout_generated = []
|
|
325
|
+
arange = torch.arange(generated.shape[1], device=generated.device)
|
|
326
|
+
for rindex, row in zip(batch.prompt_rindex, generated):
|
|
327
|
+
tokens = []
|
|
328
|
+
for i in range(self.max_new_tokens + 1):
|
|
329
|
+
tokens.append(torch.where(arange < rindex + i, row, self.EOS_TOKEN_ID))
|
|
330
|
+
rollout_generated.append(torch.stack(tokens))
|
|
331
|
+
rollout_generated = torch.stack(rollout_generated)
|
|
332
|
+
return rollout_generated
|
|
333
|
+
|
|
334
|
+
def _get_done_status(self, generated, batch):
|
|
335
|
+
# done is True when we either first sample an EOS token or reach the maximum number
|
|
336
|
+
# of generated tokens
|
|
337
|
+
done_idx = torch.minimum(
|
|
338
|
+
(generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
|
|
339
|
+
torch.as_tensor(self.max_new_tokens) - 1,
|
|
340
|
+
)
|
|
341
|
+
truncated_idx = (
|
|
342
|
+
torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as(
|
|
343
|
+
done_idx
|
|
344
|
+
)
|
|
345
|
+
- 1
|
|
346
|
+
)
|
|
347
|
+
zeros = torch.zeros(
|
|
348
|
+
done_idx.numel(),
|
|
349
|
+
self.max_new_tokens,
|
|
350
|
+
dtype=torch.bool,
|
|
351
|
+
device=generated.device,
|
|
352
|
+
)
|
|
353
|
+
truncated = zeros.scatter(-1, truncated_idx.unsqueeze(-1), 1).unsqueeze(-1)
|
|
354
|
+
done = zeros.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1)
|
|
355
|
+
terminated = (
|
|
356
|
+
done & ~truncated
|
|
357
|
+
) # we assume that if it's not truncated, it was terminated
|
|
358
|
+
return truncated | terminated, terminated
|
|
359
|
+
|
|
360
|
+
def _get_action(self, generated, batch):
|
|
361
|
+
# the sequence of actions for each trajectory is just the generated token ids
|
|
362
|
+
action_idx = torch.arange(self.max_new_tokens, device=generated.device)
|
|
363
|
+
action_idx = action_idx + batch.prompt_rindex.unsqueeze(-1)
|
|
364
|
+
return generated.gather(-1, action_idx)
|
|
365
|
+
|
|
366
|
+
def _get_end_scores(self, rollout_generated, rollout_attention_mask, batch):
|
|
367
|
+
# calculate the reward for the finished sequence
|
|
368
|
+
_, end_scores = self.reward_model(
|
|
369
|
+
input_ids=rollout_generated[:, -1],
|
|
370
|
+
attention_mask=rollout_attention_mask[:, -1],
|
|
371
|
+
)
|
|
372
|
+
_, end_scores_labels = self.reward_model(
|
|
373
|
+
input_ids=batch.input_ids,
|
|
374
|
+
attention_mask=batch.attention_mask,
|
|
375
|
+
)
|
|
376
|
+
return end_scores, end_scores_labels
|
|
377
|
+
|
|
378
|
+
@classmethod
|
|
379
|
+
def _padded_right_to_left(cls, tensor, *, eos_token_id=None, dim=1):
|
|
380
|
+
if eos_token_id is None:
|
|
381
|
+
eos_token_id = cls.EOS_TOKEN_ID
|
|
382
|
+
mask = tensor != eos_token_id
|
|
383
|
+
out = torch.full_like(tensor, eos_token_id)
|
|
384
|
+
out[mask.flip(dim)] = tensor[mask]
|
|
385
|
+
return out
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def _padded_left_to_right(
|
|
389
|
+
cls, tensor, *, sequence_length=None, eos_token_id=None, dim=1
|
|
390
|
+
):
|
|
391
|
+
# some care must be taken here, because generated sequences may have both left
|
|
392
|
+
# and right padding, and also may not terminated early if all sequences in the
|
|
393
|
+
# batch reached EOS before reaching the token limit
|
|
394
|
+
if sequence_length is None:
|
|
395
|
+
sequence_length = tensor.size(dim)
|
|
396
|
+
if dim < 0:
|
|
397
|
+
dim = tensor.ndim + dim
|
|
398
|
+
if eos_token_id is None:
|
|
399
|
+
eos_token_id = cls.EOS_TOKEN_ID
|
|
400
|
+
mask = tensor != eos_token_id
|
|
401
|
+
# convert [0, 0, 1, 1, 0] to [0, 0, 1, 1, 1] to avoid right eos
|
|
402
|
+
mask = ~((~mask).to(torch.uint8).cumprod(dim).bool())
|
|
403
|
+
shape = list(mask.shape)
|
|
404
|
+
shape[dim] = sequence_length
|
|
405
|
+
out = torch.full(torch.Size(shape), eos_token_id, device=tensor.device)
|
|
406
|
+
index = (slice(None),) * dim + (slice(tensor.size(dim)),)
|
|
407
|
+
out[index][mask.flip(dim)] = tensor[mask]
|
|
408
|
+
return out
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def _default_conf(self):
|
|
412
|
+
from transformers import GenerationConfig
|
|
413
|
+
|
|
414
|
+
return GenerationConfig(
|
|
415
|
+
pad_token_id=self.EOS_TOKEN_ID,
|
|
416
|
+
max_new_tokens=self.max_new_tokens,
|
|
417
|
+
return_dict_in_generate=True,
|
|
418
|
+
output_scores=True,
|
|
419
|
+
do_sample=True,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
def _get_scores(
|
|
423
|
+
self, scores: tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None
|
|
424
|
+
):
|
|
425
|
+
scores = torch.stack(scores, 1)
|
|
426
|
+
if scores.shape[1] != self.max_new_tokens:
|
|
427
|
+
scores = F.pad(
|
|
428
|
+
scores,
|
|
429
|
+
(0, 0, 0, self.max_new_tokens - scores.shape[1]),
|
|
430
|
+
value=float("-inf"),
|
|
431
|
+
)
|
|
432
|
+
scores = F.log_softmax(scores, dim=-1)
|
|
433
|
+
if use_max:
|
|
434
|
+
scores = scores.max(dim=-1).values
|
|
435
|
+
else:
|
|
436
|
+
index = generated_tokens.unsqueeze(-1)
|
|
437
|
+
scores = torch.gather(scores, dim=-1, index=index)
|
|
438
|
+
if pad_to is not None:
|
|
439
|
+
pad = pad_to - scores.shape[1]
|
|
440
|
+
return F.pad(scores, (0, pad), value=-float("inf"))
|
|
441
|
+
return scores
|
|
442
|
+
|
|
443
|
+
@staticmethod
|
|
444
|
+
def logprobs_of_labels(logits, labels):
|
|
445
|
+
"""Log probabilities of the labels.
|
|
446
|
+
|
|
447
|
+
These are calculated from the logits. The labels (token ids) are used to index
|
|
448
|
+
the logits along the relevant dimension.
|
|
449
|
+
"""
|
|
450
|
+
logprobs = F.log_softmax(logits, dim=-1)
|
|
451
|
+
logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))
|
|
452
|
+
return logprobs_labels.squeeze(-1)
|
|
453
|
+
|
|
454
|
+
@torch.no_grad()
|
|
455
|
+
def _log_ratio(self, generated, prompt_rindex):
|
|
456
|
+
# get the scores and normalise for log probabilities
|
|
457
|
+
attention_mask = (generated != self.EOS_TOKEN_ID).bool()
|
|
458
|
+
logits = self.model(
|
|
459
|
+
input_ids=generated, attention_mask=attention_mask, return_dict=True
|
|
460
|
+
).logits
|
|
461
|
+
logprobs = self.logprobs_of_labels(logits[:, :-1], generated[:, 1:])
|
|
462
|
+
ref_logits = self.ref_model(
|
|
463
|
+
input_ids=generated.to(self.ref_model.device),
|
|
464
|
+
attention_mask=attention_mask.to(self.ref_model.device),
|
|
465
|
+
return_dict=True,
|
|
466
|
+
).logits.to(logits.device)
|
|
467
|
+
ref_logprobs = self.logprobs_of_labels(ref_logits[:, :-1], generated[:, 1:])
|
|
468
|
+
log_ratio = logprobs - ref_logprobs
|
|
469
|
+
log_ratio = log_ratio.masked_fill(~attention_mask[:, :-1], 0)
|
|
470
|
+
log_ratio = torch.stack(
|
|
471
|
+
[
|
|
472
|
+
row[rindex - 1 : rindex + self.max_new_tokens - 1]
|
|
473
|
+
for row, rindex in zip(log_ratio, prompt_rindex)
|
|
474
|
+
],
|
|
475
|
+
dim=0,
|
|
476
|
+
)
|
|
477
|
+
return log_ratio
|
|
478
|
+
|
|
479
|
+
def _get_generated_tokens(self, generated, rindex):
|
|
480
|
+
# extracts the generated tokens from the full sequence of prompt + generated
|
|
481
|
+
idx = torch.arange(generated.shape[1], device=generated.device)
|
|
482
|
+
rindex = rindex.unsqueeze(-1)
|
|
483
|
+
mask = (idx >= rindex) & (idx < rindex + self.max_new_tokens)
|
|
484
|
+
return generated[mask].reshape(-1, self.max_new_tokens)
|
|
485
|
+
|
|
486
|
+
@torch.no_grad()
|
|
487
|
+
def generate(self, batch: PromptData, generation_config=None):
|
|
488
|
+
"""Generates a sequence of tokens from a batch of data sampled from the data collector.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
batch (PromptData): the data to be used. Must have ``input_ids``
|
|
492
|
+
and ``prompt_rindex`` fields.
|
|
493
|
+
generation_config (GenerationConfig, optional): the configuration for the
|
|
494
|
+
call to generate.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
generated (torch.Tensor): a [B x (Ti +To)] sequence of integers (tokens),
|
|
498
|
+
where Ti is the length of the input sequence and To is the length
|
|
499
|
+
of the generated sequence.
|
|
500
|
+
log_probs_gen: the log-probabilities of the token generated.
|
|
501
|
+
log_ratio: the log ratio between probabilities under the generative
|
|
502
|
+
model and the frozen version.
|
|
503
|
+
|
|
504
|
+
"""
|
|
505
|
+
input_ids = batch.mask_label().input_ids
|
|
506
|
+
|
|
507
|
+
# move padding tokens to left pad
|
|
508
|
+
# huggingface models expect left padding for generation
|
|
509
|
+
input_ids = self._padded_right_to_left(input_ids)
|
|
510
|
+
|
|
511
|
+
# generate and capture scores
|
|
512
|
+
if generation_config is None:
|
|
513
|
+
generation_config = self._default_conf
|
|
514
|
+
|
|
515
|
+
attention_mask = (input_ids != self.EOS_TOKEN_ID).bool()
|
|
516
|
+
outputs = self.model.generate(
|
|
517
|
+
input_ids=input_ids,
|
|
518
|
+
attention_mask=attention_mask,
|
|
519
|
+
generation_config=generation_config,
|
|
520
|
+
)
|
|
521
|
+
samples = outputs.sequences
|
|
522
|
+
|
|
523
|
+
# we'll insert generated tokens into a tensor prepopulated with padding tokens,
|
|
524
|
+
# thereby moving back to right padding for reward model
|
|
525
|
+
generated = self._padded_left_to_right(
|
|
526
|
+
samples,
|
|
527
|
+
sequence_length=input_ids.shape[1] + self.max_new_tokens,
|
|
528
|
+
eos_token_id=self.EOS_TOKEN_ID,
|
|
529
|
+
)
|
|
530
|
+
generated_tokens = self._get_generated_tokens(generated, batch.prompt_rindex)
|
|
531
|
+
# get the scores and normalise for log probabilities
|
|
532
|
+
log_probs_gen = self._get_scores(outputs.scores, generated_tokens)
|
|
533
|
+
|
|
534
|
+
log_ratio = self._log_ratio(generated, batch.prompt_rindex)
|
|
535
|
+
return generated, log_probs_gen, log_ratio
|
|
536
|
+
|
|
537
|
+
def step_scheduler(self):
|
|
538
|
+
# recover true kl
|
|
539
|
+
self.kl_coef = self.kl_scheduler.update(self._kl_queue)
|
|
540
|
+
if isinstance(self._kl_queue, (list, collections.deque)):
|
|
541
|
+
# remove all values
|
|
542
|
+
while len(self._kl_queue):
|
|
543
|
+
self._kl_queue.remove(self._kl_queue[0])
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
|
|
7
|
+
from .query import HashToInt, QueryModule
|
|
8
|
+
from .tdstorage import TensorDictMap, TensorMap
|
|
9
|
+
from .tree import MCTSForest, Tree
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BinaryToDecimal",
|
|
13
|
+
"RandomProjectionHash",
|
|
14
|
+
"SipHash",
|
|
15
|
+
"HashToInt",
|
|
16
|
+
"QueryModule",
|
|
17
|
+
"TensorDictMap",
|
|
18
|
+
"TensorMap",
|
|
19
|
+
"MCTSForest",
|
|
20
|
+
"Tree",
|
|
21
|
+
]
|