torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1544 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from contextlib import nullcontext
|
|
11
|
+
from copy import copy
|
|
12
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
|
|
16
|
+
from tensordict.utils import _zip_strict, is_seq_of_nested_key, logger as torchrl_logger
|
|
17
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
18
|
+
from torchrl.data import Composite, Unbounded
|
|
19
|
+
from torchrl.data.tensor_specs import DEVICE_TYPING
|
|
20
|
+
from torchrl.envs import EnvBase, Transform
|
|
21
|
+
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
|
|
22
|
+
from torchrl.envs.transforms.transforms import Compose
|
|
23
|
+
from torchrl.envs.transforms.utils import _set_missing_tolerance
|
|
24
|
+
from torchrl.modules.llm.policies.common import LLMWrapperBase
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
import transformers
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RayKLRewardTransform(RayTransform):
|
|
31
|
+
"""A Ray-based implementation of :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`.
|
|
32
|
+
|
|
33
|
+
This class creates a Ray remote actor from KLRewardTransform that can be shared across multiple workers.
|
|
34
|
+
All method calls are delegated to the remote actor, ensuring that multiple environments can
|
|
35
|
+
share the same KL computation resources.
|
|
36
|
+
|
|
37
|
+
To avoid serialization issues with large models, this class supports model factories
|
|
38
|
+
that create models on the remote actor rather than passing full models through Ray channels.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
ref_model (LLMWrapperBase, optional): the reference model. Prefer using a model factory instead
|
|
42
|
+
to avoid serialization issues.
|
|
43
|
+
|
|
44
|
+
Keyword Args:
|
|
45
|
+
ref_model_factory (Callable[[], LLMWrapperBase], optional): A callable that returns a reference model.
|
|
46
|
+
This allows for explicit resource control and avoids serialization issues.
|
|
47
|
+
num_cpus (int, optional): Number of CPUs to allocate to the Ray actor. Defaults to 1.
|
|
48
|
+
num_gpus (int, optional): Number of GPUs to allocate to the Ray actor. Defaults to 0.
|
|
49
|
+
device (torch.device, optional): Device to use on the remote Ray actor for tensor operations.
|
|
50
|
+
The local Ray transform will handle CPU serialization and device restoration automatically.
|
|
51
|
+
Defaults to None.
|
|
52
|
+
actor_name (str, optional): Name of the Ray actor to use. If provided, the actor will be reused if it already exists.
|
|
53
|
+
**kwargs: Additional keyword arguments to pass to KLRewardTransform.
|
|
54
|
+
|
|
55
|
+
Note:
|
|
56
|
+
When using model factories, the corresponding model argument (ref_model) should be None.
|
|
57
|
+
Model factories are preferred for large models to avoid serialization overhead.
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> # Option 1: Using model factory for explicit resource control
|
|
61
|
+
>>> def create_ref_model():
|
|
62
|
+
... return TransformersWrapper(ref_model, tokenizer=tokenizer, generate=False, return_log_probs=True)
|
|
63
|
+
>>> transform = RayKLRewardTransform(
|
|
64
|
+
... ref_model=None,
|
|
65
|
+
... ref_model_factory=create_ref_model,
|
|
66
|
+
... num_gpus=1,
|
|
67
|
+
... device=torch.device("cuda")
|
|
68
|
+
... )
|
|
69
|
+
|
|
70
|
+
>>> # Option 2: Pass model directly (Ray handles serialization)
|
|
71
|
+
>>> transform = RayKLRewardTransform(ref_model=ref_model, device=torch.device("cuda"))
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
ref_model: LLMWrapperBase | None = None,
|
|
77
|
+
*,
|
|
78
|
+
ref_model_factory: Callable[[], LLMWrapperBase] | None = None,
|
|
79
|
+
num_cpus: int | None = None,
|
|
80
|
+
num_gpus: int = 0,
|
|
81
|
+
device: DEVICE_TYPING | None = None,
|
|
82
|
+
actor_name: str | None = None,
|
|
83
|
+
**kwargs,
|
|
84
|
+
):
|
|
85
|
+
# Validate arguments: model and factory should not both be provided
|
|
86
|
+
if ref_model is not None and ref_model_factory is not None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"Cannot provide both 'ref_model' and 'ref_model_factory'. Choose one."
|
|
89
|
+
)
|
|
90
|
+
if ref_model is None and ref_model_factory is None:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"Must provide exactly one of 'ref_model' or 'ref_model_factory'."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Store creation parameters for actor creation
|
|
96
|
+
self._ref_model = ref_model
|
|
97
|
+
self._ref_model_factory = ref_model_factory
|
|
98
|
+
self._creation_kwargs = kwargs
|
|
99
|
+
# Store device separately for passing to remote actor
|
|
100
|
+
self._remote_device = device
|
|
101
|
+
|
|
102
|
+
# Default num_cpus
|
|
103
|
+
if num_cpus is None:
|
|
104
|
+
num_cpus = 1
|
|
105
|
+
|
|
106
|
+
# Call parent constructor without device (Ray transform handles CPU/device mapping)
|
|
107
|
+
super().__init__(
|
|
108
|
+
num_cpus=num_cpus,
|
|
109
|
+
num_gpus=num_gpus,
|
|
110
|
+
device=None, # Don't store device locally
|
|
111
|
+
actor_name=actor_name,
|
|
112
|
+
**kwargs,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _create_actor(self, **kwargs):
|
|
116
|
+
"""Create the remote KLRewardTransform actor."""
|
|
117
|
+
# Create the remote KLRewardTransform with resource specifications
|
|
118
|
+
RemoteKLRewardTransform = self._ray.remote(
|
|
119
|
+
num_cpus=self._num_cpus, num_gpus=self._num_gpus
|
|
120
|
+
)(KLRewardTransform)
|
|
121
|
+
|
|
122
|
+
if self._actor_name is not None:
|
|
123
|
+
RemoteKLRewardTransform = RemoteKLRewardTransform.options(
|
|
124
|
+
name=self._actor_name
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Determine how to create model on the remote actor
|
|
128
|
+
ref_model_arg = self._ref_model
|
|
129
|
+
|
|
130
|
+
# If we have factory, we'll pass it and set model to None
|
|
131
|
+
creation_kwargs = self._creation_kwargs.copy()
|
|
132
|
+
if self._ref_model_factory is not None:
|
|
133
|
+
creation_kwargs["ref_model_factory"] = self._ref_model_factory
|
|
134
|
+
ref_model_arg = None
|
|
135
|
+
|
|
136
|
+
# Pass device to the remote actor
|
|
137
|
+
if self._remote_device is not None:
|
|
138
|
+
creation_kwargs["device"] = self._remote_device
|
|
139
|
+
|
|
140
|
+
# Create the shared actor
|
|
141
|
+
actor = RemoteKLRewardTransform.remote(
|
|
142
|
+
ref_model=ref_model_arg, **creation_kwargs
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return actor
|
|
146
|
+
|
|
147
|
+
def __repr__(self):
|
|
148
|
+
"""String representation."""
|
|
149
|
+
try:
|
|
150
|
+
if hasattr(self, "_actor") and self._actor is not None:
|
|
151
|
+
return self._ray.get(self._actor.__repr__.remote())
|
|
152
|
+
else:
|
|
153
|
+
return "RayKLRewardTransform(actor=None)"
|
|
154
|
+
except Exception:
|
|
155
|
+
return f"RayKLRewardTransform(actor={getattr(self, '_actor', 'None')})"
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class KLRewardTransform(Transform, metaclass=_RayServiceMetaClass):
|
|
159
|
+
"""A legacy transform for computing KL divergence-based rewards.
|
|
160
|
+
|
|
161
|
+
**Deprecated**: This transform is maintained for backward compatibility but is no longer
|
|
162
|
+
the recommended approach. Use :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` instead,
|
|
163
|
+
which provides better modularity and integration with the new wrapper design.
|
|
164
|
+
|
|
165
|
+
**Recent Changes:**
|
|
166
|
+
|
|
167
|
+
- **Legacy Status**: This transform is now considered legacy and may not work optimally
|
|
168
|
+
with the new modular wrapper design.
|
|
169
|
+
- **ChatHistory Integration**: Limited support for the new :class:`~torchrl.modules.llm.policies.ChatHistory` objects.
|
|
170
|
+
- **Input Mode Support**: May not handle all input modes (`"history"`, `"text"`, `"tokens"`) consistently.
|
|
171
|
+
|
|
172
|
+
**Recommendation**:
|
|
173
|
+
Use :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` for new code, which provides:
|
|
174
|
+
- Better integration with the new wrapper design
|
|
175
|
+
- Consistent support for all input modes
|
|
176
|
+
- Proper handling of ChatHistory objects
|
|
177
|
+
- More modular and composable architecture
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
ref_model (LLMWrapperBase): the reference model.
|
|
181
|
+
|
|
182
|
+
Keyword Args:
|
|
183
|
+
ref_model_factory (Callable[[], LLMWrapperBase], optional): A callable that returns a reference model.
|
|
184
|
+
assistant_only (bool): whether to only compute KL on assistant tokens. Defaults to `True`.
|
|
185
|
+
tokenizer (transformers.AutoTokenizer): the tokenizer to use. Defaults to `None`.
|
|
186
|
+
detach (bool): whether to detach the KL from the computation graph. Defaults to `True`.
|
|
187
|
+
device (torch.device): the device to cast the tensors to. This is not the device of the specs, but the device
|
|
188
|
+
onto which the tensors will be moved. It allows to keep the model on a different device
|
|
189
|
+
than the upcoming data. When using Ray service, this device will be used on the remote actor.
|
|
190
|
+
Defaults to `None`.
|
|
191
|
+
padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
|
|
192
|
+
use_ray_service (bool, optional): whether to use Ray service. Defaults to `False`.
|
|
193
|
+
actor_name (str, optional): the name of the Ray actor to use. Defaults to `None`.
|
|
194
|
+
|
|
195
|
+
Examples:
|
|
196
|
+
>>> # Legacy usage (not recommended for new code)
|
|
197
|
+
>>> transform = KLRewardTransform(gen_model, ref_model)
|
|
198
|
+
>>>
|
|
199
|
+
>>> # Recommended approach using RetrieveKL
|
|
200
|
+
>>> from torchrl.envs.llm.transforms.kl import RetrieveKL
|
|
201
|
+
>>> transform = RetrieveKL(gen_model, ref_model, assistant_only=True)
|
|
202
|
+
|
|
203
|
+
.. seealso::
|
|
204
|
+
:class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: The recommended transform for KL divergence computation.
|
|
205
|
+
:class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: Base transform for retrieving log-probabilities.
|
|
206
|
+
:class:`~torchrl.envs.llm.transforms.kl.KLComputation`: Transform for computing KL divergence between log-prob tensors.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
DEFAULT_IN_KEYS = ["reward"]
|
|
210
|
+
_RayServiceClass = RayKLRewardTransform
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
ref_model: LLMWrapperBase | None = None,
|
|
215
|
+
*,
|
|
216
|
+
ref_model_factory: Callable[[], LLMWrapperBase] | None = None,
|
|
217
|
+
coef=1.0,
|
|
218
|
+
in_keys=None,
|
|
219
|
+
out_keys=None,
|
|
220
|
+
log_prob_key: NestedKey = ("log_probs", "full"),
|
|
221
|
+
device: torch.device | None = None,
|
|
222
|
+
add_to_reward: bool = True,
|
|
223
|
+
tokenizer: transformers.AutoTokenizer | None = None,
|
|
224
|
+
assistant_only: bool = True,
|
|
225
|
+
padding_side: str = "left",
|
|
226
|
+
use_ray_service: bool = False,
|
|
227
|
+
):
|
|
228
|
+
# Handle model factory - create model if factory is provided
|
|
229
|
+
if ref_model_factory is not None:
|
|
230
|
+
if ref_model is not None:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"Cannot provide both 'ref_model' and 'ref_model_factory'. Choose one."
|
|
233
|
+
)
|
|
234
|
+
ref_model = ref_model_factory()
|
|
235
|
+
elif ref_model is None:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"Must provide exactly one of 'ref_model' or 'ref_model_factory'."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if in_keys is None:
|
|
241
|
+
in_keys = self.DEFAULT_IN_KEYS
|
|
242
|
+
if out_keys is None:
|
|
243
|
+
out_keys = copy(in_keys)
|
|
244
|
+
if len(out_keys) == len(in_keys):
|
|
245
|
+
out_keys = out_keys + ["kl_penalty", "ref_log_probs"]
|
|
246
|
+
elif len(out_keys) != len(in_keys) + 2:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
"The out_keys must have the same length as the in_keys (plus two additional optional kl entries for logging)."
|
|
249
|
+
)
|
|
250
|
+
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
|
251
|
+
if not is_seq_of_nested_key(self.in_keys) or not is_seq_of_nested_key(
|
|
252
|
+
self.out_keys
|
|
253
|
+
):
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"invalid in_keys / out_keys:\nin_keys={self.in_keys} \nout_keys={self.out_keys}"
|
|
256
|
+
)
|
|
257
|
+
if len(self.in_keys) != 1 or len(self.out_keys) != 3:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}."
|
|
260
|
+
)
|
|
261
|
+
self._out_keys = [unravel_key(out_key) for out_key in self._out_keys]
|
|
262
|
+
|
|
263
|
+
if getattr(ref_model, "generate", False):
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"The actor is configured to generate text, not compute the log-probs."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# update the in_keys for dispatch etc
|
|
269
|
+
self.in_keys = self.in_keys + ref_model.in_keys
|
|
270
|
+
self.in_keys = [unravel_key(in_key) for in_key in self.in_keys]
|
|
271
|
+
|
|
272
|
+
self.add_to_reward = add_to_reward
|
|
273
|
+
# check that the model has parameters
|
|
274
|
+
self.__dict__["ref_model"] = ref_model
|
|
275
|
+
|
|
276
|
+
# self._buffers["actor_params"] = params.clone().detach()
|
|
277
|
+
|
|
278
|
+
self.device = device
|
|
279
|
+
|
|
280
|
+
# find the sample log-prob key
|
|
281
|
+
self.log_prob_full_key = log_prob_key
|
|
282
|
+
|
|
283
|
+
self._tokenizer = tokenizer
|
|
284
|
+
self.assistant_only = assistant_only
|
|
285
|
+
self.padding_side = padding_side
|
|
286
|
+
|
|
287
|
+
if not isinstance(coef, torch.Tensor):
|
|
288
|
+
coef = torch.as_tensor(coef)
|
|
289
|
+
self.register_buffer("coef", coef)
|
|
290
|
+
# sanity check for the ref_model
|
|
291
|
+
if not getattr(ref_model, "input_mode", "tokens") == "tokens":
|
|
292
|
+
raise ValueError(
|
|
293
|
+
"The ref_model must be configured to use tokens as input. Please set the `input_mode` argument to `tokens`."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def pad_output(self):
|
|
298
|
+
# We need pad_output to match the pad_output of the inference model
|
|
299
|
+
return self.ref_model.pad_output
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def tokenizer(self):
|
|
303
|
+
tokenizer = self._tokenizer
|
|
304
|
+
if tokenizer is not None:
|
|
305
|
+
return tokenizer
|
|
306
|
+
try:
|
|
307
|
+
return self.ref_model.tokenizer
|
|
308
|
+
except AttributeError:
|
|
309
|
+
raise AttributeError(
|
|
310
|
+
"The ref_model does not have a tokenizer. Please pass the tokenizer to the constructor."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def set_container(self, container: Transform | EnvBase) -> None:
|
|
314
|
+
result = super().set_container(container)
|
|
315
|
+
if self.action_key is None:
|
|
316
|
+
parent = getattr(self, "parent", None)
|
|
317
|
+
if parent is not None:
|
|
318
|
+
action_keys = parent.action_keys
|
|
319
|
+
if len(action_keys) != 1:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"More than one action_key found. Please pass the `action_key` argument directly to {type(self).__name__}."
|
|
322
|
+
)
|
|
323
|
+
action_key = action_keys[0]
|
|
324
|
+
self.action_key = action_key
|
|
325
|
+
return result
|
|
326
|
+
|
|
327
|
+
def _reset(
|
|
328
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
329
|
+
) -> TensorDictBase:
|
|
330
|
+
with _set_missing_tolerance(self, True):
|
|
331
|
+
tensordict_reset = self._step(tensordict_reset, tensordict_reset)
|
|
332
|
+
return tensordict_reset
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def action_key(self) -> NestedKey:
|
|
336
|
+
# Get the action from the base env (a ChatEnv).
|
|
337
|
+
if self.parent.base_env.input_mode == "history":
|
|
338
|
+
return ("history", "full")
|
|
339
|
+
if self.parent.base_env.input_mode == "text":
|
|
340
|
+
return ("text", "full")
|
|
341
|
+
if self.parent.base_env.input_mode == "tokens":
|
|
342
|
+
return ("tokens", "full")
|
|
343
|
+
raise ValueError(f"Invalid input mode: {self.parent.base_env.input_mode}")
|
|
344
|
+
|
|
345
|
+
def _step(
|
|
346
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
347
|
+
) -> TensorDictBase:
|
|
348
|
+
original_device = None
|
|
349
|
+
if self.device is not None:
|
|
350
|
+
original_device = tensordict.device
|
|
351
|
+
tensordict = tensordict.to(self.device)
|
|
352
|
+
next_tensordict = next_tensordict.to(self.device)
|
|
353
|
+
# tensordict = self._get_text_response(tensordict, next_tensordict)
|
|
354
|
+
response = tensordict.get(self.action_key, None)
|
|
355
|
+
if response is None:
|
|
356
|
+
if not self.missing_tolerance:
|
|
357
|
+
raise RuntimeError(
|
|
358
|
+
f"Action with key {self.action_key} not found data {tensordict}"
|
|
359
|
+
)
|
|
360
|
+
# being called after reset or without action, skipping
|
|
361
|
+
if self.out_keys[0] != "reward" and self.parent is not None:
|
|
362
|
+
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
|
|
363
|
+
return next_tensordict
|
|
364
|
+
|
|
365
|
+
# We use the ("tokens", "full") key to get the log-probs of the reference model
|
|
366
|
+
with torch.device(self.device) if self.device is not None else nullcontext():
|
|
367
|
+
td_input = tensordict.copy()
|
|
368
|
+
ref_log_prob_td = self.ref_model(td_input)
|
|
369
|
+
if self.pad_output:
|
|
370
|
+
ref_log_prob_padded = ref_log_prob_td.get(self.log_prob_full_key)
|
|
371
|
+
else:
|
|
372
|
+
ref_log_prob_unpadded = ref_log_prob_td.get(
|
|
373
|
+
self.log_prob_full_key, as_list=True # type: ignore[misc]
|
|
374
|
+
)
|
|
375
|
+
if self.assistant_only:
|
|
376
|
+
# Get the assistant mask
|
|
377
|
+
mask = tensordict.get(("masks", "all_assistant_mask"))
|
|
378
|
+
# mask will often be None - fall back on prompt / response separation
|
|
379
|
+
if mask is None:
|
|
380
|
+
if self.pad_output:
|
|
381
|
+
# simple case: just take the prompt length
|
|
382
|
+
prompt_length = tensordict.get(("tokens", "prompt")).shape[-1]
|
|
383
|
+
mask = tensordict.get(("masks", "all_attention_mask")).clone()
|
|
384
|
+
mask[..., :prompt_length] = False
|
|
385
|
+
else:
|
|
386
|
+
# simple case: just take the prompt length
|
|
387
|
+
prompt_length = [
|
|
388
|
+
t.size(-1)
|
|
389
|
+
for t in tensordict.get(("tokens", "prompt"), as_list=True) # type: ignore[misc]
|
|
390
|
+
]
|
|
391
|
+
mask = tensordict.get(("masks", "all_attention_mask"), as_list=True) # type: ignore[misc]
|
|
392
|
+
for i in range(len(prompt_length)):
|
|
393
|
+
mask[i] = mask[i].clone()
|
|
394
|
+
mask[i][..., : prompt_length[i]] = False
|
|
395
|
+
|
|
396
|
+
# we want to keep the batch dimension
|
|
397
|
+
ref_log_prob_list = []
|
|
398
|
+
if self.pad_output:
|
|
399
|
+
for i in range(ref_log_prob_padded.size(0)):
|
|
400
|
+
ref_log_prob_list.append(
|
|
401
|
+
ref_log_prob_padded[i].masked_fill(~mask[i], 0)
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
for i in range(len(ref_log_prob_unpadded)):
|
|
405
|
+
ref_log_prob_list.append(
|
|
406
|
+
ref_log_prob_unpadded[i].masked_fill(~mask[i], 0)
|
|
407
|
+
)
|
|
408
|
+
if self.pad_output:
|
|
409
|
+
ref_log_prob = pad_sequence(
|
|
410
|
+
ref_log_prob_list,
|
|
411
|
+
batch_first=True,
|
|
412
|
+
padding_value=0,
|
|
413
|
+
padding_side=self.padding_side,
|
|
414
|
+
)
|
|
415
|
+
else:
|
|
416
|
+
ref_log_prob = torch.nested.nested_tensor(
|
|
417
|
+
ref_log_prob_list, layout=torch.strided
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# we obtain the current log-probs (already computed) from the current tensordict
|
|
421
|
+
if self.pad_output:
|
|
422
|
+
curr_log_prob_padded = tensordict.get(self.log_prob_full_key)
|
|
423
|
+
else:
|
|
424
|
+
curr_log_prob_unpadded = tensordict.get(
|
|
425
|
+
self.log_prob_full_key, as_list=True # type: ignore[misc]
|
|
426
|
+
)
|
|
427
|
+
if self.assistant_only:
|
|
428
|
+
# we want to keep the batch dimension
|
|
429
|
+
curr_log_prob_list = []
|
|
430
|
+
if self.pad_output:
|
|
431
|
+
for i in range(curr_log_prob_padded.size(0)):
|
|
432
|
+
curr_log_prob_list.append(
|
|
433
|
+
curr_log_prob_padded[i].masked_fill(~mask[i], 0)
|
|
434
|
+
)
|
|
435
|
+
else:
|
|
436
|
+
for i in range(len(curr_log_prob_unpadded)):
|
|
437
|
+
curr_log_prob_list.append(
|
|
438
|
+
curr_log_prob_unpadded[i].masked_fill(~mask[i], 0)
|
|
439
|
+
)
|
|
440
|
+
if self.pad_output:
|
|
441
|
+
curr_log_prob = pad_sequence(
|
|
442
|
+
curr_log_prob_list,
|
|
443
|
+
batch_first=True,
|
|
444
|
+
padding_value=0,
|
|
445
|
+
padding_side=self.padding_side,
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
curr_log_prob = torch.nested.nested_tensor(
|
|
449
|
+
curr_log_prob_list, layout=torch.strided
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
ref_log_prob = ref_log_prob.to(curr_log_prob.device)
|
|
453
|
+
# We want the log-probs to have a similar dim to the reward
|
|
454
|
+
curr_log_prob = curr_log_prob.unsqueeze(-1)
|
|
455
|
+
ref_log_prob = ref_log_prob.unsqueeze(-1)
|
|
456
|
+
|
|
457
|
+
for i in range(ref_log_prob.size(0)):
|
|
458
|
+
if ref_log_prob[i].shape != curr_log_prob[i].shape:
|
|
459
|
+
# Don't check shapes if nested
|
|
460
|
+
raise ValueError(
|
|
461
|
+
f"the log-probability tensor shapes must match, got cur_log_prob.shape={curr_log_prob[i].shape} and log_prob.shape={ref_log_prob[i].shape}. "
|
|
462
|
+
f"One possible reason is that the padding token is identical to the eos token, which means that the eos_token log_prob is truncated from the "
|
|
463
|
+
f"reference model output."
|
|
464
|
+
)
|
|
465
|
+
kl = curr_log_prob - ref_log_prob
|
|
466
|
+
if self.add_to_reward:
|
|
467
|
+
reward_key = self.in_keys[0]
|
|
468
|
+
reward = next_tensordict.get(reward_key)
|
|
469
|
+
# we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x)
|
|
470
|
+
if not reward.is_nested and ref_log_prob.is_nested:
|
|
471
|
+
reward = torch.nested.nested_tensor(
|
|
472
|
+
[rew.expand(lp.shape) for rew, lp in zip(reward, ref_log_prob)],
|
|
473
|
+
layout=torch.strided,
|
|
474
|
+
)
|
|
475
|
+
if reward is not None and reward.ndim != curr_log_prob.ndim:
|
|
476
|
+
raise ValueError(
|
|
477
|
+
"The number of dimensions of reward must be the same as the number of dimensions of the KL "
|
|
478
|
+
f"term. Got ndim={reward.ndim} and {curr_log_prob.ndim} respectively."
|
|
479
|
+
)
|
|
480
|
+
if reward is None:
|
|
481
|
+
reward = 0
|
|
482
|
+
reward = reward - self.coef * kl
|
|
483
|
+
next_tensordict.set(self.out_keys[0], reward)
|
|
484
|
+
next_tensordict.set(self.out_keys[1], kl)
|
|
485
|
+
next_tensordict.set(self.out_keys[2], ref_log_prob)
|
|
486
|
+
if original_device is not None:
|
|
487
|
+
next_tensordict = next_tensordict.to(original_device)
|
|
488
|
+
return next_tensordict
|
|
489
|
+
|
|
490
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
491
|
+
next_td = tensordict.pop("next")
|
|
492
|
+
next_td = self._step(tensordict, next_td)
|
|
493
|
+
return tensordict.set("next", next_td)
|
|
494
|
+
|
|
495
|
+
def transform_output_spec(self, output_spec: Composite) -> Composite:
|
|
496
|
+
in_key = unravel_key(self.in_keys[0])
|
|
497
|
+
out_key = unravel_key(self.out_keys[0])
|
|
498
|
+
|
|
499
|
+
if "full_observation_spec" in output_spec.keys():
|
|
500
|
+
observation_spec = output_spec["full_observation_spec"]
|
|
501
|
+
else:
|
|
502
|
+
observation_spec = Composite(
|
|
503
|
+
shape=output_spec.shape, device=output_spec.device
|
|
504
|
+
)
|
|
505
|
+
output_spec["full_observation_spec"] = observation_spec
|
|
506
|
+
|
|
507
|
+
if in_key == "reward" and out_key == "reward":
|
|
508
|
+
parent = self.parent
|
|
509
|
+
|
|
510
|
+
reward_keys = parent.reward_keys
|
|
511
|
+
if len(reward_keys) == 1:
|
|
512
|
+
reward_key = reward_keys[0]
|
|
513
|
+
shape = output_spec["full_reward_spec"].shape
|
|
514
|
+
elif "reward" in reward_keys:
|
|
515
|
+
reward_key = "reward"
|
|
516
|
+
shape = output_spec["full_reward_spec"].shape
|
|
517
|
+
else:
|
|
518
|
+
shape = output_spec.shape
|
|
519
|
+
reward_key = "reward"
|
|
520
|
+
# For LLMs, the shape of the reward is (batch, -1, 1)
|
|
521
|
+
shape = torch.Size((*shape, -1, 1))
|
|
522
|
+
reward_spec = Unbounded(
|
|
523
|
+
device=output_spec.device,
|
|
524
|
+
shape=shape,
|
|
525
|
+
)
|
|
526
|
+
output_spec["full_reward_spec"] = Composite(
|
|
527
|
+
{reward_key: reward_spec},
|
|
528
|
+
shape=output_spec["full_reward_spec"].shape,
|
|
529
|
+
)
|
|
530
|
+
elif in_key == "reward":
|
|
531
|
+
# TODO: we should at least allow to make this a component of the reward specs, to avoid a call during reset
|
|
532
|
+
parent = self.parent
|
|
533
|
+
reward_spec = output_spec["full_reward_spec"][parent.reward_key]
|
|
534
|
+
|
|
535
|
+
shape = output_spec["full_reward_spec"].shape
|
|
536
|
+
# For LLMs, the shape of the reward is (batch, -1, 1)
|
|
537
|
+
shape = torch.Size((*shape, -1, 1))
|
|
538
|
+
reward_spec = reward_spec.clone()
|
|
539
|
+
reward_spec.shape = shape
|
|
540
|
+
|
|
541
|
+
# then we need to populate the output keys
|
|
542
|
+
observation_spec[out_key] = reward_spec
|
|
543
|
+
else:
|
|
544
|
+
observation_spec = output_spec["full_observation_spec"]
|
|
545
|
+
reward_spec = observation_spec[in_key]
|
|
546
|
+
|
|
547
|
+
shape = observation_spec.shape
|
|
548
|
+
shape = torch.Size((*shape, -1, 1))
|
|
549
|
+
reward_spec = reward_spec.clone()
|
|
550
|
+
reward_spec.shape = shape
|
|
551
|
+
|
|
552
|
+
# then we need to populate the output keys
|
|
553
|
+
observation_spec[out_key] = reward_spec
|
|
554
|
+
|
|
555
|
+
observation_spec[self.out_keys[1]] = reward_spec.clone()
|
|
556
|
+
|
|
557
|
+
return output_spec
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class RetrieveLogProb(Transform):
|
|
561
|
+
"""A transform to retrieve log-probabilities from a model for KL divergence computation.
|
|
562
|
+
|
|
563
|
+
This transform computes log-probabilities from a reference model, which can then be used
|
|
564
|
+
to compute KL divergence with another model's log-probabilities. It's designed to work
|
|
565
|
+
with the :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` and :class:`~torchrl.envs.llm.transforms.kl.KLComputation` transforms.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
model (LLMWrapperBase): the model to use to compute the log-probs.
|
|
569
|
+
|
|
570
|
+
Keyword Args:
|
|
571
|
+
log_probs_full_key (NestedKey): the key where the log-probs are stored.
|
|
572
|
+
If not provided, the key will be retrieved from the model's `log_probs_key` attribute
|
|
573
|
+
(i.e., `(model.log_probs_key, "full")`).
|
|
574
|
+
assistant_only (bool): whether to zero out the log-probs of the non-assistant tokens (i.e., steps of history
|
|
575
|
+
where the role is not `"assistant"`). Defaults to `True`.
|
|
576
|
+
|
|
577
|
+
.. note:: When `assistant_only=True`, the model must have `input_mode='history'` to properly identify
|
|
578
|
+
assistant tokens. For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`.
|
|
579
|
+
This ensures users are conscious of the limitation that assistant token identification requires
|
|
580
|
+
structured conversation history.
|
|
581
|
+
|
|
582
|
+
tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
|
|
583
|
+
To control the tokenization in the ref_model, pass the tokenizer kwargs to the ref_model constructor.
|
|
584
|
+
Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_dict": True, "padding": False, "add_generation_prompt": False}`.
|
|
585
|
+
tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `ref_model`.
|
|
586
|
+
detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
|
|
587
|
+
device (torch.device): the device to use for tensor creation. Defaults to `None`.
|
|
588
|
+
padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
|
|
589
|
+
|
|
590
|
+
Examples:
|
|
591
|
+
>>> from torchrl.data.llm import History
|
|
592
|
+
>>> from torchrl.modules.llm import TransformersWrapper
|
|
593
|
+
>>> from torchrl.modules.llm.policies import ChatHistory
|
|
594
|
+
>>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
|
|
595
|
+
>>> from tensordict import TensorDict, set_list_to_stack
|
|
596
|
+
>>> import torch
|
|
597
|
+
>>>
|
|
598
|
+
>>> # Set up list to stack for History
|
|
599
|
+
>>> set_list_to_stack(True).set()
|
|
600
|
+
>>>
|
|
601
|
+
>>> # Create chat data
|
|
602
|
+
>>> chats = [
|
|
603
|
+
... [
|
|
604
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
|
605
|
+
... {"role": "user", "content": "Hello, how are you?"},
|
|
606
|
+
... {"role": "assistant", "content": "I'm doing well, thank you!"},
|
|
607
|
+
... ],
|
|
608
|
+
... [
|
|
609
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
|
610
|
+
... {"role": "user", "content": "What's the weather like?"},
|
|
611
|
+
... {"role": "assistant", "content": "I can't check the weather for you."},
|
|
612
|
+
... ],
|
|
613
|
+
... ]
|
|
614
|
+
>>> history = History.from_chats(chats)
|
|
615
|
+
>>> print(f"Created history with shape: {history.shape}")
|
|
616
|
+
Created history with shape: torch.Size([2, 3])
|
|
617
|
+
>>>
|
|
618
|
+
>>> # Setup tokenizer and model
|
|
619
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
|
620
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
|
621
|
+
>>> model = OPTForCausalLM(OPTConfig()).eval()
|
|
622
|
+
>>>
|
|
623
|
+
>>> # Create reference model
|
|
624
|
+
>>> ref_model = TransformersWrapper(
|
|
625
|
+
... model,
|
|
626
|
+
... tokenizer=tokenizer,
|
|
627
|
+
... input_mode="history",
|
|
628
|
+
... generate=False,
|
|
629
|
+
... return_log_probs=True,
|
|
630
|
+
... pad_output=True,
|
|
631
|
+
... )
|
|
632
|
+
>>>
|
|
633
|
+
>>> # Create the RetrieveLogProb transform
|
|
634
|
+
>>> transform = RetrieveLogProb(
|
|
635
|
+
... ref_model,
|
|
636
|
+
... assistant_only=True,
|
|
637
|
+
... tokenizer=tokenizer,
|
|
638
|
+
... )
|
|
639
|
+
>>>
|
|
640
|
+
>>> # Prepare data using ChatHistory
|
|
641
|
+
>>> chat_history = ChatHistory(full=history)
|
|
642
|
+
>>> data = TensorDict(history=chat_history, batch_size=(2,))
|
|
643
|
+
>>>
|
|
644
|
+
>>> # Apply the transform to get reference log probabilities
|
|
645
|
+
>>> result = transform(data)
|
|
646
|
+
>>> log_probs_key = (ref_model.log_probs_key, "full")
|
|
647
|
+
>>> ref_log_probs = result.get(log_probs_key)
|
|
648
|
+
>>> print(f"Log-probs shape: {ref_log_probs.shape}")
|
|
649
|
+
Log-probs shape: torch.Size([2, 26])
|
|
650
|
+
|
|
651
|
+
.. note::
|
|
652
|
+
By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length).
|
|
653
|
+
Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding).
|
|
654
|
+
The reference log probabilities are computed only for assistant tokens when `assistant_only=True`.
|
|
655
|
+
|
|
656
|
+
**Input Mode Compatibility:**
|
|
657
|
+
- When `assistant_only=True` (default), the model must have `input_mode='history'` to properly identify assistant tokens.
|
|
658
|
+
- When `assistant_only=False`, the transform works with any input mode (`"history"`, `"text"`, or `"tokens"`).
|
|
659
|
+
- This design ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
|
|
660
|
+
|
|
661
|
+
.. seealso::
|
|
662
|
+
:class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: A higher-level transform that combines two `RetrieveLogProb` instances with `KLComputation`.
|
|
663
|
+
:class:`~torchrl.envs.llm.transforms.kl.KLComputation`: A transform that computes KL divergence between two log-prob tensors.
|
|
664
|
+
:class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead).
|
|
665
|
+
"""
|
|
666
|
+
|
|
667
|
+
def __init__(
|
|
668
|
+
self,
|
|
669
|
+
model: LLMWrapperBase,
|
|
670
|
+
*,
|
|
671
|
+
log_probs_full_key: NestedKey | None = None,
|
|
672
|
+
assistant_only: bool = True,
|
|
673
|
+
tokenizer_kwargs: dict | None = None,
|
|
674
|
+
detach: bool = True,
|
|
675
|
+
device: torch.device | None = None,
|
|
676
|
+
tokenizer: transformers.AutoTokenizer | None = None,
|
|
677
|
+
padding_side: str = "left",
|
|
678
|
+
):
|
|
679
|
+
# Set up keys
|
|
680
|
+
if log_probs_full_key is None:
|
|
681
|
+
log_probs_full_key = (model.log_probs_key, "full")
|
|
682
|
+
elif (
|
|
683
|
+
not isinstance(log_probs_full_key, tuple)
|
|
684
|
+
or log_probs_full_key[-1] != "full"
|
|
685
|
+
):
|
|
686
|
+
warnings.warn(
|
|
687
|
+
f"The log_probs_full_key {log_probs_full_key} is not a tuple or does not end with 'full'. "
|
|
688
|
+
"This may cause issues with the KL computation. "
|
|
689
|
+
"Please use a tuple with the log_probs_key and 'full' as the last element."
|
|
690
|
+
)
|
|
691
|
+
self.log_probs_full_key = log_probs_full_key
|
|
692
|
+
|
|
693
|
+
# Set up input/output keys
|
|
694
|
+
in_keys = list(model.in_keys)
|
|
695
|
+
out_keys = [self.log_probs_full_key]
|
|
696
|
+
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
|
697
|
+
|
|
698
|
+
# Store model and configuration
|
|
699
|
+
self.model = model
|
|
700
|
+
self.assistant_only = assistant_only
|
|
701
|
+
self.detach = detach
|
|
702
|
+
self.device = device
|
|
703
|
+
self.tokenizer = tokenizer
|
|
704
|
+
self.padding_side = padding_side
|
|
705
|
+
|
|
706
|
+
# Set up tokenizer kwargs
|
|
707
|
+
if tokenizer_kwargs is None:
|
|
708
|
+
tokenizer_kwargs = {}
|
|
709
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
|
|
710
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
|
711
|
+
tokenizer_kwargs.setdefault("return_dict", True)
|
|
712
|
+
tokenizer_kwargs.setdefault("padding", False)
|
|
713
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
|
714
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
|
715
|
+
|
|
716
|
+
# Validate model configuration (after setting assistant_only)
|
|
717
|
+
self._validate_model_config(model)
|
|
718
|
+
|
|
719
|
+
def _validate_model_config(self, model: LLMWrapperBase):
|
|
720
|
+
"""Validate model configuration."""
|
|
721
|
+
if not getattr(model, "return_log_probs", True):
|
|
722
|
+
raise ValueError(
|
|
723
|
+
"The model must have `return_log_probs=True` to use the `RetrieveLogProb` transform."
|
|
724
|
+
)
|
|
725
|
+
if getattr(model, "generate", True):
|
|
726
|
+
raise ValueError(
|
|
727
|
+
"The model must have `generate=False` to use the `RetrieveLogProb` transform."
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
# Check input mode compatibility with assistant_only
|
|
731
|
+
input_mode = getattr(model, "input_mode", "history")
|
|
732
|
+
if self.assistant_only and input_mode != "history":
|
|
733
|
+
raise ValueError(
|
|
734
|
+
f"The model must have `input_mode='history'` when `assistant_only=True`. "
|
|
735
|
+
f"Current input_mode is '{input_mode}'. "
|
|
736
|
+
f"To use input_mode '{input_mode}', set `assistant_only=False`."
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
740
|
+
next_td = tensordict.get("next")
|
|
741
|
+
next_is_none = False
|
|
742
|
+
if next_td is None:
|
|
743
|
+
next_is_none = True
|
|
744
|
+
next_td = tensordict
|
|
745
|
+
output = self._step(tensordict, next_td)
|
|
746
|
+
if next_is_none:
|
|
747
|
+
return output
|
|
748
|
+
return tensordict.set("next", output)
|
|
749
|
+
|
|
750
|
+
def _mask_assistant_tokens(
|
|
751
|
+
self, td: TensorDictBase, lp_key: NestedKey
|
|
752
|
+
) -> torch.Tensor:
|
|
753
|
+
"""Mask log-probs to only include assistant tokens.
|
|
754
|
+
|
|
755
|
+
Args:
|
|
756
|
+
td: TensorDict containing the data
|
|
757
|
+
lp_key: Key for log-probs in the TensorDict
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
Masked log-probs tensor
|
|
761
|
+
"""
|
|
762
|
+
with torch.device(self.device) if self.device is not None else nullcontext():
|
|
763
|
+
# Get assistant mask
|
|
764
|
+
assistant_masks = td.get(("masks", "all_assistant_mask"), as_list=True) # type: ignore[misc]
|
|
765
|
+
log_probs = td.get(lp_key, as_list=True) # type: ignore[misc]
|
|
766
|
+
log_probs = [
|
|
767
|
+
torch.masked_fill(lp, ~mask, 0.0)
|
|
768
|
+
for lp, mask in _zip_strict(log_probs, assistant_masks)
|
|
769
|
+
]
|
|
770
|
+
if self.model.pad_output:
|
|
771
|
+
log_probs = pad_sequence(
|
|
772
|
+
log_probs,
|
|
773
|
+
batch_first=True,
|
|
774
|
+
padding_value=0.0,
|
|
775
|
+
padding_side=self.padding_side,
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
log_probs = torch.nested.as_nested_tensor(
|
|
779
|
+
log_probs, layout=self.model.layout
|
|
780
|
+
)
|
|
781
|
+
return log_probs
|
|
782
|
+
|
|
783
|
+
@set_list_to_stack(True)
|
|
784
|
+
def _step(
|
|
785
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
786
|
+
) -> TensorDictBase:
|
|
787
|
+
# Compute log-probs using the model
|
|
788
|
+
# Use tensordict since we want to process the "full" entry
|
|
789
|
+
ref_td = self.model(tensordict.copy())
|
|
790
|
+
tmp_log_probs_key = (self.model.log_probs_key, "full")
|
|
791
|
+
|
|
792
|
+
# Apply assistant masking if requested
|
|
793
|
+
if self.assistant_only:
|
|
794
|
+
log_probs = self._mask_assistant_tokens(ref_td, tmp_log_probs_key)
|
|
795
|
+
ref_td.set(tmp_log_probs_key, log_probs)
|
|
796
|
+
|
|
797
|
+
# Rename and store the log-probs
|
|
798
|
+
if tmp_log_probs_key != self.log_probs_full_key:
|
|
799
|
+
ref_td.rename_key_(tmp_log_probs_key, self.log_probs_full_key)
|
|
800
|
+
next_tensordict.update(ref_td, keys_to_update=(self.log_probs_full_key,))
|
|
801
|
+
|
|
802
|
+
return next_tensordict
|
|
803
|
+
|
|
804
|
+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
805
|
+
# Add kl to observation spec
|
|
806
|
+
observation_spec["kl_penalty"] = Unbounded(
|
|
807
|
+
device=observation_spec.device,
|
|
808
|
+
shape=observation_spec.shape,
|
|
809
|
+
)
|
|
810
|
+
return observation_spec
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
class RayRetrieveKL(RayTransform):
|
|
814
|
+
"""A Ray-based implementation of :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`.
|
|
815
|
+
|
|
816
|
+
This class creates a Ray remote actor from RetrieveKL that can be shared across multiple workers.
|
|
817
|
+
All method calls are delegated to the remote actor, ensuring that multiple environments can
|
|
818
|
+
share the same KL computation resources.
|
|
819
|
+
|
|
820
|
+
To avoid serialization issues with large models, this class supports model factories
|
|
821
|
+
that create models on the remote actor rather than passing full models through Ray channels.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
gen_model (LLMWrapperBase | Literal["from_collector"]): the generation model, or "from_collector" for lazy initialization.
|
|
825
|
+
Prefer using a model factory instead to avoid serialization issues.
|
|
826
|
+
ref_model (LLMWrapperBase | None): the reference model. Prefer using a model factory instead
|
|
827
|
+
to avoid serialization issues.
|
|
828
|
+
|
|
829
|
+
Keyword Args:
|
|
830
|
+
gen_model_factory (Callable[[], LLMWrapperBase], optional): A callable that returns a generation model.
|
|
831
|
+
This allows for explicit resource control and avoids serialization issues.
|
|
832
|
+
ref_model_factory (Callable[[], LLMWrapperBase], optional): A callable that returns a reference model.
|
|
833
|
+
This allows for explicit resource control and avoids serialization issues.
|
|
834
|
+
num_cpus (int, optional): Number of CPUs to allocate to the Ray actor. Defaults to 1.
|
|
835
|
+
num_gpus (int, optional): Number of GPUs to allocate to the Ray actor. Defaults to 0.
|
|
836
|
+
device (torch.device, optional): Device to use on the remote Ray actor for tensor operations.
|
|
837
|
+
The local Ray transform will handle CPU serialization and device restoration automatically.
|
|
838
|
+
Defaults to None.
|
|
839
|
+
actor_name (str, optional): Name of the Ray actor to use. If provided, the actor will be reused if it already exists.
|
|
840
|
+
**kwargs: Additional keyword arguments to pass to RetrieveKL.
|
|
841
|
+
|
|
842
|
+
Note:
|
|
843
|
+
When using model factories, the corresponding model arguments (gen_model, ref_model) should be None.
|
|
844
|
+
Model factories are preferred for large models to avoid serialization overhead.
|
|
845
|
+
|
|
846
|
+
Examples:
|
|
847
|
+
>>> # Option 1: Using model factories for explicit resource control
|
|
848
|
+
>>> def create_gen_model():
|
|
849
|
+
... return TransformersWrapper(model, tokenizer=tokenizer, generate=False, return_log_probs=True)
|
|
850
|
+
>>> def create_ref_model():
|
|
851
|
+
... return TransformersWrapper(ref_model, tokenizer=tokenizer, generate=False, return_log_probs=True)
|
|
852
|
+
>>> transform = RayRetrieveKL(
|
|
853
|
+
... gen_model=None, ref_model=None,
|
|
854
|
+
... gen_model_factory=create_gen_model,
|
|
855
|
+
... ref_model_factory=create_ref_model,
|
|
856
|
+
... num_gpus=1,
|
|
857
|
+
... device=torch.device("cuda")
|
|
858
|
+
... )
|
|
859
|
+
|
|
860
|
+
>>> # Option 2: Pass models directly (Ray handles serialization)
|
|
861
|
+
>>> transform = RayRetrieveKL(gen_model=gen_model, ref_model=ref_model, device=torch.device("cuda"))
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
def __init__(
|
|
865
|
+
self,
|
|
866
|
+
gen_model: LLMWrapperBase | Literal["from_collector"] | None = "from_collector",
|
|
867
|
+
ref_model: LLMWrapperBase | None = None,
|
|
868
|
+
*,
|
|
869
|
+
gen_model_factory: Callable[[], LLMWrapperBase] | None = None,
|
|
870
|
+
ref_model_factory: Callable[[], LLMWrapperBase] | None = None,
|
|
871
|
+
num_cpus: int | None = None,
|
|
872
|
+
num_gpus: int = 0,
|
|
873
|
+
device: DEVICE_TYPING | None = None,
|
|
874
|
+
actor_name: str | None = None,
|
|
875
|
+
**kwargs,
|
|
876
|
+
):
|
|
877
|
+
# Validate arguments: models and factories should not both be provided
|
|
878
|
+
if gen_model is not None and gen_model_factory is not None:
|
|
879
|
+
raise ValueError(
|
|
880
|
+
"Cannot provide both 'gen_model' and 'gen_model_factory'. Choose one."
|
|
881
|
+
)
|
|
882
|
+
if ref_model is not None and ref_model_factory is not None:
|
|
883
|
+
raise ValueError(
|
|
884
|
+
"Cannot provide both 'ref_model' and 'ref_model_factory'. Choose one."
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# Store creation parameters for actor creation
|
|
888
|
+
self._gen_model = gen_model
|
|
889
|
+
self._ref_model = ref_model
|
|
890
|
+
self._gen_model_factory = gen_model_factory
|
|
891
|
+
self._ref_model_factory = ref_model_factory
|
|
892
|
+
self._creation_kwargs = kwargs
|
|
893
|
+
# Store device separately for passing to remote actor
|
|
894
|
+
self._remote_device = device
|
|
895
|
+
|
|
896
|
+
# Default num_cpus
|
|
897
|
+
if num_cpus is None:
|
|
898
|
+
num_cpus = 1
|
|
899
|
+
|
|
900
|
+
# Call parent constructor without device (Ray transform handles CPU/device mapping)
|
|
901
|
+
super().__init__(
|
|
902
|
+
num_cpus=num_cpus,
|
|
903
|
+
num_gpus=num_gpus,
|
|
904
|
+
device=None, # Don't store device locally
|
|
905
|
+
actor_name=actor_name,
|
|
906
|
+
**kwargs,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
def _create_actor(self, **kwargs):
|
|
910
|
+
"""Create the remote RetrieveKL actor."""
|
|
911
|
+
# Create the remote RetrieveKL with resource specifications
|
|
912
|
+
RemoteRetrieveKL = self._ray.remote(
|
|
913
|
+
num_cpus=self._num_cpus, num_gpus=self._num_gpus
|
|
914
|
+
)(RetrieveKL)
|
|
915
|
+
|
|
916
|
+
if self._actor_name is not None:
|
|
917
|
+
RemoteRetrieveKL = RemoteRetrieveKL.options(name=self._actor_name)
|
|
918
|
+
|
|
919
|
+
# Determine how to create models on the remote actor
|
|
920
|
+
gen_model_arg = self._gen_model
|
|
921
|
+
ref_model_arg = self._ref_model
|
|
922
|
+
|
|
923
|
+
# If we have factories, we'll pass them and set models to None
|
|
924
|
+
creation_kwargs = self._creation_kwargs.copy()
|
|
925
|
+
if self._gen_model_factory is not None:
|
|
926
|
+
creation_kwargs["gen_model_factory"] = self._gen_model_factory
|
|
927
|
+
gen_model_arg = None
|
|
928
|
+
if self._ref_model_factory is not None:
|
|
929
|
+
creation_kwargs["ref_model_factory"] = self._ref_model_factory
|
|
930
|
+
ref_model_arg = None
|
|
931
|
+
|
|
932
|
+
# Pass device to the remote actor
|
|
933
|
+
if self._remote_device is not None:
|
|
934
|
+
creation_kwargs["device"] = self._remote_device
|
|
935
|
+
|
|
936
|
+
# Create the shared actor
|
|
937
|
+
actor = RemoteRetrieveKL.remote(
|
|
938
|
+
gen_model=gen_model_arg, ref_model=ref_model_arg, **creation_kwargs
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
return actor
|
|
942
|
+
|
|
943
|
+
def __repr__(self):
|
|
944
|
+
"""String representation."""
|
|
945
|
+
try:
|
|
946
|
+
if hasattr(self, "_actor") and self._actor is not None:
|
|
947
|
+
return self._ray.get(self._actor.__repr__.remote())
|
|
948
|
+
else:
|
|
949
|
+
return "RayRetrieveKL(actor=None)"
|
|
950
|
+
except Exception:
|
|
951
|
+
return f"RayRetrieveKL(actor={getattr(self, '_actor', 'None')})"
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
class RetrieveKL(Compose, metaclass=_RayServiceMetaClass):
|
|
955
|
+
"""A transform to retrieve the KL divergence between two models' log-probabilities.
|
|
956
|
+
|
|
957
|
+
This transform combines two :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb` instances
|
|
958
|
+
with a :class:`~torchrl.envs.llm.transforms.kl.KLComputation` to compute KL divergence
|
|
959
|
+
between a generation model and a reference model.
|
|
960
|
+
|
|
961
|
+
.. note::
|
|
962
|
+
Both gen_model and ref_model must use the same pad_output value (True or False), otherwise KL computation will fail.
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
gen_model (LLMWrapperBase): the generation model, wrapped in such a way that it does not generate but computes the log-probs.
|
|
966
|
+
In cases where the transform is used within a :class:`~torchrl.collectors.llm.LLMCollector` run on a remote worker, the
|
|
967
|
+
policy may not be available ahead of time. In this case, the `gen_model` can be set to `"from_collector"` (default) to retrieve the
|
|
968
|
+
policy from the collector. See :meth:`~torchrl.modules.llm.policies.LLMWrapperBase.get_new_version` for more details
|
|
969
|
+
about generating a new version of the policy to gather the log-probs.
|
|
970
|
+
ref_model (LLMWrapperBase): the reference model, wrapped in such a way that it does not generate but computes the log-probs.
|
|
971
|
+
|
|
972
|
+
Keyword Args:
|
|
973
|
+
gen_model_factory (Callable[[], LLMWrapperBase], optional): A callable that returns a generation model.
|
|
974
|
+
This allows for explicit resource control and avoids serialization issues when using Ray.
|
|
975
|
+
ref_model_factory (Callable[[], LLMWrapperBase], optional): A callable that returns a reference model.
|
|
976
|
+
This allows for explicit resource control and avoids serialization issues when using Ray.
|
|
977
|
+
assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
|
|
978
|
+
where the role is `"assistant"`). Defaults to `True`.
|
|
979
|
+
|
|
980
|
+
.. note:: When `assistant_only=True`, both models must have `input_mode='history'` to properly identify assistant tokens.
|
|
981
|
+
For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`.
|
|
982
|
+
This ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
|
|
983
|
+
|
|
984
|
+
gen_log_probs_full_key (str): the key where the log-probs of the generation model are stored. Defaults to `("log_probs", "full")`.
|
|
985
|
+
ref_log_probs_full_key (str): the key where the log-probs of the reference model are stored. Defaults to `("ref_log_probs", "full")`.
|
|
986
|
+
history_key (str): the key where the history is stored. Defaults to `"history"`.
|
|
987
|
+
tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
|
|
988
|
+
To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
|
|
989
|
+
Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
|
|
990
|
+
detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
|
|
991
|
+
device (torch.device): the device to cast the tensors to. This is not the device of the specs, but the device
|
|
992
|
+
onto which the tensors will be moved. It allows to keep the model on a different device
|
|
993
|
+
than the upcoming data itself. When using Ray service, this device will be used on the remote actor.
|
|
994
|
+
Defaults to `None`.
|
|
995
|
+
tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
|
|
996
|
+
padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
|
|
997
|
+
kl_key (NestedKey): the key where the KL divergence is stored. Defaults to `"kl_penalty"`.
|
|
998
|
+
add_to_reward (bool): whether to add the KL divergence to the reward. Defaults to `True`.
|
|
999
|
+
coeff (float): the coefficient for the KL term when adding to reward. Defaults to `1.0`.
|
|
1000
|
+
padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
|
|
1001
|
+
use_ray_service (bool, optional): if ``True``, returns a :class:`RayRetrieveKL` instance instead,
|
|
1002
|
+
which creates a Ray actor for shared KL computation across multiple environments.
|
|
1003
|
+
Defaults to ``False``.
|
|
1004
|
+
actor_name (str, optional): the name of the Ray actor to use. Defaults to `None`.
|
|
1005
|
+
**kwargs: additional arguments to pass to the `RetrieveLogProb` transform.
|
|
1006
|
+
|
|
1007
|
+
Examples:
|
|
1008
|
+
>>> from torchrl.data.llm import History
|
|
1009
|
+
>>> from torchrl.modules.llm import TransformersWrapper
|
|
1010
|
+
>>> from torchrl.modules.llm.policies import ChatHistory
|
|
1011
|
+
>>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
|
|
1012
|
+
>>> from tensordict import TensorDict, set_list_to_stack
|
|
1013
|
+
>>> import torch
|
|
1014
|
+
>>>
|
|
1015
|
+
>>> # Set up list to stack for History
|
|
1016
|
+
>>> set_list_to_stack(True).set()
|
|
1017
|
+
>>>
|
|
1018
|
+
>>> # Create chat data
|
|
1019
|
+
>>> chats = [
|
|
1020
|
+
... [
|
|
1021
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
|
1022
|
+
... {"role": "user", "content": "Hello, how are you?"},
|
|
1023
|
+
... {"role": "assistant", "content": "I'm doing well, thank you!"},
|
|
1024
|
+
... ],
|
|
1025
|
+
... [
|
|
1026
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
|
1027
|
+
... {"role": "user", "content": "What's the weather like?"},
|
|
1028
|
+
... {"role": "assistant", "content": "I can't check the weather for you."},
|
|
1029
|
+
... ],
|
|
1030
|
+
... ]
|
|
1031
|
+
>>> history = History.from_chats(chats)
|
|
1032
|
+
>>> print(f"Created history with shape: {history.shape}")
|
|
1033
|
+
Created history with shape: torch.Size([2, 3])
|
|
1034
|
+
>>>
|
|
1035
|
+
>>> # Setup tokenizer and model
|
|
1036
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
|
1037
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
|
1038
|
+
>>> model = OPTForCausalLM(OPTConfig()).eval()
|
|
1039
|
+
>>>
|
|
1040
|
+
>>> # Create generation and reference models
|
|
1041
|
+
>>> gen_model = TransformersWrapper(
|
|
1042
|
+
... model,
|
|
1043
|
+
... tokenizer=tokenizer,
|
|
1044
|
+
... input_mode="history",
|
|
1045
|
+
... generate=False,
|
|
1046
|
+
... return_log_probs=True,
|
|
1047
|
+
... pad_output=True,
|
|
1048
|
+
... log_probs_key="gen_log_probs",
|
|
1049
|
+
... )
|
|
1050
|
+
>>> ref_model = TransformersWrapper(
|
|
1051
|
+
... model,
|
|
1052
|
+
... tokenizer=tokenizer,
|
|
1053
|
+
... input_mode="history",
|
|
1054
|
+
... generate=False,
|
|
1055
|
+
... return_log_probs=True,
|
|
1056
|
+
... pad_output=True,
|
|
1057
|
+
... log_probs_key="ref_log_probs",
|
|
1058
|
+
... )
|
|
1059
|
+
>>>
|
|
1060
|
+
>>> # Create RetrieveKL transform
|
|
1061
|
+
>>> transform = RetrieveKL(
|
|
1062
|
+
... gen_model=gen_model,
|
|
1063
|
+
... ref_model=ref_model,
|
|
1064
|
+
... assistant_only=True,
|
|
1065
|
+
... tokenizer=tokenizer,
|
|
1066
|
+
... )
|
|
1067
|
+
>>>
|
|
1068
|
+
>>> # Prepare data with next tensordict using ChatHistory
|
|
1069
|
+
>>> chat_history = ChatHistory(full=history)
|
|
1070
|
+
>>> next_td = TensorDict(history=chat_history, batch_size=(2,))
|
|
1071
|
+
>>> data = TensorDict(history=chat_history, next=next_td, batch_size=(2,))
|
|
1072
|
+
>>>
|
|
1073
|
+
>>> # Apply transform
|
|
1074
|
+
>>> result = transform(data)
|
|
1075
|
+
>>> kl = result["next"].get("kl_penalty")
|
|
1076
|
+
>>> print(f"KL shape: {kl.shape}")
|
|
1077
|
+
KL shape: torch.Size([2, 26])
|
|
1078
|
+
|
|
1079
|
+
Note:
|
|
1080
|
+
**Input Mode Compatibility:**
|
|
1081
|
+
- When `assistant_only=True`, both models must have `input_mode='history'` to properly identify assistant tokens.
|
|
1082
|
+
- When `assistant_only=False`, the transform works with any input mode (`"history"`, `"text"`, or `"tokens"`).
|
|
1083
|
+
- This design ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
|
|
1084
|
+
|
|
1085
|
+
.. seealso::
|
|
1086
|
+
:class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: The base transform for retrieving log-probabilities from a single model.
|
|
1087
|
+
:class:`~torchrl.envs.llm.transforms.kl.KLComputation`: The transform that computes KL divergence between two log-prob tensors.
|
|
1088
|
+
:class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead).
|
|
1089
|
+
"""
|
|
1090
|
+
|
|
1091
|
+
_RayServiceClass = RayRetrieveKL
|
|
1092
|
+
|
|
1093
|
+
def __init__(
|
|
1094
|
+
self,
|
|
1095
|
+
gen_model: LLMWrapperBase | Literal["from_collector"] = "from_collector",
|
|
1096
|
+
ref_model: LLMWrapperBase | None = None,
|
|
1097
|
+
*,
|
|
1098
|
+
gen_model_factory: Callable[[], LLMWrapperBase] | None = None,
|
|
1099
|
+
ref_model_factory: Callable[[], LLMWrapperBase] | None = None,
|
|
1100
|
+
assistant_only: bool = True,
|
|
1101
|
+
history_key: str = "history",
|
|
1102
|
+
tokenizer_kwargs: dict[str, Any] | None = None,
|
|
1103
|
+
detach: bool = True,
|
|
1104
|
+
device: torch.device | None = None,
|
|
1105
|
+
tokenizer: transformers.AutoTokenizer | None = None,
|
|
1106
|
+
padding_side: str = "left",
|
|
1107
|
+
gen_log_probs_full_key: NestedKey = ("log_probs", "full"),
|
|
1108
|
+
ref_log_probs_full_key: NestedKey = ("ref_log_probs", "full"),
|
|
1109
|
+
kl_key: NestedKey = "kl_penalty",
|
|
1110
|
+
add_to_reward: bool = True,
|
|
1111
|
+
coeff: float = 1.0,
|
|
1112
|
+
use_ray_service: bool = False,
|
|
1113
|
+
**kwargs,
|
|
1114
|
+
):
|
|
1115
|
+
# Handle model factories - create models if factories are provided
|
|
1116
|
+
if gen_model_factory is not None:
|
|
1117
|
+
if gen_model is not None and gen_model != "from_collector":
|
|
1118
|
+
raise ValueError(
|
|
1119
|
+
"Cannot provide both 'gen_model' and 'gen_model_factory'. Choose one."
|
|
1120
|
+
)
|
|
1121
|
+
gen_model = gen_model_factory()
|
|
1122
|
+
|
|
1123
|
+
if ref_model_factory is not None:
|
|
1124
|
+
if ref_model is not None:
|
|
1125
|
+
raise ValueError(
|
|
1126
|
+
"Cannot provide both 'ref_model' and 'ref_model_factory'. Choose one."
|
|
1127
|
+
)
|
|
1128
|
+
ref_model = ref_model_factory()
|
|
1129
|
+
|
|
1130
|
+
if isinstance(gen_model, str) and gen_model == "from_collector":
|
|
1131
|
+
# Lazy init
|
|
1132
|
+
self._initialized = False
|
|
1133
|
+
self._init_params = {
|
|
1134
|
+
"ref_model": ref_model,
|
|
1135
|
+
"gen_model_factory": gen_model_factory,
|
|
1136
|
+
"ref_model_factory": ref_model_factory,
|
|
1137
|
+
"assistant_only": assistant_only,
|
|
1138
|
+
"history_key": history_key,
|
|
1139
|
+
"tokenizer_kwargs": tokenizer_kwargs,
|
|
1140
|
+
"detach": detach,
|
|
1141
|
+
"device": device,
|
|
1142
|
+
"tokenizer": tokenizer,
|
|
1143
|
+
"gen_log_probs_full_key": gen_log_probs_full_key,
|
|
1144
|
+
"ref_log_probs_full_key": ref_log_probs_full_key,
|
|
1145
|
+
"kl_key": kl_key,
|
|
1146
|
+
"add_to_reward": add_to_reward,
|
|
1147
|
+
"coeff": coeff,
|
|
1148
|
+
"padding_side": padding_side,
|
|
1149
|
+
**kwargs,
|
|
1150
|
+
}
|
|
1151
|
+
super().__init__()
|
|
1152
|
+
return
|
|
1153
|
+
|
|
1154
|
+
self._initialized = True
|
|
1155
|
+
|
|
1156
|
+
# Check pad_output consistency if both models are provided
|
|
1157
|
+
if hasattr(gen_model, "pad_output") and hasattr(ref_model, "pad_output"):
|
|
1158
|
+
if gen_model.pad_output != ref_model.pad_output:
|
|
1159
|
+
raise ValueError(
|
|
1160
|
+
f"pad_output mismatch: gen_model.pad_output={gen_model.pad_output}, "
|
|
1161
|
+
f"ref_model.pad_output={ref_model.pad_output}. "
|
|
1162
|
+
"Both models must use the same padding strategy for KL computation."
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
if not getattr(gen_model, "return_log_probs", True):
|
|
1166
|
+
raise ValueError(
|
|
1167
|
+
"The generation model must have `return_log_probs=True` to use the `RetrieveKL` transform."
|
|
1168
|
+
)
|
|
1169
|
+
elif getattr(gen_model, "generate", False):
|
|
1170
|
+
raise ValueError(
|
|
1171
|
+
"The generation model must have `generate=False` to use the `RetrieveKL` transform."
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
if not getattr(ref_model, "return_log_probs", True):
|
|
1175
|
+
raise ValueError(
|
|
1176
|
+
"The reference model must have `return_log_probs=True` to use the `RetrieveKL` transform."
|
|
1177
|
+
)
|
|
1178
|
+
elif getattr(ref_model, "generate", False):
|
|
1179
|
+
raise ValueError(
|
|
1180
|
+
"The reference model must have `generate=False` to use the `RetrieveKL` transform."
|
|
1181
|
+
)
|
|
1182
|
+
if getattr(gen_model, "log_probs_key", "gen_log_probs") == getattr(
|
|
1183
|
+
ref_model, "log_probs_key", "log_probs"
|
|
1184
|
+
):
|
|
1185
|
+
raise ValueError(
|
|
1186
|
+
"The generation and reference models must have different `log_prob_key` values to use the `RetrieveKL` transform."
|
|
1187
|
+
)
|
|
1188
|
+
if gen_model is None:
|
|
1189
|
+
raise ValueError("gen_model cannot be None when not using 'from_collector'")
|
|
1190
|
+
if ref_model is None:
|
|
1191
|
+
raise ValueError("ref_model cannot be None")
|
|
1192
|
+
|
|
1193
|
+
t1 = RetrieveLogProb(
|
|
1194
|
+
gen_model,
|
|
1195
|
+
log_probs_full_key=gen_log_probs_full_key,
|
|
1196
|
+
assistant_only=assistant_only,
|
|
1197
|
+
tokenizer_kwargs=tokenizer_kwargs,
|
|
1198
|
+
detach=detach,
|
|
1199
|
+
device=device,
|
|
1200
|
+
tokenizer=tokenizer,
|
|
1201
|
+
padding_side=padding_side,
|
|
1202
|
+
**kwargs,
|
|
1203
|
+
)
|
|
1204
|
+
t2 = RetrieveLogProb(
|
|
1205
|
+
ref_model,
|
|
1206
|
+
log_probs_full_key=ref_log_probs_full_key,
|
|
1207
|
+
assistant_only=assistant_only,
|
|
1208
|
+
tokenizer_kwargs=tokenizer_kwargs,
|
|
1209
|
+
detach=detach,
|
|
1210
|
+
device=device,
|
|
1211
|
+
tokenizer=tokenizer,
|
|
1212
|
+
padding_side=padding_side,
|
|
1213
|
+
**kwargs,
|
|
1214
|
+
)
|
|
1215
|
+
t3 = KLComputation(
|
|
1216
|
+
gen_log_probs_full_key=gen_log_probs_full_key,
|
|
1217
|
+
ref_log_probs_full_key=ref_log_probs_full_key,
|
|
1218
|
+
kl_key=kl_key,
|
|
1219
|
+
add_to_reward=add_to_reward,
|
|
1220
|
+
coeff=coeff,
|
|
1221
|
+
)
|
|
1222
|
+
super().__init__(t1, t2, t3)
|
|
1223
|
+
|
|
1224
|
+
def _init_deferred(self):
|
|
1225
|
+
torchrl_logger.info("Initializing RetrieveKL transform")
|
|
1226
|
+
container = self.container
|
|
1227
|
+
if container is None:
|
|
1228
|
+
# also logging, since this will be sometimes hidden within the AttributeError
|
|
1229
|
+
torchrl_logger.warning(
|
|
1230
|
+
"The container is not set. Please set the container before calling this method."
|
|
1231
|
+
)
|
|
1232
|
+
raise ValueError(
|
|
1233
|
+
"The container is not set. Please set the container before calling this method."
|
|
1234
|
+
)
|
|
1235
|
+
container.empty_cache()
|
|
1236
|
+
self.empty_cache()
|
|
1237
|
+
collector = self.collector
|
|
1238
|
+
if collector is None:
|
|
1239
|
+
# also logging, since this will be sometimes hidden within the AttributeError
|
|
1240
|
+
torchrl_logger.warning(
|
|
1241
|
+
"The collector is not set. Please set the collector before calling this method."
|
|
1242
|
+
)
|
|
1243
|
+
raise ValueError(
|
|
1244
|
+
"The collector is not set. Please set the collector before calling this method."
|
|
1245
|
+
)
|
|
1246
|
+
ref_model = self._init_params["ref_model"]
|
|
1247
|
+
pad_output = getattr(ref_model, "pad_output", None)
|
|
1248
|
+
gen_log_probs_full_key = self._init_params["gen_log_probs_full_key"]
|
|
1249
|
+
if (
|
|
1250
|
+
not isinstance(gen_log_probs_full_key, tuple)
|
|
1251
|
+
or gen_log_probs_full_key[-1] != "full"
|
|
1252
|
+
):
|
|
1253
|
+
raise ValueError(
|
|
1254
|
+
f"The gen_log_probs_full_key {gen_log_probs_full_key} is not a tuple or does not end with 'full'. "
|
|
1255
|
+
"This may cause issues with the KL computation. "
|
|
1256
|
+
"Please use a tuple with the log_probs_key and 'full' as the last element."
|
|
1257
|
+
)
|
|
1258
|
+
log_probs_key = gen_log_probs_full_key[:-1]
|
|
1259
|
+
gen_model = collector.policy.get_new_version(
|
|
1260
|
+
generate=False,
|
|
1261
|
+
return_log_probs=True,
|
|
1262
|
+
log_probs_key=log_probs_key,
|
|
1263
|
+
input_mode=ref_model.input_mode,
|
|
1264
|
+
input_key=(ref_model.input_mode, "full"),
|
|
1265
|
+
pad_output=pad_output, # Pass pad_output from ref_model
|
|
1266
|
+
)
|
|
1267
|
+
# Create the transforms manually instead of calling __init__
|
|
1268
|
+
t1 = RetrieveLogProb(
|
|
1269
|
+
gen_model,
|
|
1270
|
+
log_probs_full_key=gen_log_probs_full_key,
|
|
1271
|
+
assistant_only=self._init_params["assistant_only"],
|
|
1272
|
+
tokenizer_kwargs=self._init_params["tokenizer_kwargs"],
|
|
1273
|
+
detach=self._init_params["detach"],
|
|
1274
|
+
device=self._init_params["device"],
|
|
1275
|
+
tokenizer=self._init_params["tokenizer"],
|
|
1276
|
+
padding_side=self._init_params["padding_side"],
|
|
1277
|
+
)
|
|
1278
|
+
ref_log_probs_full_key = self._init_params["ref_log_probs_full_key"]
|
|
1279
|
+
if (
|
|
1280
|
+
not isinstance(ref_log_probs_full_key, tuple)
|
|
1281
|
+
or ref_log_probs_full_key[-1] != "full"
|
|
1282
|
+
):
|
|
1283
|
+
raise ValueError(
|
|
1284
|
+
f"The ref_log_probs_full_key {ref_log_probs_full_key} is not a tuple or does not end with 'full'. "
|
|
1285
|
+
"This may cause issues with the KL computation. "
|
|
1286
|
+
"Please use a tuple with the log_probs_key and 'full' as the last element."
|
|
1287
|
+
)
|
|
1288
|
+
t2 = RetrieveLogProb(
|
|
1289
|
+
ref_model,
|
|
1290
|
+
log_probs_full_key=ref_log_probs_full_key,
|
|
1291
|
+
assistant_only=self._init_params["assistant_only"],
|
|
1292
|
+
tokenizer_kwargs=self._init_params["tokenizer_kwargs"],
|
|
1293
|
+
detach=self._init_params["detach"],
|
|
1294
|
+
device=self._init_params["device"],
|
|
1295
|
+
tokenizer=self._init_params["tokenizer"],
|
|
1296
|
+
padding_side=self._init_params["padding_side"],
|
|
1297
|
+
)
|
|
1298
|
+
t3 = KLComputation(
|
|
1299
|
+
gen_log_probs_full_key=gen_log_probs_full_key,
|
|
1300
|
+
ref_log_probs_full_key=ref_log_probs_full_key,
|
|
1301
|
+
kl_key=self._init_params["kl_key"],
|
|
1302
|
+
add_to_reward=self._init_params["add_to_reward"],
|
|
1303
|
+
coeff=self._init_params["coeff"],
|
|
1304
|
+
)
|
|
1305
|
+
# Replace the transforms in the Compose
|
|
1306
|
+
self.transforms.extend([t1, t2, t3])
|
|
1307
|
+
del self._init_params
|
|
1308
|
+
self._initialized = True
|
|
1309
|
+
torchrl_logger.info("Successfully initialized")
|
|
1310
|
+
|
|
1311
|
+
def _step(
|
|
1312
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
1313
|
+
) -> TensorDictBase:
|
|
1314
|
+
if not self._initialized:
|
|
1315
|
+
self._init_deferred()
|
|
1316
|
+
return super()._step(tensordict, next_tensordict)
|
|
1317
|
+
|
|
1318
|
+
def _reset(
|
|
1319
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
1320
|
+
) -> TensorDictBase:
|
|
1321
|
+
if not self._initialized:
|
|
1322
|
+
self._init_deferred()
|
|
1323
|
+
return super()._reset(tensordict, tensordict_reset)
|
|
1324
|
+
|
|
1325
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
1326
|
+
if not self._initialized:
|
|
1327
|
+
self._init_deferred()
|
|
1328
|
+
return super().forward(tensordict)
|
|
1329
|
+
|
|
1330
|
+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
1331
|
+
if not self._initialized:
|
|
1332
|
+
self._init_deferred()
|
|
1333
|
+
return super().transform_observation_spec(observation_spec)
|
|
1334
|
+
|
|
1335
|
+
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
|
|
1336
|
+
if not self._initialized:
|
|
1337
|
+
self._init_deferred()
|
|
1338
|
+
return super().transform_reward_spec(reward_spec)
|
|
1339
|
+
|
|
1340
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
1341
|
+
if not self._initialized:
|
|
1342
|
+
self._init_deferred()
|
|
1343
|
+
return super()._inv_call(tensordict)
|
|
1344
|
+
|
|
1345
|
+
def transform_action_spec(self, action_spec: Composite) -> Composite:
|
|
1346
|
+
if not self._initialized:
|
|
1347
|
+
self._init_deferred()
|
|
1348
|
+
return super().transform_action_spec(action_spec)
|
|
1349
|
+
|
|
1350
|
+
def transform_input_spec(self, input_spec: Composite) -> Composite:
|
|
1351
|
+
if not self._initialized:
|
|
1352
|
+
self._init_deferred()
|
|
1353
|
+
return super().transform_input_spec(input_spec)
|
|
1354
|
+
|
|
1355
|
+
def transform_output_spec(self, output_spec: Composite) -> Composite:
|
|
1356
|
+
if not self._initialized:
|
|
1357
|
+
self._init_deferred()
|
|
1358
|
+
return super().transform_output_spec(output_spec)
|
|
1359
|
+
|
|
1360
|
+
def transform_state_spec(self, state_spec: Composite) -> Composite:
|
|
1361
|
+
if not self._initialized:
|
|
1362
|
+
self._init_deferred()
|
|
1363
|
+
return super().transform_state_spec(state_spec)
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
class KLComputation(Transform):
|
|
1367
|
+
"""A transform to compute KL divergence between two log-prob tensors and optionally add it to the reward.
|
|
1368
|
+
|
|
1369
|
+
This transform computes KL divergence between generation and reference log-probabilities
|
|
1370
|
+
and can optionally subtract it from the reward (for KL penalty). It's designed to work
|
|
1371
|
+
with the :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb` and :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` transforms.
|
|
1372
|
+
|
|
1373
|
+
.. note::
|
|
1374
|
+
Both input log-prob tensors must use the same padding strategy (pad_output) for correct KL computation.
|
|
1375
|
+
|
|
1376
|
+
Args:
|
|
1377
|
+
gen_log_probs_full_key (NestedKey): the key where the generation model log-probs are stored.
|
|
1378
|
+
Defaults to `("gen_log_probs", "full")`.
|
|
1379
|
+
ref_log_probs_full_key (NestedKey): the key where the reference model log-probs are stored.
|
|
1380
|
+
Defaults to `("ref_log_probs", "full")`.
|
|
1381
|
+
kl_key (NestedKey): the key where the KL divergence is stored. Defaults to `"kl_penalty"`.
|
|
1382
|
+
add_to_reward (bool): whether to add the KL divergence to the reward. Defaults to `True`.
|
|
1383
|
+
coeff (float): the coefficient for the KL term when adding to reward. Defaults to `1.0`.
|
|
1384
|
+
padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
|
|
1385
|
+
|
|
1386
|
+
Examples:
|
|
1387
|
+
>>> from tensordict import TensorDict
|
|
1388
|
+
>>> import torch
|
|
1389
|
+
>>>
|
|
1390
|
+
>>> # Create sample log-probs
|
|
1391
|
+
>>> gen_log_probs = torch.randn(2, 10) # 2 samples, 10 tokens each
|
|
1392
|
+
>>> ref_log_probs = torch.randn(2, 10)
|
|
1393
|
+
>>>
|
|
1394
|
+
>>> # Create data with next tensordict
|
|
1395
|
+
>>> next_td = TensorDict(
|
|
1396
|
+
... {
|
|
1397
|
+
... ("gen_log_probs", "full"): gen_log_probs,
|
|
1398
|
+
... ("ref_log_probs", "full"): ref_log_probs,
|
|
1399
|
+
... "reward": torch.randn(2, 10, 1),
|
|
1400
|
+
... },
|
|
1401
|
+
... batch_size=(2,)
|
|
1402
|
+
... )
|
|
1403
|
+
>>> data = TensorDict(next=next_td, batch_size=(2,))
|
|
1404
|
+
>>>
|
|
1405
|
+
>>> # Create KLComputation transform
|
|
1406
|
+
>>> kl_transform = KLComputation(
|
|
1407
|
+
... gen_log_probs_key=("gen_log_probs", "full"),
|
|
1408
|
+
... ref_log_probs_key=("ref_log_probs", "full"),
|
|
1409
|
+
... kl_key="kl_penalty",
|
|
1410
|
+
... add_to_reward=True,
|
|
1411
|
+
... coef=1.0,
|
|
1412
|
+
... )
|
|
1413
|
+
>>>
|
|
1414
|
+
>>> # Apply transform
|
|
1415
|
+
>>> result = kl_transform(data)
|
|
1416
|
+
>>> kl = result["next"].get("kl_penalty")
|
|
1417
|
+
>>> print(f"KL shape: {kl.shape}")
|
|
1418
|
+
KL shape: torch.Size([2, 10])
|
|
1419
|
+
|
|
1420
|
+
.. seealso::
|
|
1421
|
+
:class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: The base transform for retrieving log-probabilities from a single model.
|
|
1422
|
+
:class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: A higher-level transform that combines two `RetrieveLogProb` instances with `KLComputation`.
|
|
1423
|
+
:class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead).
|
|
1424
|
+
|
|
1425
|
+
"""
|
|
1426
|
+
|
|
1427
|
+
def __init__(
|
|
1428
|
+
self,
|
|
1429
|
+
gen_log_probs_full_key: NestedKey = ("log_probs", "full"),
|
|
1430
|
+
ref_log_probs_full_key: NestedKey = ("ref_log_probs", "full"),
|
|
1431
|
+
*,
|
|
1432
|
+
kl_key: NestedKey = "kl_penalty",
|
|
1433
|
+
add_to_reward: bool = True,
|
|
1434
|
+
coeff: float = 1.0,
|
|
1435
|
+
padding_side: str = "left",
|
|
1436
|
+
):
|
|
1437
|
+
in_keys = [gen_log_probs_full_key, ref_log_probs_full_key]
|
|
1438
|
+
if add_to_reward:
|
|
1439
|
+
in_keys.append("reward")
|
|
1440
|
+
out_keys = [kl_key]
|
|
1441
|
+
if add_to_reward:
|
|
1442
|
+
out_keys.append("reward")
|
|
1443
|
+
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
|
1444
|
+
|
|
1445
|
+
self.gen_log_probs_full_key = gen_log_probs_full_key
|
|
1446
|
+
self.ref_log_probs_full_key = ref_log_probs_full_key
|
|
1447
|
+
self.kl_key = kl_key
|
|
1448
|
+
self.add_to_reward = add_to_reward
|
|
1449
|
+
self.coeff = coeff
|
|
1450
|
+
self.padding_side = padding_side
|
|
1451
|
+
|
|
1452
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
1453
|
+
next_td = tensordict.get("next")
|
|
1454
|
+
has_next_td = True
|
|
1455
|
+
if next_td is None:
|
|
1456
|
+
next_td = tensordict
|
|
1457
|
+
has_next_td = False
|
|
1458
|
+
next_td = self._step(tensordict, next_td)
|
|
1459
|
+
if has_next_td:
|
|
1460
|
+
return tensordict.set("next", next_td)
|
|
1461
|
+
return next_td
|
|
1462
|
+
|
|
1463
|
+
def _step(
|
|
1464
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
1465
|
+
) -> TensorDictBase:
|
|
1466
|
+
# Get log-probs
|
|
1467
|
+
gen_log_probs = next_tensordict.get(self.gen_log_probs_full_key, as_list=True) # type: ignore[misc]
|
|
1468
|
+
ref_log_probs = next_tensordict.get(self.ref_log_probs_full_key, as_list=True) # type: ignore[misc]
|
|
1469
|
+
|
|
1470
|
+
if gen_log_probs is None or ref_log_probs is None:
|
|
1471
|
+
raise ValueError(
|
|
1472
|
+
f"Log-probs not found. Expected keys: {self.gen_log_probs_key}, {self.ref_log_probs_key}"
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1475
|
+
# Debug: Check lengths and shapes
|
|
1476
|
+
if len(gen_log_probs) != len(ref_log_probs):
|
|
1477
|
+
raise ValueError(
|
|
1478
|
+
f"Batch size mismatch: gen_log_probs has {len(gen_log_probs)} samples, ref_log_probs has {len(ref_log_probs)} samples"
|
|
1479
|
+
)
|
|
1480
|
+
|
|
1481
|
+
# Check individual sequence lengths
|
|
1482
|
+
for i, (gen_lp, ref_lp) in enumerate(_zip_strict(gen_log_probs, ref_log_probs)):
|
|
1483
|
+
if gen_lp.shape != ref_lp.shape:
|
|
1484
|
+
raise ValueError(
|
|
1485
|
+
f"Sample {i} has different shapes: gen_log_probs[{i}].shape={gen_lp.shape}, ref_log_probs[{i}].shape={ref_lp.shape}"
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
# Compute KL divergence: KL(p||q) = E_p[log p - log q]
|
|
1489
|
+
# Here gen_log_probs = log p, ref_log_probs = log q
|
|
1490
|
+
kl = [
|
|
1491
|
+
gen_lp - ref_lp
|
|
1492
|
+
for gen_lp, ref_lp in _zip_strict(gen_log_probs, ref_log_probs)
|
|
1493
|
+
]
|
|
1494
|
+
|
|
1495
|
+
kl = torch.nested.as_nested_tensor(kl, layout=torch.strided)
|
|
1496
|
+
|
|
1497
|
+
next_tensordict.set(self.kl_key, kl)
|
|
1498
|
+
|
|
1499
|
+
# Add to reward if requested
|
|
1500
|
+
if self.add_to_reward:
|
|
1501
|
+
reward = next_tensordict.get("reward", as_list=True) # type: ignore[misc]
|
|
1502
|
+
if reward is not None:
|
|
1503
|
+
if isinstance(reward, list):
|
|
1504
|
+
if reward[0].ndim != kl[0].ndim + 1:
|
|
1505
|
+
raise ValueError(
|
|
1506
|
+
f"The rewards have shape {reward[0].shape} but the kl has shape {kl[0].shape}. "
|
|
1507
|
+
f"The rewards should have one more dimension than the KL."
|
|
1508
|
+
)
|
|
1509
|
+
reward = [
|
|
1510
|
+
r - self.coeff * k.unsqueeze(-1)
|
|
1511
|
+
for r, k in _zip_strict(reward, kl)
|
|
1512
|
+
]
|
|
1513
|
+
next_tensordict.set(
|
|
1514
|
+
"reward",
|
|
1515
|
+
torch.nested.as_nested_tensor(reward, layout=torch.strided),
|
|
1516
|
+
)
|
|
1517
|
+
else:
|
|
1518
|
+
if reward.ndim != kl.ndim + 1:
|
|
1519
|
+
raise ValueError(
|
|
1520
|
+
f"The rewards have shape {reward.shape} but the kl has shape {kl.shape}. "
|
|
1521
|
+
f"The rewards should have one more dimension than the KL."
|
|
1522
|
+
)
|
|
1523
|
+
reward = reward - self.coeff * kl.unsqueeze(-1)
|
|
1524
|
+
next_tensordict.set("reward", reward)
|
|
1525
|
+
|
|
1526
|
+
return next_tensordict
|
|
1527
|
+
|
|
1528
|
+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
1529
|
+
# Add kl to observation spec
|
|
1530
|
+
observation_spec[self.kl_key] = Unbounded(
|
|
1531
|
+
device=observation_spec.device,
|
|
1532
|
+
shape=observation_spec.shape,
|
|
1533
|
+
)
|
|
1534
|
+
return observation_spec
|
|
1535
|
+
|
|
1536
|
+
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
|
|
1537
|
+
# Optionally adjust reward spec if KL is added to reward
|
|
1538
|
+
if self.add_to_reward:
|
|
1539
|
+
shape = reward_spec["reward"].shape
|
|
1540
|
+
# For LLMs, the shape of the reward is (batch, -1, 1)
|
|
1541
|
+
shape = torch.Size((*shape, -1, 1))
|
|
1542
|
+
reward_spec["reward"] = reward_spec["reward"].clone()
|
|
1543
|
+
reward_spec["reward"].shape = shape
|
|
1544
|
+
return reward_spec
|