ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +5 -4
- ai_edge_torch/_convert/conversion.py +112 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +94 -48
- ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/test_convert.py +495 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
- ai_edge_torch/config.py +27 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +72 -40
- ai_edge_torch/debug/test/test_culprit.py +7 -5
- ai_edge_torch/debug/test/test_search_model.py +8 -7
- ai_edge_torch/debug/utils.py +14 -3
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
- ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
- ai_edge_torch/generative/examples/openelm/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
- ai_edge_torch/generative/examples/phi/phi3.py +286 -0
- ai_edge_torch/generative/examples/phi/verify.py +65 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
- ai_edge_torch/generative/examples/smollm/verify.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
- ai_edge_torch/generative/examples/t5/t5.py +208 -159
- ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
- ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
- ai_edge_torch/generative/fx_passes/__init__.py +4 -5
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
- ai_edge_torch/generative/layers/attention.py +141 -102
- ai_edge_torch/generative/layers/attention_utils.py +53 -12
- ai_edge_torch/generative/layers/builder.py +37 -7
- ai_edge_torch/generative/layers/feed_forward.py +39 -14
- ai_edge_torch/generative/layers/kv_cache.py +162 -50
- ai_edge_torch/generative/layers/model_config.py +84 -30
- ai_edge_torch/generative/layers/normalization.py +185 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
- ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/layers/unet/model_config.py +17 -15
- ai_edge_torch/generative/quantize/example.py +7 -8
- ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
- ai_edge_torch/generative/test/test_model_conversion.py +124 -188
- ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
- ai_edge_torch/generative/test/test_quantize.py +76 -60
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/loader.py +120 -57
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
- ai_edge_torch/generative/utilities/t5_loader.py +110 -81
- ai_edge_torch/generative/utilities/verifier.py +247 -0
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
- ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
- ai_edge_torch/lowertools/test_utils.py +60 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
- ai_edge_torch/model.py +53 -18
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +357 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
- ai_edge_torch/quantize/quant_config.py +13 -9
- ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
- ai_edge_torch/convert/conversion.py +0 -117
- ai_edge_torch/convert/conversion_utils.py +0 -400
- ai_edge_torch/convert/fx_passes/__init__.py +0 -59
- ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
- ai_edge_torch/convert/test/test_convert.py +0 -311
- ai_edge_torch/convert/test/test_convert_composites.py +0 -192
- ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -17,17 +17,15 @@
|
|
17
17
|
import copy
|
18
18
|
import os
|
19
19
|
from pathlib import Path
|
20
|
-
from typing import Optional
|
21
|
-
|
22
|
-
import numpy as np
|
23
|
-
import torch
|
24
|
-
import torch.nn as nn
|
20
|
+
from typing import Optional
|
25
21
|
|
26
22
|
from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
|
27
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
28
24
|
import ai_edge_torch.generative.layers.builder as builder
|
29
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
30
26
|
import ai_edge_torch.generative.utilities.t5_loader as loading_utils
|
27
|
+
import torch
|
28
|
+
import torch.nn as nn
|
31
29
|
|
32
30
|
ENCDEC_TENSOR_NAMES = {
|
33
31
|
"ff_up_proj": "{prefix}.block.{}.layer.{num}.DenseReluDense.wi",
|
@@ -36,9 +34,11 @@ ENCDEC_TENSOR_NAMES = {
|
|
36
34
|
"attn_key_proj": "{prefix}.block.{}.layer.0.SelfAttention.k",
|
37
35
|
"attn_value_proj": "{prefix}.block.{}.layer.0.SelfAttention.v",
|
38
36
|
"attn_output_proj": "{prefix}.block.{}.layer.0.SelfAttention.o",
|
39
|
-
"relative_attn_bias":
|
37
|
+
"relative_attn_bias": (
|
38
|
+
"{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
|
39
|
+
),
|
40
40
|
"pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
|
41
|
-
"
|
41
|
+
"post_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
42
42
|
"final_norm": "{prefix}.final_layer_norm",
|
43
43
|
}
|
44
44
|
|
@@ -52,13 +52,19 @@ class T5Stack(nn.Module):
|
|
52
52
|
self.config = config
|
53
53
|
self.embed_tokens = embed_tokens
|
54
54
|
self.is_decoder = config.is_decoder
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
55
|
+
# T5 has only one block config.
|
56
|
+
block_config = config.block_config(0)
|
57
|
+
self.transformer_blocks = nn.ModuleList([
|
58
|
+
EncoderDecoderBlock(
|
59
|
+
block_config,
|
60
|
+
config,
|
61
|
+
has_relative_attention_bias=bool(idx == 0),
|
62
|
+
)
|
63
|
+
for idx in range(config.num_layers)
|
64
|
+
])
|
65
|
+
self.final_norm = builder.build_norm(
|
66
|
+
config.embedding_dim, config.final_norm_config
|
60
67
|
)
|
61
|
-
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
|
62
68
|
|
63
69
|
def forward(
|
64
70
|
self,
|
@@ -73,23 +79,23 @@ class T5Stack(nn.Module):
|
|
73
79
|
torch.Tensor
|
74
80
|
] = None, # should be for decoder case
|
75
81
|
):
|
76
|
-
input_shape = input_ids.size()
|
77
82
|
inputs_embeds = self.embed_tokens(input_ids)
|
78
|
-
batch_size, seq_length = input_shape
|
79
83
|
hidden_states = inputs_embeds
|
80
84
|
position_bias = None
|
81
85
|
encoder_decoder_position_bias = None
|
82
|
-
for
|
86
|
+
for _, layer_module in enumerate(self.transformer_blocks):
|
83
87
|
# EncoderDecoderBlock.forward
|
84
|
-
hidden_states, position_bias, encoder_decoder_position_bias =
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
88
|
+
hidden_states, position_bias, encoder_decoder_position_bias = (
|
89
|
+
layer_module(
|
90
|
+
hidden_states,
|
91
|
+
input_pos,
|
92
|
+
mask=attention_mask,
|
93
|
+
relative_position=relative_position,
|
94
|
+
position_bias=position_bias,
|
95
|
+
encoder_hidden_states=encoder_hidden_states,
|
96
|
+
encoder_attention_mask=encoder_attention_mask,
|
97
|
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
98
|
+
)
|
93
99
|
)
|
94
100
|
|
95
101
|
hidden_states = self.final_norm(hidden_states)
|
@@ -109,7 +115,8 @@ class T5(nn.Module):
|
|
109
115
|
|
110
116
|
encoder_config = copy.deepcopy(config)
|
111
117
|
encoder_config.is_decoder = False
|
112
|
-
|
118
|
+
# T5 has only one block config.
|
119
|
+
encoder_config.block_config(0).attn_config.enable_kv_cache = False
|
113
120
|
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
114
121
|
|
115
122
|
decoder_config = copy.deepcopy(config)
|
@@ -130,23 +137,27 @@ class T5(nn.Module):
|
|
130
137
|
)
|
131
138
|
|
132
139
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
133
|
-
size=config.kv_cache_max,
|
140
|
+
size=config.kv_cache_max,
|
141
|
+
dtype=torch.float32,
|
142
|
+
device=torch.device("cpu"),
|
134
143
|
)
|
135
144
|
|
145
|
+
# T5 has only one block config.
|
146
|
+
attn_config = config.block_config(0).attn_config
|
136
147
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
137
148
|
bidirectional=True,
|
138
149
|
query_length=config.kv_cache_max,
|
139
150
|
key_length=config.kv_cache_max,
|
140
|
-
num_buckets=
|
141
|
-
max_distance=
|
151
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
152
|
+
max_distance=attn_config.relative_attention_max_distance,
|
142
153
|
)
|
143
154
|
|
144
155
|
self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
145
156
|
bidirectional=False,
|
146
157
|
query_length=config.kv_cache_max,
|
147
158
|
key_length=config.kv_cache_max,
|
148
|
-
num_buckets=
|
149
|
-
max_distance=
|
159
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
160
|
+
max_distance=attn_config.relative_attention_max_distance,
|
150
161
|
)
|
151
162
|
|
152
163
|
@torch.inference_mode
|
@@ -159,9 +170,10 @@ class T5(nn.Module):
|
|
159
170
|
pad_mask: torch.Tensor,
|
160
171
|
) -> torch.Tensor:
|
161
172
|
B, T = input_ids.size()
|
162
|
-
assert (
|
163
|
-
|
164
|
-
|
173
|
+
assert self.config.max_seq_len >= T, (
|
174
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
175
|
+
f" {self.config.max_seq_len}"
|
176
|
+
)
|
165
177
|
|
166
178
|
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
167
179
|
enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
|
@@ -170,10 +182,18 @@ class T5(nn.Module):
|
|
170
182
|
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
171
183
|
dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
|
172
184
|
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
173
|
-
enc_relative_position = enc_relative_position[
|
174
|
-
|
175
|
-
|
176
|
-
|
185
|
+
enc_relative_position = enc_relative_position[
|
186
|
+
:, :, :, : self.config.kv_cache_max
|
187
|
+
]
|
188
|
+
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
189
|
+
2, decoder_input_pos
|
190
|
+
)
|
191
|
+
dec_relative_position = dec_relative_position[
|
192
|
+
:, :, :, : self.config.kv_cache_max
|
193
|
+
]
|
194
|
+
enc_attention_mask = self.enc_attn_mask_cache.index_select(
|
195
|
+
2, decoder_input_pos
|
196
|
+
)
|
177
197
|
# Mask off any "pad" tokens that shouldn't contribute to cross attention
|
178
198
|
enc_attention_mask[:, :, :, :] += pad_mask
|
179
199
|
|
@@ -210,12 +230,15 @@ class T5Encoder(nn.Module):
|
|
210
230
|
|
211
231
|
self.config = config
|
212
232
|
# Construct model layers.
|
213
|
-
assert
|
233
|
+
assert (
|
234
|
+
embedding_layer != None
|
235
|
+
), "Passed in embedding layer should not be None!"
|
214
236
|
self.tok_embedding = embedding_layer
|
215
237
|
|
216
238
|
encoder_config = copy.deepcopy(config)
|
217
239
|
encoder_config.is_decoder = False
|
218
|
-
|
240
|
+
# T5 has only one block config.
|
241
|
+
encoder_config.block_config(0).attn_config.enable_kv_cache = False
|
219
242
|
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
220
243
|
|
221
244
|
self.enc_attn_mask_cache = (
|
@@ -228,12 +251,14 @@ class T5Encoder(nn.Module):
|
|
228
251
|
.unsqueeze(0)
|
229
252
|
)
|
230
253
|
|
254
|
+
# T5 has only one block config.
|
255
|
+
attn_config = config.block_config(0).attn_config
|
231
256
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
232
257
|
bidirectional=True,
|
233
258
|
query_length=config.kv_cache_max,
|
234
259
|
key_length=config.kv_cache_max,
|
235
|
-
num_buckets=
|
236
|
-
max_distance=
|
260
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
261
|
+
max_distance=attn_config.relative_attention_max_distance,
|
237
262
|
)
|
238
263
|
|
239
264
|
@torch.inference_mode
|
@@ -244,16 +269,19 @@ class T5Encoder(nn.Module):
|
|
244
269
|
pad_mask: torch.Tensor,
|
245
270
|
) -> torch.Tensor:
|
246
271
|
B, T = input_ids.size()
|
247
|
-
assert (
|
248
|
-
|
249
|
-
|
272
|
+
assert self.config.max_seq_len >= T, (
|
273
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
274
|
+
f" {self.config.max_seq_len}"
|
275
|
+
)
|
250
276
|
|
251
277
|
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
252
278
|
enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
|
253
279
|
# Mask off any "pad" tokens that shouldn't contribute to self-attention
|
254
280
|
enc_mask[:, :, :, :] += pad_mask
|
255
281
|
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
256
|
-
enc_relative_position = enc_relative_position[
|
282
|
+
enc_relative_position = enc_relative_position[
|
283
|
+
:, :, :, : self.config.kv_cache_max
|
284
|
+
]
|
257
285
|
|
258
286
|
# Convert encoder inputs in embeddings if needed
|
259
287
|
encoder_hidden_states = self.encoder(
|
@@ -273,7 +301,9 @@ class T5Decoder(nn.Module):
|
|
273
301
|
|
274
302
|
self.config = config
|
275
303
|
# Construct model layers.
|
276
|
-
assert
|
304
|
+
assert (
|
305
|
+
embedding_layer != None
|
306
|
+
), "Passed in embedding layer should not be None!"
|
277
307
|
self.tok_embedding = embedding_layer
|
278
308
|
|
279
309
|
decoder_config = copy.deepcopy(config)
|
@@ -293,16 +323,20 @@ class T5Decoder(nn.Module):
|
|
293
323
|
.unsqueeze(0)
|
294
324
|
)
|
295
325
|
|
326
|
+
# T5 has only one block config.
|
327
|
+
attn_config = config.block_config(0).attn_config
|
296
328
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
297
329
|
bidirectional=True,
|
298
330
|
query_length=config.kv_cache_max,
|
299
331
|
key_length=config.kv_cache_max,
|
300
|
-
num_buckets=
|
301
|
-
max_distance=
|
332
|
+
num_buckets=attn_config.relative_attention_num_buckets,
|
333
|
+
max_distance=attn_config.relative_attention_max_distance,
|
302
334
|
)
|
303
335
|
|
304
336
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
305
|
-
size=config.kv_cache_max,
|
337
|
+
size=config.kv_cache_max,
|
338
|
+
dtype=torch.float32,
|
339
|
+
device=torch.device("cpu"),
|
306
340
|
)
|
307
341
|
|
308
342
|
@torch.inference_mode
|
@@ -315,9 +349,15 @@ class T5Decoder(nn.Module):
|
|
315
349
|
) -> torch.Tensor:
|
316
350
|
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
317
351
|
dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
|
318
|
-
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
319
|
-
|
320
|
-
|
352
|
+
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
353
|
+
2, decoder_input_pos
|
354
|
+
)
|
355
|
+
dec_relative_position = dec_relative_position[
|
356
|
+
:, :, :, : self.config.kv_cache_max
|
357
|
+
]
|
358
|
+
enc_attention_mask = self.enc_attn_mask_cache.index_select(
|
359
|
+
2, decoder_input_pos
|
360
|
+
)
|
321
361
|
# Mask off any "pad" tokens that shouldn't contribute to cross attention
|
322
362
|
enc_attention_mask[:, :, :, :] += pad_mask
|
323
363
|
|
@@ -342,6 +382,7 @@ class T5Decoder(nn.Module):
|
|
342
382
|
def get_model_config_t5() -> cfg.ModelConfig:
|
343
383
|
attn_config = cfg.AttentionConfig(
|
344
384
|
num_heads=12,
|
385
|
+
head_dim=64,
|
345
386
|
num_query_groups=12,
|
346
387
|
qkv_use_bias=False,
|
347
388
|
relative_attention_num_buckets=32,
|
@@ -357,19 +398,20 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
357
398
|
type=cfg.NormalizationType.RMS_NORM,
|
358
399
|
epsilon=1e-6,
|
359
400
|
)
|
360
|
-
|
401
|
+
block_config = cfg.TransformerBlockConfig(
|
402
|
+
attn_config=attn_config,
|
403
|
+
relative_attention=True,
|
404
|
+
ff_config=ff_config,
|
405
|
+
pre_attention_norm_config=norm_config,
|
406
|
+
post_attention_norm_config=norm_config,
|
407
|
+
)
|
361
408
|
config = cfg.ModelConfig(
|
362
409
|
vocab_size=32128,
|
363
410
|
num_layers=12,
|
364
411
|
max_seq_len=512,
|
365
412
|
embedding_dim=768,
|
366
|
-
|
367
|
-
relative_attention=True,
|
368
|
-
ff_config=ff_config,
|
369
|
-
pre_attention_norm_config=norm_config,
|
370
|
-
pre_ff_norm_config=norm_config,
|
413
|
+
block_configs=block_config,
|
371
414
|
final_norm_config=norm_config,
|
372
|
-
parallel_residual=False,
|
373
415
|
lm_head_use_bias=False,
|
374
416
|
enable_hlfb=True,
|
375
417
|
)
|
@@ -390,7 +432,7 @@ def build_t5_model(checkpoint_path: str) -> nn.Module:
|
|
390
432
|
"cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
|
391
433
|
"cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
|
392
434
|
# In the decoder, the FF is layer 2 in the Transformer block
|
393
|
-
"
|
435
|
+
"post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
|
394
436
|
# In the decoder, the cross attention is layer 1 in the Transformer block
|
395
437
|
"pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
396
438
|
}
|
@@ -446,7 +488,7 @@ def build_t5_decoder_model(
|
|
446
488
|
"cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
|
447
489
|
"cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
|
448
490
|
# In the decoder, the FF is layer 2 in the Transformer block
|
449
|
-
"
|
491
|
+
"post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
|
450
492
|
# In the decoder, the cross attention is layer 1 in the Transformer block
|
451
493
|
"pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
452
494
|
}
|
@@ -470,104 +512,101 @@ def build_t5_decoder_model(
|
|
470
512
|
|
471
513
|
|
472
514
|
def get_sample_encoder_input_ids() -> torch.Tensor:
|
473
|
-
idx = torch.tensor(
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
1,
|
553
|
-
]
|
554
|
-
]
|
555
|
-
)
|
515
|
+
idx = torch.tensor([[
|
516
|
+
3856,
|
517
|
+
27111,
|
518
|
+
10,
|
519
|
+
4425,
|
520
|
+
51,
|
521
|
+
4008,
|
522
|
+
31,
|
523
|
+
7,
|
524
|
+
2306,
|
525
|
+
16576,
|
526
|
+
47,
|
527
|
+
4381,
|
528
|
+
16,
|
529
|
+
8,
|
530
|
+
3414,
|
531
|
+
13,
|
532
|
+
1410,
|
533
|
+
16,
|
534
|
+
932,
|
535
|
+
11,
|
536
|
+
1515,
|
537
|
+
2766,
|
538
|
+
6,
|
539
|
+
11,
|
540
|
+
4838,
|
541
|
+
16,
|
542
|
+
23964,
|
543
|
+
16,
|
544
|
+
1797,
|
545
|
+
13,
|
546
|
+
24,
|
547
|
+
215,
|
548
|
+
5,
|
549
|
+
94,
|
550
|
+
47,
|
551
|
+
2017,
|
552
|
+
168,
|
553
|
+
1204,
|
554
|
+
57,
|
555
|
+
6800,
|
556
|
+
7,
|
557
|
+
11,
|
558
|
+
9443,
|
559
|
+
38,
|
560
|
+
3673,
|
561
|
+
8,
|
562
|
+
4016,
|
563
|
+
13,
|
564
|
+
66,
|
565
|
+
70,
|
566
|
+
14234,
|
567
|
+
5,
|
568
|
+
2449,
|
569
|
+
1215,
|
570
|
+
83,
|
571
|
+
17,
|
572
|
+
16,
|
573
|
+
8782,
|
574
|
+
70,
|
575
|
+
723,
|
576
|
+
30,
|
577
|
+
8,
|
578
|
+
6162,
|
579
|
+
13,
|
580
|
+
1410,
|
581
|
+
12,
|
582
|
+
48,
|
583
|
+
833,
|
584
|
+
250,
|
585
|
+
13,
|
586
|
+
149,
|
587
|
+
231,
|
588
|
+
79,
|
589
|
+
1858,
|
590
|
+
16576,
|
591
|
+
5,
|
592
|
+
1,
|
593
|
+
]])
|
556
594
|
return idx
|
557
595
|
|
558
596
|
|
559
597
|
def define_and_run_t5(checkpoint_path: str) -> None:
|
560
|
-
|
598
|
+
current_dir = Path(__file__).parent.resolve()
|
599
|
+
t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
|
561
600
|
|
562
601
|
model = build_t5_model(checkpoint_path)
|
563
602
|
|
564
603
|
idx = get_sample_encoder_input_ids()
|
565
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
604
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
566
605
|
tokens[0, :77] = idx
|
567
|
-
input_pos = torch.arange(0, 512)
|
606
|
+
input_pos = torch.arange(0, 512, dtype=torch.int)
|
568
607
|
|
569
|
-
decode_d_token = torch.tensor([[0]], dtype=torch.
|
570
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
608
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
609
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
571
610
|
pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
|
572
611
|
pad_mask[77:] = float("-inf")
|
573
612
|
lm_logits = model.forward(
|
@@ -579,20 +618,30 @@ def define_and_run_t5(checkpoint_path: str) -> None:
|
|
579
618
|
|
580
619
|
# TODO(haoliang): Move those tests.
|
581
620
|
def define_and_run_t5_split(checkpoint_path: str) -> None:
|
582
|
-
|
621
|
+
current_dir = Path(__file__).parent.resolve()
|
622
|
+
t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
|
623
|
+
|
583
624
|
config = get_model_config_t5()
|
584
|
-
embedding_layer = nn.Embedding(
|
585
|
-
|
586
|
-
|
625
|
+
embedding_layer = nn.Embedding(
|
626
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
627
|
+
)
|
628
|
+
t5_encoder_model = build_t5_encoder_model(
|
629
|
+
config, embedding_layer, checkpoint_path
|
630
|
+
)
|
631
|
+
t5_decoder_model = build_t5_decoder_model(
|
632
|
+
config, embedding_layer, checkpoint_path
|
633
|
+
)
|
587
634
|
idx = get_sample_encoder_input_ids()
|
588
635
|
|
589
|
-
tokens = torch.full((1, 512), 0, dtype=torch.
|
636
|
+
tokens = torch.full((1, 512), 0, dtype=torch.int, device="cpu")
|
590
637
|
tokens[0, :77] = idx
|
591
|
-
input_pos = torch.arange(0, 512)
|
638
|
+
input_pos = torch.arange(0, 512, dtype=torch.int)
|
592
639
|
|
593
|
-
decode_d_token = torch.tensor([[0]], dtype=torch.
|
594
|
-
decode_d_input_pos = torch.tensor([0], dtype=torch.
|
595
|
-
pad_mask = torch.zeros(
|
640
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
641
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
642
|
+
pad_mask = torch.zeros(
|
643
|
+
[t5_encoder_model.config.kv_cache_max], dtype=torch.float32
|
644
|
+
)
|
596
645
|
pad_mask[77:] = float("-inf")
|
597
646
|
hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
|
598
647
|
lm_logits = t5_decoder_model.forward(
|