torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from tensordict import tensorclass, TensorDict
|
|
9
|
+
|
|
10
|
+
from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
|
|
11
|
+
|
|
12
|
+
DEFAULT_DATASET = "CarperAI/openai_summarize_tldr"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@tensorclass
|
|
16
|
+
class PromptData:
|
|
17
|
+
"""A prompt dataset."""
|
|
18
|
+
|
|
19
|
+
input_ids: torch.Tensor
|
|
20
|
+
attention_mask: torch.Tensor
|
|
21
|
+
prompt_rindex: torch.Tensor
|
|
22
|
+
labels: torch.Tensor | None = None
|
|
23
|
+
logits: torch.Tensor | None = None
|
|
24
|
+
loss: torch.Tensor | None = None
|
|
25
|
+
|
|
26
|
+
def mask_label(self, pad_token_id=50256):
|
|
27
|
+
_, block_size = self.input_ids.shape
|
|
28
|
+
attention_mask = (
|
|
29
|
+
torch.arange(block_size, device=self.prompt_rindex.device)
|
|
30
|
+
< self.prompt_rindex[:, None]
|
|
31
|
+
).to(torch.int64)
|
|
32
|
+
input_ids = torch.where(attention_mask == 1, self.input_ids, pad_token_id)
|
|
33
|
+
return self.__class__(
|
|
34
|
+
input_ids=input_ids,
|
|
35
|
+
attention_mask=attention_mask,
|
|
36
|
+
prompt_rindex=self.prompt_rindex,
|
|
37
|
+
loss=self.loss,
|
|
38
|
+
batch_size=[],
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_dataset(
|
|
43
|
+
cls,
|
|
44
|
+
split,
|
|
45
|
+
dataset_name=None,
|
|
46
|
+
max_length=550,
|
|
47
|
+
root_dir=None,
|
|
48
|
+
from_disk=False,
|
|
49
|
+
num_workers: int | None = None,
|
|
50
|
+
):
|
|
51
|
+
"""Returns a :class:`PromptData` from a dataset name.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
split (str): ``"train"`` or ``"valid"`` depending on the data split needed.
|
|
55
|
+
dataset_name (str, optional): name of the dataset to be processed. Defaults to
|
|
56
|
+
``"CarperAI/openai_summarize_comparisons"``.
|
|
57
|
+
max_length (int, optional): maximum length of the dataset sequences.
|
|
58
|
+
Defaults to 550.
|
|
59
|
+
root_dir (path, optional): the path where the datasets are stored.
|
|
60
|
+
Defaults to ``"$HOME/.cache/torchrl/data"``
|
|
61
|
+
from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
|
|
62
|
+
will be used. Otherwise, :func:`datasets.load_dataset` will be used.
|
|
63
|
+
Defaults to ``False``.
|
|
64
|
+
num_workers (int, optional): number of workers for :meth:`datasets.dataset.map`
|
|
65
|
+
which is called during tokenization.
|
|
66
|
+
Defaults to ``max(os.cpu_count() // 2, 1)``.
|
|
67
|
+
|
|
68
|
+
Returns: a :class:`PromptData` instance containing a memory-mapped
|
|
69
|
+
version of the required dataset.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> data = PromptData.from_dataset("train")
|
|
73
|
+
>>> print(data)
|
|
74
|
+
PromptDataTLDR(
|
|
75
|
+
attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
76
|
+
input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
77
|
+
prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
78
|
+
labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
79
|
+
logits=None,
|
|
80
|
+
loss=None,
|
|
81
|
+
batch_size=torch.Size([116722]),
|
|
82
|
+
device=None,
|
|
83
|
+
is_shared=False)
|
|
84
|
+
>>> # data can be sampled from using regular indexing
|
|
85
|
+
>>> sub_data = data[:3]
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
dataset_name = dataset_name if dataset_name is not None else DEFAULT_DATASET
|
|
89
|
+
loader = TokenizedDatasetLoader(
|
|
90
|
+
split,
|
|
91
|
+
max_length,
|
|
92
|
+
dataset_name,
|
|
93
|
+
PromptTensorDictTokenizer,
|
|
94
|
+
root_dir=root_dir,
|
|
95
|
+
from_disk=from_disk,
|
|
96
|
+
num_workers=num_workers,
|
|
97
|
+
)
|
|
98
|
+
data = loader.load()
|
|
99
|
+
return cls(**data, labels=data["input_ids"], batch_size=data.shape)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class PromptTensorDictTokenizer(TensorDictTokenizer):
|
|
103
|
+
"""Tokenization recipe for prompt datasets.
|
|
104
|
+
|
|
105
|
+
Returns a tokenizer function, which reads an example containing a prompt
|
|
106
|
+
and a label and tokenizes them.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
tokenizer (tokenizer from transformers library): the tokenizer to use.
|
|
110
|
+
max_length (int): maximum length of the sequence.
|
|
111
|
+
key (str, optional): the key where to find the text. Defaults to ``"prompt"``.
|
|
112
|
+
padding (str, optional): type of padding. Defaults to ``"max_length"``.
|
|
113
|
+
truncation (bool, optional): whether the sequences should be truncated to max_length.
|
|
114
|
+
return_tensordict (bool, optional): if ``True``, a TensoDict is returned.
|
|
115
|
+
Otherwise, a the original data will be returned.
|
|
116
|
+
device (torch.device, optional): the device where to store the data.
|
|
117
|
+
This option is ignored if ``return_tensordict=False``.
|
|
118
|
+
|
|
119
|
+
The :meth:`__call__` method of this class will execute the following operations:
|
|
120
|
+
|
|
121
|
+
- Read the ``prompt`` string contacted with the ``label`` string and tokenize
|
|
122
|
+
them. The results will be stored in the ``"input_ids"`` TensorDict entry.
|
|
123
|
+
- Write a ``"prompt_rindex"`` entry with the index of the last valid
|
|
124
|
+
token from the prompt.
|
|
125
|
+
- Write a ``"valid_sample"`` which identifies which entry in the
|
|
126
|
+
tensordict has eough toknens to meet the ``max_length`` criterion.
|
|
127
|
+
- Return a :class:`tensordict.TensorDict` instance with tokenized inputs.
|
|
128
|
+
|
|
129
|
+
The tensordict batch-size will match the batch-size of the input.
|
|
130
|
+
|
|
131
|
+
Examples:
|
|
132
|
+
>>> from transformers import AutoTokenizer
|
|
133
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
134
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
|
135
|
+
>>> example = {
|
|
136
|
+
... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"],
|
|
137
|
+
... "label": ["Indeed it is.", 'It might as well be.'],
|
|
138
|
+
... }
|
|
139
|
+
>>> fn = PromptTensorDictTokenizer(tokenizer, 50)
|
|
140
|
+
>>> print(fn(example))
|
|
141
|
+
TensorDict(
|
|
142
|
+
fields={
|
|
143
|
+
attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
144
|
+
input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
145
|
+
prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
146
|
+
valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
147
|
+
batch_size=torch.Size([2]),
|
|
148
|
+
device=None,
|
|
149
|
+
is_shared=False)
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
tokenizer,
|
|
156
|
+
max_length,
|
|
157
|
+
key="prompt",
|
|
158
|
+
padding="max_length",
|
|
159
|
+
truncation=True,
|
|
160
|
+
return_tensordict=True,
|
|
161
|
+
device=None,
|
|
162
|
+
):
|
|
163
|
+
self.tokenizer = tokenizer
|
|
164
|
+
self.max_length = max_length
|
|
165
|
+
self.key = key
|
|
166
|
+
self.padding = padding
|
|
167
|
+
self.truncation = truncation
|
|
168
|
+
self.return_tensordict = return_tensordict
|
|
169
|
+
self.device = device
|
|
170
|
+
|
|
171
|
+
def __call__(self, sample):
|
|
172
|
+
tokenizer = self.tokenizer
|
|
173
|
+
max_length = self.max_length
|
|
174
|
+
|
|
175
|
+
tokenized_prompts = tokenizer(
|
|
176
|
+
sample[self.key], max_length=max_length, truncation=True
|
|
177
|
+
)
|
|
178
|
+
prompt_rindex = [len(prompt) - 1 for prompt in tokenized_prompts["input_ids"]]
|
|
179
|
+
tokenized_example = tokenizer(
|
|
180
|
+
[
|
|
181
|
+
prompt + label
|
|
182
|
+
for prompt, label in zip(sample[self.key], sample["label"])
|
|
183
|
+
],
|
|
184
|
+
max_length=max_length,
|
|
185
|
+
padding=self.padding,
|
|
186
|
+
truncation=self.truncation,
|
|
187
|
+
)
|
|
188
|
+
tokenized_example["prompt_rindex"] = prompt_rindex
|
|
189
|
+
# drop any examples whose total length when tokenized exceeds block size
|
|
190
|
+
# with recommended block size of 550, this is only ~0.1% of available examples.
|
|
191
|
+
# NOTE: to mark as discarded we just save the mask as we cannot change the shape here
|
|
192
|
+
tokenized_example["valid_sample"] = [True] * len(tokenized_example["input_ids"])
|
|
193
|
+
for i, input_ids in enumerate(tokenized_example["input_ids"]):
|
|
194
|
+
if input_ids[-1] != tokenizer.eos_token_id:
|
|
195
|
+
tokenized_example["valid_sample"][i] = False
|
|
196
|
+
if self.return_tensordict:
|
|
197
|
+
return TensorDict.from_dict(dict(tokenized_example), device=self.device)
|
|
198
|
+
return tokenized_example
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict import tensorclass
|
|
11
|
+
from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
|
|
12
|
+
|
|
13
|
+
DEFAULT_DATASET = "CarperAI/openai_summarize_comparisons"
|
|
14
|
+
_has_datasets = importlib.util.find_spec("datasets") is not None
|
|
15
|
+
_has_tqdm = importlib.util.find_spec("tqdm") is not None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@tensorclass
|
|
19
|
+
class RewardData:
|
|
20
|
+
"""A dataclass for reward model training."""
|
|
21
|
+
|
|
22
|
+
input_ids: torch.Tensor
|
|
23
|
+
attention_mask: torch.Tensor
|
|
24
|
+
rewards: torch.Tensor | None = None
|
|
25
|
+
end_scores: torch.Tensor | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@tensorclass
|
|
29
|
+
class PairwiseDataset:
|
|
30
|
+
"""Represents a dataset in a pairwise manner (chosen vs rejected).
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
chosen_data: data to be chosen.
|
|
34
|
+
rejected_data: corresponding data to be rejected.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> data = PairwiseDataset.from_dataset("train", max_length=550)
|
|
38
|
+
>>> print(data)
|
|
39
|
+
PairwiseDataset(
|
|
40
|
+
chosen_data=RewardData(
|
|
41
|
+
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
42
|
+
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
43
|
+
rewards=None,
|
|
44
|
+
end_scores=None,
|
|
45
|
+
batch_size=torch.Size([92534]),
|
|
46
|
+
device=None,
|
|
47
|
+
is_shared=False),
|
|
48
|
+
rejected_data=RewardData(
|
|
49
|
+
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
50
|
+
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
51
|
+
rewards=None,
|
|
52
|
+
end_scores=None,
|
|
53
|
+
batch_size=torch.Size([92534]),
|
|
54
|
+
device=None,
|
|
55
|
+
is_shared=False),
|
|
56
|
+
batch_size=torch.Size([92534]),
|
|
57
|
+
device=None,
|
|
58
|
+
is_shared=False)
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
chosen_data: RewardData
|
|
63
|
+
rejected_data: RewardData
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_dataset(
|
|
67
|
+
cls,
|
|
68
|
+
split,
|
|
69
|
+
dataset_name: str | None = None,
|
|
70
|
+
max_length: int = 550,
|
|
71
|
+
root_dir: str | None = None,
|
|
72
|
+
from_disk: bool = False,
|
|
73
|
+
num_workers: int | None = None,
|
|
74
|
+
):
|
|
75
|
+
"""Returns a :class:`PairwiseDataset` from a dataset name.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
split (str): ``"train"`` or ``"valid"`` depending on the data split needed.
|
|
79
|
+
dataset_name (str, optional): name of the dataset to be processed. Defaults to
|
|
80
|
+
``"CarperAI/openai_summarize_comparisons"``.
|
|
81
|
+
max_length (int, optional): maximum length of the dataset sequences.
|
|
82
|
+
Defaults to 550.
|
|
83
|
+
root_dir (path, optional): the path where the datasets are stored.
|
|
84
|
+
Defaults to ``"$HOME/.cache/torchrl/data"``
|
|
85
|
+
from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
|
|
86
|
+
will be used. Otherwise, :func:`datasets.load_dataset` will be used.
|
|
87
|
+
Defaults to ``False``.
|
|
88
|
+
|
|
89
|
+
Returns: a :class:`PairwiseDataset` instance containing a memory-mapped
|
|
90
|
+
version of the required dataset.
|
|
91
|
+
|
|
92
|
+
Examples:
|
|
93
|
+
>>> data = PairwiseDataset.from_dataset("train")
|
|
94
|
+
>>> print(data)
|
|
95
|
+
PairwiseDataset(
|
|
96
|
+
chosen_data=RewardData(
|
|
97
|
+
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
98
|
+
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
99
|
+
rewards=None,
|
|
100
|
+
end_scores=None,
|
|
101
|
+
batch_size=torch.Size([92534]),
|
|
102
|
+
device=None,
|
|
103
|
+
is_shared=False),
|
|
104
|
+
rejected_data=RewardData(
|
|
105
|
+
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
106
|
+
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
107
|
+
rewards=None,
|
|
108
|
+
end_scores=None,
|
|
109
|
+
batch_size=torch.Size([92534]),
|
|
110
|
+
device=None,
|
|
111
|
+
is_shared=False),
|
|
112
|
+
batch_size=torch.Size([92534]),
|
|
113
|
+
device=None,
|
|
114
|
+
is_shared=False)
|
|
115
|
+
>>> # data can be sampled from using regular indexing
|
|
116
|
+
>>> sub_data = data[:3]
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
if dataset_name is None:
|
|
120
|
+
dataset_name = DEFAULT_DATASET
|
|
121
|
+
loader = TokenizedDatasetLoader(
|
|
122
|
+
split,
|
|
123
|
+
max_length,
|
|
124
|
+
dataset_name,
|
|
125
|
+
TensorDictTokenizer,
|
|
126
|
+
pre_tokenization_hook,
|
|
127
|
+
root_dir=root_dir,
|
|
128
|
+
from_disk=from_disk,
|
|
129
|
+
num_workers=num_workers,
|
|
130
|
+
)
|
|
131
|
+
data = loader.load()
|
|
132
|
+
maxidx = data.shape[0] // 2
|
|
133
|
+
batch_size = [maxidx]
|
|
134
|
+
# this is a zero-copy creation, as we index memmap-arrays without
|
|
135
|
+
# creating new storage.
|
|
136
|
+
chosen_data = data[:maxidx]
|
|
137
|
+
rejected_data = data[maxidx:]
|
|
138
|
+
return cls(
|
|
139
|
+
chosen_data=RewardData(
|
|
140
|
+
**chosen_data,
|
|
141
|
+
batch_size=batch_size,
|
|
142
|
+
),
|
|
143
|
+
rejected_data=RewardData(
|
|
144
|
+
**rejected_data,
|
|
145
|
+
batch_size=batch_size,
|
|
146
|
+
),
|
|
147
|
+
batch_size=batch_size,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def pre_tokenization_hook(dataset, min_length=5):
|
|
152
|
+
"""Pre-tokenizer for the reward model (comparison) dataset.
|
|
153
|
+
|
|
154
|
+
This function selects all samples where the length of the prompt is
|
|
155
|
+
sufficient and where the ``"chosen"`` and ``"rejected"`` entries differ.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
dataset (datasets.Dataset): the dataset to process. Should have entries
|
|
159
|
+
``"prompt"``, ``"chosen"`` and ``"rejected"``.
|
|
160
|
+
min_length (int, optional): minimum length of a prompt (in word count).
|
|
161
|
+
|
|
162
|
+
Returns: a new ``datasets.Dataset`` with selected prompts under ``"text"``.
|
|
163
|
+
The first half are the chosen strings and the second the rejected ones,
|
|
164
|
+
always preceded by the original prompt.
|
|
165
|
+
|
|
166
|
+
Examples:
|
|
167
|
+
>>> from datasets import Dataset
|
|
168
|
+
>>> data = Dataset.from_dict({
|
|
169
|
+
... "prompt": ["I'm the king"],
|
|
170
|
+
... "chosen": ["It is true, you are the king"],
|
|
171
|
+
... "rejected": ["No, I am the king, you are not"]})
|
|
172
|
+
>>> print(pre_tokenization_hook(data))
|
|
173
|
+
Dataset({
|
|
174
|
+
features: ['text'],
|
|
175
|
+
num_rows: 2
|
|
176
|
+
})
|
|
177
|
+
>>> data = Dataset.from_dict({
|
|
178
|
+
... "prompt": ["I'm the king"],
|
|
179
|
+
... "chosen": ["It is true, you are the king"],
|
|
180
|
+
... "rejected": ["It is true, you are the king"]}) # chosen and rejected match
|
|
181
|
+
>>> print(pre_tokenization_hook(data))
|
|
182
|
+
Dataset({
|
|
183
|
+
features: [],
|
|
184
|
+
num_rows: 0
|
|
185
|
+
})
|
|
186
|
+
>>> data = Dataset.from_dict({
|
|
187
|
+
... "prompt": ["I'm the king"],
|
|
188
|
+
... "chosen": ["Yes"],
|
|
189
|
+
... "rejected": ["No"]}) # chosen and rejected are too short
|
|
190
|
+
>>> print(pre_tokenization_hook(data))
|
|
191
|
+
Dataset({
|
|
192
|
+
features: [],
|
|
193
|
+
num_rows: 0
|
|
194
|
+
})
|
|
195
|
+
|
|
196
|
+
"""
|
|
197
|
+
if not _has_datasets:
|
|
198
|
+
raise ImportError(
|
|
199
|
+
"datasets module couldn't be found. Make sure it is installed in your environment."
|
|
200
|
+
)
|
|
201
|
+
from datasets import Dataset as HFDataset
|
|
202
|
+
|
|
203
|
+
chosen = []
|
|
204
|
+
rejected = []
|
|
205
|
+
if _has_tqdm:
|
|
206
|
+
from tqdm import tqdm
|
|
207
|
+
|
|
208
|
+
pbar = tqdm(dataset)
|
|
209
|
+
else:
|
|
210
|
+
pbar = dataset
|
|
211
|
+
for sample in pbar:
|
|
212
|
+
prompt = sample["prompt"]
|
|
213
|
+
chosen_summary = sample["chosen"]
|
|
214
|
+
rejected_summary = sample["rejected"]
|
|
215
|
+
if chosen_summary == rejected_summary:
|
|
216
|
+
continue
|
|
217
|
+
if (
|
|
218
|
+
len(chosen_summary.split()) < min_length
|
|
219
|
+
or len(rejected_summary.split()) < min_length
|
|
220
|
+
):
|
|
221
|
+
continue
|
|
222
|
+
chosen.append({"text": prompt + "\n" + chosen_summary})
|
|
223
|
+
rejected.append({"text": prompt + "\n" + rejected_summary})
|
|
224
|
+
|
|
225
|
+
return HFDataset.from_list(chosen + rejected)
|
torchrl/data/llm/topk.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections import defaultdict, deque
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import NestedKey, TensorDictBase
|
|
12
|
+
from torchrl._utils import logger as torchrl_logger
|
|
13
|
+
from torchrl.envs.transforms import Transform
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TopKRewardSelector(Transform):
|
|
17
|
+
"""A replay-buffer transform that selects the top-k rewards for each prompt.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
|
|
21
|
+
topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
|
|
22
|
+
prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
|
|
23
|
+
rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
|
|
24
|
+
done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
|
|
25
|
+
verbose (bool): Whether to print verbose information. Defaults to `False`.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
|
|
29
|
+
>>> from tensordict import TensorDict, lazy_stack
|
|
30
|
+
>>> import torch
|
|
31
|
+
>>> from torchrl.data.llm.topk import TopKRewardSelector
|
|
32
|
+
>>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
|
|
33
|
+
>>> rb = ReplayBuffer(
|
|
34
|
+
... storage=LazyStackStorage(50),
|
|
35
|
+
... sampler=SamplerWithoutReplacement,
|
|
36
|
+
... batch_size=5,
|
|
37
|
+
... )
|
|
38
|
+
>>> # Create a tensordict with 50 items, each with 10 dialog turns
|
|
39
|
+
>>> td = lazy_stack(
|
|
40
|
+
... [
|
|
41
|
+
... TensorDict(
|
|
42
|
+
... {
|
|
43
|
+
... ("next", "done"): torch.full((1, 1), True),
|
|
44
|
+
... # Reward for i+5 tokens
|
|
45
|
+
... ("next", "reward"): torch.full((i + 5, 1), i),
|
|
46
|
+
... # total of 10 dialogs per prompt
|
|
47
|
+
... "text": f"Prompt {i // 5}",
|
|
48
|
+
... }
|
|
49
|
+
... )
|
|
50
|
+
... for i in range(50)
|
|
51
|
+
... ]
|
|
52
|
+
... )
|
|
53
|
+
>>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
|
|
54
|
+
>>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
|
|
55
|
+
>>> rb.append_transform(topk)
|
|
56
|
+
>>> for _td in td.chunk(25):
|
|
57
|
+
... rb.extend(_td)
|
|
58
|
+
>>> # Only wrote top3 of 50 items in 10 groups of 5
|
|
59
|
+
>>> assert rb.write_count == 30
|
|
60
|
+
>>> assert len(rb) == 30
|
|
61
|
+
>>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
|
|
62
|
+
>>> # 0 and 1 are missing because they're not part of the top-k
|
|
63
|
+
>>> assert (
|
|
64
|
+
... r3 == torch.tensor(
|
|
65
|
+
... [
|
|
66
|
+
... [4, 4, 4, 4, 4, 4, 4, 4, 4],
|
|
67
|
+
... [3, 3, 3, 3, 3, 3, 3, 3, 0],
|
|
68
|
+
... [2, 2, 2, 2, 2, 2, 2, 0, 0],
|
|
69
|
+
... ]
|
|
70
|
+
... )
|
|
71
|
+
... ).all()
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
total_dialog_turns: int,
|
|
77
|
+
topk_size: int,
|
|
78
|
+
prompt_key: NestedKey = ("text", "prompt"),
|
|
79
|
+
rewards_key: NestedKey = ("next", "reward"),
|
|
80
|
+
done_key: NestedKey = ("next", "done"),
|
|
81
|
+
verbose: bool = True,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.in_keys = [prompt_key, rewards_key, done_key]
|
|
85
|
+
self.prompt_key = prompt_key
|
|
86
|
+
self.rewards_key = rewards_key
|
|
87
|
+
self.done_key = done_key
|
|
88
|
+
self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
|
|
89
|
+
self.total_dialog_turns = total_dialog_turns
|
|
90
|
+
self.topk_size = topk_size
|
|
91
|
+
if topk_size > total_dialog_turns:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
|
|
94
|
+
)
|
|
95
|
+
self.verbose = verbose
|
|
96
|
+
|
|
97
|
+
def forward(self, tensordict: TensorDictBase) -> Any:
|
|
98
|
+
return tensordict
|
|
99
|
+
|
|
100
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
101
|
+
# Tensordict can be any number of dims, but it must contain entire trajectories
|
|
102
|
+
if tensordict.ndim == 1:
|
|
103
|
+
# Check how many done states we have
|
|
104
|
+
num_done = tensordict[self.done_key].sum()
|
|
105
|
+
if num_done > 1:
|
|
106
|
+
done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
|
|
107
|
+
splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
|
|
108
|
+
tensordicts = tensordict.split(splits)
|
|
109
|
+
tensordicts = [self._inv_call(td) for td in tensordicts]
|
|
110
|
+
tensordicts = [td for td in tensordicts if td is not None]
|
|
111
|
+
return torch.cat(tensordicts) if tensordicts else None
|
|
112
|
+
# Then we have a single trajectory
|
|
113
|
+
if not tensordict[-1][self.done_key].all():
|
|
114
|
+
raise RuntimeError("Expected the trajectory to be done.")
|
|
115
|
+
prompt = tensordict[0][self.prompt_key]
|
|
116
|
+
if not isinstance(prompt, str):
|
|
117
|
+
raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
|
|
118
|
+
self.queues[prompt].append(tensordict)
|
|
119
|
+
if len(self.queues[prompt]) == self.total_dialog_turns:
|
|
120
|
+
if self.verbose:
|
|
121
|
+
torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
|
|
122
|
+
# Cat is the most robust way to combine the trajs
|
|
123
|
+
tds = torch.cat(list(self.queues[prompt]), -1)
|
|
124
|
+
# Collect rewards
|
|
125
|
+
reward = tds.get(self.rewards_key, as_nested_tensor=True)
|
|
126
|
+
reward = self._aggregate_rewards(reward)
|
|
127
|
+
# Check if all rewards are equal
|
|
128
|
+
if (reward == reward[0]).all():
|
|
129
|
+
# If all rewards are equal, we can't select top-k
|
|
130
|
+
if self.verbose:
|
|
131
|
+
torchrl_logger.warning(
|
|
132
|
+
f"All rewards are equal ({reward.unique()=})"
|
|
133
|
+
)
|
|
134
|
+
return
|
|
135
|
+
# Filter out rewards below median
|
|
136
|
+
median_reward = reward.median(dim=-1, keepdim=True)[0]
|
|
137
|
+
mask = reward > median_reward
|
|
138
|
+
filtered_reward = reward[mask]
|
|
139
|
+
filtered_indices = mask.nonzero(as_tuple=True)[0]
|
|
140
|
+
# Get top-k from filtered rewards
|
|
141
|
+
topk_reward = filtered_reward.topk(
|
|
142
|
+
k=min(self.topk_size, len(filtered_indices)), dim=-1
|
|
143
|
+
)
|
|
144
|
+
if not topk_reward.indices.numel():
|
|
145
|
+
if self.verbose:
|
|
146
|
+
torchrl_logger.warning(
|
|
147
|
+
f"No top-{self.topk_size} rewards found ({reward=})"
|
|
148
|
+
)
|
|
149
|
+
return
|
|
150
|
+
# Map back to original indices
|
|
151
|
+
selected_indices = filtered_indices[topk_reward.indices]
|
|
152
|
+
tds = tds[selected_indices]
|
|
153
|
+
if self.verbose:
|
|
154
|
+
torchrl_logger.info(
|
|
155
|
+
f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
|
|
156
|
+
)
|
|
157
|
+
return tds
|
|
158
|
+
return
|
|
159
|
+
elif tensordict.ndim > 2:
|
|
160
|
+
# keep the time dim at the end
|
|
161
|
+
tensordict = tensordict.flatten(0, -2)
|
|
162
|
+
trajs = tensordict.unbind(-1)
|
|
163
|
+
# Iterate over the trajectories
|
|
164
|
+
result = []
|
|
165
|
+
for traj in trajs:
|
|
166
|
+
td_out = self._inv_call(traj)
|
|
167
|
+
if td_out is None:
|
|
168
|
+
continue
|
|
169
|
+
result.append(td_out)
|
|
170
|
+
if result:
|
|
171
|
+
return torch.cat(result, -1)
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
|
|
175
|
+
"""Aggregate the rewards across the dialog turns.
|
|
176
|
+
|
|
177
|
+
`reward` is expected to be a nested tensor.
|
|
178
|
+
|
|
179
|
+
The default implementation is to take the mean of the rewards across the dialog turns.
|
|
180
|
+
"""
|
|
181
|
+
# reward = reward.to_padded_tensor(padding=0.0)
|
|
182
|
+
if reward.ndim < 2 or reward.ndim > 3:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
|
|
185
|
+
)
|
|
186
|
+
return reward.mean(dim=-2).squeeze(-1)
|