ai-edge-torch-nightly 0.7.0.dev20250929__py3-none-any.whl → 0.8.0.dev20251206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +36 -1
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -27
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +1 -20
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +1 -30
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -20
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +183 -28
- ai_edge_torch/generative/utilities/export_config.py +2 -0
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/RECORD +57 -51
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20250929.dist-info → ai_edge_torch_nightly-0.8.0.dev20251206.dist-info}/top_level.txt +0 -0
|
@@ -41,6 +41,20 @@ class TestKVLayers(googletest.TestCase):
|
|
|
41
41
|
)
|
|
42
42
|
return config
|
|
43
43
|
|
|
44
|
+
def _assert_kv_cache_entry_equal(self, kv1, kv2):
|
|
45
|
+
self.assertIsInstance(kv1, kv_utils.KVCacheEntry)
|
|
46
|
+
self.assertIsInstance(kv2, kv_utils.KVCacheEntry)
|
|
47
|
+
self.assertEqual(kv1.kv_layout, kv2.kv_layout)
|
|
48
|
+
self.assertTrue(torch.equal(kv1.k_cache, kv2.k_cache))
|
|
49
|
+
self.assertTrue(torch.equal(kv1.v_cache, kv2.v_cache))
|
|
50
|
+
|
|
51
|
+
def _assert_kv_cache_equal(self, kv1, kv2):
|
|
52
|
+
self.assertIsInstance(kv1, kv_utils.KVCache)
|
|
53
|
+
self.assertIsInstance(kv2, kv_utils.KVCache)
|
|
54
|
+
self.assertEqual(len(kv1.caches), len(kv2.caches))
|
|
55
|
+
for kv1_entry, kv2_entry in zip(kv1.caches, kv2.caches):
|
|
56
|
+
self._assert_kv_cache_entry_equal(kv1_entry, kv2_entry)
|
|
57
|
+
|
|
44
58
|
def test_cache_udpate(self):
|
|
45
59
|
N = 1
|
|
46
60
|
HEAD_DIM = 2
|
|
@@ -118,7 +132,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
118
132
|
flat, treespec = pytree.tree_flatten(kv)
|
|
119
133
|
self.assertLen(flat, NUM_LAYERS * 2)
|
|
120
134
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
121
|
-
self.
|
|
135
|
+
self._assert_kv_cache_equal(kv, kv_unflat)
|
|
122
136
|
|
|
123
137
|
def test_pytree_roundtrip_kv_cache_derived(self):
|
|
124
138
|
NUM_LAYERS = 4
|
|
@@ -134,7 +148,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
134
148
|
flat, treespec = pytree.tree_flatten(kv)
|
|
135
149
|
self.assertLen(flat, NUM_LAYERS * 2)
|
|
136
150
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
137
|
-
self.
|
|
151
|
+
self._assert_kv_cache_equal(kv, kv_unflat)
|
|
138
152
|
|
|
139
153
|
def test_pytree_roundtrip_kv_entry(self):
|
|
140
154
|
attn_config = cfg.AttentionConfig(
|
|
@@ -144,8 +158,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
144
158
|
flat, treespec = pytree.tree_flatten(kv)
|
|
145
159
|
self.assertLen(flat, 2)
|
|
146
160
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
147
|
-
self.
|
|
148
|
-
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
|
161
|
+
self._assert_kv_cache_entry_equal(kv, kv_unflat)
|
|
149
162
|
|
|
150
163
|
def test_pytree_roundtrip_kv_entry_derived(self):
|
|
151
164
|
attn_config = cfg.AttentionConfig(
|
|
@@ -157,8 +170,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
157
170
|
flat, treespec = pytree.tree_flatten(kv)
|
|
158
171
|
self.assertLen(flat, 2)
|
|
159
172
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
160
|
-
self.
|
|
161
|
-
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
|
173
|
+
self._assert_kv_cache_entry_equal(kv, kv_unflat)
|
|
162
174
|
|
|
163
175
|
|
|
164
176
|
if __name__ == "__main__":
|
|
@@ -79,18 +79,18 @@ class TestVerifyRecipes(parameterized.TestCase):
|
|
|
79
79
|
Dtype.INT4,
|
|
80
80
|
Mode.DYNAMIC_RANGE,
|
|
81
81
|
Algorithm.MIN_MAX,
|
|
82
|
-
Granularity.
|
|
83
|
-
|
|
82
|
+
Granularity.BLOCKWISE_32,
|
|
83
|
+
),
|
|
84
|
+
(
|
|
85
|
+
Dtype.FP32,
|
|
86
|
+
Dtype.INT4,
|
|
87
|
+
Mode.DYNAMIC_RANGE,
|
|
88
|
+
Algorithm.MIN_MAX,
|
|
89
|
+
Granularity.BLOCKWISE_128,
|
|
84
90
|
),
|
|
85
91
|
])
|
|
86
92
|
def test_verify_valid_recipes(
|
|
87
|
-
self,
|
|
88
|
-
activation,
|
|
89
|
-
weight,
|
|
90
|
-
mode,
|
|
91
|
-
algo,
|
|
92
|
-
granularity,
|
|
93
|
-
block_size=None,
|
|
93
|
+
self, activation, weight, mode, algo, granularity
|
|
94
94
|
):
|
|
95
95
|
quant_recipe.LayerQuantRecipe(
|
|
96
96
|
activation, weight, mode, algo, granularity
|
|
@@ -108,21 +108,21 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
108
108
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
109
109
|
return quant_config.QuantConfig(
|
|
110
110
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
111
|
-
attention=quant_recipe_utils.
|
|
111
|
+
attention=quant_recipe_utils.create_layer_quant_dynamic(),
|
|
112
112
|
)
|
|
113
113
|
)
|
|
114
114
|
|
|
115
115
|
def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
116
116
|
return quant_config.QuantConfig(
|
|
117
117
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
118
|
-
feedforward=quant_recipe_utils.
|
|
118
|
+
feedforward=quant_recipe_utils.create_layer_quant_dynamic(),
|
|
119
119
|
)
|
|
120
120
|
)
|
|
121
121
|
|
|
122
122
|
@parameterized.parameters([
|
|
123
123
|
(quant_recipes.full_fp16_recipe()),
|
|
124
|
-
(quant_recipes.
|
|
125
|
-
(quant_recipes.
|
|
124
|
+
(quant_recipes.full_dynamic_recipe()),
|
|
125
|
+
(quant_recipes.full_weight_only_recipe()),
|
|
126
126
|
(_attention_int8_dynamic_recipe()),
|
|
127
127
|
(_feedforward_int8_dynamic_recipe()),
|
|
128
128
|
])
|
|
@@ -148,7 +148,7 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
148
148
|
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
149
149
|
input_pos = torch.arange(0, 100, dtype=torch.int)
|
|
150
150
|
|
|
151
|
-
quant_config = quant_recipes.
|
|
151
|
+
quant_config = quant_recipes.full_dynamic_recipe()
|
|
152
152
|
quantized_model = ai_edge_torch.convert(
|
|
153
153
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
154
154
|
)
|
|
@@ -164,7 +164,9 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
164
164
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
|
165
165
|
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
166
166
|
input_pos = torch.arange(0, 100, dtype=torch.int)
|
|
167
|
-
quant_config = quant_recipes.
|
|
167
|
+
quant_config = quant_recipes.full_dynamic_recipe(
|
|
168
|
+
weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32
|
|
169
|
+
)
|
|
168
170
|
quantized_model = ai_edge_torch.convert(
|
|
169
171
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
170
172
|
)
|
|
@@ -175,17 +177,6 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
175
177
|
"Quantized model isn't smaller than F32 model.",
|
|
176
178
|
)
|
|
177
179
|
|
|
178
|
-
def test_unsupported_block_size(self):
|
|
179
|
-
config = toy_model.get_model_config()
|
|
180
|
-
pytorch_model = toy_model.ToySingleLayerModel(config)
|
|
181
|
-
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
182
|
-
input_pos = torch.arange(0, 100, dtype=torch.int)
|
|
183
|
-
self.assertRaises(
|
|
184
|
-
ValueError,
|
|
185
|
-
quant_recipes.all_supported_int4_dynamic_block_recipe,
|
|
186
|
-
36,
|
|
187
|
-
)
|
|
188
|
-
|
|
189
180
|
def test_quantize_convert_compare_toy(self):
|
|
190
181
|
self.skipTest("b/338288901")
|
|
191
182
|
config = toy_model_with_kv_cache.get_model_config()
|
|
@@ -19,15 +19,17 @@ import enum
|
|
|
19
19
|
import os
|
|
20
20
|
import pathlib
|
|
21
21
|
import tempfile
|
|
22
|
-
from typing import
|
|
22
|
+
from typing import Callable, Dict, Optional, Union
|
|
23
23
|
from absl import flags
|
|
24
24
|
from ai_edge_torch._convert import converter as converter_utils
|
|
25
25
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
26
26
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
|
27
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
28
29
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
30
|
from ai_edge_torch.generative.utilities import export_config as export_config_lib
|
|
30
31
|
from ai_edge_torch.generative.utilities import litertlm_builder
|
|
32
|
+
from ai_edge_torch.generative.utilities import loader
|
|
31
33
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
32
34
|
import torch
|
|
33
35
|
|
|
@@ -94,6 +96,11 @@ def define_conversion_flags(
|
|
|
94
96
|
(8, 64, 128, 256, 512, 1024),
|
|
95
97
|
'List of the maximum sizes of prefill input tensors.',
|
|
96
98
|
)
|
|
99
|
+
flags.DEFINE_integer(
|
|
100
|
+
'decode_batch_size',
|
|
101
|
+
1,
|
|
102
|
+
'The batch size for the decode signature.',
|
|
103
|
+
)
|
|
97
104
|
flags.DEFINE_integer(
|
|
98
105
|
'kv_cache_max_len',
|
|
99
106
|
1280,
|
|
@@ -102,14 +109,14 @@ def define_conversion_flags(
|
|
|
102
109
|
flags.DEFINE_string(
|
|
103
110
|
'quantize',
|
|
104
111
|
'dynamic_int8',
|
|
105
|
-
'How the model should be quantized. Set to "none" to disable'
|
|
106
|
-
'
|
|
112
|
+
'How the model should be quantized. Set to "none" to disable '
|
|
113
|
+
'quantization. See `QuantizationName` for supported quantization types.',
|
|
107
114
|
)
|
|
108
115
|
flags.DEFINE_multi_integer(
|
|
109
116
|
'lora_ranks',
|
|
110
117
|
None,
|
|
111
|
-
'If set, the model will be converted with the provided list of LoRA'
|
|
112
|
-
'
|
|
118
|
+
'If set, the model will be converted with the provided list of LoRA '
|
|
119
|
+
'ranks.',
|
|
113
120
|
)
|
|
114
121
|
flags.DEFINE_bool(
|
|
115
122
|
'mask_as_input',
|
|
@@ -125,15 +132,61 @@ def define_conversion_flags(
|
|
|
125
132
|
flags.DEFINE_bool(
|
|
126
133
|
'custom_checkpoint_loader',
|
|
127
134
|
False,
|
|
128
|
-
'If true, the conversion script will use a custom checkpoint loader
|
|
129
|
-
' will read a checkpoint from a remote source.',
|
|
135
|
+
'If true, the conversion script will use a custom checkpoint loader '
|
|
136
|
+
'which will read a checkpoint from a remote source.',
|
|
137
|
+
)
|
|
138
|
+
flags.DEFINE_bool(
|
|
139
|
+
'gpu_dynamic_shapes',
|
|
140
|
+
False,
|
|
141
|
+
'It is to support dynamic shapes on GPU effectively. If true, the graph '
|
|
142
|
+
'sets the actual kv_cache size and prefill lengths when the graph is '
|
|
143
|
+
'initialized for inference based on the flags, `kv_cache_max_len` and '
|
|
144
|
+
'`prefill_seq_lens` as the maximum of kv_cache size and prefill lengths '
|
|
145
|
+
'in the graph.',
|
|
146
|
+
)
|
|
147
|
+
flags.DEFINE_bool(
|
|
148
|
+
'export_gpu_dynamic_shape_verifications',
|
|
149
|
+
False,
|
|
150
|
+
'If true, the conversion script will export signatures used only for '
|
|
151
|
+
'verification of GPU dynamic shapes.',
|
|
130
152
|
)
|
|
131
153
|
return flags
|
|
132
154
|
|
|
133
155
|
|
|
156
|
+
# Context length for verifying GPU dynamic shapes.
|
|
157
|
+
_CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1280
|
|
158
|
+
# Long prefill length for verifying GPU dynamic shapes.
|
|
159
|
+
_LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1024
|
|
160
|
+
# Short prefill length for verifying GPU dynamic shapes.
|
|
161
|
+
_SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 64
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def is_magic_number_(num: int) -> bool:
|
|
165
|
+
"""Returns true if the number is a magic number, i.e. prime number > 10."""
|
|
166
|
+
if num < 10:
|
|
167
|
+
return False
|
|
168
|
+
if num % 2 == 0:
|
|
169
|
+
return False
|
|
170
|
+
for i in range(3, int(num / 2), 2):
|
|
171
|
+
if num % i == 0:
|
|
172
|
+
return False
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_magic_number_for(org_number: int) -> int:
|
|
177
|
+
"""Returns the magic number for the given original number."""
|
|
178
|
+
while not is_magic_number_(org_number):
|
|
179
|
+
org_number += 1
|
|
180
|
+
return org_number
|
|
181
|
+
|
|
182
|
+
|
|
134
183
|
def get_mask_cache_size_from_flags() -> int:
|
|
135
184
|
"""Returns the mask cache size according to the flags."""
|
|
136
|
-
|
|
185
|
+
if flags.FLAGS.mask_as_input:
|
|
186
|
+
return 0
|
|
187
|
+
if flags.FLAGS.gpu_dynamic_shapes:
|
|
188
|
+
return get_magic_number_for(flags.FLAGS.kv_cache_max_len)
|
|
189
|
+
return flags.FLAGS.kv_cache_max_len
|
|
137
190
|
|
|
138
191
|
|
|
139
192
|
def get_quant_recipe_from_flag(
|
|
@@ -155,18 +208,22 @@ def get_quant_recipe_from_flag(
|
|
|
155
208
|
case QuantizationName.NONE:
|
|
156
209
|
return None
|
|
157
210
|
case QuantizationName.DYNAMIC_INT8:
|
|
158
|
-
return quant_recipes.
|
|
211
|
+
return quant_recipes.full_dynamic_recipe(mcfg=model_config)
|
|
159
212
|
case QuantizationName.WEIGHT_ONLY_INT8:
|
|
160
|
-
return quant_recipes.
|
|
213
|
+
return quant_recipes.full_weight_only_recipe(mcfg=model_config)
|
|
161
214
|
case QuantizationName.FP16:
|
|
162
215
|
return quant_recipes.full_fp16_recipe()
|
|
163
216
|
case QuantizationName.DYNAMIC_INT4_BLOCK32:
|
|
164
|
-
return quant_recipes.
|
|
165
|
-
|
|
217
|
+
return quant_recipes.full_dynamic_recipe(
|
|
218
|
+
mcfg=model_config,
|
|
219
|
+
weight_dtype=quant_attrs.Dtype.INT4,
|
|
220
|
+
granularity=quant_attrs.Granularity.BLOCKWISE_32,
|
|
166
221
|
)
|
|
167
222
|
case QuantizationName.DYNAMIC_INT4_BLOCK128:
|
|
168
|
-
return quant_recipes.
|
|
169
|
-
|
|
223
|
+
return quant_recipes.full_dynamic_recipe(
|
|
224
|
+
mcfg=model_config,
|
|
225
|
+
weight_dtype=quant_attrs.Dtype.INT4,
|
|
226
|
+
granularity=quant_attrs.Granularity.BLOCKWISE_128,
|
|
170
227
|
)
|
|
171
228
|
case _:
|
|
172
229
|
raise ValueError(f'Unsupported quantization flag: {quantize}')
|
|
@@ -225,6 +282,10 @@ def convert_to_tflite(
|
|
|
225
282
|
config: cfg.ModelConfig = None,
|
|
226
283
|
lora_ranks: Optional[list[int]] = None,
|
|
227
284
|
export_config: ExportConfig = None,
|
|
285
|
+
extra_model: torch.nn.Module = None,
|
|
286
|
+
extra_prefill_seq_lens: list[int] = None,
|
|
287
|
+
extra_kv_cache_max_len: int = 0,
|
|
288
|
+
extra_signature_prefix: str = '',
|
|
228
289
|
):
|
|
229
290
|
"""Converts a nn.Module model to multi-signature tflite model.
|
|
230
291
|
|
|
@@ -277,6 +338,15 @@ def convert_to_tflite(
|
|
|
277
338
|
no LoRA signatures will be added.
|
|
278
339
|
export_config (ExportConfig, optional): The export configuration. If None,
|
|
279
340
|
it uses the default export configuration.
|
|
341
|
+
extra_model (torch.nn.Module, optional): PyTorch model to export in
|
|
342
|
+
addition to the pytorch_model. This model can have different
|
|
343
|
+
prefill_seq_lens and kv_cache_max_len.
|
|
344
|
+
extra_prefill_seq_lens (list[int], optional): The prefill sequence
|
|
345
|
+
lengths for extra_model. Meaningful only when extra_model is not None.
|
|
346
|
+
extra_kv_cache_max_len (int, optional): The maximum size of KV cache
|
|
347
|
+
buffer for extra_model. Meaningful only when extra_model is not None.
|
|
348
|
+
extra_signature_prefix (str, optional): The prefix of the extra model
|
|
349
|
+
signatures. Meaningful only when extra_model is not None.
|
|
280
350
|
"""
|
|
281
351
|
# pylint: disable=protected-access
|
|
282
352
|
torch._dynamo.config.cache_size_limit = 64
|
|
@@ -315,32 +385,51 @@ def convert_to_tflite(
|
|
|
315
385
|
)
|
|
316
386
|
output_file = os.path.join(output_path, output_filename)
|
|
317
387
|
|
|
318
|
-
|
|
388
|
+
converter = converter_utils.Converter()
|
|
389
|
+
_add_signatures(
|
|
390
|
+
converter,
|
|
319
391
|
pytorch_model,
|
|
320
|
-
output_file,
|
|
321
392
|
prefill_seq_lens,
|
|
322
393
|
kv_cache_max_len,
|
|
323
394
|
pixel_values_size,
|
|
324
395
|
pixel_seq_len,
|
|
325
|
-
quantize,
|
|
326
396
|
config,
|
|
327
397
|
loras,
|
|
328
398
|
export_config,
|
|
329
399
|
)
|
|
400
|
+
|
|
401
|
+
if extra_model is not None and extra_prefill_seq_lens:
|
|
402
|
+
_add_signatures(
|
|
403
|
+
converter,
|
|
404
|
+
extra_model,
|
|
405
|
+
extra_prefill_seq_lens,
|
|
406
|
+
extra_kv_cache_max_len,
|
|
407
|
+
pixel_values_size,
|
|
408
|
+
pixel_seq_len,
|
|
409
|
+
config,
|
|
410
|
+
loras,
|
|
411
|
+
export_config,
|
|
412
|
+
signature_prefix=extra_signature_prefix,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
edge_model = converter.convert(
|
|
416
|
+
quant_config=get_quant_recipe_from_flag(quantize, config),
|
|
417
|
+
)
|
|
418
|
+
edge_model.export(output_file)
|
|
330
419
|
return output_file
|
|
331
420
|
|
|
332
421
|
|
|
333
|
-
def
|
|
422
|
+
def _add_signatures(
|
|
423
|
+
converter: converter_utils.Converter,
|
|
334
424
|
pytorch_model: torch.nn.Module,
|
|
335
|
-
output_file: str,
|
|
336
425
|
prefill_seq_lens: list[int],
|
|
337
426
|
kv_cache_max_len: int,
|
|
338
427
|
pixel_values_size: torch.Size,
|
|
339
428
|
pixel_seq_len: int,
|
|
340
|
-
quantize: str,
|
|
341
429
|
config: cfg.ModelConfig,
|
|
342
430
|
loras: list[None | lora_utils.LoRA],
|
|
343
431
|
export_config: ExportConfig,
|
|
432
|
+
signature_prefix: str = '',
|
|
344
433
|
):
|
|
345
434
|
"""Helper function to export a model to tflite."""
|
|
346
435
|
prefill_tokens_list = []
|
|
@@ -385,17 +474,14 @@ def _export_helper(
|
|
|
385
474
|
kv_layout=export_config.kvcache_layout,
|
|
386
475
|
)
|
|
387
476
|
|
|
388
|
-
quant_config = get_quant_recipe_from_flag(quantize, config)
|
|
389
|
-
|
|
390
477
|
# For export, we create a module that captures any non-exportable,
|
|
391
478
|
# arugments, e.g. the generation config object.
|
|
392
479
|
mod = ExportableModule(pytorch_model, export_config=export_config).eval()
|
|
393
480
|
|
|
394
|
-
converter = converter_utils.Converter()
|
|
395
481
|
for lora in loras:
|
|
396
482
|
for i in range(len(prefill_seq_lens)):
|
|
397
483
|
prefill_seq_len = prefill_seq_lens[i]
|
|
398
|
-
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
|
484
|
+
prefill_signature_name = f'{signature_prefix}prefill_{prefill_seq_len}'
|
|
399
485
|
|
|
400
486
|
sample_kwargs = {
|
|
401
487
|
'tokens': prefill_tokens_list[i],
|
|
@@ -450,16 +536,85 @@ def _export_helper(
|
|
|
450
536
|
if lora is not None:
|
|
451
537
|
sample_kwargs['lora'] = lora
|
|
452
538
|
|
|
539
|
+
decode_signature_name = f'{signature_prefix}decode'
|
|
540
|
+
if lora is not None:
|
|
541
|
+
decode_signature_name += f'_lora_r{lora.get_rank()}'
|
|
453
542
|
converter.add_signature(
|
|
454
|
-
|
|
543
|
+
decode_signature_name,
|
|
455
544
|
mod,
|
|
456
545
|
sample_kwargs=sample_kwargs,
|
|
457
546
|
)
|
|
458
547
|
|
|
459
|
-
|
|
460
|
-
|
|
548
|
+
|
|
549
|
+
def build_and_convert_to_tflite_from_flags(
|
|
550
|
+
model_builder: Callable[
|
|
551
|
+
[str, Callable[[str], Dict[str, torch.Tensor]], int], torch.nn.Module
|
|
552
|
+
],
|
|
553
|
+
checkpoint_path: str = None,
|
|
554
|
+
output_name_prefix: str = None,
|
|
555
|
+
):
|
|
556
|
+
"""Builds a nn.Module model and converts it according to the flags."""
|
|
557
|
+
if checkpoint_path is None:
|
|
558
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
|
559
|
+
if output_name_prefix is None:
|
|
560
|
+
output_name_prefix = flags.FLAGS.output_name_prefix
|
|
561
|
+
|
|
562
|
+
pytorch_model = model_builder(
|
|
563
|
+
checkpoint_path,
|
|
564
|
+
loader.maybe_get_custom_loader(
|
|
565
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
566
|
+
),
|
|
567
|
+
get_mask_cache_size_from_flags(),
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# Extra model for GPU dynamic shape verification if needed.
|
|
571
|
+
extra_model = None
|
|
572
|
+
extra_prefill_seq_lens = None
|
|
573
|
+
extra_kv_cache_max_len = 0
|
|
574
|
+
if flags.FLAGS.gpu_dynamic_shapes:
|
|
575
|
+
prefill_seq_lens = [
|
|
576
|
+
get_magic_number_for(l) for l in flags.FLAGS.prefill_seq_lens
|
|
577
|
+
]
|
|
578
|
+
kv_cache_max_len = get_magic_number_for(flags.FLAGS.kv_cache_max_len)
|
|
579
|
+
|
|
580
|
+
if flags.FLAGS.export_gpu_dynamic_shape_verifications:
|
|
581
|
+
extra_kv_cache_max_len = _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
|
582
|
+
if extra_kv_cache_max_len > flags.FLAGS.kv_cache_max_len:
|
|
583
|
+
extra_kv_cache_max_len = flags.FLAGS.kv_cache_max_len
|
|
584
|
+
extra_model = model_builder(
|
|
585
|
+
checkpoint_path,
|
|
586
|
+
loader.maybe_get_custom_loader(
|
|
587
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
588
|
+
),
|
|
589
|
+
extra_kv_cache_max_len,
|
|
590
|
+
)
|
|
591
|
+
extra_prefill_seq_lens = []
|
|
592
|
+
if extra_kv_cache_max_len > _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
|
|
593
|
+
extra_prefill_seq_lens.append(
|
|
594
|
+
_SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
|
595
|
+
)
|
|
596
|
+
if extra_kv_cache_max_len > _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
|
|
597
|
+
extra_prefill_seq_lens.append(
|
|
598
|
+
_LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
|
599
|
+
)
|
|
600
|
+
else:
|
|
601
|
+
prefill_seq_lens = flags.FLAGS.prefill_seq_lens
|
|
602
|
+
kv_cache_max_len = flags.FLAGS.kv_cache_max_len
|
|
603
|
+
|
|
604
|
+
convert_to_tflite(
|
|
605
|
+
pytorch_model,
|
|
606
|
+
output_path=flags.FLAGS.output_path,
|
|
607
|
+
output_name_prefix=output_name_prefix,
|
|
608
|
+
prefill_seq_len=prefill_seq_lens,
|
|
609
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
610
|
+
quantize=flags.FLAGS.quantize,
|
|
611
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
|
612
|
+
export_config=export_config_lib.get_from_flags(),
|
|
613
|
+
extra_model=extra_model,
|
|
614
|
+
extra_prefill_seq_lens=extra_prefill_seq_lens,
|
|
615
|
+
extra_kv_cache_max_len=extra_kv_cache_max_len,
|
|
616
|
+
extra_signature_prefix='test_' if extra_model is not None else '',
|
|
461
617
|
)
|
|
462
|
-
edge_model.export(output_file)
|
|
463
618
|
|
|
464
619
|
|
|
465
620
|
def convert_to_litert(
|
|
@@ -56,5 +56,7 @@ def get_from_flags() -> ExportConfig:
|
|
|
56
56
|
export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
|
|
57
57
|
if flags.FLAGS.mask_as_input:
|
|
58
58
|
export_config.mask_as_input = flags.FLAGS.mask_as_input
|
|
59
|
+
if flags.FLAGS.decode_batch_size:
|
|
60
|
+
export_config.decode_batch_size = flags.FLAGS.decode_batch_size
|
|
59
61
|
|
|
60
62
|
return export_config
|
|
@@ -18,16 +18,19 @@
|
|
|
18
18
|
|
|
19
19
|
import os
|
|
20
20
|
import pathlib
|
|
21
|
+
from google.protobuf import text_format
|
|
21
22
|
|
|
22
23
|
try:
|
|
23
24
|
# pylint: disable=g-import-not-at-top
|
|
24
25
|
from ai_edge_litert.internal import llm_metadata_pb2
|
|
25
26
|
from ai_edge_litert.internal import litertlm_builder
|
|
27
|
+
from ai_edge_litert.internal import llm_model_type_pb2
|
|
26
28
|
# pylint: enable=g-import-not-at-top
|
|
27
29
|
|
|
28
30
|
_litertlm_builder_available = True
|
|
29
31
|
except ImportError:
|
|
30
32
|
llm_metadata_pb2 = None
|
|
33
|
+
llm_model_type_pb2 = None
|
|
31
34
|
litertlm_builder = None
|
|
32
35
|
_litertlm_builder_available = False
|
|
33
36
|
|
|
@@ -41,16 +44,19 @@ def build_litertlm(
|
|
|
41
44
|
workdir: str,
|
|
42
45
|
output_path: str,
|
|
43
46
|
context_length: int,
|
|
44
|
-
model_prompt_prefix: str | None,
|
|
45
|
-
model_prompt_suffix: str | None,
|
|
46
|
-
user_prompt_prefix: str | None,
|
|
47
|
-
user_prompt_suffix: str | None,
|
|
48
|
-
tokenizer_model_path: str | None,
|
|
49
|
-
hf_tokenizer_model_path: str | None,
|
|
47
|
+
model_prompt_prefix: str | None = None,
|
|
48
|
+
model_prompt_suffix: str | None = None,
|
|
49
|
+
user_prompt_prefix: str | None = None,
|
|
50
|
+
user_prompt_suffix: str | None = None,
|
|
51
|
+
tokenizer_model_path: str | None = None,
|
|
52
|
+
hf_tokenizer_model_path: str | None = None,
|
|
50
53
|
start_token: str | None = None,
|
|
51
54
|
start_token_id: int | None = None,
|
|
52
55
|
stop_tokens: str | list[str] | None = None,
|
|
53
56
|
stop_token_ids: list[int] | None = None,
|
|
57
|
+
llm_model_type: str = 'generic',
|
|
58
|
+
jinja_prompt_template: str | None = None,
|
|
59
|
+
base_llm_metadata_path: str | None = None,
|
|
54
60
|
**kwargs,
|
|
55
61
|
):
|
|
56
62
|
"""Builds a LiteRT-LM file from a TFlite model."""
|
|
@@ -58,10 +64,22 @@ def build_litertlm(
|
|
|
58
64
|
|
|
59
65
|
if not is_litertlm_builder_available():
|
|
60
66
|
raise ValueError('LiteRT-LM builder is not available.')
|
|
61
|
-
assert llm_metadata_pb2 is not None
|
|
62
67
|
assert litertlm_builder is not None
|
|
68
|
+
assert llm_metadata_pb2 is not None
|
|
69
|
+
assert llm_model_type_pb2 is not None
|
|
63
70
|
|
|
64
71
|
llm_metadata = llm_metadata_pb2.LlmMetadata()
|
|
72
|
+
if base_llm_metadata_path:
|
|
73
|
+
if base_llm_metadata_path.endswith('.pb'):
|
|
74
|
+
with open(base_llm_metadata_path, 'rb') as f:
|
|
75
|
+
llm_metadata.ParseFromString(f.read())
|
|
76
|
+
elif base_llm_metadata_path.endswith('.textproto'):
|
|
77
|
+
with open(base_llm_metadata_path, 'r') as f:
|
|
78
|
+
text_format.Parse(f.read(), llm_metadata, allow_unknown_field=True)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
'Base LLM metadata path must be a binary or text proto file.'
|
|
82
|
+
)
|
|
65
83
|
|
|
66
84
|
if start_token_id is not None:
|
|
67
85
|
llm_metadata.start_token.token_ids.ids.append(start_token_id)
|
|
@@ -96,7 +114,42 @@ def build_litertlm(
|
|
|
96
114
|
|
|
97
115
|
llm_metadata.max_num_tokens = context_length
|
|
98
116
|
|
|
99
|
-
|
|
117
|
+
mdl_type = llm_metadata.llm_model_type.WhichOneof('model_type')
|
|
118
|
+
if not mdl_type or mdl_type == 'generic_model':
|
|
119
|
+
match llm_model_type:
|
|
120
|
+
case litertlm_builder.LlmModelType.GENERIC:
|
|
121
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
122
|
+
llm_model_type_pb2.LlmModelType(
|
|
123
|
+
generic_model=llm_model_type_pb2.GenericModel()
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
case litertlm_builder.LlmModelType.GEMMA3N:
|
|
127
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
128
|
+
llm_model_type_pb2.LlmModelType(
|
|
129
|
+
gemma3n=llm_model_type_pb2.Gemma3N()
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
case litertlm_builder.LlmModelType.GEMMA3:
|
|
133
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
134
|
+
llm_model_type_pb2.LlmModelType(gemma3=llm_model_type_pb2.Gemma3())
|
|
135
|
+
)
|
|
136
|
+
case litertlm_builder.LlmModelType.QWEN3:
|
|
137
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
138
|
+
llm_model_type_pb2.LlmModelType(qwen3=llm_model_type_pb2.Qwen3())
|
|
139
|
+
)
|
|
140
|
+
case litertlm_builder.LlmModelType.QWEN2P5:
|
|
141
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
142
|
+
llm_model_type_pb2.LlmModelType(
|
|
143
|
+
qwen2p5=llm_model_type_pb2.Qwen2p5()
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
case _:
|
|
147
|
+
raise ValueError(f'Unsupported LLM model type: {llm_model_type}')
|
|
148
|
+
|
|
149
|
+
if jinja_prompt_template is not None:
|
|
150
|
+
llm_metadata.jinja_prompt_template = jinja_prompt_template
|
|
151
|
+
|
|
152
|
+
llm_metadata_path = os.path.join(workdir, 'llm_metadata_final.pb')
|
|
100
153
|
with open(llm_metadata_path, 'wb') as f:
|
|
101
154
|
f.write(llm_metadata.SerializeToString())
|
|
102
155
|
|
|
@@ -135,7 +135,8 @@ def load_pytorch_statedict(full_path: str):
|
|
|
135
135
|
|
|
136
136
|
tensors = {}
|
|
137
137
|
for file in files:
|
|
138
|
-
|
|
138
|
+
map_location = "cpu" if not torch.cuda.is_available() else None
|
|
139
|
+
this_file_tensors = torch.load(file, map_location=map_location)
|
|
139
140
|
for k in this_file_tensors:
|
|
140
141
|
assert k not in tensors
|
|
141
142
|
tensors.update(this_file_tensors)
|
|
@@ -80,8 +80,14 @@ def _get_granularity(
|
|
|
80
80
|
return _QuantGranularity.CHANNELWISE
|
|
81
81
|
if granularity == quant_attrs.Granularity.NONE:
|
|
82
82
|
return _QuantGranularity.TENSORWISE
|
|
83
|
-
if granularity == quant_attrs.Granularity.
|
|
84
|
-
return _QuantGranularity.
|
|
83
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_32:
|
|
84
|
+
return _QuantGranularity.BLOCKWISE_32
|
|
85
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_64:
|
|
86
|
+
return _QuantGranularity.BLOCKWISE_64
|
|
87
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_128:
|
|
88
|
+
return _QuantGranularity.BLOCKWISE_128
|
|
89
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_256:
|
|
90
|
+
return _QuantGranularity.BLOCKWISE_256
|
|
85
91
|
raise ValueError('Unimplemented granularity')
|
|
86
92
|
|
|
87
93
|
|
|
@@ -108,7 +114,6 @@ def _set_quant_config(
|
|
|
108
114
|
symmetric=True,
|
|
109
115
|
granularity=_get_granularity(layer_recipe.granularity),
|
|
110
116
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
|
111
|
-
block_size=layer_recipe.block_size,
|
|
112
117
|
),
|
|
113
118
|
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
|
|
114
119
|
explicit_dequantize=_get_explicit_dequant_from_mode(
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Torch-TFL ops definitions, decompositions, and lowerings."""
|
|
16
|
+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _decomps
|
|
17
|
+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _lowerings
|
|
18
|
+
from ai_edge_torch.odml_torch.experimental.torch_tfl import _ops
|
|
19
|
+
|
|
20
|
+
decomps = _decomps.decomps
|