ai-edge-torch-nightly 0.5.0.dev20250511__py3-none-any.whl → 0.5.0.dev20250513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
49
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
50
  intermediate_size=2048,
51
51
  )
52
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
52
+ norm_config = cfg.NormalizationConfig(
53
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
+ )
53
55
  block_config = cfg.TransformerBlockConfig(
54
56
  attn_config=attn_config,
55
57
  ff_config=ff_config,
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("amd-llama-135m")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("openelm")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -51,7 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
51
  The model config for an OpenELM model.
52
52
  """
53
53
  norm_config = cfg.NormalizationConfig(
54
- type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
54
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
55
55
  )
56
56
  num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
57
57
  num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
@@ -27,37 +27,49 @@ Typical usage example:
27
27
  )
28
28
  """
29
29
 
30
+ from typing import Optional
31
+ from ai_edge_torch.generative.layers import model_config
30
32
  from ai_edge_torch.generative.quantize import quant_recipe
31
33
  from ai_edge_torch.generative.quantize import quant_recipe_utils
32
34
  from ai_edge_torch.quantize import quant_config
33
35
 
34
36
 
35
- def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
37
+ def full_int8_dynamic_recipe(
38
+ mcfg: Optional[model_config.ModelConfig] = None,
39
+ ) -> quant_config.QuantConfig:
36
40
  return quant_config.QuantConfig(
37
41
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
38
42
  default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
43
+ _model_config=mcfg,
39
44
  )
40
45
  )
41
46
 
42
47
 
43
- def full_int8_weight_only_recipe() -> quant_config.QuantConfig:
48
+ def full_int8_weight_only_recipe(
49
+ mcfg: Optional[model_config.ModelConfig] = None,
50
+ ) -> quant_config.QuantConfig:
44
51
  return quant_config.QuantConfig(
45
52
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
46
53
  default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
54
+ _model_config=mcfg,
47
55
  )
48
56
  )
49
57
 
50
58
 
51
- def full_fp16_recipe() -> quant_config.QuantConfig:
59
+ def full_fp16_recipe(
60
+ mcfg: Optional[model_config.ModelConfig] = None,
61
+ ) -> quant_config.QuantConfig:
52
62
  return quant_config.QuantConfig(
53
63
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
54
- default=quant_recipe_utils.create_layer_quant_fp16()
64
+ default=quant_recipe_utils.create_layer_quant_fp16(),
65
+ _model_config=mcfg,
55
66
  )
56
67
  )
57
68
 
58
69
 
59
70
  def all_supported_int4_dynamic_block_recipe(
60
71
  block_size: int,
72
+ mcfg: Optional[model_config.ModelConfig] = None,
61
73
  ) -> quant_config.QuantConfig:
62
74
  return quant_config.QuantConfig(
63
75
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
@@ -65,5 +77,6 @@ def all_supported_int4_dynamic_block_recipe(
65
77
  block_size
66
78
  ),
67
79
  embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
80
+ _model_config=mcfg,
68
81
  )
69
82
  )
@@ -26,6 +26,7 @@ from ai_edge_torch.generative.layers import lora as lora_utils
26
26
  import ai_edge_torch.generative.layers.model_config as cfg
27
27
  from ai_edge_torch.generative.quantize import quant_recipes
28
28
  from ai_edge_torch.generative.utilities import export_config
29
+ from ai_edge_torch.quantize import quant_config as qcfg
29
30
  import torch
30
31
 
31
32
  ExportConfig = export_config.ExportConfig
