ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (57) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
  3. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
  4. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
  7. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
  8. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
  9. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
  10. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
  11. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
  12. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
  13. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
  14. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
  15. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
  16. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
  17. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
  18. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
  19. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
  20. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
  21. ai_edge_torch/generative/layers/attention.py +25 -2
  22. ai_edge_torch/generative/layers/attention_test.py +13 -1
  23. ai_edge_torch/generative/layers/attention_utils.py +62 -1
  24. ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
  25. ai_edge_torch/generative/layers/builder.py +4 -2
  26. ai_edge_torch/generative/layers/model_config.py +5 -0
  27. ai_edge_torch/generative/layers/normalization.py +8 -2
  28. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
  29. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
  30. ai_edge_torch/generative/quantize/example.py +1 -1
  31. ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
  32. ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
  33. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
  34. ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
  35. ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
  36. ai_edge_torch/generative/test/test_kv_cache.py +18 -6
  37. ai_edge_torch/generative/test/test_quantize.py +17 -26
  38. ai_edge_torch/generative/utilities/converter.py +183 -28
  39. ai_edge_torch/generative/utilities/export_config.py +2 -0
  40. ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
  41. ai_edge_torch/generative/utilities/loader.py +2 -1
  42. ai_edge_torch/lowertools/translate_recipe.py +8 -3
  43. ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
  44. ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
  45. ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
  46. ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
  47. ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
  48. ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
  49. ai_edge_torch/odml_torch/export.py +24 -7
  50. ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
  51. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
  54. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
  55. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
  56. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
  57. {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,20 @@ class TestKVLayers(googletest.TestCase):
41
41
  )
42
42
  return config
43
43
 
44
+ def _assert_kv_cache_entry_equal(self, kv1, kv2):
45
+ self.assertIsInstance(kv1, kv_utils.KVCacheEntry)
46
+ self.assertIsInstance(kv2, kv_utils.KVCacheEntry)
47
+ self.assertEqual(kv1.kv_layout, kv2.kv_layout)
48
+ self.assertTrue(torch.equal(kv1.k_cache, kv2.k_cache))
49
+ self.assertTrue(torch.equal(kv1.v_cache, kv2.v_cache))
50
+
51
+ def _assert_kv_cache_equal(self, kv1, kv2):
52
+ self.assertIsInstance(kv1, kv_utils.KVCache)
53
+ self.assertIsInstance(kv2, kv_utils.KVCache)
54
+ self.assertEqual(len(kv1.caches), len(kv2.caches))
55
+ for kv1_entry, kv2_entry in zip(kv1.caches, kv2.caches):
56
+ self._assert_kv_cache_entry_equal(kv1_entry, kv2_entry)
57
+
44
58
  def test_cache_udpate(self):
45
59
  N = 1
46
60
  HEAD_DIM = 2
@@ -118,7 +132,7 @@ class TestKVLayers(googletest.TestCase):
118
132
  flat, treespec = pytree.tree_flatten(kv)
119
133
  self.assertLen(flat, NUM_LAYERS * 2)
120
134
  kv_unflat = pytree.tree_unflatten(flat, treespec)
121
- self.assertEqual(kv, kv_unflat)
135
+ self._assert_kv_cache_equal(kv, kv_unflat)
122
136
 
123
137
  def test_pytree_roundtrip_kv_cache_derived(self):
124
138
  NUM_LAYERS = 4
@@ -134,7 +148,7 @@ class TestKVLayers(googletest.TestCase):
134
148
  flat, treespec = pytree.tree_flatten(kv)
135
149
  self.assertLen(flat, NUM_LAYERS * 2)
136
150
  kv_unflat = pytree.tree_unflatten(flat, treespec)
137
- self.assertEqual(kv, kv_unflat)
151
+ self._assert_kv_cache_equal(kv, kv_unflat)
138
152
 
139
153
  def test_pytree_roundtrip_kv_entry(self):
