ai-edge-torch-nightly 0.5.0.dev20250508__py3-none-any.whl → 0.5.0.dev20250510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/gemma3/decoder.py +8 -33
- ai_edge_torch/generative/quantize/quant_recipe.py +5 -1
- ai_edge_torch/generative/quantize/quant_recipes.py +2 -1
- ai_edge_torch/generative/test/test_quantize.py +0 -1
- ai_edge_torch/generative/utilities/converter.py +102 -13
- ai_edge_torch/generative/utilities/model_builder.py +1 -2
- ai_edge_torch/lowertools/_shim.py +1 -1
- ai_edge_torch/lowertools/translate_recipe.py +14 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250510.dist-info}/METADATA +2 -1
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250510.dist-info}/RECORD +16 -16
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250510.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250510.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250508.dist-info → ai_edge_torch_nightly-0.5.0.dev20250510.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma import gemma2
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
|
-
flags = converter.define_conversion_flags(
|
23
|
+
flags = converter.define_conversion_flags(
|
24
|
+
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
|
25
|
+
)
|
24
26
|
|
25
27
|
|
26
28
|
def main(_):
|
@@ -20,7 +20,9 @@ from ai_edge_torch.generative.examples.gemma3 import gemma3
|
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
|
-
flags = converter.define_conversion_flags(
|
23
|
+
flags = converter.define_conversion_flags(
|
24
|
+
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
|
25
|
+
)
|
24
26
|
|
25
27
|
_MODEL_SIZE = flags.DEFINE_string(
|
26
28
|
'model_size',
|
@@ -119,9 +119,7 @@ class Decoder(nn.Module):
|
|
119
119
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
120
120
|
)
|
121
121
|
self.lm_head = nn.Linear(
|
122
|
-
config.embedding_dim,
|
123
|
-
config.vocab_size,
|
124
|
-
bias=config.lm_head_use_bias,
|
122
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
125
123
|
)
|
126
124
|
# Gemma3 re-uses the embedding as the head projection layer.
|
127
125
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
@@ -130,30 +128,13 @@ class Decoder(nn.Module):
|
|
130
128
|
for idx in range(config.num_layers)
|
131
129
|
)
|
132
130
|
self.final_norm = builder.build_norm(
|
133
|
-
config.embedding_dim,
|
134
|
-
config.final_norm_config,
|
131
|
+
config.embedding_dim, config.final_norm_config
|
135
132
|
)
|
136
133
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
137
134
|
size=config.kv_cache_max,
|
138
135
|
)
|
139
|
-
# Gemma3 has same hyper parameters for each layer except for attention
|
140
|
-
# types. Use the first layer.
|
141
|
-
attn_config = config.block_config(0).attn_config
|
142
|
-
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
143
|
-
size=config.kv_cache_max,
|
144
|
-
window_size=attn_config.sliding_window_size,
|
145
|
-
)
|
146
136
|
self.config = config
|
147
137
|
|
148
|
-
def get_attention_mask(
|
149
|
-
self,
|
150
|
-
attn_type: cfg.AttentionType,
|
151
|
-
input_pos: torch.Tensor,
|
152
|
-
) -> torch.Tensor:
|
153
|
-
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
154
|
-
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
155
|
-
return self.mask_cache.index_select(2, input_pos)
|
156
|
-
|
157
138
|
def get_local_global_attention_mask(
|
158
139
|
self,
|
159
140
|
attention_mask: torch.Tensor,
|
@@ -200,9 +181,7 @@ class Decoder(nn.Module):
|
|
200
181
|
sliding_mask_bool,
|
201
182
|
torch.zeros_like(sliding_mask_bool, dtype=torch.float),
|
202
183
|
torch.full_like(
|
203
|
-
sliding_mask_bool,
|
204
|
-
self.config.causal_mask_value,
|
205
|
-
dtype=torch.float,
|
184
|
+
sliding_mask_bool, self.config.causal_mask_value, dtype=torch.float
|
206
185
|
),
|
207
186
|
)
|
208
187
|
|
@@ -261,7 +240,6 @@ class Decoder(nn.Module):
|
|
261
240
|
pixel_mask = self.build_pixel_mask(image_indices)
|
262
241
|
# RoPE parameters are the same for all blocks. Use the first layer.
|
263
242
|
attn_config = self.config.block_config(0).attn_config
|
264
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
265
243
|
# Different rotary base for global and local attention
|
266
244
|
# based on attention pattern
|
267
245
|
rope = [
|
@@ -273,12 +251,8 @@ class Decoder(nn.Module):
|
|
273
251
|
for i in range(self.config.num_layers)
|
274
252
|
]
|
275
253
|
if mask is None:
|
276
|
-
mask =
|
277
|
-
|
278
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
279
|
-
)
|
280
|
-
for i in range(self.config.num_layers)
|
281
|
-
]
|
254
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
255
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
282
256
|
|
283
257
|
return self._forward_with_embeds(
|
284
258
|
input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
|
@@ -305,7 +279,7 @@ class Decoder(nn.Module):
|
|
305
279
|
if pixel_mask is None:
|
306
280
|
mask = [
|
307
281
|
self.get_local_global_attention_mask(
|
308
|
-
mask,
|
282
|
+
mask[i] if isinstance(mask, list) else mask,
|
309
283
|
self.config.block_config(i).attn_config.attn_type,
|
310
284
|
input_pos,
|
311
285
|
self.config.block_config(i).attn_config.sliding_window_size,
|
@@ -316,7 +290,7 @@ class Decoder(nn.Module):
|
|
316
290
|
pixel_mask = pixel_mask.index_select(2, input_pos)
|
317
291
|
mask = [
|
318
292
|
self.compose_mask(
|
319
|
-
mask[i],
|
293
|
+
mask[i] if isinstance(mask, list) else mask,
|
320
294
|
pixel_mask,
|
321
295
|
self.config.block_config(i).attn_config.attn_type,
|
322
296
|
)
|
@@ -330,6 +304,7 @@ class Decoder(nn.Module):
|
|
330
304
|
if kv_entry:
|
331
305
|
updated_kv_entries.append(kv_entry)
|
332
306
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
307
|
+
|
333
308
|
if export_config is not None:
|
334
309
|
if (
|
335
310
|
torch.numel(input_pos) > 1
|
@@ -16,9 +16,12 @@
|
|
16
16
|
from dataclasses import dataclass
|
17
17
|
from typing import Optional, Union
|
18
18
|
|
19
|
+
from ai_edge_torch.generative.layers import model_config
|
19
20
|
from ai_edge_torch.generative.quantize import quant_attrs
|
20
21
|
from ai_edge_torch.generative.quantize import supported_schemes
|
21
22
|
|
23
|
+
ModelConfig = model_config.ModelConfig
|
24
|
+
|
22
25
|
|
23
26
|
@dataclass
|
24
27
|
class LayerQuantRecipe:
|
@@ -52,7 +55,7 @@ class LayerQuantRecipe:
|
|
52
55
|
f'w:{self.weight_dtype.name}, '
|
53
56
|
f'{self.mode.name}, '
|
54
57
|
f'{self.algorithm.name}, '
|
55
|
-
f'{self.granularity.name}'
|
58
|
+
f'{self.granularity.name}, '
|
56
59
|
f'{self.block_size}'
|
57
60
|
)
|
58
61
|
return f'{base_str})'
|
@@ -133,6 +136,7 @@ class GenerativeQuantRecipe:
|
|
133
136
|
feedforward: Union[
|
134
137
|
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
|
135
138
|
] = None
|
139
|
+
_model_config: Optional[ModelConfig] = None
|
136
140
|
|
137
141
|
def __str__(self):
|
138
142
|
return f"""GenerativeQuantRecipe(
|
@@ -63,6 +63,7 @@ def all_supported_int4_dynamic_block_recipe(
|
|
63
63
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
64
64
|
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
|
65
65
|
block_size
|
66
|
-
)
|
66
|
+
),
|
67
|
+
embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
67
68
|
)
|
68
69
|
)
|
@@ -14,7 +14,6 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import ai_edge_torch
|
17
|
-
from ai_edge_torch import config
|
18
17
|
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
|
19
18
|
from ai_edge_torch.generative.quantize import quant_recipe
|
20
19
|
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
+
import enum
|
18
19
|
import os
|
19
20
|
import pathlib
|
20
21
|
from typing import Optional, Union
|
@@ -42,7 +43,32 @@ class ExportableModule(torch.nn.Module):
|
|
42
43
|
return self.module(*export_args, **full_kwargs)
|
43
44
|
|
44
45
|
|
45
|
-
|
46
|
+
class QuantizationName(str, enum.Enum):
|
47
|
+
"""Strings for all supported quantization recipes.
|
48
|
+
|
49
|
+
none: No quantization.
|
50
|
+
dynamic_int8: Dynamic range quantization with int8 weights.
|
51
|
+
weight_only_int8: Weight only quantization with int8 weights.
|
52
|
+
fp16: Float16 quantization.
|
53
|
+
dynamic_int4_block32: Dynamic range quantization with int4 weights and block
|
54
|
+
size of 32, better model quality but slower inference.
|
55
|
+
dynamic_int4_block128: Dynamic range quantization with int4 weights and block
|
56
|
+
size of 128, faster inference but worse model quality.
|
57
|
+
"""
|
58
|
+
|
59
|
+
NONE = 'none'
|
60
|
+
DYNAMIC_INT8 = 'dynamic_int8'
|
61
|
+
WEIGHT_ONLY_INT8 = 'weight_only_int8'
|
62
|
+
FP16 = 'fp16'
|
63
|
+
DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
|
64
|
+
DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'
|
65
|
+
|
66
|
+
|
67
|
+
def define_conversion_flags(
|
68
|
+
model_name: str,
|
69
|
+
default_mask_as_input: bool = False,
|
70
|
+
default_transpose_kv_cache: bool = False,
|
71
|
+
):
|
46
72
|
"""Defines common flags used for model conversion."""
|
47
73
|
|
48
74
|
flags.DEFINE_string(
|
@@ -70,10 +96,10 @@ def define_conversion_flags(model_name: str):
|
|
70
96
|
1280,
|
71
97
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
72
98
|
)
|
73
|
-
flags.
|
99
|
+
flags.DEFINE_string(
|
74
100
|
'quantize',
|
75
|
-
|
76
|
-
'
|
101
|
+
'dynamic_int8',
|
102
|
+
'How the model should be quantized.',
|
77
103
|
)
|
78
104
|
flags.DEFINE_multi_integer(
|
79
105
|
'lora_ranks',
|
@@ -83,18 +109,78 @@ def define_conversion_flags(model_name: str):
|
|
83
109
|
)
|
84
110
|
flags.DEFINE_bool(
|
85
111
|
'mask_as_input',
|
86
|
-
|
112
|
+
default_mask_as_input,
|
87
113
|
'If true, the mask will be passed in as input. Otherwise, mask will be '
|
88
114
|
'built by the model internally.',
|
89
115
|
)
|
90
116
|
flags.DEFINE_bool(
|
91
117
|
'transpose_kv_cache',
|
92
|
-
|
118
|
+
default_transpose_kv_cache,
|
93
119
|
'If true, the model will be converted with transposed KV cache.',
|
94
120
|
)
|
95
121
|
return flags
|
96
122
|
|
97
123
|
|
124
|
+
def get_quant_recipe_from_flag(
|
125
|
+
quantize: str,
|
126
|
+
) -> Optional[quant_recipes.QuantizationRecipe]:
|
127
|
+
"""Processes the quantization flag and returns the corresponding recipe.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
quantize: The quantization type.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
The quantization recipe, or None if no quantization is needed.
|
134
|
+
|
135
|
+
Raises:
|
136
|
+
ValueError: If the quantization type is not supported.
|
137
|
+
"""
|
138
|
+
match quantize:
|
139
|
+
case QuantizationName.NONE:
|
140
|
+
return None
|
141
|
+
case QuantizationName.DYNAMIC_INT8:
|
142
|
+
return quant_recipes.full_int8_dynamic_recipe()
|
143
|
+
case QuantizationName.WEIGHT_ONLY_INT8:
|
144
|
+
return quant_recipes.full_int8_weight_only_recipe()
|
145
|
+
case QuantizationName.FP16:
|
146
|
+
return quant_recipes.full_fp16_recipe()
|
147
|
+
case QuantizationName.DYNAMIC_INT4_BLOCK32:
|
148
|
+
return quant_recipes.full_int4_dynamic_block_recipe(32)
|
149
|
+
case QuantizationName.DYNAMIC_INT4_BLOCK128:
|
150
|
+
return quant_recipes.full_int4_dynamic_block_recipe(128)
|
151
|
+
case _:
|
152
|
+
raise ValueError(f'Unsupported quantization flag: {quantize}')
|
153
|
+
|
154
|
+
|
155
|
+
def create_quantize_suffix(quantize: str) -> str:
|
156
|
+
"""Creates a suffix for the output file name based on the quantization type.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
quantize: The quantization type.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
A string representing the quantization suffix.
|
163
|
+
|
164
|
+
Raises:
|
165
|
+
ValueError: If the quantization type is not supported.
|
166
|
+
"""
|
167
|
+
match quantize:
|
168
|
+
case QuantizationName.NONE:
|
169
|
+
return 'f32'
|
170
|
+
case QuantizationName.DYNAMIC_INT8:
|
171
|
+
return 'q8'
|
172
|
+
case QuantizationName.WEIGHT_ONLY_INT8:
|
173
|
+
return 'q8_wo'
|
174
|
+
case QuantizationName.FP16:
|
175
|
+
return 'fp16'
|
176
|
+
case QuantizationName.DYNAMIC_INT4_BLOCK32:
|
177
|
+
return 'q4_block32'
|
178
|
+
case QuantizationName.DYNAMIC_INT4_BLOCK128:
|
179
|
+
return 'q4_block128'
|
180
|
+
case _:
|
181
|
+
raise ValueError(f'Unsupported quantization flag: {quantize}')
|
182
|
+
|
183
|
+
|
98
184
|
def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
|
99
185
|
if isinstance(mask_len, list):
|
100
186
|
return [
|
@@ -114,7 +200,7 @@ def convert_to_tflite(
|
|
114
200
|
prefill_seq_len: Union[int, list[int]],
|
115
201
|
pixel_values_size: torch.Size = None,
|
116
202
|
pixel_seq_len: int = 0,
|
117
|
-
quantize:
|
203
|
+
quantize: str = 'dynamic_int8',
|
118
204
|
config: cfg.ModelConfig = None,
|
119
205
|
lora_ranks: Optional[list[int]] = None,
|
120
206
|
export_config: ExportConfig = None,
|
@@ -160,8 +246,8 @@ def convert_to_tflite(
|
|
160
246
|
embeddings generated by the image encoder with pixel values. The actual
|
161
247
|
length of prefill_seq_len will be added by pixel_seq_len when pixel
|
162
248
|
values are passed.
|
163
|
-
quantize (
|
164
|
-
|
249
|
+
quantize (str, optional): The quantization type. Defaults to
|
250
|
+
'dynamic_int8'.
|
165
251
|
config (cfg.ModelConfig, optional): The model config used to configure KV
|
166
252
|
cache. If None, it uses the config of the pytorch_model.
|
167
253
|
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
|
@@ -182,7 +268,7 @@ def convert_to_tflite(
|
|
182
268
|
lora = lora_utils.LoRA.zeros(rank, config)
|
183
269
|
loras.append(lora)
|
184
270
|
|
185
|
-
quant_suffix =
|
271
|
+
quant_suffix = create_quantize_suffix(quantize)
|
186
272
|
kv_size = config.kv_cache_max_len
|
187
273
|
lora_suffix = (
|
188
274
|
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
@@ -216,7 +302,7 @@ def _export_helper(
|
|
216
302
|
prefill_seq_lens: list[int],
|
217
303
|
pixel_values_size: torch.Size,
|
218
304
|
pixel_seq_len: int,
|
219
|
-
quantize:
|
305
|
+
quantize: str,
|
220
306
|
config: cfg.ModelConfig,
|
221
307
|
loras: list[None | lora_utils.LoRA],
|
222
308
|
export_config: ExportConfig,
|
@@ -265,7 +351,8 @@ def _export_helper(
|
|
265
351
|
kv_layout=export_config.kvcache_layout,
|
266
352
|
)
|
267
353
|
|
268
|
-
quant_config =
|
354
|
+
quant_config = get_quant_recipe_from_flag(quantize)
|
355
|
+
quant_config._model_config = config
|
269
356
|
|
270
357
|
# For export, we create a module that captures any non-exportable,
|
271
358
|
# arugments, e.g. the generation config object.
|
@@ -330,5 +417,7 @@ def _export_helper(
|
|
330
417
|
sample_kwargs=sample_kwargs,
|
331
418
|
)
|
332
419
|
|
333
|
-
edge_model = converter.convert(
|
420
|
+
edge_model = converter.convert(
|
421
|
+
quant_config=quant_config,
|
422
|
+
)
|
334
423
|
edge_model.export(output_file)
|
@@ -75,8 +75,7 @@ class DecoderOnlyModel(nn.Module):
|
|
75
75
|
for idx in range(config.num_layers)
|
76
76
|
)
|
77
77
|
self.final_norm = builder.build_norm(
|
78
|
-
config.embedding_dim,
|
79
|
-
config.final_norm_config,
|
78
|
+
config.embedding_dim, config.final_norm_config
|
80
79
|
)
|
81
80
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
82
81
|
size=config.kv_cache_max,
|
@@ -50,7 +50,7 @@ def exported_programs_to_tflite(
|
|
50
50
|
*,
|
51
51
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
52
52
|
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
53
|
-
_saved_model_dir: Optional[str] = None
|
53
|
+
_saved_model_dir: Optional[str] = None,
|
54
54
|
):
|
55
55
|
"""Converts a list of ExportedProgram to a TFLite model.
|
56
56
|
|
@@ -29,6 +29,8 @@ _IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
|
|
29
29
|
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
|
30
30
|
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
|
31
31
|
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
|
32
|
+
# TODO: b/415833584 - Improve the regex for pre-softmax layer.
|
33
|
+
_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
|
32
34
|
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
|
33
35
|
|
34
36
|
|
@@ -95,10 +97,11 @@ def _set_quant_config(
|
|
95
97
|
rm: quantizer.recipe_manager.RecipeManager,
|
96
98
|
layer_recipe: quant_recipe.LayerQuantRecipe,
|
97
99
|
regex: str,
|
100
|
+
operation_name: _OpName = _OpName.ALL_SUPPORTED,
|
98
101
|
):
|
99
102
|
rm.add_quantization_config(
|
100
103
|
regex=regex,
|
101
|
-
operation_name=
|
104
|
+
operation_name=operation_name,
|
102
105
|
op_config=_OpQuantConfig(
|
103
106
|
weight_tensor_config=_TensorQuantConfig(
|
104
107
|
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
@@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe(
|
|
126
129
|
|
127
130
|
if recipe.embedding is not None:
|
128
131
|
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
|
132
|
+
if (
|
133
|
+
recipe._model_config is not None
|
134
|
+
and recipe._model_config.lm_head_share_weight_with_embedding
|
135
|
+
):
|
136
|
+
_set_quant_config(
|
137
|
+
rm,
|
138
|
+
recipe.embedding,
|
139
|
+
_DECODE_LOGITS_REGEX_STR,
|
140
|
+
_OpName.FULLY_CONNECTED,
|
141
|
+
)
|
129
142
|
|
130
143
|
if recipe.attention is not None:
|
131
144
|
if isinstance(recipe.attention, dict):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250510
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -22,6 +22,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
22
|
Requires-Python: >=3.10
|
23
23
|
Description-Content-Type: text/markdown
|
24
24
|
License-File: LICENSE
|
25
|
+
Requires-Dist: absl-py
|
25
26
|
Requires-Dist: numpy
|
26
27
|
Requires-Dist: scipy
|
27
28
|
Requires-Dist: safetensors
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=03QthwiMre1vVY49We8vVzXhxe0zkOzzsTMQZv3hDrk,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -61,15 +61,15 @@ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt
|
|
61
61
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
62
62
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
63
63
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=RRilUl2Ui08R9gy1Ua0jnaXNCrIJJb-oztgP62G3mX4,1526
|
64
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
64
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=9ozSw2-xuf5Wfh1HeLDTP3wJxxUZmrD3An1njJPMpdI,1594
|
65
65
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=6ImjTzJcq6JoKz2Z-z8pjv5BsRu5nUeEsTK3IPs3xgI,3521
|
66
66
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=JQLLiHNVBM9jOrZqUF0EmgAwtDD0yTRlmIbLaWM7qTg,11557
|
67
67
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
68
68
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEadDRZt_wY5fiLSCpMo54PcxFaL_Q,1789
|
69
69
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
70
70
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
71
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
72
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
71
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
|
72
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=v0ZcKrIAvERQLb1wK1Vc_ewWWVZgJFUdRTyoVY0Lfus,14955
|
73
73
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
74
74
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
75
75
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
@@ -178,9 +178,9 @@ ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJl
|
|
178
178
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
179
179
|
ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
|
180
180
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
|
181
|
-
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=
|
181
|
+
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
|
182
182
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
|
183
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
183
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=5UkUAT0qsWzLtNAeX-M5hEMi-kqoLV70_F76QiXmVZ4,2424
|
184
184
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
|
185
185
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
186
186
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
@@ -189,13 +189,13 @@ ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZj
|
|
189
189
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
190
190
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
|
191
191
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hPmWpg41ZMWwBsngTykRVzRPHtpbkwiLM,12811
|
192
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
192
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
|
193
193
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
194
194
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
195
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
195
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=LrBqxXVxkOWh4abcHfY4QXRpYxjjfEYd4ifrpGGbebI,14441
|
196
196
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
197
197
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
198
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
198
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=IG-88o7nWI9XrNDnwnQ-MoilsuqJ7KwrnbP3bn2EY9U,6334
|
199
199
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
200
200
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
201
201
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -209,12 +209,12 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjb
|
|
209
209
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
210
210
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
|
211
211
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
212
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
212
|
+
ai_edge_torch/lowertools/_shim.py,sha256=rEZkOdHiz7CPvPL0WATIYnH4K6wF1YBtcv3oFEx2ZeQ,3277
|
213
213
|
ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
|
214
214
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
|
215
215
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
216
216
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
217
|
-
ai_edge_torch/lowertools/translate_recipe.py,sha256=
|
217
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=JNsRc1Jmpj5W6PBww8KRMkbtxcv7ssl8Rr1R3x5_7to,6283
|
218
218
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
219
219
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
220
220
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
251
251
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
252
252
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
253
253
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
254
|
+
ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/METADATA,sha256=1fA2DwpzLkPWBS-gV86ik7v9m39lO_RUaU4k7qAEvkM,2074
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250510.dist-info/RECORD,,
|
File without changes
|
File without changes
|