ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +30 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +330 -0
- ai_edge_torch/convert/converter.py +171 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
- ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +273 -0
- ai_edge_torch/convert/test/test_convert_composites.py +171 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/debug/__init__.py +16 -0
- ai_edge_torch/debug/culprit.py +423 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +288 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +103 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +135 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_model_conversion.py +201 -0
- ai_edge_torch/generative/test/test_quantize.py +109 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +290 -0
- ai_edge_torch/generative/utilities/t5_loader.py +467 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +134 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building a T5 model.
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
import os
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
|
|
27
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
28
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
29
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
30
|
+
import ai_edge_torch.generative.utilities.t5_loader as loading_utils
|
|
31
|
+
|
|
32
|
+
ENCDEC_TENSOR_NAMES = {
|
|
33
|
+
"ff_up_proj": "{prefix}.block.{}.layer.{num}.DenseReluDense.wi",
|
|
34
|
+
"ff_down_proj": "{prefix}.block.{}.layer.{num}.DenseReluDense.wo",
|
|
35
|
+
"attn_query_proj": "{prefix}.block.{}.layer.0.SelfAttention.q",
|
|
36
|
+
"attn_key_proj": "{prefix}.block.{}.layer.0.SelfAttention.k",
|
|
37
|
+
"attn_value_proj": "{prefix}.block.{}.layer.0.SelfAttention.v",
|
|
38
|
+
"attn_output_proj": "{prefix}.block.{}.layer.0.SelfAttention.o",
|
|
39
|
+
"relative_attn_bias": "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias",
|
|
40
|
+
"pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
|
|
41
|
+
"pre_ff_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
42
|
+
"final_norm": "{prefix}.final_layer_norm",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
TENSOR_NAMES = {"lm_head": "lm_head", "embedding": "shared"}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class T5Stack(nn.Module):
|
|
49
|
+
|
|
50
|
+
def __init__(self, config, embed_tokens=None):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.config = config
|
|
53
|
+
self.embed_tokens = embed_tokens
|
|
54
|
+
self.is_decoder = config.is_decoder
|
|
55
|
+
self.transformer_blocks = nn.ModuleList(
|
|
56
|
+
[
|
|
57
|
+
EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
|
|
58
|
+
for i in range(config.num_layers)
|
|
59
|
+
]
|
|
60
|
+
)
|
|
61
|
+
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
|
|
62
|
+
|
|
63
|
+
def forward(
|
|
64
|
+
self,
|
|
65
|
+
input_ids: torch.Tensor,
|
|
66
|
+
input_pos: torch.Tensor,
|
|
67
|
+
attention_mask: torch.Tensor,
|
|
68
|
+
relative_position: torch.Tensor,
|
|
69
|
+
encoder_hidden_states: Optional[
|
|
70
|
+
torch.Tensor
|
|
71
|
+
] = None, # should be for decoder case
|
|
72
|
+
encoder_attention_mask: Optional[
|
|
73
|
+
torch.Tensor
|
|
74
|
+
] = None, # should be for decoder case
|
|
75
|
+
):
|
|
76
|
+
input_shape = input_ids.size()
|
|
77
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
78
|
+
batch_size, seq_length = input_shape
|
|
79
|
+
hidden_states = inputs_embeds
|
|
80
|
+
position_bias = None
|
|
81
|
+
encoder_decoder_position_bias = None
|
|
82
|
+
for i, layer_module in enumerate(self.transformer_blocks):
|
|
83
|
+
# EncoderDecoderBlock.forward
|
|
84
|
+
hidden_states, position_bias, encoder_decoder_position_bias = layer_module(
|
|
85
|
+
hidden_states,
|
|
86
|
+
input_pos,
|
|
87
|
+
mask=attention_mask,
|
|
88
|
+
relative_position=relative_position,
|
|
89
|
+
position_bias=position_bias,
|
|
90
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
91
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
92
|
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
hidden_states = self.final_norm(hidden_states)
|
|
96
|
+
return hidden_states
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class T5(nn.Module):
|
|
100
|
+
|
|
101
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
102
|
+
super().__init__()
|
|
103
|
+
|
|
104
|
+
self.config = config
|
|
105
|
+
# Construct model layers.
|
|
106
|
+
self.tok_embedding = nn.Embedding(
|
|
107
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
encoder_config = copy.deepcopy(config)
|
|
111
|
+
encoder_config.is_decoder = False
|
|
112
|
+
encoder_config.attn_config.enable_kv_cache = False
|
|
113
|
+
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
|
114
|
+
|
|
115
|
+
decoder_config = copy.deepcopy(config)
|
|
116
|
+
decoder_config.is_decoder = True
|
|
117
|
+
self.decoder = T5Stack(decoder_config, self.tok_embedding)
|
|
118
|
+
self.lm_head = nn.Linear(
|
|
119
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.enc_attn_mask_cache = (
|
|
123
|
+
torch.zeros(
|
|
124
|
+
(config.kv_cache_max, config.kv_cache_max),
|
|
125
|
+
dtype=torch.float32,
|
|
126
|
+
device=torch.device("cpu"),
|
|
127
|
+
)
|
|
128
|
+
.unsqueeze(0)
|
|
129
|
+
.unsqueeze(0)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
|
133
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
|
137
|
+
bidirectional=True,
|
|
138
|
+
query_length=config.kv_cache_max,
|
|
139
|
+
key_length=config.kv_cache_max,
|
|
140
|
+
num_buckets=config.attn_config.relative_attention_num_buckets,
|
|
141
|
+
max_distance=config.attn_config.relative_attention_max_distance,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
|
145
|
+
bidirectional=False,
|
|
146
|
+
query_length=config.kv_cache_max,
|
|
147
|
+
key_length=config.kv_cache_max,
|
|
148
|
+
num_buckets=config.attn_config.relative_attention_num_buckets,
|
|
149
|
+
max_distance=config.attn_config.relative_attention_max_distance,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@torch.inference_mode
|
|
153
|
+
def forward(
|
|
154
|
+
self,
|
|
155
|
+
input_ids: torch.Tensor,
|
|
156
|
+
input_pos: torch.Tensor,
|
|
157
|
+
decoder_input_ids: torch.Tensor,
|
|
158
|
+
decoder_input_pos: torch.Tensor,
|
|
159
|
+
pad_mask: torch.Tensor,
|
|
160
|
+
) -> torch.Tensor:
|
|
161
|
+
B, T = input_ids.size()
|
|
162
|
+
assert (
|
|
163
|
+
self.config.max_seq_len >= T
|
|
164
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
165
|
+
|
|
166
|
+
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
|
167
|
+
enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
|
|
168
|
+
# Mask off any "pad" tokens that shouldn't contribute to self-attention
|
|
169
|
+
enc_mask[:, :, :, :] += pad_mask
|
|
170
|
+
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
|
171
|
+
dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
|
|
172
|
+
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
|
173
|
+
enc_relative_position = enc_relative_position[:, :, :, : self.config.kv_cache_max]
|
|
174
|
+
dec_relative_position = self.enc_rel_pos_mask.index_select(2, decoder_input_pos)
|
|
175
|
+
dec_relative_position = dec_relative_position[:, :, :, : self.config.kv_cache_max]
|
|
176
|
+
enc_attention_mask = self.enc_attn_mask_cache.index_select(2, decoder_input_pos)
|
|
177
|
+
# Mask off any "pad" tokens that shouldn't contribute to cross attention
|
|
178
|
+
enc_attention_mask[:, :, :, :] += pad_mask
|
|
179
|
+
|
|
180
|
+
# Convert encoder inputs in embeddings if needed
|
|
181
|
+
encoder_hidden_states = self.encoder(
|
|
182
|
+
input_ids=input_ids,
|
|
183
|
+
input_pos=input_pos,
|
|
184
|
+
attention_mask=enc_mask,
|
|
185
|
+
relative_position=enc_relative_position,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Decode
|
|
189
|
+
decoder_out = self.decoder(
|
|
190
|
+
input_ids=decoder_input_ids,
|
|
191
|
+
input_pos=decoder_input_pos,
|
|
192
|
+
attention_mask=dec_mask,
|
|
193
|
+
relative_position=dec_relative_position,
|
|
194
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
195
|
+
encoder_attention_mask=enc_attention_mask,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Rescale output before projecting on vocab
|
|
199
|
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
|
200
|
+
sequence_output = decoder_out * (self.config.embedding_dim**-0.5)
|
|
201
|
+
|
|
202
|
+
lm_logits = self.lm_head(sequence_output)
|
|
203
|
+
return lm_logits
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class T5Encoder(nn.Module):
|
|
207
|
+
|
|
208
|
+
def __init__(self, config: cfg.ModelConfig, embedding_layer):
|
|
209
|
+
super().__init__()
|
|
210
|
+
|
|
211
|
+
self.config = config
|
|
212
|
+
# Construct model layers.
|
|
213
|
+
assert embedding_layer != None, "Passed in embedding layer should not be None!"
|
|
214
|
+
self.tok_embedding = embedding_layer
|
|
215
|
+
|
|
216
|
+
encoder_config = copy.deepcopy(config)
|
|
217
|
+
encoder_config.is_decoder = False
|
|
218
|
+
encoder_config.attn_config.enable_kv_cache = False
|
|
219
|
+
self.encoder = T5Stack(encoder_config, self.tok_embedding)
|
|
220
|
+
|
|
221
|
+
self.enc_attn_mask_cache = (
|
|
222
|
+
torch.zeros(
|
|
223
|
+
(config.kv_cache_max, config.kv_cache_max),
|
|
224
|
+
dtype=torch.float32,
|
|
225
|
+
device=torch.device("cpu"),
|
|
226
|
+
)
|
|
227
|
+
.unsqueeze(0)
|
|
228
|
+
.unsqueeze(0)
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
|
232
|
+
bidirectional=True,
|
|
233
|
+
query_length=config.kv_cache_max,
|
|
234
|
+
key_length=config.kv_cache_max,
|
|
235
|
+
num_buckets=config.attn_config.relative_attention_num_buckets,
|
|
236
|
+
max_distance=config.attn_config.relative_attention_max_distance,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
@torch.inference_mode
|
|
240
|
+
def forward(
|
|
241
|
+
self,
|
|
242
|
+
input_ids: torch.Tensor,
|
|
243
|
+
input_pos: torch.Tensor,
|
|
244
|
+
pad_mask: torch.Tensor,
|
|
245
|
+
) -> torch.Tensor:
|
|
246
|
+
B, T = input_ids.size()
|
|
247
|
+
assert (
|
|
248
|
+
self.config.max_seq_len >= T
|
|
249
|
+
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
|
|
250
|
+
|
|
251
|
+
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
|
252
|
+
enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
|
|
253
|
+
# Mask off any "pad" tokens that shouldn't contribute to self-attention
|
|
254
|
+
enc_mask[:, :, :, :] += pad_mask
|
|
255
|
+
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
|
256
|
+
enc_relative_position = enc_relative_position[:, :, :, : self.config.kv_cache_max]
|
|
257
|
+
|
|
258
|
+
# Convert encoder inputs in embeddings if needed
|
|
259
|
+
encoder_hidden_states = self.encoder(
|
|
260
|
+
input_ids=input_ids,
|
|
261
|
+
input_pos=input_pos,
|
|
262
|
+
attention_mask=enc_mask,
|
|
263
|
+
relative_position=enc_relative_position,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return encoder_hidden_states
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class T5Decoder(nn.Module):
|
|
270
|
+
|
|
271
|
+
def __init__(self, config: cfg.ModelConfig, embedding_layer):
|
|
272
|
+
super().__init__()
|
|
273
|
+
|
|
274
|
+
self.config = config
|
|
275
|
+
# Construct model layers.
|
|
276
|
+
assert embedding_layer != None, "Passed in embedding layer should not be None!"
|
|
277
|
+
self.tok_embedding = embedding_layer
|
|
278
|
+
|
|
279
|
+
decoder_config = copy.deepcopy(config)
|
|
280
|
+
decoder_config.is_decoder = True
|
|
281
|
+
self.decoder = T5Stack(decoder_config, self.tok_embedding)
|
|
282
|
+
self.lm_head = nn.Linear(
|
|
283
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
self.enc_attn_mask_cache = (
|
|
287
|
+
torch.zeros(
|
|
288
|
+
(config.kv_cache_max, config.kv_cache_max),
|
|
289
|
+
dtype=torch.float32,
|
|
290
|
+
device=torch.device("cpu"),
|
|
291
|
+
)
|
|
292
|
+
.unsqueeze(0)
|
|
293
|
+
.unsqueeze(0)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
|
297
|
+
bidirectional=True,
|
|
298
|
+
query_length=config.kv_cache_max,
|
|
299
|
+
key_length=config.kv_cache_max,
|
|
300
|
+
num_buckets=config.attn_config.relative_attention_num_buckets,
|
|
301
|
+
max_distance=config.attn_config.relative_attention_max_distance,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
|
305
|
+
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
@torch.inference_mode
|
|
309
|
+
def forward(
|
|
310
|
+
self,
|
|
311
|
+
encoder_hidden_states: torch.Tensor,
|
|
312
|
+
decoder_input_ids: torch.Tensor,
|
|
313
|
+
decoder_input_pos: torch.Tensor,
|
|
314
|
+
pad_mask: torch.Tensor,
|
|
315
|
+
) -> torch.Tensor:
|
|
316
|
+
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
|
317
|
+
dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
|
|
318
|
+
dec_relative_position = self.enc_rel_pos_mask.index_select(2, decoder_input_pos)
|
|
319
|
+
dec_relative_position = dec_relative_position[:, :, :, : self.config.kv_cache_max]
|
|
320
|
+
enc_attention_mask = self.enc_attn_mask_cache.index_select(2, decoder_input_pos)
|
|
321
|
+
# Mask off any "pad" tokens that shouldn't contribute to cross attention
|
|
322
|
+
enc_attention_mask[:, :, :, :] += pad_mask
|
|
323
|
+
|
|
324
|
+
# Decode
|
|
325
|
+
decoder_out = self.decoder(
|
|
326
|
+
input_ids=decoder_input_ids,
|
|
327
|
+
input_pos=decoder_input_pos,
|
|
328
|
+
attention_mask=dec_mask,
|
|
329
|
+
relative_position=dec_relative_position,
|
|
330
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
331
|
+
encoder_attention_mask=enc_attention_mask,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Rescale output before projecting on vocab
|
|
335
|
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
|
336
|
+
sequence_output = decoder_out * (self.config.embedding_dim**-0.5)
|
|
337
|
+
|
|
338
|
+
lm_logits = self.lm_head(sequence_output)
|
|
339
|
+
return lm_logits
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def get_model_config_t5() -> cfg.ModelConfig:
|
|
343
|
+
attn_config = cfg.AttentionConfig(
|
|
344
|
+
num_heads=12,
|
|
345
|
+
num_query_groups=12,
|
|
346
|
+
qkv_use_bias=False,
|
|
347
|
+
relative_attention_num_buckets=32,
|
|
348
|
+
relative_attention_max_distance=128,
|
|
349
|
+
)
|
|
350
|
+
ff_config = cfg.FeedForwardConfig(
|
|
351
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
352
|
+
activation=cfg.ActivationType.RELU,
|
|
353
|
+
intermediate_size=3072,
|
|
354
|
+
)
|
|
355
|
+
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
|
|
356
|
+
norm_config = cfg.NormalizationConfig(
|
|
357
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
358
|
+
epsilon=1e-6,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
config = cfg.ModelConfig(
|
|
362
|
+
vocab_size=32128,
|
|
363
|
+
num_layers=12,
|
|
364
|
+
max_seq_len=512,
|
|
365
|
+
embedding_dim=768,
|
|
366
|
+
attn_config=attn_config,
|
|
367
|
+
relative_attention=True,
|
|
368
|
+
ff_config=ff_config,
|
|
369
|
+
pre_attention_norm_config=norm_config,
|
|
370
|
+
pre_ff_norm_config=norm_config,
|
|
371
|
+
final_norm_config=norm_config,
|
|
372
|
+
parallel_residual=False,
|
|
373
|
+
lm_head_use_bias=False,
|
|
374
|
+
enable_hlfb=True,
|
|
375
|
+
)
|
|
376
|
+
return config
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def build_t5_model(checkpoint_path: str) -> nn.Module:
|
|
380
|
+
config = get_model_config_t5()
|
|
381
|
+
model = T5(config)
|
|
382
|
+
# Need the encoder and decoder mappings.
|
|
383
|
+
encoder_tensor_names = {
|
|
384
|
+
k: v.replace("{prefix}", "encoder").replace("{num}", "1")
|
|
385
|
+
for k, v in ENCDEC_TENSOR_NAMES.items()
|
|
386
|
+
}
|
|
387
|
+
decoder_tensor_names = ENCDEC_TENSOR_NAMES | {
|
|
388
|
+
"cross_attn_query_proj": "{prefix}.block.{}.layer.1.EncDecAttention.q",
|
|
389
|
+
"cross_attn_key_proj": "{prefix}.block.{}.layer.1.EncDecAttention.k",
|
|
390
|
+
"cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
|
|
391
|
+
"cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
|
|
392
|
+
# In the decoder, the FF is layer 2 in the Transformer block
|
|
393
|
+
"pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
|
|
394
|
+
# In the decoder, the cross attention is layer 1 in the Transformer block
|
|
395
|
+
"pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
decoder_tensor_names = {
|
|
399
|
+
k: v.replace("{prefix}", "decoder").replace("{num}", "2")
|
|
400
|
+
for k, v in decoder_tensor_names.items()
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
# Additional layer norms for Cross Attention in decoder
|
|
404
|
+
# decoder_tensor_names["pre_attn_norm"] = "{prefix}.block.{}.layer.1.layer_norm",
|
|
405
|
+
tensor_names = {
|
|
406
|
+
"encoder.": loading_utils.ModelLoader.TensorNames(**encoder_tensor_names),
|
|
407
|
+
"decoder.": loading_utils.ModelLoader.TensorNames(**decoder_tensor_names),
|
|
408
|
+
"": loading_utils.ModelLoader.TensorNames(**TENSOR_NAMES),
|
|
409
|
+
}
|
|
410
|
+
loader = loading_utils.ModelLoader(checkpoint_path, names=tensor_names)
|
|
411
|
+
# The embedding is shared between the encoder and decoder, so we set
|
|
412
|
+
# strict=False.
|
|
413
|
+
loader.load(model, strict=False, fuse_attention=False)
|
|
414
|
+
return model
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def build_t5_encoder_model(
|
|
418
|
+
config: cfg.ModelConfig, embedding_layer, checkpoint_path: str
|
|
419
|
+
) -> nn.Module:
|
|
420
|
+
model = T5Encoder(config, embedding_layer)
|
|
421
|
+
encoder_tensor_names = {
|
|
422
|
+
k: v.replace("{prefix}", "encoder").replace("{num}", "1")
|
|
423
|
+
for k, v in ENCDEC_TENSOR_NAMES.items()
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
# Additional layer norms for Cross Attention in decoder
|
|
427
|
+
# decoder_tensor_names["pre_attn_norm"] = "{prefix}.block.{}.layer.1.layer_norm",
|
|
428
|
+
tensor_names = {
|
|
429
|
+
"encoder.": loading_utils.ModelLoader.TensorNames(**encoder_tensor_names),
|
|
430
|
+
"": loading_utils.ModelLoader.TensorNames(**TENSOR_NAMES),
|
|
431
|
+
}
|
|
432
|
+
loader = loading_utils.ModelLoader(checkpoint_path, names=tensor_names)
|
|
433
|
+
# The embedding is shared between the encoder and decoder, so we set
|
|
434
|
+
# strict=False.
|
|
435
|
+
loader.load(model, strict=False, fuse_attention=False)
|
|
436
|
+
return model
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def build_t5_decoder_model(
|
|
440
|
+
config: cfg.ModelConfig, embedding_layer, checkpoint_path: str
|
|
441
|
+
) -> nn.Module:
|
|
442
|
+
model = T5Decoder(config, embedding_layer)
|
|
443
|
+
decoder_tensor_names = ENCDEC_TENSOR_NAMES | {
|
|
444
|
+
"cross_attn_query_proj": "{prefix}.block.{}.layer.1.EncDecAttention.q",
|
|
445
|
+
"cross_attn_key_proj": "{prefix}.block.{}.layer.1.EncDecAttention.k",
|
|
446
|
+
"cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
|
|
447
|
+
"cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
|
|
448
|
+
# In the decoder, the FF is layer 2 in the Transformer block
|
|
449
|
+
"pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
|
|
450
|
+
# In the decoder, the cross attention is layer 1 in the Transformer block
|
|
451
|
+
"pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
decoder_tensor_names = {
|
|
455
|
+
k: v.replace("{prefix}", "decoder").replace("{num}", "2")
|
|
456
|
+
for k, v in decoder_tensor_names.items()
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
# Additional layer norms for Cross Attention in decoder
|
|
460
|
+
# decoder_tensor_names["pre_attn_norm"] = "{prefix}.block.{}.layer.1.layer_norm",
|
|
461
|
+
tensor_names = {
|
|
462
|
+
"decoder.": loading_utils.ModelLoader.TensorNames(**decoder_tensor_names),
|
|
463
|
+
"": loading_utils.ModelLoader.TensorNames(**TENSOR_NAMES),
|
|
464
|
+
}
|
|
465
|
+
loader = loading_utils.ModelLoader(checkpoint_path, names=tensor_names)
|
|
466
|
+
# The embedding is shared between the encoder and decoder, so we set
|
|
467
|
+
# strict=False.
|
|
468
|
+
loader.load(model, strict=False, fuse_attention=False)
|
|
469
|
+
return model
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def get_sample_encoder_input_ids() -> torch.Tensor:
|
|
473
|
+
idx = torch.tensor(
|
|
474
|
+
[
|
|
475
|
+
[
|
|
476
|
+
3856,
|
|
477
|
+
27111,
|
|
478
|
+
10,
|
|
479
|
+
4425,
|
|
480
|
+
51,
|
|
481
|
+
4008,
|
|
482
|
+
31,
|
|
483
|
+
7,
|
|
484
|
+
2306,
|
|
485
|
+
16576,
|
|
486
|
+
47,
|
|
487
|
+
4381,
|
|
488
|
+
16,
|
|
489
|
+
8,
|
|
490
|
+
3414,
|
|
491
|
+
13,
|
|
492
|
+
1410,
|
|
493
|
+
16,
|
|
494
|
+
932,
|
|
495
|
+
11,
|
|
496
|
+
1515,
|
|
497
|
+
2766,
|
|
498
|
+
6,
|
|
499
|
+
11,
|
|
500
|
+
4838,
|
|
501
|
+
16,
|
|
502
|
+
23964,
|
|
503
|
+
16,
|
|
504
|
+
1797,
|
|
505
|
+
13,
|
|
506
|
+
24,
|
|
507
|
+
215,
|
|
508
|
+
5,
|
|
509
|
+
94,
|
|
510
|
+
47,
|
|
511
|
+
2017,
|
|
512
|
+
168,
|
|
513
|
+
1204,
|
|
514
|
+
57,
|
|
515
|
+
6800,
|
|
516
|
+
7,
|
|
517
|
+
11,
|
|
518
|
+
9443,
|
|
519
|
+
38,
|
|
520
|
+
3673,
|
|
521
|
+
8,
|
|
522
|
+
4016,
|
|
523
|
+
13,
|
|
524
|
+
66,
|
|
525
|
+
70,
|
|
526
|
+
14234,
|
|
527
|
+
5,
|
|
528
|
+
2449,
|
|
529
|
+
1215,
|
|
530
|
+
83,
|
|
531
|
+
17,
|
|
532
|
+
16,
|
|
533
|
+
8782,
|
|
534
|
+
70,
|
|
535
|
+
723,
|
|
536
|
+
30,
|
|
537
|
+
8,
|
|
538
|
+
6162,
|
|
539
|
+
13,
|
|
540
|
+
1410,
|
|
541
|
+
12,
|
|
542
|
+
48,
|
|
543
|
+
833,
|
|
544
|
+
250,
|
|
545
|
+
13,
|
|
546
|
+
149,
|
|
547
|
+
231,
|
|
548
|
+
79,
|
|
549
|
+
1858,
|
|
550
|
+
16576,
|
|
551
|
+
5,
|
|
552
|
+
1,
|
|
553
|
+
]
|
|
554
|
+
]
|
|
555
|
+
)
|
|
556
|
+
return idx
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def define_and_run_t5(checkpoint_path: str) -> None:
|
|
560
|
+
t5_goldens = torch.load("t5_lm_logits.pt")
|
|
561
|
+
|
|
562
|
+
model = build_t5_model(checkpoint_path)
|
|
563
|
+
|
|
564
|
+
idx = get_sample_encoder_input_ids()
|
|
565
|
+
tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
|
|
566
|
+
tokens[0, :77] = idx
|
|
567
|
+
input_pos = torch.arange(0, 512)
|
|
568
|
+
|
|
569
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int64)
|
|
570
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
571
|
+
pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
|
|
572
|
+
pad_mask[77:] = float("-inf")
|
|
573
|
+
lm_logits = model.forward(
|
|
574
|
+
tokens, input_pos, decode_d_token, decode_d_input_pos, pad_mask
|
|
575
|
+
)
|
|
576
|
+
print("comparing with goldens..")
|
|
577
|
+
assert torch.allclose(t5_goldens, lm_logits, atol=1e-05)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
# TODO(haoliang): Move those tests.
|
|
581
|
+
def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
582
|
+
t5_goldens = torch.load("t5_lm_logits.pt")
|
|
583
|
+
config = get_model_config_t5()
|
|
584
|
+
embedding_layer = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
|
|
585
|
+
t5_encoder_model = build_t5_encoder_model(config, embedding_layer, checkpoint_path)
|
|
586
|
+
t5_decoder_model = build_t5_decoder_model(config, embedding_layer, checkpoint_path)
|
|
587
|
+
idx = get_sample_encoder_input_ids()
|
|
588
|
+
|
|
589
|
+
tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
|
|
590
|
+
tokens[0, :77] = idx
|
|
591
|
+
input_pos = torch.arange(0, 512)
|
|
592
|
+
|
|
593
|
+
decode_d_token = torch.tensor([[0]], dtype=torch.int64)
|
|
594
|
+
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
595
|
+
pad_mask = torch.zeros([t5_encoder_model.config.kv_cache_max], dtype=torch.float32)
|
|
596
|
+
pad_mask[77:] = float("-inf")
|
|
597
|
+
hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
|
|
598
|
+
lm_logits = t5_decoder_model.forward(
|
|
599
|
+
hidden_states, decode_d_token, decode_d_input_pos, pad_mask
|
|
600
|
+
)
|
|
601
|
+
print("comparing with goldens..")
|
|
602
|
+
assert torch.allclose(t5_goldens, lm_logits, atol=1e-05)
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
if __name__ == "__main__":
|
|
606
|
+
checkpoint = os.path.join(Path.home(), "Downloads/llm_data/t5")
|
|
607
|
+
# define_and_run_t5(checkpoint)
|
|
608
|
+
define_and_run_t5_split(checkpoint)
|