ai-edge-torch-nightly 0.5.0.dev20250511__py3-none-any.whl → 0.5.0.dev20250513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +3 -1
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -1
- ai_edge_torch/generative/quantize/quant_recipes.py +17 -4
- ai_edge_torch/generative/utilities/converter.py +12 -7
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250511.dist-info → ai_edge_torch_nightly-0.5.0.dev20250513.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250511.dist-info → ai_edge_torch_nightly-0.5.0.dev20250513.dist-info}/RECORD +12 -12
- {ai_edge_torch_nightly-0.5.0.dev20250511.dist-info → ai_edge_torch_nightly-0.5.0.dev20250513.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250511.dist-info → ai_edge_torch_nightly-0.5.0.dev20250513.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250511.dist-info → ai_edge_torch_nightly-0.5.0.dev20250513.dist-info}/top_level.txt +0 -0
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
49
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
50
|
intermediate_size=2048,
|
51
51
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
52
|
+
norm_config = cfg.NormalizationConfig(
|
53
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
+
)
|
53
55
|
block_config = cfg.TransformerBlockConfig(
|
54
56
|
attn_config=attn_config,
|
55
57
|
ff_config=ff_config,
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("amd-llama-135m")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("openelm")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -51,7 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
The model config for an OpenELM model.
|
52
52
|
"""
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
54
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
|
55
55
|
)
|
56
56
|
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
57
57
|
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
@@ -27,37 +27,49 @@ Typical usage example:
|
|
27
27
|
)
|
28
28
|
"""
|
29
29
|
|
30
|
+
from typing import Optional
|
31
|
+
from ai_edge_torch.generative.layers import model_config
|
30
32
|
from ai_edge_torch.generative.quantize import quant_recipe
|
31
33
|
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
32
34
|
from ai_edge_torch.quantize import quant_config
|
33
35
|
|
34
36
|
|
35
|
-
def full_int8_dynamic_recipe(
|
37
|
+
def full_int8_dynamic_recipe(
|
38
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
39
|
+
) -> quant_config.QuantConfig:
|
36
40
|
return quant_config.QuantConfig(
|
37
41
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
38
42
|
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
43
|
+
_model_config=mcfg,
|
39
44
|
)
|
40
45
|
)
|
41
46
|
|
42
47
|
|
43
|
-
def full_int8_weight_only_recipe(
|
48
|
+
def full_int8_weight_only_recipe(
|
49
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
50
|
+
) -> quant_config.QuantConfig:
|
44
51
|
return quant_config.QuantConfig(
|
45
52
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
46
53
|
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
|
54
|
+
_model_config=mcfg,
|
47
55
|
)
|
48
56
|
)
|
49
57
|
|
50
58
|
|
51
|
-
def full_fp16_recipe(
|
59
|
+
def full_fp16_recipe(
|
60
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
61
|
+
) -> quant_config.QuantConfig:
|
52
62
|
return quant_config.QuantConfig(
|
53
63
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
54
|
-
default=quant_recipe_utils.create_layer_quant_fp16()
|
64
|
+
default=quant_recipe_utils.create_layer_quant_fp16(),
|
65
|
+
_model_config=mcfg,
|
55
66
|
)
|
56
67
|
)
|
57
68
|
|
58
69
|
|
59
70
|
def all_supported_int4_dynamic_block_recipe(
|
60
71
|
block_size: int,
|
72
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
61
73
|
) -> quant_config.QuantConfig:
|
62
74
|
return quant_config.QuantConfig(
|
63
75
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
@@ -65,5 +77,6 @@ def all_supported_int4_dynamic_block_recipe(
|
|
65
77
|
block_size
|
66
78
|
),
|
67
79
|
embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
80
|
+
_model_config=mcfg,
|
68
81
|
)
|
69
82
|
)
|
@@ -26,6 +26,7 @@ from ai_edge_torch.generative.layers import lora as lora_utils
|
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
27
|
from ai_edge_torch.generative.quantize import quant_recipes
|
28
28
|
from ai_edge_torch.generative.utilities import export_config
|
29
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
29
30
|
import torch
|
30
31
|
|
31
32
|
ExportConfig = export_config.ExportConfig
|
@@ -123,7 +124,8 @@ def define_conversion_flags(
|
|
123
124
|
|
124
125
|
def get_quant_recipe_from_flag(
|
125
126
|
quantize: str,
|
126
|
-
|
127
|
+
model_config: cfg.ModelConfig,
|
128
|
+
) -> Optional[qcfg.QuantConfig]:
|
127
129
|
"""Processes the quantization flag and returns the corresponding recipe.
|
128
130
|
|
129
131
|
Args:
|
@@ -139,15 +141,19 @@ def get_quant_recipe_from_flag(
|
|
139
141
|
case QuantizationName.NONE:
|
140
142
|
return None
|
141
143
|
case QuantizationName.DYNAMIC_INT8:
|
142
|
-
return quant_recipes.full_int8_dynamic_recipe()
|
144
|
+
return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
|
143
145
|
case QuantizationName.WEIGHT_ONLY_INT8:
|
144
|
-
return quant_recipes.full_int8_weight_only_recipe()
|
146
|
+
return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
|
145
147
|
case QuantizationName.FP16:
|
146
148
|
return quant_recipes.full_fp16_recipe()
|
147
149
|
case QuantizationName.DYNAMIC_INT4_BLOCK32:
|
148
|
-
return quant_recipes.
|
150
|
+
return quant_recipes.all_supported_int4_dynamic_block_recipe(
|
151
|
+
32, mcfg=model_config
|
152
|
+
)
|
149
153
|
case QuantizationName.DYNAMIC_INT4_BLOCK128:
|
150
|
-
return quant_recipes.
|
154
|
+
return quant_recipes.all_supported_int4_dynamic_block_recipe(
|
155
|
+
128, mcfg=model_config
|
156
|
+
)
|
151
157
|
case _:
|
152
158
|
raise ValueError(f'Unsupported quantization flag: {quantize}')
|
153
159
|
|
@@ -351,8 +357,7 @@ def _export_helper(
|
|
351
357
|
kv_layout=export_config.kvcache_layout,
|
352
358
|
)
|
353
359
|
|
354
|
-
quant_config = get_quant_recipe_from_flag(quantize)
|
355
|
-
quant_config._model_config = config
|
360
|
+
quant_config = get_quant_recipe_from_flag(quantize, config)
|
356
361
|
|
357
362
|
# For export, we create a module that captures any non-exportable,
|
358
363
|
# arugments, e.g. the generation config object.
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250513
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=Q2u2GS0KjqxWhznlOZBgkCi4NAQcdpjJzkUYdcGYQ5o,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -52,8 +52,8 @@ ai_edge_torch/generative/custom_ops/bmm_4d.py,sha256=JmVbZCujG_wuBchma8QF3DSBfVc
|
|
52
52
|
ai_edge_torch/generative/custom_ops/dynamic_update_slice.py,sha256=ZGAq2CfWZsfef5mHulsWmyUx0dDWJX6J6xPjhBrjQdM,2097
|
53
53
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
54
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
|
-
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=XsDXx6k0kE_OYu_dr7GEC26jCepV1Kv39iH-kpuqA4M,2794
|
56
|
+
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=hiuMFJ8QPymGMM6PiSQqQrfR4M1mblpPuDfjjabcr_w,1560
|
57
57
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
58
58
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
59
59
|
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
|
@@ -86,8 +86,8 @@ ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6X
|
|
86
86
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
87
87
|
ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
|
88
88
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
90
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
89
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=wRdT7bWbCX8g4TbzKbjcLx6vmKtuT5-g-ipg19hJW-M,1525
|
90
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hPcXYHj-nBP56TOeQQejB3HRzv6yHSftHOx0OEPP5M8,4574
|
91
91
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
|
92
92
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
93
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
|
@@ -180,7 +180,7 @@ ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FT
|
|
180
180
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
|
181
181
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
|
182
182
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
|
183
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
183
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=45DJfcQXZ1FA1qI4LgYoYE4UD4yvfIYoY9LgYTeKFVw,2845
|
184
184
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
|
185
185
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
186
186
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
@@ -192,7 +192,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
192
192
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
|
193
193
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
194
194
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
195
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
195
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=OMBy_nos9mEGMQOAD8o0on-gAkRk-kliodFSTthD5BE,14612
|
196
196
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
197
197
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
198
198
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=IG-88o7nWI9XrNDnwnQ-MoilsuqJ7KwrnbP3bn2EY9U,6334
|
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
251
251
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
252
252
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
253
253
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
254
|
+
ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/METADATA,sha256=qh5r3x7C0ksa3D2WriWd0yePFgxK8urh9aSsJCC_gjY,2074
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/RECORD,,
|
File without changes
|
File without changes
|