140
154
  attn_config = cfg.AttentionConfig(
@@ -144,8 +158,7 @@ class TestKVLayers(googletest.TestCase):
144
158
  flat, treespec = pytree.tree_flatten(kv)
145
159
  self.assertLen(flat, 2)
146
160
  kv_unflat = pytree.tree_unflatten(flat, treespec)
147
- self.assertEqual(kv, kv_unflat)
148
- self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
161
+ self._assert_kv_cache_entry_equal(kv, kv_unflat)
149
162
 
150
163
  def test_pytree_roundtrip_kv_entry_derived(self):
151
164
  attn_config = cfg.AttentionConfig(
@@ -157,8 +170,7 @@ class TestKVLayers(googletest.TestCase):
157
170
  flat, treespec = pytree.tree_flatten(kv)
158
171
  self.assertLen(flat, 2)
159
172
  kv_unflat = pytree.tree_unflatten(flat, treespec)
160
- self.assertEqual(kv, kv_unflat)
161
- self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
173
+ self._assert_kv_cache_entry_equal(kv, kv_unflat)
162
174
 
163
175
 
164
176
  if __name__ == "__main__":
@@ -79,18 +79,18 @@ class TestVerifyRecipes(parameterized.TestCase):
79
79
  Dtype.INT4,
80
80
  Mode.DYNAMIC_RANGE,
81
81
  Algorithm.MIN_MAX,
82
- Granularity.BLOCKWISE,
83
- 32,
82
+ Granularity.BLOCKWISE_32,
83
+ ),
84
+ (
85
+ Dtype.FP32,
86
+ Dtype.INT4,
87
+ Mode.DYNAMIC_RANGE,
88
+ Algorithm.MIN_MAX,
89
+ Granularity.BLOCKWISE_128,
84
90
  ),
85
91
  ])
86
92
  def test_verify_valid_recipes(
87
- self,
88
- activation,
89
- weight,
90
- mode,
91
- algo,
92
- granularity,
93
- block_size=None,
93
+ self, activation, weight, mode, algo, granularity
94
94
  ):
95
95
  quant_recipe.LayerQuantRecipe(
96
96
  activation, weight, mode, algo, granularity
@@ -108,21 +108,21 @@ class TestQuantizeConvert(parameterized.TestCase):
108
108
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
109
109
  return quant_config.QuantConfig(
110
110
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
111
- attention=quant_recipe_utils.create_layer_quant_int8_dynamic(),
111
+ attention=quant_recipe_utils.create_layer_quant_dynamic(),
112
112
  )
113
113
  )
114
114
 
115
115
  def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
116
116
  return quant_config.QuantConfig(
117
117
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
118
- feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(),
118
+ feedforward=quant_recipe_utils.create_layer_quant_dynamic(),
119
119
  )
120
120
  )
121
121
 
122
122
  @parameterized.parameters([
123
123
  (quant_recipes.full_fp16_recipe()),
124
- (quant_recipes.full_int8_dynamic_recipe()),
125
- (quant_recipes.full_int8_weight_only_recipe()),
124
+ (quant_recipes.full_dynamic_recipe()),
125
+ (quant_recipes.full_weight_only_recipe()),
126
126
  (_attention_int8_dynamic_recipe()),
127
127
  (_feedforward_int8_dynamic_recipe()),
128
128
  ])
@@ -148,7 +148,7 @@ class TestQuantizeConvert(parameterized.TestCase):
148
148
  idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
149
149
  input_pos = torch.arange(0, 100, dtype=torch.int)
150
150
 
151
- quant_config = quant_recipes.full_int8_dynamic_recipe()
151
+ quant_config = quant_recipes.full_dynamic_recipe()
152
152
  quantized_model = ai_edge_torch.convert(
153
153
  pytorch_model, (idx, input_pos), quant_config=quant_config
154
154
  )
@@ -164,7 +164,9 @@ class TestQuantizeConvert(parameterized.TestCase):
164
164
  pytorch_model = toy_model.ToySingleLayerModel(config)
165
165
  idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
166
166
  input_pos = torch.arange(0, 100, dtype=torch.int)
