ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +30 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +330 -0
- ai_edge_torch/convert/converter.py +171 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
- ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +273 -0
- ai_edge_torch/convert/test/test_convert_composites.py +171 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/debug/__init__.py +16 -0
- ai_edge_torch/debug/culprit.py +423 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +288 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +103 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +135 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_model_conversion.py +201 -0
- ai_edge_torch/generative/test/test_quantize.py +109 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +290 -0
- ai_edge_torch/generative/utilities/t5_loader.py +467 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +134 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
import enum
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
21
|
+
from ai_edge_torch.generative.quantize import supported_schemes
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class LayerQuantRecipe:
|
|
26
|
+
"""Quantization recipe for a single Edge Generative API layer (e.g. Attention).
|
|
27
|
+
|
|
28
|
+
Generic layer-scoped quantization recipe that specifies how this layer should
|
|
29
|
+
be quantized by the Edge Generative API. This is applicable to layers implemented
|
|
30
|
+
in ai_edge_torch/generative/layers/. Combinations of attributes that are not
|
|
31
|
+
supported during runtime will be detected when .verify() is called.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
activation_dtype: Desired data type of activation tensors.
|
|
35
|
+
weight_dtype: Desired data type of weight tensors.
|
|
36
|
+
mode: Type of quantization.
|
|
37
|
+
algorithm: Algorithm for calculating quantization parameters.
|
|
38
|
+
granularity: Granularity of quantization.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
activation_dtype: quant_attrs.Dtype
|
|
42
|
+
weight_dtype: quant_attrs.Dtype
|
|
43
|
+
mode: quant_attrs.Mode
|
|
44
|
+
algorithm: quant_attrs.Algorithm
|
|
45
|
+
granularity: quant_attrs.Granularity
|
|
46
|
+
|
|
47
|
+
def __str__(self):
|
|
48
|
+
return (
|
|
49
|
+
f'(a:{self.activation_dtype.name}, '
|
|
50
|
+
f'w:{self.weight_dtype.name}, '
|
|
51
|
+
f'{self.mode.name}, '
|
|
52
|
+
f'{self.algorithm.name}, '
|
|
53
|
+
f'{self.granularity.name})'
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
__repr__ = __str__
|
|
57
|
+
|
|
58
|
+
def verify(self):
|
|
59
|
+
"""Checks if all attributes configured are supported in runtime.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If any attributes are incompatible.
|
|
63
|
+
"""
|
|
64
|
+
is_valid = False
|
|
65
|
+
for supported in supported_schemes.get_supported_layer_schemes():
|
|
66
|
+
if (
|
|
67
|
+
self.activation_dtype == supported[0]
|
|
68
|
+
and self.weight_dtype == supported[1]
|
|
69
|
+
and self.mode == supported[2]
|
|
70
|
+
and self.algorithm == supported[3]
|
|
71
|
+
and self.granularity == supported[4]
|
|
72
|
+
):
|
|
73
|
+
is_valid = True
|
|
74
|
+
break
|
|
75
|
+
|
|
76
|
+
if not is_valid:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
'Unsupported LayerQuantRecipe configuration. See get_supported_recipe_matrix()'
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass
|
|
83
|
+
class TransformerQuantRecipe:
|
|
84
|
+
"""Quantization recipe for a model composed of the Edge Generative API layers.
|
|
85
|
+
|
|
86
|
+
Attributes:
|
|
87
|
+
default: The quantization recipe for global scope of the model.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
default: Optional[LayerQuantRecipe] = None
|
|
91
|
+
|
|
92
|
+
def __str__(self):
|
|
93
|
+
return f"""TransformerQuantRecipe(
|
|
94
|
+
Default: {self.default}
|
|
95
|
+
)"""
|
|
96
|
+
|
|
97
|
+
__repr__ = __str__
|
|
98
|
+
|
|
99
|
+
def verify(self):
|
|
100
|
+
"""Checks if the recipe configured can be supported in runtime.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
ValueError: If the recipe configured is invalid or unsupported.
|
|
104
|
+
"""
|
|
105
|
+
if self.default is not None:
|
|
106
|
+
self.default.verify()
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Helper functions to construct custom quantization recipes.
|
|
17
|
+
|
|
18
|
+
These are intended for more advanced users who want to configure their own
|
|
19
|
+
quantization recipes. For pre-constructed recipes, use `quant_recipes.py` instead.
|
|
20
|
+
|
|
21
|
+
Typical usage example:
|
|
22
|
+
|
|
23
|
+
1. Applying a single layer recipe to the entire model
|
|
24
|
+
|
|
25
|
+
quant_recipe.TransformerQuantRecipe(
|
|
26
|
+
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
31
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
|
|
35
|
+
return quant_recipe.LayerQuantRecipe(
|
|
36
|
+
activation_dtype=quant_attrs.Dtype.FP32,
|
|
37
|
+
weight_dtype=quant_attrs.Dtype.INT8,
|
|
38
|
+
mode=quant_attrs.Mode.DYNAMIC_RANGE,
|
|
39
|
+
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
40
|
+
granularity=quant_attrs.Granularity.CHANNELWISE,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
45
|
+
return quant_recipe.LayerQuantRecipe(
|
|
46
|
+
activation_dtype=quant_attrs.Dtype.FP32,
|
|
47
|
+
weight_dtype=quant_attrs.Dtype.FP16,
|
|
48
|
+
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
49
|
+
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
50
|
+
granularity=quant_attrs.Granularity.NONE,
|
|
51
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Helper functions to create common and supported quantization recipes.
|
|
17
|
+
|
|
18
|
+
These recipes will work with models created with the Edge Generative API only.
|
|
19
|
+
Assume Transformer architecture congruent with
|
|
20
|
+
ai_edge_torch/generative/layers/model_config.py:ModelConfig.
|
|
21
|
+
|
|
22
|
+
Typical usage example:
|
|
23
|
+
|
|
24
|
+
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
|
|
25
|
+
edge_model = ai_edge_torch.convert(
|
|
26
|
+
model, (tokens, input_pos), quant_config=quant_config
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
31
|
+
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
|
32
|
+
from ai_edge_torch.quantize import quant_config
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
36
|
+
return quant_config.QuantConfig(
|
|
37
|
+
transformer_recipe=quant_recipe.TransformerQuantRecipe(
|
|
38
|
+
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def full_fp16_recipe() -> quant_config.QuantConfig:
|
|
44
|
+
return quant_config.QuantConfig(
|
|
45
|
+
transformer_recipe=quant_recipe.TransformerQuantRecipe(
|
|
46
|
+
default=quant_recipe_utils.create_layer_quant_fp16()
|
|
47
|
+
)
|
|
48
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_supported_layer_schemes():
|
|
18
|
+
"""List of layer-scoped quantization schemes supported in runtime.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
List of tuple(activation_dtype, weight_dtype, mode, algorithm, granularity).
|
|
22
|
+
"""
|
|
23
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Algorithm as _a
|
|
24
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Dtype as _t
|
|
25
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Granularity as _g
|
|
26
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Mode as _m
|
|
27
|
+
|
|
28
|
+
return [
|
|
29
|
+
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
|
|
30
|
+
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.NONE),
|
|
31
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Testing model conversion for a few gen-ai models.
|
|
16
|
+
import copy
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
|
26
|
+
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
27
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
28
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
29
|
+
from ai_edge_torch.testing import model_coverage
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestModelConversion(unittest.TestCase):
|
|
33
|
+
"""Unit tests that check for model conversion and correctness."""
|
|
34
|
+
|
|
35
|
+
def test_toy_model_with_kv_cache(self):
|
|
36
|
+
self.skipTest("b/338288901")
|
|
37
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
38
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
39
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
40
|
+
[10], dtype=torch.int64
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
44
|
+
|
|
45
|
+
self.assertTrue(
|
|
46
|
+
model_coverage.compare_tflite_torch(
|
|
47
|
+
edge_model,
|
|
48
|
+
pytorch_model,
|
|
49
|
+
(idx, input_pos),
|
|
50
|
+
num_valid_inputs=1,
|
|
51
|
+
atol=1e-5,
|
|
52
|
+
rtol=1e-5,
|
|
53
|
+
)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
57
|
+
self.skipTest("b/338288901")
|
|
58
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
59
|
+
config.enable_hlfb = True
|
|
60
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
61
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
62
|
+
[10], dtype=torch.int64
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
66
|
+
|
|
67
|
+
self.assertTrue(
|
|
68
|
+
model_coverage.compare_tflite_torch(
|
|
69
|
+
edge_model,
|
|
70
|
+
pytorch_model,
|
|
71
|
+
(idx, input_pos),
|
|
72
|
+
num_valid_inputs=1,
|
|
73
|
+
atol=1e-5,
|
|
74
|
+
rtol=1e-5,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def test_tiny_llama(self):
|
|
79
|
+
self.skipTest("b/338288901")
|
|
80
|
+
config = tiny_llama.get_fake_model_config_for_test()
|
|
81
|
+
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
82
|
+
|
|
83
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
84
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
85
|
+
tokens[0, :4] = idx
|
|
86
|
+
input_pos = torch.arange(0, 10)
|
|
87
|
+
|
|
88
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
89
|
+
|
|
90
|
+
self.assertTrue(
|
|
91
|
+
model_coverage.compare_tflite_torch(
|
|
92
|
+
edge_model,
|
|
93
|
+
pytorch_model,
|
|
94
|
+
(tokens, input_pos),
|
|
95
|
+
num_valid_inputs=1,
|
|
96
|
+
atol=1e-5,
|
|
97
|
+
rtol=1e-5,
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def test_tiny_llama_multisig(self):
|
|
102
|
+
self.skipTest("b/338288901")
|
|
103
|
+
config = tiny_llama.get_fake_model_config_for_test()
|
|
104
|
+
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
105
|
+
|
|
106
|
+
# prefill
|
|
107
|
+
seq_len = 10
|
|
108
|
+
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
|
|
109
|
+
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
|
110
|
+
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
|
111
|
+
prefill_input_pos = torch.arange(0, seq_len)
|
|
112
|
+
|
|
113
|
+
# decode
|
|
114
|
+
decode_token = torch.tensor([[1]], dtype=torch.long)
|
|
115
|
+
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
|
116
|
+
|
|
117
|
+
edge_model = (
|
|
118
|
+
ai_edge_torch.signature(
|
|
119
|
+
"prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
|
|
120
|
+
)
|
|
121
|
+
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
|
122
|
+
.convert()
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# For the pytorch model, the KV cache is a persistent state internal to the model, and it
|
|
126
|
+
# will be shared for prefill and decode. However, for tflite, currently we can't share
|
|
127
|
+
# kv-cache between the two signatures. prefill will change the content in kv-cache,
|
|
128
|
+
# but it won't be readable by the decode tflite model. This means the output of running `decode` after
|
|
129
|
+
# running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
|
|
130
|
+
copied_model = copy.deepcopy(pytorch_model)
|
|
131
|
+
|
|
132
|
+
self.assertTrue(
|
|
133
|
+
model_coverage.compare_tflite_torch(
|
|
134
|
+
edge_model,
|
|
135
|
+
pytorch_model,
|
|
136
|
+
(prefill_tokens, prefill_input_pos),
|
|
137
|
+
signature_name="prefill",
|
|
138
|
+
num_valid_inputs=1,
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
self.assertTrue(
|
|
143
|
+
model_coverage.compare_tflite_torch(
|
|
144
|
+
edge_model,
|
|
145
|
+
copied_model,
|
|
146
|
+
(decode_token, decode_input_pos),
|
|
147
|
+
signature_name="decode",
|
|
148
|
+
num_valid_inputs=1,
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def test_gemma(self):
|
|
153
|
+
self.skipTest("b/338288901")
|
|
154
|
+
config = gemma.get_fake_model_config_2b_for_test()
|
|
155
|
+
model = gemma.Gemma(config)
|
|
156
|
+
|
|
157
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
158
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
159
|
+
tokens[0, :4] = idx
|
|
160
|
+
input_pos = torch.arange(0, 10)
|
|
161
|
+
|
|
162
|
+
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
163
|
+
|
|
164
|
+
# TODO(talumbau, haoliang): debug numerical diff.
|
|
165
|
+
self.assertTrue(
|
|
166
|
+
model_coverage.compare_tflite_torch(
|
|
167
|
+
edge_model,
|
|
168
|
+
model,
|
|
169
|
+
(tokens, input_pos),
|
|
170
|
+
num_valid_inputs=1,
|
|
171
|
+
atol=1e-2,
|
|
172
|
+
rtol=1e-5,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def test_phi2(self):
|
|
177
|
+
self.skipTest("b/338288901")
|
|
178
|
+
config = phi2.get_fake_model_config_for_test()
|
|
179
|
+
pytorch_model = phi2.Phi2(config)
|
|
180
|
+
|
|
181
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
182
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
183
|
+
tokens[0, :4] = idx
|
|
184
|
+
input_pos = torch.arange(0, 10)
|
|
185
|
+
|
|
186
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
187
|
+
|
|
188
|
+
self.assertTrue(
|
|
189
|
+
model_coverage.compare_tflite_torch(
|
|
190
|
+
edge_model,
|
|
191
|
+
pytorch_model,
|
|
192
|
+
(tokens, input_pos),
|
|
193
|
+
num_valid_inputs=1,
|
|
194
|
+
atol=1e-5,
|
|
195
|
+
rtol=1e-5,
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
if __name__ == "__main__":
|
|
201
|
+
unittest.main()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
from parameterized import parameterized
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
import ai_edge_torch
|
|
22
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
23
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
24
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
25
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Algorithm
|
|
26
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Dtype
|
|
27
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
28
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
|
29
|
+
from ai_edge_torch.testing import model_coverage
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestVerifyRecipes(unittest.TestCase):
|
|
33
|
+
"""Unit tests that check for model quantization recipes."""
|
|
34
|
+
|
|
35
|
+
@parameterized.expand(
|
|
36
|
+
[
|
|
37
|
+
(Dtype.FP32, Dtype.FP32, Mode.DYNAMIC_RANGE),
|
|
38
|
+
(Dtype.INT8, Dtype.INT8, Mode.DYNAMIC_RANGE),
|
|
39
|
+
(Dtype.INT8, Dtype.FP16, Mode.DYNAMIC_RANGE),
|
|
40
|
+
(Dtype.FP16, Dtype.INT8, Mode.DYNAMIC_RANGE),
|
|
41
|
+
(Dtype.FP32, Dtype.FP32, Mode.WEIGHT_ONLY),
|
|
42
|
+
(Dtype.INT8, Dtype.INT8, Mode.WEIGHT_ONLY),
|
|
43
|
+
(Dtype.FP16, Dtype.INT8, Mode.WEIGHT_ONLY),
|
|
44
|
+
(Dtype.INT8, Dtype.FP16, Mode.WEIGHT_ONLY),
|
|
45
|
+
(Dtype.FP16, Dtype.FP16, Mode.WEIGHT_ONLY),
|
|
46
|
+
]
|
|
47
|
+
)
|
|
48
|
+
def test_verify_invalid_recipes(
|
|
49
|
+
self,
|
|
50
|
+
activation,
|
|
51
|
+
weight,
|
|
52
|
+
mode,
|
|
53
|
+
algo=Algorithm.MIN_MAX,
|
|
54
|
+
granularity=Granularity.CHANNELWISE,
|
|
55
|
+
):
|
|
56
|
+
with self.assertRaises(ValueError):
|
|
57
|
+
quant_recipe.LayerQuantRecipe(
|
|
58
|
+
activation, weight, mode, algo, granularity
|
|
59
|
+
).verify()
|
|
60
|
+
|
|
61
|
+
@parameterized.expand(
|
|
62
|
+
[
|
|
63
|
+
(Dtype.FP32, Dtype.INT8, Mode.DYNAMIC_RANGE, Granularity.CHANNELWISE),
|
|
64
|
+
(Dtype.FP32, Dtype.FP16, Mode.WEIGHT_ONLY, Granularity.NONE),
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
def test_verify_valid_recipes(
|
|
68
|
+
self,
|
|
69
|
+
activation,
|
|
70
|
+
weight,
|
|
71
|
+
mode,
|
|
72
|
+
granularity,
|
|
73
|
+
algo=Algorithm.MIN_MAX,
|
|
74
|
+
):
|
|
75
|
+
quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TestQuantizeConvert(unittest.TestCase):
|
|
79
|
+
"""Test conversion with quantization."""
|
|
80
|
+
|
|
81
|
+
def test_quantize_convert_toy(self):
|
|
82
|
+
self.skipTest("b/338288901")
|
|
83
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
84
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
85
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
86
|
+
[10], dtype=torch.int64
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
quant_config = quant_recipes.full_fp16_recipe()
|
|
90
|
+
quantized_model = ai_edge_torch.convert(
|
|
91
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
92
|
+
)
|
|
93
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
94
|
+
|
|
95
|
+
self.assertLess(len(quantized_model._tflite_model), len(float_model._tflite_model))
|
|
96
|
+
self.assertTrue(
|
|
97
|
+
model_coverage.compare_tflite_torch(
|
|
98
|
+
quantized_model,
|
|
99
|
+
pytorch_model,
|
|
100
|
+
(idx, input_pos),
|
|
101
|
+
num_valid_inputs=1,
|
|
102
|
+
atol=1e-3,
|
|
103
|
+
rtol=1e-3,
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == "__main__":
|
|
109
|
+
unittest.main()
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# This module contains common utility functions.
|