ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240802__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +45 -36
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -47
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/RECORD +89 -89
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/top_level.txt +0 -0
ai_edge_torch/debug/culprit.py
CHANGED
|
@@ -23,14 +23,13 @@ import os
|
|
|
23
23
|
import sys
|
|
24
24
|
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
25
25
|
|
|
26
|
+
import ai_edge_torch
|
|
27
|
+
from ai_edge_torch.debug import utils
|
|
26
28
|
from functorch.compile import minifier as fx_minifier
|
|
27
29
|
import torch
|
|
28
30
|
from torch._functorch import aot_autograd
|
|
29
31
|
import torch.utils._pytree as pytree
|
|
30
32
|
|
|
31
|
-
import ai_edge_torch
|
|
32
|
-
from ai_edge_torch.debug import utils
|
|
33
|
-
|
|
34
33
|
_torch_float_dtypes = {
|
|
35
34
|
torch.float32,
|
|
36
35
|
torch.float,
|
|
@@ -120,21 +119,29 @@ class Culprit(SearchResult):
|
|
|
120
119
|
# TODO (b/321263453): Support Python code gen with sample arg tensor values.
|
|
121
120
|
random_inputs = True
|
|
122
121
|
|
|
123
|
-
graph_module_code = self.graph_module.print_readable(
|
|
122
|
+
graph_module_code = self.graph_module.print_readable(
|
|
123
|
+
print_output=False
|
|
124
|
+
).rstrip()
|
|
124
125
|
|
|
125
126
|
input_strs = []
|
|
126
127
|
for value in self.inputs:
|
|
127
128
|
if torch.is_tensor(value):
|
|
128
129
|
if not random_inputs:
|
|
129
|
-
input_strs.append(
|
|
130
|
-
|
|
130
|
+
input_strs.append(
|
|
131
|
+
f"# size={_get_shape_str(value)}, dtype={value.dtype}"
|
|
132
|
+
)
|
|
133
|
+
input_strs.append(
|
|
134
|
+
f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),"
|
|
135
|
+
)
|
|
131
136
|
else:
|
|
132
137
|
input_strs.append(_tensor_to_random_tensor_call(value) + ",")
|
|
133
138
|
else:
|
|
134
139
|
input_strs.append(str(value) + ",")
|
|
135
140
|
|
|
136
141
|
inputs_code = (
|
|
137
|
-
"_args = (\n"
|
|
142
|
+
"_args = (\n"
|
|
143
|
+
+ "\n".join([" " * 4 + code for code in input_strs])
|
|
144
|
+
+ "\n)"
|
|
138
145
|
)
|
|
139
146
|
|
|
140
147
|
code = graph_module_code + "\n\n" + inputs_code
|
|
@@ -157,7 +164,9 @@ class Culprit(SearchResult):
|
|
|
157
164
|
+ "from torch import device\n"
|
|
158
165
|
+ "import ai_edge_torch\n\n"
|
|
159
166
|
+ definitions
|
|
160
|
-
+
|
|
167
|
+
+ "\n\n_edge_model ="
|
|
168
|
+
f" ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(),"
|
|
169
|
+
" _args)\n"
|
|
161
170
|
)
|
|
162
171
|
if self._runtime_errors:
|
|
163
172
|
code += "_edge_model(*_args)\n"
|
|
@@ -212,7 +221,9 @@ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
|
|
|
212
221
|
return fx_gm
|
|
213
222
|
|
|
214
223
|
|
|
215
|
-
def _erase_unused_inputs(
|
|
224
|
+
def _erase_unused_inputs(
|
|
225
|
+
fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
|
|
226
|
+
):
|
|
216
227
|
fx_gm = copy.deepcopy(fx_gm)
|
|
217
228
|
inputs = tuple(inputs)
|
|
218
229
|
args = fx_gm.graph.process_inputs(*inputs)
|
|
@@ -316,7 +327,9 @@ def _erase_sub_gm_from_gm(
|
|
|
316
327
|
return fx_gm, fx_inputs
|
|
317
328
|
|
|
318
329
|
|
|
319
|
-
def _normalize_minified_fx_gm(
|
|
330
|
+
def _normalize_minified_fx_gm(
|
|
331
|
+
fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
|
|
332
|
+
):
|
|
320
333
|
fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
|
|
321
334
|
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
|
|
322
335
|
fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
|
|
@@ -374,7 +387,8 @@ def _search_model(
|
|
|
374
387
|
ep = torch.export.export(model, export_args)
|
|
375
388
|
except Exception as err:
|
|
376
389
|
raise ValueError(
|
|
377
|
-
"Your model is not exportable by torch.export.export. Please modify
|
|
390
|
+
"Your model is not exportable by torch.export.export. Please modify"
|
|
391
|
+
" your model to be torch-exportable first."
|
|
378
392
|
) from err
|
|
379
393
|
else:
|
|
380
394
|
ep = model
|
|
@@ -392,7 +406,9 @@ def _search_model(
|
|
|
392
406
|
xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
|
|
393
407
|
del os.environ["XLA_HLO_DEBUG"]
|
|
394
408
|
|
|
395
|
-
create_minified_hlo_graph =
|
|
409
|
+
create_minified_hlo_graph = (
|
|
410
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph
|
|
411
|
+
)
|
|
396
412
|
torch._functorch.fx_minifier.create_minified_hlo_graph = (
|
|
397
413
|
lambda *args, **kwargs: None
|
|
398
414
|
)
|
|
@@ -403,7 +419,9 @@ def _search_model(
|
|
|
403
419
|
if xla_hlo_debug_value is not None:
|
|
404
420
|
os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
|
|
405
421
|
|
|
406
|
-
torch._functorch.fx_minifier.create_minified_hlo_graph =
|
|
422
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph = (
|
|
423
|
+
create_minified_hlo_graph
|
|
424
|
+
)
|
|
407
425
|
|
|
408
426
|
found_culprits_num = 0
|
|
409
427
|
while True:
|
|
@@ -420,7 +438,9 @@ def _search_model(
|
|
|
420
438
|
max_granularity=max_granularity,
|
|
421
439
|
)
|
|
422
440
|
|
|
423
|
-
min_fx_gm, min_inputs = _normalize_minified_fx_gm(
|
|
441
|
+
min_fx_gm, min_inputs = _normalize_minified_fx_gm(
|
|
442
|
+
raw_min_fx_gm, raw_min_inputs
|
|
443
|
+
)
|
|
424
444
|
found_culprits_num += 1
|
|
425
445
|
yield SearchResult(min_fx_gm, min_inputs)
|
|
426
446
|
|
|
@@ -429,7 +449,10 @@ def _search_model(
|
|
|
429
449
|
)
|
|
430
450
|
|
|
431
451
|
except RuntimeError as e:
|
|
432
|
-
if
|
|
452
|
+
if (
|
|
453
|
+
str(e) == "Input graph did not fail the tester"
|
|
454
|
+
and found_culprits_num > 0
|
|
455
|
+
):
|
|
433
456
|
break
|
|
434
457
|
raise e
|
|
435
458
|
|
|
@@ -467,5 +490,7 @@ def find_culprits(
|
|
|
467
490
|
enable_fx_minifier_logging=enable_fx_minifier_logging,
|
|
468
491
|
):
|
|
469
492
|
yield Culprit(
|
|
470
|
-
search_result.graph_module,
|
|
493
|
+
search_result.graph_module,
|
|
494
|
+
search_result.inputs,
|
|
495
|
+
_runtime_errors=runtime_errors,
|
|
471
496
|
)
|
|
@@ -19,16 +19,17 @@ import io
|
|
|
19
19
|
import sys
|
|
20
20
|
import unittest
|
|
21
21
|
|
|
22
|
-
import torch
|
|
23
|
-
|
|
24
22
|
from ai_edge_torch.debug import find_culprits
|
|
23
|
+
import torch
|
|
25
24
|
|
|
26
25
|
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
|
|
27
26
|
|
|
28
27
|
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
|
|
29
28
|
|
|
30
29
|
|
|
31
|
-
@torch.library.impl(
|
|
30
|
+
@torch.library.impl(
|
|
31
|
+
_test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd"
|
|
32
|
+
)
|
|
32
33
|
def non_lowerable_op(x):
|
|
33
34
|
if x.max() > 10.0:
|
|
34
35
|
return x + 1.0
|
|
@@ -16,9 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
import unittest
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
19
|
from ai_edge_torch.debug import _search_model
|
|
20
|
+
import torch
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
class TestSearchModel(unittest.TestCase):
|
|
@@ -43,7 +42,9 @@ class TestSearchModel(unittest.TestCase):
|
|
|
43
42
|
|
|
44
43
|
results = list(_search_model(find_subgraph_with_sub, model, args))
|
|
45
44
|
self.assertEqual(len(results), 2)
|
|
46
|
-
self.assertIn(
|
|
45
|
+
self.assertIn(
|
|
46
|
+
torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes]
|
|
47
|
+
)
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
if __name__ == "__main__":
|
ai_edge_torch/debug/utils.py
CHANGED
|
@@ -21,7 +21,9 @@ import torch.fx._pytree as fx_pytree
|
|
|
21
21
|
from torch.utils import _pytree as pytree
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def exported_program_to_fx_graph_module_and_inputs(
|
|
24
|
+
def exported_program_to_fx_graph_module_and_inputs(
|
|
25
|
+
ep: torch.export.ExportedProgram,
|
|
26
|
+
):
|
|
25
27
|
fx_gm = ep.graph_module
|
|
26
28
|
fx_inputs = pytree.tree_map(
|
|
27
29
|
torch.tensor, ep._graph_module_flat_inputs(*ep.example_inputs)
|
|
@@ -20,12 +20,11 @@
|
|
|
20
20
|
import os
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
|
|
23
|
-
import torch
|
|
24
|
-
|
|
25
23
|
import ai_edge_torch
|
|
26
24
|
from ai_edge_torch.generative.examples.experimental.gemma import gemma
|
|
27
25
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
28
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
27
|
+
import torch
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
def convert_gemma_to_tflite(
|
|
@@ -79,7 +78,9 @@ def convert_gemma_to_tflite(
|
|
|
79
78
|
)
|
|
80
79
|
.convert(quant_config=quant_config)
|
|
81
80
|
)
|
|
82
|
-
edge_model.export(
|
|
81
|
+
edge_model.export(
|
|
82
|
+
f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
|
83
|
+
)
|
|
83
84
|
|
|
84
85
|
|
|
85
86
|
if __name__ == '__main__':
|
|
@@ -21,16 +21,15 @@ import os
|
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Tuple
|
|
23
23
|
|
|
24
|
-
import numpy as np
|
|
25
|
-
import torch
|
|
26
|
-
import torch.nn as nn
|
|
27
|
-
|
|
28
24
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
29
25
|
import ai_edge_torch.generative.layers.builder as builder
|
|
30
26
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
31
27
|
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
32
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
33
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
30
|
+
import numpy as np
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
34
33
|
|
|
35
34
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
36
35
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
@@ -81,7 +80,9 @@ class Gemma(nn.Module):
|
|
|
81
80
|
device=torch.device("cpu"),
|
|
82
81
|
)
|
|
83
82
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
84
|
-
size=config.kv_cache_max,
|
|
83
|
+
size=config.kv_cache_max,
|
|
84
|
+
dtype=torch.float32,
|
|
85
|
+
device=torch.device("cpu"),
|
|
85
86
|
)
|
|
86
87
|
self.config = config
|
|
87
88
|
|
|
@@ -93,9 +94,10 @@ class Gemma(nn.Module):
|
|
|
93
94
|
kv_cache: kv_utils.EKVCache,
|
|
94
95
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
95
96
|
B, T = tokens.size()
|
|
96
|
-
assert (
|
|
97
|
-
|
|
98
|
-
|
|
97
|
+
assert self.config.max_seq_len >= T, (
|
|
98
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
99
|
+
f" {self.config.max_seq_len}"
|
|
100
|
+
)
|
|
99
101
|
|
|
100
102
|
cos, sin = self.rope_cache
|
|
101
103
|
cos = cos.index_select(0, input_pos)
|
|
@@ -19,12 +19,11 @@
|
|
|
19
19
|
import os
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
|
|
22
|
-
import torch
|
|
23
|
-
|
|
24
22
|
import ai_edge_torch
|
|
25
23
|
from ai_edge_torch.generative.examples.experimental.phi import phi2
|
|
26
24
|
from ai_edge_torch.generative.layers.experimental import ekv_cache
|
|
27
25
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
26
|
+
import torch
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
def convert_phi2_to_tflite(
|
|
@@ -46,7 +45,9 @@ def convert_phi2_to_tflite(
|
|
|
46
45
|
quantize (bool, optional): Whether the model should be quanized.
|
|
47
46
|
Defaults to True.
|
|
48
47
|
"""
|
|
49
|
-
pytorch_model = phi2.build_model(
|
|
48
|
+
pytorch_model = phi2.build_model(
|
|
49
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
50
|
+
)
|
|
50
51
|
# Tensors used to trace the model graph during conversion.
|
|
51
52
|
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
52
53
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
@@ -76,7 +77,9 @@ def convert_phi2_to_tflite(
|
|
|
76
77
|
)
|
|
77
78
|
.convert(quant_config=quant_config)
|
|
78
79
|
)
|
|
79
|
-
edge_model.export(
|
|
80
|
+
edge_model.export(
|
|
81
|
+
f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
|
82
|
+
)
|
|
80
83
|
|
|
81
84
|
|
|
82
85
|
if __name__ == '__main__':
|
|
@@ -22,16 +22,15 @@ import os
|
|
|
22
22
|
from pathlib import Path
|
|
23
23
|
from typing import Tuple
|
|
24
24
|
|
|
25
|
-
import numpy as np
|
|
26
|
-
import torch
|
|
27
|
-
import torch.nn as nn
|
|
28
|
-
|
|
29
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
30
26
|
import ai_edge_torch.generative.layers.builder as builder
|
|
31
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
32
28
|
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
33
29
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
34
30
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
31
|
+
import numpy as np
|
|
32
|
+
import torch
|
|
33
|
+
import torch.nn as nn
|
|
35
34
|
|
|
36
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
37
36
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
|
@@ -76,7 +75,9 @@ class Phi2(nn.Module):
|
|
|
76
75
|
device=torch.device("cpu"),
|
|
77
76
|
)
|
|
78
77
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
79
|
-
size=config.kv_cache_max,
|
|
78
|
+
size=config.kv_cache_max,
|
|
79
|
+
dtype=torch.float32,
|
|
80
|
+
device=torch.device("cpu"),
|
|
80
81
|
)
|
|
81
82
|
self.config = config
|
|
82
83
|
|
|
@@ -88,9 +89,10 @@ class Phi2(nn.Module):
|
|
|
88
89
|
kv_cache: kv_utils.EKVCache,
|
|
89
90
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
90
91
|
B, T = tokens.size()
|
|
91
|
-
assert (
|
|
92
|
-
|
|
93
|
-
|
|
92
|
+
assert self.config.max_seq_len >= T, (
|
|
93
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
94
|
+
f" {self.config.max_seq_len}"
|
|
95
|
+
)
|
|
94
96
|
|
|
95
97
|
cos, sin = self.rope_cache
|
|
96
98
|
cos = cos.index_select(0, input_pos)
|
|
@@ -20,12 +20,11 @@
|
|
|
20
20
|
import os
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
|
|
23
|
-
import torch
|
|
24
|
-
|
|
25
23
|
import ai_edge_torch
|
|
26
24
|
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
|
|
27
25
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
28
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
27
|
+
import torch
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
def convert_tiny_llama_to_tflite(
|
|
@@ -22,16 +22,15 @@ import os
|
|
|
22
22
|
from pathlib import Path
|
|
23
23
|
from typing import Tuple
|
|
24
24
|
|
|
25
|
-
import numpy as np
|
|
26
|
-
import torch
|
|
27
|
-
import torch.nn as nn
|
|
28
|
-
|
|
29
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
30
26
|
import ai_edge_torch.generative.layers.builder as builder
|
|
31
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
32
28
|
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
|
33
29
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
34
30
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
31
|
+
import numpy as np
|
|
32
|
+
import torch
|
|
33
|
+
import torch.nn as nn
|
|
35
34
|
|
|
36
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
37
36
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
@@ -78,7 +77,9 @@ class TinyLLamma(nn.Module):
|
|
|
78
77
|
device=torch.device("cpu"),
|
|
79
78
|
)
|
|
80
79
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
81
|
-
size=config.kv_cache_max,
|
|
80
|
+
size=config.kv_cache_max,
|
|
81
|
+
dtype=torch.float32,
|
|
82
|
+
device=torch.device("cpu"),
|
|
82
83
|
)
|
|
83
84
|
self.config = config
|
|
84
85
|
|
|
@@ -90,9 +91,10 @@ class TinyLLamma(nn.Module):
|
|
|
90
91
|
kv_cache: kv_utils.EKVCache,
|
|
91
92
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
|
92
93
|
B, T = tokens.size()
|
|
93
|
-
assert (
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
assert self.config.max_seq_len >= T, (
|
|
95
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
96
|
+
f" {self.config.max_seq_len}"
|
|
97
|
+
)
|
|
96
98
|
|
|
97
99
|
cos, sin = self.rope_cache
|
|
98
100
|
cos = cos.index_select(0, input_pos)
|
|
@@ -16,11 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from pathlib import Path
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
19
|
import ai_edge_torch
|
|
22
20
|
from ai_edge_torch.generative.examples.gemma import gemma
|
|
23
21
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
import torch
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def convert_gemma_to_tflite(
|
|
@@ -58,7 +57,9 @@ def convert_gemma_to_tflite(
|
|
|
58
57
|
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
59
58
|
.convert(quant_config=quant_config)
|
|
60
59
|
)
|
|
61
|
-
edge_model.export(
|
|
60
|
+
edge_model.export(
|
|
61
|
+
f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
|
|
62
|
+
)
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
if __name__ == '__main__':
|
|
@@ -17,15 +17,14 @@
|
|
|
17
17
|
import os
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
|
|
20
|
-
import numpy as np
|
|
21
|
-
import torch
|
|
22
|
-
import torch.nn as nn
|
|
23
|
-
|
|
24
20
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
25
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
26
22
|
import ai_edge_torch.generative.layers.builder as builder
|
|
27
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
29
28
|
|
|
30
29
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
31
30
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
@@ -76,7 +75,9 @@ class Gemma(nn.Module):
|
|
|
76
75
|
device=torch.device("cpu"),
|
|
77
76
|
)
|
|
78
77
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
79
|
-
size=config.kv_cache_max,
|
|
78
|
+
size=config.kv_cache_max,
|
|
79
|
+
dtype=torch.float32,
|
|
80
|
+
device=torch.device("cpu"),
|
|
80
81
|
)
|
|
81
82
|
self.config = config
|
|
82
83
|
|
|
@@ -86,9 +87,10 @@ class Gemma(nn.Module):
|
|
|
86
87
|
@torch.inference_mode
|
|
87
88
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
88
89
|
B, T = idx.size()
|
|
89
|
-
assert (
|
|
90
|
-
|
|
91
|
-
|
|
90
|
+
assert self.config.max_seq_len >= T, (
|
|
91
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
92
|
+
f" {self.config.max_seq_len}"
|
|
93
|
+
)
|
|
92
94
|
|
|
93
95
|
cos, sin = self.rope_cache
|
|
94
96
|
cos = cos.index_select(0, input_pos)
|
|
@@ -171,7 +173,9 @@ def define_and_run_2b() -> None:
|
|
|
171
173
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
172
174
|
lm_logits = model.forward(tokens, input_pos)
|
|
173
175
|
print("comparing with goldens..")
|
|
174
|
-
assert torch.allclose(
|
|
176
|
+
assert torch.allclose(
|
|
177
|
+
gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
|
|
178
|
+
)
|
|
175
179
|
|
|
176
180
|
|
|
177
181
|
if __name__ == "__main__":
|
|
@@ -16,11 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from pathlib import Path
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
19
|
import ai_edge_torch
|
|
22
20
|
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
23
21
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
import torch
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def convert_phi2_to_tflite(
|
|
@@ -41,7 +40,9 @@ def convert_phi2_to_tflite(
|
|
|
41
40
|
quantize (bool, optional): Whether the model should be quanized.
|
|
42
41
|
Defaults to True.
|
|
43
42
|
"""
|
|
44
|
-
pytorch_model = phi2.build_model(
|
|
43
|
+
pytorch_model = phi2.build_model(
|
|
44
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
45
|
+
)
|
|
45
46
|
# Tensors used to trace the model graph during conversion.
|
|
46
47
|
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
47
48
|
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
@@ -56,7 +57,9 @@ def convert_phi2_to_tflite(
|
|
|
56
57
|
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
57
58
|
.convert(quant_config=quant_config)
|
|
58
59
|
)
|
|
59
|
-
edge_model.export(
|
|
60
|
+
edge_model.export(
|
|
61
|
+
f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
|
|
62
|
+
)
|
|
60
63
|
|
|
61
64
|
|
|
62
65
|
if __name__ == '__main__':
|
|
@@ -18,15 +18,14 @@
|
|
|
18
18
|
import os
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
|
|
21
|
-
import numpy as np
|
|
22
|
-
import torch
|
|
23
|
-
import torch.nn as nn
|
|
24
|
-
|
|
25
21
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
26
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
27
23
|
import ai_edge_torch.generative.layers.builder as builder
|
|
28
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
29
25
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
30
29
|
|
|
31
30
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
32
31
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
|
@@ -71,7 +70,9 @@ class Phi2(nn.Module):
|
|
|
71
70
|
device=torch.device("cpu"),
|
|
72
71
|
)
|
|
73
72
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
74
|
-
size=config.kv_cache_max,
|
|
73
|
+
size=config.kv_cache_max,
|
|
74
|
+
dtype=torch.float32,
|
|
75
|
+
device=torch.device("cpu"),
|
|
75
76
|
)
|
|
76
77
|
self.config = config
|
|
77
78
|
|
|
@@ -81,9 +82,10 @@ class Phi2(nn.Module):
|
|
|
81
82
|
@torch.inference_mode
|
|
82
83
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
83
84
|
B, T = idx.size()
|
|
84
|
-
assert (
|
|
85
|
-
|
|
86
|
-
|
|
85
|
+
assert self.config.max_seq_len >= T, (
|
|
86
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
87
|
+
f" {self.config.max_seq_len}"
|
|
88
|
+
)
|
|
87
89
|
|
|
88
90
|
cos, sin = self.rope_cache
|
|
89
91
|
cos = cos.index_select(0, input_pos)
|
|
@@ -160,7 +162,9 @@ def define_and_run() -> None:
|
|
|
160
162
|
input_pos = torch.arange(0, kv_cache_max_len)
|
|
161
163
|
lm_logits = model.forward(tokens, input_pos)
|
|
162
164
|
print("comparing with goldens..")
|
|
163
|
-
assert torch.allclose(
|
|
165
|
+
assert torch.allclose(
|
|
166
|
+
phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
|
|
167
|
+
)
|
|
164
168
|
|
|
165
169
|
|
|
166
170
|
if __name__ == "__main__":
|
|
@@ -73,7 +73,9 @@ class SelfAttention(nn.Module):
|
|
|
73
73
|
|
|
74
74
|
class CrossAttention(nn.Module):
|
|
75
75
|
|
|
76
|
-
def __init__(
|
|
76
|
+
def __init__(
|
|
77
|
+
self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True
|
|
78
|
+
):
|
|
77
79
|
super().__init__()
|
|
78
80
|
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
|
79
81
|
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
|
@@ -13,25 +13,34 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
16
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
20
17
|
import ai_edge_torch.generative.layers.attention_utils as attention_utils
|
|
21
18
|
import ai_edge_torch.generative.layers.builder as builder
|
|
22
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
20
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
21
|
+
import torch
|
|
22
|
+
from torch import nn
|
|
24
23
|
|
|
25
24
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
26
|
-
ff_up_proj=
|
|
27
|
-
|
|
25
|
+
ff_up_proj=(
|
|
26
|
+
"cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1"
|
|
27
|
+
),
|
|
28
|
+
ff_down_proj=(
|
|
29
|
+
"cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2"
|
|
30
|
+
),
|
|
28
31
|
attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
|
|
29
32
|
attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
|
|
30
33
|
attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
|
|
31
34
|
attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
|
|
32
|
-
pre_attn_norm=
|
|
33
|
-
|
|
34
|
-
|
|
35
|
+
pre_attn_norm=(
|
|
36
|
+
"cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1"
|
|
37
|
+
),
|
|
38
|
+
pre_ff_norm=(
|
|
39
|
+
"cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2"
|
|
40
|
+
),
|
|
41
|
+
embedding=(
|
|
42
|
+
"cond_stage_model.transformer.text_model.embeddings.token_embedding"
|
|
43
|
+
),
|
|
35
44
|
embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
|
36
45
|
final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
|
|
37
46
|
lm_head=None,
|
|
@@ -54,7 +63,9 @@ class CLIP(nn.Module):
|
|
|
54
63
|
self.transformer_blocks = nn.ModuleList(
|
|
55
64
|
TransformerBlock(config) for _ in range(config.num_layers)
|
|
56
65
|
)
|
|
57
|
-
self.final_norm = builder.build_norm(
|
|
66
|
+
self.final_norm = builder.build_norm(
|
|
67
|
+
config.embedding_dim, config.final_norm_config
|
|
68
|
+
)
|
|
58
69
|
|
|
59
70
|
self.mask_cache = attention_utils.build_causal_mask_cache(
|
|
60
71
|
size=config.max_seq_len, dtype=torch.float32
|