torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# This source code is licensed under the MIT license found in the
|
|
3
|
+
# LICENSE file in the root directory of this source tree.
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import uuid
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import cast, Literal
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import NonTensorData, TensorDictBase
|
|
14
|
+
from torchrl.data.tensor_specs import Composite, NonTensor, Unbounded
|
|
15
|
+
from torchrl.envs.transforms.transforms import Transform
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class VersionChange:
|
|
20
|
+
"""Records a single version change event."""
|
|
21
|
+
|
|
22
|
+
timestamp: datetime
|
|
23
|
+
old_version: str | int | None
|
|
24
|
+
new_version: str | int
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PolicyVersion(Transform):
|
|
28
|
+
"""A transform that keeps track of the version of the policy.
|
|
29
|
+
|
|
30
|
+
This transform is used to track policy versions during training, particularly in asynchronous
|
|
31
|
+
settings where policy weights are updated periodically. It is designed to work seamlessly with
|
|
32
|
+
:class:`~torchrl.collectors.llm.LLMCollector` to ensure data collection and training remain in sync.
|
|
33
|
+
|
|
34
|
+
The version can be either a UUID (string) or an integer counter. When used with :class:`~torchrl.collectors.llm.LLMCollector`,
|
|
35
|
+
the version is automatically incremented each time the policy weights are updated.
|
|
36
|
+
|
|
37
|
+
Example usage with :class:`~torchrl.collectors.llm.LLMCollector`:
|
|
38
|
+
|
|
39
|
+
.. code-block:: python
|
|
40
|
+
|
|
41
|
+
>>> # Create a policy version tracker
|
|
42
|
+
>>> policy_version = PolicyVersion(version_type="int") # or "uuid" for UUID-based versioning
|
|
43
|
+
>>> # Create collector with version tracking
|
|
44
|
+
>>> collector = LLMCollector(
|
|
45
|
+
... env=env,
|
|
46
|
+
... policy=policy,
|
|
47
|
+
... track_policy_version=policy_version, # Pass the version tracker
|
|
48
|
+
... # ... other arguments
|
|
49
|
+
... )
|
|
50
|
+
>>> # The version will be automatically incremented when weights are updated
|
|
51
|
+
>>> collector.update_policy_weights_(new_weights)
|
|
52
|
+
>>> # The version is stored in the collected data
|
|
53
|
+
>>> for batch in collector:
|
|
54
|
+
... current_version = batch["policy_version"]
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
version_type: The type of versioning to use. Can be either:
|
|
58
|
+
- str or "uuid": Uses UUID4 for versions (good for distributed systems)
|
|
59
|
+
- int or "int": Uses incrementing integers (good for debugging)
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, version_type: type | Literal["uuid", "int"] = int):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.version_type = version_type
|
|
65
|
+
self.version_history: list[VersionChange] = [] # Track version changes
|
|
66
|
+
self._current_version: str | int | None = None
|
|
67
|
+
self._increment_version(init=True)
|
|
68
|
+
self.cal_on_reset = True
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def version(self) -> str | int:
|
|
72
|
+
"""The current version of the policy."""
|
|
73
|
+
if self._current_version is None:
|
|
74
|
+
raise RuntimeError("Version not initialized")
|
|
75
|
+
return self._current_version
|
|
76
|
+
|
|
77
|
+
@version.setter
|
|
78
|
+
def version(self, value: str | int) -> None:
|
|
79
|
+
self._current_version = value
|
|
80
|
+
|
|
81
|
+
def increment_version(self) -> None:
|
|
82
|
+
"""Increment the version number.
|
|
83
|
+
|
|
84
|
+
This is called automatically by LLMCollector when policy weights are updated.
|
|
85
|
+
Can also be called manually if needed.
|
|
86
|
+
"""
|
|
87
|
+
self._increment_version()
|
|
88
|
+
|
|
89
|
+
def _increment_version(self, init: bool = False) -> str | int:
|
|
90
|
+
"""Internal method to handle version incrementing with history tracking."""
|
|
91
|
+
old_version = self._current_version
|
|
92
|
+
if self.version_type in (str, "uuid"):
|
|
93
|
+
self._increment_version_uuid(init)
|
|
94
|
+
elif self.version_type in (int, "int"):
|
|
95
|
+
self._increment_version_int(init)
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Invalid version type: {self.version_type}")
|
|
98
|
+
|
|
99
|
+
# Record the version change
|
|
100
|
+
self.version_history.append(
|
|
101
|
+
VersionChange(
|
|
102
|
+
timestamp=datetime.now(),
|
|
103
|
+
old_version=old_version,
|
|
104
|
+
new_version=self.version,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
return self.version
|
|
108
|
+
|
|
109
|
+
def _increment_version_uuid(self, init: bool = False) -> None:
|
|
110
|
+
"""Generate a new UUID version.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
init: If True, this is the initial version creation.
|
|
114
|
+
"""
|
|
115
|
+
self.version = str(uuid.uuid4())
|
|
116
|
+
|
|
117
|
+
def _increment_version_int(self, init: bool = False) -> None:
|
|
118
|
+
"""Increment the integer version counter.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
init: If True, initialize counter to 0, otherwise increment by 1.
|
|
122
|
+
"""
|
|
123
|
+
if init:
|
|
124
|
+
self.version = 0
|
|
125
|
+
else:
|
|
126
|
+
# Cast to int to ensure type safety
|
|
127
|
+
current = cast(int, self.version)
|
|
128
|
+
self.version = current + 1
|
|
129
|
+
|
|
130
|
+
def _reset(
|
|
131
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
132
|
+
) -> TensorDictBase:
|
|
133
|
+
"""Reset the environment and update version in the new tensordict.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
tensordict: The current tensordict
|
|
137
|
+
tensordict_reset: The tensordict to reset to
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
The reset tensordict with updated version
|
|
141
|
+
"""
|
|
142
|
+
tensordict_reset = self._step(None, tensordict_reset)
|
|
143
|
+
return tensordict_reset
|
|
144
|
+
|
|
145
|
+
def _step(
|
|
146
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
147
|
+
) -> TensorDictBase:
|
|
148
|
+
"""Add the current version to the tensordict.
|
|
149
|
+
|
|
150
|
+
This method is called on each environment step to ensure the collected
|
|
151
|
+
data is tagged with the correct policy version.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
tensordict: The tensordict to update with version info
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
The tensordict with added version information
|
|
158
|
+
"""
|
|
159
|
+
if self.version_type in (str, "uuid"):
|
|
160
|
+
version = NonTensorData(self.version).expand(next_tensordict.shape)
|
|
161
|
+
elif self.version_type in (int, "int"):
|
|
162
|
+
# Cast to float for torch.full
|
|
163
|
+
version = torch.full(next_tensordict.shape, float(cast(int, self.version)))
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(f"Invalid version type: {self.version_type}")
|
|
166
|
+
|
|
167
|
+
next_tensordict.set("policy_version", version)
|
|
168
|
+
return next_tensordict
|
|
169
|
+
|
|
170
|
+
def transform_observation_spec(self, spec: Composite) -> Composite:
|
|
171
|
+
"""Update the environment spec to include the version field.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
spec: The environment spec to update
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Updated spec including the version field
|
|
178
|
+
"""
|
|
179
|
+
if self.version_type in (str, "uuid"):
|
|
180
|
+
spec["policy_version"] = NonTensor(
|
|
181
|
+
example_data=uuid.uuid4(), shape=spec.shape, device=spec.device
|
|
182
|
+
)
|
|
183
|
+
elif self.version_type in (int, "int"):
|
|
184
|
+
spec["policy_version"] = Unbounded(
|
|
185
|
+
shape=spec.shape, dtype=torch.int64, device=spec.device
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
raise ValueError(f"Invalid version type: {self.version_type}")
|
|
189
|
+
return spec
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Literal
|
|
11
|
+
|
|
12
|
+
from tensordict import lazy_stack, TensorDictBase
|
|
13
|
+
from torchrl._utils import logger as torchrl_logger
|
|
14
|
+
|
|
15
|
+
from torchrl.data.llm.history import History
|
|
16
|
+
from torchrl.envs import Transform
|
|
17
|
+
from torchrl.envs.common import EnvBase
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AddThinkingPrompt(Transform):
|
|
21
|
+
"""A transform that adds thinking prompts to encourage the LLM to reconsider its response.
|
|
22
|
+
|
|
23
|
+
This transform can either add a new thinking prompt as a separate message or edit the last
|
|
24
|
+
assistant response to include a thinking prompt before the final answer. This is useful for
|
|
25
|
+
training LLMs to self-correct and think more carefully when their initial responses are
|
|
26
|
+
incorrect or incomplete.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
cond (Callable[[TensorDictBase], bool], optional): Condition function that determines
|
|
30
|
+
when to add the thinking prompt. Takes a tensordict and returns `True` if the prompt
|
|
31
|
+
should be added.
|
|
32
|
+
prompt (str, optional): The thinking prompt to add. If None, a default prompt is used.
|
|
33
|
+
Defaults to `"But wait, let me think about this more carefully..."`.
|
|
34
|
+
random_prompt (bool, optional): Whether to randomly select from predefined prompts.
|
|
35
|
+
Defaults to `False`.
|
|
36
|
+
role (Literal["user", "assistant"], optional): The role for the thinking prompt.
|
|
37
|
+
If `"assistant"`, the prompt is added to the assistant's response. If `"user"`, it's
|
|
38
|
+
added as a separate user message. Defaults to `"assistant"`.
|
|
39
|
+
edit_last_turn (bool, optional): Whether to edit the last assistant response instead
|
|
40
|
+
of adding a new message. Only works with `role="assistant"`. Defaults to `True`.
|
|
41
|
+
zero_reward (bool, optional): Whether to zero out the reward when the thinking prompt
|
|
42
|
+
is added. If `None`, defaults to the value of `edit_last_turn`. Defaults to the same value as `edit_last_turn`.
|
|
43
|
+
undo_done (bool, optional): Whether to undo the done flag when the thinking prompt
|
|
44
|
+
is added. Defaults to `True`.
|
|
45
|
+
egocentric (bool, optional): Whether the thinking prompt is written from the perspective of the assistant.
|
|
46
|
+
Defaults to `None`, which means that the prompt is written from the perspective of the user if `role="user"`
|
|
47
|
+
and from the perspective of the assistant if `role="assistant"`.
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
>>> from torchrl.envs.llm.transforms import AddThinkingPrompt
|
|
51
|
+
>>> from torchrl.envs.llm import GSM8KEnv
|
|
52
|
+
>>> from transformers import AutoTokenizer
|
|
53
|
+
>>> import torch
|
|
54
|
+
>>>
|
|
55
|
+
>>> # Create environment with thinking prompt transform
|
|
56
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
57
|
+
>>> env = GSM8KEnv(tokenizer=tokenizer, max_steps=10)
|
|
58
|
+
>>> env = env.append_transform(
|
|
59
|
+
... AddThinkingPrompt(
|
|
60
|
+
... cond=lambda td: td["reward"] < 50,
|
|
61
|
+
... role="assistant",
|
|
62
|
+
... edit_last_turn=True,
|
|
63
|
+
... zero_reward=True,
|
|
64
|
+
... undo_done=True
|
|
65
|
+
... )
|
|
66
|
+
... )
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Test with wrong answer (low reward)
|
|
69
|
+
>>> reset = env.reset()
|
|
70
|
+
>>> wrong_answer = (
|
|
71
|
+
... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
|
|
72
|
+
... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
|
|
73
|
+
... "To find the total, I need to add April and May: 48 + 24 = 72. "
|
|
74
|
+
... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
|
|
75
|
+
... "<answer>322 clips</answer><|im_end|>"
|
|
76
|
+
... )
|
|
77
|
+
>>> reset["text_response"] = [wrong_answer]
|
|
78
|
+
>>> s = env.step(reset)
|
|
79
|
+
>>> assert (s["next", "reward"] == 0).all() # Reward zeroed
|
|
80
|
+
>>> assert (s["next", "done"] == 0).all() # Done undone
|
|
81
|
+
>>> assert s["next", "history"].shape == (1, 3) # History modified
|
|
82
|
+
>>>
|
|
83
|
+
>>> # Test with correct answer (high reward)
|
|
84
|
+
>>> reset = env.reset()
|
|
85
|
+
>>> correct_answer = (
|
|
86
|
+
... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
|
|
87
|
+
... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
|
|
88
|
+
... "To find the total, I need to add April and May: 48 + 24 = 72. "
|
|
89
|
+
... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
|
|
90
|
+
... "<answer>72</answer><|im_end|>"
|
|
91
|
+
... )
|
|
92
|
+
>>> reset["text_response"] = [correct_answer]
|
|
93
|
+
>>> s = env.step(reset)
|
|
94
|
+
>>> assert (s["next", "reward"] != 0).all() # Reward not zeroed
|
|
95
|
+
>>> assert s["next", "done"].all() # Done remains True
|
|
96
|
+
>>> assert s["next", "history"].shape == (1, 3) # History unchanged
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
# Predefined thinking prompts
|
|
100
|
+
DEFAULT_PROMPTS_EG = [
|
|
101
|
+
"But wait, let me think about this more carefully...",
|
|
102
|
+
"Actually, let me reconsider this...",
|
|
103
|
+
"But we can do better. Let me think about it step by step...",
|
|
104
|
+
"Wait, I need to double-check my reasoning...",
|
|
105
|
+
"Actually, let me think about it more carefully...",
|
|
106
|
+
"It looks like I made a mistake. Let me think about it step by step...",
|
|
107
|
+
]
|
|
108
|
+
DEFAULT_PROMPTS_COG = [
|
|
109
|
+
"But wait, think about this more carefully...",
|
|
110
|
+
"Actually, reconsider this...",
|
|
111
|
+
"But we can do better. Let's think about it step by step...",
|
|
112
|
+
"Wait, you need to double-check your reasoning...",
|
|
113
|
+
"Actually, think about it more carefully...",
|
|
114
|
+
"It looks like you made a mistake. Can you see what went wrong? Let's think about it step by step...",
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
cond: Callable[[TensorDictBase], bool],
|
|
120
|
+
prompt: str | None = None,
|
|
121
|
+
random_prompt: bool = False,
|
|
122
|
+
role: Literal["user", "assistant"] = "assistant",
|
|
123
|
+
edit_last_turn: bool = True,
|
|
124
|
+
zero_reward: bool | None = None,
|
|
125
|
+
undo_done: bool = True,
|
|
126
|
+
egocentric: bool | None = None,
|
|
127
|
+
) -> None:
|
|
128
|
+
super().__init__()
|
|
129
|
+
|
|
130
|
+
# Set condition and role
|
|
131
|
+
self.cond = cond
|
|
132
|
+
self.role = role
|
|
133
|
+
if egocentric is None:
|
|
134
|
+
egocentric = role == "assistant"
|
|
135
|
+
self.egocentric = egocentric
|
|
136
|
+
|
|
137
|
+
# Set the prompt
|
|
138
|
+
if prompt is None:
|
|
139
|
+
prompt = (
|
|
140
|
+
self.DEFAULT_PROMPTS_EG[0]
|
|
141
|
+
if egocentric
|
|
142
|
+
else self.DEFAULT_PROMPTS_COG[0]
|
|
143
|
+
)
|
|
144
|
+
self._prompt = prompt
|
|
145
|
+
self.random_prompt = random_prompt
|
|
146
|
+
|
|
147
|
+
# Validate edit_last_turn constraint
|
|
148
|
+
if edit_last_turn and role != "assistant":
|
|
149
|
+
raise ValueError("edit_last_turn can only be used with role='assistant'")
|
|
150
|
+
self.edit_last_turn = edit_last_turn
|
|
151
|
+
|
|
152
|
+
# Set zero_reward behavior
|
|
153
|
+
if zero_reward is None:
|
|
154
|
+
zero_reward = edit_last_turn
|
|
155
|
+
self.zero_reward = zero_reward
|
|
156
|
+
self.undo_done = undo_done
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def prompt(self) -> str:
|
|
160
|
+
if self.random_prompt:
|
|
161
|
+
import random
|
|
162
|
+
|
|
163
|
+
return random.choice(
|
|
164
|
+
self.DEFAULT_PROMPTS_EG if self.egocentric else self.DEFAULT_PROMPTS_COG
|
|
165
|
+
)
|
|
166
|
+
return self._prompt
|
|
167
|
+
|
|
168
|
+
def _step(
|
|
169
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
170
|
+
) -> TensorDictBase:
|
|
171
|
+
"""Process the tensordict and add thinking prompts based on the condition.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
tensordict: The current tensordict
|
|
175
|
+
next_tensordict: The next tensordict containing the most recent history and reward
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The modified next_tensordict
|
|
179
|
+
"""
|
|
180
|
+
# Handle batch dimensions
|
|
181
|
+
if next_tensordict.batch_dims >= 1:
|
|
182
|
+
ntds = []
|
|
183
|
+
for td, next_td in zip(tensordict.unbind(0), next_tensordict.unbind(0)):
|
|
184
|
+
ntds.append(self._step(td, next_td))
|
|
185
|
+
next_tensordict.update(lazy_stack(ntds))
|
|
186
|
+
return next_tensordict
|
|
187
|
+
|
|
188
|
+
# Check that base_env is on history mode
|
|
189
|
+
parent = self.parent
|
|
190
|
+
if parent is None:
|
|
191
|
+
raise RuntimeError("AddThinkingPrompt must be used with a ChatEnv")
|
|
192
|
+
base_env = parent.base_env
|
|
193
|
+
if base_env.input_mode != "history":
|
|
194
|
+
raise RuntimeError(
|
|
195
|
+
"AddThinkingPrompt must be used with a ChatEnv in history mode"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Check if we should add the thinking prompt
|
|
199
|
+
if self.cond(next_tensordict):
|
|
200
|
+
torchrl_logger.info("Adding thinking prompt.")
|
|
201
|
+
history: History = next_tensordict["history"].prompt
|
|
202
|
+
last_turn = history[..., -1]
|
|
203
|
+
|
|
204
|
+
if self.edit_last_turn:
|
|
205
|
+
|
|
206
|
+
# Edit the last assistant response
|
|
207
|
+
content = last_turn.content
|
|
208
|
+
modified_content = self._replace_answer_with_prompt(content)
|
|
209
|
+
|
|
210
|
+
# Create new history entry with modified content
|
|
211
|
+
new_turn = History(
|
|
212
|
+
role="assistant",
|
|
213
|
+
content=modified_content,
|
|
214
|
+
batch_size=last_turn.batch_size,
|
|
215
|
+
device=last_turn.device,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Replace the last turn in history
|
|
219
|
+
history = history[..., :-1].append(new_turn)
|
|
220
|
+
next_tensordict["history"].prompt = history
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
# Add a new message
|
|
224
|
+
prompt = self.prompt
|
|
225
|
+
|
|
226
|
+
history = history.append(History(role=self.role, content=prompt))
|
|
227
|
+
next_tensordict["history"].prompt = history
|
|
228
|
+
|
|
229
|
+
if self.undo_done:
|
|
230
|
+
parent: EnvBase = self.parent
|
|
231
|
+
if parent is not None:
|
|
232
|
+
done_keys = parent.done_keys
|
|
233
|
+
for key in done_keys:
|
|
234
|
+
done = next_tensordict.get(key)
|
|
235
|
+
if done is not None:
|
|
236
|
+
next_tensordict.set(key, done.zero_())
|
|
237
|
+
|
|
238
|
+
# Zero out reward if requested
|
|
239
|
+
if self.zero_reward:
|
|
240
|
+
parent: EnvBase = self.parent
|
|
241
|
+
if parent is not None:
|
|
242
|
+
reward_keys = parent.reward_keys
|
|
243
|
+
for key in reward_keys:
|
|
244
|
+
reward = next_tensordict.get(key)
|
|
245
|
+
if reward is not None:
|
|
246
|
+
next_tensordict.set(key, reward.zero_())
|
|
247
|
+
else:
|
|
248
|
+
torchrl_logger.info("Not adding thinking prompt.")
|
|
249
|
+
return next_tensordict
|
|
250
|
+
|
|
251
|
+
def _replace_answer_with_prompt(self, content: str) -> str:
|
|
252
|
+
"""Replace the last answer section with a thinking prompt.
|
|
253
|
+
|
|
254
|
+
This method uses regex to find and replace the last <answer>...</answer> section
|
|
255
|
+
with the thinking prompt, preserving any content before the answer tag.
|
|
256
|
+
Only the last answer block is replaced to avoid interfering with earlier
|
|
257
|
+
examples or instructions that might contain answer tags.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
content: The original content string
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
The modified content with the last answer replaced by the thinking prompt
|
|
264
|
+
"""
|
|
265
|
+
# Pattern to match <answer>...</answer> with optional EOS token
|
|
266
|
+
# Use non-greedy matching and be more specific about the end
|
|
267
|
+
answer_pattern = r"<answer>.*?</answer>(?:\s*<\|im_end\|>)?"
|
|
268
|
+
|
|
269
|
+
# Check if there's an answer tag
|
|
270
|
+
if "<answer>" in content:
|
|
271
|
+
# Find all matches to get the last one
|
|
272
|
+
matches = list(re.finditer(answer_pattern, content, flags=re.DOTALL))
|
|
273
|
+
|
|
274
|
+
if matches:
|
|
275
|
+
# Get the last match
|
|
276
|
+
last_match = matches[-1]
|
|
277
|
+
start, end = last_match.span()
|
|
278
|
+
|
|
279
|
+
# Replace only the last answer section with the thinking prompt
|
|
280
|
+
prompt = self.prompt
|
|
281
|
+
modified_content = content[:start] + prompt + content[end:]
|
|
282
|
+
|
|
283
|
+
# Clean up any trailing whitespace
|
|
284
|
+
modified_content = modified_content.rstrip()
|
|
285
|
+
|
|
286
|
+
# Ensure we end with the EOS token if the original content had it
|
|
287
|
+
if content.endswith("<|im_end|>"):
|
|
288
|
+
modified_content = modified_content.rstrip() + "<|im_end|>"
|
|
289
|
+
|
|
290
|
+
# Ensure proper spacing around the prompt
|
|
291
|
+
if not modified_content.endswith(prompt):
|
|
292
|
+
# If the prompt wasn't properly inserted, append it
|
|
293
|
+
modified_content = content.rstrip()
|
|
294
|
+
if modified_content.endswith("<|im_end|>"):
|
|
295
|
+
modified_content = modified_content[
|
|
296
|
+
: -len("<|im_end|>")
|
|
297
|
+
].rstrip()
|
|
298
|
+
modified_content = modified_content + "\n\n" + prompt + "<|im_end|>"
|
|
299
|
+
else:
|
|
300
|
+
# No matches found, just append the prompt
|
|
301
|
+
prompt = self.prompt
|
|
302
|
+
modified_content = content.rstrip() + "\n\n" + prompt
|
|
303
|
+
|
|
304
|
+
else:
|
|
305
|
+
# No answer tag found, just append the prompt
|
|
306
|
+
prompt = self.prompt
|
|
307
|
+
modified_content = content.rstrip() + "\n\n" + prompt
|
|
308
|
+
|
|
309
|
+
return modified_content
|
|
310
|
+
|
|
311
|
+
def _reset(
|
|
312
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
313
|
+
) -> TensorDictBase:
|
|
314
|
+
"""Reset the transform state.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
tensordict: The current tensordict
|
|
318
|
+
tensordict_reset: The reset tensordict
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
The reset tensordict
|
|
322
|
+
"""
|
|
323
|
+
return tensordict_reset
|