ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/_convert/conversion.py +24 -0
- ai_edge_torch/_convert/converter.py +57 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
- ai_edge_torch/_convert/test/test_convert.py +25 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +10 -6
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/deepseek/deepseek.py +9 -5
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +10 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +8 -7
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
- ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +15 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/llama/llama.py +26 -10
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +9 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +1 -4
- ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -4
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -5
- ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
- ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/phi/phi2.py +9 -5
- ai_edge_torch/generative/examples/phi/phi3.py +8 -6
- ai_edge_torch/generative/examples/phi/phi4.py +8 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +21 -7
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -3
- ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +13 -7
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
- ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -6
- ai_edge_torch/generative/examples/smollm/verify.py +2 -2
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +1 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +9 -5
- ai_edge_torch/generative/layers/model_config.py +2 -2
- ai_edge_torch/generative/utilities/converter.py +18 -5
- ai_edge_torch/generative/utilities/loader.py +19 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +64 -63
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building Qwen 2.5 models."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -51,9 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
51
53
|
intermediate_size=11008,
|
52
54
|
)
|
53
55
|
norm_config = cfg.NormalizationConfig(
|
54
|
-
type=cfg.NormalizationType.RMS_NORM,
|
55
|
-
epsilon=1e-06,
|
56
|
-
enable_hlfb=True,
|
56
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
57
57
|
)
|
58
58
|
block_config = cfg.TransformerBlockConfig(
|
59
59
|
attn_config=attn_config,
|
@@ -69,7 +69,6 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
69
69
|
kv_cache_max_len=kv_cache_max_len,
|
70
70
|
block_configs=block_config,
|
71
71
|
final_norm_config=norm_config,
|
72
|
-
enable_hlfb=True,
|
73
72
|
)
|
74
73
|
return config
|
75
74
|
|
@@ -108,28 +107,43 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
108
107
|
return config
|
109
108
|
|
110
109
|
|
111
|
-
def build_3b_model(
|
110
|
+
def build_3b_model(
|
111
|
+
checkpoint_path: str,
|
112
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
113
|
+
**kwargs
|
114
|
+
) -> nn.Module:
|
112
115
|
return model_builder.build_decoder_only_model(
|
113
116
|
checkpoint_path=checkpoint_path,
|
114
117
|
config=get_3b_model_config(**kwargs),
|
115
118
|
tensor_names=TENSOR_NAMES,
|
116
119
|
model_class=Qwen,
|
120
|
+
custom_loader=custom_loader,
|
117
121
|
)
|
118
122
|
|
119
123
|
|
120
|
-
def build_1_5b_model(
|
124
|
+
def build_1_5b_model(
|
125
|
+
checkpoint_path: str,
|
126
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
127
|
+
**kwargs
|
128
|
+
) -> nn.Module:
|
121
129
|
return model_builder.build_decoder_only_model(
|
122
130
|
checkpoint_path=checkpoint_path,
|
123
131
|
config=get_1_5b_model_config(**kwargs),
|
124
132
|
tensor_names=TENSOR_NAMES,
|
125
133
|
model_class=Qwen,
|
134
|
+
custom_loader=custom_loader,
|
126
135
|
)
|
127
136
|
|
128
137
|
|
129
|
-
def build_0_5b_model(
|
138
|
+
def build_0_5b_model(
|
139
|
+
checkpoint_path: str,
|
140
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
141
|
+
**kwargs
|
142
|
+
) -> nn.Module:
|
130
143
|
return model_builder.build_decoder_only_model(
|
131
144
|
checkpoint_path=checkpoint_path,
|
132
145
|
config=get_0_5b_model_config(**kwargs),
|
133
146
|
tensor_names=TENSOR_NAMES,
|
134
147
|
model_class=Qwen,
|
148
|
+
custom_loader=custom_loader,
|
135
149
|
)
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('qwen_vl')
|
24
25
|
|
@@ -35,8 +36,12 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
|
|
35
36
|
|
36
37
|
|
37
38
|
def main(_):
|
39
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
38
40
|
pytorch_model = qwen_vl.build_model(
|
39
|
-
|
41
|
+
checkpoint_path,
|
42
|
+
custom_loader=loader.maybe_get_custom_loader(
|
43
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
44
|
+
),
|
40
45
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
46
|
image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
|
42
47
|
)
|
@@ -97,8 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
97
97
|
intermediate_size=11008,
|
98
98
|
)
|
99
99
|
norm_config = cfg.NormalizationConfig(
|
100
|
-
type=cfg.NormalizationType.RMS_NORM,
|
101
|
-
epsilon=1e-06,
|
100
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
102
101
|
)
|
103
102
|
block_config = cfg.TransformerBlockConfig(
|
104
103
|
attn_config=attn_config,
|
@@ -114,7 +113,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
114
113
|
kv_cache_max_len=kv_cache_max_len,
|
115
114
|
block_configs=block_config,
|
116
115
|
final_norm_config=norm_config,
|
117
|
-
enable_hlfb=True,
|
118
116
|
)
|
119
117
|
return config
|
120
118
|
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building an image encoder of Qwen 2.5 VL model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import List, Optional, Tuple
|
19
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import attention
|
22
22
|
from ai_edge_torch.generative.layers import attention_utils
|
@@ -332,8 +332,7 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
|
332
332
|
use_bias=True,
|
333
333
|
)
|
334
334
|
norm_config = cfg.NormalizationConfig(
|
335
|
-
type=cfg.NormalizationType.RMS_NORM,
|
336
|
-
epsilon=1e-6,
|
335
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
337
336
|
)
|
338
337
|
block_config = cfg.TransformerBlockConfig(
|
339
338
|
attn_config=attn_config,
|
@@ -359,7 +358,6 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
|
|
359
358
|
window_size=112,
|
360
359
|
spatial_merge_size=2,
|
361
360
|
full_atten_block_indexes=[7, 15, 23, 31],
|
362
|
-
enable_hlfb=True,
|
363
361
|
)
|
364
362
|
return config
|
365
363
|
|
@@ -385,13 +383,21 @@ def build_image_encoder(
|
|
385
383
|
return encoder
|
386
384
|
|
387
385
|
|
388
|
-
def load_image_encoder(
|
389
|
-
|
386
|
+
def load_image_encoder(
|
387
|
+
checkpoint_path: str,
|
388
|
+
encoder: QwenVLImageEncoder,
|
389
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
390
|
+
):
|
391
|
+
loader = loading_utils.ModelLoader(
|
392
|
+
checkpoint_path, TENSOR_NAMES, custom_loader
|
393
|
+
)
|
390
394
|
# Loose the strictness because only image encoder is being loaded.
|
391
395
|
loader.load(encoder, strict=False)
|
392
396
|
|
393
397
|
# Load merger weights.
|
394
|
-
merger_loader = loading_utils.ModelLoader(
|
398
|
+
merger_loader = loading_utils.ModelLoader(
|
399
|
+
checkpoint_path, None, custom_loader
|
400
|
+
)
|
395
401
|
state = merger_loader.get_state()
|
396
402
|
w1_state = dict()
|
397
403
|
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Example of building a full-stack of Qwen 2.5 VL model."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import List, Optional, Tuple
|
19
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.examples.qwen_vl import decoder
|
22
22
|
from ai_edge_torch.generative.examples.qwen_vl import image_encoder
|
@@ -204,12 +204,20 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
|
204
204
|
)
|
205
205
|
|
206
206
|
|
207
|
-
def build_model(
|
207
|
+
def build_model(
|
208
|
+
checkpoint_path: str,
|
209
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
210
|
+
**kwargs
|
211
|
+
) -> QwenVL:
|
208
212
|
config = get_model_config(**kwargs)
|
209
213
|
model = QwenVL(config)
|
210
|
-
image_encoder.load_image_encoder(
|
214
|
+
image_encoder.load_image_encoder(
|
215
|
+
checkpoint_path, model.image_encoder, custom_loader
|
216
|
+
)
|
211
217
|
# Load the parameters of decoder.
|
212
|
-
loader = loading_utils.ModelLoader(
|
218
|
+
loader = loading_utils.ModelLoader(
|
219
|
+
checkpoint_path, decoder.TENSOR_NAMES, custom_loader
|
220
|
+
)
|
213
221
|
loader.load(model.decoder, strict=False)
|
214
222
|
model.eval()
|
215
223
|
return model
|
@@ -22,6 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache
|
25
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
26
|
from ai_edge_torch.generative.utilities import verifier
|
26
27
|
from PIL import Image
|
27
28
|
import requests
|
@@ -33,10 +34,15 @@ _IMAGE_URL = flags.DEFINE_string(
|
|
33
34
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
34
35
|
"The image URI to encode.",
|
35
36
|
)
|
36
|
-
|
37
|
-
"
|
37
|
+
_PROMPTS_WITH_IMAGE = flags.DEFINE_string(
|
38
|
+
"prompts_with_image",
|
38
39
|
"<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
|
39
|
-
"The input prompts to generate answers.",
|
40
|
+
"The input prompts to generate answers with an image.",
|
41
|
+
)
|
42
|
+
_PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
|
43
|
+
"prompts_text_only",
|
44
|
+
"What is the meaning of life?",
|
45
|
+
"The input prompts to generate answers only with text.",
|
40
46
|
)
|
41
47
|
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
42
48
|
"max_new_tokens",
|
@@ -68,13 +74,29 @@ def main(_):
|
|
68
74
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
69
75
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
70
76
|
reauthored_model = qwen_vl.build_model(str(reauthored_checkpoint))
|
77
|
+
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
71
78
|
|
72
79
|
logging.info("Loading the processor from: %s", checkpoint)
|
73
80
|
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
74
81
|
|
82
|
+
logging.info("Verifying with text-only prompts...")
|
83
|
+
verifier.verify_reauthored_model(
|
84
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
85
|
+
original_model
|
86
|
+
),
|
87
|
+
reauthored_model=wrapped_reauthored_model,
|
88
|
+
tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
|
89
|
+
generate_prompts=_PROMPTS_TEXT_ONLY.value,
|
90
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
91
|
+
atol=1e-04,
|
92
|
+
)
|
93
|
+
|
94
|
+
logging.info("Verifying with image input...")
|
75
95
|
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
76
96
|
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
77
|
-
inputs = processor(
|
97
|
+
inputs = processor(
|
98
|
+
text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
|
99
|
+
)
|
78
100
|
|
79
101
|
logging.info("Verifying the reauthored model with model.forward()...")
|
80
102
|
logging.info("Forwarding the original model...")
|
@@ -87,7 +109,6 @@ def main(_):
|
|
87
109
|
logging.info("outputs_original: %s", outputs_original)
|
88
110
|
|
89
111
|
logging.info("Forwarding the reauthored model...")
|
90
|
-
wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
|
91
112
|
grid_thw = inputs["image_grid_thw"].tolist()
|
92
113
|
config = reauthored_model.config.image_encoder_config.image_embedding
|
93
114
|
reauthored_model.image_encoder.set_image_size(
|
@@ -15,12 +15,12 @@
|
|
15
15
|
|
16
16
|
"""Example of converting SmolLM model to multi-signature tflite model."""
|
17
17
|
|
18
|
-
import os
|
19
18
|
from absl import app
|
20
19
|
from absl import flags
|
21
20
|
from ai_edge_torch.generative.examples.smollm import smollm
|
22
21
|
from ai_edge_torch.generative.utilities import converter
|
23
22
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
23
|
+
from ai_edge_torch.generative.utilities import loader
|
24
24
|
|
25
25
|
flags = converter.define_conversion_flags('smollm')
|
26
26
|
|
@@ -32,8 +32,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
|
32
32
|
|
33
33
|
|
34
34
|
def main(_):
|
35
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
35
36
|
pytorch_model = smollm.build_model(
|
36
|
-
|
37
|
+
checkpoint_path,
|
38
|
+
custom_loader=loader.maybe_get_custom_loader(
|
39
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
40
|
+
),
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
37
42
|
)
|
38
43
|
|
39
44
|
export_config = export_cfg.get_from_flags()
|
@@ -19,6 +19,7 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.smollm import smollm
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags('smollm2')
|
24
25
|
|
@@ -30,8 +31,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
|
30
31
|
|
31
32
|
|
32
33
|
def main(_):
|
34
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
33
35
|
pytorch_model = smollm.build_model_v2(
|
34
|
-
|
36
|
+
checkpoint_path,
|
37
|
+
custom_loader=loader.maybe_get_custom_loader(
|
38
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
39
|
+
),
|
40
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
35
41
|
)
|
36
42
|
|
37
43
|
export_config = export_cfg.get_from_flags()
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a SmolLM model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
@@ -49,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
52
|
intermediate_size=1536,
|
51
53
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
53
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
55
55
|
block_config = cfg.TransformerBlockConfig(
|
56
56
|
attn_config=attn_config,
|
57
57
|
ff_config=ff_config,
|
@@ -66,7 +66,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
66
66
|
kv_cache_max_len=kv_cache_max_len,
|
67
67
|
block_configs=block_config,
|
68
68
|
final_norm_config=norm_config,
|
69
|
-
enable_hlfb=True,
|
70
69
|
)
|
71
70
|
return config
|
72
71
|
|
@@ -80,12 +79,17 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
80
79
|
return config
|
81
80
|
|
82
81
|
|
83
|
-
def build_model(
|
82
|
+
def build_model(
|
83
|
+
checkpoint_path: str,
|
84
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
85
|
+
**kwargs
|
86
|
+
) -> nn.Module:
|
84
87
|
return model_builder.build_decoder_only_model(
|
85
88
|
checkpoint_path=checkpoint_path,
|
86
89
|
config=get_model_config(**kwargs),
|
87
90
|
tensor_names=TENSOR_NAMES,
|
88
91
|
model_class=SmolLM,
|
92
|
+
custom_loader=custom_loader,
|
89
93
|
)
|
90
94
|
|
91
95
|
|
@@ -118,10 +122,15 @@ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
|
118
122
|
return config
|
119
123
|
|
120
124
|
|
121
|
-
def build_model_v2(
|
125
|
+
def build_model_v2(
|
126
|
+
checkpoint_path: str,
|
127
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
128
|
+
**kwargs
|
129
|
+
) -> nn.Module:
|
122
130
|
return model_builder.build_decoder_only_model(
|
123
131
|
checkpoint_path=checkpoint_path,
|
124
132
|
config=get_model_config_v2(**kwargs),
|
125
133
|
tensor_names=TENSOR_NAMES,
|
126
134
|
model_class=SmolLM2,
|
135
|
+
custom_loader=custom_loader,
|
127
136
|
)
|
@@ -43,8 +43,8 @@ _MODEL_VERSION = flags.DEFINE_enum(
|
|
43
43
|
"The version of SmolLm to verify.",
|
44
44
|
)
|
45
45
|
_CHECKPOINT = {
|
46
|
-
"v1": "HuggingFaceTB/SmolLM-135M",
|
47
|
-
"v2": "HuggingFaceTB/SmolLM2-135M",
|
46
|
+
"v1": "HuggingFaceTB/SmolLM-135M-Instruct",
|
47
|
+
"v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
48
48
|
}
|
49
49
|
|
50
50
|
_BUILDER = {
|
@@ -57,7 +57,8 @@ class CLIP(nn.Module):
|
|
57
57
|
super().__init__()
|
58
58
|
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
59
59
|
self.tok_embedding_position = nn.Parameter(
|
60
|
-
torch.zeros((config.max_seq_len, config.embedding_dim))
|
60
|
+
torch.zeros((config.max_seq_len, config.embedding_dim)),
|
61
|
+
requires_grad=False,
|
61
62
|
)
|
62
63
|
|
63
64
|
self.config = config
|
@@ -112,7 +113,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
112
113
|
use_bias=True,
|
113
114
|
)
|
114
115
|
|
115
|
-
norm_config = cfg.NormalizationConfig(
|
116
|
+
norm_config = cfg.NormalizationConfig(
|
117
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
|
118
|
+
)
|
116
119
|
|
117
120
|
block_config = cfg.TransformerBlockConfig(
|
118
121
|
attn_config=attn_config,
|
@@ -128,7 +131,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
128
131
|
embedding_dim=embedding_dim,
|
129
132
|
block_configs=block_config,
|
130
133
|
final_norm_config=norm_config,
|
131
|
-
enable_hlfb=True,
|
132
134
|
)
|
133
135
|
|
134
136
|
return config
|
@@ -163,7 +165,9 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
163
165
|
use_bias=True,
|
164
166
|
)
|
165
167
|
|
166
|
-
norm_config = cfg.NormalizationConfig(
|
168
|
+
norm_config = cfg.NormalizationConfig(
|
169
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
|
170
|
+
)
|
167
171
|
|
168
172
|
block_config = cfg.TransformerBlockConfig(
|
169
173
|
attn_config=attn_config,
|
@@ -179,7 +183,6 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
179
183
|
embedding_dim=embedding_dim,
|
180
184
|
block_configs=block_config,
|
181
185
|
final_norm_config=norm_config,
|
182
|
-
enable_hlfb=True,
|
183
186
|
)
|
184
187
|
|
185
188
|
return config
|
@@ -393,8 +393,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
393
393
|
)
|
394
394
|
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
|
395
395
|
norm_config = cfg.NormalizationConfig(
|
396
|
-
type=cfg.NormalizationType.RMS_NORM,
|
397
|
-
epsilon=1e-6,
|
396
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=False
|
398
397
|
)
|
399
398
|
block_config = cfg.TransformerBlockConfig(
|
400
399
|
attn_config=attn_config,
|
@@ -411,7 +410,6 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
411
410
|
block_configs=block_config,
|
412
411
|
final_norm_config=norm_config,
|
413
412
|
lm_head_use_bias=False,
|
414
|
-
enable_hlfb=True,
|
415
413
|
)
|
416
414
|
return config
|
417
415
|
|
@@ -138,7 +138,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
138
138
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
139
139
|
intermediate_size=256,
|
140
140
|
)
|
141
|
-
norm_config = cfg.NormalizationConfig(
|
141
|
+
norm_config = cfg.NormalizationConfig(
|
142
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
|
143
|
+
)
|
142
144
|
block_config = cfg.TransformerBlockConfig(
|
143
145
|
attn_config=attn_config,
|
144
146
|
ff_config=ff_config,
|
@@ -152,5 +154,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
152
154
|
embedding_dim=128,
|
153
155
|
block_configs=block_config,
|
154
156
|
final_norm_config=norm_config,
|
157
|
+
enable_hlfb=False,
|
155
158
|
)
|
156
159
|
return config
|
@@ -108,7 +108,9 @@ def get_model_config() -> cfg.ModelConfig:
|
|
108
108
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
109
109
|
intermediate_size=256,
|
110
110
|
)
|
111
|
-
norm_config = cfg.NormalizationConfig(
|
111
|
+
norm_config = cfg.NormalizationConfig(
|
112
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
|
113
|
+
)
|
112
114
|
block_config = cfg.TransformerBlockConfig(
|
113
115
|
attn_config=attn_config,
|
114
116
|
ff_config=ff_config,
|
@@ -122,7 +124,6 @@ def get_model_config() -> cfg.ModelConfig:
|
|
122
124
|
embedding_dim=128,
|
123
125
|
block_configs=block_config,
|
124
126
|
final_norm_config=norm_config,
|
125
|
-
enable_hlfb=True,
|
126
127
|
)
|
127
128
|
return config
|
128
129
|
|
@@ -19,13 +19,19 @@ from absl import app
|
|
19
19
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
20
20
|
from ai_edge_torch.generative.utilities import converter
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
|
+
from ai_edge_torch.generative.utilities import loader
|
22
23
|
|
23
24
|
flags = converter.define_conversion_flags("tiny_llama")
|
24
25
|
|
25
26
|
|
26
27
|
def main(_):
|
28
|
+
checkpoint_path = flags.FLAGS.checkpoint_path
|
27
29
|
pytorch_model = tiny_llama.build_model(
|
28
|
-
|
30
|
+
checkpoint_path,
|
31
|
+
custom_loader=loader.maybe_get_custom_loader(
|
32
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
|
+
),
|
34
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
29
35
|
)
|
30
36
|
converter.convert_to_tflite(
|
31
37
|
pytorch_model,
|
@@ -15,8 +15,10 @@
|
|
15
15
|
|
16
16
|
"""Example of building a TinyLlama model."""
|
17
17
|
|
18
|
+
from typing import Callable, Dict
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
21
|
+
import torch
|
20
22
|
from torch import nn
|
21
23
|
|
22
24
|
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
|
@@ -49,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
49
51
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
50
52
|
intermediate_size=5632,
|
51
53
|
)
|
52
|
-
norm_config = cfg.NormalizationConfig(
|
53
|
-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
54
|
-
)
|
54
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
|
55
55
|
block_config = cfg.TransformerBlockConfig(
|
56
56
|
attn_config=attn_config,
|
57
57
|
ff_config=ff_config,
|
@@ -67,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
67
67
|
block_configs=block_config,
|
68
68
|
final_norm_config=norm_config,
|
69
69
|
lm_head_share_weight_with_embedding=False,
|
70
|
-
enable_hlfb=True,
|
71
70
|
)
|
72
71
|
return config
|
73
72
|
|
@@ -81,10 +80,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
81
80
|
return config
|
82
81
|
|
83
82
|
|
84
|
-
def build_model(
|
83
|
+
def build_model(
|
84
|
+
checkpoint_path: str,
|
85
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
86
|
+
**kwargs
|
87
|
+
) -> nn.Module:
|
85
88
|
return model_builder.build_decoder_only_model(
|
86
89
|
checkpoint_path=checkpoint_path,
|
87
90
|
config=get_model_config(**kwargs),
|
88
91
|
tensor_names=TENSOR_NAMES,
|
89
92
|
model_class=TinyLlama,
|
93
|
+
custom_loader=custom_loader,
|
90
94
|
)
|
@@ -66,7 +66,7 @@ class NormalizationConfig:
|
|
66
66
|
"""Normalizater parameters."""
|
67
67
|
|
68
68
|
type: NormalizationType = NormalizationType.NONE
|
69
|
-
enable_hlfb: bool =
|
69
|
+
enable_hlfb: bool = True
|
70
70
|
epsilon: float = 1e-5
|
71
71
|
zero_centered: bool = False
|
72
72
|
# Number of groups used in group normalization.
|
@@ -218,7 +218,7 @@ class ModelConfig:
|
|
218
218
|
lm_head_share_weight_with_embedding: bool = True
|
219
219
|
|
220
220
|
# Whether to turn on high-level function boundary.
|
221
|
-
enable_hlfb: bool =
|
221
|
+
enable_hlfb: bool = True
|
222
222
|
|
223
223
|
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
224
224
|
kv_cache_max_len: int = 0
|
@@ -100,7 +100,8 @@ def define_conversion_flags(
|
|
100
100
|
flags.DEFINE_string(
|
101
101
|
'quantize',
|
102
102
|
'dynamic_int8',
|
103
|
-
'How the model should be quantized.'
|
103
|
+
'How the model should be quantized. Set to "none" to disable'
|
104
|
+
' quantization. See `QuantizationName` for supported quantization types.',
|
104
105
|
)
|
105
106
|
flags.DEFINE_multi_integer(
|
106
107
|
'lora_ranks',
|
@@ -119,6 +120,12 @@ def define_conversion_flags(
|
|
119
120
|
default_transpose_kv_cache,
|
120
121
|
'If true, the model will be converted with transposed KV cache.',
|
121
122
|
)
|
123
|
+
flags.DEFINE_bool(
|
124
|
+
'custom_checkpoint_loader',
|
125
|
+
False,
|
126
|
+
'If true, the conversion script will use a custom checkpoint loader which'
|
127
|
+
' will read a checkpoint from a remote source.',
|
128
|
+
)
|
122
129
|
return flags
|
123
130
|
|
124
131
|
|
@@ -397,13 +404,19 @@ def _export_helper(
|
|
397
404
|
)
|
398
405
|
|
399
406
|
if prefill_pixel_values is not None:
|
400
|
-
|
401
|
-
|
402
|
-
|
407
|
+
sample_pixel_kwargs = {
|
408
|
+
'tokens': prefill_tokens_list_with_pixel[i],
|
409
|
+
'input_pos': prefill_input_pos_list_with_pixel[i],
|
410
|
+
'kv_cache': prefill_kv,
|
411
|
+
'pixel_values': prefill_pixel_values,
|
412
|
+
}
|
413
|
+
# mask should be built internally when pixel values are passed.
|
414
|
+
if lora is not None:
|
415
|
+
sample_pixel_kwargs['lora'] = lora
|
403
416
|
converter.add_signature(
|
404
417
|
prefill_signature_name + '_pixel',
|
405
418
|
mod,
|
406
|
-
sample_kwargs=
|
419
|
+
sample_kwargs=sample_pixel_kwargs,
|
407
420
|
)
|
408
421
|
|
409
422
|
sample_kwargs = {
|
@@ -49,6 +49,25 @@ def get_custom_loader(
|
|
49
49
|
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
|
50
50
|
|
51
51
|
|
52
|
+
def maybe_get_custom_loader(
|
53
|
+
checkpoint_path: str,
|
54
|
+
use_custom_loader: bool = False,
|
55
|
+
) -> Callable[[str], Dict[str, torch.Tensor]] | None:
|
56
|
+
"""Returns a custom loader for the given checkpoint path.
|
57
|
+
|
58
|
+
If use_custom_loader is True, the function will return a custom loader.
|
59
|
+
Otherwise, it will return None.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
checkpoint_path (string): The path to the checkpoint.
|
63
|
+
use_custom_loader (bool): Whether to use a custom loader.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Callable[[str], Dict[str, torch.Tensor]] | None: The custom loader.
|
67
|
+
"""
|
68
|
+
return get_custom_loader(checkpoint_path) if use_custom_loader else None
|
69
|
+
|
70
|
+
|
52
71
|
def load_safetensors(full_path: str):
|
53
72
|
"""Loads safetensors into a single state dictionary.
|
54
73
|
|