167
- quant_config = quant_recipes.all_supported_int4_dynamic_block_recipe(32)
167
+ quant_config = quant_recipes.full_dynamic_recipe(
168
+ weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32
169
+ )
168
170
  quantized_model = ai_edge_torch.convert(
169
171
  pytorch_model, (idx, input_pos), quant_config=quant_config
170
172
  )
@@ -175,17 +177,6 @@ class TestQuantizeConvert(parameterized.TestCase):
175
177
  "Quantized model isn't smaller than F32 model.",
176
178
  )
177
179
 
178
- def test_unsupported_block_size(self):
179
- config = toy_model.get_model_config()
180
- pytorch_model = toy_model.ToySingleLayerModel(config)
181
- idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
182
- input_pos = torch.arange(0, 100, dtype=torch.int)
183
- self.assertRaises(
184
- ValueError,
185
- quant_recipes.all_supported_int4_dynamic_block_recipe,
186
- 36,
187
- )
188
-
189
180
  def test_quantize_convert_compare_toy(self):
190
181
  self.skipTest("b/338288901")
191
182
  config = toy_model_with_kv_cache.get_model_config()
@@ -19,15 +19,17 @@ import enum
19
19
  import os
20
20
  import pathlib
21
21
  import tempfile
22
- from typing import Any, Optional, Union
22
+ from typing import Callable, Dict, Optional, Union
23
23
  from absl import flags
24
24
  from ai_edge_torch._convert import converter as converter_utils
25
25
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
26
  from ai_edge_torch.generative.layers import lora as lora_utils
27
27
  import ai_edge_torch.generative.layers.model_config as cfg
28
+ from ai_edge_torch.generative.quantize import quant_attrs
28
29
  from ai_edge_torch.generative.quantize import quant_recipes
29
30
  from ai_edge_torch.generative.utilities import export_config as export_config_lib
30
31
  from ai_edge_torch.generative.utilities import litertlm_builder
32
+ from ai_edge_torch.generative.utilities import loader
31
33
  from ai_edge_torch.quantize import quant_config as qcfg
32
34
  import torch
33
35
 
@@ -94,6 +96,11 @@ def define_conversion_flags(
94
96
  (8, 64, 128, 256, 512, 1024),
95
97
  'List of the maximum sizes of prefill input tensors.',
96
98
  )
