ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/_convert/conversion.py +24 -0
- ai_edge_torch/_convert/converter.py +57 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/_convert/test/test_convert.py +25 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +10 -6
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/deepseek/deepseek.py +9 -5
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +8 -7
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +15 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/llama/llama.py +26 -10
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -4
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -4
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -5
- ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
- ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/phi2.py +9 -5
- ai_edge_torch/generative/examples/phi/phi3.py +8 -6
- ai_edge_torch/generative/examples/phi/phi4.py +8 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +21 -7
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -3
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +13 -7
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -6
- ai_edge_torch/generative/examples/smollm/verify.py +2 -2
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +9 -5
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +18 -5
- ai_edge_torch/generative/utilities/loader.py +19 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +64 -63
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building Hammer 2.1 models."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -43,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
43
45
|
intermediate_size=8960,
|
44
46
|
)
|
45
47
|
norm_config = cfg.NormalizationConfig(
|
46
|
-
type=cfg.NormalizationType.RMS_NORM,
|
47
|
-
epsilon=1e-06,
|
48
|
-
enable_hlfb=True,
|
48
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
49
49
|
)
|
50
50
|
block_config = cfg.TransformerBlockConfig(
|
51
51
|
attn_config=attn_config,
|
@@ -61,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
61
61
|
kv_cache_max_len=kv_cache_max_len,
|
62
62
|
block_configs=block_config,
|
63
63
|
final_norm_config=norm_config,
|
64
|
-
enable_hlfb=True,
|
65
64
|
)
|
66
65
|
return config
|
67
66
|
|
@@ -89,19 +88,29 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
89
88
|
return config
|
90
89
|
|
91
90
|
|
92
|
-
def build_1_5b_model(
|
91
|
+
def build_1_5b_model(
|
92
|
+
checkpoint_path: str,
|
93
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
94
|
+
**kwargs
|
95
|
+
) -> nn.Module:
|
93
96
|
return model_builder.build_decoder_only_model(
|
94
97
|
checkpoint_path=checkpoint_path,
|
95
98
|
config=get_1_5b_model_config(**kwargs),
|
96
99
|
tensor_names=TENSOR_NAMES,
|
97
100
|
model_class=Hammer,
|
101
|
+
custom_loader=custom_loader,
|
98
102
|
)
|
99
103
|
|
100
104
|
|
101
|
-
def build_0_5b_model(
|
105
|
+
def build_0_5b_model(
|
106
|
+
checkpoint_path: str,
|
107
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
108
|
+
**kwargs
|
109
|
+
) -> nn.Module:
|
102
110
|
return model_builder.build_decoder_only_model(
|
103
111
|
checkpoint_path=checkpoint_path,
|
104
112
|
config=get_0_5b_model_config(**kwargs),
|
105
113
|
tensor_names=TENSOR_NAMES,
|
106
114
|
model_class=Hammer,
|
115
|
+
custom_loader=custom_loader,
|
107
116
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.llama import llama
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
|
24
25
|
flags = converter.define_conversion_flags('llama')
|
@@ -37,8 +38,13 @@ _BUILDER = {
|
|
37
38
|
|
38
39
|
|
39
40
|
def main(_):
|
41
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
40
42
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
41
|
-
|
43
|
+
checkpoint_path,
|
44
|
+
custom_loader=loader.maybe_get_custom_loader(
|
45
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
|
+
),
|
47
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
42
48
|
)
|
43
49
|
converter.convert_to_tflite(
|
44
50
|
pytorch_model,
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from functools import partial
|
19
19
|
import math
|
20
|
-
from typing import Tuple
|
20
|
+
from typing import Callable, Dict, Tuple
|
21
21
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
23
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
121
121
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
122
122
|
intermediate_size=8192,
|
123
123
|
)
|
124
|
-
norm_config = cfg.NormalizationConfig(
|
125
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
126
|
-
)
|
124
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
127
125
|
block_config = cfg.TransformerBlockConfig(
|
128
126
|
attn_config=attn_config,
|
129
127
|
ff_config=ff_config,
|
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
152
150
|
kv_cache_max_len=kv_cache_max_len,
|
153
151
|
block_configs=block_config,
|
154
152
|
final_norm_config=norm_config,
|
155
|
-
enable_hlfb=True,
|
156
153
|
build_rope=build_rope,
|
157
154
|
)
|
158
155
|
return config
|
@@ -180,19 +177,38 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
180
177
|
|
181
178
|
|
182
179
|
def _build_model(
|
183
|
-
checkpoint_path: str,
|
180
|
+
checkpoint_path: str,
|
181
|
+
config: cfg.ModelConfig,
|
182
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
184
183
|
) -> torch.nn.Module:
|
185
184
|
return model_builder.build_decoder_only_model(
|
186
185
|
checkpoint_path=checkpoint_path,
|
187
186
|
config=config,
|
188
187
|
tensor_names=TENSOR_NAMES,
|
189
188
|
model_class=Llama,
|
189
|
+
custom_loader=custom_loader,
|
190
190
|
)
|
191
191
|
|
192
192
|
|
193
|
-
def build_1b_model(
|
194
|
-
|
193
|
+
def build_1b_model(
|
194
|
+
checkpoint_path: str,
|
195
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
196
|
+
**kwargs
|
197
|
+
) -> torch.nn.Module:
|
198
|
+
return _build_model(
|
199
|
+
checkpoint_path,
|
200
|
+
get_1b_model_config(**kwargs),
|
201
|
+
custom_loader=custom_loader,
|
202
|
+
)
|
195
203
|
|
196
204
|
|
197
|
-
def build_3b_model(
|
198
|
-
|
205
|
+
def build_3b_model(
|
206
|
+
checkpoint_path: str,
|
207
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
208
|
+
**kwargs
|
209
|
+
) -> torch.nn.Module:
|
210
|
+
return _build_model(
|
211
|
+
checkpoint_path,
|
212
|
+
get_3b_model_config(**kwargs),
|
213
|
+
custom_loader=custom_loader,
|
214
|
+
)
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.openelm import openelm
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("openelm")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = openelm.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Example of building an OpenELM model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
20
21
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
21
23
|
from torch import nn
|
22
24
|
|
23
25
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
@@ -51,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
53
|
The model config for an OpenELM model.
|
52
54
|
"""
|
53
55
|
norm_config = cfg.NormalizationConfig(
|
54
|
-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
55
57
|
)
|
56
58
|
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
|
57
59
|
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
|
@@ -99,7 +101,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
99
101
|
kv_cache_max_len=kv_cache_max_len,
|
100
102
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
101
103
|
final_norm_config=norm_config,
|
102
|
-
enable_hlfb=True,
|
103
104
|
)
|
104
105
|
return config
|
105
106
|
|
@@ -118,10 +119,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
118
119
|
return config
|
119
120
|
|
120
121
|
|
121
|
-
def build_model(
|
122
|
+
def build_model(
|
123
|
+
checkpoint_path: str,
|
124
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
125
|
+
**kwargs
|
126
|
+
) -> nn.Module:
|
122
127
|
return model_builder.build_decoder_only_model(
|
123
128
|
checkpoint_path=checkpoint_path,
|
124
129
|
config=get_model_config(**kwargs),
|
125
130
|
tensor_names=TENSOR_NAMES,
|
126
131
|
model_class=OpenELM,
|
132
|
+
custom_loader=custom_loader,
|
127
133
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
import torch
|
23
24
|
|
24
25
|
flags = converter.define_conversion_flags('paligemma2-3b-224')
|
@@ -32,9 +33,13 @@ _VERSION = flags.DEFINE_enum(
|
|
32
33
|
|
33
34
|
|
34
35
|
def main(_):
|
36
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
35
37
|
pytorch_model = paligemma.build_model(
|
36
|
-
|
38
|
+
checkpoint_path,
|
37
39
|
version=int(_VERSION.value),
|
40
|
+
custom_loader=loader.maybe_get_custom_loader(
|
41
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
42
|
+
),
|
38
43
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
39
44
|
)
|
40
45
|
|
@@ -110,9 +110,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
110
110
|
intermediate_size=16384,
|
111
111
|
)
|
112
112
|
norm_config = cfg.NormalizationConfig(
|
113
|
-
type=cfg.NormalizationType.RMS_NORM,
|
114
|
-
epsilon=1e-6,
|
115
|
-
zero_centered=True,
|
113
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
116
114
|
)
|
117
115
|
block_config = cfg.TransformerBlockConfig(
|
118
116
|
attn_config=attn_config,
|
@@ -131,7 +129,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
131
129
|
block_configs=block_config,
|
132
130
|
final_norm_config=norm_config,
|
133
131
|
lm_head_use_bias=False,
|
134
|
-
enable_hlfb=True,
|
135
132
|
)
|
136
133
|
return config
|
137
134
|
|
@@ -93,9 +93,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
93
93
|
The model config for the decoder of a PaliGemma 3B model.
|
94
94
|
"""
|
95
95
|
norm_config = cfg.NormalizationConfig(
|
96
|
-
type=cfg.NormalizationType.RMS_NORM,
|
97
|
-
epsilon=1e-6,
|
98
|
-
zero_centered=True,
|
96
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
99
97
|
)
|
100
98
|
ff_config = cfg.FeedForwardConfig(
|
101
99
|
type=cfg.FeedForwardType.GATED,
|
@@ -139,7 +137,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
139
137
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
140
138
|
final_norm_config=norm_config,
|
141
139
|
lm_head_use_bias=False,
|
142
|
-
enable_hlfb=True,
|
143
140
|
final_logit_softcap=30.0,
|
144
141
|
)
|
145
142
|
return config
|
@@ -66,7 +66,8 @@ class SiglipVisionEncoder(nn.Module):
|
|
66
66
|
config.image_embedding.image_size // config.image_embedding.patch_size
|
67
67
|
) ** 2
|
68
68
|
self.tok_embedding_position = nn.Parameter(
|
69
|
-
torch.zeros((num_patches, config.embedding_dim))
|
69
|
+
torch.zeros((num_patches, config.embedding_dim)),
|
70
|
+
requires_grad=False,
|
70
71
|
)
|
71
72
|
|
72
73
|
self.transformer_blocks = nn.ModuleList(
|
@@ -117,9 +118,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
117
118
|
use_bias=True,
|
118
119
|
)
|
119
120
|
norm_config = cfg.NormalizationConfig(
|
120
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
121
|
-
epsilon=1e-6,
|
122
|
-
enable_hlfb=True,
|
121
|
+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
|
123
122
|
)
|
124
123
|
block_config = cfg.TransformerBlockConfig(
|
125
124
|
attn_config=attn_config,
|
@@ -136,7 +135,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
|
|
136
135
|
image_embedding=image_embedding_config,
|
137
136
|
block_configs=block_config,
|
138
137
|
final_norm_config=norm_config,
|
139
|
-
enable_hlfb=True,
|
140
138
|
)
|
141
139
|
return config
|
142
140
|
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import Optional
|
19
|
+
from typing import Callable, Dict, Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import decoder2
|
@@ -139,7 +139,12 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
139
139
|
)
|
140
140
|
|
141
141
|
|
142
|
-
def build_model(
|
142
|
+
def build_model(
|
143
|
+
checkpoint_path: str,
|
144
|
+
version: int = 2,
|
145
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
146
|
+
**kwargs,
|
147
|
+
) -> PaliGemma:
|
143
148
|
if version == 1:
|
144
149
|
decoder_class = decoder.Decoder
|
145
150
|
decoder_tensor_names = decoder.TENSOR_NAMES
|
@@ -153,15 +158,17 @@ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
|
|
153
158
|
model = PaliGemma(config, decoder_class)
|
154
159
|
# Load the parameters of image encoder.
|
155
160
|
loader = loading_utils.ModelLoader(
|
156
|
-
checkpoint_path, image_encoder.TENSOR_NAMES
|
161
|
+
checkpoint_path, image_encoder.TENSOR_NAMES, custom_loader
|
157
162
|
)
|
158
163
|
loader.load(model.image_encoder, strict=False)
|
159
164
|
# Load the parameters of decoder.
|
160
|
-
loader = loading_utils.ModelLoader(
|
165
|
+
loader = loading_utils.ModelLoader(
|
166
|
+
checkpoint_path, decoder_tensor_names, custom_loader
|
167
|
+
)
|
161
168
|
loader.load(model.decoder, strict=False)
|
162
169
|
|
163
170
|
# Load the parameters of image projection.
|
164
|
-
loader = loading_utils.ModelLoader(checkpoint_path, None)
|
171
|
+
loader = loading_utils.ModelLoader(checkpoint_path, None, custom_loader)
|
165
172
|
state = loader.get_state()
|
166
173
|
converted_state = dict()
|
167
174
|
converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
|
@@ -21,6 +21,7 @@ from absl import app
|
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import kagglehub
|
26
27
|
from PIL import Image
|
@@ -39,10 +40,15 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
39
40
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
40
41
|
"The image URI to encode.",
|
41
42
|
)
|
42
|
-
|
43
|
-
"
|
43
|
+
_PROMPTS_WITH_IMAGE = flags.DEFINE_string(
|
44
|
+
"prompts_with_image",
|
44
45
|
"<image><bos>describe en",
|
45
|
-
"The input prompts to generate answers.",
|
46
|
+
"The input prompts to generate answers with an image.",
|
47
|
+
)
|
48
|
+
_PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
|
49
|
+
"prompts_text_only",
|
50
|
+
"What is the meaning of life?",
|
51
|
+
"The input prompts to generate answers only with text.",
|
46
52
|
)
|
47
53
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
48
54
|
"max_new_tokens",
|
@@ -84,6 +90,7 @@ def main(_):
|
|
84
90
|
reauthored_model = paligemma.build_model(
|
85
91
|
reauthored_checkpoint, version=int(_VERSION.value)
|
86
92
|
)
|
93
|
+
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
|
87
94
|
|
88
95
|
logging.info("Loading the processor from: %s", checkpoint)
|
89
96
|
# It works only when GemmaTokenizerFast is available. In some environments,
|
@@ -91,9 +98,25 @@ def main(_):
|
|
91
98
|
# sentencepiece model file properly.
|
92
99
|
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
93
100
|
|
101
|
+
logging.info("Verifying with text-only prompts...")
|
102
|
+
verifier.verify_reauthored_model(
|
103
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
104
|
+
original_model
|
105
|
+
),
|
106
|
+
reauthored_model=wrapped_reauthored_model,
|
107
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
108
|
+
generate_prompts=_PROMPTS_TEXT_ONLY.value,
|
109
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
110
|
+
verify_inputs=False, # Numeric check not working. Disable it for now.
|
111
|
+
atol=1e-04,
|
112
|
+
)
|
113
|
+
|
114
|
+
logging.info("Verifying with image input...")
|
94
115
|
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
95
116
|
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
96
|
-
inputs = processor(
|
117
|
+
inputs = processor(
|
118
|
+
text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
|
119
|
+
)
|
97
120
|
|
98
121
|
logging.info("Verifying the reauthored model with model.forward()...")
|
99
122
|
logging.info("Forwarding the original model...")
|
@@ -104,7 +127,6 @@ def main(_):
|
|
104
127
|
logging.info("outputs_original: %s", outputs_original)
|
105
128
|
|
106
129
|
logging.info("Forwarding the reauthored model...")
|
107
|
-
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
|
108
130
|
outputs_reauthored = wrapped_reauthored_model.forward(
|
109
131
|
tokens=inputs["input_ids"],
|
110
132
|
pixel_values=inputs["pixel_values"],
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.phi import phi3
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("phi3")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = phi3.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.phi import phi4
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("phi4")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = phi4.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -20,13 +20,19 @@ from absl import app
|
|
20
20
|
from ai_edge_torch.generative.examples.phi import phi2
|
21
21
|
from ai_edge_torch.generative.utilities import converter
|
22
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
+
from ai_edge_torch.generative.utilities import loader
|
23
24
|
|
24
25
|
flags = converter.define_conversion_flags("phi2")
|
25
26
|
|
26
27
|
|
27
28
|
def main(_):
|
29
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
28
30
|
pytorch_model = phi2.build_model(
|
29
|
-
|
31
|
+
checkpoint_path,
|
32
|
+
custom_loader=loader.maybe_get_custom_loader(
|
33
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
34
|
+
),
|
35
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
30
36
|
)
|
31
37
|
converter.convert_to_tflite(
|
32
38
|
pytorch_model,
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
20
21
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
21
23
|
from torch import nn
|
22
24
|
|
23
25
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
@@ -64,9 +66,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
64
66
|
intermediate_size=10240,
|
65
67
|
use_bias=True,
|
66
68
|
)
|
67
|
-
norm_config = cfg.NormalizationConfig(
|
68
|
-
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
|
69
|
-
)
|
69
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
70
70
|
block_config = cfg.TransformerBlockConfig(
|
71
71
|
attn_config=attn_config,
|
72
72
|
ff_config=ff_config,
|
@@ -83,7 +83,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
83
83
|
final_norm_config=norm_config,
|
84
84
|
lm_head_use_bias=True,
|
85
85
|
lm_head_share_weight_with_embedding=False,
|
86
|
-
enable_hlfb=True,
|
87
86
|
)
|
88
87
|
return config
|
89
88
|
|
@@ -98,10 +97,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
98
97
|
return config
|
99
98
|
|
100
99
|
|
101
|
-
def build_model(
|
100
|
+
def build_model(
|
101
|
+
checkpoint_path: str,
|
102
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
103
|
+
**kwargs
|
104
|
+
) -> nn.Module:
|
102
105
|
return model_builder.build_decoder_only_model(
|
103
106
|
checkpoint_path=checkpoint_path,
|
104
107
|
config=get_model_config(**kwargs),
|
105
108
|
tensor_names=TENSOR_NAMES,
|
106
109
|
model_class=Phi2,
|
110
|
+
custom_loader=custom_loader,
|
107
111
|
)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from functools import partial
|
19
19
|
import math
|
20
|
-
from typing import Tuple
|
20
|
+
from typing import Callable, Dict, Tuple
|
21
21
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
23
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -162,9 +162,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
162
162
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
163
163
|
intermediate_size=8192,
|
164
164
|
)
|
165
|
-
norm_config = cfg.NormalizationConfig(
|
166
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
167
|
-
)
|
165
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
168
166
|
block_config = cfg.TransformerBlockConfig(
|
169
167
|
attn_config=attn_config,
|
170
168
|
ff_config=ff_config,
|
@@ -192,7 +190,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
192
190
|
block_configs=block_config,
|
193
191
|
final_norm_config=norm_config,
|
194
192
|
lm_head_share_weight_with_embedding=False,
|
195
|
-
enable_hlfb=True,
|
196
193
|
build_rope=build_rope,
|
197
194
|
)
|
198
195
|
return config
|
@@ -208,11 +205,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
208
205
|
return config
|
209
206
|
|
210
207
|
|
211
|
-
def build_model(
|
208
|
+
def build_model(
|
209
|
+
checkpoint_path: str,
|
210
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
211
|
+
**kwargs
|
212
|
+
) -> torch.nn.Module:
|
212
213
|
"""Instantiates the model instance and load checkpoint if provided."""
|
213
214
|
return model_builder.build_decoder_only_model(
|
214
215
|
checkpoint_path=checkpoint_path,
|
215
216
|
config=get_model_config(**kwargs),
|
216
217
|
tensor_names=TENSOR_NAMES,
|
217
218
|
model_class=Phi3_5Mini,
|
219
|
+
custom_loader=custom_loader,
|
218
220
|
)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from functools import partial
|
19
19
|
import math
|
20
|
-
from typing import Tuple
|
20
|
+
from typing import Callable, Dict, Tuple
|
21
21
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
23
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -112,9 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
112
112
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
113
113
|
intermediate_size=8192,
|
114
114
|
)
|
115
|
-
norm_config = cfg.NormalizationConfig(
|
116
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
117
|
-
)
|
115
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
118
116
|
block_config = cfg.TransformerBlockConfig(
|
119
117
|
attn_config=attn_config,
|
120
118
|
ff_config=ff_config,
|
@@ -141,7 +139,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
141
139
|
embedding_dim=3072,
|
142
140
|
block_configs=block_config,
|
143
141
|
final_norm_config=norm_config,
|
144
|
-
enable_hlfb=True,
|
145
142
|
build_rope=build_rope,
|
146
143
|
)
|
147
144
|
return config
|
@@ -157,11 +154,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
157
154
|
return config
|
158
155
|
|
159
156
|
|
160
|
-
def build_model(
|
157
|
+
def build_model(
|
158
|
+
checkpoint_path: str,
|
159
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
160
|
+
**kwargs
|
161
|
+
) -> torch.nn.Module:
|
161
162
|
"""Instantiates the model instance and load checkpoint if provided."""
|
162
163
|
return model_builder.build_decoder_only_model(
|
163
164
|
checkpoint_path=checkpoint_path,
|
164
165
|
config=get_model_config(**kwargs),
|
165
166
|
tensor_names=TENSOR_NAMES,
|
166
167
|
model_class=Phi4Mini,
|
168
|
+
custom_loader=custom_loader,
|
167
169
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('qwen')
|
24
25
|
|
@@ -37,8 +38,13 @@ _BUILDER = {
|
|
37
38
|
|
38
39
|
|
39
40
|
def main(_):
|
41
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
40
42
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
41
|
-
|
43
|
+
checkpoint_path,
|
44
|
+
custom_loader=loader.maybe_get_custom_loader(
|
45
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
|
+
),
|
47
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
42
48
|
)
|
43
49
|
converter.convert_to_tflite(
|
44
50
|
pytorch_model,
|