torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import NonTensorData, NonTensorStack, TensorDictBase
|
|
12
|
+
from tensordict.nn import dispatch
|
|
13
|
+
from tensordict.utils import _zip_strict, NestedKey
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
from torchrl._utils import _replace_last
|
|
16
|
+
from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
|
|
17
|
+
from torchrl.envs import Transform, UnaryTransform
|
|
18
|
+
from torchrl.envs.transforms.utils import _set_missing_tolerance
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Tokenizer(UnaryTransform):
|
|
22
|
+
r"""Applies a tokenization operation on the specified inputs.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
|
|
26
|
+
out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
|
|
27
|
+
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
|
|
28
|
+
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
|
|
29
|
+
|
|
30
|
+
Keyword Args:
|
|
31
|
+
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
|
|
32
|
+
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
|
|
33
|
+
pre-trained tokenizer.
|
|
34
|
+
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
|
|
35
|
+
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
|
|
36
|
+
function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
|
|
37
|
+
inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
|
|
38
|
+
additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
|
|
39
|
+
|
|
40
|
+
.. note:: This transform can be used both to transform output strings into tokens and to transform back tokenized
|
|
41
|
+
actions or states into strings. If the environment has a string state-spec, the transformed version will have
|
|
42
|
+
a tokenized state-spec. If it is a string action spec, it will result in a tokenized action spec.
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
in_keys: Sequence[NestedKey] | None = None,
|
|
49
|
+
out_keys: Sequence[NestedKey] | None = None,
|
|
50
|
+
in_keys_inv: Sequence[NestedKey] | None = None,
|
|
51
|
+
out_keys_inv: Sequence[NestedKey] | None = None,
|
|
52
|
+
*,
|
|
53
|
+
tokenizer: transformers.PretrainedTokenizerBase = None, # noqa: F821
|
|
54
|
+
use_raw_nontensor: bool = False,
|
|
55
|
+
additional_tokens: list[str] | None = None,
|
|
56
|
+
skip_special_tokens: bool = True,
|
|
57
|
+
add_special_tokens: bool = False,
|
|
58
|
+
padding: bool = True,
|
|
59
|
+
max_length: int | None = None,
|
|
60
|
+
return_attention_mask: bool = True,
|
|
61
|
+
missing_tolerance: bool = True,
|
|
62
|
+
call_before_reset: bool = False,
|
|
63
|
+
):
|
|
64
|
+
if tokenizer is None:
|
|
65
|
+
from transformers import AutoTokenizer
|
|
66
|
+
|
|
67
|
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
|
68
|
+
elif isinstance(tokenizer, str):
|
|
69
|
+
from transformers import AutoTokenizer
|
|
70
|
+
|
|
71
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
|
72
|
+
|
|
73
|
+
self.tokenizer = tokenizer
|
|
74
|
+
self.add_special_tokens = add_special_tokens
|
|
75
|
+
self.skip_special_tokens = skip_special_tokens
|
|
76
|
+
self.padding = padding
|
|
77
|
+
self.max_length = max_length
|
|
78
|
+
self.return_attention_mask = return_attention_mask
|
|
79
|
+
self.call_before_reset = call_before_reset
|
|
80
|
+
if additional_tokens:
|
|
81
|
+
self.tokenizer.add_tokens(additional_tokens)
|
|
82
|
+
super().__init__(
|
|
83
|
+
in_keys=in_keys,
|
|
84
|
+
out_keys=out_keys,
|
|
85
|
+
in_keys_inv=in_keys_inv,
|
|
86
|
+
out_keys_inv=out_keys_inv,
|
|
87
|
+
fn=self.call_tokenizer_fn,
|
|
88
|
+
inv_fn=self.call_tokenizer_inv_fn,
|
|
89
|
+
use_raw_nontensor=use_raw_nontensor,
|
|
90
|
+
)
|
|
91
|
+
self._missing_tolerance = missing_tolerance
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def device(self):
|
|
95
|
+
if "_device" in self.__dict__:
|
|
96
|
+
return self._device
|
|
97
|
+
parent = self.parent
|
|
98
|
+
if parent is None:
|
|
99
|
+
return None
|
|
100
|
+
device = parent.device
|
|
101
|
+
self._device = device
|
|
102
|
+
return device
|
|
103
|
+
|
|
104
|
+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
|
|
105
|
+
# Specialized for attention mask
|
|
106
|
+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
|
|
107
|
+
value = next_tensordict.get(in_key, default=None)
|
|
108
|
+
if value is not None:
|
|
109
|
+
observation = self._apply_transform(value)
|
|
110
|
+
if self.return_attention_mask:
|
|
111
|
+
observation, attention_mask = observation
|
|
112
|
+
next_tensordict.set(
|
|
113
|
+
_replace_last(out_key, "attention_mask"),
|
|
114
|
+
attention_mask,
|
|
115
|
+
)
|
|
116
|
+
next_tensordict.set(
|
|
117
|
+
out_key,
|
|
118
|
+
observation,
|
|
119
|
+
)
|
|
120
|
+
elif (
|
|
121
|
+
self.missing_tolerance
|
|
122
|
+
and self.return_attention_mask
|
|
123
|
+
and out_key in next_tensordict.keys(True)
|
|
124
|
+
):
|
|
125
|
+
attention_key = _replace_last(out_key, "attention_mask")
|
|
126
|
+
if attention_key not in next_tensordict:
|
|
127
|
+
next_tensordict[attention_key] = torch.ones_like(
|
|
128
|
+
next_tensordict.get(out_key)
|
|
129
|
+
)
|
|
130
|
+
elif not self.missing_tolerance:
|
|
131
|
+
raise KeyError(
|
|
132
|
+
f"{self}: '{in_key}' not found in tensordict {next_tensordict}"
|
|
133
|
+
)
|
|
134
|
+
return next_tensordict
|
|
135
|
+
|
|
136
|
+
@dispatch(source="in_keys", dest="out_keys")
|
|
137
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
138
|
+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
|
|
139
|
+
data = tensordict.get(in_key, None)
|
|
140
|
+
if data is not None:
|
|
141
|
+
data = self._apply_transform(data)
|
|
142
|
+
if self.return_attention_mask:
|
|
143
|
+
data, attention_mask = data
|
|
144
|
+
tensordict.set(
|
|
145
|
+
_replace_last(out_key, "attention_mask"),
|
|
146
|
+
attention_mask,
|
|
147
|
+
)
|
|
148
|
+
tensordict.set(out_key, data)
|
|
149
|
+
elif not self.missing_tolerance:
|
|
150
|
+
raise KeyError(f"'{in_key}' not found in tensordict {tensordict}")
|
|
151
|
+
return tensordict
|
|
152
|
+
|
|
153
|
+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
154
|
+
if self.call_before_reset:
|
|
155
|
+
with _set_missing_tolerance(self, True):
|
|
156
|
+
tensordict = self._call(tensordict)
|
|
157
|
+
return tensordict
|
|
158
|
+
|
|
159
|
+
def _reset(
|
|
160
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
161
|
+
) -> TensorDictBase:
|
|
162
|
+
if self.call_before_reset:
|
|
163
|
+
return tensordict_reset
|
|
164
|
+
return super()._reset(tensordict, tensordict_reset)
|
|
165
|
+
|
|
166
|
+
def call_tokenizer_fn(self, value: str | list[str]):
|
|
167
|
+
device = self.device
|
|
168
|
+
kwargs = {"add_special_tokens": self.add_special_tokens}
|
|
169
|
+
if self.max_length is not None:
|
|
170
|
+
kwargs["padding"] = "max_length"
|
|
171
|
+
kwargs["max_length"] = self.max_length
|
|
172
|
+
if isinstance(value, str):
|
|
173
|
+
out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0]
|
|
174
|
+
# TODO: incorporate attention mask
|
|
175
|
+
if self.return_attention_mask:
|
|
176
|
+
attention_mask = torch.ones_like(out, dtype=torch.int64)
|
|
177
|
+
else:
|
|
178
|
+
kwargs["padding"] = (
|
|
179
|
+
self.padding if self.max_length is None else "max_length"
|
|
180
|
+
)
|
|
181
|
+
kwargs["return_attention_mask"] = self.return_attention_mask
|
|
182
|
+
# kwargs["return_token_type_ids"] = False
|
|
183
|
+
out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs)
|
|
184
|
+
if self.return_attention_mask:
|
|
185
|
+
attention_mask = out["attention_mask"]
|
|
186
|
+
out = out["input_ids"]
|
|
187
|
+
|
|
188
|
+
if device is not None and out.device != device:
|
|
189
|
+
out = out.to(device)
|
|
190
|
+
if self.return_attention_mask:
|
|
191
|
+
attention_mask = attention_mask.to(device)
|
|
192
|
+
if self.return_attention_mask:
|
|
193
|
+
return out, attention_mask
|
|
194
|
+
return out
|
|
195
|
+
|
|
196
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
197
|
+
# Override _inv_call to account for ragged dims
|
|
198
|
+
if not self.in_keys_inv:
|
|
199
|
+
return tensordict
|
|
200
|
+
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
|
|
201
|
+
data = tensordict.get(out_key, None, as_padded_tensor=True)
|
|
202
|
+
if data is not None:
|
|
203
|
+
item = self._inv_apply_transform(data)
|
|
204
|
+
tensordict.set(in_key, item)
|
|
205
|
+
elif not self.missing_tolerance:
|
|
206
|
+
raise KeyError(f"'{out_key}' not found in tensordict {tensordict}")
|
|
207
|
+
return tensordict
|
|
208
|
+
|
|
209
|
+
def call_tokenizer_inv_fn(self, value: Tensor):
|
|
210
|
+
if value.ndim == 1:
|
|
211
|
+
out = self.tokenizer.decode(
|
|
212
|
+
value.int(), skip_special_tokens=self.skip_special_tokens
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
out = self.tokenizer.batch_decode(
|
|
216
|
+
value.int(), skip_special_tokens=self.skip_special_tokens
|
|
217
|
+
)
|
|
218
|
+
device = self._str_device
|
|
219
|
+
if isinstance(out, list):
|
|
220
|
+
result = NonTensorStack(*out)
|
|
221
|
+
if device:
|
|
222
|
+
result = result.to(device)
|
|
223
|
+
return result
|
|
224
|
+
return NonTensorData(out, device=device)
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def _str_device(self):
|
|
228
|
+
parent = self.parent
|
|
229
|
+
if parent is None:
|
|
230
|
+
return None
|
|
231
|
+
if self.in_keys:
|
|
232
|
+
in_key = self.in_keys[0]
|
|
233
|
+
elif self.in_keys_inv:
|
|
234
|
+
in_key = self.in_keys_inv[0]
|
|
235
|
+
else:
|
|
236
|
+
return None
|
|
237
|
+
if in_key in parent.observation_keys:
|
|
238
|
+
return parent.full_observation_spec[in_key].device
|
|
239
|
+
if in_key in parent.action_keys:
|
|
240
|
+
return parent.full_action_spec[in_key].device
|
|
241
|
+
if in_key in parent.state_keys:
|
|
242
|
+
return parent.full_state_spec[in_key].device
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
def transform_input_spec(self, input_spec: Composite) -> Composite:
|
|
246
|
+
# We need to cap the spec to generate valid random strings
|
|
247
|
+
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
|
|
248
|
+
if in_key in input_spec["full_state_spec"].keys(True, True):
|
|
249
|
+
spec = input_spec["full_state_spec"]
|
|
250
|
+
elif in_key in input_spec["full_action_spec"].keys(False, True):
|
|
251
|
+
spec = input_spec["full_action_spec"]
|
|
252
|
+
else:
|
|
253
|
+
raise KeyError(
|
|
254
|
+
f"The input keys {in_key} wasn't found in the env input specs."
|
|
255
|
+
)
|
|
256
|
+
local_spec = spec.pop(in_key)
|
|
257
|
+
local_dtype = local_spec.dtype
|
|
258
|
+
if local_dtype is None or local_dtype.is_floating_point:
|
|
259
|
+
local_dtype = torch.int64
|
|
260
|
+
new_shape = spec.shape
|
|
261
|
+
if self.max_length is None:
|
|
262
|
+
# Then we can't tell what the shape will be
|
|
263
|
+
new_shape = new_shape + torch.Size((-1,))
|
|
264
|
+
else:
|
|
265
|
+
new_shape = new_shape + torch.Size((self.max_length,))
|
|
266
|
+
spec[out_key] = Bounded(
|
|
267
|
+
0,
|
|
268
|
+
self.tokenizer.vocab_size,
|
|
269
|
+
shape=new_shape,
|
|
270
|
+
device=local_spec.device,
|
|
271
|
+
dtype=local_dtype,
|
|
272
|
+
)
|
|
273
|
+
return input_spec
|
|
274
|
+
|
|
275
|
+
transform_output_spec = Transform.transform_output_spec
|
|
276
|
+
transform_reward_spec = Transform.transform_reward_spec
|
|
277
|
+
transform_done_spec = Transform.transform_done_spec
|
|
278
|
+
|
|
279
|
+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
|
|
280
|
+
attention_mask_keys = set()
|
|
281
|
+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
|
|
282
|
+
new_shape = observation_spec.shape + torch.Size((-1,))
|
|
283
|
+
try:
|
|
284
|
+
in_spec = observation_spec[in_key]
|
|
285
|
+
obs_dtype = in_spec.dtype
|
|
286
|
+
device = in_spec.device
|
|
287
|
+
except KeyError:
|
|
288
|
+
# In some cases (eg, the tokenizer is applied during reset on data that
|
|
289
|
+
# originates from a dataloader) we don't have an in_spec
|
|
290
|
+
in_spec = None
|
|
291
|
+
obs_dtype = None
|
|
292
|
+
device = observation_spec.device
|
|
293
|
+
if obs_dtype is None or obs_dtype.is_floating_point:
|
|
294
|
+
obs_dtype = torch.int64
|
|
295
|
+
observation_spec[out_key] = Bounded(
|
|
296
|
+
0,
|
|
297
|
+
self.tokenizer.vocab_size,
|
|
298
|
+
shape=new_shape,
|
|
299
|
+
device=device,
|
|
300
|
+
dtype=obs_dtype,
|
|
301
|
+
)
|
|
302
|
+
if self.return_attention_mask:
|
|
303
|
+
attention_mask_key = _replace_last(out_key, "attention_mask")
|
|
304
|
+
if attention_mask_key in attention_mask_keys:
|
|
305
|
+
raise KeyError(
|
|
306
|
+
"Conflicting attention_mask keys. Make sure the token tensors are "
|
|
307
|
+
"nested at different places in the tensordict such that `(*root, 'attention_mask')` "
|
|
308
|
+
"entries are unique."
|
|
309
|
+
)
|
|
310
|
+
attention_mask_keys.add(attention_mask_key)
|
|
311
|
+
attention_dtype = obs_dtype
|
|
312
|
+
if attention_dtype is None or attention_dtype.is_floating_point:
|
|
313
|
+
attention_dtype = torch.int64
|
|
314
|
+
observation_spec[attention_mask_key] = Bounded(
|
|
315
|
+
0,
|
|
316
|
+
2,
|
|
317
|
+
shape=new_shape,
|
|
318
|
+
device=device,
|
|
319
|
+
dtype=attention_dtype,
|
|
320
|
+
)
|
|
321
|
+
return observation_spec
|