@@ -123,7 +124,8 @@ def define_conversion_flags(
123
124
 
124
125
  def get_quant_recipe_from_flag(
125
126
  quantize: str,
126
- ) -> Optional[quant_recipes.QuantizationRecipe]:
127
+ model_config: cfg.ModelConfig,
128
+ ) -> Optional[qcfg.QuantConfig]:
127
129
  """Processes the quantization flag and returns the corresponding recipe.
128
130
 
129
131
  Args:
@@ -139,15 +141,19 @@ def get_quant_recipe_from_flag(
139
141
  case QuantizationName.NONE:
140
142
  return None
141
143
  case QuantizationName.DYNAMIC_INT8:
142
- return quant_recipes.full_int8_dynamic_recipe()
144
+ return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
143
145
  case QuantizationName.WEIGHT_ONLY_INT8:
144
- return quant_recipes.full_int8_weight_only_recipe()
146
+ return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
145
147
  case QuantizationName.FP16:
146
148
  return quant_recipes.full_fp16_recipe()
147
149
  case QuantizationName.DYNAMIC_INT4_BLOCK32:
148
- return quant_recipes.full_int4_dynamic_block_recipe(32)
150
+ return quant_recipes.all_supported_int4_dynamic_block_recipe(
151
+ 32, mcfg=model_config
152
+ )
149
153
  case QuantizationName.DYNAMIC_INT4_BLOCK128:
150
- return quant_recipes.full_int4_dynamic_block_recipe(128)
154
+ return quant_recipes.all_supported_int4_dynamic_block_recipe(
155
+ 128, mcfg=model_config
156
+ )
151
157
  case _:
152
158
  raise ValueError(f'Unsupported quantization flag: {quantize}')
153
159
 
@@ -351,8 +357,7 @@ def _export_helper(
351
357
  kv_layout=export_config.kvcache_layout,
352
358
  )
353
359
 
354
- quant_config = get_quant_recipe_from_flag(quantize)
355
- quant_config._model_config = config
360
+ quant_config = get_quant_recipe_from_flag(quantize, config)
356
361
 
357
362
  # For export, we create a module that captures any non-exportable,
358
363
  # arugments, e.g. the generation config object.
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250511"
16
+ __version__ = "0.5.0.dev20250513"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250511
3
+ Version: 0.5.0.dev20250513
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=hwgGQ5rNXzjaW8x5d5_q1vreBcKdC0qd0Sd_5QYRF_o,706
5
+ ai_edge_torch/version.py,sha256=Q2u2GS0KjqxWhznlOZBgkCi4NAQcdpjJzkUYdcGYQ5o,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -52,8 +52,8 @@ ai_edge_torch/generative/custom_ops/bmm_4d.py,sha256=JmVbZCujG_wuBchma8QF3DSBfVc
52
52
  ai_edge_torch/generative/custom_ops/dynamic_update_slice.py,sha256=ZGAq2CfWZsfef5mHulsWmyUx0dDWJX6J6xPjhBrjQdM,2097
53
53
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
54
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
56
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
55
+ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=XsDXx6k0kE_OYu_dr7GEC26jCepV1Kv39iH-kpuqA4M,2794
56
+ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=hiuMFJ8QPymGMM6PiSQqQrfR4M1mblpPuDfjjabcr_w,1560
57
57
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
58
58
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
59
59
  ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
@@ -86,8 +86,8 @@ ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6X
86
86
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
87
87
  ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
88
88
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=Hgp31zIQdJsTweRMr0U3d2SKW1h2nWnqWt1FlmuQqiI,1551
90
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
89
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=wRdT7bWbCX8g4TbzKbjcLx6vmKtuT5-g-ipg19hJW-M,1525
90
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hPcXYHj-nBP56TOeQQejB3HRzv6yHSftHOx0OEPP5M8,4574
91
91
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
92
92
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
93
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=7HHXkC-IIu7ieBvBI4RlXs_oITz7R8a6YVYQskAs_Uk,2023
@@ -180,7 +180,7 @@ ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FT
180
180
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
181
181
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
182
182
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
183
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=5UkUAT0qsWzLtNAeX-M5hEMi-kqoLV70_F76QiXmVZ4,2424
183
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=45DJfcQXZ1FA1qI4LgYoYE4UD4yvfIYoY9LgYTeKFVw,2845
184
184
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
185
185
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
186
186
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
@@ -192,7 +192,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
192
192
  ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
193
193
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
194
194
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
195
- ai_edge_torch/generative/utilities/converter.py,sha256=LrBqxXVxkOWh4abcHfY4QXRpYxjjfEYd4ifrpGGbebI,14441
195
+ ai_edge_torch/generative/utilities/converter.py,sha256=OMBy_nos9mEGMQOAD8o0on-gAkRk-kliodFSTthD5BE,14612
196
196
  ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
197
197
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
198
198
  ai_edge_torch/generative/utilities/model_builder.py,sha256=IG-88o7nWI9XrNDnwnQ-MoilsuqJ7KwrnbP3bn2EY9U,6334
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
251
251
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
252
252
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
253
253
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
254
- ai_edge_torch_nightly-0.5.0.dev20250511.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
- ai_edge_torch_nightly-0.5.0.dev20250511.dist-info/METADATA,sha256=PmlXdlLctkno1gMu-BWqW8CjHcSargbvVYhYycNMKTs,2074
256
- ai_edge_torch_nightly-0.5.0.dev20250511.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
- ai_edge_torch_nightly-0.5.0.dev20250511.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
- ai_edge_torch_nightly-0.5.0.dev20250511.dist-info/RECORD,,
254
+ ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
255
+ ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/METADATA,sha256=qh5r3x7C0ksa3D2WriWd0yePFgxK8urh9aSsJCC_gjY,2074
256
+ ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
257
+ ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
258
+ ai_edge_torch_nightly-0.5.0.dev20250513.dist-info/RECORD,,