ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250516__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/_convert/conversion.py +23 -0
- ai_edge_torch/_convert/converter.py +57 -3
- ai_edge_torch/_convert/test/test_convert.py +25 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +9 -2
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/deepseek/deepseek.py +8 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +9 -1
- ai_edge_torch/generative/examples/gemma/gemma2.py +7 -2
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +14 -2
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/llama/llama.py +25 -6
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -1
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +2 -1
- ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
- ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/phi2.py +8 -1
- ai_edge_torch/generative/examples/phi/phi3.py +7 -2
- ai_edge_torch/generative/examples/phi/phi4.py +7 -2
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +20 -3
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +14 -2
- ai_edge_torch/generative/examples/smollm/verify.py +2 -2
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -1
- ai_edge_torch/generative/utilities/converter.py +16 -4
- ai_edge_torch/generative/utilities/loader.py +19 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/RECORD +54 -54
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/top_level.txt +0 -0
@@ -66,7 +66,8 @@ class SiglipVisionEncoder(nn.Module):
|
|
66
66
|
config.image_embedding.image_size // config.image_embedding.patch_size
|
67
67
|
) ** 2
|
68
68
|
self.tok_embedding_position = nn.Parameter(
|
69
|
-
torch.zeros((num_patches, config.embedding_dim))
|
69
|
+
torch.zeros((num_patches, config.embedding_dim)),
|
70
|
+
requires_grad=False,
|
70
71
|
)
|
71
72
|
|
72
73
|
self.transformer_blocks = nn.ModuleList(
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building a full-stack of PaliGemma model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import Optional
|
19
|
+
from typing import Callable, Dict, Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.paligemma import decoder
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import decoder2
|
@@ -139,7 +139,12 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
139
139
|
)
|
140
140
|
|
141
141
|
|
142
|
-
def build_model(
|
142
|
+
def build_model(
|
143
|
+
checkpoint_path: str,
|
144
|
+
version: int = 2,
|
145
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
146
|
+
**kwargs,
|
147
|
+
) -> PaliGemma:
|
143
148
|
if version == 1:
|
144
149
|
decoder_class = decoder.Decoder
|
145
150
|
decoder_tensor_names = decoder.TENSOR_NAMES
|
@@ -153,15 +158,17 @@ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
|
|
153
158
|
model = PaliGemma(config, decoder_class)
|
154
159
|
# Load the parameters of image encoder.
|
155
160
|
loader = loading_utils.ModelLoader(
|
156
|
-
checkpoint_path, image_encoder.TENSOR_NAMES
|
161
|
+
checkpoint_path, image_encoder.TENSOR_NAMES, custom_loader
|
157
162
|
)
|
158
163
|
loader.load(model.image_encoder, strict=False)
|
159
164
|
# Load the parameters of decoder.
|
160
|
-
loader = loading_utils.ModelLoader(
|
165
|
+
loader = loading_utils.ModelLoader(
|
166
|
+
checkpoint_path, decoder_tensor_names, custom_loader
|
167
|
+
)
|
161
168
|
loader.load(model.decoder, strict=False)
|
162
169
|
|
163
170
|
# Load the parameters of image projection.
|
164
|
-
loader = loading_utils.ModelLoader(checkpoint_path, None)
|
171
|
+
loader = loading_utils.ModelLoader(checkpoint_path, None, custom_loader)
|
165
172
|
state = loader.get_state()
|
166
173
|
converted_state = dict()
|
167
174
|
converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
|
@@ -21,6 +21,7 @@ from absl import app
|
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.paligemma import paligemma
|
23
23
|
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import kagglehub
|
26
27
|
from PIL import Image
|
@@ -39,10 +40,15 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
39
40
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
40
41
|
"The image URI to encode.",
|
41
42
|
)
|
42
|
-
|
43
|
-
"
|
43
|
+
_PROMPTS_WITH_IMAGE = flags.DEFINE_string(
|
44
|
+
"prompts_with_image",
|
44
45
|
"<image><bos>describe en",
|
45
|
-
"The input prompts to generate answers.",
|
46
|
+
"The input prompts to generate answers with an image.",
|
47
|
+
)
|
48
|
+
_PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
|
49
|
+
"prompts_text_only",
|
50
|
+
"What is the meaning of life?",
|
51
|
+
"The input prompts to generate answers only with text.",
|
46
52
|
)
|
47
53
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
48
54
|
"max_new_tokens",
|
@@ -84,6 +90,7 @@ def main(_):
|
|
84
90
|
reauthored_model = paligemma.build_model(
|
85
91
|
reauthored_checkpoint, version=int(_VERSION.value)
|
86
92
|
)
|
93
|
+
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
|
87
94
|
|
88
95
|
logging.info("Loading the processor from: %s", checkpoint)
|
89
96
|
# It works only when GemmaTokenizerFast is available. In some environments,
|
@@ -91,9 +98,25 @@ def main(_):
|
|
91
98
|
# sentencepiece model file properly.
|
92
99
|
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
93
100
|
|
101
|
+
logging.info("Verifying with text-only prompts...")
|
102
|
+
verifier.verify_reauthored_model(
|
103
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
104
|
+
original_model
|
105
|
+
),
|
106
|
+
reauthored_model=wrapped_reauthored_model,
|
107
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
108
|
+
generate_prompts=_PROMPTS_TEXT_ONLY.value,
|
109
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
110
|
+
verify_inputs=False, # Numeric check not working. Disable it for now.
|
111
|
+
atol=1e-04,
|
112
|
+
)
|
113
|
+
|
114
|
+
logging.info("Verifying with image input...")
|
94
115
|
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
95
116
|
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
96
|
-
inputs = processor(
|
117
|
+
inputs = processor(
|
118
|
+
text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
|
119
|
+
)
|
97
120
|
|
98
121
|
logging.info("Verifying the reauthored model with model.forward()...")
|
99
122
|
logging.info("Forwarding the original model...")
|
@@ -104,7 +127,6 @@ def main(_):
|
|
104
127
|
logging.info("outputs_original: %s", outputs_original)
|
105
128
|
|
106
129
|
logging.info("Forwarding the reauthored model...")
|
107
|
-
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
|
108
130
|
outputs_reauthored = wrapped_reauthored_model.forward(
|
109
131
|
tokens=inputs["input_ids"],
|
110
132
|
pixel_values=inputs["pixel_values"],
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.phi import phi3
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("phi3")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = phi3.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.phi import phi4
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("phi4")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = phi4.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -20,13 +20,19 @@ from absl import app
|
|
20
20
|
from ai_edge_torch.generative.examples.phi import phi2
|
21
21
|
from ai_edge_torch.generative.utilities import converter
|
22
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
+
from ai_edge_torch.generative.utilities import loader
|
23
24
|
|
24
25
|
flags = converter.define_conversion_flags("phi2")
|
25
26
|
|
26
27
|
|
27
28
|
def main(_):
|
29
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
28
30
|
pytorch_model = phi2.build_model(
|
29
|
-
|
31
|
+
checkpoint_path,
|
32
|
+
custom_loader=loader.maybe_get_custom_loader(
|
33
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
34
|
+
),
|
35
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
30
36
|
)
|
31
37
|
converter.convert_to_tflite(
|
32
38
|
pytorch_model,
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-2 model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
20
21
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
21
23
|
from torch import nn
|
22
24
|
|
23
25
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
@@ -98,10 +100,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
98
100
|
return config
|
99
101
|
|
100
102
|
|
101
|
-
def build_model(
|
103
|
+
def build_model(
|
104
|
+
checkpoint_path: str,
|
105
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
106
|
+
**kwargs
|
107
|
+
) -> nn.Module:
|
102
108
|
return model_builder.build_decoder_only_model(
|
103
109
|
checkpoint_path=checkpoint_path,
|
104
110
|
config=get_model_config(**kwargs),
|
105
111
|
tensor_names=TENSOR_NAMES,
|
106
112
|
model_class=Phi2,
|
113
|
+
custom_loader=custom_loader,
|
107
114
|
)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from functools import partial
|
19
19
|
import math
|
20
|
-
from typing import Tuple
|
20
|
+
from typing import Callable, Dict, Tuple
|
21
21
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
23
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -208,11 +208,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
208
208
|
return config
|
209
209
|
|
210
210
|
|
211
|
-
def build_model(
|
211
|
+
def build_model(
|
212
|
+
checkpoint_path: str,
|
213
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
214
|
+
**kwargs
|
215
|
+
) -> torch.nn.Module:
|
212
216
|
"""Instantiates the model instance and load checkpoint if provided."""
|
213
217
|
return model_builder.build_decoder_only_model(
|
214
218
|
checkpoint_path=checkpoint_path,
|
215
219
|
config=get_model_config(**kwargs),
|
216
220
|
tensor_names=TENSOR_NAMES,
|
217
221
|
model_class=Phi3_5Mini,
|
222
|
+
custom_loader=custom_loader,
|
218
223
|
)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from functools import partial
|
19
19
|
import math
|
20
|
-
from typing import Tuple
|
20
|
+
from typing import Callable, Dict, Tuple
|
21
21
|
|
22
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
23
|
from ai_edge_torch.generative.utilities import model_builder
|
@@ -157,11 +157,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
157
157
|
return config
|
158
158
|
|
159
159
|
|
160
|
-
def build_model(
|
160
|
+
def build_model(
|
161
|
+
checkpoint_path: str,
|
162
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
163
|
+
**kwargs
|
164
|
+
) -> torch.nn.Module:
|
161
165
|
"""Instantiates the model instance and load checkpoint if provided."""
|
162
166
|
return model_builder.build_decoder_only_model(
|
163
167
|
checkpoint_path=checkpoint_path,
|
164
168
|
config=get_model_config(**kwargs),
|
165
169
|
tensor_names=TENSOR_NAMES,
|
166
170
|
model_class=Phi4Mini,
|
171
|
+
custom_loader=custom_loader,
|
167
172
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.qwen import qwen
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('qwen')
|
24
25
|
|
@@ -37,8 +38,13 @@ _BUILDER = {
|
|
37
38
|
|
38
39
|
|
39
40
|
def main(_):
|
41
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
40
42
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
41
|
-
|
43
|
+
checkpoint_path,
|
44
|
+
custom_loader=loader.maybe_get_custom_loader(
|
45
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
|
+
),
|
47
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
42
48
|
)
|
43
49
|
converter.convert_to_tflite(
|
44
50
|
pytorch_model,
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building Qwen 2.5 models."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -108,28 +110,43 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
108
110
|
return config
|
109
111
|
|
110
112
|
|
111
|
-
def build_3b_model(
|
113
|
+
def build_3b_model(
|
114
|
+
checkpoint_path: str,
|
115
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
116
|
+
**kwargs
|
117
|
+
) -> nn.Module:
|
112
118
|
return model_builder.build_decoder_only_model(
|
113
119
|
checkpoint_path=checkpoint_path,
|
114
120
|
config=get_3b_model_config(**kwargs),
|
115
121
|
tensor_names=TENSOR_NAMES,
|
116
122
|
model_class=Qwen,
|
123
|
+
custom_loader=custom_loader,
|
117
124
|
)
|
118
125
|
|
119
126
|
|
120
|
-
def build_1_5b_model(
|
127
|
+
def build_1_5b_model(
|
128
|
+
checkpoint_path: str,
|
129
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
130
|
+
**kwargs
|
131
|
+
) -> nn.Module:
|
121
132
|
return model_builder.build_decoder_only_model(
|
122
133
|
checkpoint_path=checkpoint_path,
|
123
134
|
config=get_1_5b_model_config(**kwargs),
|
124
135
|
tensor_names=TENSOR_NAMES,
|
125
136
|
model_class=Qwen,
|
137
|
+
custom_loader=custom_loader,
|
126
138
|
)
|
127
139
|
|
128
140
|
|
129
|
-
def build_0_5b_model(
|
141
|
+
def build_0_5b_model(
|
142
|
+
checkpoint_path: str,
|
143
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
144
|
+
**kwargs
|
145
|
+
) -> nn.Module:
|
130
146
|
return model_builder.build_decoder_only_model(
|
131
147
|
checkpoint_path=checkpoint_path,
|
132
148
|
config=get_0_5b_model_config(**kwargs),
|
133
149
|
tensor_names=TENSOR_NAMES,
|
134
150
|
model_class=Qwen,
|
151
|
+
custom_loader=custom_loader,
|
135
152
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('qwen_vl')
|
24
25
|
|
@@ -35,8 +36,12 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
|
|
35
36
|
|
36
37
|
|
37
38
|
def main(_):
|
39
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
38
40
|
pytorch_model = qwen_vl.build_model(
|
39
|
-
|
41
|
+
checkpoint_path,
|
42
|
+
custom_loader=loader.maybe_get_custom_loader(
|
43
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
44
|
+
),
|
40
45
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
46
|
image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
|
42
47
|
)
|
@@ -97,8 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
97
97
|
intermediate_size=11008,
|
98
98
|
)
|
99
99
|
norm_config = cfg.NormalizationConfig(
|
100
|
-
type=cfg.NormalizationType.RMS_NORM,
|
101
|
-
epsilon=1e-06,
|
100
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06, enable_hlfb=True
|
102
101
|
)
|
103
102
|
block_config = cfg.TransformerBlockConfig(
|
104
103
|
attn_config=attn_config,
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building an image encoder of Qwen 2.5 VL model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import List, Optional, Tuple
|
19
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import attention
|
22
22
|
from ai_edge_torch.generative.layers import attention_utils
|
@@ -385,13 +385,21 @@ def build_image_encoder(
|
|
385
385
|
return encoder
|
386
386
|
|
387
387
|
|
388
|
-
def load_image_encoder(
|
389
|
-
|
388
|
+
def load_image_encoder(
|
389
|
+
checkpoint_path: str,
|
390
|
+
encoder: QwenVLImageEncoder,
|
391
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
392
|
+
):
|
393
|
+
loader = loading_utils.ModelLoader(
|
394
|
+
checkpoint_path, TENSOR_NAMES, custom_loader
|
395
|
+
)
|
390
396
|
# Loose the strictness because only image encoder is being loaded.
|
391
397
|
loader.load(encoder, strict=False)
|
392
398
|
|
393
399
|
# Load merger weights.
|
394
|
-
merger_loader = loading_utils.ModelLoader(
|
400
|
+
merger_loader = loading_utils.ModelLoader(
|
401
|
+
checkpoint_path, None, custom_loader
|
402
|
+
)
|
395
403
|
state = merger_loader.get_state()
|
396
404
|
w1_state = dict()
|
397
405
|
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building a full-stack of Qwen 2.5 VL model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import List, Optional, Tuple
|
19
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.qwen_vl import decoder
|
22
22
|
from ai_edge_torch.generative.examples.qwen_vl import image_encoder
|
@@ -204,12 +204,20 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
|
204
204
|
)
|
205
205
|
|
206
206
|
|
207
|
-
def build_model(
|
207
|
+
def build_model(
|
208
|
+
checkpoint_path: str,
|
209
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
210
|
+
**kwargs
|
211
|
+
) -> QwenVL:
|
208
212
|
config = get_model_config(**kwargs)
|
209
213
|
model = QwenVL(config)
|
210
|
-
image_encoder.load_image_encoder(
|
214
|
+
image_encoder.load_image_encoder(
|
215
|
+
checkpoint_path, model.image_encoder, custom_loader
|
216
|
+
)
|
211
217
|
# Load the parameters of decoder.
|
212
|
-
loader = loading_utils.ModelLoader(
|
218
|
+
loader = loading_utils.ModelLoader(
|
219
|
+
checkpoint_path, decoder.TENSOR_NAMES, custom_loader
|
220
|
+
)
|
213
221
|
loader.load(model.decoder, strict=False)
|
214
222
|
model.eval()
|
215
223
|
return model
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache
|
25
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
26
|
from ai_edge_torch.generative.utilities import verifier
|
26
27
|
from PIL import Image
|
27
28
|
import requests
|
@@ -33,10 +34,15 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
33
34
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
34
35
|
"The image URI to encode.",
|
35
36
|
)
|
36
|
-
|
37
|
-
"
|
37
|
+
_PROMPTS_WITH_IMAGE = flags.DEFINE_string(
|
38
|
+
"prompts_with_image",
|
38
39
|
"<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
|
39
|
-
"The input prompts to generate answers.",
|
40
|
+
"The input prompts to generate answers with an image.",
|
41
|
+
)
|
42
|
+
_PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
|
43
|
+
"prompts_text_only",
|
44
|
+
"What is the meaning of life?",
|
45
|
+
"The input prompts to generate answers only with text.",
|
40
46
|
)
|
41
47
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
42
48
|
"max_new_tokens",
|
@@ -68,13 +74,29 @@ def main(_):
|
|
68
74
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
69
75
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
70
76
|
reauthored_model = qwen_vl.build_model(str(reauthored_checkpoint))
|
77
|
+
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
71
78
|
|
72
79
|
logging.info("Loading the processor from: %s", checkpoint)
|
73
80
|
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
74
81
|
|
82
|
+
logging.info("Verifying with text-only prompts...")
|
83
|
+
verifier.verify_reauthored_model(
|
84
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
85
|
+
original_model
|
86
|
+
),
|
87
|
+
reauthored_model=wrapped_reauthored_model,
|
88
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
89
|
+
generate_prompts=_PROMPTS_TEXT_ONLY.value,
|
90
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
91
|
+
atol=1e-04,
|
92
|
+
)
|
93
|
+
|
94
|
+
logging.info("Verifying with image input...")
|
75
95
|
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
76
96
|
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
77
|
-
inputs = processor(
|
97
|
+
inputs = processor(
|
98
|
+
text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
|
99
|
+
)
|
78
100
|
|
79
101
|
logging.info("Verifying the reauthored model with model.forward()...")
|
80
102
|
logging.info("Forwarding the original model...")
|
@@ -87,7 +109,6 @@ def main(_):
|
|
87
109
|
logging.info("outputs_original: %s", outputs_original)
|
88
110
|
|
89
111
|
logging.info("Forwarding the reauthored model...")
|
90
|
-
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
91
112
|
grid_thw = inputs["image_grid_thw"].tolist()
|
92
113
|
config = reauthored_model.config.image_encoder_config.image_embedding
|
93
114
|
reauthored_model.image_encoder.set_image_size(
|
@@ -15,12 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of converting SmolLM model to multi-signature tflite model."""
|
17
17
|
|
18
|
-
import os
|
19
18
|
from absl import app
|
20
19
|
from absl import flags
|
21
20
|
from ai_edge_torch.generative.examples.smollm import smollm
|
22
21
|
from ai_edge_torch.generative.utilities import converter
|
23
22
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
23
|
+
from ai_edge_torch.generative.utilities import loader
|
24
24
|
|
25
25
|
flags = converter.define_conversion_flags('smollm')
|
26
26
|
|
@@ -32,8 +32,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
|
32
32
|
|
33
33
|
|
34
34
|
def main(_):
|
35
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
35
36
|
pytorch_model = smollm.build_model(
|
36
|
-
|
37
|
+
checkpoint_path,
|
38
|
+
custom_loader=loader.maybe_get_custom_loader(
|
39
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
40
|
+
),
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
37
42
|
)
|
38
43
|
|
39
44
|
export_config = export_cfg.get_from_flags()
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.smollm import smollm
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('smollm2')
|
24
25
|
|
@@ -30,8 +31,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
|
30
31
|
|
31
32
|
|
32
33
|
def main(_):
|
34
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
33
35
|
pytorch_model = smollm.build_model_v2(
|
34
|
-
|
36
|
+
checkpoint_path,
|
37
|
+
custom_loader=loader.maybe_get_custom_loader(
|
38
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
39
|
+
),
|
40
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
35
41
|
)
|
36
42
|
|
37
43
|
export_config = export_cfg.get_from_flags()
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -80,12 +82,17 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
80
82
|
return config
|
81
83
|
|
82
84
|
|
83
|
-
def build_model(
|
85
|
+
def build_model(
|
86
|
+
checkpoint_path: str,
|
87
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
88
|
+
**kwargs
|
89
|
+
) -> nn.Module:
|
84
90
|
return model_builder.build_decoder_only_model(
|
85
91
|
checkpoint_path=checkpoint_path,
|
86
92
|
config=get_model_config(**kwargs),
|
87
93
|
tensor_names=TENSOR_NAMES,
|
88
94
|
model_class=SmolLM,
|
95
|
+
custom_loader=custom_loader,
|
89
96
|
)
|
90
97
|
|
91
98
|
|
@@ -118,10 +125,15 @@ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
|
118
125
|
return config
|
119
126
|
|
120
127
|
|
121
|
-
def build_model_v2(
|
128
|
+
def build_model_v2(
|
129
|
+
checkpoint_path: str,
|
130
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
131
|
+
**kwargs
|
132
|
+
) -> nn.Module:
|
122
133
|
return model_builder.build_decoder_only_model(
|
123
134
|
checkpoint_path=checkpoint_path,
|
124
135
|
config=get_model_config_v2(**kwargs),
|
125
136
|
tensor_names=TENSOR_NAMES,
|
126
137
|
model_class=SmolLM2,
|
138
|
+
custom_loader=custom_loader,
|
127
139
|
)
|
@@ -43,8 +43,8 @@ _MODEL_VERSION = flags.DEFINE_enum(
|
|
43
43
|
"The version of SmolLm to verify.",
|
44
44
|
)
|
45
45
|
_CHECKPOINT = {
|
46
|
-
"v1": "HuggingFaceTB/SmolLM-135M",
|
47
|
-
"v2": "HuggingFaceTB/SmolLM2-135M",
|
46
|
+
"v1": "HuggingFaceTB/SmolLM-135M-Instruct",
|
47
|
+
"v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
48
48
|
}
|
49
49
|
|
50
50
|
_BUILDER = {
|
@@ -57,7 +57,8 @@ class CLIP(nn.Module):
|
|
57
57
|
super().__init__()
|
58
58
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
59
59
|
self.tok_embedding_position = nn.Parameter(
|
60
|
-
torch.zeros((config.max_seq_len, config.embedding_dim))
|
60
|
+
torch.zeros((config.max_seq_len, config.embedding_dim)),
|
61
|
+
requires_grad=False,
|
61
62
|
)
|
62
63
|
|
63
64
|
self.config = config
|