torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1378 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from typing import Literal, TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from tensordict import (
|
|
15
|
+
lazy_stack,
|
|
16
|
+
LazyStackedTensorDict,
|
|
17
|
+
list_to_stack,
|
|
18
|
+
TensorClass,
|
|
19
|
+
TensorDict,
|
|
20
|
+
)
|
|
21
|
+
from tensordict.utils import _maybe_correct_neg_dim
|
|
22
|
+
from torchrl._utils import logger as torchrl_logger
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
import transformers
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Global storage for custom templates and their metadata
|
|
29
|
+
_CHAT_TEMPLATES = {
|
|
30
|
+
"chatml_format": """{% for message in messages %}
|
|
31
|
+
{%- if message['role'] == 'assistant' %}
|
|
32
|
+
{% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
|
|
33
|
+
{%- else %}
|
|
34
|
+
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
|
35
|
+
{%- endif %}
|
|
36
|
+
{% endfor %}
|
|
37
|
+
{%- if add_generation_prompt %}
|
|
38
|
+
{% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
|
|
39
|
+
{%- endif %}
|
|
40
|
+
""",
|
|
41
|
+
"qwen": """
|
|
42
|
+
{%- if tools %}
|
|
43
|
+
{{- '<|im_start|>system\\n' }}
|
|
44
|
+
{%- if messages[0]['role'] == 'system' %}
|
|
45
|
+
{{- messages[0]['content'] }}
|
|
46
|
+
{%- else %}
|
|
47
|
+
{{- 'You are a helpful assistant.' }}
|
|
48
|
+
{%- endif %}
|
|
49
|
+
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
|
|
50
|
+
{%- for tool in tools %}
|
|
51
|
+
{{- "\\n" }}
|
|
52
|
+
{{- tool | tojson }}
|
|
53
|
+
{%- endfor %}
|
|
54
|
+
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}
|
|
55
|
+
{%- else %}
|
|
56
|
+
{%- if messages[0]['role'] == 'system' %}
|
|
57
|
+
{{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
|
|
58
|
+
{%- else %}
|
|
59
|
+
{{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
|
|
60
|
+
{%- endif %}
|
|
61
|
+
{%- endif %}
|
|
62
|
+
{%- for message in messages %}
|
|
63
|
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
|
64
|
+
{{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
|
|
65
|
+
{%- elif (message.role == "assistant" and not message.tool_calls) %}
|
|
66
|
+
{% generation %} {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} {% endgeneration %}
|
|
67
|
+
{%- elif message.role == "assistant" %}
|
|
68
|
+
{% generation %}{{- '<|im_start|>' + message.role }}
|
|
69
|
+
{%- if message.content %}
|
|
70
|
+
{{- '\\n' + message.content }}
|
|
71
|
+
{%- endif %}
|
|
72
|
+
{%- for tool_call in message.tool_calls %}
|
|
73
|
+
{%- if tool_call.function is defined %}
|
|
74
|
+
{%- set tool_call = tool_call.function %}
|
|
75
|
+
{%- endif %}
|
|
76
|
+
{{- '\\n<tool_call>\\n{\\\"name\\\": \\\"' }}
|
|
77
|
+
{{- tool_call.name }}
|
|
78
|
+
{{- '\\\", \\\"arguments\\\": ' }}
|
|
79
|
+
{{- tool_call.arguments | tojson }}
|
|
80
|
+
{{- '}\\n</tool_call>' }}
|
|
81
|
+
{%- endfor %}
|
|
82
|
+
{{- '<|im_end|>\\n' }}{% endgeneration %}
|
|
83
|
+
{%- elif message.role == "tool" %}
|
|
84
|
+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
|
85
|
+
{{- '<|im_start|>tool' }}
|
|
86
|
+
{%- endif %}
|
|
87
|
+
{{- '\\n<tool_response>\\n' }}
|
|
88
|
+
{%- if message.tool_responses %}
|
|
89
|
+
{{- message.tool_responses }}
|
|
90
|
+
{%- else %}
|
|
91
|
+
{{- message.content }}
|
|
92
|
+
{%- endif %}
|
|
93
|
+
{{- '\\n</tool_response>' }}
|
|
94
|
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
|
95
|
+
{{- '<|im_end|>\\n' }}
|
|
96
|
+
{%- endif %}
|
|
97
|
+
{%- endif %}
|
|
98
|
+
{%- endfor %}
|
|
99
|
+
{%- if add_generation_prompt %}
|
|
100
|
+
{% generation %}{{- '<|im_start|>assistant\\n' }}{% endgeneration %}
|
|
101
|
+
{%- endif %}
|
|
102
|
+
""",
|
|
103
|
+
"dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ ' ' }}{% endgeneration %}{% endif %}""",
|
|
104
|
+
"falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant: ' }}{% endgeneration %}{% endif %}""",
|
|
105
|
+
"deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] + eos_token }}{% endgeneration %}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant:' }}{% endgeneration %}{% endif %}""",
|
|
106
|
+
"llama": """{{- bos_token }}
|
|
107
|
+
{%- if messages[0]['role'] == 'system' %}
|
|
108
|
+
{%- set system_message = messages[0]['content']|trim %}
|
|
109
|
+
{%- set messages = messages[1:] %}
|
|
110
|
+
{%- else %}
|
|
111
|
+
{%- set system_message = "" %}
|
|
112
|
+
{%- endif %}
|
|
113
|
+
{%- if system_message %}
|
|
114
|
+
{{- "<|header_start|>system<|header_end|>\n\n" }}
|
|
115
|
+
{{- system_message }}
|
|
116
|
+
{{- "<|eot|>" }}
|
|
117
|
+
{%- endif %}
|
|
118
|
+
{%- for message in messages %}
|
|
119
|
+
{%- if message['role'] == 'assistant' %}
|
|
120
|
+
{% generation %}{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
|
121
|
+
{%- if message['content'] is string %}
|
|
122
|
+
{{- message['content'] }}
|
|
123
|
+
{%- else %}
|
|
124
|
+
{%- for content in message['content'] %}
|
|
125
|
+
{%- if content['type'] == 'text' %}
|
|
126
|
+
{{- content['text'] | trim }}
|
|
127
|
+
{%- endif %}
|
|
128
|
+
{%- endfor %}
|
|
129
|
+
{%- endif %}
|
|
130
|
+
{{- "<|eot|>" }}{% endgeneration %}
|
|
131
|
+
{%- else %}
|
|
132
|
+
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
|
133
|
+
{%- if message['content'] is string %}
|
|
134
|
+
{{- message['content'] }}
|
|
135
|
+
{%- else %}
|
|
136
|
+
{%- for content in message['content'] %}
|
|
137
|
+
{%- if content['type'] == 'text' %}
|
|
138
|
+
{{- content['text'] | trim }}
|
|
139
|
+
{%- endif %}
|
|
140
|
+
{%- endfor %}
|
|
141
|
+
{%- endif %}
|
|
142
|
+
{{- "<|eot|>" }}
|
|
143
|
+
{%- endif %}
|
|
144
|
+
{%- endfor %}
|
|
145
|
+
{%- if add_generation_prompt %}
|
|
146
|
+
{% generation %}{{- '<|header_start|>assistant<|header_end|>\n\n' }}{% endgeneration %}
|
|
147
|
+
{%- endif %}""",
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# Global storage for custom template metadata
|
|
151
|
+
_CUSTOM_INVERSE_PARSERS = {}
|
|
152
|
+
_CUSTOM_MODEL_FAMILY_KEYWORDS = {}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def add_chat_template(
|
|
156
|
+
template_name: str,
|
|
157
|
+
template: str,
|
|
158
|
+
inverse_parser: callable | None = None,
|
|
159
|
+
model_family_keywords: list[str] | None = None,
|
|
160
|
+
) -> None:
|
|
161
|
+
r"""Add a custom chat template to the global template dictionary.
|
|
162
|
+
|
|
163
|
+
This function allows you to add custom chat templates for new model families
|
|
164
|
+
that support assistant token masking via the `{% generation %}` keyword.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
template_name (str): The name of the template (e.g., "llama", "mistral").
|
|
168
|
+
This name will be used in the `chat_template_name` parameter of
|
|
169
|
+
`History.apply_chat_template()` and `History.from_text()`.
|
|
170
|
+
template (str): The Jinja2 template string. Must include `{% generation %}`
|
|
171
|
+
blocks around assistant message content to enable token masking.
|
|
172
|
+
inverse_parser (callable, optional): A function that parses formatted text back
|
|
173
|
+
into a History object. Should have signature `(text: str) -> History`.
|
|
174
|
+
If None, a basic parser will be used.
|
|
175
|
+
model_family_keywords (list[str], optional): Keywords to detect this model family
|
|
176
|
+
in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models.
|
|
177
|
+
If provided, the template will be automatically selected for models containing
|
|
178
|
+
these keywords in their name.
|
|
179
|
+
|
|
180
|
+
Example:
|
|
181
|
+
>>> from torchrl.data.llm.chat import add_chat_template, History
|
|
182
|
+
>>> from transformers import AutoTokenizer
|
|
183
|
+
>>>
|
|
184
|
+
>>> # Add a custom template for Llama models
|
|
185
|
+
>>> llama_template = '''
|
|
186
|
+
... {% for message in messages %}
|
|
187
|
+
... {%- if message['role'] == 'user' %}
|
|
188
|
+
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
|
|
189
|
+
... {%- elif message['role'] == 'assistant' %}
|
|
190
|
+
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
|
|
191
|
+
... {%- endif %}
|
|
192
|
+
... {% endfor %}
|
|
193
|
+
... {%- if add_generation_prompt %}
|
|
194
|
+
... {% generation %}{{ ' ' }}{% endgeneration %}
|
|
195
|
+
... {%- endif %}
|
|
196
|
+
... '''
|
|
197
|
+
>>>
|
|
198
|
+
>>> def parse_llama_text(text: str) -> History:
|
|
199
|
+
... # Custom parser for Llama format
|
|
200
|
+
... import re
|
|
201
|
+
... pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
|
|
202
|
+
... matches = re.findall(pattern, text, re.DOTALL)
|
|
203
|
+
... messages = []
|
|
204
|
+
... for user_content, assistant_content in matches:
|
|
205
|
+
... messages.append(History(role="user", content=user_content.strip()))
|
|
206
|
+
... messages.append(History(role="assistant", content=assistant_content.strip()))
|
|
207
|
+
... return lazy_stack(messages)
|
|
208
|
+
>>>
|
|
209
|
+
>>> # Add the template with auto-detection
|
|
210
|
+
>>> add_chat_template(
|
|
211
|
+
... template_name="llama",
|
|
212
|
+
... template=llama_template,
|
|
213
|
+
... inverse_parser=parse_llama_text,
|
|
214
|
+
... model_family_keywords=["llama", "meta-llama"]
|
|
215
|
+
... )
|
|
216
|
+
>>>
|
|
217
|
+
>>> # Now you can use it with auto-detection
|
|
218
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
|
219
|
+
>>> history = History.from_chats([[
|
|
220
|
+
... {"role": "user", "content": "Hello"},
|
|
221
|
+
... {"role": "assistant", "content": "Hi there!"}
|
|
222
|
+
... ]])
|
|
223
|
+
>>>
|
|
224
|
+
>>> # Auto-detection will use the llama template
|
|
225
|
+
>>> result = history.apply_chat_template(
|
|
226
|
+
... tokenizer=tokenizer,
|
|
227
|
+
... add_generation_prompt=False,
|
|
228
|
+
... return_dict=True,
|
|
229
|
+
... return_assistant_tokens_mask=True,
|
|
230
|
+
... )
|
|
231
|
+
>>>
|
|
232
|
+
>>> # Or use it explicitly
|
|
233
|
+
>>> result = history.apply_chat_template(
|
|
234
|
+
... tokenizer=tokenizer,
|
|
235
|
+
... chat_template_name="llama",
|
|
236
|
+
... add_generation_prompt=False,
|
|
237
|
+
... return_dict=True,
|
|
238
|
+
... return_assistant_tokens_mask=True,
|
|
239
|
+
... )
|
|
240
|
+
|
|
241
|
+
.. note:
|
|
242
|
+
- The template must include `{% generation %}` blocks around assistant message
|
|
243
|
+
content to enable assistant token masking.
|
|
244
|
+
- The inverse parser should handle the specific format of your template.
|
|
245
|
+
- Model family keywords are case-insensitive and matched against the tokenizer's
|
|
246
|
+
`name_or_path` attribute.
|
|
247
|
+
- Templates are stored globally and persist for the duration of the Python session.
|
|
248
|
+
"""
|
|
249
|
+
global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824
|
|
250
|
+
|
|
251
|
+
# Validate template contains generation blocks
|
|
252
|
+
if "{% generation %}" not in template:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Template '{template_name}' must include '{{% generation %}}' blocks "
|
|
255
|
+
"around assistant message content to enable token masking."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Add template to dictionary
|
|
259
|
+
_CHAT_TEMPLATES[template_name] = template
|
|
260
|
+
|
|
261
|
+
# Store inverse parser if provided
|
|
262
|
+
if inverse_parser is not None:
|
|
263
|
+
_CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser
|
|
264
|
+
|
|
265
|
+
# Store model family keywords if provided
|
|
266
|
+
if model_family_keywords is not None:
|
|
267
|
+
_CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords
|
|
268
|
+
|
|
269
|
+
torchrl_logger.info(
|
|
270
|
+
f"Added custom chat template '{template_name}' with assistant token masking support"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
|
|
275
|
+
class ContentBase(TensorClass["nocast", "shadow"]):
|
|
276
|
+
"""Base class for all message content types.
|
|
277
|
+
|
|
278
|
+
Attributes:
|
|
279
|
+
type (str): The type of the content.
|
|
280
|
+
text (str, optional): The text content.
|
|
281
|
+
url (str, optional): The URL content.
|
|
282
|
+
data (str, optional): The data content.
|
|
283
|
+
mime_type (str, optional): The MIME type of the content.
|
|
284
|
+
name (str, optional): The name of the content.
|
|
285
|
+
size (int, optional): The size of the content.
|
|
286
|
+
function_name (str, optional): The name of the function.
|
|
287
|
+
function_args (dict, optional): The arguments of the function.
|
|
288
|
+
|
|
289
|
+
Examples:
|
|
290
|
+
>>> from tensordict import lazy_stack
|
|
291
|
+
>>> content1 = ContentBase(type="text", text="Hello, world!")
|
|
292
|
+
>>> print(content1)
|
|
293
|
+
ContentBase(
|
|
294
|
+
text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
|
|
295
|
+
type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
|
|
296
|
+
url=None,
|
|
297
|
+
data=None,
|
|
298
|
+
mime_type=None,
|
|
299
|
+
name=None,
|
|
300
|
+
size=None,
|
|
301
|
+
function_name=None,
|
|
302
|
+
function_args=None,
|
|
303
|
+
batch_size=torch.Size([]),
|
|
304
|
+
device=None,
|
|
305
|
+
is_shared=False)
|
|
306
|
+
>>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
|
|
307
|
+
>>> print(content2)
|
|
308
|
+
ContentBase(
|
|
309
|
+
type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
|
|
310
|
+
url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
|
|
311
|
+
text=None,
|
|
312
|
+
data=None,
|
|
313
|
+
mime_type=None,
|
|
314
|
+
name=None,
|
|
315
|
+
size=None,
|
|
316
|
+
function_name=None,
|
|
317
|
+
function_args=None,
|
|
318
|
+
batch_size=torch.Size([]),
|
|
319
|
+
device=None,
|
|
320
|
+
is_shared=False)
|
|
321
|
+
>>> content = lazy_stack([content1, content2])
|
|
322
|
+
>>> print(content)
|
|
323
|
+
ContentBase(
|
|
324
|
+
type=NonTensorStack(
|
|
325
|
+
['text', 'image'],
|
|
326
|
+
batch_size=torch.Size([2]),
|
|
327
|
+
device=None),
|
|
328
|
+
url=None,
|
|
329
|
+
data=None,
|
|
330
|
+
mime_type=None,
|
|
331
|
+
name=None,
|
|
332
|
+
size=None,
|
|
333
|
+
function_name=None,
|
|
334
|
+
function_args=None,
|
|
335
|
+
text=None,
|
|
336
|
+
batch_size=torch.Size([2]),
|
|
337
|
+
device=None,
|
|
338
|
+
is_shared=False)
|
|
339
|
+
>>> # A content is typically used in a History object. Usually, its batch dimension is
|
|
340
|
+
>>> # one dimension greater than the History object.
|
|
341
|
+
>>> history = History(role="user", content=content)
|
|
342
|
+
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
type: Literal[
|
|
346
|
+
"text", "image", "audio", "video", "file", "function_call"
|
|
347
|
+
] # Required: "text", "image", "audio", "video", "file", "function_call"
|
|
348
|
+
|
|
349
|
+
# Text content
|
|
350
|
+
text: str | None = None
|
|
351
|
+
|
|
352
|
+
# Media/file content (either URL or data)
|
|
353
|
+
url: str | None = None # HTTP URL to content
|
|
354
|
+
data: str | None = None # Base64 encoded content
|
|
355
|
+
|
|
356
|
+
# Metadata
|
|
357
|
+
mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
|
|
358
|
+
name: str | None = None # Original filename or description
|
|
359
|
+
size: int | None = None # File size in bytes
|
|
360
|
+
|
|
361
|
+
# Function calling (for AI agents)
|
|
362
|
+
function_name: str | None = None
|
|
363
|
+
function_args: dict | None = None
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class History(TensorClass["nocast"]):
|
|
367
|
+
"""A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
|
|
368
|
+
|
|
369
|
+
The `History` class provides a centralized API for managing conversational data, offering several advantages over
|
|
370
|
+
traditional list-based approaches:
|
|
371
|
+
|
|
372
|
+
- Centralized API for conversion to and from string formats, facilitating seamless integration with language models.
|
|
373
|
+
- Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation
|
|
374
|
+
trajectories, especially useful in reinforcement learning environments.
|
|
375
|
+
- Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data.
|
|
376
|
+
- **Assistant token masking support** across multiple model families for reinforcement learning applications.
|
|
377
|
+
|
|
378
|
+
**Recent Changes:**
|
|
379
|
+
- **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory`
|
|
380
|
+
containers for structured conversation management in LLM environments.
|
|
381
|
+
- **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"`
|
|
382
|
+
is specified, providing consistent conversation state management.
|
|
383
|
+
- **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking.
|
|
384
|
+
|
|
385
|
+
.. note:: The `"<none>"` role is used to indicate that the element is a placeholder,
|
|
386
|
+
for example when the tool call was not executed but a stack requires a certain number of elements
|
|
387
|
+
per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template`
|
|
388
|
+
method will remove the `<none>` role from the history.
|
|
389
|
+
|
|
390
|
+
**Assistant Token Masking Support:**
|
|
391
|
+
|
|
392
|
+
The class supports assistant token masking across multiple model families, allowing you to identify which tokens
|
|
393
|
+
in a conversation were generated by the assistant. This is crucial for reinforcement learning applications.
|
|
394
|
+
|
|
395
|
+
**Supported Model Families:**
|
|
396
|
+
|
|
397
|
+
- **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support
|
|
398
|
+
- **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format
|
|
399
|
+
- **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format
|
|
400
|
+
- **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format
|
|
401
|
+
- **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template
|
|
402
|
+
|
|
403
|
+
**Example with Assistant Token Masking:**
|
|
404
|
+
|
|
405
|
+
.. code-block:: python
|
|
406
|
+
|
|
407
|
+
>>> from torchrl.data.llm.chat import History
|
|
408
|
+
>>> from torchrl.modules.llm.policies import ChatHistory
|
|
409
|
+
>>> from transformers import AutoTokenizer
|
|
410
|
+
>>>
|
|
411
|
+
>>> # Create a conversation history
|
|
412
|
+
>>> history = History.from_chats([[
|
|
413
|
+
... {"role": "user", "content": "Hello"},
|
|
414
|
+
... {"role": "assistant", "content": "Hi there!"},
|
|
415
|
+
... {"role": "user", "content": "How are you?"},
|
|
416
|
+
... {"role": "assistant", "content": "I'm doing well, thanks!"}
|
|
417
|
+
... ]])
|
|
418
|
+
>>>
|
|
419
|
+
>>> # Create ChatHistory container for LLM wrapper
|
|
420
|
+
>>> chat_history = ChatHistory(prompt=history)
|
|
421
|
+
>>>
|
|
422
|
+
>>> # Load any supported tokenizer
|
|
423
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|
424
|
+
>>>
|
|
425
|
+
>>> # Apply chat template with assistant token masking
|
|
426
|
+
>>> result = history.apply_chat_template(
|
|
427
|
+
... tokenizer=tokenizer,
|
|
428
|
+
... add_generation_prompt=False,
|
|
429
|
+
... return_dict=True,
|
|
430
|
+
... return_assistant_tokens_mask=True,
|
|
431
|
+
... )
|
|
432
|
+
>>>
|
|
433
|
+
>>> # The result contains an assistant_masks tensor
|
|
434
|
+
>>> assistant_masks = result["assistant_masks"]
|
|
435
|
+
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
|
|
436
|
+
|
|
437
|
+
**Integration with LLM Wrappers:**
|
|
438
|
+
|
|
439
|
+
History objects work seamlessly with the new modular wrapper design:
|
|
440
|
+
|
|
441
|
+
.. code-block:: python
|
|
442
|
+
|
|
443
|
+
>>> from torchrl.modules.llm import TransformersWrapper
|
|
444
|
+
>>> from torchrl.modules.llm.policies import ChatHistory
|
|
445
|
+
>>>
|
|
446
|
+
>>> # Create wrapper with history input mode
|
|
447
|
+
>>> wrapper = TransformersWrapper(
|
|
448
|
+
... model, tokenizer=tokenizer,
|
|
449
|
+
... input_mode="history",
|
|
450
|
+
... generate=True,
|
|
451
|
+
... return_log_probs=True
|
|
452
|
+
... )
|
|
453
|
+
>>>
|
|
454
|
+
>>> # Use History with ChatHistory container
|
|
455
|
+
>>> history = History.from_chats([[
|
|
456
|
+
... {"role": "user", "content": "Hello"},
|
|
457
|
+
... {"role": "assistant", "content": "Hi there!"}
|
|
458
|
+
... ]])
|
|
459
|
+
>>> chat_history = ChatHistory(prompt=history)
|
|
460
|
+
>>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
|
|
461
|
+
>>> print(result["history"].response) # New response from LLM
|
|
462
|
+
|
|
463
|
+
Attributes:
|
|
464
|
+
role (str): The role of the message sender.
|
|
465
|
+
content (str): The content of the message.
|
|
466
|
+
is_complete (bool): Whether the message was properly terminated with an end token. Defaults to `True`.
|
|
467
|
+
tool_calls (list[dict] | None): Optional list of tool calls in the message.
|
|
468
|
+
tool_responses (list[str] | None): Optional list of tool responses.
|
|
469
|
+
|
|
470
|
+
Methods:
|
|
471
|
+
apply_chat_template: converts the `History` object to str / tokens.
|
|
472
|
+
append: append one element to the list of items along a given dimension.
|
|
473
|
+
extend: extend the list of items along a given dimension.
|
|
474
|
+
|
|
475
|
+
Examples:
|
|
476
|
+
>>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches
|
|
477
|
+
>>> import tensordict
|
|
478
|
+
>>> tensordict.set_list_to_stack(True).set()
|
|
479
|
+
>>> import transformers
|
|
480
|
+
>>> history0 = History(
|
|
481
|
+
... role='system',
|
|
482
|
+
... content='''CONTENT
|
|
483
|
+
... This is the setup''',
|
|
484
|
+
... )
|
|
485
|
+
>>> history1 = History(
|
|
486
|
+
... role='user',
|
|
487
|
+
... content='''CONTENT
|
|
488
|
+
... This is the first user prompt''',
|
|
489
|
+
... )
|
|
490
|
+
>>> history2 = History(
|
|
491
|
+
... role='assistant',
|
|
492
|
+
... content='''CONTENT
|
|
493
|
+
... This is the second prompt, the first for the assistant.''',
|
|
494
|
+
... )
|
|
495
|
+
>>> history = torch.stack([history0, history1, history2])
|
|
496
|
+
>>> assert history.role == ['system', 'user', 'assistant']
|
|
497
|
+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2")
|
|
498
|
+
>>> # Apply a template to pass the history to an LLM. Note that the output has
|
|
499
|
+
>>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument.
|
|
500
|
+
>>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True)
|
|
501
|
+
>>> parsed_string
|
|
502
|
+
<|im_start|>system
|
|
503
|
+
CONTENT
|
|
504
|
+
This is the setup<|im_end|>
|
|
505
|
+
|
|
506
|
+
<|im_start|>user
|
|
507
|
+
CONTENT
|
|
508
|
+
This is the first user prompt<|im_end|>
|
|
509
|
+
|
|
510
|
+
<|im_start|>assistant
|
|
511
|
+
CONTENT
|
|
512
|
+
This is the second prompt, the first for the assistant.<|im_end|>
|
|
513
|
+
|
|
514
|
+
<|im_start|>assistant
|
|
515
|
+
|
|
516
|
+
.. seealso::
|
|
517
|
+
:class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments.
|
|
518
|
+
:class:`~torchrl.modules.llm.policies.Text`: Container for text data.
|
|
519
|
+
:class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
role: str | list[str] | list[list[str]]
|
|
523
|
+
content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
|
|
524
|
+
list[ContentBase]
|
|
525
|
+
]
|
|
526
|
+
is_complete: bool = True
|
|
527
|
+
tool_calls: list[dict] | None = None
|
|
528
|
+
tool_responses: list[str] | None = None
|
|
529
|
+
|
|
530
|
+
def __post_init__(self):
|
|
531
|
+
if not list_to_stack():
|
|
532
|
+
raise RuntimeError(
|
|
533
|
+
"Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, "
|
|
534
|
+
"or the LIST_TO_STACK=1 environment variable."
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def apply_chat_template(
|
|
538
|
+
self,
|
|
539
|
+
*,
|
|
540
|
+
tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
|
|
541
|
+
add_generation_prompt: bool = True,
|
|
542
|
+
chat_template: str | None = None,
|
|
543
|
+
chat_template_name: str | None = None,
|
|
544
|
+
continue_final_message: bool = False,
|
|
545
|
+
tokenize: bool | None = None,
|
|
546
|
+
padding: bool | str = False,
|
|
547
|
+
truncation: bool | str = False,
|
|
548
|
+
return_tensors: str | None = None,
|
|
549
|
+
return_dict: bool | None = None,
|
|
550
|
+
return_assistant_tokens_mask: bool = False,
|
|
551
|
+
**kwargs,
|
|
552
|
+
) -> str | list[str] | TensorDict:
|
|
553
|
+
"""Applies a chat template to the history.
|
|
554
|
+
|
|
555
|
+
Keyword Args:
|
|
556
|
+
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
|
|
557
|
+
add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
|
|
558
|
+
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
|
|
559
|
+
chat_template_name (str, optional): The name of the chat template to use.
|
|
560
|
+
Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template.
|
|
561
|
+
Defaults to `None`.
|
|
562
|
+
continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
|
|
563
|
+
tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
|
|
564
|
+
padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
|
|
565
|
+
truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
|
|
566
|
+
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
|
|
567
|
+
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
|
|
568
|
+
return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
|
|
569
|
+
If `True`, the mask will be written to the `assistant_masks` key.
|
|
570
|
+
For tokens generated by the assistant, the mask will contain `1`.
|
|
571
|
+
For user and system tokens, the mask will contain `0`.
|
|
572
|
+
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
|
573
|
+
Defaults to `False`.
|
|
574
|
+
|
|
575
|
+
.. note:: Assistant token masking is supported across multiple model families:
|
|
576
|
+
- **Qwen family**: Uses custom template with full tool calling support
|
|
577
|
+
- **DialoGPT family**: Uses custom template for conversation format
|
|
578
|
+
- **Falcon family**: Uses custom template for instruction format
|
|
579
|
+
- **DeepSeek family**: Uses custom template with native format
|
|
580
|
+
- **Other models**: Use the default `chatml_format` template
|
|
581
|
+
|
|
582
|
+
The method automatically detects the model family and selects the appropriate template.
|
|
583
|
+
|
|
584
|
+
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
The formatted history.
|
|
588
|
+
"""
|
|
589
|
+
if chat_template is None:
|
|
590
|
+
if chat_template_name is not None:
|
|
591
|
+
chat_template = _CHAT_TEMPLATES[chat_template_name]
|
|
592
|
+
chat_template_name = None
|
|
593
|
+
elif tokenizer is None:
|
|
594
|
+
raise RuntimeError(
|
|
595
|
+
"You must specify a tokenizer to use when chat_template is not specified."
|
|
596
|
+
)
|
|
597
|
+
else:
|
|
598
|
+
# Auto-detect model family and use appropriate template
|
|
599
|
+
model_name = getattr(tokenizer, "name_or_path", "").lower()
|
|
600
|
+
|
|
601
|
+
# First check for custom model family keywords
|
|
602
|
+
custom_template_found = False
|
|
603
|
+
for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
|
|
604
|
+
if any(keyword.lower() in model_name for keyword in keywords):
|
|
605
|
+
chat_template = _CHAT_TEMPLATES[template_name]
|
|
606
|
+
chat_template_name = None
|
|
607
|
+
custom_template_found = True
|
|
608
|
+
break
|
|
609
|
+
|
|
610
|
+
if not custom_template_found:
|
|
611
|
+
# Fall back to built-in model family detection
|
|
612
|
+
if "qwen" in model_name:
|
|
613
|
+
# We prefer our implementation of the Qwen template,
|
|
614
|
+
# since it accounts for the assistant's masking.
|
|
615
|
+
chat_template = _CHAT_TEMPLATES["qwen"]
|
|
616
|
+
chat_template_name = None
|
|
617
|
+
elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
|
|
618
|
+
# DialoGPT family - use our custom template
|
|
619
|
+
chat_template = _CHAT_TEMPLATES["dialogpt"]
|
|
620
|
+
chat_template_name = None
|
|
621
|
+
elif "falcon" in model_name or "tiiuae/falcon" in model_name:
|
|
622
|
+
# Falcon family - use our custom template
|
|
623
|
+
chat_template = _CHAT_TEMPLATES["falcon"]
|
|
624
|
+
chat_template_name = None
|
|
625
|
+
elif "deepseek" in model_name:
|
|
626
|
+
# DeepSeek family - use our custom template with generation keyword
|
|
627
|
+
chat_template = _CHAT_TEMPLATES["deepseek"]
|
|
628
|
+
chat_template_name = None
|
|
629
|
+
elif "llama" in model_name:
|
|
630
|
+
# Llama family - use our custom template
|
|
631
|
+
chat_template = _CHAT_TEMPLATES["llama"]
|
|
632
|
+
chat_template_name = None
|
|
633
|
+
else:
|
|
634
|
+
# For other models, check if their default template supports generation
|
|
635
|
+
if (
|
|
636
|
+
hasattr(tokenizer, "chat_template")
|
|
637
|
+
and tokenizer.chat_template
|
|
638
|
+
and "{% generation %}" in tokenizer.chat_template
|
|
639
|
+
):
|
|
640
|
+
# Use the model's own template if it supports generation
|
|
641
|
+
chat_template = tokenizer.chat_template
|
|
642
|
+
else:
|
|
643
|
+
# Use our default chatml_format template
|
|
644
|
+
chat_template = _CHAT_TEMPLATES["chatml_format"]
|
|
645
|
+
if chat_template is None:
|
|
646
|
+
chat_template = _CHAT_TEMPLATES["chatml_format"]
|
|
647
|
+
if tokenize is None:
|
|
648
|
+
if return_assistant_tokens_mask or return_tensors is not None:
|
|
649
|
+
tokenize = True
|
|
650
|
+
else:
|
|
651
|
+
tokenize = False
|
|
652
|
+
if tokenize:
|
|
653
|
+
if return_tensors is None:
|
|
654
|
+
return_tensors = "pt"
|
|
655
|
+
if return_dict is None and return_assistant_tokens_mask:
|
|
656
|
+
return_dict = True
|
|
657
|
+
elif return_dict is None:
|
|
658
|
+
return_dict = False
|
|
659
|
+
|
|
660
|
+
if self.ndim > 1:
|
|
661
|
+
result = [
|
|
662
|
+
self[i].apply_chat_template(
|
|
663
|
+
tokenizer=tokenizer,
|
|
664
|
+
add_generation_prompt=add_generation_prompt,
|
|
665
|
+
chat_template=chat_template,
|
|
666
|
+
chat_template_name=chat_template_name,
|
|
667
|
+
tokenize=tokenize,
|
|
668
|
+
padding=padding,
|
|
669
|
+
truncation=truncation,
|
|
670
|
+
return_tensors=return_tensors,
|
|
671
|
+
continue_final_message=continue_final_message,
|
|
672
|
+
return_dict=return_dict,
|
|
673
|
+
return_assistant_tokens_mask=return_assistant_tokens_mask,
|
|
674
|
+
**kwargs,
|
|
675
|
+
)
|
|
676
|
+
for i in range(self.batch_size[0])
|
|
677
|
+
]
|
|
678
|
+
if return_dict:
|
|
679
|
+
return lazy_stack(result)
|
|
680
|
+
else:
|
|
681
|
+
return result
|
|
682
|
+
self_flat = self.view(-1)
|
|
683
|
+
# tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
|
|
684
|
+
self_flat = self_flat.tolist(tolist_first=True)
|
|
685
|
+
# Remove the "<none>" role
|
|
686
|
+
self_flat = [item for item in self_flat if item["role"] != "<none>"]
|
|
687
|
+
result = tokenizer.apply_chat_template(
|
|
688
|
+
conversation=self_flat,
|
|
689
|
+
add_generation_prompt=add_generation_prompt,
|
|
690
|
+
chat_template=chat_template,
|
|
691
|
+
tokenize=tokenize,
|
|
692
|
+
padding=padding,
|
|
693
|
+
truncation=truncation,
|
|
694
|
+
return_tensors=return_tensors,
|
|
695
|
+
continue_final_message=continue_final_message,
|
|
696
|
+
return_dict=return_dict,
|
|
697
|
+
return_assistant_tokens_mask=return_assistant_tokens_mask,
|
|
698
|
+
**kwargs,
|
|
699
|
+
)
|
|
700
|
+
if not isinstance(result, (torch.Tensor, list, str)):
|
|
701
|
+
result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
|
|
702
|
+
# If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
|
|
703
|
+
if self.batch_dims == 1:
|
|
704
|
+
if result.batch_size[0] != 1:
|
|
705
|
+
raise RuntimeError(
|
|
706
|
+
f"Expected a batch size of 1, got {result.batch_size[0]}."
|
|
707
|
+
)
|
|
708
|
+
result = result.squeeze(0)
|
|
709
|
+
return result
|
|
710
|
+
|
|
711
|
+
@classmethod
|
|
712
|
+
def from_text(
|
|
713
|
+
cls,
|
|
714
|
+
text: str | list[str],
|
|
715
|
+
chat_template_name: str | None = None,
|
|
716
|
+
# currently without effect
|
|
717
|
+
chat_template: str | None = None,
|
|
718
|
+
tokenizer: transformers.AutoTokenizer # noqa: F821
|
|
719
|
+
| transformers.AutoProcessor # noqa: F821
|
|
720
|
+
| None = None,
|
|
721
|
+
) -> History:
|
|
722
|
+
r"""Inverts a chat template into a History object.
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
text (str | list[str]): The chat template to invert.
|
|
726
|
+
chat_template_name (str, optional): The name of the chat template to use.
|
|
727
|
+
tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
|
|
728
|
+
|
|
729
|
+
Returns:
|
|
730
|
+
History: The inverted History object.
|
|
731
|
+
|
|
732
|
+
Examples:
|
|
733
|
+
>>> from torchrl.data.llm.history import History
|
|
734
|
+
>>> from transformers import AutoTokenizer
|
|
735
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
|
736
|
+
>>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\n<think>The capital of France is Paris, the capital of Germany is Berlin.</think>\n<answer><python>\n"
|
|
737
|
+
>>> history = History.from_text(text, tokenizer=tokenizer)
|
|
738
|
+
>>> print(history)
|
|
739
|
+
History(
|
|
740
|
+
content=NonTensorStack(
|
|
741
|
+
['You are a helpful assistant.', 'Write a python s...,
|
|
742
|
+
batch_size=torch.Size([3]),
|
|
743
|
+
device=None),
|
|
744
|
+
is_complete=NonTensorStack(
|
|
745
|
+
[True, True, False],
|
|
746
|
+
batch_size=torch.Size([3]),
|
|
747
|
+
device=None),
|
|
748
|
+
role=NonTensorStack(
|
|
749
|
+
['system', 'user', 'assistant'],
|
|
750
|
+
batch_size=torch.Size([3]),
|
|
751
|
+
device=None),
|
|
752
|
+
tool_calls=None,
|
|
753
|
+
tool_responses=None,
|
|
754
|
+
batch_size=torch.Size([3]),
|
|
755
|
+
device=None,
|
|
756
|
+
is_shared=False)
|
|
757
|
+
"""
|
|
758
|
+
if chat_template_name is None:
|
|
759
|
+
if chat_template is not None:
|
|
760
|
+
# TODO: find best match given template
|
|
761
|
+
pass
|
|
762
|
+
|
|
763
|
+
model_name = getattr(tokenizer, "name_or_path", "").lower()
|
|
764
|
+
# First check for custom model family keywords
|
|
765
|
+
custom_template_found = False
|
|
766
|
+
for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
|
|
767
|
+
if any(keyword.lower() in model_name for keyword in keywords):
|
|
768
|
+
chat_template_name = template_name
|
|
769
|
+
custom_template_found = True
|
|
770
|
+
break
|
|
771
|
+
|
|
772
|
+
if not custom_template_found:
|
|
773
|
+
# Fall back to built-in model family detection
|
|
774
|
+
if "qwen" in model_name:
|
|
775
|
+
# We can automatically detect the template name from the tokenizer
|
|
776
|
+
# and use the precoded parser.
|
|
777
|
+
chat_template_name = "qwen"
|
|
778
|
+
elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
|
|
779
|
+
chat_template_name = "dialogpt"
|
|
780
|
+
elif "falcon" in model_name or "tiiuae/falcon" in model_name:
|
|
781
|
+
chat_template_name = "falcon"
|
|
782
|
+
elif "deepseek" in model_name:
|
|
783
|
+
chat_template_name = "deepseek"
|
|
784
|
+
elif "llama" in model_name:
|
|
785
|
+
chat_template_name = "llama"
|
|
786
|
+
else:
|
|
787
|
+
chat_template_name = "chatml_format"
|
|
788
|
+
|
|
789
|
+
# Get the appropriate inverse parser function
|
|
790
|
+
if chat_template_name in ("chatml_format",):
|
|
791
|
+
func = cls._inv_chatml
|
|
792
|
+
elif chat_template_name in ("qwen",):
|
|
793
|
+
func = cls._inv_qwen
|
|
794
|
+
elif chat_template_name in ("dialogpt",):
|
|
795
|
+
func = cls._inv_dialogpt
|
|
796
|
+
elif chat_template_name in ("falcon",):
|
|
797
|
+
func = cls._inv_falcon
|
|
798
|
+
elif chat_template_name in ("deepseek",):
|
|
799
|
+
func = cls._inv_deepseek
|
|
800
|
+
elif chat_template_name in ("llama",):
|
|
801
|
+
func = cls._inv_llama
|
|
802
|
+
elif chat_template_name in _CUSTOM_INVERSE_PARSERS:
|
|
803
|
+
# Use custom inverse parser
|
|
804
|
+
func = _CUSTOM_INVERSE_PARSERS[chat_template_name]
|
|
805
|
+
else:
|
|
806
|
+
raise NotImplementedError(
|
|
807
|
+
f"chat_template_name '{chat_template_name}' is not supported. "
|
|
808
|
+
"Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. "
|
|
809
|
+
"Use add_chat_template() to add custom templates."
|
|
810
|
+
)
|
|
811
|
+
if isinstance(text, list):
|
|
812
|
+
list_of_histories = [func(t) for t in text]
|
|
813
|
+
try:
|
|
814
|
+
return lazy_stack(list_of_histories)
|
|
815
|
+
except RuntimeError as e:
|
|
816
|
+
raise RuntimeError(
|
|
817
|
+
f"Failed to stack histories: {list_of_histories=}"
|
|
818
|
+
) from e
|
|
819
|
+
return func(text)
|
|
820
|
+
|
|
821
|
+
@classmethod
|
|
822
|
+
def _inv_chatml(cls, text: str) -> History:
|
|
823
|
+
"""Inverts a chatml string into a History object.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
text (str): The chatml string to invert.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
History: The inverted History object.
|
|
830
|
+
"""
|
|
831
|
+
import json
|
|
832
|
+
|
|
833
|
+
torchrl_logger.debug(f"Inverting chatml:\n{text}")
|
|
834
|
+
# Find all complete blocks (ending with im_end or endoftext)
|
|
835
|
+
complete_pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|(im_end|endoftext)\|>"
|
|
836
|
+
complete_matches = re.findall(complete_pattern, text, flags=re.DOTALL)
|
|
837
|
+
|
|
838
|
+
# Find any incomplete block at the end
|
|
839
|
+
incomplete_pattern = r"<\|im_start\|>(.*?)\n(.*?)$"
|
|
840
|
+
incomplete_matches = []
|
|
841
|
+
if complete_matches:
|
|
842
|
+
# Look for incomplete block after the last complete one
|
|
843
|
+
last_complete = complete_matches[-1]
|
|
844
|
+
last_complete_text = f"<|im_start|>{last_complete[0]}\n{last_complete[1]}<|{last_complete[2]}|>"
|
|
845
|
+
remaining_text = text[
|
|
846
|
+
text.rindex(last_complete_text) + len(last_complete_text) :
|
|
847
|
+
]
|
|
848
|
+
if remaining_text.strip():
|
|
849
|
+
incomplete_match = re.search(
|
|
850
|
+
incomplete_pattern, remaining_text, flags=re.DOTALL
|
|
851
|
+
)
|
|
852
|
+
if incomplete_match:
|
|
853
|
+
incomplete_matches = [
|
|
854
|
+
(incomplete_match.group(1), incomplete_match.group(2), None)
|
|
855
|
+
]
|
|
856
|
+
else:
|
|
857
|
+
# No complete blocks, check entire text for incomplete block
|
|
858
|
+
incomplete_match = re.search(incomplete_pattern, text, flags=re.DOTALL)
|
|
859
|
+
if incomplete_match:
|
|
860
|
+
incomplete_matches = [
|
|
861
|
+
(incomplete_match.group(1), incomplete_match.group(2), None)
|
|
862
|
+
]
|
|
863
|
+
|
|
864
|
+
# Combine complete and incomplete matches
|
|
865
|
+
matches = complete_matches + incomplete_matches
|
|
866
|
+
|
|
867
|
+
# Define tool patterns - same as Qwen for consistency
|
|
868
|
+
tool_call_pattern = re.compile(r"<tool_call>\n(.*?)\n</tool_call>", re.DOTALL)
|
|
869
|
+
tool_response_pattern = re.compile(
|
|
870
|
+
r"<tool_response>\n(.*?)\n</tool_response>", re.DOTALL
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
parsed_messages = []
|
|
874
|
+
for match in matches:
|
|
875
|
+
role = match[0].strip()
|
|
876
|
+
content = match[1].strip()
|
|
877
|
+
is_complete = match[2] is not None # None indicates incomplete
|
|
878
|
+
|
|
879
|
+
# Initialize message dict
|
|
880
|
+
message_dict = {
|
|
881
|
+
"role": role,
|
|
882
|
+
"content": content,
|
|
883
|
+
"is_complete": is_complete,
|
|
884
|
+
"tool_calls": None,
|
|
885
|
+
"tool_responses": None,
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
# Find tool calls within the message
|
|
889
|
+
tool_calls = tool_call_pattern.findall(content)
|
|
890
|
+
if tool_calls:
|
|
891
|
+
tool_calls_list = []
|
|
892
|
+
for tool_call in tool_calls:
|
|
893
|
+
try:
|
|
894
|
+
tool_call_dict = json.loads(tool_call)
|
|
895
|
+
tool_calls_list.append(tool_call_dict)
|
|
896
|
+
except json.JSONDecodeError:
|
|
897
|
+
continue
|
|
898
|
+
if tool_calls_list:
|
|
899
|
+
message_dict["tool_calls"] = tool_calls_list
|
|
900
|
+
|
|
901
|
+
# Check for tool responses
|
|
902
|
+
tool_responses = tool_response_pattern.findall(content)
|
|
903
|
+
if tool_responses:
|
|
904
|
+
message_dict["tool_responses"] = tool_responses
|
|
905
|
+
|
|
906
|
+
parsed_messages.append(cls(**message_dict))
|
|
907
|
+
|
|
908
|
+
if not parsed_messages:
|
|
909
|
+
raise RuntimeError(
|
|
910
|
+
f"Couldn't get a single item out of text {text}. A common cause "
|
|
911
|
+
f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
return lazy_stack(parsed_messages)
|
|
915
|
+
|
|
916
|
+
@classmethod
|
|
917
|
+
def _inv_qwen(cls, template):
|
|
918
|
+
import json
|
|
919
|
+
|
|
920
|
+
# Define regex patterns for different parts of the template
|
|
921
|
+
message_pattern = re.compile(
|
|
922
|
+
r"<\|im_start\|>(.*?)(?:<\|(im_end|endoftext)\|>|$)", re.DOTALL
|
|
923
|
+
)
|
|
924
|
+
tool_call_pattern = re.compile(r"<tool_call>\n(.*?)\n</tool_call>", re.DOTALL)
|
|
925
|
+
tool_response_pattern = re.compile(
|
|
926
|
+
r"<tool_response>\n(.*?)\n</tool_response>", re.DOTALL
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
# Find all messages and track if they end with a proper token
|
|
930
|
+
messages = []
|
|
931
|
+
is_complete_list = []
|
|
932
|
+
for match in message_pattern.finditer(template):
|
|
933
|
+
full_match = match.group(0)
|
|
934
|
+
messages.append(match.group(1))
|
|
935
|
+
# Check if the message ends with a proper token
|
|
936
|
+
is_complete_list.append(
|
|
937
|
+
full_match.endswith("<|im_end|>")
|
|
938
|
+
or full_match.endswith("<|endoftext|>")
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
parsed_messages = []
|
|
942
|
+
for message, is_complete in zip(messages, is_complete_list):
|
|
943
|
+
# Split the message into role and content
|
|
944
|
+
parts = message.split("\n", 1)
|
|
945
|
+
if len(parts) < 2:
|
|
946
|
+
continue
|
|
947
|
+
role, content = parts[0], parts[1]
|
|
948
|
+
|
|
949
|
+
# Initialize message dict
|
|
950
|
+
message_dict = {
|
|
951
|
+
"role": role.strip(),
|
|
952
|
+
"content": content.strip(),
|
|
953
|
+
"is_complete": is_complete,
|
|
954
|
+
"tool_calls": None,
|
|
955
|
+
"tool_responses": None,
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
# Find tool calls within the message
|
|
959
|
+
tool_calls = tool_call_pattern.findall(content)
|
|
960
|
+
if tool_calls:
|
|
961
|
+
tool_calls_list = []
|
|
962
|
+
for tool_call in tool_calls:
|
|
963
|
+
try:
|
|
964
|
+
tool_call_dict = json.loads(tool_call)
|
|
965
|
+
tool_calls_list.append(tool_call_dict)
|
|
966
|
+
except json.JSONDecodeError:
|
|
967
|
+
continue
|
|
968
|
+
if tool_calls_list:
|
|
969
|
+
message_dict["tool_calls"] = tool_calls_list
|
|
970
|
+
|
|
971
|
+
# Check for tool responses
|
|
972
|
+
tool_responses = tool_response_pattern.findall(content)
|
|
973
|
+
if tool_responses:
|
|
974
|
+
message_dict["tool_responses"] = tool_responses
|
|
975
|
+
|
|
976
|
+
parsed_messages.append(cls(**message_dict))
|
|
977
|
+
|
|
978
|
+
if not parsed_messages:
|
|
979
|
+
raise RuntimeError(
|
|
980
|
+
f"Couldn't get a single item out of text {template}. A common cause "
|
|
981
|
+
f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?"
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
return lazy_stack(parsed_messages)
|
|
985
|
+
|
|
986
|
+
@classmethod
|
|
987
|
+
def _inv_dialogpt(cls, text: str) -> History:
|
|
988
|
+
"""Inverts a DialogPT string into a History object.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
text (str): The DialogPT string to invert.
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
History: The inverted History object.
|
|
995
|
+
"""
|
|
996
|
+
torchrl_logger.debug(f"Inverting DialogPT:\n{text}")
|
|
997
|
+
|
|
998
|
+
# DialogPT format is simple: alternating user/assistant messages
|
|
999
|
+
# Split by lines and parse
|
|
1000
|
+
lines = text.strip().split("\n")
|
|
1001
|
+
parsed_messages = []
|
|
1002
|
+
|
|
1003
|
+
for line in lines:
|
|
1004
|
+
line = line.strip()
|
|
1005
|
+
if not line:
|
|
1006
|
+
continue
|
|
1007
|
+
|
|
1008
|
+
# Determine role based on content
|
|
1009
|
+
if line.startswith("Assistant:"):
|
|
1010
|
+
role = "assistant"
|
|
1011
|
+
content = line[len("Assistant:") :].strip()
|
|
1012
|
+
elif line.startswith("User:"):
|
|
1013
|
+
role = "user"
|
|
1014
|
+
content = line[len("User:") :].strip()
|
|
1015
|
+
else:
|
|
1016
|
+
# Default to user if no prefix
|
|
1017
|
+
role = "user"
|
|
1018
|
+
content = line
|
|
1019
|
+
|
|
1020
|
+
message_dict = {
|
|
1021
|
+
"role": role,
|
|
1022
|
+
"content": content,
|
|
1023
|
+
"is_complete": True, # DialogPT doesn't have explicit end tokens
|
|
1024
|
+
"tool_calls": None,
|
|
1025
|
+
"tool_responses": None,
|
|
1026
|
+
}
|
|
1027
|
+
|
|
1028
|
+
parsed_messages.append(cls(**message_dict))
|
|
1029
|
+
|
|
1030
|
+
if not parsed_messages:
|
|
1031
|
+
raise RuntimeError(f"Couldn't get a single item out of text {text}.")
|
|
1032
|
+
|
|
1033
|
+
return lazy_stack(parsed_messages)
|
|
1034
|
+
|
|
1035
|
+
@classmethod
|
|
1036
|
+
def _inv_falcon(cls, text: str) -> History:
|
|
1037
|
+
"""Inverts a Falcon string into a History object.
|
|
1038
|
+
|
|
1039
|
+
Args:
|
|
1040
|
+
text (str): The Falcon string to invert.
|
|
1041
|
+
|
|
1042
|
+
Returns:
|
|
1043
|
+
History: The inverted History object.
|
|
1044
|
+
"""
|
|
1045
|
+
torchrl_logger.debug(f"Inverting Falcon:\n{text}")
|
|
1046
|
+
|
|
1047
|
+
# Falcon format: "User: ... Assistant: ..."
|
|
1048
|
+
# Split by "User:" and "Assistant:" prefixes
|
|
1049
|
+
import re
|
|
1050
|
+
|
|
1051
|
+
# Pattern to match User: and Assistant: messages
|
|
1052
|
+
pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
|
|
1053
|
+
matches = re.findall(pattern, text, re.DOTALL)
|
|
1054
|
+
|
|
1055
|
+
parsed_messages = []
|
|
1056
|
+
for match in matches:
|
|
1057
|
+
if len(match) != 2:
|
|
1058
|
+
continue
|
|
1059
|
+
prefix, content = match
|
|
1060
|
+
content = content.strip()
|
|
1061
|
+
if not content:
|
|
1062
|
+
continue
|
|
1063
|
+
|
|
1064
|
+
if prefix == "User:":
|
|
1065
|
+
role = "user"
|
|
1066
|
+
elif prefix == "Assistant:":
|
|
1067
|
+
role = "assistant"
|
|
1068
|
+
else:
|
|
1069
|
+
continue
|
|
1070
|
+
|
|
1071
|
+
message_dict = {
|
|
1072
|
+
"role": role,
|
|
1073
|
+
"content": content,
|
|
1074
|
+
"is_complete": True, # Falcon doesn't have explicit end tokens
|
|
1075
|
+
"tool_calls": None,
|
|
1076
|
+
"tool_responses": None,
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
parsed_messages.append(cls(**message_dict))
|
|
1080
|
+
|
|
1081
|
+
if not parsed_messages:
|
|
1082
|
+
raise RuntimeError(f"Couldn't get a single item out of text {text}.")
|
|
1083
|
+
|
|
1084
|
+
return lazy_stack(parsed_messages)
|
|
1085
|
+
|
|
1086
|
+
@classmethod
|
|
1087
|
+
def _inv_deepseek(cls, text: str) -> History:
|
|
1088
|
+
"""Inverts a DeepSeek string into a History object.
|
|
1089
|
+
|
|
1090
|
+
Args:
|
|
1091
|
+
text (str): The DeepSeek string to invert.
|
|
1092
|
+
|
|
1093
|
+
Returns:
|
|
1094
|
+
History: The inverted History object.
|
|
1095
|
+
"""
|
|
1096
|
+
torchrl_logger.debug(f"Inverting DeepSeek:\n{text}")
|
|
1097
|
+
import re
|
|
1098
|
+
|
|
1099
|
+
# Remove leading/trailing special tokens (e.g.
|
|
1100
|
+
text = re.sub(r"^<[^>]+>", "", text) # Remove leading <...>
|
|
1101
|
+
text = re.sub(r"<[^>]+>$", "", text) # Remove trailing <...>
|
|
1102
|
+
# Remove any REDACTED_SPECIAL_TOKEN if present
|
|
1103
|
+
text = re.sub(r"REDACTED_SPECIAL_TOKEN", "", text)
|
|
1104
|
+
# Pattern to match User: and Assistant: messages
|
|
1105
|
+
pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
|
|
1106
|
+
matches = re.findall(pattern, text, re.DOTALL)
|
|
1107
|
+
parsed_messages = []
|
|
1108
|
+
for match in matches:
|
|
1109
|
+
if len(match) < 2:
|
|
1110
|
+
continue
|
|
1111
|
+
prefix, content = match[0], match[1]
|
|
1112
|
+
content = content.strip()
|
|
1113
|
+
if not content:
|
|
1114
|
+
continue
|
|
1115
|
+
if prefix == "User:":
|
|
1116
|
+
role = "user"
|
|
1117
|
+
elif prefix == "Assistant:":
|
|
1118
|
+
role = "assistant"
|
|
1119
|
+
else:
|
|
1120
|
+
continue
|
|
1121
|
+
message_dict = {
|
|
1122
|
+
"role": role,
|
|
1123
|
+
"content": content,
|
|
1124
|
+
"is_complete": True, # DeepSeek doesn't have explicit end tokens
|
|
1125
|
+
"tool_calls": None,
|
|
1126
|
+
"tool_responses": None,
|
|
1127
|
+
}
|
|
1128
|
+
parsed_messages.append(cls(**message_dict))
|
|
1129
|
+
if not parsed_messages:
|
|
1130
|
+
raise RuntimeError(f"Couldn't get a single item out of text {text}.")
|
|
1131
|
+
return lazy_stack(parsed_messages)
|
|
1132
|
+
|
|
1133
|
+
@classmethod
|
|
1134
|
+
def _inv_llama(cls, text: str) -> History:
|
|
1135
|
+
import re
|
|
1136
|
+
|
|
1137
|
+
messages = []
|
|
1138
|
+
|
|
1139
|
+
# Remove BOS token if present
|
|
1140
|
+
if text.startswith("<|begin_of_text|>"):
|
|
1141
|
+
text = text[len("<|begin_of_text|>") :]
|
|
1142
|
+
|
|
1143
|
+
# Pattern to match complete message blocks: <|header_start|>role<|header_end|>\n\ncontent<|eot|>
|
|
1144
|
+
complete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)<\|eot\|>"
|
|
1145
|
+
complete_matches = re.findall(complete_pattern, text, re.DOTALL)
|
|
1146
|
+
|
|
1147
|
+
# Pattern to match incomplete message blocks: <|header_start|>role<|header_end|>\n\ncontent (without <|eot|>)
|
|
1148
|
+
incomplete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)$"
|
|
1149
|
+
|
|
1150
|
+
# Find any incomplete message at the end
|
|
1151
|
+
incomplete_matches = []
|
|
1152
|
+
if complete_matches:
|
|
1153
|
+
# Look for incomplete message after the last complete one
|
|
1154
|
+
last_complete_end = text.rfind("<|eot|>")
|
|
1155
|
+
if last_complete_end != -1:
|
|
1156
|
+
remaining_text = text[last_complete_end + len("<|eot|>") :]
|
|
1157
|
+
if remaining_text.strip():
|
|
1158
|
+
incomplete_match = re.search(
|
|
1159
|
+
incomplete_pattern, remaining_text, re.DOTALL
|
|
1160
|
+
)
|
|
1161
|
+
if incomplete_match:
|
|
1162
|
+
incomplete_matches = [
|
|
1163
|
+
(
|
|
1164
|
+
incomplete_match.group(1),
|
|
1165
|
+
incomplete_match.group(2),
|
|
1166
|
+
False,
|
|
1167
|
+
)
|
|
1168
|
+
]
|
|
1169
|
+
else:
|
|
1170
|
+
# No complete messages, check entire text for incomplete message
|
|
1171
|
+
incomplete_match = re.search(incomplete_pattern, text, re.DOTALL)
|
|
1172
|
+
if incomplete_match:
|
|
1173
|
+
incomplete_matches = [
|
|
1174
|
+
(incomplete_match.group(1), incomplete_match.group(2), False)
|
|
1175
|
+
]
|
|
1176
|
+
|
|
1177
|
+
# Process complete messages
|
|
1178
|
+
for role, content in complete_matches:
|
|
1179
|
+
if content.strip():
|
|
1180
|
+
messages.append(
|
|
1181
|
+
cls(role=role, content=content.strip(), is_complete=True)
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
# Process incomplete messages
|
|
1185
|
+
for role, content, is_complete in incomplete_matches:
|
|
1186
|
+
if content.strip():
|
|
1187
|
+
messages.append(
|
|
1188
|
+
cls(role=role, content=content.strip(), is_complete=is_complete)
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1191
|
+
if not messages:
|
|
1192
|
+
raise RuntimeError(f"Couldn't parse Llama format from text: {text}")
|
|
1193
|
+
|
|
1194
|
+
from tensordict import lazy_stack
|
|
1195
|
+
|
|
1196
|
+
return lazy_stack(messages)
|
|
1197
|
+
|
|
1198
|
+
def append(
|
|
1199
|
+
self, history: History, *, inplace: bool = True, dim: int = -1
|
|
1200
|
+
) -> History:
|
|
1201
|
+
"""Appends a new history to the current one.
|
|
1202
|
+
|
|
1203
|
+
Args:
|
|
1204
|
+
history (History): The new history to append.
|
|
1205
|
+
inplace (bool, optional): Whether to perform the operation in-place. Defaults to `True`.
|
|
1206
|
+
dim (int, optional): The dimension to append along. Defaults to -1.
|
|
1207
|
+
|
|
1208
|
+
Returns:
|
|
1209
|
+
History: The appended History object.
|
|
1210
|
+
"""
|
|
1211
|
+
# TODO: we should remove the <none> role from the history before appending / extending
|
|
1212
|
+
# It works when keeping them, but it may lead to a lot of useless padding in between valid messages
|
|
1213
|
+
if not self.batch_dims:
|
|
1214
|
+
raise RuntimeError(
|
|
1215
|
+
"Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self."
|
|
1216
|
+
)
|
|
1217
|
+
if self.batch_dims != history.batch_dims + 1:
|
|
1218
|
+
raise RuntimeError(
|
|
1219
|
+
f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}."
|
|
1220
|
+
)
|
|
1221
|
+
dim = _maybe_correct_neg_dim(dim, self.batch_size)
|
|
1222
|
+
if inplace:
|
|
1223
|
+
if (
|
|
1224
|
+
isinstance(self._tensordict, LazyStackedTensorDict)
|
|
1225
|
+
and self._tensordict.stack_dim == dim
|
|
1226
|
+
):
|
|
1227
|
+
td = history._tensordict
|
|
1228
|
+
if td.device != self.device:
|
|
1229
|
+
if self.device is None:
|
|
1230
|
+
td = td.copy().clear_device_()
|
|
1231
|
+
else:
|
|
1232
|
+
td = td.to(self.device)
|
|
1233
|
+
self._tensordict.append(td)
|
|
1234
|
+
return self
|
|
1235
|
+
else:
|
|
1236
|
+
td = history._tensordict
|
|
1237
|
+
if td.device != self.device:
|
|
1238
|
+
if self.device is None:
|
|
1239
|
+
td = td.copy().clear_device_()
|
|
1240
|
+
else:
|
|
1241
|
+
td = td.to(self.device)
|
|
1242
|
+
td = lazy_stack(list(self._tensordict.unbind(dim)) + [td], dim=dim)
|
|
1243
|
+
self.__dict__["_tensordict"] = td
|
|
1244
|
+
return self
|
|
1245
|
+
if history.device != self.device:
|
|
1246
|
+
if self.device is None:
|
|
1247
|
+
history = history.copy().clear_device_()
|
|
1248
|
+
else:
|
|
1249
|
+
history = history.to(self.device)
|
|
1250
|
+
return lazy_stack(list(self.unbind(dim)) + [history], dim=dim)
|
|
1251
|
+
|
|
1252
|
+
def extend(
|
|
1253
|
+
self, history: History, *, inplace: bool = True, dim: int = 0
|
|
1254
|
+
) -> History:
|
|
1255
|
+
if not self.batch_dims:
|
|
1256
|
+
raise RuntimeError(
|
|
1257
|
+
"Cannot add an element to a batchless History. Call unsqueeze(dim=0) first on self."
|
|
1258
|
+
)
|
|
1259
|
+
if self.batch_dims != history.batch_dims:
|
|
1260
|
+
raise RuntimeError(
|
|
1261
|
+
f"The new history to extend must have as many dimensions as self. Got self.ndim={self.ndim} and history.ndim={self.ndim}."
|
|
1262
|
+
)
|
|
1263
|
+
dim = _maybe_correct_neg_dim(dim, self.batch_size)
|
|
1264
|
+
# if self.ndim > 1 and dim >= self.ndim - 1:
|
|
1265
|
+
# # then we need to append each element independently
|
|
1266
|
+
# result = []
|
|
1267
|
+
# for hist, new_hist in zip(self.unbind(0), history.unbind(0)):
|
|
1268
|
+
# hist_c = hist.extend(new_hist, inplace=inplace, dim=dim - 1)
|
|
1269
|
+
# result.append(hist_c)
|
|
1270
|
+
# if inplace:
|
|
1271
|
+
# return self
|
|
1272
|
+
# return lazy_stack(result)
|
|
1273
|
+
if inplace:
|
|
1274
|
+
if (
|
|
1275
|
+
isinstance(self._tensordict, LazyStackedTensorDict)
|
|
1276
|
+
and self._tensordict.stack_dim == dim
|
|
1277
|
+
):
|
|
1278
|
+
td = history._tensordict
|
|
1279
|
+
if td.device != self.device:
|
|
1280
|
+
if self.device is None:
|
|
1281
|
+
td = td.copy().clear_device_()
|
|
1282
|
+
else:
|
|
1283
|
+
td = td.to(self.device)
|
|
1284
|
+
self._tensordict.extend(td)
|
|
1285
|
+
return self
|
|
1286
|
+
else:
|
|
1287
|
+
td = lazy_stack(
|
|
1288
|
+
list(self._tensordict.unbind(dim))
|
|
1289
|
+
+ list(history._tensordict.unbind(dim)),
|
|
1290
|
+
dim=dim,
|
|
1291
|
+
)
|
|
1292
|
+
if td.device != self.device:
|
|
1293
|
+
if self.device is None:
|
|
1294
|
+
td = td.copy().clear_device_()
|
|
1295
|
+
else:
|
|
1296
|
+
td = td.to(self.device)
|
|
1297
|
+
self.__dict__["_tensordict"] = td
|
|
1298
|
+
return self
|
|
1299
|
+
if history.device != self.device:
|
|
1300
|
+
if self.device is None:
|
|
1301
|
+
history = history.copy().clear_device_()
|
|
1302
|
+
else:
|
|
1303
|
+
history = history.to(self.device)
|
|
1304
|
+
return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim)
|
|
1305
|
+
|
|
1306
|
+
@classmethod
|
|
1307
|
+
def default_spec(cls, shape=(-1,)):
|
|
1308
|
+
"""A default spec to use in transforms / envs that return History objects.
|
|
1309
|
+
|
|
1310
|
+
Args:
|
|
1311
|
+
shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length
|
|
1312
|
+
along the time dimension).
|
|
1313
|
+
|
|
1314
|
+
Example:
|
|
1315
|
+
>>> import tensordict
|
|
1316
|
+
>>> from torchrl.data import History
|
|
1317
|
+
>>> tensordict.set_list_to_stack(True).set()
|
|
1318
|
+
>>>
|
|
1319
|
+
>>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,))
|
|
1320
|
+
>>> spec = history.default_spec()
|
|
1321
|
+
>>> print(spec)
|
|
1322
|
+
Composite(
|
|
1323
|
+
role: NonTensor(
|
|
1324
|
+
shape=torch.Size([-1]),
|
|
1325
|
+
space=None,
|
|
1326
|
+
device=None,
|
|
1327
|
+
dtype=None,
|
|
1328
|
+
domain=None,
|
|
1329
|
+
example_data=foo),
|
|
1330
|
+
content: NonTensor(
|
|
1331
|
+
shape=torch.Size([-1]),
|
|
1332
|
+
space=None,
|
|
1333
|
+
device=None,
|
|
1334
|
+
dtype=None,
|
|
1335
|
+
domain=None,
|
|
1336
|
+
example_data=foo),
|
|
1337
|
+
device=None,
|
|
1338
|
+
shape=torch.Size([-1]))
|
|
1339
|
+
>>> print(spec.zero())
|
|
1340
|
+
History(
|
|
1341
|
+
content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
|
|
1342
|
+
role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None),
|
|
1343
|
+
batch_size=torch.Size([1]),
|
|
1344
|
+
device=None,
|
|
1345
|
+
is_shared=False)
|
|
1346
|
+
|
|
1347
|
+
"""
|
|
1348
|
+
from torchrl.data import Composite, NonTensor
|
|
1349
|
+
|
|
1350
|
+
def get_default_value(field):
|
|
1351
|
+
if field.default is not dataclasses.MISSING:
|
|
1352
|
+
return field.default
|
|
1353
|
+
elif field.type in (str, "str"):
|
|
1354
|
+
return "foo"
|
|
1355
|
+
else:
|
|
1356
|
+
return None
|
|
1357
|
+
|
|
1358
|
+
defaults = {
|
|
1359
|
+
k: NonTensor(
|
|
1360
|
+
example_data=get_default_value(cls.__dataclass_fields__[k]),
|
|
1361
|
+
shape=shape,
|
|
1362
|
+
)
|
|
1363
|
+
for k in cls.__dataclass_fields__
|
|
1364
|
+
}
|
|
1365
|
+
|
|
1366
|
+
return Composite(defaults, shape=shape[:-1], data_cls=cls)
|
|
1367
|
+
|
|
1368
|
+
@classmethod
|
|
1369
|
+
def from_chats(cls, chats: list[list[dict]]) -> History:
|
|
1370
|
+
"""Create a History object from a list of chats.
|
|
1371
|
+
|
|
1372
|
+
Args:
|
|
1373
|
+
chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
|
|
1374
|
+
"""
|
|
1375
|
+
if isinstance(chats[0], dict):
|
|
1376
|
+
return lazy_stack([cls(**chat) for chat in chats])
|
|
1377
|
+
else:
|
|
1378
|
+
return lazy_stack([cls.from_chats(chat) for chat in chats])
|