ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
- ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
- ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
# UNet configuration class.
|
|
17
|
-
|
|
18
|
-
from dataclasses import field
|
|
17
|
+
import dataclasses
|
|
19
18
|
import enum
|
|
20
19
|
from typing import List, Optional
|
|
21
20
|
|
|
@@ -30,13 +29,13 @@ class SamplingType(enum.Enum):
|
|
|
30
29
|
CONVOLUTION = enum.auto()
|
|
31
30
|
|
|
32
31
|
|
|
33
|
-
@dataclass
|
|
32
|
+
@dataclasses.dataclass
|
|
34
33
|
class UpSamplingConfig:
|
|
35
34
|
mode: SamplingType
|
|
36
35
|
scale_factor: float
|
|
37
36
|
|
|
38
37
|
|
|
39
|
-
@dataclass
|
|
38
|
+
@dataclasses.dataclass
|
|
40
39
|
class DownSamplingConfig:
|
|
41
40
|
mode: SamplingType
|
|
42
41
|
in_channels: int
|
|
@@ -46,7 +45,7 @@ class DownSamplingConfig:
|
|
|
46
45
|
out_channels: Optional[int] = None
|
|
47
46
|
|
|
48
47
|
|
|
49
|
-
@dataclass
|
|
48
|
+
@dataclasses.dataclass
|
|
50
49
|
class ResidualBlock2DConfig:
|
|
51
50
|
in_channels: int
|
|
52
51
|
out_channels: int
|
|
@@ -56,7 +55,7 @@ class ResidualBlock2DConfig:
|
|
|
56
55
|
time_embedding_channels: Optional[int] = None
|
|
57
56
|
|
|
58
57
|
|
|
59
|
-
@dataclass
|
|
58
|
+
@dataclasses.dataclass
|
|
60
59
|
class AttentionBlock2DConfig:
|
|
61
60
|
dim: int
|
|
62
61
|
normalization_config: layers_cfg.NormalizationConfig
|
|
@@ -65,7 +64,7 @@ class AttentionBlock2DConfig:
|
|
|
65
64
|
attention_batch_size: int = 1
|
|
66
65
|
|
|
67
66
|
|
|
68
|
-
@dataclass
|
|
67
|
+
@dataclasses.dataclass
|
|
69
68
|
class CrossAttentionBlock2DConfig:
|
|
70
69
|
query_dim: int
|
|
71
70
|
cross_dim: int
|
|
@@ -75,7 +74,7 @@ class CrossAttentionBlock2DConfig:
|
|
|
75
74
|
attention_batch_size: int = 1
|
|
76
75
|
|
|
77
76
|
|
|
78
|
-
@dataclass
|
|
77
|
+
@dataclasses.dataclass
|
|
79
78
|
class FeedForwardBlock2DConfig:
|
|
80
79
|
dim: int
|
|
81
80
|
hidden_dim: int
|
|
@@ -84,7 +83,7 @@ class FeedForwardBlock2DConfig:
|
|
|
84
83
|
use_bias: bool
|
|
85
84
|
|
|
86
85
|
|
|
87
|
-
@dataclass
|
|
86
|
+
@dataclasses.dataclass
|
|
88
87
|
class TransformerBlock2DConfig:
|
|
89
88
|
pre_conv_normalization_config: layers_cfg.NormalizationConfig
|
|
90
89
|
attention_block_config: AttentionBlock2DConfig
|
|
@@ -92,7 +91,7 @@ class TransformerBlock2DConfig:
|
|
|
92
91
|
feed_forward_block_config: FeedForwardBlock2DConfig
|
|
93
92
|
|
|
94
93
|
|
|
95
|
-
@dataclass
|
|
94
|
+
@dataclasses.dataclass
|
|
96
95
|
class UpDecoderBlock2DConfig:
|
|
97
96
|
in_channels: int
|
|
98
97
|
out_channels: int
|
|
@@ -113,7 +112,7 @@ class UpDecoderBlock2DConfig:
|
|
|
113
112
|
context_dim: Optional[int] = None
|
|
114
113
|
|
|
115
114
|
|
|
116
|
-
@dataclass
|
|
115
|
+
@dataclasses.dataclass
|
|
117
116
|
class SkipUpDecoderBlock2DConfig:
|
|
118
117
|
in_channels: int
|
|
119
118
|
out_channels: int
|
|
@@ -136,7 +135,7 @@ class SkipUpDecoderBlock2DConfig:
|
|
|
136
135
|
context_dim: Optional[int] = None
|
|
137
136
|
|
|
138
137
|
|
|
139
|
-
@dataclass
|
|
138
|
+
@dataclasses.dataclass
|
|
140
139
|
class DownEncoderBlock2DConfig:
|
|
141
140
|
in_channels: int
|
|
142
141
|
out_channels: int
|
|
@@ -157,7 +156,7 @@ class DownEncoderBlock2DConfig:
|
|
|
157
156
|
context_dim: Optional[int] = None
|
|
158
157
|
|
|
159
158
|
|
|
160
|
-
@dataclass
|
|
159
|
+
@dataclasses.dataclass
|
|
161
160
|
class MidBlock2DConfig:
|
|
162
161
|
in_channels: int
|
|
163
162
|
normalization_config: layers_cfg.NormalizationConfig
|
|
@@ -173,7 +172,7 @@ class MidBlock2DConfig:
|
|
|
173
172
|
context_dim: Optional[int] = None
|
|
174
173
|
|
|
175
174
|
|
|
176
|
-
@dataclass
|
|
175
|
+
@dataclasses.dataclass
|
|
177
176
|
class AutoEncoderConfig:
|
|
178
177
|
"""Configurations of encoder/decoder in the autoencoder model."""
|
|
179
178
|
|
|
@@ -210,7 +209,7 @@ class AutoEncoderConfig:
|
|
|
210
209
|
mid_block_config: MidBlock2DConfig
|
|
211
210
|
|
|
212
211
|
|
|
213
|
-
@dataclass
|
|
212
|
+
@dataclasses.dataclass
|
|
214
213
|
class DiffusionModelConfig:
|
|
215
214
|
"""Configurations of Diffusion model."""
|
|
216
215
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -13,8 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
|
-
|
|
18
16
|
from ai_edge_quantizer import quantizer
|
|
19
17
|
from ai_edge_torch.generative.quantize import quant_attrs
|
|
20
18
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
@@ -25,7 +25,8 @@ class LayerQuantRecipe:
|
|
|
25
25
|
"""Quantization recipe for a single Edge Generative API layer (e.g. Attention).
|
|
26
26
|
|
|
27
27
|
Generic layer-scoped quantization recipe that specifies how this layer should
|
|
28
|
-
be quantized by the Edge Generative API. This is applicable to layers
|
|
28
|
+
be quantized by the Edge Generative API. This is applicable to layers
|
|
29
|
+
implemented
|
|
29
30
|
in ai_edge_torch/generative/layers/. Combinations of attributes that are not
|
|
30
31
|
supported during runtime will be detected when .verify() is called.
|
|
31
32
|
|
|
@@ -83,7 +84,8 @@ class LayerQuantRecipe:
|
|
|
83
84
|
class GenerativeQuantRecipe:
|
|
84
85
|
"""Quantization recipe for a model composed of the Edge Generative API layers.
|
|
85
86
|
|
|
86
|
-
Some layers can be specified with different `LayerQuantRecipe` for each block
|
|
87
|
+
Some layers can be specified with different `LayerQuantRecipe` for each block
|
|
88
|
+
by
|
|
87
89
|
providing a dictionary keyed by the TransformerBlock index, e.g. attention
|
|
88
90
|
and feedforward. For example,
|
|
89
91
|
|
|
@@ -102,11 +104,11 @@ class GenerativeQuantRecipe:
|
|
|
102
104
|
default: The quantization recipe for global scope of the model.
|
|
103
105
|
embedding: Recipe for the embedding table.
|
|
104
106
|
attention: Recipe for the attention blocks. This could be specified with
|
|
105
|
-
different LayerQuantRecipe for each block by providing a dictionary
|
|
106
|
-
|
|
107
|
+
different LayerQuantRecipe for each block by providing a dictionary keyed
|
|
108
|
+
by the TransformerBlock index.
|
|
107
109
|
feedforward: Recipe for the feedforward layers. This could be specified with
|
|
108
|
-
different LayerQuantRecipe for each block by providing a dictionary
|
|
109
|
-
|
|
110
|
+
different LayerQuantRecipe for each block by providing a dictionary keyed
|
|
111
|
+
by the TransformerBlock index.
|
|
110
112
|
"""
|
|
111
113
|
|
|
112
114
|
default: Optional[LayerQuantRecipe] = None
|
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
"""Helper functions to construct custom quantization recipes.
|
|
17
17
|
|
|
18
18
|
These are intended for more advanced users who want to configure their own
|
|
19
|
-
quantization recipes. For pre-constructed recipes, use `quant_recipes.py`
|
|
19
|
+
quantization recipes. For pre-constructed recipes, use `quant_recipes.py`
|
|
20
|
+
instead.
|
|
20
21
|
|
|
21
22
|
Typical usage example:
|
|
22
23
|
|
|
@@ -14,24 +14,23 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# A suite of tests to validate experimental external KV Cache layers and models.
|
|
16
16
|
|
|
17
|
-
import unittest
|
|
18
|
-
|
|
19
17
|
from ai_edge_torch.generative.examples.experimental.gemma import gemma
|
|
20
18
|
from ai_edge_torch.generative.examples.experimental.phi import phi2
|
|
21
19
|
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
|
|
22
20
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
23
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
24
|
-
import numpy as np
|
|
25
22
|
import torch
|
|
26
23
|
|
|
24
|
+
from tensorflow.python.platform import googletest
|
|
25
|
+
|
|
27
26
|
|
|
28
|
-
class TestExternalKVLayers(
|
|
27
|
+
class TestExternalKVLayers(googletest.TestCase):
|
|
29
28
|
|
|
30
29
|
def _get_test_config(
|
|
31
30
|
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
|
32
31
|
):
|
|
33
32
|
attn_config = cfg.AttentionConfig(
|
|
34
|
-
num_heads=1, num_query_groups=num_query_groups
|
|
33
|
+
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
|
35
34
|
)
|
|
36
35
|
config = cfg.ModelConfig(
|
|
37
36
|
kv_cache_max_len=kv_cache_max_len,
|
|
@@ -117,7 +116,7 @@ class TestExternalKVLayers(unittest.TestCase):
|
|
|
117
116
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
|
118
117
|
|
|
119
118
|
|
|
120
|
-
class TestExternalKVModels(
|
|
119
|
+
class TestExternalKVModels(googletest.TestCase):
|
|
121
120
|
|
|
122
121
|
def test_can_build_gemma(self):
|
|
123
122
|
gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
|
|
@@ -130,4 +129,4 @@ class TestExternalKVModels(unittest.TestCase):
|
|
|
130
129
|
|
|
131
130
|
|
|
132
131
|
if __name__ == "__main__":
|
|
133
|
-
|
|
132
|
+
googletest.main()
|
|
@@ -16,15 +16,16 @@
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
18
|
import tempfile
|
|
19
|
-
import unittest
|
|
20
19
|
|
|
21
20
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
22
21
|
from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
23
22
|
import safetensors.torch
|
|
24
23
|
import torch
|
|
25
24
|
|
|
25
|
+
from tensorflow.python.platform import googletest
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
|
|
28
|
+
class TestLoader(googletest.TestCase):
|
|
28
29
|
"""Unit tests that check weight loader."""
|
|
29
30
|
|
|
30
31
|
def test_load_safetensors(self):
|
|
@@ -78,4 +79,4 @@ class TestLoader(unittest.TestCase):
|
|
|
78
79
|
|
|
79
80
|
|
|
80
81
|
if __name__ == "__main__":
|
|
81
|
-
|
|
82
|
+
googletest.main()
|
|
@@ -14,9 +14,6 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Testing model conversion for a few gen-ai models.
|
|
16
16
|
import copy
|
|
17
|
-
import os
|
|
18
|
-
import tempfile
|
|
19
|
-
import unittest
|
|
20
17
|
|
|
21
18
|
import ai_edge_torch
|
|
22
19
|
from ai_edge_torch.generative.examples.gemma import gemma
|
|
@@ -27,22 +24,24 @@ from ai_edge_torch.testing import model_coverage
|
|
|
27
24
|
import numpy as np
|
|
28
25
|
import torch
|
|
29
26
|
|
|
27
|
+
from tensorflow.python.platform import googletest
|
|
30
28
|
|
|
31
|
-
|
|
29
|
+
|
|
30
|
+
class TestModelConversion(googletest.TestCase):
|
|
32
31
|
"""Unit tests that check for model conversion and correctness."""
|
|
33
32
|
|
|
34
33
|
def test_toy_model_with_kv_cache(self):
|
|
35
34
|
config = toy_model_with_kv_cache.get_model_config()
|
|
36
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
35
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
|
37
36
|
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
38
37
|
[10], dtype=torch.int64
|
|
39
38
|
)
|
|
40
39
|
|
|
41
40
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
42
41
|
|
|
43
|
-
# TODO
|
|
42
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
44
43
|
skip_output_check = True
|
|
45
|
-
if skip_output_check
|
|
44
|
+
if not skip_output_check:
|
|
46
45
|
self.assertTrue(
|
|
47
46
|
model_coverage.compare_tflite_torch(
|
|
48
47
|
edge_model,
|
|
@@ -57,16 +56,16 @@ class TestModelConversion(unittest.TestCase):
|
|
|
57
56
|
def test_toy_model_with_multi_batches(self):
|
|
58
57
|
config = toy_model_with_kv_cache.get_model_config()
|
|
59
58
|
config.batch_size = 2
|
|
60
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
59
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
|
61
60
|
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
|
|
62
61
|
[10], dtype=torch.int64
|
|
63
62
|
)
|
|
64
63
|
|
|
65
64
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
66
65
|
|
|
67
|
-
# TODO
|
|
66
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
68
67
|
skip_output_check = True
|
|
69
|
-
if skip_output_check
|
|
68
|
+
if not skip_output_check:
|
|
70
69
|
self.assertTrue(
|
|
71
70
|
model_coverage.compare_tflite_torch(
|
|
72
71
|
edge_model,
|
|
@@ -81,16 +80,16 @@ class TestModelConversion(unittest.TestCase):
|
|
|
81
80
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
82
81
|
config = toy_model_with_kv_cache.get_model_config()
|
|
83
82
|
config.enable_hlfb = True
|
|
84
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
83
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
|
85
84
|
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
86
85
|
[10], dtype=torch.int64
|
|
87
86
|
)
|
|
88
87
|
|
|
89
88
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
90
89
|
|
|
91
|
-
# TODO
|
|
90
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
92
91
|
skip_output_check = True
|
|
93
|
-
if skip_output_check
|
|
92
|
+
if not skip_output_check:
|
|
94
93
|
self.assertTrue(
|
|
95
94
|
model_coverage.compare_tflite_torch(
|
|
96
95
|
edge_model,
|
|
@@ -105,7 +104,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
105
104
|
def test_tiny_llama(self):
|
|
106
105
|
self.skipTest("b/338288901")
|
|
107
106
|
config = tiny_llama.get_fake_model_config_for_test()
|
|
108
|
-
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
107
|
+
pytorch_model = tiny_llama.TinyLLamma(config).eval()
|
|
109
108
|
|
|
110
109
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
111
110
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
@@ -114,9 +113,9 @@ class TestModelConversion(unittest.TestCase):
|
|
|
114
113
|
|
|
115
114
|
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
116
115
|
|
|
117
|
-
# TODO
|
|
116
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
118
117
|
skip_output_check = True
|
|
119
|
-
if skip_output_check
|
|
118
|
+
if not skip_output_check:
|
|
120
119
|
self.assertTrue(
|
|
121
120
|
model_coverage.compare_tflite_torch(
|
|
122
121
|
edge_model,
|
|
@@ -130,7 +129,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
130
129
|
|
|
131
130
|
def test_tiny_llama_multisig(self):
|
|
132
131
|
config = tiny_llama.get_fake_model_config_for_test()
|
|
133
|
-
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
132
|
+
pytorch_model = tiny_llama.TinyLLamma(config).eval()
|
|
134
133
|
|
|
135
134
|
# prefill
|
|
136
135
|
seq_len = 10
|
|
@@ -151,9 +150,9 @@ class TestModelConversion(unittest.TestCase):
|
|
|
151
150
|
.convert()
|
|
152
151
|
)
|
|
153
152
|
|
|
154
|
-
# TODO
|
|
153
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
155
154
|
skip_output_check = True
|
|
156
|
-
if skip_output_check
|
|
155
|
+
if not skip_output_check:
|
|
157
156
|
copied_model = copy.deepcopy(pytorch_model)
|
|
158
157
|
|
|
159
158
|
self.assertTrue(
|
|
@@ -188,9 +187,9 @@ class TestModelConversion(unittest.TestCase):
|
|
|
188
187
|
|
|
189
188
|
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
190
189
|
|
|
191
|
-
# TODO
|
|
190
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
192
191
|
skip_output_check = True
|
|
193
|
-
if skip_output_check
|
|
192
|
+
if not skip_output_check:
|
|
194
193
|
# TODO(talumbau, haoliang): debug numerical diff.
|
|
195
194
|
self.assertTrue(
|
|
196
195
|
model_coverage.compare_tflite_torch(
|
|
@@ -206,7 +205,7 @@ class TestModelConversion(unittest.TestCase):
|
|
|
206
205
|
def test_phi2(self):
|
|
207
206
|
self.skipTest("b/338288901")
|
|
208
207
|
config = phi2.get_fake_model_config_for_test()
|
|
209
|
-
pytorch_model = phi2.Phi2(config)
|
|
208
|
+
pytorch_model = phi2.Phi2(config).eval()
|
|
210
209
|
|
|
211
210
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
212
211
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
@@ -215,9 +214,9 @@ class TestModelConversion(unittest.TestCase):
|
|
|
215
214
|
|
|
216
215
|
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
217
216
|
|
|
218
|
-
# TODO
|
|
217
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
219
218
|
skip_output_check = True
|
|
220
|
-
if skip_output_check
|
|
219
|
+
if not skip_output_check:
|
|
221
220
|
self.assertTrue(
|
|
222
221
|
model_coverage.compare_tflite_torch(
|
|
223
222
|
edge_model,
|
|
@@ -231,4 +230,4 @@ class TestModelConversion(unittest.TestCase):
|
|
|
231
230
|
|
|
232
231
|
|
|
233
232
|
if __name__ == "__main__":
|
|
234
|
-
|
|
233
|
+
googletest.main()
|
|
@@ -13,9 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import unittest
|
|
17
|
-
|
|
18
16
|
import ai_edge_torch
|
|
17
|
+
from ai_edge_torch import config
|
|
19
18
|
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
|
|
20
19
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
21
20
|
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
|
@@ -29,8 +28,10 @@ from ai_edge_torch.testing import model_coverage
|
|
|
29
28
|
from parameterized import parameterized
|
|
30
29
|
import torch
|
|
31
30
|
|
|
31
|
+
from tensorflow.python.platform import googletest
|
|
32
|
+
|
|
32
33
|
|
|
33
|
-
class TestVerifyRecipes(
|
|
34
|
+
class TestVerifyRecipes(googletest.TestCase):
|
|
34
35
|
"""Unit tests that check for model quantization recipes."""
|
|
35
36
|
|
|
36
37
|
@parameterized.expand([
|
|
@@ -87,7 +88,7 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
87
88
|
).verify()
|
|
88
89
|
|
|
89
90
|
|
|
90
|
-
class TestQuantizeConvert(
|
|
91
|
+
class TestQuantizeConvert(googletest.TestCase):
|
|
91
92
|
"""Test conversion with quantization."""
|
|
92
93
|
|
|
93
94
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
@@ -111,6 +112,10 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
111
112
|
(_attention_int8_dynamic_recipe()),
|
|
112
113
|
(_feedforward_int8_dynamic_recipe()),
|
|
113
114
|
])
|
|
115
|
+
@googletest.skipIf(
|
|
116
|
+
not config.Config.use_torch_xla,
|
|
117
|
+
reason="Not working with odml_torch at the moment.",
|
|
118
|
+
)
|
|
114
119
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
|
115
120
|
config = toy_model.get_model_config()
|
|
116
121
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
|
@@ -157,4 +162,4 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
157
162
|
|
|
158
163
|
|
|
159
164
|
if __name__ == "__main__":
|
|
160
|
-
|
|
165
|
+
googletest.main()
|
|
@@ -92,9 +92,7 @@ def load_pytorch_statedict(full_path: str):
|
|
|
92
92
|
|
|
93
93
|
|
|
94
94
|
class ModelLoader:
|
|
95
|
-
"""
|
|
96
|
-
Edge Generative API layer format.
|
|
97
|
-
"""
|
|
95
|
+
"""Utlity for loading model checkpoints to the Edge Generative API layer."""
|
|
98
96
|
|
|
99
97
|
@dataclass
|
|
100
98
|
class TensorNames:
|
|
@@ -116,12 +114,13 @@ class ModelLoader:
|
|
|
116
114
|
lm_head: str = None
|
|
117
115
|
|
|
118
116
|
def __init__(self, file_name: str, names: TensorNames) -> None:
|
|
119
|
-
"""ModelLoader constructor.
|
|
120
|
-
|
|
117
|
+
"""ModelLoader constructor.
|
|
118
|
+
|
|
119
|
+
Can be used to load multiple models of the same type.
|
|
121
120
|
|
|
122
121
|
Args:
|
|
123
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
124
|
-
|
|
122
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
|
123
|
+
file.
|
|
125
124
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
126
125
|
"""
|
|
127
126
|
self._file_name = file_name
|
|
@@ -140,7 +139,8 @@ class ModelLoader:
|
|
|
140
139
|
|
|
141
140
|
Returns:
|
|
142
141
|
missing_keys (List[str]): a list of str containing the missing keys.
|
|
143
|
-
unexpected_keys (List[str]): a list of str containing the unexpected
|
|
142
|
+
unexpected_keys (List[str]): a list of str containing the unexpected
|
|
143
|
+
keys.
|
|
144
144
|
|
|
145
145
|
Raises:
|
|
146
146
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
@@ -208,7 +208,7 @@ class ModelLoader:
|
|
|
208
208
|
if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
|
|
209
209
|
return load_pytorch_statedict
|
|
210
210
|
|
|
211
|
-
raise ValueError(
|
|
211
|
+
raise ValueError("File format not supported.")
|
|
212
212
|
|
|
213
213
|
def _map_feedforward(
|
|
214
214
|
self,
|
|
@@ -346,9 +346,9 @@ class ModelLoader:
|
|
|
346
346
|
q_per_kv = (
|
|
347
347
|
config.attn_config.num_heads // config.attn_config.num_query_groups
|
|
348
348
|
)
|
|
349
|
-
qs = torch.split(q, config.head_dim * q_per_kv)
|
|
350
|
-
ks = torch.split(k, config.head_dim)
|
|
351
|
-
vs = torch.split(v, config.head_dim)
|
|
349
|
+
qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
|
|
350
|
+
ks = torch.split(k, config.attn_config.head_dim)
|
|
351
|
+
vs = torch.split(v, config.attn_config.head_dim)
|
|
352
352
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
|
353
353
|
return torch.cat(cycled)
|
|
354
354
|
else:
|