ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (42) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/fx_infra/__init__.py +1 -0
  3. ai_edge_torch/fx_infra/_safe_run_decompositions.py +54 -1
  4. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
  5. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
  6. ai_edge_torch/generative/layers/attention.py +25 -2
  7. ai_edge_torch/generative/layers/attention_test.py +13 -1
  8. ai_edge_torch/generative/layers/attention_utils.py +62 -1
  9. ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
  10. ai_edge_torch/generative/layers/builder.py +4 -2
  11. ai_edge_torch/generative/layers/model_config.py +5 -0
  12. ai_edge_torch/generative/layers/normalization.py +8 -2
  13. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
  14. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
  15. ai_edge_torch/generative/quantize/example.py +1 -1
  16. ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
  17. ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
  18. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
  19. ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
  20. ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
  21. ai_edge_torch/generative/test/test_kv_cache.py +18 -6
  22. ai_edge_torch/generative/test/test_quantize.py +17 -26
  23. ai_edge_torch/generative/utilities/converter.py +97 -22
  24. ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
  25. ai_edge_torch/generative/utilities/loader.py +2 -1
  26. ai_edge_torch/lowertools/translate_recipe.py +8 -3
  27. ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
  28. ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
  29. ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
  30. ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
  31. ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
  32. ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
  33. ai_edge_torch/odml_torch/export.py +24 -7
  34. ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
  35. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +94 -2
  36. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
  37. ai_edge_torch/version.py +1 -1
  38. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/METADATA +15 -3
  39. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/RECORD +42 -36
  40. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/WHEEL +1 -1
  41. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/licenses}/LICENSE +0 -0
  42. {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/top_level.txt +0 -0
@@ -32,23 +32,29 @@ from ai_edge_torch.generative.quantize import quant_attrs
32
32
  from ai_edge_torch.generative.quantize import quant_recipe
33
33
 
34
34
 
35
- def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
35
+ def create_layer_quant_dynamic(
36
+ weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
37
+ granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
38
+ ) -> quant_recipe.LayerQuantRecipe:
36
39
  return quant_recipe.LayerQuantRecipe(
37
40
  activation_dtype=quant_attrs.Dtype.FP32,
38
- weight_dtype=quant_attrs.Dtype.INT8,
41
+ weight_dtype=weight_dtype,
39
42
  mode=quant_attrs.Mode.DYNAMIC_RANGE,
40
43
  algorithm=quant_attrs.Algorithm.MIN_MAX,
41
- granularity=quant_attrs.Granularity.CHANNELWISE,
44
+ granularity=granularity,
42
45
  )
43
46
 
44
47
 
45
- def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
48
+ def create_layer_quant_weight_only(
49
+ weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
50
+ granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
51
+ ) -> quant_recipe.LayerQuantRecipe:
46
52
  return quant_recipe.LayerQuantRecipe(
47
53
  activation_dtype=quant_attrs.Dtype.FP32,
48
- weight_dtype=quant_attrs.Dtype.INT8,
54
+ weight_dtype=weight_dtype,
49
55
  mode=quant_attrs.Mode.WEIGHT_ONLY,
50
56
  algorithm=quant_attrs.Algorithm.MIN_MAX,
51
- granularity=quant_attrs.Granularity.CHANNELWISE,
57
+ granularity=granularity,
52
58
  )
53
59
 
54
60
 
@@ -60,16 +66,3 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
60
66
  algorithm=quant_attrs.Algorithm.FLOAT_CAST,
61
67
  granularity=quant_attrs.Granularity.NONE,
62
68
  )
63
-
64
-
65
- def create_layer_quant_int4_dynamic_block(
66
- block_size: int,
67
- ) -> quant_recipe.LayerQuantRecipe:
68
- return quant_recipe.LayerQuantRecipe(
69
- activation_dtype=quant_attrs.Dtype.FP32,
70
- weight_dtype=quant_attrs.Dtype.INT4,
71
- mode=quant_attrs.Mode.DYNAMIC_RANGE,
72
- algorithm=quant_attrs.Algorithm.MIN_MAX,
73
- granularity=quant_attrs.Granularity.BLOCKWISE,
74
- block_size=block_size,
75
- )
@@ -29,35 +29,44 @@ Typical usage example:
29
29
 
