evalscope 0.14.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (181) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/benchmarks/__init__.py +2 -2
  3. evalscope/benchmarks/aigc/__init__.py +0 -0
  4. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  5. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  6. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  7. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  8. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  9. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  10. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  11. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  12. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  13. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  14. evalscope/benchmarks/arc/arc_adapter.py +1 -1
  15. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  16. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  17. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  18. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  19. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  20. evalscope/benchmarks/data_adapter.py +16 -9
  21. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  22. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  23. evalscope/benchmarks/general_qa/general_qa_adapter.py +3 -3
  24. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  25. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
  26. evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
  27. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  28. evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
  29. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  30. evalscope/benchmarks/utils.py +7 -16
  31. evalscope/cli/start_app.py +1 -1
  32. evalscope/collections/evaluator.py +16 -4
  33. evalscope/config.py +7 -3
  34. evalscope/constants.py +11 -0
  35. evalscope/evaluator/evaluator.py +9 -3
  36. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  37. evalscope/metrics/__init__.py +49 -4
  38. evalscope/metrics/llm_judge.py +1 -1
  39. evalscope/metrics/named_metrics.py +13 -0
  40. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  41. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  42. evalscope/metrics/t2v_metrics/constants.py +12 -0
  43. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  44. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  45. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  46. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  47. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  48. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  49. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  50. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  51. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  52. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  53. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  54. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  55. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  56. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  57. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  58. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  59. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  60. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  61. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  62. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  63. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  64. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  65. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  66. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  67. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  68. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  69. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  70. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  71. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  72. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  73. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  74. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  75. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  76. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  77. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  139. evalscope/metrics/t2v_metrics/score.py +78 -0
  140. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  141. evalscope/models/__init__.py +50 -14
  142. evalscope/models/adapters/__init__.py +17 -0
  143. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  144. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  145. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  146. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  147. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  148. evalscope/models/adapters/t2i_adapter.py +76 -0
  149. evalscope/models/custom/__init__.py +2 -1
  150. evalscope/models/custom/dummy_model.py +11 -13
  151. evalscope/models/local_model.py +82 -33
  152. evalscope/models/model.py +2 -42
  153. evalscope/models/register.py +26 -0
  154. evalscope/perf/benchmark.py +4 -3
  155. evalscope/perf/main.py +4 -2
  156. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  157. evalscope/perf/utils/benchmark_util.py +2 -2
  158. evalscope/perf/utils/db_util.py +16 -8
  159. evalscope/report/__init__.py +1 -0
  160. evalscope/report/app.py +117 -67
  161. evalscope/report/app_arguments.py +11 -0
  162. evalscope/report/generator.py +1 -1
  163. evalscope/run.py +3 -3
  164. evalscope/third_party/thinkbench/eval.py +19 -7
  165. evalscope/utils/chat_service.py +2 -2
  166. evalscope/utils/import_utils.py +66 -0
  167. evalscope/utils/utils.py +12 -4
  168. evalscope/version.py +2 -2
  169. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/METADATA +20 -3
  170. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/RECORD +178 -66
  171. tests/aigc/__init__.py +1 -0
  172. tests/aigc/test_t2i.py +87 -0
  173. tests/cli/test_run.py +20 -7
  174. tests/perf/test_perf.py +6 -3
  175. evalscope/metrics/code_metric.py +0 -98
  176. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  177. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  178. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/LICENSE +0 -0
  179. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/WHEEL +0 -0
  180. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/entry_points.txt +0 -0
  181. {evalscope-0.14.0.dist-info → evalscope-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1844 @@
1
+ # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ PyTorch T5 model."""
15
+
16
+ import copy
17
+ import math
18
+ import os
19
+ import torch
20
+ import warnings
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ from torch.utils.checkpoint import checkpoint
24
+ from transformers.activations import ACT2FN
25
+ from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput,
26
+ Seq2SeqModelOutput)
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.models.t5.configuration_t5 import T5Config
29
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
30
+ from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward,
31
+ is_torch_fx_proxy, logging, replace_return_docstrings)
32
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
33
+ from typing import Optional, Tuple, Union
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ _CONFIG_FOR_DOC = 'T5Config'
38
+ _TOKENIZER_FOR_DOC = 'T5Tokenizer'
39
+ _CHECKPOINT_FOR_DOC = 't5-small'
40
+
41
+ ####################################################
42
+ # This dict contains ids and associated url
43
+ # for the pretrained weights provided with the models
44
+ ####################################################
45
+ T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
46
+ 't5-small',
47
+ 't5-base',
48
+ 't5-large',
49
+ 't5-3b',
50
+ 't5-11b',
51
+ # See all T5 models at https://huggingface.co/models?filter=t5
52
+ ]
53
+
54
+
55
+ ####################################################
56
+ # This is a conversion method from TF 1.0 to PyTorch
57
+ # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
58
+ ####################################################
59
+ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
60
+ """Load tf checkpoints in a pytorch model."""
61
+ try:
62
+ import numpy as np
63
+ import re
64
+ import tensorflow as tf
65
+ except ImportError:
66
+ logger.error('Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
67
+ 'https://www.tensorflow.org/install/ for installation instructions.')
68
+ raise
69
+ tf_path = os.path.abspath(tf_checkpoint_path)
70
+ logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
71
+ # Load weights from TF model
72
+ init_vars = tf.train.list_variables(tf_path)
73
+ names = []
74
+ tf_weights = {}
75
+ for name, shape in init_vars:
76
+ logger.info(f'Loading TF weight {name} with shape {shape}')
77
+ array = tf.train.load_variable(tf_path, name)
78
+ names.append(name)
79
+ tf_weights[name] = array
80
+
81
+ for txt_name in names:
82
+ name = txt_name.split('/')
83
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
84
+ # which are not required for using pretrained model
85
+ if any(n in [
86
+ 'adam_v',
87
+ 'adam_m',
88
+ 'AdamWeightDecayOptimizer',
89
+ 'AdamWeightDecayOptimizer_1',
90
+ 'global_step',
91
+ ] for n in name):
92
+ logger.info(f"Skipping {'/'.join(name)}")
93
+ tf_weights.pop(txt_name, None)
94
+ continue
95
+ if '_slot_' in name[-1]:
96
+ logger.info(f"Skipping {'/'.join(name)}")
97
+ tf_weights.pop(txt_name, None)
98
+ continue
99
+ pointer = model
100
+ array = tf_weights[txt_name]
101
+
102
+ for m_name in name:
103
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
104
+ scope_names = re.split(r'_(\d+)', m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] in ['kernel', 'scale', 'embedding']:
108
+ pointer = getattr(pointer, 'weight')
109
+ elif scope_names[0] == 'self_attention':
110
+ pointer = getattr(pointer, 'layer')
111
+ pointer = pointer[0]
112
+ elif scope_names[0] == 'enc_dec_attention':
113
+ pointer = getattr(pointer, 'layer')
114
+ pointer = pointer[1]
115
+ elif scope_names[0] == 'dense_relu_dense':
116
+ pointer = getattr(pointer, 'layer')
117
+ pointer = pointer[2]
118
+ elif scope_names[0] == 'rms_norm':
119
+ if hasattr(pointer, 'layer_norm'):
120
+ pointer = getattr(pointer, 'layer_norm')
121
+ elif hasattr(pointer, 'final_layer_norm'):
122
+ pointer = getattr(pointer, 'final_layer_norm')
123
+ elif scope_names[0] == 'scale':
124
+ pointer = getattr(pointer, 'weight')
125
+ elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
126
+ pointer = getattr(pointer, 'bias')
127
+ elif scope_names[0] == 'squad':
128
+ pointer = getattr(pointer, 'classifier')
129
+ elif scope_names[0] == 'decoder' and name[1] == 'logits':
130
+ continue
131
+ elif scope_names[0] == 'logits':
132
+ pointer = getattr(pointer, 'lm_head')
133
+ elif (scope_names[0] == 'wi' and len(scope_names) > 1 and scope_names[1].isdigit()):
134
+ pointer = getattr(pointer, f'wi_{scope_names[1]}')
135
+ continue
136
+ else:
137
+ try:
138
+ pointer = getattr(pointer, scope_names[0])
139
+ except AttributeError:
140
+ logger.info(f"Skipping {'/'.join(name)}")
141
+ continue
142
+ if len(scope_names) >= 2:
143
+ num = int(scope_names[1])
144
+ pointer = pointer[num]
145
+ if scope_names[0] not in ['kernel', 'scale', 'embedding']:
146
+ pointer = getattr(pointer, 'weight')
147
+ if scope_names[0] != 'embedding':
148
+ logger.info(f'Transposing numpy weight of shape {array.shape} for {name}')
149
+ array = np.transpose(array)
150
+ try:
151
+ assert (
152
+ pointer.shape == array.shape), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
153
+ except AssertionError as e:
154
+ e.args += (pointer.shape, array.shape)
155
+ raise
156
+ logger.info(f'Initialize PyTorch weight {name}')
157
+ pointer.data = torch.from_numpy(array.astype(np.float32))
158
+ tf_weights.pop(txt_name, None)
159
+
160
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
161
+ return model
162
+
163
+
164
+ ####################################################
165
+ # PyTorch Models are constructed by sub-classing
166
+ # - torch.nn.Module for the layers and
167
+ # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
168
+ ####################################################
169
+ PARALLELIZE_DOCSTRING = r"""
170
+ This is an experimental feature and is a subject to change at a moment's notice.
171
+
172
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
173
+ it will evenly distribute blocks across all devices.
174
+
175
+ Args:
176
+ device_map (`Dict[int, list]`, optional, defaults to None):
177
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
178
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
179
+ have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
180
+ following number of attention modules:
181
+
182
+ - t5-small: 6
183
+ - t5-base: 12
184
+ - t5-large: 24
185
+ - t5-3b: 24
186
+ - t5-11b: 24
187
+
188
+ Example:
189
+
190
+ ```python
191
+ # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
192
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
193
+ device_map = {
194
+ 0: [0, 1, 2],
195
+ 1: [3, 4, 5, 6, 7, 8, 9],
196
+ 2: [10, 11, 12, 13, 14, 15, 16],
197
+ 3: [17, 18, 19, 20, 21, 22, 23],
198
+ }
199
+ model.parallelize(device_map)
200
+ ```
201
+ """
202
+ DEPARALLELIZE_DOCSTRING = r"""
203
+ Moves the model to cpu from a model parallel state.
204
+
205
+ Example:
206
+
207
+ ```python
208
+ # On a 4 GPU machine with t5-3b:
209
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
210
+ device_map = {
211
+ 0: [0, 1, 2],
212
+ 1: [3, 4, 5, 6, 7, 8, 9],
213
+ 2: [10, 11, 12, 13, 14, 15, 16],
214
+ 3: [17, 18, 19, 20, 21, 22, 23],
215
+ }
216
+ model.parallelize(device_map) # Splits the model across several devices
217
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
218
+ ```
219
+ """
220
+
221
+
222
+ class T5LayerNorm(nn.Module):
223
+
224
+ def __init__(self, hidden_size, eps=1e-6):
225
+ """
226
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
227
+ """
228
+ super().__init__()
229
+ self.weight = nn.Parameter(torch.ones(hidden_size))
230
+ self.variance_epsilon = eps
231
+
232
+ def forward(self, hidden_states):
233
+
234
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
235
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
236
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
237
+ # half-precision inputs is done in fp32
238
+
239
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
240
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
241
+
242
+ # convert into half-precision if necessary
243
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
244
+ hidden_states = hidden_states.to(self.weight.dtype)
245
+
246
+ return self.weight * hidden_states
247
+
248
+
249
+ try:
250
+ from apex.normalization import FusedRMSNorm
251
+
252
+ T5LayerNorm = FusedRMSNorm # noqa
253
+
254
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm')
255
+ except ImportError:
256
+ # using the normal T5LayerNorm
257
+ pass
258
+ except Exception:
259
+ logger.warning('discovered apex but it failed to load, falling back to T5LayerNorm')
260
+ pass
261
+
262
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
263
+
264
+
265
+ class T5DenseActDense(nn.Module):
266
+
267
+ def __init__(self, config: T5Config):
268
+ super().__init__()
269
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
270
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
271
+ self.dropout = nn.Dropout(config.dropout_rate)
272
+ self.act = ACT2FN[config.dense_act_fn]
273
+
274
+ def forward(self, hidden_states):
275
+ hidden_states = self.wi(hidden_states)
276
+ hidden_states = self.act(hidden_states)
277
+ hidden_states = self.dropout(hidden_states)
278
+ hidden_states = self.wo(hidden_states)
279
+ return hidden_states
280
+
281
+
282
+ class T5DenseGatedActDense(nn.Module):
283
+
284
+ def __init__(self, config: T5Config):
285
+ super().__init__()
286
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
287
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
288
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
289
+ self.dropout = nn.Dropout(config.dropout_rate)
290
+ self.act = ACT2FN[config.dense_act_fn]
291
+
292
+ def forward(self, hidden_states):
293
+ hidden_gelu = self.act(self.wi_0(hidden_states))
294
+ hidden_linear = self.wi_1(hidden_states)
295
+ hidden_states = hidden_gelu * hidden_linear
296
+ hidden_states = self.dropout(hidden_states)
297
+ hidden_states = self.wo(hidden_states)
298
+ return hidden_states
299
+
300
+
301
+ class T5LayerFF(nn.Module):
302
+
303
+ def __init__(self, config: T5Config):
304
+ super().__init__()
305
+ if config.is_gated_act:
306
+ self.DenseReluDense = T5DenseGatedActDense(config)
307
+ else:
308
+ self.DenseReluDense = T5DenseActDense(config)
309
+
310
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
311
+ self.dropout = nn.Dropout(config.dropout_rate)
312
+
313
+ def forward(self, hidden_states):
314
+ forwarded_states = self.layer_norm(hidden_states)
315
+ forwarded_states = self.DenseReluDense(forwarded_states)
316
+ hidden_states = hidden_states + self.dropout(forwarded_states)
317
+ return hidden_states
318
+
319
+
320
+ class T5Attention(nn.Module):
321
+
322
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
323
+ super().__init__()
324
+ self.is_decoder = config.is_decoder
325
+ self.has_relative_attention_bias = has_relative_attention_bias
326
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
327
+ self.relative_attention_max_distance = config.relative_attention_max_distance
328
+ self.d_model = config.d_model
329
+ self.key_value_proj_dim = config.d_kv
330
+ self.n_heads = config.num_heads
331
+ self.dropout = config.dropout_rate
332
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
333
+
334
+ # Mesh TensorFlow initialization to avoid scaling before softmax
335
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
336
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
337
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
338
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
339
+
340
+ if self.has_relative_attention_bias:
341
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
342
+ self.pruned_heads = set()
343
+ self.gradient_checkpointing = False
344
+
345
+ def prune_heads(self, heads):
346
+ if len(heads) == 0:
347
+ return
348
+ heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads)
349
+ # Prune linear layers
350
+ self.q = prune_linear_layer(self.q, index)
351
+ self.k = prune_linear_layer(self.k, index)
352
+ self.v = prune_linear_layer(self.v, index)
353
+ self.o = prune_linear_layer(self.o, index, dim=1)
354
+ # Update hyper params
355
+ self.n_heads = self.n_heads - len(heads)
356
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
357
+ self.pruned_heads = self.pruned_heads.union(heads)
358
+
359
+ @staticmethod
360
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
361
+ """
362
+ Adapted from Mesh Tensorflow:
363
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
364
+
365
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
366
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
367
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
368
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
369
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
370
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
371
+
372
+ Args:
373
+ relative_position: an int32 Tensor
374
+ bidirectional: a boolean - whether the attention is bidirectional
375
+ num_buckets: an integer
376
+ max_distance: an integer
377
+
378
+ Returns:
379
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
380
+ """
381
+ relative_buckets = 0
382
+ if bidirectional:
383
+ num_buckets //= 2
384
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
385
+ relative_position = torch.abs(relative_position)
386
+ else:
387
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
388
+ # now relative_position is in the range [0, inf)
389
+
390
+ # half of the buckets are for exact increments in positions
391
+ max_exact = num_buckets // 2
392
+ is_small = relative_position < max_exact
393
+
394
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
395
+ relative_position_if_large = max_exact + (torch.log(relative_position.float() / max_exact)
396
+ / math.log(max_distance / max_exact) *
397
+ (num_buckets - max_exact)).to(torch.long)
398
+ relative_position_if_large = torch.min(
399
+ relative_position_if_large,
400
+ torch.full_like(relative_position_if_large, num_buckets - 1),
401
+ )
402
+
403
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
404
+ return relative_buckets
405
+
406
+ def compute_bias(self, query_length, key_length, device=None):
407
+ """Compute binned relative position bias"""
408
+ if device is None:
409
+ device = self.relative_attention_bias.weight.device
410
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
411
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
412
+ relative_position = (memory_position - context_position) # shape (query_length, key_length)
413
+ relative_position_bucket = self._relative_position_bucket(
414
+ relative_position, # shape (query_length, key_length)
415
+ bidirectional=(not self.is_decoder),
416
+ num_buckets=self.relative_attention_num_buckets,
417
+ max_distance=self.relative_attention_max_distance,
418
+ )
419
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
420
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
421
+ return values
422
+
423
+ def forward(
424
+ self,
425
+ hidden_states,
426
+ mask=None,
427
+ key_value_states=None,
428
+ position_bias=None,
429
+ past_key_value=None,
430
+ layer_head_mask=None,
431
+ query_length=None,
432
+ use_cache=False,
433
+ output_attentions=False,
434
+ ):
435
+ """
436
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
437
+ """
438
+ # Input is (batch_size, seq_length, dim)
439
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
440
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
441
+ batch_size, seq_length = hidden_states.shape[:2]
442
+
443
+ real_seq_length = seq_length
444
+
445
+ if past_key_value is not None:
446
+ assert (
447
+ len(past_key_value) == 2
448
+ ), f'past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states'
449
+ real_seq_length += (past_key_value[0].shape[2] if query_length is None else query_length)
450
+
451
+ key_length = (real_seq_length if key_value_states is None else key_value_states.shape[1])
452
+
453
+ def shape(states):
454
+ """projection"""
455
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
456
+
457
+ def unshape(states):
458
+ """reshape"""
459
+ return (states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim))
460
+
461
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
462
+ """projects hidden states correctly to key/query states"""
463
+ if key_value_states is None:
464
+ # self-attn
465
+ # (batch_size, n_heads, seq_length, dim_per_head)
466
+ hidden_states = shape(proj_layer(hidden_states))
467
+ elif past_key_value is None:
468
+ # cross-attn
469
+ # (batch_size, n_heads, seq_length, dim_per_head)
470
+ hidden_states = shape(proj_layer(key_value_states))
471
+
472
+ if past_key_value is not None:
473
+ if key_value_states is None:
474
+ # self-attn
475
+ # (batch_size, n_heads, key_length, dim_per_head)
476
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
477
+ else:
478
+ # cross-attn
479
+ hidden_states = past_key_value
480
+ return hidden_states
481
+
482
+ # get query states
483
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
484
+
485
+ # get key/value states
486
+ key_states = project(
487
+ hidden_states,
488
+ self.k,
489
+ key_value_states,
490
+ past_key_value[0] if past_key_value is not None else None,
491
+ )
492
+ value_states = project(
493
+ hidden_states,
494
+ self.v,
495
+ key_value_states,
496
+ past_key_value[1] if past_key_value is not None else None,
497
+ )
498
+
499
+ # compute scores
500
+ scores = torch.matmul(query_states, key_states.transpose(
501
+ 3, 2)) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
502
+
503
+ if position_bias is None:
504
+ if not self.has_relative_attention_bias:
505
+ position_bias = torch.zeros(
506
+ (1, self.n_heads, real_seq_length, key_length),
507
+ device=scores.device,
508
+ dtype=scores.dtype,
509
+ )
510
+ if self.gradient_checkpointing and self.training:
511
+ position_bias.requires_grad = True
512
+ else:
513
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
514
+
515
+ # if key and values are already calculated
516
+ # we want only the last query position bias
517
+ if past_key_value is not None:
518
+ position_bias = position_bias[:, :, -hidden_states.size(1):, :]
519
+
520
+ if mask is not None:
521
+ position_bias = (position_bias + mask) # (batch_size, n_heads, seq_length, key_length)
522
+
523
+ if self.pruned_heads:
524
+ mask = torch.ones(position_bias.shape[1])
525
+ mask[list(self.pruned_heads)] = 0
526
+ position_bias_masked = position_bias[:, mask.bool()]
527
+ else:
528
+ position_bias_masked = position_bias
529
+
530
+ scores += position_bias_masked
531
+ attn_weights = nn.functional.softmax(
532
+ scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
533
+ attn_weights = nn.functional.dropout(
534
+ attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
535
+
536
+ # Mask heads if we want to
537
+ if layer_head_mask is not None:
538
+ attn_weights = attn_weights * layer_head_mask
539
+
540
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
541
+ attn_output = self.o(attn_output)
542
+
543
+ present_key_value_state = ((key_states, value_states) if (self.is_decoder and use_cache) else None)
544
+ outputs = (attn_output, ) + (present_key_value_state, ) + (position_bias, )
545
+
546
+ if output_attentions:
547
+ outputs = outputs + (attn_weights, )
548
+ return outputs
549
+
550
+
551
+ class T5LayerSelfAttention(nn.Module):
552
+
553
+ def __init__(self, config, has_relative_attention_bias=False):
554
+ super().__init__()
555
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
556
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
557
+ self.dropout = nn.Dropout(config.dropout_rate)
558
+
559
+ def forward(
560
+ self,
561
+ hidden_states,
562
+ attention_mask=None,
563
+ position_bias=None,
564
+ layer_head_mask=None,
565
+ past_key_value=None,
566
+ use_cache=False,
567
+ output_attentions=False,
568
+ ):
569
+ normed_hidden_states = self.layer_norm(hidden_states)
570
+ attention_output = self.SelfAttention(
571
+ normed_hidden_states,
572
+ mask=attention_mask,
573
+ position_bias=position_bias,
574
+ layer_head_mask=layer_head_mask,
575
+ past_key_value=past_key_value,
576
+ use_cache=use_cache,
577
+ output_attentions=output_attentions,
578
+ )
579
+ hidden_states = hidden_states + self.dropout(attention_output[0])
580
+ outputs = (hidden_states, ) + attention_output[1:] # add attentions if we output them
581
+ return outputs
582
+
583
+
584
+ class T5LayerCrossAttention(nn.Module):
585
+
586
+ def __init__(self, config):
587
+ super().__init__()
588
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
589
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
590
+ self.dropout = nn.Dropout(config.dropout_rate)
591
+
592
+ def forward(
593
+ self,
594
+ hidden_states,
595
+ key_value_states,
596
+ attention_mask=None,
597
+ position_bias=None,
598
+ layer_head_mask=None,
599
+ past_key_value=None,
600
+ use_cache=False,
601
+ query_length=None,
602
+ output_attentions=False,
603
+ ):
604
+ normed_hidden_states = self.layer_norm(hidden_states)
605
+ attention_output = self.EncDecAttention(
606
+ normed_hidden_states,
607
+ mask=attention_mask,
608
+ key_value_states=key_value_states,
609
+ position_bias=position_bias,
610
+ layer_head_mask=layer_head_mask,
611
+ past_key_value=past_key_value,
612
+ use_cache=use_cache,
613
+ query_length=query_length,
614
+ output_attentions=output_attentions,
615
+ )
616
+ layer_output = hidden_states + self.dropout(attention_output[0])
617
+ outputs = (layer_output, ) + attention_output[1:] # add attentions if we output them
618
+ return outputs
619
+
620
+
621
+ class T5Block(nn.Module):
622
+
623
+ def __init__(self, config, has_relative_attention_bias=False):
624
+ super().__init__()
625
+ self.is_decoder = config.is_decoder
626
+ self.layer = nn.ModuleList()
627
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
628
+ if self.is_decoder:
629
+ self.layer.append(T5LayerCrossAttention(config))
630
+
631
+ self.layer.append(T5LayerFF(config))
632
+
633
+ def forward(
634
+ self,
635
+ hidden_states,
636
+ attention_mask=None,
637
+ position_bias=None,
638
+ encoder_hidden_states=None,
639
+ encoder_attention_mask=None,
640
+ encoder_decoder_position_bias=None,
641
+ layer_head_mask=None,
642
+ cross_attn_layer_head_mask=None,
643
+ past_key_value=None,
644
+ use_cache=False,
645
+ output_attentions=False,
646
+ return_dict=True,
647
+ ):
648
+
649
+ if past_key_value is not None:
650
+ if not self.is_decoder:
651
+ logger.warning('`past_key_values` is passed to the encoder. Please make sure this is intended.')
652
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
653
+
654
+ if len(past_key_value) != expected_num_past_key_values:
655
+ raise ValueError(
656
+ f'There should be {expected_num_past_key_values} past states. '
657
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
658
+ f'Got {len(past_key_value)} past key / value states')
659
+
660
+ self_attn_past_key_value = past_key_value[:2]
661
+ cross_attn_past_key_value = past_key_value[2:]
662
+ else:
663
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
664
+
665
+ self_attention_outputs = self.layer[0](
666
+ hidden_states,
667
+ attention_mask=attention_mask,
668
+ position_bias=position_bias,
669
+ layer_head_mask=layer_head_mask,
670
+ past_key_value=self_attn_past_key_value,
671
+ use_cache=use_cache,
672
+ output_attentions=output_attentions,
673
+ )
674
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
675
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
676
+
677
+ # clamp inf values to enable fp16 training
678
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
679
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
680
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
681
+
682
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
683
+ if do_cross_attention:
684
+ # the actual query length is unknown for cross attention
685
+ # if using past key value states. Need to inject it here
686
+ if present_key_value_state is not None:
687
+ query_length = present_key_value_state[0].shape[2]
688
+ else:
689
+ query_length = None
690
+
691
+ cross_attention_outputs = self.layer[1](
692
+ hidden_states,
693
+ key_value_states=encoder_hidden_states,
694
+ attention_mask=encoder_attention_mask,
695
+ position_bias=encoder_decoder_position_bias,
696
+ layer_head_mask=cross_attn_layer_head_mask,
697
+ past_key_value=cross_attn_past_key_value,
698
+ query_length=query_length,
699
+ use_cache=use_cache,
700
+ output_attentions=output_attentions,
701
+ )
702
+ hidden_states = cross_attention_outputs[0]
703
+
704
+ # clamp inf values to enable fp16 training
705
+ if (hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any()):
706
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
707
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
708
+
709
+ # Combine self attn and cross attn key value states
710
+ if present_key_value_state is not None:
711
+ present_key_value_state = (present_key_value_state + cross_attention_outputs[1])
712
+
713
+ # Keep cross-attention outputs and relative position weights
714
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
715
+
716
+ # Apply Feed Forward layer
717
+ hidden_states = self.layer[-1](hidden_states)
718
+
719
+ # clamp inf values to enable fp16 training
720
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
721
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
722
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
723
+
724
+ outputs = (hidden_states, )
725
+
726
+ if use_cache:
727
+ outputs = outputs + (present_key_value_state, ) + attention_outputs
728
+ else:
729
+ outputs = outputs + attention_outputs
730
+
731
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
732
+
733
+
734
+ class T5PreTrainedModel(PreTrainedModel):
735
+ """
736
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
737
+ models.
738
+ """
739
+
740
+ config_class = T5Config
741
+ load_tf_weights = load_tf_weights_in_t5
742
+ base_model_prefix = 'transformer'
743
+ is_parallelizable = True
744
+ supports_gradient_checkpointing = True
745
+ _no_split_modules = ['T5Block']
746
+
747
+ @property
748
+ def dummy_inputs(self):
749
+ input_ids = torch.tensor(DUMMY_INPUTS)
750
+ input_mask = torch.tensor(DUMMY_MASK)
751
+ dummy_inputs = {
752
+ 'decoder_input_ids': input_ids,
753
+ 'input_ids': input_ids,
754
+ 'decoder_attention_mask': input_mask,
755
+ }
756
+ return dummy_inputs
757
+
758
+ def _init_weights(self, module):
759
+ """Initialize the weights"""
760
+ factor = (self.config.initializer_factor) # Used for testing weights initialization
761
+ if isinstance(module, T5LayerNorm):
762
+ module.weight.data.fill_(factor * 1.0)
763
+ elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
764
+ # Mesh TensorFlow embeddings initialization
765
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
766
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
767
+ if hasattr(module, 'lm_head') and not self.config.tie_word_embeddings:
768
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
769
+ elif isinstance(module, T5DenseActDense):
770
+ # Mesh TensorFlow FF initialization
771
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
772
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
773
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5))
774
+ if hasattr(module.wi, 'bias') and module.wi.bias is not None:
775
+ module.wi.bias.data.zero_()
776
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5))
777
+ if hasattr(module.wo, 'bias') and module.wo.bias is not None:
778
+ module.wo.bias.data.zero_()
779
+ elif isinstance(module, T5DenseGatedActDense):
780
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5))
781
+ if hasattr(module.wi_0, 'bias') and module.wi_0.bias is not None:
782
+ module.wi_0.bias.data.zero_()
783
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5))
784
+ if hasattr(module.wi_1, 'bias') and module.wi_1.bias is not None:
785
+ module.wi_1.bias.data.zero_()
786
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5))
787
+ if hasattr(module.wo, 'bias') and module.wo.bias is not None:
788
+ module.wo.bias.data.zero_()
789
+ elif isinstance(module, T5Attention):
790
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
791
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
792
+ d_model = self.config.d_model
793
+ key_value_proj_dim = self.config.d_kv
794
+ n_heads = self.config.num_heads
795
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim)**-0.5))
796
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
797
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
798
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim)**-0.5))
799
+ if module.has_relative_attention_bias:
800
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model)**-0.5))
801
+
802
+ def _set_gradient_checkpointing(self, module, value=False):
803
+ if isinstance(module, (T5Attention, T5Stack)):
804
+ module.gradient_checkpointing = value
805
+
806
+ def _shift_right(self, input_ids):
807
+ decoder_start_token_id = self.config.decoder_start_token_id
808
+ pad_token_id = self.config.pad_token_id
809
+
810
+ assert decoder_start_token_id is not None, (
811
+ 'self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id.'
812
+ ' See T5 docs for more information')
813
+
814
+ # shift inputs to the right
815
+ if is_torch_fx_proxy(input_ids):
816
+ # Item assignment is not supported natively for proxies.
817
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1, ), decoder_start_token_id)
818
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
819
+ else:
820
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
821
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
822
+ shifted_input_ids[..., 0] = decoder_start_token_id
823
+
824
+ assert (pad_token_id is not None), 'self.model.config.pad_token_id has to be defined.'
825
+ # replace possible -100 values in labels by `pad_token_id`
826
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
827
+
828
+ return shifted_input_ids
829
+
830
+
831
+ class T5Stack(T5PreTrainedModel):
832
+
833
+ def __init__(self, config, embed_tokens=None):
834
+ super().__init__(config)
835
+
836
+ self.embed_tokens = embed_tokens
837
+ self.is_decoder = config.is_decoder
838
+
839
+ self.block = nn.ModuleList(
840
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)])
841
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
842
+ self.dropout = nn.Dropout(config.dropout_rate)
843
+
844
+ # Initialize weights and apply final processing
845
+ self.post_init()
846
+ # Model parallel
847
+ self.model_parallel = False
848
+ self.device_map = None
849
+ self.gradient_checkpointing = False
850
+
851
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
852
+ def parallelize(self, device_map=None):
853
+ # Check validity of device_map
854
+ self.device_map = (
855
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map)
856
+ assert_device_map(self.device_map, len(self.block))
857
+ self.model_parallel = True
858
+ self.first_device = ('cpu' if 'cpu' in self.device_map.keys() else 'cuda:' + str(min(self.device_map.keys())))
859
+ self.last_device = 'cuda:' + str(max(self.device_map.keys()))
860
+ # Load onto devices
861
+ for k, v in self.device_map.items():
862
+ for layer in v:
863
+ cuda_device = 'cuda:' + str(k)
864
+ self.block[layer] = self.block[layer].to(cuda_device)
865
+
866
+ # Set embed_tokens to first layer
867
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
868
+ # Set final layer norm to last device
869
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
870
+
871
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
872
+ def deparallelize(self):
873
+ self.model_parallel = False
874
+ self.device_map = None
875
+ self.first_device = 'cpu'
876
+ self.last_device = 'cpu'
877
+ for i in range(len(self.block)):
878
+ self.block[i] = self.block[i].to('cpu')
879
+ self.embed_tokens = self.embed_tokens.to('cpu')
880
+ self.final_layer_norm = self.final_layer_norm.to('cpu')
881
+ torch.cuda.empty_cache()
882
+
883
+ def get_input_embeddings(self):
884
+ return self.embed_tokens
885
+
886
+ def set_input_embeddings(self, new_embeddings):
887
+ self.embed_tokens = new_embeddings
888
+
889
+ def forward(
890
+ self,
891
+ input_ids=None,
892
+ attention_mask=None,
893
+ encoder_hidden_states=None,
894
+ encoder_attention_mask=None,
895
+ inputs_embeds=None,
896
+ head_mask=None,
897
+ cross_attn_head_mask=None,
898
+ past_key_values=None,
899
+ use_cache=None,
900
+ output_attentions=None,
901
+ output_hidden_states=None,
902
+ return_dict=None,
903
+ ):
904
+ # Model parallel
905
+ if self.model_parallel:
906
+ torch.cuda.set_device(self.first_device)
907
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
908
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
909
+ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
910
+ output_hidden_states = (
911
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
912
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
913
+
914
+ if input_ids is not None and inputs_embeds is not None:
915
+ err_msg_prefix = 'decoder_' if self.is_decoder else ''
916
+ raise ValueError(
917
+ f'You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time')
918
+ elif input_ids is not None:
919
+ input_shape = input_ids.size()
920
+ input_ids = input_ids.view(-1, input_shape[-1])
921
+ elif inputs_embeds is not None:
922
+ input_shape = inputs_embeds.size()[:-1]
923
+ else:
924
+ err_msg_prefix = 'decoder_' if self.is_decoder else ''
925
+ raise ValueError(f'You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds')
926
+
927
+ if inputs_embeds is None:
928
+ assert (self.embed_tokens is not None), 'You have to initialize the model with valid token embeddings'
929
+ inputs_embeds = self.embed_tokens(input_ids)
930
+
931
+ batch_size, seq_length = input_shape
932
+
933
+ # required mask seq length can be calculated via length of past
934
+ mask_seq_length = (past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length)
935
+
936
+ if use_cache is True:
937
+ assert (self.is_decoder), f'`use_cache` can only be set to `True` if {self} is used as a decoder'
938
+
939
+ if attention_mask is None:
940
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
941
+ if (self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None):
942
+ encoder_seq_length = encoder_hidden_states.shape[1]
943
+ encoder_attention_mask = torch.ones(
944
+ batch_size,
945
+ encoder_seq_length,
946
+ device=inputs_embeds.device,
947
+ dtype=torch.long,
948
+ )
949
+
950
+ # initialize past_key_values with `None` if past does not exist
951
+ if past_key_values is None:
952
+ past_key_values = [None] * len(self.block)
953
+
954
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
955
+ # ourselves in which case we just need to make it broadcastable to all heads.
956
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
957
+
958
+ # If a 2D or 3D attention mask is provided for the cross-attention
959
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
960
+ if self.is_decoder and encoder_hidden_states is not None:
961
+ (
962
+ encoder_batch_size,
963
+ encoder_sequence_length,
964
+ _,
965
+ ) = encoder_hidden_states.size()
966
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
967
+ if encoder_attention_mask is None:
968
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
969
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
970
+ else:
971
+ encoder_extended_attention_mask = None
972
+
973
+ # Prepare head mask if needed
974
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
975
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
976
+ present_key_value_states = () if use_cache else None
977
+ all_hidden_states = () if output_hidden_states else None
978
+ all_attentions = () if output_attentions else None
979
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
980
+ position_bias = None
981
+ encoder_decoder_position_bias = None
982
+
983
+ hidden_states = self.dropout(inputs_embeds)
984
+
985
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
986
+ layer_head_mask = head_mask[i]
987
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
988
+ # Model parallel
989
+ if self.model_parallel:
990
+ torch.cuda.set_device(hidden_states.device)
991
+ # Ensure that attention_mask is always on the same device as hidden_states
992
+ if attention_mask is not None:
993
+ attention_mask = attention_mask.to(hidden_states.device)
994
+ if position_bias is not None:
995
+ position_bias = position_bias.to(hidden_states.device)
996
+ if encoder_hidden_states is not None:
997
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
998
+ if encoder_extended_attention_mask is not None:
999
+ encoder_extended_attention_mask = (encoder_extended_attention_mask.to(hidden_states.device))
1000
+ if encoder_decoder_position_bias is not None:
1001
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1002
+ if layer_head_mask is not None:
1003
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1004
+ if cross_attn_layer_head_mask is not None:
1005
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
1006
+ if output_hidden_states:
1007
+ all_hidden_states = all_hidden_states + (hidden_states, )
1008
+
1009
+ if self.gradient_checkpointing and self.training:
1010
+ if use_cache:
1011
+ logger.warning(
1012
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
1013
+ use_cache = False
1014
+
1015
+ def create_custom_forward(module):
1016
+
1017
+ def custom_forward(*inputs):
1018
+ return tuple(module(*inputs, use_cache, output_attentions))
1019
+
1020
+ return custom_forward
1021
+
1022
+ layer_outputs = checkpoint(
1023
+ create_custom_forward(layer_module),
1024
+ hidden_states,
1025
+ extended_attention_mask,
1026
+ position_bias,
1027
+ encoder_hidden_states,
1028
+ encoder_extended_attention_mask,
1029
+ encoder_decoder_position_bias,
1030
+ layer_head_mask,
1031
+ cross_attn_layer_head_mask,
1032
+ None, # past_key_value is always None with gradient checkpointing
1033
+ )
1034
+ else:
1035
+ layer_outputs = layer_module(
1036
+ hidden_states,
1037
+ attention_mask=extended_attention_mask,
1038
+ position_bias=position_bias,
1039
+ encoder_hidden_states=encoder_hidden_states,
1040
+ encoder_attention_mask=encoder_extended_attention_mask,
1041
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1042
+ layer_head_mask=layer_head_mask,
1043
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1044
+ past_key_value=past_key_value,
1045
+ use_cache=use_cache,
1046
+ output_attentions=output_attentions,
1047
+ )
1048
+
1049
+ # layer_outputs is a tuple with:
1050
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1051
+ if use_cache is False:
1052
+ layer_outputs = layer_outputs[:1] + (None, ) + layer_outputs[1:]
1053
+
1054
+ hidden_states, present_key_value_state = layer_outputs[:2]
1055
+
1056
+ # We share the position biases between the layers - the first layer store them
1057
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1058
+ # (cross-attention position bias), (cross-attention weights)
1059
+ position_bias = layer_outputs[2]
1060
+ if self.is_decoder and encoder_hidden_states is not None:
1061
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1062
+ # append next layer key value states
1063
+ if use_cache:
1064
+ present_key_value_states = present_key_value_states + (present_key_value_state, )
1065
+
1066
+ if output_attentions:
1067
+ all_attentions = all_attentions + (layer_outputs[3], )
1068
+ if self.is_decoder:
1069
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5], )
1070
+
1071
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1072
+ if self.model_parallel:
1073
+ for k, v in self.device_map.items():
1074
+ if i == v[-1] and 'cuda:' + str(k) != self.last_device:
1075
+ hidden_states = hidden_states.to('cuda:' + str(k + 1))
1076
+
1077
+ hidden_states = self.final_layer_norm(hidden_states)
1078
+ hidden_states = self.dropout(hidden_states)
1079
+
1080
+ # Add last layer
1081
+ if output_hidden_states:
1082
+ all_hidden_states = all_hidden_states + (hidden_states, )
1083
+
1084
+ if not return_dict:
1085
+ return tuple(v for v in [
1086
+ hidden_states,
1087
+ present_key_value_states,
1088
+ all_hidden_states,
1089
+ all_attentions,
1090
+ all_cross_attentions,
1091
+ ] if v is not None)
1092
+ return BaseModelOutputWithPastAndCrossAttentions(
1093
+ last_hidden_state=hidden_states,
1094
+ past_key_values=present_key_value_states,
1095
+ hidden_states=all_hidden_states,
1096
+ attentions=all_attentions,
1097
+ cross_attentions=all_cross_attentions,
1098
+ )
1099
+
1100
+
1101
+ T5_START_DOCSTRING = r"""
1102
+
1103
+ The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
1104
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
1105
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
1106
+ text-to-text denoising generative setting.
1107
+
1108
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1109
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1110
+ etc.)
1111
+
1112
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1113
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1114
+ and behavior.
1115
+
1116
+ Parameters:
1117
+ config ([`T5Config`]): Model configuration class with all the parameters of the model.
1118
+ Initializing with a config file does not load the weights associated with the model, only the
1119
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1120
+ """
1121
+
1122
+ T5_INPUTS_DOCSTRING = r"""
1123
+ Args:
1124
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1125
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1126
+ should be able to pad the inputs on both the right and the left.
1127
+
1128
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1129
+ [`PreTrainedTokenizer.__call__`] for detail.
1130
+
1131
+ [What are input IDs?](../glossary#input-ids)
1132
+
1133
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1134
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1135
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1136
+
1137
+ - 1 for tokens that are **not masked**,
1138
+ - 0 for tokens that are **masked**.
1139
+
1140
+ [What are attention masks?](../glossary#attention-mask)
1141
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1142
+ Indices of decoder input sequence tokens in the vocabulary.
1143
+
1144
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1145
+ [`PreTrainedTokenizer.__call__`] for details.
1146
+
1147
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
1148
+
1149
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
1150
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
1151
+
1152
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
1153
+ Training](./t5#training).
1154
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1155
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1156
+ be used by default.
1157
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1158
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1159
+ 1]`:
1160
+
1161
+ - 1 indicates the head is **not masked**,
1162
+ - 0 indicates the head is **masked**.
1163
+
1164
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1165
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1166
+ 1]`:
1167
+
1168
+ - 1 indicates the head is **not masked**,
1169
+ - 0 indicates the head is **masked**.
1170
+
1171
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1172
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1173
+ `[0, 1]`:
1174
+
1175
+ - 1 indicates the head is **not masked**,
1176
+ - 0 indicates the head is **masked**.
1177
+
1178
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1179
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
1180
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
1181
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1182
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1183
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1184
+
1185
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1186
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1187
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1188
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1189
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1190
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1191
+ model's internal embedding lookup matrix.
1192
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
1193
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1194
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
1195
+ input (see `past_key_values`). This is useful if you want more control over how to convert
1196
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1197
+
1198
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1199
+ of `inputs_embeds`.
1200
+
1201
+ use_cache (`bool`, *optional*):
1202
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1203
+ `past_key_values`).
1204
+
1205
+ output_attentions (`bool`, *optional*):
1206
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1207
+ tensors for more detail.
1208
+ output_hidden_states (`bool`, *optional*):
1209
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1210
+ more detail.
1211
+ return_dict (`bool`, *optional*):
1212
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1213
+ """
1214
+
1215
+ T5_ENCODER_INPUTS_DOCSTRING = r"""
1216
+ Args:
1217
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1218
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1219
+ should be able to pad the inputs on both the right and the left.
1220
+
1221
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1222
+ [`PreTrainedTokenizer.__call__`] for detail.
1223
+
1224
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1225
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1226
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1227
+
1228
+ - 1 for tokens that are **not masked**,
1229
+ - 0 for tokens that are **masked**.
1230
+
1231
+ [What are attention masks?](../glossary#attention-mask)
1232
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1233
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1234
+
1235
+ - 1 indicates the head is **not masked**,
1236
+ - 0 indicates the head is **masked**.
1237
+
1238
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1239
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1240
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1241
+ model's internal embedding lookup matrix.
1242
+ output_attentions (`bool`, *optional*):
1243
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1244
+ tensors for more detail.
1245
+ output_hidden_states (`bool`, *optional*):
1246
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1247
+ more detail.
1248
+ return_dict (`bool`, *optional*):
1249
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1250
+ """
1251
+
1252
+ # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1253
+ __HEAD_MASK_WARNING_MSG = """
1254
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
1255
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
1256
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
1257
+ num_heads)`.
1258
+ """
1259
+
1260
+
1261
+ @add_start_docstrings(
1262
+ 'The bare T5 Model transformer outputting raw hidden-states without any specific head on top.',
1263
+ T5_START_DOCSTRING,
1264
+ )
1265
+ class T5Model(T5PreTrainedModel):
1266
+ _keys_to_ignore_on_load_missing = [
1267
+ r'encoder.embed_tokens.weight',
1268
+ r'decoder.embed_tokens.weight',
1269
+ ]
1270
+ _keys_to_ignore_on_load_unexpected = [
1271
+ r'decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight',
1272
+ ]
1273
+
1274
+ def __init__(self, config: T5Config):
1275
+ super().__init__(config)
1276
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1277
+
1278
+ encoder_config = copy.deepcopy(config)
1279
+ encoder_config.is_decoder = False
1280
+ encoder_config.use_cache = False
1281
+ encoder_config.is_encoder_decoder = False
1282
+ self.encoder = T5Stack(encoder_config, self.shared)
1283
+
1284
+ decoder_config = copy.deepcopy(config)
1285
+ decoder_config.is_decoder = True
1286
+ decoder_config.is_encoder_decoder = False
1287
+ decoder_config.num_layers = config.num_decoder_layers
1288
+ self.decoder = T5Stack(decoder_config, self.shared)
1289
+
1290
+ # Initialize weights and apply final processing
1291
+ self.post_init()
1292
+
1293
+ # Model parallel
1294
+ self.model_parallel = False
1295
+ self.device_map = None
1296
+
1297
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1298
+ def parallelize(self, device_map=None):
1299
+ self.device_map = (
1300
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1301
+ if device_map is None else device_map)
1302
+ assert_device_map(self.device_map, len(self.encoder.block))
1303
+ self.encoder.parallelize(self.device_map)
1304
+ self.decoder.parallelize(self.device_map)
1305
+ self.model_parallel = True
1306
+
1307
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1308
+ def deparallelize(self):
1309
+ self.encoder.deparallelize()
1310
+ self.decoder.deparallelize()
1311
+ self.encoder = self.encoder.to('cpu')
1312
+ self.decoder = self.decoder.to('cpu')
1313
+ self.model_parallel = False
1314
+ self.device_map = None
1315
+ torch.cuda.empty_cache()
1316
+
1317
+ def get_input_embeddings(self):
1318
+ return self.shared
1319
+
1320
+ def set_input_embeddings(self, new_embeddings):
1321
+ self.shared = new_embeddings
1322
+ self.encoder.set_input_embeddings(new_embeddings)
1323
+ self.decoder.set_input_embeddings(new_embeddings)
1324
+
1325
+ def get_encoder(self):
1326
+ return self.encoder
1327
+
1328
+ def get_decoder(self):
1329
+ return self.decoder
1330
+
1331
+ def _prune_heads(self, heads_to_prune):
1332
+ """
1333
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1334
+ class PreTrainedModel
1335
+ """
1336
+ for layer, heads in heads_to_prune.items():
1337
+ self.encoder.layer[layer].attention.prune_heads(heads)
1338
+
1339
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1340
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1341
+ def forward(
1342
+ self,
1343
+ input_ids: Optional[torch.LongTensor] = None,
1344
+ attention_mask: Optional[torch.FloatTensor] = None,
1345
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1346
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1347
+ head_mask: Optional[torch.FloatTensor] = None,
1348
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1349
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1350
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1351
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1352
+ inputs_embeds: Optional[torch.Tensor] = None,
1353
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1354
+ use_cache: Optional[bool] = None,
1355
+ output_attentions: Optional[bool] = None,
1356
+ output_hidden_states: Optional[bool] = None,
1357
+ return_dict: Optional[bool] = None,
1358
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1359
+ r"""
1360
+ Returns:
1361
+
1362
+ Example:
1363
+
1364
+ ```python
1365
+ >>> from transformers import T5Tokenizer, T5Model
1366
+
1367
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1368
+ >>> model = T5Model.from_pretrained("t5-small")
1369
+
1370
+ >>> input_ids = tokenizer(
1371
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1372
+ ... ).input_ids # Batch size 1
1373
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1374
+
1375
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1376
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1377
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1378
+
1379
+ >>> # forward pass
1380
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1381
+ >>> last_hidden_states = outputs.last_hidden_state
1382
+ ```"""
1383
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1384
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1385
+
1386
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1387
+ if head_mask is not None and decoder_head_mask is None:
1388
+ if self.config.num_layers == self.config.num_decoder_layers:
1389
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1390
+ decoder_head_mask = head_mask
1391
+
1392
+ # Encode if needed (training, first prediction pass)
1393
+ if encoder_outputs is None:
1394
+ encoder_outputs = self.encoder(
1395
+ input_ids=input_ids,
1396
+ attention_mask=attention_mask,
1397
+ inputs_embeds=inputs_embeds,
1398
+ head_mask=head_mask,
1399
+ output_attentions=output_attentions,
1400
+ output_hidden_states=output_hidden_states,
1401
+ return_dict=return_dict,
1402
+ )
1403
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1404
+ encoder_outputs = BaseModelOutput(
1405
+ last_hidden_state=encoder_outputs[0],
1406
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1407
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1408
+ )
1409
+
1410
+ hidden_states = encoder_outputs[0]
1411
+
1412
+ # Set device for model parallelism
1413
+ if self.model_parallel:
1414
+ torch.cuda.set_device(self.decoder.first_device)
1415
+ hidden_states = hidden_states.to(self.decoder.first_device)
1416
+ if decoder_input_ids is not None:
1417
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1418
+ if attention_mask is not None:
1419
+ attention_mask = attention_mask.to(self.decoder.first_device)
1420
+ if decoder_attention_mask is not None:
1421
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1422
+
1423
+ # Decode
1424
+ decoder_outputs = self.decoder(
1425
+ input_ids=decoder_input_ids,
1426
+ attention_mask=decoder_attention_mask,
1427
+ inputs_embeds=decoder_inputs_embeds,
1428
+ past_key_values=past_key_values,
1429
+ encoder_hidden_states=hidden_states,
1430
+ encoder_attention_mask=attention_mask,
1431
+ head_mask=decoder_head_mask,
1432
+ cross_attn_head_mask=cross_attn_head_mask,
1433
+ use_cache=use_cache,
1434
+ output_attentions=output_attentions,
1435
+ output_hidden_states=output_hidden_states,
1436
+ return_dict=return_dict,
1437
+ )
1438
+
1439
+ if not return_dict:
1440
+ return decoder_outputs + encoder_outputs
1441
+
1442
+ return Seq2SeqModelOutput(
1443
+ last_hidden_state=decoder_outputs.last_hidden_state,
1444
+ past_key_values=decoder_outputs.past_key_values,
1445
+ decoder_hidden_states=decoder_outputs.hidden_states,
1446
+ decoder_attentions=decoder_outputs.attentions,
1447
+ cross_attentions=decoder_outputs.cross_attentions,
1448
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1449
+ encoder_hidden_states=encoder_outputs.hidden_states,
1450
+ encoder_attentions=encoder_outputs.attentions,
1451
+ )
1452
+
1453
+
1454
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1455
+ class T5ForConditionalGeneration(T5PreTrainedModel):
1456
+ _keys_to_ignore_on_load_missing = [
1457
+ r'encoder.embed_tokens.weight',
1458
+ r'decoder.embed_tokens.weight',
1459
+ r'lm_head.weight',
1460
+ ]
1461
+ _keys_to_ignore_on_load_unexpected = [
1462
+ r'decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight',
1463
+ ]
1464
+
1465
+ def __init__(self, config: T5Config):
1466
+ super().__init__(config)
1467
+ self.model_dim = config.d_model
1468
+
1469
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1470
+
1471
+ encoder_config = copy.deepcopy(config)
1472
+ encoder_config.is_decoder = False
1473
+ encoder_config.use_cache = False
1474
+ encoder_config.is_encoder_decoder = False
1475
+ self.encoder = T5Stack(encoder_config, self.shared)
1476
+
1477
+ decoder_config = copy.deepcopy(config)
1478
+ decoder_config.is_decoder = True
1479
+ decoder_config.is_encoder_decoder = False
1480
+ decoder_config.num_layers = config.num_decoder_layers
1481
+ self.decoder = T5Stack(decoder_config, self.shared)
1482
+
1483
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1484
+
1485
+ # Initialize weights and apply final processing
1486
+ self.post_init()
1487
+
1488
+ # Model parallel
1489
+ self.model_parallel = False
1490
+ self.device_map = None
1491
+
1492
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1493
+ def parallelize(self, device_map=None):
1494
+ self.device_map = (
1495
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1496
+ if device_map is None else device_map)
1497
+ assert_device_map(self.device_map, len(self.encoder.block))
1498
+ self.encoder.parallelize(self.device_map)
1499
+ self.decoder.parallelize(self.device_map)
1500
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1501
+ self.model_parallel = True
1502
+
1503
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1504
+ def deparallelize(self):
1505
+ self.encoder.deparallelize()
1506
+ self.decoder.deparallelize()
1507
+ self.encoder = self.encoder.to('cpu')
1508
+ self.decoder = self.decoder.to('cpu')
1509
+ self.lm_head = self.lm_head.to('cpu')
1510
+ self.model_parallel = False
1511
+ self.device_map = None
1512
+ torch.cuda.empty_cache()
1513
+
1514
+ def get_input_embeddings(self):
1515
+ return self.shared
1516
+
1517
+ def set_input_embeddings(self, new_embeddings):
1518
+ self.shared = new_embeddings
1519
+ self.encoder.set_input_embeddings(new_embeddings)
1520
+ self.decoder.set_input_embeddings(new_embeddings)
1521
+
1522
+ def set_output_embeddings(self, new_embeddings):
1523
+ self.lm_head = new_embeddings
1524
+
1525
+ def get_output_embeddings(self):
1526
+ return self.lm_head
1527
+
1528
+ def get_encoder(self):
1529
+ return self.encoder
1530
+
1531
+ def get_decoder(self):
1532
+ return self.decoder
1533
+
1534
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1535
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1536
+ def forward(
1537
+ self,
1538
+ input_ids: Optional[torch.LongTensor] = None,
1539
+ attention_mask: Optional[torch.FloatTensor] = None,
1540
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1541
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1542
+ head_mask: Optional[torch.FloatTensor] = None,
1543
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1544
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1545
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1546
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1547
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1548
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1549
+ labels: Optional[torch.LongTensor] = None,
1550
+ use_cache: Optional[bool] = None,
1551
+ output_attentions: Optional[bool] = None,
1552
+ output_hidden_states: Optional[bool] = None,
1553
+ return_dict: Optional[bool] = None,
1554
+ reduction: Optional[str] = 'mean',
1555
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1556
+ r"""
1557
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1558
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1559
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1560
+ labels in `[0, ..., config.vocab_size]`
1561
+
1562
+ Returns:
1563
+
1564
+ Examples:
1565
+
1566
+ ```python
1567
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
1568
+
1569
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1570
+ >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1571
+
1572
+ >>> # training
1573
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1574
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1575
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1576
+ >>> loss = outputs.loss
1577
+ >>> logits = outputs.logits
1578
+
1579
+ >>> # inference
1580
+ >>> input_ids = tokenizer(
1581
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1582
+ ... ).input_ids # Batch size 1
1583
+ >>> outputs = model.generate(input_ids)
1584
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1585
+ >>> # studies have shown that owning a dog is good for you.
1586
+ ```"""
1587
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1588
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1589
+
1590
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1591
+ if head_mask is not None and decoder_head_mask is None:
1592
+ if self.config.num_layers == self.config.num_decoder_layers:
1593
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1594
+ decoder_head_mask = head_mask
1595
+
1596
+ # Encode if needed (training, first prediction pass)
1597
+ if encoder_outputs is None:
1598
+ # Convert encoder inputs in embeddings if needed
1599
+ encoder_outputs = self.encoder(
1600
+ input_ids=input_ids,
1601
+ attention_mask=attention_mask,
1602
+ inputs_embeds=inputs_embeds,
1603
+ head_mask=head_mask,
1604
+ output_attentions=output_attentions,
1605
+ output_hidden_states=output_hidden_states,
1606
+ return_dict=return_dict,
1607
+ )
1608
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1609
+ encoder_outputs = BaseModelOutput(
1610
+ last_hidden_state=encoder_outputs[0],
1611
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1612
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1613
+ )
1614
+
1615
+ hidden_states = encoder_outputs[0]
1616
+
1617
+ if self.model_parallel:
1618
+ torch.cuda.set_device(self.decoder.first_device)
1619
+
1620
+ if (labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None):
1621
+ # get decoder inputs from shifting lm labels to the right
1622
+ decoder_input_ids = self._shift_right(labels)
1623
+
1624
+ # Set device for model parallelism
1625
+ if self.model_parallel:
1626
+ torch.cuda.set_device(self.decoder.first_device)
1627
+ hidden_states = hidden_states.to(self.decoder.first_device)
1628
+ if decoder_input_ids is not None:
1629
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1630
+ if attention_mask is not None:
1631
+ attention_mask = attention_mask.to(self.decoder.first_device)
1632
+ if decoder_attention_mask is not None:
1633
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1634
+
1635
+ # Decode
1636
+ decoder_outputs = self.decoder(
1637
+ input_ids=decoder_input_ids,
1638
+ attention_mask=decoder_attention_mask,
1639
+ inputs_embeds=decoder_inputs_embeds,
1640
+ past_key_values=past_key_values,
1641
+ encoder_hidden_states=hidden_states,
1642
+ encoder_attention_mask=attention_mask,
1643
+ head_mask=decoder_head_mask,
1644
+ cross_attn_head_mask=cross_attn_head_mask,
1645
+ use_cache=use_cache,
1646
+ output_attentions=output_attentions,
1647
+ output_hidden_states=output_hidden_states,
1648
+ return_dict=return_dict,
1649
+ )
1650
+
1651
+ sequence_output = decoder_outputs[0]
1652
+
1653
+ # Set device for model parallelism
1654
+ if self.model_parallel:
1655
+ torch.cuda.set_device(self.encoder.first_device)
1656
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1657
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1658
+
1659
+ if self.config.tie_word_embeddings:
1660
+ # Rescale output before projecting on vocab
1661
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1662
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1663
+
1664
+ lm_logits = self.lm_head(sequence_output)
1665
+
1666
+ loss = None
1667
+ if labels is not None:
1668
+ loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction)
1669
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1670
+ if reduction == 'none':
1671
+ loss = loss.view(lm_logits.size(0), -1).sum(1)
1672
+
1673
+ if not return_dict:
1674
+ output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
1675
+ return ((loss, ) + output) if loss is not None else output
1676
+
1677
+ return Seq2SeqLMOutput(
1678
+ loss=loss,
1679
+ logits=lm_logits,
1680
+ past_key_values=decoder_outputs.past_key_values,
1681
+ decoder_hidden_states=decoder_outputs.hidden_states,
1682
+ decoder_attentions=decoder_outputs.attentions,
1683
+ cross_attentions=decoder_outputs.cross_attentions,
1684
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1685
+ encoder_hidden_states=encoder_outputs.hidden_states,
1686
+ encoder_attentions=encoder_outputs.attentions,
1687
+ )
1688
+
1689
+ def prepare_inputs_for_generation(
1690
+ self,
1691
+ input_ids,
1692
+ past=None,
1693
+ attention_mask=None,
1694
+ head_mask=None,
1695
+ decoder_head_mask=None,
1696
+ cross_attn_head_mask=None,
1697
+ use_cache=None,
1698
+ encoder_outputs=None,
1699
+ **kwargs,
1700
+ ):
1701
+
1702
+ # cut decoder_input_ids if past is used
1703
+ if past is not None:
1704
+ input_ids = input_ids[:, -1:]
1705
+
1706
+ return {
1707
+ 'decoder_input_ids': input_ids,
1708
+ 'past_key_values': past,
1709
+ 'encoder_outputs': encoder_outputs,
1710
+ 'attention_mask': attention_mask,
1711
+ 'head_mask': head_mask,
1712
+ 'decoder_head_mask': decoder_head_mask,
1713
+ 'cross_attn_head_mask': cross_attn_head_mask,
1714
+ 'use_cache': use_cache,
1715
+ }
1716
+
1717
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1718
+ return self._shift_right(labels)
1719
+
1720
+ def _reorder_cache(self, past, beam_idx):
1721
+ # if decoder past is not included in output
1722
+ # speedy decoding is disabled and no need to reorder
1723
+ if past is None:
1724
+ logger.warning('You might want to consider setting `use_cache=True` to speed up decoding')
1725
+ return past
1726
+
1727
+ reordered_decoder_past = ()
1728
+ for layer_past_states in past:
1729
+ # get the correct batch idx from layer past batch dim
1730
+ # batch dim of `past` is at 2nd position
1731
+ reordered_layer_past_states = ()
1732
+ for layer_past_state in layer_past_states:
1733
+ # need to set correct `past` for each of the four key / value states
1734
+ reordered_layer_past_states = reordered_layer_past_states + (layer_past_state.index_select(
1735
+ 0, beam_idx.to(layer_past_state.device)), )
1736
+
1737
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1738
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1739
+
1740
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states, )
1741
+ return reordered_decoder_past
1742
+
1743
+
1744
+ @add_start_docstrings(
1745
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1746
+ T5_START_DOCSTRING,
1747
+ )
1748
+ class T5EncoderModel(T5PreTrainedModel):
1749
+ authorized_missing_keys = [
1750
+ r'encoder.embed_tokens.weight',
1751
+ ]
1752
+
1753
+ def __init__(self, config: T5Config):
1754
+ super().__init__(config)
1755
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1756
+
1757
+ encoder_config = copy.deepcopy(config)
1758
+ encoder_config.use_cache = False
1759
+ encoder_config.is_encoder_decoder = False
1760
+ self.encoder = T5Stack(encoder_config, self.shared)
1761
+
1762
+ # Initialize weights and apply final processing
1763
+ self.post_init()
1764
+
1765
+ # Model parallel
1766
+ self.model_parallel = False
1767
+ self.device_map = None
1768
+
1769
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1770
+ def parallelize(self, device_map=None):
1771
+ self.device_map = (
1772
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1773
+ if device_map is None else device_map)
1774
+ assert_device_map(self.device_map, len(self.encoder.block))
1775
+ self.encoder.parallelize(self.device_map)
1776
+ self.model_parallel = True
1777
+
1778
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1779
+ def deparallelize(self):
1780
+ self.encoder.deparallelize()
1781
+ self.encoder = self.encoder.to('cpu')
1782
+ self.model_parallel = False
1783
+ self.device_map = None
1784
+ torch.cuda.empty_cache()
1785
+
1786
+ def get_input_embeddings(self):
1787
+ return self.shared
1788
+
1789
+ def set_input_embeddings(self, new_embeddings):
1790
+ self.shared = new_embeddings
1791
+ self.encoder.set_input_embeddings(new_embeddings)
1792
+
1793
+ def get_encoder(self):
1794
+ return self.encoder
1795
+
1796
+ def _prune_heads(self, heads_to_prune):
1797
+ """
1798
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1799
+ class PreTrainedModel
1800
+ """
1801
+ for layer, heads in heads_to_prune.items():
1802
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1803
+
1804
+ @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
1805
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
1806
+ def forward(
1807
+ self,
1808
+ input_ids: Optional[torch.LongTensor] = None,
1809
+ attention_mask: Optional[torch.FloatTensor] = None,
1810
+ head_mask: Optional[torch.FloatTensor] = None,
1811
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1812
+ output_attentions: Optional[bool] = None,
1813
+ output_hidden_states: Optional[bool] = None,
1814
+ return_dict: Optional[bool] = None,
1815
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1816
+ r"""
1817
+ Returns:
1818
+
1819
+ Example:
1820
+
1821
+ ```python
1822
+ >>> from transformers import T5Tokenizer, T5EncoderModel
1823
+
1824
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1825
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
1826
+ >>> input_ids = tokenizer(
1827
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1828
+ ... ).input_ids # Batch size 1
1829
+ >>> outputs = model(input_ids=input_ids)
1830
+ >>> last_hidden_states = outputs.last_hidden_state
1831
+ ```"""
1832
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1833
+
1834
+ encoder_outputs = self.encoder(
1835
+ input_ids=input_ids,
1836
+ attention_mask=attention_mask,
1837
+ inputs_embeds=inputs_embeds,
1838
+ head_mask=head_mask,
1839
+ output_attentions=output_attentions,
1840
+ output_hidden_states=output_hidden_states,
1841
+ return_dict=return_dict,
1842
+ )
1843
+
1844
+ return encoder_outputs