ai-edge-torch-nightly 0.5.0.dev20250512__py3-none-any.whl → 0.5.0.dev20250514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +3 -1
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/gemma3/decoder.py +7 -2
- ai_edge_torch/generative/examples/gemma3/gemma3.py +8 -4
- ai_edge_torch/generative/examples/gemma3/verify_util.py +14 -3
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -1
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipes.py +17 -4
- ai_edge_torch/generative/utilities/converter.py +21 -7
- ai_edge_torch/generative/utilities/loader.py +12 -2
- ai_edge_torch/generative/utilities/model_builder.py +5 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250512.dist-info → ai_edge_torch_nightly-0.5.0.dev20250514.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250512.dist-info → ai_edge_torch_nightly-0.5.0.dev20250514.dist-info}/RECORD +19 -19
- {ai_edge_torch_nightly-0.5.0.dev20250512.dist-info → ai_edge_torch_nightly-0.5.0.dev20250514.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250512.dist-info → ai_edge_torch_nightly-0.5.0.dev20250514.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250512.dist-info → ai_edge_torch_nightly-0.5.0.dev20250514.dist-info}/top_level.txt +0 -0
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
49
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
50
|
intermediate_size=2048,
|
51
51
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
52
|
+
norm_config = cfg.NormalizationConfig(
|
53
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
+
)
|
53
55
|
block_config = cfg.TransformerBlockConfig(
|
54
56
|
attn_config=attn_config,
|
55
57
|
ff_config=ff_config,
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("amd-llama-135m")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Decoder for Gemma3 model."""
|
17
17
|
|
18
|
-
from typing import List, Optional, Tuple
|
18
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
@@ -410,7 +410,11 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
410
410
|
return config
|
411
411
|
|
412
412
|
|
413
|
-
def build_model_1b(
|
413
|
+
def build_model_1b(
|
414
|
+
checkpoint_path: str,
|
415
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
416
|
+
**kwargs,
|
417
|
+
) -> nn.Module:
|
414
418
|
# TODO(b/403644647): Better error handling for loading checkpoints with
|
415
419
|
# different tensor names.
|
416
420
|
for tensor_names in TENSOR_NAMES_DICT.values():
|
@@ -420,6 +424,7 @@ def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
420
424
|
config=get_decoder_config_1b(**kwargs),
|
421
425
|
tensor_names=tensor_names,
|
422
426
|
model_class=Decoder,
|
427
|
+
custom_loader=custom_loader,
|
423
428
|
)
|
424
429
|
except KeyError as ke:
|
425
430
|
continue
|
@@ -16,8 +16,7 @@
|
|
16
16
|
"""Example of building a Gemma3 gpu model."""
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
|
-
from typing import List, Optional, Tuple
|
20
|
-
import xmlrpc
|
19
|
+
from typing import List, Optional, Tuple, Callable, Dict
|
21
20
|
|
22
21
|
from ai_edge_torch.generative.examples.gemma3 import decoder
|
23
22
|
from ai_edge_torch.generative.examples.gemma3 import image_encoder
|
@@ -166,9 +165,14 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
|
166
165
|
mm_extra_tokens=32,
|
167
166
|
)
|
168
167
|
|
169
|
-
|
168
|
+
|
169
|
+
def build_model_1b(
|
170
|
+
checkpoint_path: str,
|
171
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
172
|
+
**kwargs,
|
173
|
+
) -> decoder.Decoder:
|
170
174
|
if checkpoint_path:
|
171
|
-
model = decoder.build_model_1b(checkpoint_path, **kwargs)
|
175
|
+
model = decoder.build_model_1b(checkpoint_path, custom_loader, **kwargs)
|
172
176
|
else:
|
173
177
|
config = decoder.get_decoder_config_1b(**kwargs)
|
174
178
|
model = decoder.Decoder(config)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import os
|
20
|
-
from typing import List, Optional, Tuple
|
20
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
@@ -167,6 +167,7 @@ def verify_reauthored_gemma_model(
|
|
167
167
|
generate_prompts: List[str],
|
168
168
|
forward_input_ids: List[List[int]],
|
169
169
|
weight_filename: str,
|
170
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
170
171
|
tokenizer_filename: str = "tokenizer.model",
|
171
172
|
max_new_tokens: int = 20,
|
172
173
|
rtol: float = 1e-05,
|
@@ -196,7 +197,14 @@ def verify_reauthored_gemma_model(
|
|
196
197
|
|
197
198
|
logging.info("Loading the original model from: %s", checkpoint)
|
198
199
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
199
|
-
|
200
|
+
checkpoint_path = os.path.join(checkpoint, weight_filename)
|
201
|
+
if custom_loader is None:
|
202
|
+
original_model.load_weights(checkpoint_path)
|
203
|
+
else:
|
204
|
+
original_model.load_state_dict(
|
205
|
+
custom_loader(checkpoint_path)["model_state_dict"],
|
206
|
+
strict=False,
|
207
|
+
)
|
200
208
|
|
201
209
|
return verifier.verify_reauthored_model(
|
202
210
|
original_model=GemmaWrapper(original_model),
|
@@ -216,6 +224,7 @@ def verify_gemma3(
|
|
216
224
|
max_new_tokens: int,
|
217
225
|
variant: str,
|
218
226
|
weight_filename: str,
|
227
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
219
228
|
) -> bool:
|
220
229
|
"""Verifies the reauthored Gemma3 model.
|
221
230
|
|
@@ -225,6 +234,7 @@ def verify_gemma3(
|
|
225
234
|
max_new_tokens: Maximum number of new tokens to generate.
|
226
235
|
variant: Gemma model variant.
|
227
236
|
weight_filename: Name of the weight file.
|
237
|
+
custom_loader: A custom loader to load the weights.
|
228
238
|
|
229
239
|
Returns:
|
230
240
|
True if the verification passes, False otherwise.
|
@@ -234,7 +244,7 @@ def verify_gemma3(
|
|
234
244
|
|
235
245
|
if variant == "1b":
|
236
246
|
reauthored_model = UnifiedGemma3Wrapper(
|
237
|
-
gemma3.build_model_1b(gemma3_model_path)
|
247
|
+
gemma3.build_model_1b(gemma3_model_path, custom_loader)
|
238
248
|
)
|
239
249
|
else:
|
240
250
|
raise ValueError(f"Unsupported Gemma3 variant: {variant}")
|
@@ -247,5 +257,6 @@ def verify_gemma3(
|
|
247
257
|
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
248
258
|
max_new_tokens=max_new_tokens,
|
249
259
|
weight_filename=weight_filename,
|
260
|
+
custom_loader=custom_loader,
|
250
261
|
atol=1e-04,
|
251
262
|
)
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("openelm")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -51,7 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
51
|
The model config for an OpenELM model.
|
52
52
|
"""
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
54
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
|
55
55
|
)
|
56
56
|
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
57
57
|
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
|
|
22
22
|
import torch
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('paligemma2-3b-224')
|
25
|
-
ExportConfig = export_config.ExportConfig
|
26
|
-
|
27
25
|
|
28
26
|
_VERSION = flags.DEFINE_enum(
|
29
27
|
'version',
|
@@ -32,6 +30,7 @@ _VERSION = flags.DEFINE_enum(
|
|
32
30
|
'The version of PaliGemma model to verify.',
|
33
31
|
)
|
34
32
|
|
33
|
+
|
35
34
|
def main(_):
|
36
35
|
pytorch_model = paligemma.build_model(
|
37
36
|
flags.FLAGS.checkpoint_path,
|
@@ -51,7 +50,7 @@ def main(_):
|
|
51
50
|
pixel_seq_len=(config.image_size // config.patch_size) ** 2,
|
52
51
|
quantize=flags.FLAGS.quantize,
|
53
52
|
config=pytorch_model.config.decoder_config,
|
54
|
-
export_config=
|
53
|
+
export_config=export_config.get_from_flags(),
|
55
54
|
)
|
56
55
|
|
57
56
|
|
@@ -21,8 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags('qwen_vl')
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
|
-
|
26
24
|
|
27
25
|
_IMAGE_HEIGHT = flags.DEFINE_integer(
|
28
26
|
'image_height',
|
@@ -35,6 +33,7 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
|
|
35
33
|
'The width of image.',
|
36
34
|
)
|
37
35
|
|
36
|
+
|
38
37
|
def main(_):
|
39
38
|
pytorch_model = qwen_vl.build_model(
|
40
39
|
flags.FLAGS.checkpoint_path,
|
@@ -60,7 +59,7 @@ def main(_):
|
|
60
59
|
),
|
61
60
|
quantize=flags.FLAGS.quantize,
|
62
61
|
config=pytorch_model.config.decoder_config,
|
63
|
-
export_config=
|
62
|
+
export_config=export_config.get_from_flags(),
|
64
63
|
)
|
65
64
|
|
66
65
|
|
@@ -27,37 +27,49 @@ Typical usage example:
|
|
27
27
|
)
|
28
28
|
"""
|
29
29
|
|
30
|
+
from typing import Optional
|
31
|
+
from ai_edge_torch.generative.layers import model_config
|
30
32
|
from ai_edge_torch.generative.quantize import quant_recipe
|
31
33
|
from ai_edge_torch.generative.quantize import quant_recipe_utils
|
32
34
|
from ai_edge_torch.quantize import quant_config
|
33
35
|
|
34
36
|
|
35
|
-
def full_int8_dynamic_recipe(
|
37
|
+
def full_int8_dynamic_recipe(
|
38
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
39
|
+
) -> quant_config.QuantConfig:
|
36
40
|
return quant_config.QuantConfig(
|
37
41
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
38
42
|
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
43
|
+
_model_config=mcfg,
|
39
44
|
)
|
40
45
|
)
|
41
46
|
|
42
47
|
|
43
|
-
def full_int8_weight_only_recipe(
|
48
|
+
def full_int8_weight_only_recipe(
|
49
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
50
|
+
) -> quant_config.QuantConfig:
|
44
51
|
return quant_config.QuantConfig(
|
45
52
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
46
53
|
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
|
54
|
+
_model_config=mcfg,
|
47
55
|
)
|
48
56
|
)
|
49
57
|
|
50
58
|
|
51
|
-
def full_fp16_recipe(
|
59
|
+
def full_fp16_recipe(
|
60
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
61
|
+
) -> quant_config.QuantConfig:
|
52
62
|
return quant_config.QuantConfig(
|
53
63
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
54
|
-
default=quant_recipe_utils.create_layer_quant_fp16()
|
64
|
+
default=quant_recipe_utils.create_layer_quant_fp16(),
|
65
|
+
_model_config=mcfg,
|
55
66
|
)
|
56
67
|
)
|
57
68
|
|
58
69
|
|
59
70
|
def all_supported_int4_dynamic_block_recipe(
|
60
71
|
block_size: int,
|
72
|
+
mcfg: Optional[model_config.ModelConfig] = None,
|
61
73
|
) -> quant_config.QuantConfig:
|
62
74
|
return quant_config.QuantConfig(
|
63
75
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
@@ -65,5 +77,6 @@ def all_supported_int4_dynamic_block_recipe(
|
|
65
77
|
block_size
|
66
78
|
),
|
67
79
|
embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
|
80
|
+
_model_config=mcfg,
|
68
81
|
)
|
69
82
|
)
|
@@ -26,6 +26,7 @@ from ai_edge_torch.generative.layers import lora as lora_utils
|
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
27
|
from ai_edge_torch.generative.quantize import quant_recipes
|
28
28
|
from ai_edge_torch.generative.utilities import export_config
|
29
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
29
30
|
import torch
|
30
31
|
|
31
32
|
ExportConfig = export_config.ExportConfig
|
@@ -123,7 +124,8 @@ def define_conversion_flags(
|
|
123
124
|
|
124
125
|
def get_quant_recipe_from_flag(
|
125
126
|
quantize: str,
|
126
|
-
|
127
|
+
model_config: cfg.ModelConfig,
|
128
|
+
) -> Optional[qcfg.QuantConfig]:
|
127
129
|
"""Processes the quantization flag and returns the corresponding recipe.
|
128
130
|
|
129
131
|
Args:
|
@@ -139,15 +141,19 @@ def get_quant_recipe_from_flag(
|
|
139
141
|
case QuantizationName.NONE:
|
140
142
|
return None
|
141
143
|
case QuantizationName.DYNAMIC_INT8:
|
142
|
-
return quant_recipes.full_int8_dynamic_recipe()
|
144
|
+
return quant_recipes.full_int8_dynamic_recipe(mcfg=model_config)
|
143
145
|
case QuantizationName.WEIGHT_ONLY_INT8:
|
144
|
-
return quant_recipes.full_int8_weight_only_recipe()
|
146
|
+
return quant_recipes.full_int8_weight_only_recipe(mcfg=model_config)
|
145
147
|
case QuantizationName.FP16:
|
146
148
|
return quant_recipes.full_fp16_recipe()
|
147
149
|
case QuantizationName.DYNAMIC_INT4_BLOCK32:
|
148
|
-
return quant_recipes.
|
150
|
+
return quant_recipes.all_supported_int4_dynamic_block_recipe(
|
151
|
+
32, mcfg=model_config
|
152
|
+
)
|
149
153
|
case QuantizationName.DYNAMIC_INT4_BLOCK128:
|
150
|
-
return quant_recipes.
|
154
|
+
return quant_recipes.all_supported_int4_dynamic_block_recipe(
|
155
|
+
128, mcfg=model_config
|
156
|
+
)
|
151
157
|
case _:
|
152
158
|
raise ValueError(f'Unsupported quantization flag: {quantize}')
|
153
159
|
|
@@ -274,6 +280,15 @@ def convert_to_tflite(
|
|
274
280
|
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
275
281
|
)
|
276
282
|
|
283
|
+
if pixel_values_size is not None:
|
284
|
+
assert pixel_seq_len > 0, 'pixel_seq_len must be greater than 0'
|
285
|
+
max_prefill_seq_len = max(prefill_seq_lens)
|
286
|
+
assert kv_size > max_prefill_seq_len + pixel_seq_len, (
|
287
|
+
f'The KV cache size ({kv_size}) must be greater than the maximum '
|
288
|
+
f'prefill sequence length ({max_prefill_seq_len}) + pixel sequence '
|
289
|
+
f'length ({pixel_seq_len})'
|
290
|
+
)
|
291
|
+
|
277
292
|
if export_config is not None:
|
278
293
|
if export_config.decode_batch_size > 1:
|
279
294
|
output_name_prefix += f'_dbs{export_config.decode_batch_size}'
|
@@ -351,8 +366,7 @@ def _export_helper(
|
|
351
366
|
kv_layout=export_config.kvcache_layout,
|
352
367
|
)
|
353
368
|
|
354
|
-
quant_config = get_quant_recipe_from_flag(quantize)
|
355
|
-
quant_config._model_config = config
|
369
|
+
quant_config = get_quant_recipe_from_flag(quantize, config)
|
356
370
|
|
357
371
|
# For export, we create a module that captures any non-exportable,
|
358
372
|
# arugments, e.g. the generation config object.
|
@@ -117,7 +117,12 @@ class ModelLoader:
|
|
117
117
|
final_norm: str = None
|
118
118
|
lm_head: str = None
|
119
119
|
|
120
|
-
def __init__(
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
file_name: str,
|
123
|
+
names: TensorNames,
|
124
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
125
|
+
) -> None:
|
121
126
|
"""ModelLoader constructor.
|
122
127
|
|
123
128
|
Can be used to load multiple models of the same type.
|
@@ -126,10 +131,15 @@ class ModelLoader:
|
|
126
131
|
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
127
132
|
file.
|
128
133
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
134
|
+
custom_loader (Callable[[str], Dict[str, torch.Tensor]]): A custom
|
135
|
+
loader to be used. If not provided, the class will determine a proper
|
136
|
+
loader.
|
129
137
|
"""
|
130
138
|
self._file_name = file_name
|
131
139
|
self._names = names
|
132
|
-
self._loader =
|
140
|
+
self._loader = (
|
141
|
+
custom_loader if custom_loader is not None else self._get_loader()
|
142
|
+
)
|
133
143
|
|
134
144
|
def get_state(self) -> Dict[str, torch.Tensor]:
|
135
145
|
return self._loader(self._file_name)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
+
from typing import Callable, Dict
|
19
20
|
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
@@ -160,9 +161,12 @@ def build_decoder_only_model(
|
|
160
161
|
config: cfg.ModelConfig,
|
161
162
|
tensor_names: loading_utils.ModelLoader.TensorNames,
|
162
163
|
model_class: type[nn.Module] = DecoderOnlyModel,
|
164
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
163
165
|
) -> nn.Module:
|
164
166
|
transformer = model_class(config)
|
165
|
-
loader = loading_utils.ModelLoader(
|
167
|
+
loader = loading_utils.ModelLoader(
|
168
|
+
checkpoint_path, tensor_names, custom_loader
|
169
|
+
)
|
166
170
|
loader.load(
|
167
171
|
transformer, strict=not config.lm_head_share_weight_with_embedding
|
168
172
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250514
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=ZvSDZpKkUslpMEN4pPp4xI6n8g3mHZMdfIcYeWth5Dg,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -52,8 +52,8 @@ ai_edge_torch/generative/custom_ops/bmm_4d.py,sha256=JmVbZCujG_wuBchma8QF3DSBfVc
|
|
52
52
|
ai_edge_torch/generative/custom_ops/dynamic_update_slice.py,sha256=ZGAq2CfWZsfef5mHulsWmyUx0dDWJX6J6xPjhBrjQdM,2097
|
53
53
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
54
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
|
-
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=XsDXx6k0kE_OYu_dr7GEC26jCepV1Kv39iH-kpuqA4M,2794
|
56
|
+
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=hiuMFJ8QPymGMM6PiSQqQrfR4M1mblpPuDfjjabcr_w,1560
|
57
57
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
58
58
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
59
59
|
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
|
@@ -69,11 +69,11 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=jhiyinOqPt5ZZjEa
|
|
69
69
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=n7f2nF6Lin_tDvPs0JVldsuaBzo7pAwi5YAHAhlIxQg,6139
|
70
70
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
71
71
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=JLXXn2mFEBs4DlHH_O6hpEG9KInJqsCdWy3DrgUjT1c,1827
|
72
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
73
|
-
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=
|
72
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=shdgLzKDUi0vyNOAsrIVAEFb3Adltsri6Rx1-wxzVf4,15089
|
73
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=ZorRtnbElWsctcA0nEbfwjx0C578voF7fjFEvWSR5Ck,6582
|
74
74
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
75
75
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
76
|
-
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=1vfAtayH_I_qTpqhzu6n9xnCuvhgTzhS8IzZviW2dJQ,9418
|
77
77
|
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
78
|
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=9r8LXyaoBXYIIhhe1WQgEIjaxALQPE1dO2N6qopyWCk,1753
|
79
79
|
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
|
@@ -86,11 +86,11 @@ ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6X
|
|
86
86
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
87
87
|
ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
|
88
88
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
90
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
89
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=wRdT7bWbCX8g4TbzKbjcLx6vmKtuT5-g-ipg19hJW-M,1525
|
90
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hPcXYHj-nBP56TOeQQejB3HRzv6yHSftHOx0OEPP5M8,4574
|
91
91
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=4W26ZtPF5Cb9mpHYuRM4b2QB_4W76zf4WV36KzexVjs,2446
|
92
92
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
93
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=fkP-mWrih1s-vgJ41fLt8v5JE-UOs8Zrngh6ElQ6PMw,1997
|
94
94
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=-EYUZp55dfRY1E-N0Pr3b9i5c7Tt1XvYxvsRixguVS8,5527
|
95
95
|
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=WB8r-e_Crog1ItBq3Zse_nUG-foFyBcJsuEG26r_Ji8,6076
|
96
96
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=SvuR97sjkBtfkerH7Hu1UXB8kCFLpEATNbPfCbNAyfo,5614
|
@@ -114,7 +114,7 @@ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVs
|
|
114
114
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
|
115
115
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
116
116
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
117
|
-
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=
|
117
|
+
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=4Gntv6LBIxd0CaKkb-koLzGTdBEOGgVf3ob99lAuvuY,2196
|
118
118
|
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=7RFM25tDj_b0FkpSv8RUWir8K8v9p2jMtwZmP4VAUhw,4474
|
119
119
|
ai_edge_torch/generative/examples/qwen_vl/image_encoder.py,sha256=nHzBe_YSPnUe1d5i09v4bePQomVifzJNeUjRfprmxC0,14878
|
120
120
|
ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py,sha256=mfLFrT8NPEPh9CqlJYHwh-I2y6ST7hH_vEmbZYartHQ,7764
|
@@ -180,7 +180,7 @@ ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FT
|
|
180
180
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
|
181
181
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
|
182
182
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
|
183
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
183
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=45DJfcQXZ1FA1qI4LgYoYE4UD4yvfIYoY9LgYTeKFVw,2845
|
184
184
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
|
185
185
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
186
186
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
@@ -192,10 +192,10 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hP
|
|
192
192
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kKJ01wscTC2t_Ylr7huO5gNKES01gm3dT1gx52z15PA,7356
|
193
193
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
194
194
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
195
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
195
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=4zcDlhgCQQyLylH8NLgVjnelou2pW6HWJHBFYsFyHuw,15020
|
196
196
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5IvR3grlMd4mWO5c_Y4x9Fk1b1xa57MzlYNE8XUaN28,2049
|
197
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
198
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
197
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=nw2REQ9sGWDwphShfRqNFICYmwIjqLp6bDcwVmsNTtg,14067
|
198
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSydYN3ZgmkzPc6nW0AJrA81wI,6461
|
199
199
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
200
200
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
201
201
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -251,8 +251,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
251
251
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
252
252
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
253
253
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
254
|
-
ai_edge_torch_nightly-0.5.0.
|
255
|
-
ai_edge_torch_nightly-0.5.0.
|
256
|
-
ai_edge_torch_nightly-0.5.0.
|
257
|
-
ai_edge_torch_nightly-0.5.0.
|
258
|
-
ai_edge_torch_nightly-0.5.0.
|
254
|
+
ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
255
|
+
ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/METADATA,sha256=4_d1LvNhvXOHKlqYZDcBYSLdYDmoGvWMgCK5PJasNiU,2074
|
256
|
+
ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
257
|
+
ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
258
|
+
ai_edge_torch_nightly-0.5.0.dev20250514.dist-info/RECORD,,
|
File without changes
|
File without changes
|