30
30
  from typing import Optional
31
31
  from ai_edge_torch.generative.layers import model_config
32
+ from ai_edge_torch.generative.quantize import quant_attrs
32
33
  from ai_edge_torch.generative.quantize import quant_recipe
33
34
  from ai_edge_torch.generative.quantize import quant_recipe_utils
34
35
  from ai_edge_torch.quantize import quant_config
35
36
 
36
37
 
37
- def full_int8_dynamic_recipe(
38
- mcfg: Optional[model_config.ModelConfig] = None,
38
+ def full_dynamic_recipe(
39
+ mcfg: model_config.ModelConfig | None = None,
40
+ weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
41
+ granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
39
42
  ) -> quant_config.QuantConfig:
40
43
  return quant_config.QuantConfig(
41
44
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
42
- default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
45
+ default=quant_recipe_utils.create_layer_quant_dynamic(
46
+ weight_dtype, granularity
47
+ ),
43
48
  _model_config=mcfg,
44
49
  )
45
50
  )
46
51
 
47
52
 
48
- def full_int8_weight_only_recipe(
49
- mcfg: Optional[model_config.ModelConfig] = None,
53
+ def full_weight_only_recipe(
54
+ mcfg: model_config.ModelConfig | None = None,
55
+ weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
56
+ granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
50
57
  ) -> quant_config.QuantConfig:
51
58
  return quant_config.QuantConfig(
52
59
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
53
- default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
60
+ default=quant_recipe_utils.create_layer_quant_weight_only(
61
+ weight_dtype, granularity
62
+ ),
54
63
  _model_config=mcfg,
55
64
  )
56
65
  )
57
66
 
58
67
 
59
68
  def full_fp16_recipe(
60
- mcfg: Optional[model_config.ModelConfig] = None,
69
+ mcfg: model_config.ModelConfig | None = None,
61
70
  ) -> quant_config.QuantConfig:
62
71
  return quant_config.QuantConfig(
63
72
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
@@ -65,17 +74,3 @@ def full_fp16_recipe(
65
74
  _model_config=mcfg,
66
75
  )
67
76
  )