99
+ flags.DEFINE_integer(
100
+ 'decode_batch_size',
101
+ 1,
102
+ 'The batch size for the decode signature.',
103
+ )
97
104
  flags.DEFINE_integer(
98
105
  'kv_cache_max_len',
99
106
  1280,
@@ -102,14 +109,14 @@ def define_conversion_flags(
102
109
  flags.DEFINE_string(
103
110
  'quantize',
104
111
  'dynamic_int8',
105
- 'How the model should be quantized. Set to "none" to disable'
106
- ' quantization. See `QuantizationName` for supported quantization types.',
112
+ 'How the model should be quantized. Set to "none" to disable '
113
+ 'quantization. See `QuantizationName` for supported quantization types.',
107
114
  )
108
115
  flags.DEFINE_multi_integer(
109
116
  'lora_ranks',
110
117
  None,
111
- 'If set, the model will be converted with the provided list of LoRA'
112
- ' ranks.',
118
+ 'If set, the model will be converted with the provided list of LoRA '
119
+ 'ranks.',
113
120
  )
114
121
  flags.DEFINE_bool(
115
122
  'mask_as_input',
@@ -125,15 +132,61 @@ def define_conversion_flags(
125
132
  flags.DEFINE_bool(
126
133
  'custom_checkpoint_loader',
127
134
  False,
128
- 'If true, the conversion script will use a custom checkpoint loader which'
129
- ' will read a checkpoint from a remote source.',
135
+ 'If true, the conversion script will use a custom checkpoint loader '
136
+ 'which will read a checkpoint from a remote source.',
137
+ )
138
+ flags.DEFINE_bool(
139
+ 'gpu_dynamic_shapes',
140
+ False,
141
+ 'It is to support dynamic shapes on GPU effectively. If true, the graph '
142
+ 'sets the actual kv_cache size and prefill lengths when the graph is '
143
+ 'initialized for inference based on the flags, `kv_cache_max_len` and '
144
+ '`prefill_seq_lens` as the maximum of kv_cache size and prefill lengths '
145
+ 'in the graph.',
146
+ )
147
+ flags.DEFINE_bool(
148
+ 'export_gpu_dynamic_shape_verifications',
149
+ False,
150
+ 'If true, the conversion script will export signatures used only for '
151
+ 'verification of GPU dynamic shapes.',
130
152
  )
131
153
  return flags
132
154
 
133
155
 
156
+ # Context length for verifying GPU dynamic shapes.
157
+ _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1280
158
+ # Long prefill length for verifying GPU dynamic shapes.
159
+ _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1024
160
+ # Short prefill length for verifying GPU dynamic shapes.
161
+ _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 64
162
+
163
+
164
+ def is_magic_number_(num: int) -> bool:
165
+ """Returns true if the number is a magic number, i.e. prime number > 10."""
166
+ if num < 10:
167
+ return False
168
+ if num % 2 == 0:
169
+ return False
170
+ for i in range(3, int(num / 2), 2):
171
+ if num % i == 0:
172
+ return False
173
+ return True
174
+
175
+
176
+ def get_magic_number_for(org_number: int) -> int:
177
+ """Returns the magic number for the given original number."""
178
+ while not is_magic_number_(org_number):
179
+ org_number += 1
180
+ return org_number
181
+
182
+
134
183
  def get_mask_cache_size_from_flags() -> int:
135
184
  """Returns the mask cache size according to the flags."""
136
- return 0 if flags.FLAGS.mask_as_input else flags.FLAGS.kv_cache_max_len
185
+ if flags.FLAGS.mask_as_input:
186
+ return 0
187
+ if flags.FLAGS.gpu_dynamic_shapes:
188
+ return get_magic_number_for(flags.FLAGS.kv_cache_max_len)
189
+ return flags.FLAGS.kv_cache_max_len
137
190
 
138
191
 
139
192
  def get_quant_recipe_from_flag(
@@ -155,18 +208,22 @@ def get_quant_recipe_from_flag(
155
208
  case QuantizationName.NONE:
156
209
  return None
157
210
  case QuantizationName.DYNAMIC_INT8:
158
- return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
211
+ return quant_recipes.full_dynamic_recipe(mcfg=model_config)
159
212
  case QuantizationName.WEIGHT_ONLY_INT8:
160
- return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
213
+ return quant_recipes.full_weight_only_recipe(mcfg=model_config)
161
214
  case QuantizationName.FP16:
162
215
  return quant_recipes.full_fp16_recipe()
163
216
  case QuantizationName.DYNAMIC_INT4_BLOCK32:
164
- return quant_recipes.all_supported_int4_dynamic_block_recipe(
165
- 32, mcfg=model_config
217
+ return quant_recipes.full_dynamic_recipe(
218
+ mcfg=model_config,
219
+ weight_dtype=quant_attrs.Dtype.INT4,
220
+ granularity=quant_attrs.Granularity.BLOCKWISE_32,
166
221
  )
167
222
  case QuantizationName.DYNAMIC_INT4_BLOCK128:
168
- return quant_recipes.all_supported_int4_dynamic_block_recipe(
169
- 128, mcfg=model_config
223
+ return quant_recipes.full_dynamic_recipe(
224
+ mcfg=model_config,
225
+ weight_dtype=quant_attrs.Dtype.INT4,
226
+ granularity=quant_attrs.Granularity.BLOCKWISE_128,
170
227
  )
171
228
  case _:
172
229
  raise ValueError(f'Unsupported quantization flag: {quantize}')
@@ -225,6 +282,10 @@ def convert_to_tflite(
225
282
  config: cfg.ModelConfig = None,
226
283
  lora_ranks: Optional[list[int]] = None,
227
284
  export_config: ExportConfig = None,
285
+ extra_model: torch.nn.Module = None,
286
+ extra_prefill_seq_lens: list[int] = None,
287
+ extra_kv_cache_max_len: int = 0,
288
+ extra_signature_prefix: str = '',
228
289
  ):
229
290
  """Converts a nn.Module model to multi-signature tflite model.
230
291
 
@@ -277,6 +338,15 @@ def convert_to_tflite(
277
338
  no LoRA signatures will be added.
278
339
  export_config (ExportConfig, optional): The export configuration. If None,
279
340
  it uses the default export configuration.
341
+ extra_model (torch.nn.Module, optional): PyTorch model to export in
342
+ addition to the pytorch_model. This model can have different
343
+ prefill_seq_lens and kv_cache_max_len.
344
+ extra_prefill_seq_lens (list[int], optional): The prefill sequence
345
+ lengths for extra_model. Meaningful only when extra_model is not None.
346
+ extra_kv_cache_max_len (int, optional): The maximum size of KV cache
347
+ buffer for extra_model. Meaningful only when extra_model is not None.
348
+ extra_signature_prefix (str, optional): The prefix of the extra model
349
+ signatures. Meaningful only when extra_model is not None.
280
350
  """
281
351
  # pylint: disable=protected-access
282
352
  torch._dynamo.config.cache_size_limit = 64
@@ -315,32 +385,51 @@ def convert_to_tflite(
315
385
  )
316
386
  output_file = os.path.join(output_path, output_filename)
317
387
 
318
- _export_helper(
388
+ converter = converter_utils.Converter()
389
+ _add_signatures(
390
+ converter,
319
391
  pytorch_model,
320
- output_file,
321
392
  prefill_seq_lens,
322
393
  kv_cache_max_len,
323
394
  pixel_values_size,
324
395
  pixel_seq_len,
325
- quantize,
326
396
  config,
327
397
  loras,
328
398
  export_config,
329
399
  )
400
+
401
+ if extra_model is not None and extra_prefill_seq_lens:
402
+ _add_signatures(
403
+ converter,
404
+ extra_model,
405
+ extra_prefill_seq_lens,
406
+ extra_kv_cache_max_len,
407
+ pixel_values_size,
408
+ pixel_seq_len,
409
+ config,
410
+ loras,
411
+ export_config,
412
+ signature_prefix=extra_signature_prefix,
413
+ )
414
+
415
+ edge_model = converter.convert(
416
+ quant_config=get_quant_recipe_from_flag(quantize, config),
417
+ )
418
+ edge_model.export(output_file)
330
419
  return output_file
331
420
 
332
421
 
333
- def _export_helper(
422
+ def _add_signatures(
423
+ converter: converter_utils.Converter,
334
424
  pytorch_model: torch.nn.Module,
335
- output_file: str,
336
425
  prefill_seq_lens: list[int],
337
426
  kv_cache_max_len: int,
338
427
  pixel_values_size: torch.Size,
339
428
  pixel_seq_len: int,
340
- quantize: str,
341
429
  config: cfg.ModelConfig,
342
430
  loras: list[None | lora_utils.LoRA],
343
431
  export_config: ExportConfig,
432
+ signature_prefix: str = '',
344
433
  ):
345
434
  """Helper function to export a model to tflite."""
346
435
  prefill_tokens_list = []
@@ -385,17 +474,14 @@ def _export_helper(
385
474
  kv_layout=export_config.kvcache_layout,
386
475
  )
387
476
 
388
- quant_config = get_quant_recipe_from_flag(quantize, config)
389
-
390
477
  # For export, we create a module that captures any non-exportable,
391
478
  # arugments, e.g. the generation config object.
392
479
  mod = ExportableModule(pytorch_model, export_config=export_config).eval()
393
480
 
394
- converter = converter_utils.Converter()
395
481
  for lora in loras:
396
482
  for i in range(len(prefill_seq_lens)):
397
483
  prefill_seq_len = prefill_seq_lens[i]
398
- prefill_signature_name = f'prefill_{prefill_seq_len}'
484
+ prefill_signature_name = f'{signature_prefix}prefill_{prefill_seq_len}'
399
485
 
400
486
  sample_kwargs = {
401
487
  'tokens': prefill_tokens_list[i],
@@ -450,16 +536,85 @@ def _export_helper(
450
536
  if lora is not None:
451
537
  sample_kwargs['lora'] = lora
452
538
 
539
+ decode_signature_name = f'{signature_prefix}decode'
540
+ if lora is not None:
541
+ decode_signature_name += f'_lora_r{lora.get_rank()}'
453
542
  converter.add_signature(
454
- 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}',
543
+ decode_signature_name,
455
544
  mod,
456
545
  sample_kwargs=sample_kwargs,
457
546
  )
458
547
 
459
- edge_model = converter.convert(
460
- quant_config=quant_config,
548
+
549
+ def build_and_convert_to_tflite_from_flags(
550
+ model_builder: Callable[
551
+ [str, Callable[[str], Dict[str, torch.Tensor]], int], torch.nn.Module
552
+ ],
553
+ checkpoint_path: str = None,
554
+ output_name_prefix: str = None,
555
+ ):
556
+ """Builds a nn.Module model and converts it according to the flags."""
557
+ if checkpoint_path is None:
558
+ checkpoint_path = flags.FLAGS.checkpoint_path
559
+ if output_name_prefix is None:
560
+ output_name_prefix = flags.FLAGS.output_name_prefix
561
+
562
+ pytorch_model = model_builder(
563
+ checkpoint_path,
564
+ loader.maybe_get_custom_loader(
565
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
566
+ ),
567
+ get_mask_cache_size_from_flags(),
568
+ )
569
+
570
+ # Extra model for GPU dynamic shape verification if needed.
571
+ extra_model = None
572
+ extra_prefill_seq_lens = None
573
+ extra_kv_cache_max_len = 0
574
+ if flags.FLAGS.gpu_dynamic_shapes:
575
+ prefill_seq_lens = [
576
+ get_magic_number_for(l) for l in flags.FLAGS.prefill_seq_lens
577
+ ]
578
+ kv_cache_max_len = get_magic_number_for(flags.FLAGS.kv_cache_max_len)
579
+
580
+ if flags.FLAGS.export_gpu_dynamic_shape_verifications:
581
+ extra_kv_cache_max_len = _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS
582
+ if extra_kv_cache_max_len > flags.FLAGS.kv_cache_max_len:
583
+ extra_kv_cache_max_len = flags.FLAGS.kv_cache_max_len
584
+ extra_model = model_builder(
585
+ checkpoint_path,
586
+ loader.maybe_get_custom_loader(
587
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
588
+ ),
589
+ extra_kv_cache_max_len,
590
+ )
591
+ extra_prefill_seq_lens = []
592
+ if extra_kv_cache_max_len > _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
593
+ extra_prefill_seq_lens.append(
594
+ _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
595
+ )
596
+ if extra_kv_cache_max_len > _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
597
+ extra_prefill_seq_lens.append(
598
+ _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
599
+ )
600
+ else:
601
+ prefill_seq_lens = flags.FLAGS.prefill_seq_lens
602
+ kv_cache_max_len = flags.FLAGS.kv_cache_max_len
603
+
604
+ convert_to_tflite(
605
+ pytorch_model,
606
+ output_path=flags.FLAGS.output_path,
607
+ output_name_prefix=output_name_prefix,
608
+ prefill_seq_len=prefill_seq_lens,
609
+ kv_cache_max_len=kv_cache_max_len,
610
+ quantize=flags.FLAGS.quantize,
611
+ lora_ranks=flags.FLAGS.lora_ranks,
612
+ export_config=export_config_lib.get_from_flags(),
613
+ extra_model=extra_model,
614
+ extra_prefill_seq_lens=extra_prefill_seq_lens,
615
+ extra_kv_cache_max_len=extra_kv_cache_max_len,
616
+ extra_signature_prefix='test_' if extra_model is not None else '',
461
617
  )
462
- edge_model.export(output_file)
463
618
 
464
619
 
465
620
  def convert_to_litert(
@@ -56,5 +56,7 @@ def get_from_flags() -> ExportConfig:
56
56
  export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
57
57
  if flags.FLAGS.mask_as_input:
58
58
  export_config.mask_as_input = flags.FLAGS.mask_as_input
59
+ if flags.FLAGS.decode_batch_size:
60
+ export_config.decode_batch_size = flags.FLAGS.decode_batch_size
59
61
 
60
62
  return export_config
@@ -18,16 +18,19 @@
18
18
 
19
19
  import os
20
20
  import pathlib
21
+ from google.protobuf import text_format
21
22
 
22
23
  try:
23
24
  # pylint: disable=g-import-not-at-top
24
25
  from ai_edge_litert.internal import llm_metadata_pb2
25
26
  from ai_edge_litert.internal import litertlm_builder
27
+ from ai_edge_litert.internal import llm_model_type_pb2
26
28
  # pylint: enable=g-import-not-at-top
27
29
 
28
30
  _litertlm_builder_available = True
29
31
  except ImportError:
30
32
  llm_metadata_pb2 = None
33
+ llm_model_type_pb2 = None
31
34
  litertlm_builder = None
32
35
  _litertlm_builder_available = False
33
36
 
@@ -41,16 +44,19 @@ def build_litertlm(
41
44
  workdir: str,
42
45
  output_path: str,
43
46
  context_length: int,
44
- model_prompt_prefix: str | None,
45
- model_prompt_suffix: str | None,
46
- user_prompt_prefix: str | None,
47
- user_prompt_suffix: str | None,
48
- tokenizer_model_path: str | None,
49
- hf_tokenizer_model_path: str | None,
47
+ model_prompt_prefix: str | None = None,
48
+ model_prompt_suffix: str | None = None,
49
+ user_prompt_prefix: str | None = None,
50
+ user_prompt_suffix: str | None = None,
51
+ tokenizer_model_path: str | None = None,
52
+ hf_tokenizer_model_path: str | None = None,
50
53
  start_token: str | None = None,
51
54
  start_token_id: int | None = None,
52
55
  stop_tokens: str | list[str] | None = None,
53
56
  stop_token_ids: list[int] | None = None,
57
+ llm_model_type: str = 'generic',
58
+ jinja_prompt_template: str | None = None,
59
+ base_llm_metadata_path: str | None = None,
54
60
  **kwargs,
55
61
  ):
56
62
  """Builds a LiteRT-LM file from a TFlite model."""
@@ -58,10 +64,22 @@ def build_litertlm(
58
64
 
59
65
  if not is_litertlm_builder_available():
60
66
  raise ValueError('LiteRT-LM builder is not available.')
61
- assert llm_metadata_pb2 is not None
62
67
  assert litertlm_builder is not None
68
+ assert llm_metadata_pb2 is not None
69
+ assert llm_model_type_pb2 is not None
63
70
 
64
71
  llm_metadata = llm_metadata_pb2.LlmMetadata()
72
+ if base_llm_metadata_path:
73
+ if base_llm_metadata_path.endswith('.pb'):
74
+ with open(base_llm_metadata_path, 'rb') as f:
75
+ llm_metadata.ParseFromString(f.read())
76
+ elif base_llm_metadata_path.endswith('.textproto'):
77
+ with open(base_llm_metadata_path, 'r') as f:
78
+ text_format.Parse(f.read(), llm_metadata, allow_unknown_field=True)
79
+ else:
80
+ raise ValueError(
81
+ 'Base LLM metadata path must be a binary or text proto file.'
82
+ )
65
83
 
66
84
  if start_token_id is not None:
67
85
  llm_metadata.start_token.token_ids.ids.append(start_token_id)
@@ -96,7 +114,42 @@ def build_litertlm(
96
114
 
97
115
  llm_metadata.max_num_tokens = context_length
98
116
 
99
- llm_metadata_path = os.path.join(workdir, 'llm_metadata.pb')
117
+ mdl_type = llm_metadata.llm_model_type.WhichOneof('model_type')
118
+ if not mdl_type or mdl_type == 'generic_model':
119
+ match llm_model_type:
120
+ case litertlm_builder.LlmModelType.GENERIC:
121
+ llm_metadata.llm_model_type.CopyFrom(
122
+ llm_model_type_pb2.LlmModelType(
123
+ generic_model=llm_model_type_pb2.GenericModel()
124
+ )
125
+ )
126
+ case litertlm_builder.LlmModelType.GEMMA3N:
127
+ llm_metadata.llm_model_type.CopyFrom(
128
+ llm_model_type_pb2.LlmModelType(
129
+ gemma3n=llm_model_type_pb2.Gemma3N()
130
+ )
131
+ )
132
+ case litertlm_builder.LlmModelType.GEMMA3:
133
+ llm_metadata.llm_model_type.CopyFrom(
134
+ llm_model_type_pb2.LlmModelType(gemma3=llm_model_type_pb2.Gemma3())
135
+ )
136
+ case litertlm_builder.LlmModelType.QWEN3:
137
+ llm_metadata.llm_model_type.CopyFrom(
138
+ llm_model_type_pb2.LlmModelType(qwen3=llm_model_type_pb2.Qwen3())
139
+ )
140
+ case litertlm_builder.LlmModelType.QWEN2P5:
141
+ llm_metadata.llm_model_type.CopyFrom(
142
+ llm_model_type_pb2.LlmModelType(
143
+ qwen2p5=llm_model_type_pb2.Qwen2p5()
144
+ )
145
+ )
146
+ case _:
147
+ raise ValueError(f'Unsupported LLM model type: {llm_model_type}')
148
+
149
+ if jinja_prompt_template is not None:
150
+ llm_metadata.jinja_prompt_template = jinja_prompt_template
151
+
152
+ llm_metadata_path = os.path.join(workdir, 'llm_metadata_final.pb')
100
153
  with open(llm_metadata_path, 'wb') as f:
101
154
  f.write(llm_metadata.SerializeToString())
102
155
 
@@ -135,7 +135,8 @@ def load_pytorch_statedict(full_path: str):
135
135
 
136
136
  tensors = {}
137
137
  for file in files:
138
- this_file_tensors = torch.load(file)
138
+ map_location = "cpu" if not torch.cuda.is_available() else None
139
+ this_file_tensors = torch.load(file, map_location=map_location)
139
140
  for k in this_file_tensors:
140
141
  assert k not in tensors
141
142
  tensors.update(this_file_tensors)
@@ -80,8 +80,14 @@ def _get_granularity(
80
80
  return _QuantGranularity.CHANNELWISE
81
81
  if granularity == quant_attrs.Granularity.NONE:
82
82
  return _QuantGranularity.TENSORWISE
83
- if granularity == quant_attrs.Granularity.BLOCKWISE:
84
- return _QuantGranularity.BLOCKWISE
83
+ if granularity == quant_attrs.Granularity.BLOCKWISE_32:
84
+ return _QuantGranularity.BLOCKWISE_32
85
+ if granularity == quant_attrs.Granularity.BLOCKWISE_64:
86
+ return _QuantGranularity.BLOCKWISE_64
87
+ if granularity == quant_attrs.Granularity.BLOCKWISE_128:
88
+ return _QuantGranularity.BLOCKWISE_128
89
+ if granularity == quant_attrs.Granularity.BLOCKWISE_256:
90
+ return _QuantGranularity.BLOCKWISE_256
85
91
  raise ValueError('Unimplemented granularity')
86
92
 
87
93
 
@@ -108,7 +114,6 @@ def _set_quant_config(
108
114
  symmetric=True,
109
115
  granularity=_get_granularity(layer_recipe.granularity),
110
116
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
111
- block_size=layer_recipe.block_size,
112
117
  ),
113
118
  compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
114
119
  explicit_dequantize=_get_explicit_dequant_from_mode(
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,20 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch-TFL ops definitions, decompositions, and lowerings."""
16
+ from ai_edge_torch.odml_torch.experimental.torch_tfl import _decomps
17
+ from ai_edge_torch.odml_torch.experimental.torch_tfl import _lowerings
18
+ from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
19
+
20
+ decomps = _decomps.decomps