ai-edge-torch-nightly 0.5.0.dev20250425__py3-none-any.whl → 0.5.0.dev20250427__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +2 -36
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
- ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
- ai_edge_torch/generative/examples/hammer/verify.py +86 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/llama/llama.py +3 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/phi2.py +1 -1
- ai_edge_torch/generative/examples/phi/phi3.py +3 -1
- ai_edge_torch/generative/examples/phi/phi4.py +3 -1
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -37
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
- ai_edge_torch/generative/layers/kv_cache.py +2 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +7 -0
- ai_edge_torch/generative/utilities/converter.py +7 -2
- ai_edge_torch/generative/utilities/export_config.py +30 -0
- ai_edge_torch/model.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250427.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250427.dist-info}/RECORD +31 -27
- {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250427.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250427.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250427.dist-info}/top_level.txt +0 -0
@@ -19,41 +19,9 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
20
|
from ai_edge_torch.generative.layers import kv_cache
|
21
21
|
from ai_edge_torch.generative.utilities import converter
|
22
|
-
from ai_edge_torch.generative.utilities
|
23
|
-
import torch
|
22
|
+
from ai_edge_torch.generative.utilities import export_config
|
24
23
|
|
25
24
|
flags = converter.define_conversion_flags('deepseek')
|
26
|
-
ExportConfig = export_cfg.ExportConfig
|
27
|
-
|
28
|
-
|
29
|
-
def _create_mask(mask_len, kv_cache_max_len):
|
30
|
-
mask = torch.full(
|
31
|
-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
32
|
-
)
|
33
|
-
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
34
|
-
return mask
|
35
|
-
|
36
|
-
|
37
|
-
def _create_export_config(
|
38
|
-
prefill_seq_lens: list[int], kv_cache_max_len: int
|
39
|
-
) -> ExportConfig:
|
40
|
-
"""Creates the export config for the model."""
|
41
|
-
export_config = ExportConfig()
|
42
|
-
if isinstance(prefill_seq_lens, list):
|
43
|
-
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
44
|
-
else:
|
45
|
-
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
46
|
-
|
47
|
-
export_config.prefill_mask = prefill_mask
|
48
|
-
|
49
|
-
decode_mask = torch.full(
|
50
|
-
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
51
|
-
)
|
52
|
-
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
53
|
-
export_config.decode_mask = decode_mask
|
54
|
-
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
55
|
-
return export_config
|
56
|
-
|
57
25
|
|
58
26
|
def main(_):
|
59
27
|
pytorch_model = deepseek.build_model(
|
@@ -66,9 +34,7 @@ def main(_):
|
|
66
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
67
35
|
quantize=flags.FLAGS.quantize,
|
68
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
69
|
-
export_config=
|
70
|
-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
71
|
-
),
|
37
|
+
export_config=export_config.get_from_flags(),
|
72
38
|
)
|
73
39
|
|
74
40
|
|
@@ -17,14 +17,10 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
21
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
-
import torch
|
24
22
|
|
25
23
|
flags = converter.define_conversion_flags('gemma3-1b')
|
26
|
-
ExportConfig = export_config.ExportConfig
|
27
|
-
|
28
24
|
|
29
25
|
_MODEL_SIZE = flags.DEFINE_string(
|
30
26
|
'model_size',
|
@@ -33,55 +29,23 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
33
29
|
)
|
34
30
|
|
35
31
|
|
36
|
-
def _create_mask(mask_len, kv_cache_max_len):
|
37
|
-
mask = torch.full(
|
38
|
-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
39
|
-
)
|
40
|
-
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
41
|
-
return mask
|
42
|
-
|
43
|
-
|
44
|
-
def _create_export_config(
|
45
|
-
prefill_seq_lens: list[int], kv_cache_max_len: int
|
46
|
-
) -> ExportConfig:
|
47
|
-
"""Creates the export config for the model."""
|
48
|
-
export_config = ExportConfig()
|
49
|
-
if isinstance(prefill_seq_lens, list):
|
50
|
-
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
51
|
-
else:
|
52
|
-
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
53
|
-
|
54
|
-
export_config.prefill_mask = prefill_mask
|
55
|
-
|
56
|
-
decode_mask = torch.full(
|
57
|
-
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
58
|
-
)
|
59
|
-
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
60
|
-
export_config.decode_mask = decode_mask
|
61
|
-
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
62
|
-
return export_config
|
63
|
-
|
64
|
-
|
65
32
|
def main(_):
|
66
33
|
if _MODEL_SIZE.value == '1b':
|
67
34
|
pytorch_model = gemma3.build_model_1b(
|
68
35
|
flags.FLAGS.checkpoint_path,
|
69
36
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
70
37
|
)
|
71
|
-
config = pytorch_model.config
|
72
38
|
else:
|
73
39
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
40
|
+
|
74
41
|
converter.convert_to_tflite(
|
75
42
|
pytorch_model,
|
76
43
|
output_path=flags.FLAGS.output_path,
|
77
44
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
78
45
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
79
46
|
quantize=flags.FLAGS.quantize,
|
80
|
-
config=config,
|
81
47
|
lora_ranks=flags.FLAGS.lora_ranks,
|
82
|
-
export_config=
|
83
|
-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
84
|
-
),
|
48
|
+
export_config=export_config.get_from_flags(),
|
85
49
|
)
|
86
50
|
|
87
51
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting hammer 2.1 models to multi-signature tflite model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
21
|
+
from ai_edge_torch.generative.utilities import converter
|
22
|
+
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
flags = converter.define_conversion_flags('hammer')
|
27
|
+
ExportConfig = export_cfg.ExportConfig
|
28
|
+
|
29
|
+
|
30
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
31
|
+
'model_size',
|
32
|
+
'1.5b',
|
33
|
+
['0.5b', '1.5b'],
|
34
|
+
'The size of the model to convert.',
|
35
|
+
)
|
36
|
+
|
37
|
+
_BUILDER = {
|
38
|
+
'0.5b': hammer.build_0_5b_model,
|
39
|
+
'1.5b': hammer.build_1_5b_model,
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
44
|
+
mask = torch.full(
|
45
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
46
|
+
)
|
47
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
48
|
+
return mask
|
49
|
+
|
50
|
+
|
51
|
+
def _create_export_config(
|
52
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
53
|
+
) -> ExportConfig:
|
54
|
+
"""Creates the export config for the model."""
|
55
|
+
export_config = ExportConfig()
|
56
|
+
if isinstance(prefill_seq_lens, list):
|
57
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
58
|
+
else:
|
59
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
60
|
+
|
61
|
+
export_config.prefill_mask = prefill_mask
|
62
|
+
|
63
|
+
decode_mask = torch.full(
|
64
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
65
|
+
)
|
66
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
67
|
+
export_config.decode_mask = decode_mask
|
68
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
69
|
+
return export_config
|
70
|
+
|
71
|
+
|
72
|
+
def main(_):
|
73
|
+
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
74
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
75
|
+
)
|
76
|
+
converter.convert_to_tflite(
|
77
|
+
pytorch_model,
|
78
|
+
output_path=flags.FLAGS.output_path,
|
79
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
80
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
81
|
+
quantize=flags.FLAGS.quantize,
|
82
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
83
|
+
export_config=_create_export_config(
|
84
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
85
|
+
)
|
86
|
+
if flags.FLAGS.transpose_kv_cache
|
87
|
+
else ExportConfig(),
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
if __name__ == '__main__':
|
92
|
+
app.run(main)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building Hammer 2.1 models."""
|
17
|
+
|
18
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
21
|
+
|
22
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
23
|
+
|
24
|
+
|
25
|
+
class Hammer(model_builder.DecoderOnlyModel):
|
26
|
+
"""A Hammer model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
31
|
+
"""Returns the model config for a Hammer 2.1 1.5B model."""
|
32
|
+
attn_config = cfg.AttentionConfig(
|
33
|
+
num_heads=12,
|
34
|
+
head_dim=128,
|
35
|
+
num_query_groups=2,
|
36
|
+
rotary_base=1000000,
|
37
|
+
rotary_percentage=1.0,
|
38
|
+
qkv_use_bias=True,
|
39
|
+
)
|
40
|
+
ff_config = cfg.FeedForwardConfig(
|
41
|
+
type=cfg.FeedForwardType.GATED,
|
42
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
43
|
+
intermediate_size=8960,
|
44
|
+
)
|
45
|
+
norm_config = cfg.NormalizationConfig(
|
46
|
+
type=cfg.NormalizationType.RMS_NORM,
|
47
|
+
epsilon=1e-06,
|
48
|
+
enable_hlfb=True,
|
49
|
+
)
|
50
|
+
block_config = cfg.TransformerBlockConfig(
|
51
|
+
attn_config=attn_config,
|
52
|
+
ff_config=ff_config,
|
53
|
+
pre_attention_norm_config=norm_config,
|
54
|
+
post_attention_norm_config=norm_config,
|
55
|
+
)
|
56
|
+
config = cfg.ModelConfig(
|
57
|
+
vocab_size=151665,
|
58
|
+
num_layers=28,
|
59
|
+
max_seq_len=32768,
|
60
|
+
embedding_dim=1536,
|
61
|
+
kv_cache_max_len=kv_cache_max_len,
|
62
|
+
block_configs=block_config,
|
63
|
+
final_norm_config=norm_config,
|
64
|
+
enable_hlfb=True,
|
65
|
+
)
|
66
|
+
return config
|
67
|
+
|
68
|
+
|
69
|
+
def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
70
|
+
"""Returns the model config for a Hammer 2.1 0.5B model."""
|
71
|
+
config = get_1_5b_model_config(kv_cache_max_len)
|
72
|
+
# Hammer has only one block config.
|
73
|
+
block_config = config.block_config(0)
|
74
|
+
block_config.attn_config.num_heads = 14
|
75
|
+
block_config.attn_config.head_dim = 64
|
76
|
+
block_config.ff_config.intermediate_size = 4864
|
77
|
+
config.num_layers = 24
|
78
|
+
config.embedding_dim = 896
|
79
|
+
return config
|
80
|
+
|
81
|
+
|
82
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
83
|
+
config = get_1_5b_model_config(**kwargs)
|
84
|
+
config.vocab_size = 128
|
85
|
+
config.num_layers = 2
|
86
|
+
config.embedding_dim = 16
|
87
|
+
# Hammer has only one block config.
|
88
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
89
|
+
return config
|
90
|
+
|
91
|
+
|
92
|
+
def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
93
|
+
return model_builder.build_decoder_only_model(
|
94
|
+
checkpoint_path=checkpoint_path,
|
95
|
+
config=get_1_5b_model_config(**kwargs),
|
96
|
+
tensor_names=TENSOR_NAMES,
|
97
|
+
model_class=Hammer,
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
102
|
+
return model_builder.build_decoder_only_model(
|
103
|
+
checkpoint_path=checkpoint_path,
|
104
|
+
config=get_0_5b_model_config(**kwargs),
|
105
|
+
tensor_names=TENSOR_NAMES,
|
106
|
+
model_class=Hammer,
|
107
|
+
)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
30
|
+
"model_size",
|
31
|
+
"0.5b",
|
32
|
+
["0.5b", "1.5b"],
|
33
|
+
"The size of the model to verify.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
36
|
+
"prompts",
|
37
|
+
"What is the meaning of life?",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
_CHECKPOINT = {
|
47
|
+
"0.5b": "MadeAgents/Hammer2.1-0.5b",
|
48
|
+
"1.5b": "MadeAgents/Hammer2.1-1.5b",
|
49
|
+
}
|
50
|
+
|
51
|
+
_BUILDER = {
|
52
|
+
"0.5b": hammer.build_0_5b_model,
|
53
|
+
"1.5b": hammer.build_1_5b_model,
|
54
|
+
}
|
55
|
+
|
56
|
+
|
57
|
+
def main(_):
|
58
|
+
checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
|
59
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
60
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
61
|
+
|
62
|
+
# Locate the cached dir.
|
63
|
+
cached_config_file = transformers.utils.cached_file(
|
64
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
65
|
+
)
|
66
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
67
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
68
|
+
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
69
|
+
|
70
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
71
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
72
|
+
|
73
|
+
verifier.verify_reauthored_model(
|
74
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
+
original_model
|
76
|
+
),
|
77
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
+
generate_prompts=_PROMPTS.value,
|
80
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
81
|
+
atol=1e-04,
|
82
|
+
)
|
83
|
+
|
84
|
+
|
85
|
+
if __name__ == "__main__":
|
86
|
+
app.run(main)
|
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
|
|
22
22
|
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('llama')
|
25
|
-
ExportConfig = export_config.ExportConfig
|
26
|
-
|
27
25
|
|
28
26
|
_MODEL_SIZE = flags.DEFINE_enum(
|
29
27
|
'model_size',
|
@@ -49,7 +47,7 @@ def main(_):
|
|
49
47
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
50
48
|
quantize=flags.FLAGS.quantize,
|
51
49
|
lora_ranks=flags.FLAGS.lora_ranks,
|
52
|
-
export_config=
|
50
|
+
export_config=export_config.get_from_flags(),
|
53
51
|
)
|
54
52
|
|
55
53
|
|
@@ -121,7 +121,9 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
121
121
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
122
122
|
intermediate_size=8192,
|
123
123
|
)
|
124
|
-
norm_config = cfg.NormalizationConfig(
|
124
|
+
norm_config = cfg.NormalizationConfig(
|
125
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
126
|
+
)
|
125
127
|
block_config = cfg.TransformerBlockConfig(
|
126
128
|
attn_config=attn_config,
|
127
129
|
ff_config=ff_config,
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("phi3")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("phi4")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -22,7 +22,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
22
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags("phi2")
|
25
|
-
ExportConfig = export_config.ExportConfig
|
26
25
|
|
27
26
|
|
28
27
|
def main(_):
|
@@ -36,7 +35,7 @@ def main(_):
|
|
36
35
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
37
36
|
quantize=flags.FLAGS.quantize,
|
38
37
|
lora_ranks=flags.FLAGS.lora_ranks,
|
39
|
-
export_config=
|
38
|
+
export_config=export_config.get_from_flags(),
|
40
39
|
)
|
41
40
|
|
42
41
|
|
@@ -65,7 +65,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
65
65
|
use_bias=True,
|
66
66
|
)
|
67
67
|
norm_config = cfg.NormalizationConfig(
|
68
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
68
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
|
69
69
|
)
|
70
70
|
block_config = cfg.TransformerBlockConfig(
|
71
71
|
attn_config=attn_config,
|
@@ -162,7 +162,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
162
162
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
163
163
|
intermediate_size=8192,
|
164
164
|
)
|
165
|
-
norm_config = cfg.NormalizationConfig(
|
165
|
+
norm_config = cfg.NormalizationConfig(
|
166
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
167
|
+
)
|
166
168
|
block_config = cfg.TransformerBlockConfig(
|
167
169
|
attn_config=attn_config,
|
168
170
|
ff_config=ff_config,
|
@@ -112,7 +112,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
112
112
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
113
113
|
intermediate_size=8192,
|
114
114
|
)
|
115
|
-
norm_config = cfg.NormalizationConfig(
|
115
|
+
norm_config = cfg.NormalizationConfig(
|
116
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
117
|
+
)
|
116
118
|
block_config = cfg.TransformerBlockConfig(
|
117
119
|
attn_config=attn_config,
|
118
120
|
ff_config=ff_config,
|
@@ -17,13 +17,10 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
21
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
-
import torch
|
24
22
|
|
25
23
|
flags = converter.define_conversion_flags('qwen')
|
26
|
-
ExportConfig = export_config.ExportConfig
|
27
24
|
|
28
25
|
_MODEL_SIZE = flags.DEFINE_enum(
|
29
26
|
'model_size',
|
@@ -39,35 +36,6 @@ _BUILDER = {
|
|
39
36
|
}
|
40
37
|
|
41
38
|
|
42
|
-
def _create_mask(mask_len, kv_cache_max_len):
|
43
|
-
mask = torch.full(
|
44
|
-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
45
|
-
)
|
46
|
-
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
47
|
-
return mask
|
48
|
-
|
49
|
-
|
50
|
-
def _create_export_config(
|
51
|
-
prefill_seq_lens: list[int], kv_cache_max_len: int
|
52
|
-
) -> ExportConfig:
|
53
|
-
"""Creates the export config for the model."""
|
54
|
-
export_config = ExportConfig()
|
55
|
-
if isinstance(prefill_seq_lens, list):
|
56
|
-
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
57
|
-
else:
|
58
|
-
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
59
|
-
|
60
|
-
export_config.prefill_mask = prefill_mask
|
61
|
-
|
62
|
-
decode_mask = torch.full(
|
63
|
-
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
64
|
-
)
|
65
|
-
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
66
|
-
export_config.decode_mask = decode_mask
|
67
|
-
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
68
|
-
return export_config
|
69
|
-
|
70
|
-
|
71
39
|
def main(_):
|
72
40
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
73
41
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
@@ -79,11 +47,7 @@ def main(_):
|
|
79
47
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
80
48
|
quantize=flags.FLAGS.quantize,
|
81
49
|
lora_ranks=flags.FLAGS.lora_ranks,
|
82
|
-
export_config=
|
83
|
-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
84
|
-
)
|
85
|
-
if flags.FLAGS.transpose_kv_cache
|
86
|
-
else ExportConfig(),
|
50
|
+
export_config=export_config.get_from_flags(),
|
87
51
|
)
|
88
52
|
|
89
53
|
|
@@ -35,6 +35,10 @@ def main(_):
|
|
35
35
|
pytorch_model = smollm.build_model(
|
36
36
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
37
37
|
)
|
38
|
+
|
39
|
+
export_config = export_cfg.get_from_flags()
|
40
|
+
export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
|
41
|
+
|
38
42
|
converter.convert_to_tflite(
|
39
43
|
pytorch_model,
|
40
44
|
output_path=flags.FLAGS.output_path,
|
@@ -42,9 +46,7 @@ def main(_):
|
|
42
46
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
43
47
|
quantize=flags.FLAGS.quantize,
|
44
48
|
lora_ranks=flags.FLAGS.lora_ranks,
|
45
|
-
export_config=
|
46
|
-
decode_batch_size=_DECODE_BATCH_SIZE.value
|
47
|
-
),
|
49
|
+
export_config=export_config,
|
48
50
|
)
|
49
51
|
|
50
52
|
|
@@ -34,6 +34,9 @@ def main(_):
|
|
34
34
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
35
35
|
)
|
36
36
|
|
37
|
+
export_config = export_cfg.get_from_flags()
|
38
|
+
export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
|
39
|
+
|
37
40
|
converter.convert_to_tflite(
|
38
41
|
pytorch_model,
|
39
42
|
output_path=flags.FLAGS.output_path,
|
@@ -41,9 +44,7 @@ def main(_):
|
|
41
44
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
42
45
|
quantize=flags.FLAGS.quantize,
|
43
46
|
lora_ranks=flags.FLAGS.lora_ranks,
|
44
|
-
export_config=
|
45
|
-
decode_batch_size=_DECODE_BATCH_SIZE.value
|
46
|
-
),
|
47
|
+
export_config=export_config,
|
47
48
|
)
|
48
49
|
|
49
50
|
|
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
49
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
50
|
intermediate_size=1536,
|
51
51
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
52
|
+
norm_config = cfg.NormalizationConfig(
|
53
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
+
)
|
53
55
|
block_config = cfg.TransformerBlockConfig(
|
54
56
|
attn_config=attn_config,
|
55
57
|
ff_config=ff_config,
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("tiny_llama")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
49
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
50
|
intermediate_size=5632,
|
51
51
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
52
|
+
norm_config = cfg.NormalizationConfig(
|
53
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
+
)
|
53
55
|
block_config = cfg.TransformerBlockConfig(
|
54
56
|
attn_config=attn_config,
|
55
57
|
ff_config=ff_config,
|
@@ -51,10 +51,7 @@ class KVCacheEntry:
|
|
51
51
|
config: model_config.AttentionConfig,
|
52
52
|
batch_size: int,
|
53
53
|
) -> List[int]:
|
54
|
-
"""
|
55
|
-
|
56
|
-
the specified layout.
|
57
|
-
"""
|
54
|
+
"""Construct the shape of KV cache entry based on the specified layout."""
|
58
55
|
output_shape = []
|
59
56
|
for dim_spec in shape_spec:
|
60
57
|
if dim_spec is types.TensorDims.BATCH:
|
@@ -213,6 +210,7 @@ pytree.register_pytree_node(
|
|
213
210
|
serialized_type_name="",
|
214
211
|
)
|
215
212
|
|
213
|
+
|
216
214
|
def update(
|
217
215
|
cache: KVCacheEntry,
|
218
216
|
input_pos: torch.Tensor,
|
@@ -20,6 +20,7 @@ from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
|
20
20
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
22
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
23
24
|
from ai_edge_torch.generative.examples.llama import llama
|
24
25
|
from ai_edge_torch.generative.examples.openelm import openelm
|
25
26
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
@@ -148,6 +149,12 @@ class TestModelConversion(googletest.TestCase):
|
|
148
149
|
pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
|
149
150
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
150
151
|
|
152
|
+
def test_hammer(self):
|
153
|
+
config = hammer.get_fake_model_config()
|
154
|
+
pytorch_model = hammer.Hammer(config).eval()
|
155
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
156
|
+
|
157
|
+
|
151
158
|
def test_amd_llama_135m(self):
|
152
159
|
config = amd_llama_135m.get_fake_model_config()
|
153
160
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
@@ -81,12 +81,17 @@ def define_conversion_flags(model_name: str):
|
|
81
81
|
'If set, the model will be converted with the provided list of LoRA'
|
82
82
|
' ranks.',
|
83
83
|
)
|
84
|
+
flags.DEFINE_bool(
|
85
|
+
'mask_as_input',
|
86
|
+
False,
|
87
|
+
'If true, the mask will be passed in as input. Otherwise, mask will be '
|
88
|
+
'built by the model internally.',
|
89
|
+
)
|
84
90
|
flags.DEFINE_bool(
|
85
91
|
'transpose_kv_cache',
|
86
92
|
False,
|
87
|
-
'If
|
93
|
+
'If true, the model will be converted with transposed KV cache.',
|
88
94
|
)
|
89
|
-
|
90
95
|
return flags
|
91
96
|
|
92
97
|
|
@@ -14,8 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
"""Config for customizing model export process."""
|
17
|
+
|
17
18
|
import dataclasses
|
18
19
|
from typing import List, Optional
|
20
|
+
|
21
|
+
from absl import flags
|
19
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
23
|
import torch
|
21
24
|
|
@@ -38,3 +41,30 @@ class ExportConfig:
|
|
38
41
|
kvcache_cls: type = kv_utils.KVCache
|
39
42
|
# The batch size of the decode signature.
|
40
43
|
decode_batch_size: int = 1
|
44
|
+
|
45
|
+
|
46
|
+
def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
|
47
|
+
if isinstance(mask_len, list):
|
48
|
+
return [_build_mask(i, kv_cache_max_len) for i in mask_len]
|
49
|
+
|
50
|
+
mask = torch.full(
|
51
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
52
|
+
)
|
53
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
54
|
+
return mask
|
55
|
+
|
56
|
+
|
57
|
+
def get_from_flags() -> ExportConfig:
|
58
|
+
"""Builds an export config according to the commandline flags."""
|
59
|
+
export_config = ExportConfig()
|
60
|
+
|
61
|
+
if flags.FLAGS.mask_as_input:
|
62
|
+
export_config.prefill_mask = _build_mask(
|
63
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
64
|
+
)
|
65
|
+
export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
|
66
|
+
|
67
|
+
if flags.FLAGS.transpose_kv_cache:
|
68
|
+
export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
|
69
|
+
|
70
|
+
return export_config
|
ai_edge_torch/model.py
CHANGED
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
|
23
23
|
import abc
|
24
24
|
import re
|
25
|
+
import os
|
25
26
|
from typing import Callable
|
26
27
|
|
27
28
|
import numpy.typing as npt
|
@@ -154,6 +155,7 @@ class TfLiteModel(Model):
|
|
154
155
|
Args:
|
155
156
|
path: The path to file to which the model is serialized.
|
156
157
|
"""
|
158
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
157
159
|
with open(path, 'wb') as file_handle:
|
158
160
|
file_handle.write(self._tflite_model)
|
159
161
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250427
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,8 +1,8 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/model.py,sha256=
|
5
|
-
ai_edge_torch/version.py,sha256=
|
4
|
+
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
+
ai_edge_torch/version.py,sha256=RhNMNIs4sG78K3SOLk6zxuILeS_S2vhG7FJJOrV4cLM,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -53,7 +53,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
|
|
53
53
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
|
54
54
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
55
55
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
56
|
-
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
|
57
57
|
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
|
58
58
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
59
59
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -65,15 +65,19 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
|
|
65
65
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
66
66
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
67
67
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
68
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
|
69
69
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
|
70
70
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
71
71
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
72
72
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
73
73
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
|
74
|
+
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
|
+
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
|
76
|
+
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
|
77
|
+
ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
|
74
78
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
|
-
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=
|
76
|
-
ai_edge_torch/generative/examples/llama/llama.py,sha256=
|
79
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=nz5h4m8bVnw8P7OEtqhA_fKfvaRzxhT2_75vkFCqHmU,1735
|
80
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=H7I5iNhIJ55gb0-9k7g-FPcG2IlthnA9XMR8qd__5bQ,6621
|
77
81
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
78
82
|
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
79
83
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
@@ -93,17 +97,17 @@ ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4I
|
|
93
97
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
94
98
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
95
99
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
96
|
-
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=
|
97
|
-
ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=
|
98
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
99
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
100
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
101
|
-
ai_edge_torch/generative/examples/phi/phi4.py,sha256=
|
100
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=k-0ZC-_zZZmkdcc6dr1QGXfX9lDZZXRQSuc6wT0n3Is,1514
|
101
|
+
ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=5KSJRySjSc89FriCOnfBabD8zRLUcGAw3L0VInuJFUY,1512
|
102
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=wVIdGenHTi9xUffYddN_uXWMBO2tgo1e_hU4OG_NmHA,1513
|
103
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=X9MfjK8rmyRSrfNzIaKQNSgqLM5_CBH-BrLFX_7BWL8,3494
|
104
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=65Dbv8cA4WFdluflHQHzgDmDFjdmc6rxMO4hQukaxKU,6978
|
105
|
+
ai_edge_torch/generative/examples/phi/phi4.py,sha256=y3CCZCW4MnvX74d4MNERRuQBE0p5dquC2M9vDXXqnZI,5760
|
102
106
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
103
107
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
104
108
|
ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
|
105
109
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
106
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256
|
110
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVsJ9Obl7PBhMgd3a0T1t8dqoPp_VzZaQ,1776
|
107
111
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
|
108
112
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
109
113
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
@@ -115,9 +119,9 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
|
|
115
119
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
116
120
|
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
|
117
121
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
118
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
119
|
-
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=
|
120
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
122
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=jTM_tndbDqzq19uLz2n71S7M81L1Y6R7oVBPsMcYGzk,1785
|
123
|
+
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=wU72MzpUIi2mQ8ZODW1x4L5KZPWvuXyB-_Eqo-RKqFw,1757
|
124
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=SFE8fIJx7Y_oan0vXSmhEmI0Ib2HD3k9cyKLU_4MxfI,3807
|
121
125
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
|
122
126
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
123
127
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
@@ -143,8 +147,8 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
|
|
143
147
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
|
144
148
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=-z5tkQzGHbo37eAl9sDAJuT1Egxm8xI9CZmYLcmqIfU,4761
|
145
149
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
146
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
147
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
150
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=XM-dCBW2HG6FlwwPjlJi0I_TEaVqdv7aWpFEv-XUdLc,1539
|
151
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=6Qhml-XB8_RjQdYN948OaSsPJNrfi-Mr7PFB73C79Ug,2828
|
148
152
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
|
149
153
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
150
154
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
@@ -153,7 +157,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWn
|
|
153
157
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
154
158
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
155
159
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
156
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
160
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=dDeirtuo9AnlN1tYoLbFi_pKhIDmn35FQY1m6X28hSY,8468
|
157
161
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
158
162
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
159
163
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
@@ -179,12 +183,12 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0
|
|
179
183
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
180
184
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
181
185
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
|
182
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
186
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hPmWpg41ZMWwBsngTykRVzRPHtpbkwiLM,12811
|
183
187
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
184
188
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
185
189
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
186
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
187
|
-
ai_edge_torch/generative/utilities/export_config.py,sha256=
|
190
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=4RNNl7vk3WN_JG5EZajofiRSqtPnUNCYosxTacdEOto,10948
|
191
|
+
ai_edge_torch/generative/utilities/export_config.py,sha256=maUVt0T5FsLpHO5H-BZ-O0FRBZO_ejKwGhPR9Qq8ViM,2490
|
188
192
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
189
193
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
190
194
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
@@ -242,8 +246,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
242
246
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
243
247
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
244
248
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
245
|
-
ai_edge_torch_nightly-0.5.0.
|
246
|
-
ai_edge_torch_nightly-0.5.0.
|
247
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/METADATA,sha256=g53PvQrw8WP7McVXcoMYSEF9lmh7VWexPnfQLGOTVJg,2051
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
253
|
+
ai_edge_torch_nightly-0.5.0.dev20250427.dist-info/RECORD,,
|
File without changes
|
File without changes
|