torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import os
|
|
9
|
+
from collections.abc import Sequence
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import TensorDict, TensorDictBase
|
|
14
|
+
from tensordict.utils import NestedKey
|
|
15
|
+
from torchrl._utils import logger as torchrl_logger
|
|
16
|
+
from torchrl.data.replay_buffers import (
|
|
17
|
+
SamplerWithoutReplacement,
|
|
18
|
+
TensorDictReplayBuffer,
|
|
19
|
+
TensorStorage,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
_has_transformers = importlib.util.find_spec("transformers") is not None
|
|
23
|
+
_has_datasets = importlib.util.find_spec("datasets") is not None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TokenizedDatasetLoader:
|
|
27
|
+
"""Loads a tokenizes dataset, and caches a memory-mapped copy of it.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
split (str): One of ``"train"`` or ``"valid"``.
|
|
31
|
+
max_length (int): the maximum sequence length.
|
|
32
|
+
dataset_name (str): the name of the dataset.
|
|
33
|
+
tokenizer_fn (callable): the tokeinizing method constructor, such as
|
|
34
|
+
:class:`torchrl.data.llm.TensorDictTokenizer`. When called,
|
|
35
|
+
it should return a :class:`tensordict.TensorDict` instance
|
|
36
|
+
or a dictionary-like structure with the tokenized data.
|
|
37
|
+
pre_tokenization_hook (callable, optional): called on
|
|
38
|
+
the Dataset before tokenization. It should return a modified
|
|
39
|
+
Dataset object.
|
|
40
|
+
The intended use is for carrying out tasks that
|
|
41
|
+
require modifying the dataset as a whole as opposed to modifying
|
|
42
|
+
individual datapoints, for example discarding certain datapoints
|
|
43
|
+
based on a particular condition. Tokenization and other
|
|
44
|
+
"elementwise" operations on the data are performed by the process
|
|
45
|
+
function which is mapped over the dataset.
|
|
46
|
+
root_dir (path, optional): the path where the datasets are stored.
|
|
47
|
+
Defaults to ``"$HOME/.cache/torchrl/data"``
|
|
48
|
+
from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
|
|
49
|
+
will be used. Otherwise, :func:`datasets.load_dataset` will be used.
|
|
50
|
+
Defaults to ``False``.
|
|
51
|
+
valid_size (int, optional): the size of the validation dataset (if split
|
|
52
|
+
starts with ``"valid"``) will be truncated to this value.
|
|
53
|
+
Defaults to 2000 items.
|
|
54
|
+
num_workers (int, optional): number of workers for :meth:`datasets.dataset.map`
|
|
55
|
+
which is called during tokenization.
|
|
56
|
+
Defaults to ``max(os.cpu_count() // 2, 1)``.
|
|
57
|
+
tokenizer_class (Type, optional): A tokenizer class, such as
|
|
58
|
+
:class:`~transformers.AutoTokenizer` (default).
|
|
59
|
+
tokenizer_model_name (str, optional): The model from which the vocabulary
|
|
60
|
+
should be gathered. Defaults to ``"gpt2"``.
|
|
61
|
+
|
|
62
|
+
The dataset will be stored in ``<root_dir>/<split>/<max_length>/``.
|
|
63
|
+
|
|
64
|
+
Examples:
|
|
65
|
+
>>> from torchrl.data.llm import TensorDictTokenizer
|
|
66
|
+
>>> from torchrl.data.llm.reward import pre_tokenization_hook
|
|
67
|
+
>>> split = "train"
|
|
68
|
+
>>> max_length = 550
|
|
69
|
+
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
|
|
70
|
+
>>> loader = TokenizedDatasetLoader(
|
|
71
|
+
... split,
|
|
72
|
+
... max_length,
|
|
73
|
+
... dataset_name,
|
|
74
|
+
... TensorDictTokenizer,
|
|
75
|
+
... pre_tokenization_hook=pre_tokenization_hook,
|
|
76
|
+
... )
|
|
77
|
+
>>> dataset = loader.load()
|
|
78
|
+
>>> print(dataset)
|
|
79
|
+
TensorDict(
|
|
80
|
+
fields={
|
|
81
|
+
attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
82
|
+
input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
83
|
+
batch_size=torch.Size([185068]),
|
|
84
|
+
device=None,
|
|
85
|
+
is_shared=False)
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
split,
|
|
92
|
+
max_length,
|
|
93
|
+
dataset_name,
|
|
94
|
+
tokenizer_fn: type[TensorDictTokenizer],
|
|
95
|
+
pre_tokenization_hook=None,
|
|
96
|
+
root_dir=None,
|
|
97
|
+
from_disk=False,
|
|
98
|
+
valid_size: int = 2000,
|
|
99
|
+
num_workers: int | None = None,
|
|
100
|
+
tokenizer_class=None,
|
|
101
|
+
tokenizer_model_name=None,
|
|
102
|
+
):
|
|
103
|
+
self.split = split
|
|
104
|
+
self.max_length = max_length
|
|
105
|
+
self.dataset_name = dataset_name
|
|
106
|
+
self.tokenizer_fn = tokenizer_fn
|
|
107
|
+
self.pre_tokenization_hook = pre_tokenization_hook
|
|
108
|
+
self.root_dir = root_dir
|
|
109
|
+
self.from_disk = from_disk
|
|
110
|
+
self.valid_size = valid_size
|
|
111
|
+
if num_workers is None:
|
|
112
|
+
num_workers = max(os.cpu_count() // 2, 1)
|
|
113
|
+
self.num_workers = num_workers
|
|
114
|
+
if tokenizer_class is None:
|
|
115
|
+
from transformers import AutoTokenizer
|
|
116
|
+
|
|
117
|
+
tokenizer_class = AutoTokenizer
|
|
118
|
+
if tokenizer_model_name is None:
|
|
119
|
+
tokenizer_model_name = "gpt2"
|
|
120
|
+
self.make_tokenizer(
|
|
121
|
+
tokenizer_class=AutoTokenizer, tokenizer_model_name=tokenizer_model_name
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def make_tokenizer(self, *, tokenizer_class, tokenizer_model_name):
|
|
125
|
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_model_name)
|
|
126
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
127
|
+
self.tokenizer = tokenizer
|
|
128
|
+
|
|
129
|
+
def load(self):
|
|
130
|
+
"""Loads a pre-processed, memory-mapped dataset if it exists, and creates it otherwise."""
|
|
131
|
+
root_dir = self.root_dir
|
|
132
|
+
max_length = self.max_length
|
|
133
|
+
split = self.split
|
|
134
|
+
if root_dir is None:
|
|
135
|
+
root_dir = Path(os.environ.get("HOME")) / ".cache/torchrl/data/"
|
|
136
|
+
os.makedirs(root_dir, exist_ok=True)
|
|
137
|
+
root_dir = Path(root_dir)
|
|
138
|
+
data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0]
|
|
139
|
+
data_dir_total = data_dir / split / str(max_length)
|
|
140
|
+
# search for data
|
|
141
|
+
torchrl_logger.info(f"Looking for data in {data_dir_total}")
|
|
142
|
+
if os.path.exists(data_dir_total):
|
|
143
|
+
dataset = TensorDict.load_memmap(data_dir_total)
|
|
144
|
+
return dataset
|
|
145
|
+
dataset = self._load_dataset()
|
|
146
|
+
dataset = self._tokenize(dataset)
|
|
147
|
+
prefix = (split, str(max_length))
|
|
148
|
+
result = self.dataset_to_tensordict(
|
|
149
|
+
dataset, data_dir=data_dir, prefix=prefix, valid_mask_key="valid_sample"
|
|
150
|
+
)
|
|
151
|
+
return result[prefix]
|
|
152
|
+
|
|
153
|
+
def _load_dataset(self):
|
|
154
|
+
"""Loads a text dataset from ``datasets``.
|
|
155
|
+
|
|
156
|
+
Returns: a dataset of type ``datasets.Dataset``.
|
|
157
|
+
"""
|
|
158
|
+
if not _has_datasets:
|
|
159
|
+
raise ImportError(
|
|
160
|
+
"preproc_data requires the datasets package to be installed."
|
|
161
|
+
)
|
|
162
|
+
from datasets import load_dataset, load_from_disk
|
|
163
|
+
|
|
164
|
+
if self.from_disk:
|
|
165
|
+
dataset = load_from_disk(str(self.dataset_name))[self.split]
|
|
166
|
+
else:
|
|
167
|
+
dataset = load_dataset(self.dataset_name, split=self.split)
|
|
168
|
+
if self.split.startswith("valid"):
|
|
169
|
+
# reduce size of validation dataset
|
|
170
|
+
dataset = dataset.select(range(self.valid_size))
|
|
171
|
+
if self.pre_tokenization_hook is not None:
|
|
172
|
+
dataset = self.pre_tokenization_hook(dataset)
|
|
173
|
+
return dataset
|
|
174
|
+
|
|
175
|
+
def _tokenize(
|
|
176
|
+
self,
|
|
177
|
+
dataset,
|
|
178
|
+
excluded_features: Sequence[str] | None = None,
|
|
179
|
+
):
|
|
180
|
+
"""Preprocesses a text dataset from ``datasets``.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
dataset (datasets.Dataset): a dataset loaded using :meth:`load_dataset`.
|
|
184
|
+
excluded_features (sequence of str, optional): the features to exclude
|
|
185
|
+
once tokenization is complete. Defaults to ``{"text", "prompt", "label", "valid_sample"}``.
|
|
186
|
+
|
|
187
|
+
Returns: a dataset of type ``datasets.Dataset``.
|
|
188
|
+
"""
|
|
189
|
+
if not _has_transformers:
|
|
190
|
+
raise ImportError("The transformers library is missing.")
|
|
191
|
+
|
|
192
|
+
num_workers = self.num_workers
|
|
193
|
+
if excluded_features is None:
|
|
194
|
+
excluded_features = {"text", "prompt", "label", "valid_sample"}
|
|
195
|
+
tokenizer = self.tokenizer
|
|
196
|
+
# tokenize the dataset
|
|
197
|
+
# TODO: replace this by TensorDict.map
|
|
198
|
+
dataset = dataset.map(
|
|
199
|
+
self.tokenizer_fn(
|
|
200
|
+
tokenizer, max_length=self.max_length, return_tensordict=False
|
|
201
|
+
),
|
|
202
|
+
desc="Tokenizing...",
|
|
203
|
+
num_proc=num_workers,
|
|
204
|
+
batched=True,
|
|
205
|
+
)
|
|
206
|
+
if not isinstance(dataset, TensorDictBase):
|
|
207
|
+
dataset_dict = dataset.to_dict()
|
|
208
|
+
if excluded_features:
|
|
209
|
+
dataset_dict = {
|
|
210
|
+
key: value
|
|
211
|
+
for key, value in dataset_dict.items()
|
|
212
|
+
if key not in excluded_features
|
|
213
|
+
}
|
|
214
|
+
dataset = TensorDict.from_dict(
|
|
215
|
+
dataset_dict, auto_batch_size=True, batch_dims=1
|
|
216
|
+
)
|
|
217
|
+
elif excluded_features:
|
|
218
|
+
dataset = dataset.exclude(*excluded_features)
|
|
219
|
+
# keep non empty rows (i.e. where at least one token is not eos)
|
|
220
|
+
if "valid_sample" in dataset.keys():
|
|
221
|
+
mask = dataset.get("valid_sample")
|
|
222
|
+
dataset = dataset[mask]
|
|
223
|
+
return dataset
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def dataset_to_tensordict(
|
|
227
|
+
dataset: datasets.Dataset | TensorDict, # noqa: F821
|
|
228
|
+
data_dir: Path,
|
|
229
|
+
prefix: NestedKey = None,
|
|
230
|
+
features: Sequence[str] = None,
|
|
231
|
+
batch_dims=1,
|
|
232
|
+
valid_mask_key=None,
|
|
233
|
+
):
|
|
234
|
+
"""Converts a dataset to a memory-mapped TensorDict.
|
|
235
|
+
|
|
236
|
+
If the dataset is already a :class:`TensorDict` instance, it is simply converted
|
|
237
|
+
to a memory-mapped TensorDict.
|
|
238
|
+
Otherwise, the dataset is expected to have a ``features`` attribute
|
|
239
|
+
which is a sequence of strings indicating the features that can be found
|
|
240
|
+
in the dataset. If it does not, the ``features`` must be passed explicitly
|
|
241
|
+
to this function.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
dataset (datasets.Dataset, TensorDict or equivalent): a dataset to convert
|
|
245
|
+
to a memory-mapped TensorDict.
|
|
246
|
+
If ``features`` is ``None``, it must have a ``features`` attribute
|
|
247
|
+
with the list of keys to write in the tensordict.
|
|
248
|
+
data_dir (Path or equivalent): directory where the data should be written.
|
|
249
|
+
prefix (NestedKey, optional): the prefix of the dataset location. This can
|
|
250
|
+
be used to differentiate several copies of a same dataset that have
|
|
251
|
+
undergone different preprocessings.
|
|
252
|
+
features (sequence of str, optional): a sequence of str indicating the
|
|
253
|
+
features that can be found in the dataset.
|
|
254
|
+
batch_dims (int, optional): the number of batch_dimensions of the data
|
|
255
|
+
(ie number of dimensions along which the tensordict can be indexed).
|
|
256
|
+
Defaults to 1.
|
|
257
|
+
valid_mask_key (NestedKey, optional): if provided, this entry will be
|
|
258
|
+
tentatively gathered and used to filder the data. Defaults to
|
|
259
|
+
``None`` (ie, no filter key).
|
|
260
|
+
|
|
261
|
+
Returns: a TensorDict containing memory-mapped tensors with the dataset.
|
|
262
|
+
|
|
263
|
+
Examples:
|
|
264
|
+
>>> from datasets import Dataset
|
|
265
|
+
>>> import tempfile
|
|
266
|
+
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
|
|
267
|
+
>>> with tempfile.TemporaryDirectory() as tmpdir:
|
|
268
|
+
... data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
|
|
269
|
+
... data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
|
|
270
|
+
... )
|
|
271
|
+
... print(data_memmap)
|
|
272
|
+
TensorDict(
|
|
273
|
+
fields={
|
|
274
|
+
some: TensorDict(
|
|
275
|
+
fields={
|
|
276
|
+
prefix: TensorDict(
|
|
277
|
+
fields={
|
|
278
|
+
labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
279
|
+
tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
280
|
+
batch_size=torch.Size([10]),
|
|
281
|
+
device=None,
|
|
282
|
+
is_shared=False)},
|
|
283
|
+
batch_size=torch.Size([]),
|
|
284
|
+
device=None,
|
|
285
|
+
is_shared=False)},
|
|
286
|
+
batch_size=torch.Size([]),
|
|
287
|
+
device=None,
|
|
288
|
+
is_shared=False)
|
|
289
|
+
|
|
290
|
+
"""
|
|
291
|
+
if not isinstance(dataset, TensorDict):
|
|
292
|
+
if features is None:
|
|
293
|
+
features = dataset.features
|
|
294
|
+
if prefix is None:
|
|
295
|
+
prefix = ()
|
|
296
|
+
data_dict = {key: torch.as_tensor(dataset[key]) for key in features}
|
|
297
|
+
out = TensorDict.from_dict(
|
|
298
|
+
data_dict, batch_dims=batch_dims, auto_batch_size=True
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
out = dataset
|
|
302
|
+
if valid_mask_key is not None and valid_mask_key in out.keys(
|
|
303
|
+
include_nested=True
|
|
304
|
+
):
|
|
305
|
+
out = out[out.get(valid_mask_key)]
|
|
306
|
+
out = TensorDict({prefix: out})
|
|
307
|
+
out.memmap_(prefix=data_dir)
|
|
308
|
+
return out
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def create_infinite_iterator(iterator):
|
|
312
|
+
"""Iterates indefinitely over an iterator."""
|
|
313
|
+
while True:
|
|
314
|
+
yield from iterator
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def get_dataloader(
|
|
318
|
+
batch_size: int,
|
|
319
|
+
block_size: int,
|
|
320
|
+
tensorclass_type: type,
|
|
321
|
+
device: torch.device,
|
|
322
|
+
dataset_name: str | None = None,
|
|
323
|
+
infinite: bool = True,
|
|
324
|
+
prefetch: int = 0,
|
|
325
|
+
split: str = "train",
|
|
326
|
+
root_dir: str | None = None,
|
|
327
|
+
from_disk: bool = False,
|
|
328
|
+
num_workers: int | None = None,
|
|
329
|
+
):
|
|
330
|
+
"""Creates a dataset and returns a dataloader from it.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
batch_size (int): the batch size of the dataloader samples.
|
|
334
|
+
block_size (int): the maximum length of a sequence in the dataloader.
|
|
335
|
+
tensorclass_type (tensorclass class): a tensorclass with a :meth:`from_dataset`
|
|
336
|
+
method that must accept three keyword arguments: ``split`` (see below),
|
|
337
|
+
``max_length`` which is the block size to be used for training and
|
|
338
|
+
``dataset_name``, a string indicating the dataset. The ``root_dir``
|
|
339
|
+
and ``from_disk`` arguments should also be supported.
|
|
340
|
+
device (torch.device or equivalent): the device where the samples should
|
|
341
|
+
be cast.
|
|
342
|
+
dataset_name (str, optional): the dataset name. If not provided and if
|
|
343
|
+
the tensorclass supports it, a default dataset name will be gathered
|
|
344
|
+
for the tensorclass being used.
|
|
345
|
+
infinite (bool, optional): if ``True``, the iteration will be infinite
|
|
346
|
+
such that ``next(iterator)`` will always return a value.
|
|
347
|
+
Defaults to ``True``.
|
|
348
|
+
prefetch (int, optional): the number of items to be prefetched if
|
|
349
|
+
multithreaded dataloading is being used.
|
|
350
|
+
split (str, optional): the data split. Either ``"train"`` or ``"valid"``.
|
|
351
|
+
Defaults to ``"train"``.
|
|
352
|
+
root_dir (path, optional): the path where the datasets are stored.
|
|
353
|
+
Defaults to ``"$HOME/.cache/torchrl/data"``
|
|
354
|
+
from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
|
|
355
|
+
will be used. Otherwise, :func:`datasets.load_dataset` will be used.
|
|
356
|
+
Defaults to ``False``.
|
|
357
|
+
num_workers (int, optional): number of workers for :meth:`datasets.dataset.map`
|
|
358
|
+
which is called during tokenization.
|
|
359
|
+
Defaults to ``max(os.cpu_count() // 2, 1)``.
|
|
360
|
+
|
|
361
|
+
Examples:
|
|
362
|
+
>>> from torchrl.data.llm.reward import PairwiseDataset
|
|
363
|
+
>>> dataloader = get_dataloader(
|
|
364
|
+
... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
|
|
365
|
+
>>> for d in dataloader:
|
|
366
|
+
... print(d)
|
|
367
|
+
... break
|
|
368
|
+
PairwiseDataset(
|
|
369
|
+
chosen_data=RewardData(
|
|
370
|
+
attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
371
|
+
input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
372
|
+
rewards=None,
|
|
373
|
+
end_scores=None,
|
|
374
|
+
batch_size=torch.Size([256]),
|
|
375
|
+
device=cpu,
|
|
376
|
+
is_shared=False),
|
|
377
|
+
rejected_data=RewardData(
|
|
378
|
+
attention_mask=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
379
|
+
input_ids=Tensor(shape=torch.Size([256, 550]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
380
|
+
rewards=None,
|
|
381
|
+
end_scores=None,
|
|
382
|
+
batch_size=torch.Size([256]),
|
|
383
|
+
device=cpu,
|
|
384
|
+
is_shared=False),
|
|
385
|
+
batch_size=torch.Size([256]),
|
|
386
|
+
device=cpu,
|
|
387
|
+
is_shared=False)
|
|
388
|
+
"""
|
|
389
|
+
data = tensorclass_type.from_dataset(
|
|
390
|
+
split=split,
|
|
391
|
+
dataset_name=dataset_name,
|
|
392
|
+
max_length=block_size,
|
|
393
|
+
root_dir=root_dir,
|
|
394
|
+
from_disk=from_disk,
|
|
395
|
+
num_workers=num_workers,
|
|
396
|
+
)
|
|
397
|
+
out = TensorDictReplayBuffer(
|
|
398
|
+
storage=TensorStorage(data),
|
|
399
|
+
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True),
|
|
400
|
+
sampler=SamplerWithoutReplacement(drop_last=True),
|
|
401
|
+
batch_size=batch_size,
|
|
402
|
+
prefetch=prefetch,
|
|
403
|
+
)
|
|
404
|
+
if infinite:
|
|
405
|
+
return create_infinite_iterator(out)
|
|
406
|
+
return out
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
class TensorDictTokenizer:
|
|
410
|
+
"""Factory for a process function that applies a tokenizer over a text example.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
tokenizer (tokenizer from transformers library): the tokenizer to use.
|
|
414
|
+
max_length (int): maximum length of the sequence.
|
|
415
|
+
key (str, optional): the key where to find the text. Defaults to ``"text"``.
|
|
416
|
+
padding (str, optional): type of padding. Defaults to ``"max_length"``.
|
|
417
|
+
truncation (bool, optional): whether the sequences should be truncated to max_length.
|
|
418
|
+
return_tensordict (bool, optional): if ``True``, a TensoDict is returned.
|
|
419
|
+
Otherwise, a the original data will be returned.
|
|
420
|
+
device (torch.device, optional): the device where to store the data.
|
|
421
|
+
This option is ignored if ``return_tensordict=False``.
|
|
422
|
+
|
|
423
|
+
See transformers library for more information about tokenizers:
|
|
424
|
+
Padding and truncation: `<https://huggingface.co/docs/transformers/pad_truncation>`_
|
|
425
|
+
|
|
426
|
+
Returns: a :class:`tensordict.TensorDict` instance with the same batch-size
|
|
427
|
+
as the input data.
|
|
428
|
+
|
|
429
|
+
Examples:
|
|
430
|
+
>>> from transformers import AutoTokenizer
|
|
431
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
432
|
+
>>> tokenizer.pad_token = 100
|
|
433
|
+
>>> process = TensorDictTokenizer(tokenizer, max_length=10)
|
|
434
|
+
>>> # example with a single input
|
|
435
|
+
>>> example = {"text": "I am a little worried"}
|
|
436
|
+
>>> process(example)
|
|
437
|
+
TensorDict(
|
|
438
|
+
fields={
|
|
439
|
+
attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
440
|
+
input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
441
|
+
batch_size=torch.Size([]),
|
|
442
|
+
device=None,
|
|
443
|
+
is_shared=False)
|
|
444
|
+
>>> # example with a multiple inputs
|
|
445
|
+
>>> example = {"text": ["Let me reassure you", "It will be ok"]}
|
|
446
|
+
>>> process(example)
|
|
447
|
+
TensorDict(
|
|
448
|
+
fields={
|
|
449
|
+
attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
450
|
+
input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
451
|
+
batch_size=torch.Size([2]),
|
|
452
|
+
device=None,
|
|
453
|
+
is_shared=False)
|
|
454
|
+
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
def __init__(
|
|
458
|
+
self,
|
|
459
|
+
tokenizer,
|
|
460
|
+
max_length,
|
|
461
|
+
key="text",
|
|
462
|
+
padding="max_length",
|
|
463
|
+
truncation=True,
|
|
464
|
+
return_tensordict=True,
|
|
465
|
+
device=None,
|
|
466
|
+
):
|
|
467
|
+
self.tokenizer = tokenizer
|
|
468
|
+
self.max_length = max_length
|
|
469
|
+
self.key = key
|
|
470
|
+
self.padding = padding
|
|
471
|
+
self.truncation = truncation
|
|
472
|
+
self.return_tensordict = return_tensordict
|
|
473
|
+
self.device = device
|
|
474
|
+
|
|
475
|
+
def __call__(self, sample):
|
|
476
|
+
input = sample[self.key]
|
|
477
|
+
tokenized_sample = self.tokenizer(
|
|
478
|
+
input,
|
|
479
|
+
max_length=self.max_length,
|
|
480
|
+
padding=self.padding,
|
|
481
|
+
truncation=self.truncation,
|
|
482
|
+
)
|
|
483
|
+
batch_size = [] if isinstance(input, str) else [len(input)]
|
|
484
|
+
if self.return_tensordict:
|
|
485
|
+
return TensorDict.from_dict(
|
|
486
|
+
dict(tokenized_sample),
|
|
487
|
+
batch_size=batch_size,
|
|
488
|
+
device=self.device,
|
|
489
|
+
auto_batch_size=True,
|
|
490
|
+
)
|
|
491
|
+
return tokenized_sample
|