torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/envs/llm/envs.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
|
|
10
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from tensordict import (
|
|
15
|
+
is_leaf_nontensor,
|
|
16
|
+
LazyStackedTensorDict,
|
|
17
|
+
NestedKey,
|
|
18
|
+
set_list_to_stack,
|
|
19
|
+
TensorDict,
|
|
20
|
+
TensorDictBase,
|
|
21
|
+
unravel_key,
|
|
22
|
+
)
|
|
23
|
+
from tensordict.tensorclass import NonTensorData, NonTensorStack
|
|
24
|
+
from tensordict.utils import _zip_strict
|
|
25
|
+
from torch.utils.data import DataLoader
|
|
26
|
+
|
|
27
|
+
from torchrl._utils import _replace_last
|
|
28
|
+
from torchrl.data.map.hash import SipHash
|
|
29
|
+
from torchrl.data.tensor_specs import (
|
|
30
|
+
Bounded,
|
|
31
|
+
Categorical as CategoricalSpec,
|
|
32
|
+
Composite,
|
|
33
|
+
NonTensor,
|
|
34
|
+
Unbounded,
|
|
35
|
+
)
|
|
36
|
+
from torchrl.envs import EnvBase
|
|
37
|
+
from torchrl.envs.utils import _StepMDP
|
|
38
|
+
from torchrl.modules.utils.utils import _unpad_tensors
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
import transformers
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LLMEnv(EnvBase):
|
|
45
|
+
"""A text generation environment for language models.
|
|
46
|
+
|
|
47
|
+
This environment is designed to work with language models, where the observation is a string or a tensor of
|
|
48
|
+
integers representing a sequence of tokens. The action is also a string or a tensor of integers, which is
|
|
49
|
+
concatenated to the previous observation to form the new observation.
|
|
50
|
+
|
|
51
|
+
By default, this environment is meant to track history for a prompt. Users can append transforms to tailor
|
|
52
|
+
this to their use case, such as Chain of Thought (CoT) reasoning or other custom processing.
|
|
53
|
+
|
|
54
|
+
Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt.
|
|
55
|
+
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via
|
|
56
|
+
:meth:`~from_dataloader`.
|
|
57
|
+
|
|
58
|
+
.. note:: The default arguments of the `LLMEnv` class are set to make it easy to run this environment with
|
|
59
|
+
the vllm backend (:class:`~torchrl.modules.vLLMWrapper`).
|
|
60
|
+
|
|
61
|
+
Keyword Args:
|
|
62
|
+
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `from_text=False`).
|
|
63
|
+
Defaults to ``"tokens"``.
|
|
64
|
+
str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `from_text=True`).
|
|
65
|
+
Defaults to ``"text"``.
|
|
66
|
+
attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored.
|
|
67
|
+
Defaults to ``"attention_mask"``.
|
|
68
|
+
action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to
|
|
69
|
+
``"tokens_response"`` or ``"text_response"``.
|
|
70
|
+
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
|
|
71
|
+
Defaults to ``"reward"``.
|
|
72
|
+
from_text (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``.
|
|
73
|
+
device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
|
|
74
|
+
vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
|
|
75
|
+
unbounded vocabulary. Defaults to ``None``.
|
|
76
|
+
has_attention (bool, optional): If ``True``, an attention mask is to be used under the key indicated by
|
|
77
|
+
:attr:`attention_key`. Defaults to ``True``.
|
|
78
|
+
assign_reward (bool, optional): If ``True``, a zero-valued reward of shape equal to the action shape
|
|
79
|
+
is written during calls to `step()`. Defaults to ``False``.
|
|
80
|
+
assign_done (bool, optional): If ``True``, a zero-valued done and terminated state of shape equal to the
|
|
81
|
+
action shape is written during calls to `step()`. Defaults to ``False``.
|
|
82
|
+
|
|
83
|
+
.. note:: Regardless of the value assigned to `assign_done`, a done state will be written at the root
|
|
84
|
+
as it is a requirement for all TorchRL environments.
|
|
85
|
+
|
|
86
|
+
batch_size (int or torch.Size, optional): Batch size of the environment.
|
|
87
|
+
If left empty, an empty batch-size is assumed.
|
|
88
|
+
The batch size can be null (`torch.Size([])`) or one-dimensional. Batchless environments are not supported.
|
|
89
|
+
|
|
90
|
+
.. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env
|
|
91
|
+
and the transform should match.
|
|
92
|
+
|
|
93
|
+
eos_token_id (int, optional): The token id of the end of the sequence. If passed, the `done` state
|
|
94
|
+
is set to `True` when detected. Defaults to `None`.
|
|
95
|
+
|
|
96
|
+
.. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples.
|
|
97
|
+
|
|
98
|
+
Methods:
|
|
99
|
+
from_dataloader: Creates an LLMEnv instance from a dataloader.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
_DEFAULT_TOKEN_KEY = "tokens"
|
|
104
|
+
_DEFAULT_STR_KEY = "text"
|
|
105
|
+
_DEFAULT_ATTENTION_KEY = "attention_mask"
|
|
106
|
+
_DEFAULT_ACTION_TOKENS_KEY = "tokens_response"
|
|
107
|
+
_DEFAULT_ACTION_STR_KEY = "text_response"
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
token_key: NestedKey | None = None,
|
|
113
|
+
str_key: NestedKey | None = None,
|
|
114
|
+
attention_key: NestedKey | None = None,
|
|
115
|
+
action_key: NestedKey | None = None,
|
|
116
|
+
reward_key: NestedKey = "reward",
|
|
117
|
+
from_text: bool = True,
|
|
118
|
+
device: torch.device | None = None,
|
|
119
|
+
vocab_size: int | None = None,
|
|
120
|
+
assign_reward: bool = False,
|
|
121
|
+
assign_done: bool = False,
|
|
122
|
+
batch_size: int | torch.Size | None = None,
|
|
123
|
+
has_attention: bool = True,
|
|
124
|
+
# Experimental
|
|
125
|
+
as_llm_data: bool = False,
|
|
126
|
+
eos_token_id: int | None = None,
|
|
127
|
+
) -> None:
|
|
128
|
+
self._warn_deprecated()
|
|
129
|
+
self.as_llm_data = as_llm_data
|
|
130
|
+
if token_key is None:
|
|
131
|
+
token_key = self._DEFAULT_TOKEN_KEY
|
|
132
|
+
if str_key is None:
|
|
133
|
+
str_key = self._DEFAULT_STR_KEY
|
|
134
|
+
if attention_key is None:
|
|
135
|
+
attention_key = self._DEFAULT_ATTENTION_KEY
|
|
136
|
+
if action_key is None:
|
|
137
|
+
if from_text:
|
|
138
|
+
action_key = self._DEFAULT_ACTION_STR_KEY
|
|
139
|
+
else:
|
|
140
|
+
action_key = self._DEFAULT_ACTION_TOKENS_KEY
|
|
141
|
+
self._batch_locked = True
|
|
142
|
+
if batch_size is None:
|
|
143
|
+
batch_size = ()
|
|
144
|
+
else:
|
|
145
|
+
if not isinstance(batch_size, (tuple, list)):
|
|
146
|
+
batch_size = (batch_size,)
|
|
147
|
+
elif len(batch_size) > 1:
|
|
148
|
+
raise TypeError(
|
|
149
|
+
f"batch-size of LLMEnv must be 0 or 1d. Got batch_size={batch_size}."
|
|
150
|
+
)
|
|
151
|
+
super().__init__(
|
|
152
|
+
device=device,
|
|
153
|
+
batch_size=batch_size,
|
|
154
|
+
)
|
|
155
|
+
self.has_attention = has_attention
|
|
156
|
+
self.from_text = from_text
|
|
157
|
+
self.vocab_size = vocab_size
|
|
158
|
+
self.token_key = unravel_key(token_key)
|
|
159
|
+
self.str_key = unravel_key(str_key)
|
|
160
|
+
if attention_key is not None:
|
|
161
|
+
attention_key = unravel_key(attention_key)
|
|
162
|
+
self.attention_key = attention_key
|
|
163
|
+
self.assign_reward = assign_reward
|
|
164
|
+
self.assign_done = assign_done
|
|
165
|
+
self.eos_token_id = eos_token_id
|
|
166
|
+
if eos_token_id is None:
|
|
167
|
+
warnings.warn(
|
|
168
|
+
"eos_token_id is missing. This means that the environment will not be able to capture its "
|
|
169
|
+
"done state automatically. This may lead to undefined behaviors when the generated text reaches "
|
|
170
|
+
"an eos_token.",
|
|
171
|
+
category=UserWarning,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# self.action_key = unravel_key(action_key)
|
|
175
|
+
if from_text:
|
|
176
|
+
self.full_observation_spec_unbatched = Composite(
|
|
177
|
+
{
|
|
178
|
+
self.str_key: NonTensor(
|
|
179
|
+
example_data="a string",
|
|
180
|
+
batched=True,
|
|
181
|
+
shape=(),
|
|
182
|
+
device=device,
|
|
183
|
+
)
|
|
184
|
+
}
|
|
185
|
+
)
|
|
186
|
+
self.full_action_spec_unbatched = Composite(
|
|
187
|
+
{
|
|
188
|
+
action_key: NonTensor(
|
|
189
|
+
example_data="a string", batched=True, shape=(), device=device
|
|
190
|
+
)
|
|
191
|
+
}
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
if vocab_size is None:
|
|
195
|
+
observation_spec = {
|
|
196
|
+
token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device)
|
|
197
|
+
}
|
|
198
|
+
if self.has_attention:
|
|
199
|
+
observation_spec[attention_key] = Unbounded(
|
|
200
|
+
shape=(-1,), dtype=torch.int64, device=device
|
|
201
|
+
)
|
|
202
|
+
self.full_observation_spec_unbatched = Composite(observation_spec)
|
|
203
|
+
self.full_action_spec_unbatched = Composite(
|
|
204
|
+
{
|
|
205
|
+
action_key: Unbounded(
|
|
206
|
+
shape=(-1,), dtype=torch.int64, device=device
|
|
207
|
+
)
|
|
208
|
+
}
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
self.full_observation_spec_unbatched = Composite(
|
|
212
|
+
{
|
|
213
|
+
token_key: Bounded(
|
|
214
|
+
shape=(-1,),
|
|
215
|
+
dtype=torch.int64,
|
|
216
|
+
low=0,
|
|
217
|
+
high=vocab_size,
|
|
218
|
+
device=device,
|
|
219
|
+
)
|
|
220
|
+
}
|
|
221
|
+
)
|
|
222
|
+
self.full_action_spec_unbatched = Composite(
|
|
223
|
+
{
|
|
224
|
+
action_key: Bounded(
|
|
225
|
+
shape=(-1,),
|
|
226
|
+
dtype=torch.int64,
|
|
227
|
+
low=0,
|
|
228
|
+
high=vocab_size,
|
|
229
|
+
device=device,
|
|
230
|
+
)
|
|
231
|
+
}
|
|
232
|
+
)
|
|
233
|
+
STR2STR_ERR = ValueError(
|
|
234
|
+
"from_text cannot be True when either of assign_reward / assign_done are True. "
|
|
235
|
+
"Tokens are required to compute the reward shape."
|
|
236
|
+
)
|
|
237
|
+
if self.assign_reward:
|
|
238
|
+
if self.from_text:
|
|
239
|
+
raise STR2STR_ERR
|
|
240
|
+
self.full_reward_spec_unbatched = Composite(
|
|
241
|
+
{reward_key: Unbounded(shape=(-1,), device=device)}
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
self.full_reward_spec_unbatched = Composite(device=device)
|
|
245
|
+
|
|
246
|
+
if not self.assign_done:
|
|
247
|
+
# Use single done
|
|
248
|
+
self.full_done_spec_unbatched = Composite(
|
|
249
|
+
done=Unbounded(shape=(1,), dtype=torch.bool, device=device),
|
|
250
|
+
terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
|
|
251
|
+
)
|
|
252
|
+
elif self.from_text:
|
|
253
|
+
raise STR2STR_ERR
|
|
254
|
+
else:
|
|
255
|
+
# Use single done
|
|
256
|
+
self.full_done_spec_unbatched = Composite(
|
|
257
|
+
tokens_data=Composite(
|
|
258
|
+
done=Unbounded(shape=(-1,), dtype=torch.bool, device=device),
|
|
259
|
+
terminated=Unbounded(shape=(-1,), dtype=torch.bool, device=device),
|
|
260
|
+
),
|
|
261
|
+
done=Unbounded(shape=(1,), dtype=torch.bool, device=device),
|
|
262
|
+
terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
@classmethod
|
|
266
|
+
def _warn_deprecated(cls):
|
|
267
|
+
warnings.warn(
|
|
268
|
+
"LLMEnv is deprecated. Please use ChatEnv instead.",
|
|
269
|
+
category=DeprecationWarning,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
@classmethod
|
|
273
|
+
def from_dataloader(
|
|
274
|
+
cls,
|
|
275
|
+
dataloader: DataLoader,
|
|
276
|
+
*,
|
|
277
|
+
tokenizer: transformers.PretrainedTokenizerBase | None = None, # noqa
|
|
278
|
+
token_key: NestedKey | None = None,
|
|
279
|
+
str_key: NestedKey | None = None,
|
|
280
|
+
attention_key: NestedKey | None = None,
|
|
281
|
+
action_key: NestedKey | None = None,
|
|
282
|
+
reward_key: NestedKey = "reward",
|
|
283
|
+
from_text: bool = True,
|
|
284
|
+
device: torch.device | None = None,
|
|
285
|
+
vocab_size: int | None = None,
|
|
286
|
+
batch_size: int | torch.Size | None = None,
|
|
287
|
+
has_attention: bool = True,
|
|
288
|
+
assign_reward: bool = False,
|
|
289
|
+
assign_done: bool = False,
|
|
290
|
+
primers: Composite | None = None,
|
|
291
|
+
example_data: Any = None,
|
|
292
|
+
stack_method: Callable[[Any], Any]
|
|
293
|
+
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
|
|
294
|
+
repeats: int | None = None,
|
|
295
|
+
group_repeats: bool = True,
|
|
296
|
+
eos_token_id: int | None = None,
|
|
297
|
+
) -> LLMEnv:
|
|
298
|
+
"""Creates an LLMEnv instance from a dataloader.
|
|
299
|
+
|
|
300
|
+
This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
dataloader (DataLoader): The dataloader to load data from.
|
|
304
|
+
|
|
305
|
+
Keyword Args:
|
|
306
|
+
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
|
|
307
|
+
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
|
|
308
|
+
pre-trained tokenizer.
|
|
309
|
+
|
|
310
|
+
.. note:: Using the `tokenizer` will append a :class:`~torchrl.envs.Tokenizer` transform to the environment.
|
|
311
|
+
If `from_text` is set to `True`, the tokenizer will be called during every iteration and the rollout
|
|
312
|
+
will contain both tokens and text data.
|
|
313
|
+
If `from_text` is set to `False`, the tokenizer will be called during reset only, and the only
|
|
314
|
+
text data in the rollout will be the text sampled from the dataset.
|
|
315
|
+
|
|
316
|
+
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `from_text=False`).
|
|
317
|
+
Defaults to ``("tokens_in", "input_ids")``.
|
|
318
|
+
str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `from_text=True`).
|
|
319
|
+
Defaults to ``"test"``.
|
|
320
|
+
attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored.
|
|
321
|
+
Defaults to ``("tokens_in", "input_ids")``
|
|
322
|
+
action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to
|
|
323
|
+
``("tokens_out", "sequences")``.
|
|
324
|
+
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
|
|
325
|
+
Defaults to ``"reward"``.
|
|
326
|
+
from_text (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``.
|
|
327
|
+
device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
|
|
328
|
+
vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
|
|
329
|
+
unbounded vocabulary. Defaults to ``None``.
|
|
330
|
+
has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by
|
|
331
|
+
:attr:`attention_key`. Defaults to ``True``.
|
|
332
|
+
assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape
|
|
333
|
+
is written during calls to `step()`. Defaults to ``False``.
|
|
334
|
+
assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the
|
|
335
|
+
action shape is written during calls to `step()`. Defaults to ``False``.
|
|
336
|
+
|
|
337
|
+
.. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root
|
|
338
|
+
as it is a requirement for all TorchRL environments.
|
|
339
|
+
|
|
340
|
+
batch_size (int or torch.Size, optional): Batch size of the environment.
|
|
341
|
+
If left empty, the batch size is inferred from `dataloader.batch_size` if that attribute exists, otherwise
|
|
342
|
+
it is set to `()`.
|
|
343
|
+
The batch size can be null (`torch.Size([])`) or one-dimensional. Batchless environments are not supported.
|
|
344
|
+
|
|
345
|
+
.. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env
|
|
346
|
+
and the transform should match.
|
|
347
|
+
|
|
348
|
+
primers (Composite | None, optional): The primers to use for each key in the dataloader.
|
|
349
|
+
Defaults to ``None`` (inferred automatically from the first batch of data).
|
|
350
|
+
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The
|
|
351
|
+
method to use for stacking the data. Defaults to ``None``.
|
|
352
|
+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
|
|
353
|
+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
|
|
354
|
+
samples (rather than an advantage module).
|
|
355
|
+
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
|
|
356
|
+
all repeats are grouped in a single batch collected from the buffer. Defaults to ``True``.
|
|
357
|
+
eos_token_id (int, optional): The token id of the end of the sequence. If passed, the `done` state
|
|
358
|
+
is set to `True` when detected. Defaults to `None`.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
LLMEnv: The created LLMEnv instance.
|
|
362
|
+
"""
|
|
363
|
+
cls._warn_deprecated()
|
|
364
|
+
|
|
365
|
+
from torchrl.envs.llm import DataLoadingPrimer, Tokenizer
|
|
366
|
+
|
|
367
|
+
if str_key is None:
|
|
368
|
+
str_key = LLMEnv._DEFAULT_STR_KEY
|
|
369
|
+
if token_key is None:
|
|
370
|
+
token_key = LLMEnv._DEFAULT_TOKEN_KEY
|
|
371
|
+
if attention_key is None:
|
|
372
|
+
attention_key = LLMEnv._DEFAULT_ATTENTION_KEY
|
|
373
|
+
elif tokenizer is not None and attention_key != _replace_last(
|
|
374
|
+
token_key, "attention_mask"
|
|
375
|
+
):
|
|
376
|
+
raise ValueError(
|
|
377
|
+
"When using the Tokenizer, attention key must match `(*token_key[:-1], 'attention_mask')` where "
|
|
378
|
+
f"`token_key` is a tuple-typed nested key. Got attention_key={attention_key} while expecting "
|
|
379
|
+
f"{_replace_last(token_key, 'attention_mask')}."
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if tokenizer is not None:
|
|
383
|
+
if from_text:
|
|
384
|
+
# In this case, the tokenizer is appended to the env after each step
|
|
385
|
+
if action_key is None:
|
|
386
|
+
action_key = cls._DEFAULT_ACTION_STR_KEY
|
|
387
|
+
tokenizer_transform = Tokenizer(
|
|
388
|
+
tokenizer=tokenizer,
|
|
389
|
+
in_keys=[str_key],
|
|
390
|
+
out_keys=[token_key],
|
|
391
|
+
# Assume that the tokens are named according to _DEFAULT_ACTION_TOKENS_KEY
|
|
392
|
+
in_keys_inv=[action_key],
|
|
393
|
+
out_keys_inv=[cls._DEFAULT_ACTION_TOKENS_KEY],
|
|
394
|
+
call_before_reset=False,
|
|
395
|
+
# We should always see the required entries
|
|
396
|
+
missing_tolerance=False,
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
# FIXME: This is broken - do we need it anyway?
|
|
400
|
+
raise RuntimeError(
|
|
401
|
+
"tokenizers can only be used whenever from_text is set to `True`."
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
primer = DataLoadingPrimer(
|
|
405
|
+
dataloader=dataloader,
|
|
406
|
+
primers=primers,
|
|
407
|
+
stack_method=stack_method,
|
|
408
|
+
repeats=repeats,
|
|
409
|
+
device=device,
|
|
410
|
+
group_repeats=group_repeats,
|
|
411
|
+
batch_size=batch_size,
|
|
412
|
+
)
|
|
413
|
+
env = LLMEnv(
|
|
414
|
+
from_text=from_text,
|
|
415
|
+
device=device,
|
|
416
|
+
token_key=token_key,
|
|
417
|
+
str_key=str_key,
|
|
418
|
+
attention_key=attention_key,
|
|
419
|
+
action_key=action_key,
|
|
420
|
+
reward_key=reward_key,
|
|
421
|
+
vocab_size=vocab_size,
|
|
422
|
+
assign_reward=assign_reward,
|
|
423
|
+
assign_done=assign_done,
|
|
424
|
+
batch_size=primer.batch_size,
|
|
425
|
+
has_attention=has_attention,
|
|
426
|
+
eos_token_id=eos_token_id,
|
|
427
|
+
)
|
|
428
|
+
if tokenizer is not None:
|
|
429
|
+
env = env.append_transform(tokenizer_transform)
|
|
430
|
+
return env.append_transform(primer)
|
|
431
|
+
|
|
432
|
+
@staticmethod
|
|
433
|
+
def _check_obs_act_and_cat(obs, action, *, device):
|
|
434
|
+
if not isinstance(obs, str):
|
|
435
|
+
raise TypeError(f"Observation must be a string, got {type(obs)}.")
|
|
436
|
+
if not isinstance(action, str):
|
|
437
|
+
raise TypeError(f"Action must be a string, got {type(action)}.")
|
|
438
|
+
return NonTensorData(obs + action, device=device)
|
|
439
|
+
|
|
440
|
+
def _step(
|
|
441
|
+
self,
|
|
442
|
+
tensordict: TensorDictBase,
|
|
443
|
+
) -> TensorDictBase:
|
|
444
|
+
next_td = tensordict.empty()
|
|
445
|
+
self._make_next_obs(tensordict, next_td)
|
|
446
|
+
self._maybe_make_reward(tensordict, next_td)
|
|
447
|
+
self._maybe_make_done(tensordict, next_td)
|
|
448
|
+
if self.as_llm_data:
|
|
449
|
+
raise NotImplementedError()
|
|
450
|
+
return next_td
|
|
451
|
+
|
|
452
|
+
def _maybe_make_reward(
|
|
453
|
+
self, tensordict: TensorDictBase, next_td: TensorDictBase
|
|
454
|
+
) -> TensorDictBase:
|
|
455
|
+
if self.assign_reward:
|
|
456
|
+
next_td.set(
|
|
457
|
+
self.reward_key,
|
|
458
|
+
torch.zeros_like(
|
|
459
|
+
tensordict.get(self.action_key), dtype=self.reward_spec.dtype
|
|
460
|
+
),
|
|
461
|
+
)
|
|
462
|
+
return next_td
|
|
463
|
+
|
|
464
|
+
def _maybe_make_done(
|
|
465
|
+
self,
|
|
466
|
+
tensordict: TensorDictBase,
|
|
467
|
+
next_td: TensorDictBase,
|
|
468
|
+
resetting: bool = False,
|
|
469
|
+
) -> TensorDictBase:
|
|
470
|
+
if self.assign_done:
|
|
471
|
+
action = tensordict.get(self.action_key)
|
|
472
|
+
if action is None:
|
|
473
|
+
done = torch.zeros(
|
|
474
|
+
tensordict.shape + (1,), dtype=torch.bool, device=self.device
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
done = torch.zeros_like(action, dtype=torch.bool)
|
|
478
|
+
next_td.set(("tokens_data", "terminated"), done)
|
|
479
|
+
next_td.set(("tokens_data", "done"), done.clone())
|
|
480
|
+
next_td.set(
|
|
481
|
+
"done", next_td.get(("tokens_data", "done")).any(-1, keepdim=True)
|
|
482
|
+
)
|
|
483
|
+
next_td.set(
|
|
484
|
+
"terminated",
|
|
485
|
+
next_td.get(("tokens_data", "terminated")).any(-1, keepdim=True),
|
|
486
|
+
)
|
|
487
|
+
if not resetting and self.eos_token_id is not None:
|
|
488
|
+
if self.from_text:
|
|
489
|
+
token_action_key = self._DEFAULT_ACTION_TOKENS_KEY
|
|
490
|
+
else:
|
|
491
|
+
token_action_key = self.action_key
|
|
492
|
+
action = tensordict.get(
|
|
493
|
+
token_action_key, as_padded_tensor=True, padding_value=-1
|
|
494
|
+
)
|
|
495
|
+
mask = action == -1
|
|
496
|
+
|
|
497
|
+
if action is None:
|
|
498
|
+
raise RuntimeError(
|
|
499
|
+
f"Couldn't find the tokenized action with key {token_action_key} to set the done state in tensordict "
|
|
500
|
+
f"with keys {list(tensordict.keys(True))}."
|
|
501
|
+
)
|
|
502
|
+
full_done = action == self.eos_token_id
|
|
503
|
+
done = full_done.any(-1, keepdim=True)
|
|
504
|
+
next_td.set("done", done)
|
|
505
|
+
next_td.set("terminated", done)
|
|
506
|
+
if self.assign_done:
|
|
507
|
+
full_done = _unpad_tensors(full_done, mask)
|
|
508
|
+
next_td.set(("tokens_data", "terminated"), full_done)
|
|
509
|
+
next_td.set(("tokens_data", "done"), full_done)
|
|
510
|
+
return next_td
|
|
511
|
+
|
|
512
|
+
def _make_next_obs(
|
|
513
|
+
self, tensordict: TensorDictBase, nex_td: TensorDictBase
|
|
514
|
+
) -> TensorDictBase:
|
|
515
|
+
# Cat action entry with prev obs
|
|
516
|
+
if self.from_text:
|
|
517
|
+
obs = tensordict[self.str_key]
|
|
518
|
+
action = tensordict[self.action_key]
|
|
519
|
+
if not tensordict.batch_size:
|
|
520
|
+
if not isinstance(obs, str) or not isinstance(action, str):
|
|
521
|
+
raise TypeError(
|
|
522
|
+
"The tensordict is batchless, yet the action and/or observations are not "
|
|
523
|
+
f"strings but {type(action)} and {type(obs)}, respectivly."
|
|
524
|
+
)
|
|
525
|
+
observation = self._check_obs_act_and_cat(
|
|
526
|
+
obs, action, device=self.device
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
observation = NonTensorStack(
|
|
530
|
+
*[
|
|
531
|
+
self._check_obs_act_and_cat(_obs, _action, device=self.device)
|
|
532
|
+
for (_obs, _action) in _zip_strict(obs, action)
|
|
533
|
+
]
|
|
534
|
+
)
|
|
535
|
+
return nex_td.set(self.str_key, observation)
|
|
536
|
+
else:
|
|
537
|
+
try:
|
|
538
|
+
obs: torch.Tensor = tensordict.get(self.token_key)
|
|
539
|
+
action = tensordict.get(self.action_key)
|
|
540
|
+
if getattr(obs, "is_nested", False):
|
|
541
|
+
observation = torch.nested.as_nested_tensor(
|
|
542
|
+
[
|
|
543
|
+
torch.cat([_obs, _action], -1)
|
|
544
|
+
for _obs, _action in _zip_strict(
|
|
545
|
+
obs.unbind(0), action.unbind(0)
|
|
546
|
+
)
|
|
547
|
+
],
|
|
548
|
+
layout=obs.layout,
|
|
549
|
+
)
|
|
550
|
+
else:
|
|
551
|
+
observation = torch.cat([obs, action], -1)
|
|
552
|
+
if self.has_attention:
|
|
553
|
+
attention_mask = tensordict.get(self.attention_key)
|
|
554
|
+
attention_mask = torch.cat(
|
|
555
|
+
[attention_mask, attention_mask.new_ones(action.shape)], -1
|
|
556
|
+
)
|
|
557
|
+
nex_td.set(self.attention_key, attention_mask)
|
|
558
|
+
except TypeError:
|
|
559
|
+
raise TypeError(
|
|
560
|
+
"Failed to cat action and observation tensors. Check that from_text argument is correctly "
|
|
561
|
+
f"set in {type(self).__name__}."
|
|
562
|
+
)
|
|
563
|
+
return nex_td.set(self.token_key, observation)
|
|
564
|
+
|
|
565
|
+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
|
|
566
|
+
# We should have an observation by this time, if not raise an exception
|
|
567
|
+
def check_token():
|
|
568
|
+
return not self.from_text and (
|
|
569
|
+
self.token_key not in tensordict.keys(isinstance(self.token_key, tuple))
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
def check_str():
|
|
573
|
+
return self.from_text and (
|
|
574
|
+
self.str_key not in tensordict.keys(isinstance(self.str_key, tuple))
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
if tensordict is None or check_token() or check_str():
|
|
578
|
+
raise KeyError(
|
|
579
|
+
f"Observation key {self.token_key}/{self.str_key} is not defined in tensordict with keys "
|
|
580
|
+
f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, "
|
|
581
|
+
f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
|
|
582
|
+
)
|
|
583
|
+
if not isinstance(tensordict, LazyStackedTensorDict) and tensordict.ndim:
|
|
584
|
+
tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
|
|
585
|
+
td_reset = tensordict.copy()
|
|
586
|
+
if td_reset.device != self.device:
|
|
587
|
+
if self.device is None:
|
|
588
|
+
td_reset.clear_device_()
|
|
589
|
+
else:
|
|
590
|
+
td_reset = td_reset.to(self.device)
|
|
591
|
+
tensordict = self._maybe_make_done(tensordict, td_reset, resetting=True)
|
|
592
|
+
if self.as_llm_data:
|
|
593
|
+
raise NotImplementedError()
|
|
594
|
+
return tensordict
|
|
595
|
+
|
|
596
|
+
def _set_seed(self, seed: int | None):
|
|
597
|
+
return seed
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
class LLMHashingEnv(EnvBase):
|
|
601
|
+
"""A text generation environment that uses a hashing module to identify unique observations.
|
|
602
|
+
|
|
603
|
+
The primary goal of this environment is to identify token chains using a hashing function.
|
|
604
|
+
This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
|
|
605
|
+
identifiers, or easily prune repeated token chains in a data structure.
|
|
606
|
+
|
|
607
|
+
.. The following figure gives an overview of this workflow:
|
|
608
|
+
.. .. figure:: /_static/img/rollout-llm.png
|
|
609
|
+
.. :alt: Data collection loop with our LLM environment.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
|
|
613
|
+
|
|
614
|
+
Keyword Args:
|
|
615
|
+
hashing_module (Callable[[torch.Tensor], torch.Tensor], optional):
|
|
616
|
+
A hashing function that takes a tensor as input and returns a hashed tensor.
|
|
617
|
+
Defaults to :class:`~torchrl.data.SipHash` if not provided.
|
|
618
|
+
observation_key (NestedKey, optional): The key for the observation in the TensorDict.
|
|
619
|
+
Defaults to "observation".
|
|
620
|
+
text_output (bool, optional): Whether to include the text output in the observation.
|
|
621
|
+
Defaults to `True`.
|
|
622
|
+
tokenizer (transformers.Tokenizer | None, optional):
|
|
623
|
+
A tokenizer function that converts text to tensors.
|
|
624
|
+
Only used when `text_output` is `True`.
|
|
625
|
+
Must implement the following methods: `decode` and `batch_decode`.
|
|
626
|
+
Defaults to ``None``.
|
|
627
|
+
text_key (NestedKey | None, optional): The key for the text output in the TensorDict.
|
|
628
|
+
Defaults to "text".
|
|
629
|
+
|
|
630
|
+
Examples:
|
|
631
|
+
>>> from tensordict import TensorDict
|
|
632
|
+
>>> from torchrl.envs import LLMHashingEnv
|
|
633
|
+
>>> from transformers import GPT2Tokenizer
|
|
634
|
+
>>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
|
635
|
+
>>> x = tokenizer(["Check out TorchRL!"])["input_ids"]
|
|
636
|
+
>>> env = LLMHashingEnv(tokenizer=tokenizer)
|
|
637
|
+
>>> td = TensorDict(observation=x, batch_size=[1])
|
|
638
|
+
>>> td = env.reset(td)
|
|
639
|
+
>>> print(td)
|
|
640
|
+
TensorDict(
|
|
641
|
+
fields={
|
|
642
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
643
|
+
hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
644
|
+
observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
645
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
646
|
+
text: NonTensorStack(
|
|
647
|
+
['Check out TorchRL!'],
|
|
648
|
+
batch_size=torch.Size([1]),
|
|
649
|
+
device=None)},
|
|
650
|
+
batch_size=torch.Size([1]),
|
|
651
|
+
device=None,
|
|
652
|
+
is_shared=False)
|
|
653
|
+
|
|
654
|
+
"""
|
|
655
|
+
|
|
656
|
+
def __init__(
|
|
657
|
+
self,
|
|
658
|
+
vocab_size: int | None = None,
|
|
659
|
+
*,
|
|
660
|
+
hashing_module: Callable[[torch.Tensor], torch.Tensor] = None,
|
|
661
|
+
observation_key: NestedKey = "observation",
|
|
662
|
+
text_output: bool = True,
|
|
663
|
+
tokenizer: Callable[[str | list[str]], torch.Tensor] | None = None,
|
|
664
|
+
text_key: NestedKey | None = "text",
|
|
665
|
+
):
|
|
666
|
+
super().__init__()
|
|
667
|
+
if vocab_size is None:
|
|
668
|
+
if tokenizer is None:
|
|
669
|
+
raise TypeError(
|
|
670
|
+
"You must provide a vocab_size integer if tokenizer is `None`."
|
|
671
|
+
)
|
|
672
|
+
vocab_size = tokenizer.vocab_size
|
|
673
|
+
self._batch_locked = False
|
|
674
|
+
if hashing_module is None:
|
|
675
|
+
hashing_module = SipHash()
|
|
676
|
+
|
|
677
|
+
self._hashing_module = hashing_module
|
|
678
|
+
self._tokenizer = tokenizer
|
|
679
|
+
self.observation_key = observation_key
|
|
680
|
+
observation_spec = {
|
|
681
|
+
observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)),
|
|
682
|
+
"hashing": Unbounded(shape=(1,), dtype=torch.int64),
|
|
683
|
+
}
|
|
684
|
+
self.text_output = text_output
|
|
685
|
+
if not text_output:
|
|
686
|
+
text_key = None
|
|
687
|
+
elif text_key is None:
|
|
688
|
+
text_key = "text"
|
|
689
|
+
if text_key is not None:
|
|
690
|
+
observation_spec[text_key] = NonTensor(shape=())
|
|
691
|
+
self.text_key = text_key
|
|
692
|
+
self.observation_spec = Composite(observation_spec)
|
|
693
|
+
self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,)))
|
|
694
|
+
_StepMDP(self)
|
|
695
|
+
|
|
696
|
+
@set_list_to_stack(True)
|
|
697
|
+
def make_tensordict(self, input: str | list[str]) -> TensorDict:
|
|
698
|
+
"""Converts a string or list of strings in a TensorDict with appropriate shape and device."""
|
|
699
|
+
list_len = len(input) if isinstance(input, list) else 0
|
|
700
|
+
tensordict = TensorDict(
|
|
701
|
+
{self.observation_key: self._tokenizer(input)}, device=self.device
|
|
702
|
+
)
|
|
703
|
+
if list_len:
|
|
704
|
+
tensordict.batch_size = [list_len]
|
|
705
|
+
return self.reset(tensordict)
|
|
706
|
+
|
|
707
|
+
def _reset(self, tensordict: TensorDictBase):
|
|
708
|
+
"""Initializes the environment with a given observation.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
tensordict (TensorDictBase): A TensorDict containing the initial observation.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
A TensorDict containing the initial observation, its hash, and other relevant information.
|
|
715
|
+
|
|
716
|
+
"""
|
|
717
|
+
out = tensordict.empty()
|
|
718
|
+
obs = tensordict.get(self.observation_key, None)
|
|
719
|
+
if obs is None:
|
|
720
|
+
raise RuntimeError(
|
|
721
|
+
f"Resetting the {type(self).__name__} environment requires a prompt."
|
|
722
|
+
)
|
|
723
|
+
if self.text_output:
|
|
724
|
+
if obs.ndim > 1:
|
|
725
|
+
text = self._tokenizer.batch_decode(obs)
|
|
726
|
+
text = NonTensorStack.from_list(text)
|
|
727
|
+
else:
|
|
728
|
+
text = self._tokenizer.decode(obs)
|
|
729
|
+
text = NonTensorData(text)
|
|
730
|
+
out.set(self.text_key, text)
|
|
731
|
+
|
|
732
|
+
if obs.ndim > 1:
|
|
733
|
+
out.set("hashing", self._hashing_module(obs).unsqueeze(-1))
|
|
734
|
+
else:
|
|
735
|
+
out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))
|
|
736
|
+
|
|
737
|
+
if not self.full_done_spec.is_empty():
|
|
738
|
+
out.update(self.full_done_spec.zero(tensordict.shape))
|
|
739
|
+
else:
|
|
740
|
+
out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool))
|
|
741
|
+
out.set(
|
|
742
|
+
"terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)
|
|
743
|
+
)
|
|
744
|
+
return out
|
|
745
|
+
|
|
746
|
+
def _step(self, tensordict):
|
|
747
|
+
"""Takes an action (i.e., the next token to generate) and returns the next observation and reward.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
tensordict: A TensorDict containing the current observation and action.
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
A TensorDict containing the next observation, its hash, and other relevant information.
|
|
754
|
+
"""
|
|
755
|
+
out = tensordict.empty()
|
|
756
|
+
action = tensordict.get("action")
|
|
757
|
+
obs = torch.cat([tensordict.get(self.observation_key), action], -1)
|
|
758
|
+
kwargs = {self.observation_key: obs}
|
|
759
|
+
|
|
760
|
+
catval = torch.cat([tensordict.get("hashing"), action], -1)
|
|
761
|
+
if obs.ndim > 1:
|
|
762
|
+
new_hash = self._hashing_module(catval).unsqueeze(-1)
|
|
763
|
+
else:
|
|
764
|
+
new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1)
|
|
765
|
+
|
|
766
|
+
if self.text_output:
|
|
767
|
+
if obs.ndim > 1:
|
|
768
|
+
text = self._tokenizer.batch_decode(obs)
|
|
769
|
+
text = NonTensorStack.from_list(text)
|
|
770
|
+
else:
|
|
771
|
+
text = self._tokenizer.decode(obs)
|
|
772
|
+
text = NonTensorData(text)
|
|
773
|
+
kwargs[self.text_key] = text
|
|
774
|
+
kwargs.update(
|
|
775
|
+
{
|
|
776
|
+
"hashing": new_hash,
|
|
777
|
+
"done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool),
|
|
778
|
+
"terminated": torch.zeros(
|
|
779
|
+
(*tensordict.batch_size, 1), dtype=torch.bool
|
|
780
|
+
),
|
|
781
|
+
}
|
|
782
|
+
)
|
|
783
|
+
return out.update(kwargs)
|
|
784
|
+
|
|
785
|
+
def _set_seed(self, *args):
|
|
786
|
+
"""Sets the seed for the environment's randomness.
|
|
787
|
+
|
|
788
|
+
.. note:: This environment has no randomness, so this method does nothing.
|
|
789
|
+
"""
|