ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/__init__.py +1 -0
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +97 -22
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +94 -2
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/RECORD +42 -36
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/top_level.txt +0 -0
|
@@ -32,23 +32,29 @@ from ai_edge_torch.generative.quantize import quant_attrs
|
|
|
32
32
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def
|
|
35
|
+
def create_layer_quant_dynamic(
|
|
36
|
+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
|
|
37
|
+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
|
|
38
|
+
) -> quant_recipe.LayerQuantRecipe:
|
|
36
39
|
return quant_recipe.LayerQuantRecipe(
|
|
37
40
|
activation_dtype=quant_attrs.Dtype.FP32,
|
|
38
|
-
weight_dtype=
|
|
41
|
+
weight_dtype=weight_dtype,
|
|
39
42
|
mode=quant_attrs.Mode.DYNAMIC_RANGE,
|
|
40
43
|
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
41
|
-
granularity=
|
|
44
|
+
granularity=granularity,
|
|
42
45
|
)
|
|
43
46
|
|
|
44
47
|
|
|
45
|
-
def
|
|
48
|
+
def create_layer_quant_weight_only(
|
|
49
|
+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
|
|
50
|
+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
|
|
51
|
+
) -> quant_recipe.LayerQuantRecipe:
|
|
46
52
|
return quant_recipe.LayerQuantRecipe(
|
|
47
53
|
activation_dtype=quant_attrs.Dtype.FP32,
|
|
48
|
-
weight_dtype=
|
|
54
|
+
weight_dtype=weight_dtype,
|
|
49
55
|
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
50
56
|
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
51
|
-
granularity=
|
|
57
|
+
granularity=granularity,
|
|
52
58
|
)
|
|
53
59
|
|
|
54
60
|
|
|
@@ -60,16 +66,3 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
|
60
66
|
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
|
|
61
67
|
granularity=quant_attrs.Granularity.NONE,
|
|
62
68
|
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def create_layer_quant_int4_dynamic_block(
|
|
66
|
-
block_size: int,
|
|
67
|
-
) -> quant_recipe.LayerQuantRecipe:
|
|
68
|
-
return quant_recipe.LayerQuantRecipe(
|
|
69
|
-
activation_dtype=quant_attrs.Dtype.FP32,
|
|
70
|
-
weight_dtype=quant_attrs.Dtype.INT4,
|
|
71
|
-
mode=quant_attrs.Mode.DYNAMIC_RANGE,
|
|
72
|
-
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
73
|
-
granularity=quant_attrs.Granularity.BLOCKWISE,
|
|
74
|
-
block_size=block_size,
|
|
75
|
-
)
|
|
@@ -29,35 +29,44 @@ Typical usage example:
|
|
|
29
29
|
|
|
30
30
|
from typing import Optional
|
|
31
31
|
from ai_edge_torch.generative.layers import model_config
|
|
32
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
32
33
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
33
34
|
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
|
34
35
|
from ai_edge_torch.quantize import quant_config
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
def
|
|
38
|
-
mcfg:
|
|
38
|
+
def full_dynamic_recipe(
|
|
39
|
+
mcfg: model_config.ModelConfig | None = None,
|
|
40
|
+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
|
|
41
|
+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
|
|
39
42
|
) -> quant_config.QuantConfig:
|
|
40
43
|
return quant_config.QuantConfig(
|
|
41
44
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
42
|
-
default=quant_recipe_utils.
|
|
45
|
+
default=quant_recipe_utils.create_layer_quant_dynamic(
|
|
46
|
+
weight_dtype, granularity
|
|
47
|
+
),
|
|
43
48
|
_model_config=mcfg,
|
|
44
49
|
)
|
|
45
50
|
)
|
|
46
51
|
|
|
47
52
|
|
|
48
|
-
def
|
|
49
|
-
mcfg:
|
|
53
|
+
def full_weight_only_recipe(
|
|
54
|
+
mcfg: model_config.ModelConfig | None = None,
|
|
55
|
+
weight_dtype: quant_attrs.Dtype = quant_attrs.Dtype.INT8,
|
|
56
|
+
granularity: quant_attrs.Granularity = quant_attrs.Granularity.CHANNELWISE,
|
|
50
57
|
) -> quant_config.QuantConfig:
|
|
51
58
|
return quant_config.QuantConfig(
|
|
52
59
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
53
|
-
default=quant_recipe_utils.
|
|
60
|
+
default=quant_recipe_utils.create_layer_quant_weight_only(
|
|
61
|
+
weight_dtype, granularity
|
|
62
|
+
),
|
|
54
63
|
_model_config=mcfg,
|
|
55
64
|
)
|
|
56
65
|
)
|
|
57
66
|
|
|
58
67
|
|
|
59
68
|
def full_fp16_recipe(
|
|
60
|
-
mcfg:
|
|
69
|
+
mcfg: model_config.ModelConfig | None = None,
|
|
61
70
|
) -> quant_config.QuantConfig:
|
|
62
71
|
return quant_config.QuantConfig(
|
|
63
72
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
@@ -65,17 +74,3 @@ def full_fp16_recipe(
|
|
|
65
74
|
_model_config=mcfg,
|
|
66
75
|
)
|
|
67
76
|
)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def all_supported_int4_dynamic_block_recipe(
|
|
71
|
-
block_size: int,
|
|
72
|
-
mcfg: Optional[model_config.ModelConfig] = None,
|
|
73
|
-
) -> quant_config.QuantConfig:
|
|
74
|
-
return quant_config.QuantConfig(
|
|
75
|
-
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
76
|
-
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
|
|
77
|
-
block_size
|
|
78
|
-
),
|
|
79
|
-
_model_config=mcfg,
|
|
80
|
-
)
|
|
81
|
-
)
|
|
@@ -29,5 +29,8 @@ def get_supported_layer_schemes():
|
|
|
29
29
|
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
|
|
30
30
|
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
|
|
31
31
|
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
|
|
32
|
-
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.
|
|
32
|
+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_32),
|
|
33
|
+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_64),
|
|
34
|
+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_128),
|
|
35
|
+
(_t.FP32, _t.INT4, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.BLOCKWISE_256),
|
|
33
36
|
]
|
|
@@ -41,6 +41,20 @@ class TestKVLayers(googletest.TestCase):
|
|
|
41
41
|
)
|
|
42
42
|
return config
|
|
43
43
|
|
|
44
|
+
def _assert_kv_cache_entry_equal(self, kv1, kv2):
|
|
45
|
+
self.assertIsInstance(kv1, kv_utils.KVCacheEntry)
|
|
46
|
+
self.assertIsInstance(kv2, kv_utils.KVCacheEntry)
|
|
47
|
+
self.assertEqual(kv1.kv_layout, kv2.kv_layout)
|
|
48
|
+
self.assertTrue(torch.equal(kv1.k_cache, kv2.k_cache))
|
|
49
|
+
self.assertTrue(torch.equal(kv1.v_cache, kv2.v_cache))
|
|
50
|
+
|
|
51
|
+
def _assert_kv_cache_equal(self, kv1, kv2):
|
|
52
|
+
self.assertIsInstance(kv1, kv_utils.KVCache)
|
|
53
|
+
self.assertIsInstance(kv2, kv_utils.KVCache)
|
|
54
|
+
self.assertEqual(len(kv1.caches), len(kv2.caches))
|
|
55
|
+
for kv1_entry, kv2_entry in zip(kv1.caches, kv2.caches):
|
|
56
|
+
self._assert_kv_cache_entry_equal(kv1_entry, kv2_entry)
|
|
57
|
+
|
|
44
58
|
def test_cache_udpate(self):
|
|
45
59
|
N = 1
|
|
46
60
|
HEAD_DIM = 2
|
|
@@ -118,7 +132,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
118
132
|
flat, treespec = pytree.tree_flatten(kv)
|
|
119
133
|
self.assertLen(flat, NUM_LAYERS * 2)
|
|
120
134
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
121
|
-
self.
|
|
135
|
+
self._assert_kv_cache_equal(kv, kv_unflat)
|
|
122
136
|
|
|
123
137
|
def test_pytree_roundtrip_kv_cache_derived(self):
|
|
124
138
|
NUM_LAYERS = 4
|
|
@@ -134,7 +148,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
134
148
|
flat, treespec = pytree.tree_flatten(kv)
|
|
135
149
|
self.assertLen(flat, NUM_LAYERS * 2)
|
|
136
150
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
137
|
-
self.
|
|
151
|
+
self._assert_kv_cache_equal(kv, kv_unflat)
|
|
138
152
|
|
|
139
153
|
def test_pytree_roundtrip_kv_entry(self):
|
|
140
154
|
attn_config = cfg.AttentionConfig(
|
|
@@ -144,8 +158,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
144
158
|
flat, treespec = pytree.tree_flatten(kv)
|
|
145
159
|
self.assertLen(flat, 2)
|
|
146
160
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
147
|
-
self.
|
|
148
|
-
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
|
161
|
+
self._assert_kv_cache_entry_equal(kv, kv_unflat)
|
|
149
162
|
|
|
150
163
|
def test_pytree_roundtrip_kv_entry_derived(self):
|
|
151
164
|
attn_config = cfg.AttentionConfig(
|
|
@@ -157,8 +170,7 @@ class TestKVLayers(googletest.TestCase):
|
|
|
157
170
|
flat, treespec = pytree.tree_flatten(kv)
|
|
158
171
|
self.assertLen(flat, 2)
|
|
159
172
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
|
160
|
-
self.
|
|
161
|
-
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
|
173
|
+
self._assert_kv_cache_entry_equal(kv, kv_unflat)
|
|
162
174
|
|
|
163
175
|
|
|
164
176
|
if __name__ == "__main__":
|
|
@@ -79,18 +79,18 @@ class TestVerifyRecipes(parameterized.TestCase):
|
|
|
79
79
|
Dtype.INT4,
|
|
80
80
|
Mode.DYNAMIC_RANGE,
|
|
81
81
|
Algorithm.MIN_MAX,
|
|
82
|
-
Granularity.
|
|
83
|
-
|
|
82
|
+
Granularity.BLOCKWISE_32,
|
|
83
|
+
),
|
|
84
|
+
(
|
|
85
|
+
Dtype.FP32,
|
|
86
|
+
Dtype.INT4,
|
|
87
|
+
Mode.DYNAMIC_RANGE,
|
|
88
|
+
Algorithm.MIN_MAX,
|
|
89
|
+
Granularity.BLOCKWISE_128,
|
|
84
90
|
),
|
|
85
91
|
])
|
|
86
92
|
def test_verify_valid_recipes(
|
|
87
|
-
self,
|
|
88
|
-
activation,
|
|
89
|
-
weight,
|
|
90
|
-
mode,
|
|
91
|
-
algo,
|
|
92
|
-
granularity,
|
|
93
|
-
block_size=None,
|
|
93
|
+
self, activation, weight, mode, algo, granularity
|
|
94
94
|
):
|
|
95
95
|
quant_recipe.LayerQuantRecipe(
|
|
96
96
|
activation, weight, mode, algo, granularity
|
|
@@ -108,21 +108,21 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
108
108
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
109
109
|
return quant_config.QuantConfig(
|
|
110
110
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
111
|
-
attention=quant_recipe_utils.
|
|
111
|
+
attention=quant_recipe_utils.create_layer_quant_dynamic(),
|
|
112
112
|
)
|
|
113
113
|
)
|
|
114
114
|
|
|
115
115
|
def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
116
116
|
return quant_config.QuantConfig(
|
|
117
117
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
118
|
-
feedforward=quant_recipe_utils.
|
|
118
|
+
feedforward=quant_recipe_utils.create_layer_quant_dynamic(),
|
|
119
119
|
)
|
|
120
120
|
)
|
|
121
121
|
|
|
122
122
|
@parameterized.parameters([
|
|
123
123
|
(quant_recipes.full_fp16_recipe()),
|
|
124
|
-
(quant_recipes.
|
|
125
|
-
(quant_recipes.
|
|
124
|
+
(quant_recipes.full_dynamic_recipe()),
|
|
125
|
+
(quant_recipes.full_weight_only_recipe()),
|
|
126
126
|
(_attention_int8_dynamic_recipe()),
|
|
127
127
|
(_feedforward_int8_dynamic_recipe()),
|
|
128
128
|
])
|
|
@@ -148,7 +148,7 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
148
148
|
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
149
149
|
input_pos = torch.arange(0, 100, dtype=torch.int)
|
|
150
150
|
|
|
151
|
-
quant_config = quant_recipes.
|
|
151
|
+
quant_config = quant_recipes.full_dynamic_recipe()
|
|
152
152
|
quantized_model = ai_edge_torch.convert(
|
|
153
153
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
154
154
|
)
|
|
@@ -164,7 +164,9 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
164
164
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
|
165
165
|
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
166
166
|
input_pos = torch.arange(0, 100, dtype=torch.int)
|
|
167
|
-
quant_config = quant_recipes.
|
|
167
|
+
quant_config = quant_recipes.full_dynamic_recipe(
|
|
168
|
+
weight_dtype=Dtype.INT4, granularity=Granularity.BLOCKWISE_32
|
|
169
|
+
)
|
|
168
170
|
quantized_model = ai_edge_torch.convert(
|
|
169
171
|
pytorch_model, (idx, input_pos), quant_config=quant_config
|
|
170
172
|
)
|
|
@@ -175,17 +177,6 @@ class TestQuantizeConvert(parameterized.TestCase):
|
|
|
175
177
|
"Quantized model isn't smaller than F32 model.",
|
|
176
178
|
)
|
|
177
179
|
|
|
178
|
-
def test_unsupported_block_size(self):
|
|
179
|
-
config = toy_model.get_model_config()
|
|
180
|
-
pytorch_model = toy_model.ToySingleLayerModel(config)
|
|
181
|
-
idx = torch.unsqueeze(torch.arange(0, 100, dtype=torch.int), 0)
|
|
182
|
-
input_pos = torch.arange(0, 100, dtype=torch.int)
|
|
183
|
-
self.assertRaises(
|
|
184
|
-
ValueError,
|
|
185
|
-
quant_recipes.all_supported_int4_dynamic_block_recipe,
|
|
186
|
-
36,
|
|
187
|
-
)
|
|
188
|
-
|
|
189
180
|
def test_quantize_convert_compare_toy(self):
|
|
190
181
|
self.skipTest("b/338288901")
|
|
191
182
|
config = toy_model_with_kv_cache.get_model_config()
|
|
@@ -25,6 +25,7 @@ from ai_edge_torch._convert import converter as converter_utils
|
|
|
25
25
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
26
26
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
|
27
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
28
|
+
from ai_edge_torch.generative.quantize import quant_attrs
|
|
28
29
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
30
|
from ai_edge_torch.generative.utilities import export_config as export_config_lib
|
|
30
31
|
from ai_edge_torch.generative.utilities import litertlm_builder
|
|
@@ -143,9 +144,23 @@ def define_conversion_flags(
|
|
|
143
144
|
'`prefill_seq_lens` as the maximum of kv_cache size and prefill lengths '
|
|
144
145
|
'in the graph.',
|
|
145
146
|
)
|
|
147
|
+
flags.DEFINE_bool(
|
|
148
|
+
'export_gpu_dynamic_shape_verifications',
|
|
149
|
+
False,
|
|
150
|
+
'If true, the conversion script will export signatures used only for '
|
|
151
|
+
'verification of GPU dynamic shapes.',
|
|
152
|
+
)
|
|
146
153
|
return flags
|
|
147
154
|
|
|
148
155
|
|
|
156
|
+
# Context length for verifying GPU dynamic shapes.
|
|
157
|
+
_CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1280
|
|
158
|
+
# Long prefill length for verifying GPU dynamic shapes.
|
|
159
|
+
_LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 1024
|
|
160
|
+
# Short prefill length for verifying GPU dynamic shapes.
|
|
161
|
+
_SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS = 64
|
|
162
|
+
|
|
163
|
+
|
|
149
164
|
def is_magic_number_(num: int) -> bool:
|
|
150
165
|
"""Returns true if the number is a magic number, i.e. prime number > 10."""
|
|
151
166
|
if num < 10:
|
|
@@ -193,18 +208,22 @@ def get_quant_recipe_from_flag(
|
|
|
193
208
|
case QuantizationName.NONE:
|
|
194
209
|
return None
|
|
195
210
|
case QuantizationName.DYNAMIC_INT8:
|
|
196
|
-
return quant_recipes.
|
|
211
|
+
return quant_recipes.full_dynamic_recipe(mcfg=model_config)
|
|
197
212
|
case QuantizationName.WEIGHT_ONLY_INT8:
|
|
198
|
-
return quant_recipes.
|
|
213
|
+
return quant_recipes.full_weight_only_recipe(mcfg=model_config)
|
|
199
214
|
case QuantizationName.FP16:
|
|
200
215
|
return quant_recipes.full_fp16_recipe()
|
|
201
216
|
case QuantizationName.DYNAMIC_INT4_BLOCK32:
|
|
202
|
-
return quant_recipes.
|
|
203
|
-
|
|
217
|
+
return quant_recipes.full_dynamic_recipe(
|
|
218
|
+
mcfg=model_config,
|
|
219
|
+
weight_dtype=quant_attrs.Dtype.INT4,
|
|
220
|
+
granularity=quant_attrs.Granularity.BLOCKWISE_32,
|
|
204
221
|
)
|
|
205
222
|
case QuantizationName.DYNAMIC_INT4_BLOCK128:
|
|
206
|
-
return quant_recipes.
|
|
207
|
-
|
|
223
|
+
return quant_recipes.full_dynamic_recipe(
|
|
224
|
+
mcfg=model_config,
|
|
225
|
+
weight_dtype=quant_attrs.Dtype.INT4,
|
|
226
|
+
granularity=quant_attrs.Granularity.BLOCKWISE_128,
|
|
208
227
|
)
|
|
209
228
|
case _:
|
|
210
229
|
raise ValueError(f'Unsupported quantization flag: {quantize}')
|
|
@@ -263,6 +282,10 @@ def convert_to_tflite(
|
|
|
263
282
|
config: cfg.ModelConfig = None,
|
|
264
283
|
lora_ranks: Optional[list[int]] = None,
|
|
265
284
|
export_config: ExportConfig = None,
|
|
285
|
+
extra_model: torch.nn.Module = None,
|
|
286
|
+
extra_prefill_seq_lens: list[int] = None,
|
|
287
|
+
extra_kv_cache_max_len: int = 0,
|
|
288
|
+
extra_signature_prefix: str = '',
|
|
266
289
|
):
|
|
267
290
|
"""Converts a nn.Module model to multi-signature tflite model.
|
|
268
291
|
|
|
@@ -315,6 +338,15 @@ def convert_to_tflite(
|
|
|
315
338
|
no LoRA signatures will be added.
|
|
316
339
|
export_config (ExportConfig, optional): The export configuration. If None,
|
|
317
340
|
it uses the default export configuration.
|
|
341
|
+
extra_model (torch.nn.Module, optional): PyTorch model to export in
|
|
342
|
+
addition to the pytorch_model. This model can have different
|
|
343
|
+
prefill_seq_lens and kv_cache_max_len.
|
|
344
|
+
extra_prefill_seq_lens (list[int], optional): The prefill sequence
|
|
345
|
+
lengths for extra_model. Meaningful only when extra_model is not None.
|
|
346
|
+
extra_kv_cache_max_len (int, optional): The maximum size of KV cache
|
|
347
|
+
buffer for extra_model. Meaningful only when extra_model is not None.
|
|
348
|
+
extra_signature_prefix (str, optional): The prefix of the extra model
|
|
349
|
+
signatures. Meaningful only when extra_model is not None.
|
|
318
350
|
"""
|
|
319
351
|
# pylint: disable=protected-access
|
|
320
352
|
torch._dynamo.config.cache_size_limit = 64
|
|
@@ -353,32 +385,51 @@ def convert_to_tflite(
|
|
|
353
385
|
)
|
|
354
386
|
output_file = os.path.join(output_path, output_filename)
|
|
355
387
|
|
|
356
|
-
|
|
388
|
+
converter = converter_utils.Converter()
|
|
389
|
+
_add_signatures(
|
|
390
|
+
converter,
|
|
357
391
|
pytorch_model,
|
|
358
|
-
output_file,
|
|
359
392
|
prefill_seq_lens,
|
|
360
393
|
kv_cache_max_len,
|
|
361
394
|
pixel_values_size,
|
|
362
395
|
pixel_seq_len,
|
|
363
|
-
quantize,
|
|
364
396
|
config,
|
|
365
397
|
loras,
|
|
366
398
|
export_config,
|
|
367
399
|
)
|
|
400
|
+
|
|
401
|
+
if extra_model is not None and extra_prefill_seq_lens:
|
|
402
|
+
_add_signatures(
|
|
403
|
+
converter,
|
|
404
|
+
extra_model,
|
|
405
|
+
extra_prefill_seq_lens,
|
|
406
|
+
extra_kv_cache_max_len,
|
|
407
|
+
pixel_values_size,
|
|
408
|
+
pixel_seq_len,
|
|
409
|
+
config,
|
|
410
|
+
loras,
|
|
411
|
+
export_config,
|
|
412
|
+
signature_prefix=extra_signature_prefix,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
edge_model = converter.convert(
|
|
416
|
+
quant_config=get_quant_recipe_from_flag(quantize, config),
|
|
417
|
+
)
|
|
418
|
+
edge_model.export(output_file)
|
|
368
419
|
return output_file
|
|
369
420
|
|
|
370
421
|
|
|
371
|
-
def
|
|
422
|
+
def _add_signatures(
|
|
423
|
+
converter: converter_utils.Converter,
|
|
372
424
|
pytorch_model: torch.nn.Module,
|
|
373
|
-
output_file: str,
|
|
374
425
|
prefill_seq_lens: list[int],
|
|
375
426
|
kv_cache_max_len: int,
|
|
376
427
|
pixel_values_size: torch.Size,
|
|
377
428
|
pixel_seq_len: int,
|
|
378
|
-
quantize: str,
|
|
379
429
|
config: cfg.ModelConfig,
|
|
380
430
|
loras: list[None | lora_utils.LoRA],
|
|
381
431
|
export_config: ExportConfig,
|
|
432
|
+
signature_prefix: str = '',
|
|
382
433
|
):
|
|
383
434
|
"""Helper function to export a model to tflite."""
|
|
384
435
|
prefill_tokens_list = []
|
|
@@ -423,17 +474,14 @@ def _export_helper(
|
|
|
423
474
|
kv_layout=export_config.kvcache_layout,
|
|
424
475
|
)
|
|
425
476
|
|
|
426
|
-
quant_config = get_quant_recipe_from_flag(quantize, config)
|
|
427
|
-
|
|
428
477
|
# For export, we create a module that captures any non-exportable,
|
|
429
478
|
# arugments, e.g. the generation config object.
|
|
430
479
|
mod = ExportableModule(pytorch_model, export_config=export_config).eval()
|
|
431
480
|
|
|
432
|
-
converter = converter_utils.Converter()
|
|
433
481
|
for lora in loras:
|
|
434
482
|
for i in range(len(prefill_seq_lens)):
|
|
435
483
|
prefill_seq_len = prefill_seq_lens[i]
|
|
436
|
-
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
|
484
|
+
prefill_signature_name = f'{signature_prefix}prefill_{prefill_seq_len}'
|
|
437
485
|
|
|
438
486
|
sample_kwargs = {
|
|
439
487
|
'tokens': prefill_tokens_list[i],
|
|
@@ -488,17 +536,15 @@ def _export_helper(
|
|
|
488
536
|
if lora is not None:
|
|
489
537
|
sample_kwargs['lora'] = lora
|
|
490
538
|
|
|
539
|
+
decode_signature_name = f'{signature_prefix}decode'
|
|
540
|
+
if lora is not None:
|
|
541
|
+
decode_signature_name += f'_lora_r{lora.get_rank()}'
|
|
491
542
|
converter.add_signature(
|
|
492
|
-
|
|
543
|
+
decode_signature_name,
|
|
493
544
|
mod,
|
|
494
545
|
sample_kwargs=sample_kwargs,
|
|
495
546
|
)
|
|
496
547
|
|
|
497
|
-
edge_model = converter.convert(
|
|
498
|
-
quant_config=quant_config,
|
|
499
|
-
)
|
|
500
|
-
edge_model.export(output_file)
|
|
501
|
-
|
|
502
548
|
|
|
503
549
|
def build_and_convert_to_tflite_from_flags(
|
|
504
550
|
model_builder: Callable[
|
|
@@ -521,11 +567,36 @@ def build_and_convert_to_tflite_from_flags(
|
|
|
521
567
|
get_mask_cache_size_from_flags(),
|
|
522
568
|
)
|
|
523
569
|
|
|
570
|
+
# Extra model for GPU dynamic shape verification if needed.
|
|
571
|
+
extra_model = None
|
|
572
|
+
extra_prefill_seq_lens = None
|
|
573
|
+
extra_kv_cache_max_len = 0
|
|
524
574
|
if flags.FLAGS.gpu_dynamic_shapes:
|
|
525
575
|
prefill_seq_lens = [
|
|
526
576
|
get_magic_number_for(l) for l in flags.FLAGS.prefill_seq_lens
|
|
527
577
|
]
|
|
528
578
|
kv_cache_max_len = get_magic_number_for(flags.FLAGS.kv_cache_max_len)
|
|
579
|
+
|
|
580
|
+
if flags.FLAGS.export_gpu_dynamic_shape_verifications:
|
|
581
|
+
extra_kv_cache_max_len = _CONTEXT_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
|
582
|
+
if extra_kv_cache_max_len > flags.FLAGS.kv_cache_max_len:
|
|
583
|
+
extra_kv_cache_max_len = flags.FLAGS.kv_cache_max_len
|
|
584
|
+
extra_model = model_builder(
|
|
585
|
+
checkpoint_path,
|
|
586
|
+
loader.maybe_get_custom_loader(
|
|
587
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
|
588
|
+
),
|
|
589
|
+
extra_kv_cache_max_len,
|
|
590
|
+
)
|
|
591
|
+
extra_prefill_seq_lens = []
|
|
592
|
+
if extra_kv_cache_max_len > _SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
|
|
593
|
+
extra_prefill_seq_lens.append(
|
|
594
|
+
_SHORT_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
|
595
|
+
)
|
|
596
|
+
if extra_kv_cache_max_len > _LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS:
|
|
597
|
+
extra_prefill_seq_lens.append(
|
|
598
|
+
_LONG_PREFILL_LENGTH_TO_VERIFY_MAGIC_NUMBERS
|
|
599
|
+
)
|
|
529
600
|
else:
|
|
530
601
|
prefill_seq_lens = flags.FLAGS.prefill_seq_lens
|
|
531
602
|
kv_cache_max_len = flags.FLAGS.kv_cache_max_len
|
|
@@ -539,6 +610,10 @@ def build_and_convert_to_tflite_from_flags(
|
|
|
539
610
|
quantize=flags.FLAGS.quantize,
|
|
540
611
|
lora_ranks=flags.FLAGS.lora_ranks,
|
|
541
612
|
export_config=export_config_lib.get_from_flags(),
|
|
613
|
+
extra_model=extra_model,
|
|
614
|
+
extra_prefill_seq_lens=extra_prefill_seq_lens,
|
|
615
|
+
extra_kv_cache_max_len=extra_kv_cache_max_len,
|
|
616
|
+
extra_signature_prefix='test_' if extra_model is not None else '',
|
|
542
617
|
)
|
|
543
618
|
|
|
544
619
|
|
|
@@ -18,16 +18,19 @@
|
|
|
18
18
|
|
|
19
19
|
import os
|
|
20
20
|
import pathlib
|
|
21
|
+
from google.protobuf import text_format
|
|
21
22
|
|
|
22
23
|
try:
|
|
23
24
|
# pylint: disable=g-import-not-at-top
|
|
24
25
|
from ai_edge_litert.internal import llm_metadata_pb2
|
|
25
26
|
from ai_edge_litert.internal import litertlm_builder
|
|
27
|
+
from ai_edge_litert.internal import llm_model_type_pb2
|
|
26
28
|
# pylint: enable=g-import-not-at-top
|
|
27
29
|
|
|
28
30
|
_litertlm_builder_available = True
|
|
29
31
|
except ImportError:
|
|
30
32
|
llm_metadata_pb2 = None
|
|
33
|
+
llm_model_type_pb2 = None
|
|
31
34
|
litertlm_builder = None
|
|
32
35
|
_litertlm_builder_available = False
|
|
33
36
|
|
|
@@ -41,16 +44,19 @@ def build_litertlm(
|
|
|
41
44
|
workdir: str,
|
|
42
45
|
output_path: str,
|
|
43
46
|
context_length: int,
|
|
44
|
-
model_prompt_prefix: str | None,
|
|
45
|
-
model_prompt_suffix: str | None,
|
|
46
|
-
user_prompt_prefix: str | None,
|
|
47
|
-
user_prompt_suffix: str | None,
|
|
48
|
-
tokenizer_model_path: str | None,
|
|
49
|
-
hf_tokenizer_model_path: str | None,
|
|
47
|
+
model_prompt_prefix: str | None = None,
|
|
48
|
+
model_prompt_suffix: str | None = None,
|
|
49
|
+
user_prompt_prefix: str | None = None,
|
|
50
|
+
user_prompt_suffix: str | None = None,
|
|
51
|
+
tokenizer_model_path: str | None = None,
|
|
52
|
+
hf_tokenizer_model_path: str | None = None,
|
|
50
53
|
start_token: str | None = None,
|
|
51
54
|
start_token_id: int | None = None,
|
|
52
55
|
stop_tokens: str | list[str] | None = None,
|
|
53
56
|
stop_token_ids: list[int] | None = None,
|
|
57
|
+
llm_model_type: str = 'generic',
|
|
58
|
+
jinja_prompt_template: str | None = None,
|
|
59
|
+
base_llm_metadata_path: str | None = None,
|
|
54
60
|
**kwargs,
|
|
55
61
|
):
|
|
56
62
|
"""Builds a LiteRT-LM file from a TFlite model."""
|
|
@@ -58,10 +64,22 @@ def build_litertlm(
|
|
|
58
64
|
|
|
59
65
|
if not is_litertlm_builder_available():
|
|
60
66
|
raise ValueError('LiteRT-LM builder is not available.')
|
|
61
|
-
assert llm_metadata_pb2 is not None
|
|
62
67
|
assert litertlm_builder is not None
|
|
68
|
+
assert llm_metadata_pb2 is not None
|
|
69
|
+
assert llm_model_type_pb2 is not None
|
|
63
70
|
|
|
64
71
|
llm_metadata = llm_metadata_pb2.LlmMetadata()
|
|
72
|
+
if base_llm_metadata_path:
|
|
73
|
+
if base_llm_metadata_path.endswith('.pb'):
|
|
74
|
+
with open(base_llm_metadata_path, 'rb') as f:
|
|
75
|
+
llm_metadata.ParseFromString(f.read())
|
|
76
|
+
elif base_llm_metadata_path.endswith('.textproto'):
|
|
77
|
+
with open(base_llm_metadata_path, 'r') as f:
|
|
78
|
+
text_format.Parse(f.read(), llm_metadata, allow_unknown_field=True)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
'Base LLM metadata path must be a binary or text proto file.'
|
|
82
|
+
)
|
|
65
83
|
|
|
66
84
|
if start_token_id is not None:
|
|
67
85
|
llm_metadata.start_token.token_ids.ids.append(start_token_id)
|
|
@@ -96,7 +114,42 @@ def build_litertlm(
|
|
|
96
114
|
|
|
97
115
|
llm_metadata.max_num_tokens = context_length
|
|
98
116
|
|
|
99
|
-
|
|
117
|
+
mdl_type = llm_metadata.llm_model_type.WhichOneof('model_type')
|
|
118
|
+
if not mdl_type or mdl_type == 'generic_model':
|
|
119
|
+
match llm_model_type:
|
|
120
|
+
case litertlm_builder.LlmModelType.GENERIC:
|
|
121
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
122
|
+
llm_model_type_pb2.LlmModelType(
|
|
123
|
+
generic_model=llm_model_type_pb2.GenericModel()
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
case litertlm_builder.LlmModelType.GEMMA3N:
|
|
127
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
128
|
+
llm_model_type_pb2.LlmModelType(
|
|
129
|
+
gemma3n=llm_model_type_pb2.Gemma3N()
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
case litertlm_builder.LlmModelType.GEMMA3:
|
|
133
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
134
|
+
llm_model_type_pb2.LlmModelType(gemma3=llm_model_type_pb2.Gemma3())
|
|
135
|
+
)
|
|
136
|
+
case litertlm_builder.LlmModelType.QWEN3:
|
|
137
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
138
|
+
llm_model_type_pb2.LlmModelType(qwen3=llm_model_type_pb2.Qwen3())
|
|
139
|
+
)
|
|
140
|
+
case litertlm_builder.LlmModelType.QWEN2P5:
|
|
141
|
+
llm_metadata.llm_model_type.CopyFrom(
|
|
142
|
+
llm_model_type_pb2.LlmModelType(
|
|
143
|
+
qwen2p5=llm_model_type_pb2.Qwen2p5()
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
case _:
|
|
147
|
+
raise ValueError(f'Unsupported LLM model type: {llm_model_type}')
|
|
148
|
+
|
|
149
|
+
if jinja_prompt_template is not None:
|
|
150
|
+
llm_metadata.jinja_prompt_template = jinja_prompt_template
|
|
151
|
+
|
|
152
|
+
llm_metadata_path = os.path.join(workdir, 'llm_metadata_final.pb')
|
|
100
153
|
with open(llm_metadata_path, 'wb') as f:
|
|
101
154
|
f.write(llm_metadata.SerializeToString())
|
|
102
155
|
|
|
@@ -135,7 +135,8 @@ def load_pytorch_statedict(full_path: str):
|
|
|
135
135
|
|
|
136
136
|
tensors = {}
|
|
137
137
|
for file in files:
|
|
138
|
-
|
|
138
|
+
map_location = "cpu" if not torch.cuda.is_available() else None
|
|
139
|
+
this_file_tensors = torch.load(file, map_location=map_location)
|
|
139
140
|
for k in this_file_tensors:
|
|
140
141
|
assert k not in tensors
|
|
141
142
|
tensors.update(this_file_tensors)
|
|
@@ -80,8 +80,14 @@ def _get_granularity(
|
|
|
80
80
|
return _QuantGranularity.CHANNELWISE
|
|
81
81
|
if granularity == quant_attrs.Granularity.NONE:
|
|
82
82
|
return _QuantGranularity.TENSORWISE
|
|
83
|
-
if granularity == quant_attrs.Granularity.
|
|
84
|
-
return _QuantGranularity.
|
|
83
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_32:
|
|
84
|
+
return _QuantGranularity.BLOCKWISE_32
|
|
85
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_64:
|
|
86
|
+
return _QuantGranularity.BLOCKWISE_64
|
|
87
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_128:
|
|
88
|
+
return _QuantGranularity.BLOCKWISE_128
|
|
89
|
+
if granularity == quant_attrs.Granularity.BLOCKWISE_256:
|
|
90
|
+
return _QuantGranularity.BLOCKWISE_256
|
|
85
91
|
raise ValueError('Unimplemented granularity')
|
|
86
92
|
|
|
87
93
|
|
|
@@ -108,7 +114,6 @@ def _set_quant_config(
|
|
|
108
114
|
symmetric=True,
|
|
109
115
|
granularity=_get_granularity(layer_recipe.granularity),
|
|
110
116
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
|
111
|
-
block_size=layer_recipe.block_size,
|
|
112
117
|
),
|
|
113
118
|
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
|
|
114
119
|
explicit_dequantize=_get_explicit_dequant_from_mode(
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|