68
-
69
-
70
- def all_supported_int4_dynamic_block_recipe(
71
- block_size: int,
72
- mcfg: Optional[model_config.ModelConfig] = None,
73
- ) -> quant_config.QuantConfig:
74
- return quant_config.QuantConfig(
75
- generative_recipe=quant_recipe.GenerativeQuantRecipe(
76
- default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
77
- block_size
78
- ),
79
- _model_config=mcfg,
80
- )
81
- )
@@ -29,5 +29,8 @@ def get_supported_layer_schemes():
29
29
  (_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
30
30
  (_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
31
31
  (_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
32
- (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE),
32
+ (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_32),
33
+ (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_64),
34
+ (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_128),
35
+ (_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_256),
33
36
  ]
@@ -41,6 +41,20 @@ class TestKVLayers(googletest.TestCase):
41
41
  )
42
42
  return config
43
43
 
44
+ def _assert_kv_cache_entry_equal(self, kv1, kv2):
45
+ self.assertIsInstance(kv1, kv_utils.KVCacheEntry)
46
+ self.assertIsInstance(kv2, kv_utils.KVCacheEntry)
47
+ self.assertEqual(kv1.kv_layout, kv2.kv_layout)
48
+ self.assertTrue(torch.equal(kv1.k_cache, kv2.k_cache))
49
+ self.assertTrue(torch.equal(kv1.v_cache, kv2.v_cache))
50
+
51
+ def _assert_kv_cache_equal(self, kv1, kv2):
52
+ self.assertIsInstance(kv1, kv_utils.KVCache)
53
+ self.assertIsInstance(kv2, kv_utils.KVCache)
54
+ self.assertEqual(len(kv1.caches), len(kv2.caches))
55
+ for kv1_entry, kv2_entry in zip(kv1.caches, kv2.caches):
56
+ self._assert_kv_cache_entry_equal(kv1_entry, kv2_entry)
57
+
44
58
  def test_cache_udpate(self):
45
59
  N = 1
46
60
  HEAD_DIM = 2
@@ -118,7 +132,7 @@ class TestKVLayers(googletest.TestCase):
118
132
  flat, treespec = pytree.tree_flatten(kv)
119
133
  self.assertLen(flat, NUM_LAYERS * 2)
120
134
  kv_unflat = pytree.tree_unflatten(flat, treespec)
121
- self.assertEqual(kv, kv_unflat)
135
+ self._assert_kv_cache_equal(kv, kv_unflat)
122
136
 
123
137
  def test_pytree_roundtrip_kv_cache_derived(self):
124
138
  NUM_LAYERS = 4
@@ -134,7 +148,7 @@ class TestKVLayers(googletest.TestCase):
134
148
  flat, treespec = pytree.tree_flatten(kv)
135
149
  self.assertLen(flat, NUM_LAYERS * 2)
136
150
  kv_unflat = pytree.tree_unflatten(flat, treespec)
137
- self.assertEqual(kv, kv_unflat)
151
+ self._assert_kv_cache_equal(kv, kv_unflat)
138
152
 
139
153
  def test_pytree_roundtrip_kv_entry(self):
140
154
  attn_config = cfg.AttentionConfig(
@@ -144,8 +158,7 @@ class TestKVLayers(googletest.TestCase):
144
158
  flat, treespec = pytree.tree_flatten(kv)
145
159
  self.assertLen(flat, 2)
146
160
  kv_unflat = pytree.tree_unflatten(flat, treespec)
147
- self.assertEqual(kv, kv_unflat)
148
- self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
161
+ self._assert_kv_cache_entry_equal(kv, kv_unflat)
149
162
 
150
163
  def test_pytree_roundtrip_kv_entry_derived(self):
151
164
  attn_config = cfg.AttentionConfig(
@@ -157,8 +170,7 @@ class TestKVLayers(googletest.TestCase):
157
170
  flat, treespec = pytree.tree_flatten(kv)
158
171
  self.assertLen(flat, 2)
159
172
  kv_unflat = pytree.tree_unflatten(flat, treespec)
160
- self.assertEqual(kv, kv_unflat)
161
- self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
173
+ self._assert_kv_cache_entry_equal(kv, kv_unflat)
162
174
 
163
175
 
164
176
  if __name__ == "__main__":
@@ -79,18 +79,18 @@ class TestVerifyRecipes(parameterized.TestCase):
79
79
  Dtype.INT4,
80
80
  Mode.DYNAMIC_RANGE,
81
81
  Algorithm.MIN_MAX,
82
- Granularity.BLOCKWISE,
83
- 32,
82
+ Granularity.BLOCKWISE_32,
83
+ ),
84
+ (
85
+ Dtype.FP32,
86
+ Dtype.INT4,
87
+ Mode.DYNAMIC_RANGE,
88
+ Algorithm.MIN_MAX,
89
+ Granularity.BLOCKWISE_128,
84
90
  ),
85
91
  ])
86
92
  def test_verify_valid_recipes(
87
- self,
88
- activation,
89
- weight,
90
- mode,
91
- algo,
92
- granularity,
93
- block_size=None,
93
+ self, activation, weight, mode, algo, granularity
94
94
  ):
95
95
  quant_recipe.LayerQuantRecipe(
96
96
  activation, weight, mode, algo, granularity
@@ -108,21 +108,21 @@ class TestQuantizeConvert(parameterized.TestCase):
108
108
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
109
109
  return quant_config.QuantConfig(
110
110
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
111
- attention=quant_recipe_utils.create_layer_quant_int8_dynamic(),
111
+ attention=quant_recipe_utils.create_layer_quant_dynamic(),
112
112
  )
113
113
  )
114
114
 
115
115
  def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
116
116
  return quant_config.QuantConfig(
117
117
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
118
- feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(),
118
+ feedforward=quant_recipe_utils.create_layer_quant_dynamic(),
119
119
  )
120
120
  )
121
121
 
