ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -48
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
|
@@ -19,15 +19,14 @@ import os
|
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
from typing import Optional, Tuple
|
|
21
21
|
|
|
22
|
-
import numpy as np
|
|
23
|
-
import torch
|
|
24
|
-
import torch.nn as nn
|
|
25
|
-
|
|
26
22
|
from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
|
|
27
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
28
24
|
import ai_edge_torch.generative.layers.builder as builder
|
|
29
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
30
26
|
import ai_edge_torch.generative.utilities.t5_loader as loading_utils
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
29
|
+
import torch.nn as nn
|
|
31
30
|
|
|
32
31
|
ENCDEC_TENSOR_NAMES = {
|
|
33
32
|
"ff_up_proj": "{prefix}.block.{}.layer.{num}.DenseReluDense.wi",
|
|
@@ -36,7 +35,9 @@ ENCDEC_TENSOR_NAMES = {
|
|
|
36
35
|
"attn_key_proj": "{prefix}.block.{}.layer.0.SelfAttention.k",
|
|
37
36
|
"attn_value_proj": "{prefix}.block.{}.layer.0.SelfAttention.v",
|
|
38
37
|
"attn_output_proj": "{prefix}.block.{}.layer.0.SelfAttention.o",
|
|
39
|
-
"relative_attn_bias":
|
|
38
|
+
"relative_attn_bias": (
|
|
39
|
+
"{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
|
|
40
|
+
),
|
|
40
41
|
"pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
|
|
41
42
|
"pre_ff_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
42
43
|
"final_norm": "{prefix}.final_layer_norm",
|
|
@@ -52,13 +53,13 @@ class T5Stack(nn.Module):
|
|
|
52
53
|
self.config = config
|
|
53
54
|
self.embed_tokens = embed_tokens
|
|
54
55
|
self.is_decoder = config.is_decoder
|
|
55
|
-
self.transformer_blocks = nn.ModuleList(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
56
|
+
self.transformer_blocks = nn.ModuleList([
|
|
57
|
+
EncoderDecoderBlock(config, has_relative_attention_bias=bool(i == 0))
|
|
58
|
+
for i in range(config.num_layers)
|
|
59
|
+
])
|
|
60
|
+
self.final_norm = builder.build_norm(
|
|
61
|
+
config.embedding_dim, config.final_norm_config
|
|
60
62
|
)
|
|
61
|
-
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
|
|
62
63
|
|
|
63
64
|
def forward(
|
|
64
65
|
self,
|
|
@@ -81,15 +82,17 @@ class T5Stack(nn.Module):
|
|
|
81
82
|
encoder_decoder_position_bias = None
|
|
82
83
|
for i, layer_module in enumerate(self.transformer_blocks):
|
|
83
84
|
# EncoderDecoderBlock.forward
|
|
84
|
-
hidden_states, position_bias, encoder_decoder_position_bias =
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
85
|
+
hidden_states, position_bias, encoder_decoder_position_bias = (
|
|
86
|
+
layer_module(
|
|
87
|
+
hidden_states,
|
|
88
|
+
input_pos,
|
|
89
|
+
mask=attention_mask,
|
|
90
|
+
relative_position=relative_position,
|
|
91
|
+
position_bias=position_bias,
|
|
92
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
93
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
94
|
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
|
95
|
+
)
|
|
93
96
|
)
|
|
94
97
|
|
|
95
98
|
hidden_states = self.final_norm(hidden_states)
|
|
@@ -130,7 +133,9 @@ class T5(nn.Module):
|
|
|
130
133
|
)
|
|
131
134
|
|
|
132
135
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
|
133
|
-
size=config.kv_cache_max,
|
|
136
|
+
size=config.kv_cache_max,
|
|
137
|
+
dtype=torch.float32,
|
|
138
|
+
device=torch.device("cpu"),
|
|
134
139
|
)
|
|
135
140
|
|
|
136
141
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
|
@@ -159,9 +164,10 @@ class T5(nn.Module):
|
|
|
159
164
|
pad_mask: torch.Tensor,
|
|
160
165
|
) -> torch.Tensor:
|
|
161
166
|
B, T = input_ids.size()
|
|
162
|
-
assert (
|
|
163
|
-
|
|
164
|
-
|
|
167
|
+
assert self.config.max_seq_len >= T, (
|
|
168
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
169
|
+
f" {self.config.max_seq_len}"
|
|
170
|
+
)
|
|
165
171
|
|
|
166
172
|
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
|
167
173
|
enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
|
|
@@ -170,10 +176,18 @@ class T5(nn.Module):
|
|
|
170
176
|
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
|
171
177
|
dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
|
|
172
178
|
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
|
173
|
-
enc_relative_position = enc_relative_position[
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
179
|
+
enc_relative_position = enc_relative_position[
|
|
180
|
+
:, :, :, : self.config.kv_cache_max
|
|
181
|
+
]
|
|
182
|
+
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
|
183
|
+
2, decoder_input_pos
|
|
184
|
+
)
|
|
185
|
+
dec_relative_position = dec_relative_position[
|
|
186
|
+
:, :, :, : self.config.kv_cache_max
|
|
187
|
+
]
|
|
188
|
+
enc_attention_mask = self.enc_attn_mask_cache.index_select(
|
|
189
|
+
2, decoder_input_pos
|
|
190
|
+
)
|
|
177
191
|
# Mask off any "pad" tokens that shouldn't contribute to cross attention
|
|
178
192
|
enc_attention_mask[:, :, :, :] += pad_mask
|
|
179
193
|
|
|
@@ -210,7 +224,9 @@ class T5Encoder(nn.Module):
|
|
|
210
224
|
|
|
211
225
|
self.config = config
|
|
212
226
|
# Construct model layers.
|
|
213
|
-
assert
|
|
227
|
+
assert (
|
|
228
|
+
embedding_layer != None
|
|
229
|
+
), "Passed in embedding layer should not be None!"
|
|
214
230
|
self.tok_embedding = embedding_layer
|
|
215
231
|
|
|
216
232
|
encoder_config = copy.deepcopy(config)
|
|
@@ -244,16 +260,19 @@ class T5Encoder(nn.Module):
|
|
|
244
260
|
pad_mask: torch.Tensor,
|
|
245
261
|
) -> torch.Tensor:
|
|
246
262
|
B, T = input_ids.size()
|
|
247
|
-
assert (
|
|
248
|
-
|
|
249
|
-
|
|
263
|
+
assert self.config.max_seq_len >= T, (
|
|
264
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
265
|
+
f" {self.config.max_seq_len}"
|
|
266
|
+
)
|
|
250
267
|
|
|
251
268
|
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
|
252
269
|
enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
|
|
253
270
|
# Mask off any "pad" tokens that shouldn't contribute to self-attention
|
|
254
271
|
enc_mask[:, :, :, :] += pad_mask
|
|
255
272
|
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
|
256
|
-
enc_relative_position = enc_relative_position[
|
|
273
|
+
enc_relative_position = enc_relative_position[
|
|
274
|
+
:, :, :, : self.config.kv_cache_max
|
|
275
|
+
]
|
|
257
276
|
|
|
258
277
|
# Convert encoder inputs in embeddings if needed
|
|
259
278
|
encoder_hidden_states = self.encoder(
|
|
@@ -273,7 +292,9 @@ class T5Decoder(nn.Module):
|
|
|
273
292
|
|
|
274
293
|
self.config = config
|
|
275
294
|
# Construct model layers.
|
|
276
|
-
assert
|
|
295
|
+
assert (
|
|
296
|
+
embedding_layer != None
|
|
297
|
+
), "Passed in embedding layer should not be None!"
|
|
277
298
|
self.tok_embedding = embedding_layer
|
|
278
299
|
|
|
279
300
|
decoder_config = copy.deepcopy(config)
|
|
@@ -302,7 +323,9 @@ class T5Decoder(nn.Module):
|
|
|
302
323
|
)
|
|
303
324
|
|
|
304
325
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
|
305
|
-
size=config.kv_cache_max,
|
|
326
|
+
size=config.kv_cache_max,
|
|
327
|
+
dtype=torch.float32,
|
|
328
|
+
device=torch.device("cpu"),
|
|
306
329
|
)
|
|
307
330
|
|
|
308
331
|
@torch.inference_mode
|
|
@@ -315,9 +338,15 @@ class T5Decoder(nn.Module):
|
|
|
315
338
|
) -> torch.Tensor:
|
|
316
339
|
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
|
317
340
|
dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
|
|
318
|
-
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
|
319
|
-
|
|
320
|
-
|
|
341
|
+
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
|
342
|
+
2, decoder_input_pos
|
|
343
|
+
)
|
|
344
|
+
dec_relative_position = dec_relative_position[
|
|
345
|
+
:, :, :, : self.config.kv_cache_max
|
|
346
|
+
]
|
|
347
|
+
enc_attention_mask = self.enc_attn_mask_cache.index_select(
|
|
348
|
+
2, decoder_input_pos
|
|
349
|
+
)
|
|
321
350
|
# Mask off any "pad" tokens that shouldn't contribute to cross attention
|
|
322
351
|
enc_attention_mask[:, :, :, :] += pad_mask
|
|
323
352
|
|
|
@@ -470,89 +499,85 @@ def build_t5_decoder_model(
|
|
|
470
499
|
|
|
471
500
|
|
|
472
501
|
def get_sample_encoder_input_ids() -> torch.Tensor:
|
|
473
|
-
idx = torch.tensor(
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
1,
|
|
553
|
-
]
|
|
554
|
-
]
|
|
555
|
-
)
|
|
502
|
+
idx = torch.tensor([[
|
|
503
|
+
3856,
|
|
504
|
+
27111,
|
|
505
|
+
10,
|
|
506
|
+
4425,
|
|
507
|
+
51,
|
|
508
|
+
4008,
|
|
509
|
+
31,
|
|
510
|
+
7,
|
|
511
|
+
2306,
|
|
512
|
+
16576,
|
|
513
|
+
47,
|
|
514
|
+
4381,
|
|
515
|
+
16,
|
|
516
|
+
8,
|
|
517
|
+
3414,
|
|
518
|
+
13,
|
|
519
|
+
1410,
|
|
520
|
+
16,
|
|
521
|
+
932,
|
|
522
|
+
11,
|
|
523
|
+
1515,
|
|
524
|
+
2766,
|
|
525
|
+
6,
|
|
526
|
+
11,
|
|
527
|
+
4838,
|
|
528
|
+
16,
|
|
529
|
+
23964,
|
|
530
|
+
16,
|
|
531
|
+
1797,
|
|
532
|
+
13,
|
|
533
|
+
24,
|
|
534
|
+
215,
|
|
535
|
+
5,
|
|
536
|
+
94,
|
|
537
|
+
47,
|
|
538
|
+
2017,
|
|
539
|
+
168,
|
|
540
|
+
1204,
|
|
541
|
+
57,
|
|
542
|
+
6800,
|
|
543
|
+
7,
|
|
544
|
+
11,
|
|
545
|
+
9443,
|
|
546
|
+
38,
|
|
547
|
+
3673,
|
|
548
|
+
8,
|
|
549
|
+
4016,
|
|
550
|
+
13,
|
|
551
|
+
66,
|
|
552
|
+
70,
|
|
553
|
+
14234,
|
|
554
|
+
5,
|
|
555
|
+
2449,
|
|
556
|
+
1215,
|
|
557
|
+
83,
|
|
558
|
+
17,
|
|
559
|
+
16,
|
|
560
|
+
8782,
|
|
561
|
+
70,
|
|
562
|
+
723,
|
|
563
|
+
30,
|
|
564
|
+
8,
|
|
565
|
+
6162,
|
|
566
|
+
13,
|
|
567
|
+
1410,
|
|
568
|
+
12,
|
|
569
|
+
48,
|
|
570
|
+
833,
|
|
571
|
+
250,
|
|
572
|
+
13,
|
|
573
|
+
149,
|
|
574
|
+
231,
|
|
575
|
+
79,
|
|
576
|
+
1858,
|
|
577
|
+
16576,
|
|
578
|
+
5,
|
|
579
|
+
1,
|
|
580
|
+
]])
|
|
556
581
|
return idx
|
|
557
582
|
|
|
558
583
|
|
|
@@ -584,9 +609,15 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
|
584
609
|
t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
|
|
585
610
|
|
|
586
611
|
config = get_model_config_t5()
|
|
587
|
-
embedding_layer = nn.Embedding(
|
|
588
|
-
|
|
589
|
-
|
|
612
|
+
embedding_layer = nn.Embedding(
|
|
613
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
614
|
+
)
|
|
615
|
+
t5_encoder_model = build_t5_encoder_model(
|
|
616
|
+
config, embedding_layer, checkpoint_path
|
|
617
|
+
)
|
|
618
|
+
t5_decoder_model = build_t5_decoder_model(
|
|
619
|
+
config, embedding_layer, checkpoint_path
|
|
620
|
+
)
|
|
590
621
|
idx = get_sample_encoder_input_ids()
|
|
591
622
|
|
|
592
623
|
tokens = torch.full((1, 512), 0, dtype=torch.long, device="cpu")
|
|
@@ -595,7 +626,9 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
|
595
626
|
|
|
596
627
|
decode_d_token = torch.tensor([[0]], dtype=torch.int64)
|
|
597
628
|
decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
598
|
-
pad_mask = torch.zeros(
|
|
629
|
+
pad_mask = torch.zeros(
|
|
630
|
+
[t5_encoder_model.config.kv_cache_max], dtype=torch.float32
|
|
631
|
+
)
|
|
599
632
|
pad_mask[77:] = float("-inf")
|
|
600
633
|
hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
|
|
601
634
|
lm_logits = t5_decoder_model.forward(
|
|
@@ -16,16 +16,15 @@
|
|
|
16
16
|
|
|
17
17
|
from typing import Optional, Tuple
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
from torch import nn
|
|
21
|
-
import torch.nn.functional as F
|
|
22
|
-
|
|
23
19
|
from ai_edge_torch.generative.layers.attention import CrossAttention
|
|
24
20
|
import ai_edge_torch.generative.layers.builder as builder
|
|
25
21
|
from ai_edge_torch.generative.layers.kv_cache import KVCache
|
|
26
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
27
23
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
28
24
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
25
|
+
import torch
|
|
26
|
+
from torch import nn
|
|
27
|
+
import torch.nn.functional as F
|
|
29
28
|
|
|
30
29
|
BATCH_SIZE = 1
|
|
31
30
|
|
|
@@ -181,9 +180,13 @@ class T5Attention(CrossAttention):
|
|
|
181
180
|
"""
|
|
182
181
|
|
|
183
182
|
x = self.pre_atten_norm(x)
|
|
184
|
-
B, T, C =
|
|
183
|
+
B, T, C = (
|
|
184
|
+
x.size()
|
|
185
|
+
) # batch size, sequence length, embedding dimensionality (n_embd)
|
|
185
186
|
query_states = self.q_projection(x)
|
|
186
|
-
query_states = query_states.reshape(
|
|
187
|
+
query_states = query_states.reshape(
|
|
188
|
+
B, T, -1, self.head_dim
|
|
189
|
+
) # (B, T, nh_q, hs)
|
|
187
190
|
|
|
188
191
|
if key_value_states is not None:
|
|
189
192
|
(
|
|
@@ -223,7 +226,12 @@ class T5Attention(CrossAttention):
|
|
|
223
226
|
|
|
224
227
|
mask = mask + position_bias
|
|
225
228
|
y = self.sdpa_func(
|
|
226
|
-
query_states,
|
|
229
|
+
query_states,
|
|
230
|
+
key_states,
|
|
231
|
+
value_states,
|
|
232
|
+
self.head_dim,
|
|
233
|
+
mask=mask,
|
|
234
|
+
scale=1.0,
|
|
227
235
|
)
|
|
228
236
|
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
|
|
229
237
|
# output projection
|
|
@@ -15,15 +15,14 @@
|
|
|
15
15
|
# A toy example which has a single-layer transformer block.
|
|
16
16
|
from typing import Tuple
|
|
17
17
|
|
|
18
|
-
import numpy as np
|
|
19
|
-
import torch
|
|
20
|
-
import torch.nn as nn
|
|
21
|
-
|
|
22
18
|
import ai_edge_torch
|
|
23
19
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
24
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
25
21
|
import ai_edge_torch.generative.layers.builder as builder
|
|
26
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
|
+
import numpy as np
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
27
26
|
|
|
28
27
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
29
28
|
KV_CACHE_MAX_LEN = 100
|
|
@@ -72,7 +71,10 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
|
72
71
|
|
|
73
72
|
def get_model_config() -> cfg.ModelConfig:
|
|
74
73
|
attn_config = cfg.AttentionConfig(
|
|
75
|
-
num_heads=32,
|
|
74
|
+
num_heads=32,
|
|
75
|
+
num_query_groups=4,
|
|
76
|
+
rotary_percentage=1.0,
|
|
77
|
+
enable_kv_cache=False,
|
|
76
78
|
)
|
|
77
79
|
ff_config = cfg.FeedForwardConfig(
|
|
78
80
|
type=cfg.FeedForwardType.GATED,
|
|
@@ -16,16 +16,15 @@
|
|
|
16
16
|
|
|
17
17
|
from typing import Tuple
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
import torch.nn as nn
|
|
21
|
-
import torch_xla
|
|
22
|
-
|
|
23
19
|
import ai_edge_torch
|
|
24
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
25
21
|
import ai_edge_torch.generative.layers.builder as builder
|
|
26
22
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
27
23
|
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
28
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch_xla
|
|
29
28
|
|
|
30
29
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
31
30
|
|
|
@@ -15,16 +15,15 @@
|
|
|
15
15
|
# A toy example which has basic transformer block (w/ KV-Cache).
|
|
16
16
|
from typing import List, Tuple
|
|
17
17
|
|
|
18
|
-
import numpy as np
|
|
19
|
-
import torch
|
|
20
|
-
import torch.nn as nn
|
|
21
|
-
import torch_xla
|
|
22
|
-
|
|
23
18
|
import ai_edge_torch
|
|
24
19
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
25
20
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
21
|
import ai_edge_torch.generative.layers.builder as builder
|
|
27
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
|
+
import numpy as np
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
import torch_xla
|
|
28
27
|
|
|
29
28
|
RoPECache = Tuple[torch.Tensor, torch.Tensor]
|
|
30
29
|
|
|
@@ -16,11 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from pathlib import Path
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
19
|
import ai_edge_torch
|
|
22
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
23
21
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
import torch
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def convert_tiny_llama_to_tflite(
|
|
@@ -58,7 +57,9 @@ def convert_tiny_llama_to_tflite(
|
|
|
58
57
|
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
59
58
|
.convert(quant_config=quant_config)
|
|
60
59
|
)
|
|
61
|
-
edge_model.export(
|
|
60
|
+
edge_model.export(
|
|
61
|
+
f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
|
|
62
|
+
)
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
if __name__ == '__main__':
|
|
@@ -17,15 +17,14 @@
|
|
|
17
17
|
import os
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
|
|
20
|
-
import numpy as np
|
|
21
|
-
import torch
|
|
22
|
-
import torch.nn as nn
|
|
23
|
-
|
|
24
20
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
25
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
22
|
import ai_edge_torch.generative.layers.builder as builder
|
|
27
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
29
28
|
|
|
30
29
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
31
30
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
@@ -72,7 +71,9 @@ class TinyLLamma(nn.Module):
|
|
|
72
71
|
device=torch.device("cpu"),
|
|
73
72
|
)
|
|
74
73
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
75
|
-
size=config.kv_cache_max,
|
|
74
|
+
size=config.kv_cache_max,
|
|
75
|
+
dtype=torch.float32,
|
|
76
|
+
device=torch.device("cpu"),
|
|
76
77
|
)
|
|
77
78
|
self.config = config
|
|
78
79
|
|
|
@@ -82,9 +83,10 @@ class TinyLLamma(nn.Module):
|
|
|
82
83
|
@torch.inference_mode
|
|
83
84
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
84
85
|
B, T = idx.size()
|
|
85
|
-
assert (
|
|
86
|
-
|
|
87
|
-
|
|
86
|
+
assert self.config.max_seq_len >= T, (
|
|
87
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
88
|
+
f" {self.config.max_seq_len}"
|
|
89
|
+
)
|
|
88
90
|
|
|
89
91
|
cos, sin = self.rope_cache
|
|
90
92
|
cos = cos.index_select(0, input_pos)
|
|
@@ -12,11 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
|
|
17
15
|
from ai_edge_torch.convert.fx_passes import CanonicalizePass
|
|
18
16
|
from ai_edge_torch.convert.fx_passes import run_passes
|
|
19
17
|
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
|
|
18
|
+
import torch
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
def run_generative_passes(
|
|
@@ -12,10 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
|
|
17
15
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
18
16
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
17
|
+
import torch
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
@@ -36,7 +35,11 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
|
36
35
|
# Composite info:
|
|
37
36
|
# - name: odml.scaled_dot_product_attention
|
|
38
37
|
# - inputs: q, k, v, mask
|
|
39
|
-
if
|
|
38
|
+
if (
|
|
39
|
+
name == "odml.scaled_dot_product_attention"
|
|
40
|
+
and is_input
|
|
41
|
+
and io_position == 3
|
|
42
|
+
):
|
|
40
43
|
if self.is_zero_tensor_node(source):
|
|
41
44
|
# Remove the mark_tensor call on the mask input by
|
|
42
45
|
# replacing the target with an identity function.
|