ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Optional, Union
|
|
18
|
+
|
|
19
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
20
|
+
from ai_edge_torch.generative.quantize import supported_schemes
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class LayerQuantRecipe:
|
|
25
|
+
"""Quantization recipe for a single Edge Generative API layer (e.g. Attention).
|
|
26
|
+
|
|
27
|
+
Generic layer-scoped quantization recipe that specifies how this layer should
|
|
28
|
+
be quantized by the Edge Generative API. This is applicable to layers implemented
|
|
29
|
+
in ai_edge_torch/generative/layers/. Combinations of attributes that are not
|
|
30
|
+
supported during runtime will be detected when .verify() is called.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
activation_dtype: Desired data type of activation tensors.
|
|
34
|
+
weight_dtype: Desired data type of weight tensors.
|
|
35
|
+
mode: Type of quantization.
|
|
36
|
+
algorithm: Algorithm for calculating quantization parameters.
|
|
37
|
+
granularity: Granularity of quantization.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
activation_dtype: quant_attrs.Dtype
|
|
41
|
+
weight_dtype: quant_attrs.Dtype
|
|
42
|
+
mode: quant_attrs.Mode
|
|
43
|
+
algorithm: quant_attrs.Algorithm
|
|
44
|
+
granularity: quant_attrs.Granularity
|
|
45
|
+
|
|
46
|
+
def __str__(self):
|
|
47
|
+
return (
|
|
48
|
+
f'(a:{self.activation_dtype.name}, '
|
|
49
|
+
f'w:{self.weight_dtype.name}, '
|
|
50
|
+
f'{self.mode.name}, '
|
|
51
|
+
f'{self.algorithm.name}, '
|
|
52
|
+
f'{self.granularity.name})'
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
__repr__ = __str__
|
|
56
|
+
|
|
57
|
+
def verify(self):
|
|
58
|
+
"""Checks if all attributes configured are supported in runtime.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If any attributes are incompatible.
|
|
62
|
+
"""
|
|
63
|
+
is_valid = False
|
|
64
|
+
for supported in supported_schemes.get_supported_layer_schemes():
|
|
65
|
+
if (
|
|
66
|
+
self.activation_dtype == supported[0]
|
|
67
|
+
and self.weight_dtype == supported[1]
|
|
68
|
+
and self.mode == supported[2]
|
|
69
|
+
and self.algorithm == supported[3]
|
|
70
|
+
and self.granularity == supported[4]
|
|
71
|
+
):
|
|
72
|
+
is_valid = True
|
|
73
|
+
break
|
|
74
|
+
|
|
75
|
+
if not is_valid:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
'Unsupported LayerQuantRecipe configuration. See get_supported_recipe_matrix()'
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class GenerativeQuantRecipe:
|
|
83
|
+
"""Quantization recipe for a model composed of the Edge Generative API layers.
|
|
84
|
+
|
|
85
|
+
Some layers can be specified with different `LayerQuantRecipe` for each block by
|
|
86
|
+
providing a dictionary keyed by the TransformerBlock index, e.g. attention
|
|
87
|
+
and feedforward. For example,
|
|
88
|
+
|
|
89
|
+
```
|
|
90
|
+
default = LayerQuantRecipeA
|
|
91
|
+
attention = { 2: LayerQuantRecipeB }
|
|
92
|
+
feedforward = { 3: LayerQuantRecipeC }
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
will apply LayerQuantRecipeA to the entire model, overriden by
|
|
96
|
+
LayerQuantRecipeB for the TransformerBlock[2].attention layer and
|
|
97
|
+
LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
|
|
98
|
+
with invalid indices will be ignored.
|
|
99
|
+
|
|
100
|
+
Attributes:
|
|
101
|
+
default: The quantization recipe for global scope of the model.
|
|
102
|
+
embedding: Recipe for the embedding table.
|
|
103
|
+
attention: Recipe for the attention blocks. This could be specified with
|
|
104
|
+
different LayerQuantRecipe for each block by providing a dictionary
|
|
105
|
+
keyed by the TransformerBlock index.
|
|
106
|
+
feedforward: Recipe for the feedforward layers. This could be specified with
|
|
107
|
+
different LayerQuantRecipe for each block by providing a dictionary
|
|
108
|
+
keyed by the TransformerBlock index.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
default: Optional[LayerQuantRecipe] = None
|
|
112
|
+
embedding: Optional[LayerQuantRecipe] = None
|
|
113
|
+
attention: Union[
|
|
114
|
+
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
|
115
|
+
] = None
|
|
116
|
+
feedforward: Union[
|
|
117
|
+
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
|
118
|
+
] = None
|
|
119
|
+
|
|
120
|
+
def __str__(self):
|
|
121
|
+
return f"""GenerativeQuantRecipe(
|
|
122
|
+
Default: {self.default}
|
|
123
|
+
Embedding: {self.embedding}
|
|
124
|
+
Attention: {self.attention}
|
|
125
|
+
Feedforward: {self.feedforward}
|
|
126
|
+
)"""
|
|
127
|
+
|
|
128
|
+
__repr__ = __str__
|
|
129
|
+
|
|
130
|
+
def verify(self):
|
|
131
|
+
"""Checks if the recipe configured can be supported in runtime.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError: If the recipe configured is invalid or unsupported.
|
|
135
|
+
"""
|
|
136
|
+
if self.default is not None:
|
|
137
|
+
self.default.verify()
|
|
138
|
+
if self.embedding is not None:
|
|
139
|
+
self.embedding.verify()
|
|
140
|
+
if self.attention is not None:
|
|
141
|
+
if isinstance(self.attention, dict):
|
|
142
|
+
for recipe in self.attention.values():
|
|
143
|
+
recipe.verify()
|
|
144
|
+
else:
|
|
145
|
+
self.attention.verify()
|
|
146
|
+
if self.feedforward is not None:
|
|
147
|
+
if isinstance(self.feedforward, dict):
|
|
148
|
+
for recipe in self.feedforward.values():
|
|
149
|
+
recipe.verify()
|
|
150
|
+
else:
|
|
151
|
+
self.feedforward.verify()
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Helper functions to construct custom quantization recipes.
|
|
17
|
+
|
|
18
|
+
These are intended for more advanced users who want to configure their own
|
|
19
|
+
quantization recipes. For pre-constructed recipes, use `quant_recipes.py` instead.
|
|
20
|
+
|
|
21
|
+
Typical usage example:
|
|
22
|
+
|
|
23
|
+
1. Applying a single layer recipe to the entire model
|
|
24
|
+
|
|
25
|
+
quant_recipe.GenerativeQuantRecipe(
|
|
26
|
+
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
31
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
|
|
35
|
+
return quant_recipe.LayerQuantRecipe(
|
|
36
|
+
activation_dtype=quant_attrs.Dtype.FP32,
|
|
37
|
+
weight_dtype=quant_attrs.Dtype.INT8,
|
|
38
|
+
mode=quant_attrs.Mode.DYNAMIC_RANGE,
|
|
39
|
+
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
40
|
+
granularity=quant_attrs.Granularity.CHANNELWISE,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
45
|
+
return quant_recipe.LayerQuantRecipe(
|
|
46
|
+
activation_dtype=quant_attrs.Dtype.FP32,
|
|
47
|
+
weight_dtype=quant_attrs.Dtype.FP16,
|
|
48
|
+
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
49
|
+
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
|
|
50
|
+
granularity=quant_attrs.Granularity.NONE,
|
|
51
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Helper functions to create common and supported quantization recipes.
|
|
17
|
+
|
|
18
|
+
These recipes will work with models created with the Edge Generative API only.
|
|
19
|
+
Assume Transformer architecture congruent with
|
|
20
|
+
ai_edge_torch/generative/layers/model_config.py:ModelConfig.
|
|
21
|
+
|
|
22
|
+
Typical usage example:
|
|
23
|
+
|
|
24
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
|
25
|
+
edge_model = ai_edge_torch.convert(
|
|
26
|
+
model, (tokens, input_pos), quant_config=quant_config
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from ai_edge_torch.generative.quantize import quant_recipe
|
|
31
|
+
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
|
32
|
+
from ai_edge_torch.quantize import quant_config
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
36
|
+
return quant_config.QuantConfig(
|
|
37
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
38
|
+
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def full_fp16_recipe() -> quant_config.QuantConfig:
|
|
44
|
+
return quant_config.QuantConfig(
|
|
45
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
46
|
+
default=quant_recipe_utils.create_layer_quant_fp16()
|
|
47
|
+
)
|
|
48
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_supported_layer_schemes():
|
|
18
|
+
"""List of layer-scoped quantization schemes supported in runtime.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
List of tuple(activation_dtype, weight_dtype, mode, algorithm, granularity).
|
|
22
|
+
"""
|
|
23
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Algorithm as _a
|
|
24
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Dtype as _t
|
|
25
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Granularity as _g
|
|
26
|
+
from ai_edge_torch.generative.quantize.quant_attrs import Mode as _m
|
|
27
|
+
|
|
28
|
+
return [
|
|
29
|
+
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
|
|
30
|
+
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
|
|
31
|
+
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
|
|
32
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Testing weight loader utilities.
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import safetensors.torch
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
25
|
+
from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestLoader(unittest.TestCase):
|
|
29
|
+
"""Unit tests that check weight loader."""
|
|
30
|
+
|
|
31
|
+
def test_load_safetensors(self):
|
|
32
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
33
|
+
file_path = os.path.join(temp_dir, "test.safetensors")
|
|
34
|
+
test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
|
|
35
|
+
safetensors.torch.save_file(test_data, file_path)
|
|
36
|
+
|
|
37
|
+
loaded_tensors = loading_utils.load_safetensors(file_path)
|
|
38
|
+
self.assertIn("weight", loaded_tensors)
|
|
39
|
+
self.assertIn("bias", loaded_tensors)
|
|
40
|
+
|
|
41
|
+
def test_load_statedict(self):
|
|
42
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
43
|
+
file_path = os.path.join(temp_dir, "test.pt")
|
|
44
|
+
model = torch.nn.Linear(10, 5)
|
|
45
|
+
state_dict = model.state_dict()
|
|
46
|
+
torch.save(state_dict, file_path)
|
|
47
|
+
|
|
48
|
+
loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
|
|
49
|
+
self.assertIn("weight", loaded_tensors)
|
|
50
|
+
self.assertIn("bias", loaded_tensors)
|
|
51
|
+
|
|
52
|
+
def test_model_loader(self):
|
|
53
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
54
|
+
file_path = os.path.join(temp_dir, "test.safetensors")
|
|
55
|
+
test_weights = {
|
|
56
|
+
"lm_head.weight": torch.randn((32000, 2048)),
|
|
57
|
+
"model.embed_tokens.weight": torch.randn((32000, 2048)),
|
|
58
|
+
"model.layers.0.input_layernorm.weight": torch.randn((2048,)),
|
|
59
|
+
"model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
|
|
60
|
+
"model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
|
|
61
|
+
"model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
|
|
62
|
+
"model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
|
|
63
|
+
"model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
|
|
64
|
+
"model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
|
|
65
|
+
"model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
|
|
66
|
+
"model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
|
|
67
|
+
"model.norm.weight": torch.randn((2048,)),
|
|
68
|
+
}
|
|
69
|
+
safetensors.torch.save_file(test_weights, file_path)
|
|
70
|
+
cfg = tiny_llama.get_model_config()
|
|
71
|
+
cfg.num_layers = 1
|
|
72
|
+
model = tiny_llama.TinyLLamma(cfg)
|
|
73
|
+
|
|
74
|
+
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
|
75
|
+
# if returns successfully, it means all the tensors were initiallized.
|
|
76
|
+
loader.load(model, strict=True)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if __name__ == "__main__":
|
|
80
|
+
unittest.main()
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Testing model conversion for a few gen-ai models.
|
|
16
|
+
import copy
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
|
26
|
+
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
27
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
28
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
29
|
+
from ai_edge_torch.testing import model_coverage
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestModelConversion(unittest.TestCase):
|
|
33
|
+
"""Unit tests that check for model conversion and correctness."""
|
|
34
|
+
|
|
35
|
+
def test_toy_model_with_kv_cache(self):
|
|
36
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
37
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
38
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
39
|
+
[10], dtype=torch.int64
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
43
|
+
|
|
44
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
45
|
+
skip_output_check = True
|
|
46
|
+
if skip_output_check is False:
|
|
47
|
+
self.assertTrue(
|
|
48
|
+
model_coverage.compare_tflite_torch(
|
|
49
|
+
edge_model,
|
|
50
|
+
pytorch_model,
|
|
51
|
+
(idx, input_pos),
|
|
52
|
+
num_valid_inputs=1,
|
|
53
|
+
atol=1e-5,
|
|
54
|
+
rtol=1e-5,
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def test_toy_model_with_multi_batches(self):
|
|
59
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
60
|
+
config.batch_size = 2
|
|
61
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
62
|
+
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
|
|
63
|
+
[10], dtype=torch.int64
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
67
|
+
|
|
68
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
69
|
+
skip_output_check = True
|
|
70
|
+
if skip_output_check is False:
|
|
71
|
+
self.assertTrue(
|
|
72
|
+
model_coverage.compare_tflite_torch(
|
|
73
|
+
edge_model,
|
|
74
|
+
pytorch_model,
|
|
75
|
+
(idx, input_pos),
|
|
76
|
+
num_valid_inputs=1,
|
|
77
|
+
atol=1e-5,
|
|
78
|
+
rtol=1e-5,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
83
|
+
config = toy_model_with_kv_cache.get_model_config()
|
|
84
|
+
config.enable_hlfb = True
|
|
85
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
86
|
+
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
87
|
+
[10], dtype=torch.int64
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
91
|
+
|
|
92
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
93
|
+
skip_output_check = True
|
|
94
|
+
if skip_output_check is False:
|
|
95
|
+
self.assertTrue(
|
|
96
|
+
model_coverage.compare_tflite_torch(
|
|
97
|
+
edge_model,
|
|
98
|
+
pytorch_model,
|
|
99
|
+
(idx, input_pos),
|
|
100
|
+
num_valid_inputs=1,
|
|
101
|
+
atol=1e-5,
|
|
102
|
+
rtol=1e-5,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def test_tiny_llama(self):
|
|
107
|
+
self.skipTest("b/338288901")
|
|
108
|
+
config = tiny_llama.get_fake_model_config_for_test()
|
|
109
|
+
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
110
|
+
|
|
111
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
112
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
113
|
+
tokens[0, :4] = idx
|
|
114
|
+
input_pos = torch.arange(0, 10)
|
|
115
|
+
|
|
116
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
117
|
+
|
|
118
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
119
|
+
skip_output_check = True
|
|
120
|
+
if skip_output_check is False:
|
|
121
|
+
self.assertTrue(
|
|
122
|
+
model_coverage.compare_tflite_torch(
|
|
123
|
+
edge_model,
|
|
124
|
+
pytorch_model,
|
|
125
|
+
(tokens, input_pos),
|
|
126
|
+
num_valid_inputs=1,
|
|
127
|
+
atol=1e-5,
|
|
128
|
+
rtol=1e-5,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def test_tiny_llama_multisig(self):
|
|
133
|
+
config = tiny_llama.get_fake_model_config_for_test()
|
|
134
|
+
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
135
|
+
|
|
136
|
+
# prefill
|
|
137
|
+
seq_len = 10
|
|
138
|
+
prefill_tokens = torch.full((1, seq_len), 0, dtype=torch.long, device="cpu")
|
|
139
|
+
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
|
140
|
+
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
|
141
|
+
prefill_input_pos = torch.arange(0, seq_len)
|
|
142
|
+
|
|
143
|
+
# decode
|
|
144
|
+
decode_token = torch.tensor([[1]], dtype=torch.long)
|
|
145
|
+
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
|
146
|
+
|
|
147
|
+
edge_model = (
|
|
148
|
+
ai_edge_torch.signature(
|
|
149
|
+
"prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
|
|
150
|
+
)
|
|
151
|
+
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
|
152
|
+
.convert()
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
156
|
+
skip_output_check = True
|
|
157
|
+
if skip_output_check is False:
|
|
158
|
+
copied_model = copy.deepcopy(pytorch_model)
|
|
159
|
+
|
|
160
|
+
self.assertTrue(
|
|
161
|
+
model_coverage.compare_tflite_torch(
|
|
162
|
+
edge_model,
|
|
163
|
+
pytorch_model,
|
|
164
|
+
(prefill_tokens, prefill_input_pos),
|
|
165
|
+
signature_name="prefill",
|
|
166
|
+
num_valid_inputs=1,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self.assertTrue(
|
|
171
|
+
model_coverage.compare_tflite_torch(
|
|
172
|
+
edge_model,
|
|
173
|
+
copied_model,
|
|
174
|
+
(decode_token, decode_input_pos),
|
|
175
|
+
signature_name="decode",
|
|
176
|
+
num_valid_inputs=1,
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def test_gemma(self):
|
|
181
|
+
self.skipTest("b/338288901")
|
|
182
|
+
config = gemma.get_fake_model_config_2b_for_test()
|
|
183
|
+
model = gemma.Gemma(config)
|
|
184
|
+
|
|
185
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
186
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
187
|
+
tokens[0, :4] = idx
|
|
188
|
+
input_pos = torch.arange(0, 10)
|
|
189
|
+
|
|
190
|
+
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
191
|
+
|
|
192
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
193
|
+
skip_output_check = True
|
|
194
|
+
if skip_output_check is False:
|
|
195
|
+
# TODO(talumbau, haoliang): debug numerical diff.
|
|
196
|
+
self.assertTrue(
|
|
197
|
+
model_coverage.compare_tflite_torch(
|
|
198
|
+
edge_model,
|
|
199
|
+
model,
|
|
200
|
+
(tokens, input_pos),
|
|
201
|
+
num_valid_inputs=1,
|
|
202
|
+
atol=1e-2,
|
|
203
|
+
rtol=1e-5,
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def test_phi2(self):
|
|
208
|
+
self.skipTest("b/338288901")
|
|
209
|
+
config = phi2.get_fake_model_config_for_test()
|
|
210
|
+
pytorch_model = phi2.Phi2(config)
|
|
211
|
+
|
|
212
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
213
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
214
|
+
tokens[0, :4] = idx
|
|
215
|
+
input_pos = torch.arange(0, 10)
|
|
216
|
+
|
|
217
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
218
|
+
|
|
219
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
220
|
+
skip_output_check = True
|
|
221
|
+
if skip_output_check is False:
|
|
222
|
+
self.assertTrue(
|
|
223
|
+
model_coverage.compare_tflite_torch(
|
|
224
|
+
edge_model,
|
|
225
|
+
pytorch_model,
|
|
226
|
+
(tokens, input_pos),
|
|
227
|
+
num_valid_inputs=1,
|
|
228
|
+
atol=1e-5,
|
|
229
|
+
rtol=1e-5,
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
if __name__ == "__main__":
|
|
235
|
+
unittest.main()
|