122
122
  @parameterized.parameters([
123
123
  (quant_recipes.full_fp16_recipe()),
124
- (quant_recipes.full_int8_dynamic_recipe()),
125
- (quant_recipes.full_int8_weight_only_recipe()),
124
+ (quant_recipes.full_dynamic_recipe()),
125
+ (quant_recipes.full_weight_only_recipe()),
126
126
  (_attention_int8_dynamic_recipe()),
127
127
  (_feedforward_int8_dynamic_recipe()),
128
128
  ])
@@ -148,7 +148,7 @@ class TestQuantizeConvert(parameterized.TestCase):
148
148
  idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
149
149
  input_pos = torch.arange(0, 100, dtype=torch.int)
150
150
 
151
- quant_config = quant_recipes.full_int8_dynamic_recipe()
151
+ quant_config = quant_recipes.full_dynamic_recipe()
152
152
  quantized_model = ai_edge_torch.convert(
153
153
  pytorch_model, (idx, input_pos), quant_config=quant_config
154
154
  )
@@ -164,7 +164,9 @@ class TestQuantizeConvert(parameterized.TestCase):
164
164
  pytorch_model = toy_model.ToySingleLayerModel(config)
165
165
  idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
166
166
  input_pos = torch.arange(0, 100, dtype=torch.int)
167
- quant_config = quant_recipes.all_supported_int4_dynamic_block_recipe(32)
167
+ quant_config = quant_recipes.full_dynamic_recipe(
168
+ weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32
169
+ )
168
170
  quantized_model = ai_edge_torch.convert(
169
171
  pytorch_model, (idx, input_pos), quant_config=quant_config
170
172
  )
@@ -175,17 +177,6 @@ class TestQuantizeConvert(parameterized.TestCase):
175
177
  "Quantized model isn't smaller than F32 model.",
176
178
  )
177
179
 
178
- def test_unsupported_block_size(self):
179
- config = toy_model.get_model_config()
180
- pytorch_model = toy_model.ToySingleLayerModel(config)
181
- idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
182
- input_pos = torch.arange(0, 100, dtype=torch.int)
183
- self.assertRaises(
184
- ValueError,
185
- quant_recipes.all_supported_int4_dynamic_block_recipe,
186
- 36,
187
- )
188
-
189
180
  def test_quantize_convert_compare_toy(self):
190
181
  self.skipTest("b/338288901")
191
182
  config = toy_model_with_kv_cache.get_model_config()
@@ -25,6 +25,7 @@ from ai_edge_torch._convert import converter as converter_utils
25
25
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
26
  from ai_edge_torch.generative.layers import lora as lora_utils
27
27
  import ai_edge_torch.generative.layers.model_config as cfg
28
+ from ai_edge_torch.generative.quantize import quant_attrs
28
29
  from ai_edge_torch.generative.quantize import quant_recipes
29
30
  from ai_edge_torch.generative.utilities import export_config as export_config_lib
30
31
  from ai_edge_torch.generative.utilities import litertlm_builder
@@ -143,9 +144,23 @@ def define_conversion_flags(
143
144
  '`prefill_seq_lens` as the maximum of kv_cache size and prefill lengths '
144
145
  'in the graph.',
145
146
  )
147
+ flags.DEFINE_bool(
148
+ 'export_gpu_dynamic_shape_verifications',
149
+ False,
150
+ 'If true, the conversion script will export signatures used only for '
151
+ 'verification of GPU dynamic shapes.',
152
+ )
146
153
  return flags
147
154
 
148
155
 
156
+ # Context length for verifying GPU dynamic shapes.
157
+ _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1280
158
+ # Long prefill length for verifying GPU dynamic shapes.
159
+ _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1024
160
+ # Short prefill length for verifying GPU dynamic shapes.
161
+ _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 64
162
+
163
+
149
164
  def is_magic_number_(num: int) -> bool:
150
165
  """Returns true if the number is a magic number, i.e. prime number > 10."""
151
166
  if num < 10:
@@ -193,18 +208,22 @@ def get_quant_recipe_from_flag(
193
208
  case QuantizationName.NONE:
194
209
  return None
195
210
  case QuantizationName.DYNAMIC_INT8:
196
- return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
211
+ return quant_recipes.full_dynamic_recipe(mcfg=model_config)
197
212
  case QuantizationName.WEIGHT_ONLY_INT8:
198
- return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
213
+ return quant_recipes.full_weight_only_recipe(mcfg=model_config)
199
214
  case QuantizationName.FP16:
200
215
  return quant_recipes.full_fp16_recipe()
201
216
  case QuantizationName.DYNAMIC_INT4_BLOCK32:
202
- return quant_recipes.all_supported_int4_dynamic_block_recipe(
203
- 32, mcfg=model_config
217
+ return quant_recipes.full_dynamic_recipe(
218
+ mcfg=model_config,
219
+ weight_dtype=quant_attrs.Dtype.INT4,
220
+ granularity=quant_attrs.Granularity.BLOCKWISE_32,
204
221
  )
205
222
  case QuantizationName.DYNAMIC_INT4_BLOCK128:
206
- return quant_recipes.all_supported_int4_dynamic_block_recipe(
207
- 128, mcfg=model_config
223
+ return quant_recipes.full_dynamic_recipe(
224
+ mcfg=model_config,
225
+ weight_dtype=quant_attrs.Dtype.INT4,
226
+ granularity=quant_attrs.Granularity.BLOCKWISE_128,
208
227
  )
209
228
  case _:
210
229
  raise ValueError(f'Unsupported quantization flag: {quantize}')
@@ -263,6 +282,10 @@ def convert_to_tflite(
263
282
  config: cfg.ModelConfig = None,
264
283
  lora_ranks: Optional[list[int]] = None,
265
284
  export_config: ExportConfig = None,
285
+ extra_model: torch.nn.Module = None,
286
+ extra_prefill_seq_lens: list[int] = None,
287
+ extra_kv_cache_max_len: int = 0,
288
+ extra_signature_prefix: str = '',
266
289
  ):
267
290
  """Converts a nn.Module model to multi-signature tflite model.
268
291
 
@@ -315,6 +338,15 @@ def convert_to_tflite(
315
338
  no LoRA signatures will be added.
316
339
  export_config (ExportConfig, optional): The export configuration. If None,
317
340
  it uses the default export configuration.
341
+ extra_model (torch.nn.Module, optional): PyTorch model to export in
342
+ addition to the pytorch_model. This model can have different
343
+ prefill_seq_lens and kv_cache_max_len.
344
+ extra_prefill_seq_lens (list[int], optional): The prefill sequence
345
+ lengths for extra_model. Meaningful only when extra_model is not None.
346
+ extra_kv_cache_max_len (int, optional): The maximum size of KV cache
347
+ buffer for extra_model. Meaningful only when extra_model is not None.
348
+ extra_signature_prefix (str, optional): The prefix of the extra model
349
+ signatures. Meaningful only when extra_model is not None.
318
350
  """
319
351
  # pylint: disable=protected-access
320
352
  torch._dynamo.config.cache_size_limit = 64
@@ -353,32 +385,51 @@ def convert_to_tflite(
353
385
  )
354
386
  output_file = os.path.join(output_path, output_filename)
355
387
 
356
- _export_helper(
388
+ converter = converter_utils.Converter()
389
+ _add_signatures(
390
+ converter,
357
391
  pytorch_model,
358
- output_file,
359
392
  prefill_seq_lens,
360
393
  kv_cache_max_len,
361
394
  pixel_values_size,
362
395
  pixel_seq_len,
363
- quantize,
364
396
  config,
365
397
  loras,
366
398
  export_config,
367
399
  )
400
+
401
+ if extra_model is not None and extra_prefill_seq_lens:
402
+ _add_signatures(
403
+ converter,
404
+ extra_model,
405
+ extra_prefill_seq_lens,
406
+ extra_kv_cache_max_len,
407
+ pixel_values_size,
408
+ pixel_seq_len,
409
+ config,
410
+ loras,
411
+ export_config,
412
+ signature_prefix=extra_signature_prefix,
413
+ )
414
+
415
+ edge_model = converter.convert(
416
+ quant_config=get_quant_recipe_from_flag(quantize, config),
417
+ )
418
+ edge_model.export(output_file)
368
419
  return output_file
369
420
 
370
421
 
371
- def _export_helper(
422
+ def _add_signatures(
423
+ converter: converter_utils.Converter,
372
424
  pytorch_model: torch.nn.Module,
373
- output_file: str,
374
425
  prefill_seq_lens: list[int],
375
426
  kv_cache_max_len: int,
376
427
  pixel_values_size: torch.Size,
377
428
  pixel_seq_len: int,
378
- quantize: str,
379
429
  config: cfg.ModelConfig,
380
430
  loras: list[None | lora_utils.LoRA],
381
431
  export_config: ExportConfig,
432
+ signature_prefix: str = '',
382
433
  ):
383
434
  """Helper function to export a model to tflite."""
384
435
  prefill_tokens_list = []
@@ -423,17 +474,14 @@ def _export_helper(
423
474
  kv_layout=export_config.kvcache_layout,
424
475
  )
425
476
 
426
- quant_config = get_quant_recipe_from_flag(quantize, config)
427
-
428
477
  # For export, we create a module that captures any non-exportable,
429
478
  # arugments, e.g. the generation config object.
430
479
  mod = ExportableModule(pytorch_model, export_config=export_config).eval()
431
480
 
432
- converter = converter_utils.Converter()
433
481
  for lora in loras:
434
482
  for i in range(len(prefill_seq_lens)):
435
483
  prefill_seq_len = prefill_seq_lens[i]
436
- prefill_signature_name = f'prefill_{prefill_seq_len}'
484
+ prefill_signature_name = f'{signature_prefix}prefill_{prefill_seq_len}'
437
485
 
438
486
  sample_kwargs = {
439
487
  'tokens': prefill_tokens_list[i],
@@ -488,17 +536,15 @@ def _export_helper(
488
536
  if lora is not None:
489
537
  sample_kwargs['lora'] = lora
490
538
 
539
+ decode_signature_name = f'{signature_prefix}decode'
540
+ if lora is not None:
541
+ decode_signature_name += f'_lora_r{lora.get_rank()}'
491
542
  converter.add_signature(
492
- 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
543
+ decode_signature_name,
493
544
  mod,
494
545
  sample_kwargs=sample_kwargs,
495
546
  )
496
547
 
497
- edge_model = converter.convert(
498
- quant_config=quant_config,
499
- )
500
- edge_model.export(output_file)
501
-
502
548
 
503
549
  def build_and_convert_to_tflite_from_flags(
504
550
  model_builder: Callable[
@@ -521,11 +567,36 @@ def build_and_convert_to_tflite_from_flags(
521
567
  get_mask_cache_size_from_flags(),
522
568
  )
523
569
 
570
+ # Extra model for GPU dynamic shape verification if needed.
571
+ extra_model = None
572
+ extra_prefill_seq_lens = None
573
+ extra_kv_cache_max_len = 0
524
574
  if flags.FLAGS.gpu_dynamic_shapes:
525
575
  prefill_seq_lens = [
526
576
  get_magic_number_for(l) for l in flags.FLAGS.prefill_seq_lens
527
577
  ]
528
578
  kv_cache_max_len = get_magic_number_for(flags.FLAGS.kv_cache_max_len)
579
+
580
+ if flags.FLAGS.export_gpu_dynamic_shape_verifications:
581
+ extra_kv_cache_max_len = _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS
582
+ if extra_kv_cache_max_len > flags.FLAGS.kv_cache_max_len:
583
+ extra_kv_cache_max_len = flags.FLAGS.kv_cache_max_len
584
+ extra_model = model_builder(
585
+ checkpoint_path,
586
+ loader.maybe_get_custom_loader(
587
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
588
+ ),
589
+ extra_kv_cache_max_len,
590
+ )
591
+ extra_prefill_seq_lens = []
592
+ if extra_kv_cache_max_len > _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
593
+ extra_prefill_seq_lens.append(
594
+ _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
595
+ )
596
+ if extra_kv_cache_max_len > _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
597
+ extra_prefill_seq_lens.append(
598
+ _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
599
+ )
529
600
  else:
530
601
  prefill_seq_lens = flags.FLAGS.prefill_seq_lens
531
602
  kv_cache_max_len = flags.FLAGS.kv_cache_max_len
@@ -539,6 +610,10 @@ def build_and_convert_to_tflite_from_flags(
539
610
  quantize=flags.FLAGS.quantize,
540
611
  lora_ranks=flags.FLAGS.lora_ranks,
541
612
  export_config=export_config_lib.get_from_flags(),
613
+ extra_model=extra_model,
614
+ extra_prefill_seq_lens=extra_prefill_seq_lens,
615
+ extra_kv_cache_max_len=extra_kv_cache_max_len,
616
+ extra_signature_prefix='test_' if extra_model is not None else '',
542
617
  )
543
618
 
544
619
 
@@ -18,16 +18,19 @@
18
18
 
19
19
  import os
20
20
  import pathlib
21
+ from google.protobuf import text_format
21
22
 
22
23
  try:
23
24
  # pylint: disable=g-import-not-at-top
24
25
  from ai_edge_litert.internal import llm_metadata_pb2
25
26
  from ai_edge_litert.internal import litertlm_builder
27
+ from ai_edge_litert.internal import llm_model_type_pb2
26
28
  # pylint: enable=g-import-not-at-top
27
29
 
28
30
  _litertlm_builder_available = True
29
31
  except ImportError:
30
32
  llm_metadata_pb2 = None
33
+ llm_model_type_pb2 = None
31
34
  litertlm_builder = None
32
35
  _litertlm_builder_available = False
33
36
 
@@ -41,16 +44,19 @@ def build_litertlm(
41
44
  workdir: str,
42
45
  output_path: str,
43
46
  context_length: int,
44
- model_prompt_prefix: str | None,
45
- model_prompt_suffix: str | None,
46
- user_prompt_prefix: str | None,
47
- user_prompt_suffix: str | None,
48
- tokenizer_model_path: str | None,
49
- hf_tokenizer_model_path: str | None,
47
+ model_prompt_prefix: str | None = None,
48
+ model_prompt_suffix: str | None = None,
49
+ user_prompt_prefix: str | None = None,
50
+ user_prompt_suffix: str | None = None,
51
+ tokenizer_model_path: str | None = None,
52
+ hf_tokenizer_model_path: str | None = None,
50
53
  start_token: str | None = None,
51
54
  start_token_id: int | None = None,
52
55
  stop_tokens: str | list[str] | None = None,
53
56
  stop_token_ids: list[int] | None = None,
57
+ llm_model_type: str = 'generic',
58
+ jinja_prompt_template: str | None = None,
59
+ base_llm_metadata_path: str | None = None,
54
60
  **kwargs,
55
61
  ):
56
62
  """Builds a LiteRT-LM file from a TFlite model."""
@@ -58,10 +64,22 @@ def build_litertlm(
58
64
 
59
65
  if not is_litertlm_builder_available():
60
66
  raise ValueError('LiteRT-LM builder is not available.')
61
- assert llm_metadata_pb2 is not None
62
67
  assert litertlm_builder is not None
68
+ assert llm_metadata_pb2 is not None
69
+ assert llm_model_type_pb2 is not None
63
70
 
64
71
  llm_metadata = llm_metadata_pb2.LlmMetadata()
72
+ if base_llm_metadata_path:
73
+ if base_llm_metadata_path.endswith('.pb'):
74
+ with open(base_llm_metadata_path, 'rb') as f:
75
+ llm_metadata.ParseFromString(f.read())
76
+ elif base_llm_metadata_path.endswith('.textproto'):
77
+ with open(base_llm_metadata_path, 'r') as f:
78
+ text_format.Parse(f.read(), llm_metadata, allow_unknown_field=True)
79
+ else:
80
+ raise ValueError(
81
+ 'Base LLM metadata path must be a binary or text proto file.'
82
+ )
65
83
 
66
84
  if start_token_id is not None:
67
85
  llm_metadata.start_token.token_ids.ids.append(start_token_id)
@@ -96,7 +114,42 @@ def build_litertlm(
96
114
 
97
115
  llm_metadata.max_num_tokens = context_length
98
116
 
99
- llm_metadata_path = os.path.join(workdir, 'llm_metadata.pb')
117
+ mdl_type = llm_metadata.llm_model_type.WhichOneof('model_type')
118
+ if not mdl_type or mdl_type == 'generic_model':
119
+ match llm_model_type:
120
+ case litertlm_builder.LlmModelType.GENERIC:
121
+ llm_metadata.llm_model_type.CopyFrom(
122
+ llm_model_type_pb2.LlmModelType(
123
+ generic_model=llm_model_type_pb2.GenericModel()
124
+ )
125
+ )
126
+ case litertlm_builder.LlmModelType.GEMMA3N:
127
+ llm_metadata.llm_model_type.CopyFrom(
128
+ llm_model_type_pb2.LlmModelType(
129
+ gemma3n=llm_model_type_pb2.Gemma3N()
130
+ )
131
+ )
132
+ case litertlm_builder.LlmModelType.GEMMA3:
133
+ llm_metadata.llm_model_type.CopyFrom(
134
+ llm_model_type_pb2.LlmModelType(gemma3=llm_model_type_pb2.Gemma3())
135
+ )
136
+ case litertlm_builder.LlmModelType.QWEN3:
137
+ llm_metadata.llm_model_type.CopyFrom(
138
+ llm_model_type_pb2.LlmModelType(qwen3=llm_model_type_pb2.Qwen3())
139
+ )
140
+ case litertlm_builder.LlmModelType.QWEN2P5:
141
+ llm_metadata.llm_model_type.CopyFrom(
142
+ llm_model_type_pb2.LlmModelType(
143
+ qwen2p5=llm_model_type_pb2.Qwen2p5()
144
+ )
145
+ )
146
+ case _:
147
+ raise ValueError(f'Unsupported LLM model type: {llm_model_type}')
148
+
149
+ if jinja_prompt_template is not None:
150
+ llm_metadata.jinja_prompt_template = jinja_prompt_template
151
+
152
+ llm_metadata_path = os.path.join(workdir, 'llm_metadata_final.pb')
100
153
  with open(llm_metadata_path, 'wb') as f:
101
154
  f.write(llm_metadata.SerializeToString())
102
155
 
@@ -135,7 +135,8 @@ def load_pytorch_statedict(full_path: str):
135
135
 
136
136
  tensors = {}
137
137
  for file in files:
138
- this_file_tensors = torch.load(file)
138
+ map_location = "cpu" if not torch.cuda.is_available() else None
139
+ this_file_tensors = torch.load(file, map_location=map_location)
139
140
  for k in this_file_tensors:
140
141
  assert k not in tensors
141
142
  tensors.update(this_file_tensors)
@@ -80,8 +80,14 @@ def _get_granularity(
80
80
  return _QuantGranularity.CHANNELWISE
81
81
  if granularity == quant_attrs.Granularity.NONE:
82
82
  return _QuantGranularity.TENSORWISE
83
- if granularity == quant_attrs.Granularity.BLOCKWISE:
84
- return _QuantGranularity.BLOCKWISE
83
+ if granularity == quant_attrs.Granularity.BLOCKWISE_32:
84
+ return _QuantGranularity.BLOCKWISE_32
85
+ if granularity == quant_attrs.Granularity.BLOCKWISE_64:
86
+ return _QuantGranularity.BLOCKWISE_64
87
+ if granularity == quant_attrs.Granularity.BLOCKWISE_128:
88
+ return _QuantGranularity.BLOCKWISE_128
89
+ if granularity == quant_attrs.Granularity.BLOCKWISE_256:
90
+ return _QuantGranularity.BLOCKWISE_256
85
91
  raise ValueError('Unimplemented granularity')
86
92
 
87
93
 
@@ -108,7 +114,6 @@ def _set_quant_config(
108
114
  symmetric=True,
109
115
  granularity=_get_granularity(layer_recipe.granularity),
110
116
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
111
- block_size=layer_recipe.block_size,
112
117
  ),
113
118
  compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
114
119
  explicit_dequantize=_get_explicit_dequant_from_mode(
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================