ai-edge-torch-nightly 0.6.0.dev20250602__py3-none-any.whl → 0.6.0.dev20250604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +3 -1
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
- ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +6 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +3 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
- ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/llama/llama.py +13 -26
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
- ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
- ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/phi2.py +8 -16
- ai_edge_torch/generative/examples/phi/phi3.py +8 -16
- ai_edge_torch/generative/examples/phi/phi4.py +8 -16
- ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
- ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
- ai_edge_torch/generative/examples/t5/t5.py +23 -23
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
- ai_edge_torch/generative/layers/kv_cache.py +13 -1
- ai_edge_torch/generative/layers/model_config.py +0 -14
- ai_edge_torch/generative/test/test_kv_cache.py +14 -24
- ai_edge_torch/generative/test/test_lora.py +4 -21
- ai_edge_torch/generative/test/test_model_conversion.py +8 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
- ai_edge_torch/generative/utilities/converter.py +15 -6
- ai_edge_torch/generative/utilities/model_builder.py +16 -6
- ai_edge_torch/generative/utilities/verifier.py +16 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/RECORD +60 -60
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/top_level.txt +0 -0
@@ -89,16 +89,8 @@ class Phi4Mini(model_builder.DecoderOnlyModel):
|
|
89
89
|
pass
|
90
90
|
|
91
91
|
|
92
|
-
def get_model_config(
|
93
|
-
"""Returns the model config for a Phi-4 model.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
97
|
-
is 1024.
|
98
|
-
|
99
|
-
Returns:
|
100
|
-
The model config for a Phi-4 model.
|
101
|
-
"""
|
92
|
+
def get_model_config() -> cfg.ModelConfig:
|
93
|
+
"""Returns the model config for a Phi-4 model."""
|
102
94
|
attn_config = cfg.AttentionConfig(
|
103
95
|
num_heads=24,
|
104
96
|
head_dim=128,
|
@@ -135,7 +127,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
135
127
|
vocab_size=200064,
|
136
128
|
num_layers=32,
|
137
129
|
max_seq_len=max_seq_len,
|
138
|
-
kv_cache_max_len=kv_cache_max_len,
|
139
130
|
embedding_dim=3072,
|
140
131
|
block_configs=block_config,
|
141
132
|
final_norm_config=norm_config,
|
@@ -144,11 +135,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
144
135
|
return config
|
145
136
|
|
146
137
|
|
147
|
-
def get_fake_model_config(
|
148
|
-
config = get_model_config(
|
138
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
139
|
+
config = get_model_config()
|
149
140
|
config.vocab_size = 128
|
150
141
|
config.num_layers = 2
|
151
|
-
config.max_seq_len =
|
142
|
+
config.max_seq_len = 256
|
152
143
|
# Phi-4 has only one block config.
|
153
144
|
config.block_config(0).ff_config.intermediate_size = 128
|
154
145
|
return config
|
@@ -157,13 +148,14 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
157
148
|
def build_model(
|
158
149
|
checkpoint_path: str,
|
159
150
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
160
|
-
|
151
|
+
mask_cache_size: int = 0,
|
161
152
|
) -> torch.nn.Module:
|
162
153
|
"""Instantiates the model instance and load checkpoint if provided."""
|
163
154
|
return model_builder.build_decoder_only_model(
|
164
155
|
checkpoint_path=checkpoint_path,
|
165
|
-
config=get_model_config(
|
156
|
+
config=get_model_config(),
|
166
157
|
tensor_names=TENSOR_NAMES,
|
167
158
|
model_class=Phi4Mini,
|
168
159
|
custom_loader=custom_loader,
|
160
|
+
mask_cache_size=mask_cache_size,
|
169
161
|
)
|
@@ -15,7 +15,6 @@
|
|
15
15
|
"""Utils for verifying the Phi model."""
|
16
16
|
|
17
17
|
import logging
|
18
|
-
import os
|
19
18
|
import pathlib
|
20
19
|
from typing import Callable, Dict
|
21
20
|
|
@@ -39,7 +38,6 @@ _BUILDER = {
|
|
39
38
|
def verify_phi(
|
40
39
|
version: str,
|
41
40
|
checkpoint_dir: str,
|
42
|
-
weight_filename: str = "model.safetensors",
|
43
41
|
max_new_tokens: int = 30,
|
44
42
|
prompts: list[str] | None = None,
|
45
43
|
atol: float = 1e-04,
|
@@ -63,7 +61,7 @@ def verify_phi(
|
|
63
61
|
)
|
64
62
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
65
63
|
else:
|
66
|
-
reauthored_checkpoint =
|
64
|
+
reauthored_checkpoint = checkpoint_dir
|
67
65
|
|
68
66
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
69
67
|
reauthored_model = _BUILDER[version](
|
@@ -44,13 +44,14 @@ def main(_):
|
|
44
44
|
custom_loader=loader.maybe_get_custom_loader(
|
45
45
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
46
|
),
|
47
|
-
|
47
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
48
48
|
)
|
49
49
|
converter.convert_to_tflite(
|
50
50
|
pytorch_model,
|
51
51
|
output_path=flags.FLAGS.output_path,
|
52
52
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
53
53
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
54
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
54
55
|
quantize=flags.FLAGS.quantize,
|
55
56
|
lora_ranks=flags.FLAGS.lora_ranks,
|
56
57
|
export_config=export_config.get_from_flags(),
|
@@ -44,13 +44,14 @@ def main(_):
|
|
44
44
|
custom_loader=loader.maybe_get_custom_loader(
|
45
45
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
46
|
),
|
47
|
-
|
47
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
48
48
|
)
|
49
49
|
converter.convert_to_tflite(
|
50
50
|
pytorch_model,
|
51
51
|
output_path=flags.FLAGS.output_path,
|
52
52
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
53
53
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
54
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
54
55
|
quantize=flags.FLAGS.quantize,
|
55
56
|
lora_ranks=flags.FLAGS.lora_ranks,
|
56
57
|
export_config=export_config.get_from_flags(),
|
@@ -29,16 +29,8 @@ class Qwen(model_builder.DecoderOnlyModel):
|
|
29
29
|
pass
|
30
30
|
|
31
31
|
|
32
|
-
def get_3b_model_config(
|
33
|
-
"""Returns the model config for a Qwen 2.5 3B model.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
37
|
-
is 1024.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The model config for a SmolLM model.
|
41
|
-
"""
|
32
|
+
def get_3b_model_config() -> cfg.ModelConfig:
|
33
|
+
"""Returns the model config for a Qwen 2.5 3B model."""
|
42
34
|
attn_config = cfg.AttentionConfig(
|
43
35
|
num_heads=16,
|
44
36
|
head_dim=128,
|
@@ -66,16 +58,15 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
66
58
|
num_layers=36,
|
67
59
|
max_seq_len=32768,
|
68
60
|
embedding_dim=2048,
|
69
|
-
kv_cache_max_len=kv_cache_max_len,
|
70
61
|
block_configs=block_config,
|
71
62
|
final_norm_config=norm_config,
|
72
63
|
)
|
73
64
|
return config
|
74
65
|
|
75
66
|
|
76
|
-
def get_1_5b_model_config(
|
67
|
+
def get_1_5b_model_config() -> cfg.ModelConfig:
|
77
68
|
"""Returns the model config for a Qwen 2.5 1B model."""
|
78
|
-
config = get_3b_model_config(
|
69
|
+
config = get_3b_model_config()
|
79
70
|
# Qwen has only one block config.
|
80
71
|
block_config = config.block_config(0)
|
81
72
|
block_config.attn_config.num_heads = 12
|
@@ -85,9 +76,9 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
85
76
|
return config
|
86
77
|
|
87
78
|
|
88
|
-
def get_0_5b_model_config(
|
79
|
+
def get_0_5b_model_config() -> cfg.ModelConfig:
|
89
80
|
"""Returns the model config for a Qwen 2.5 0.5B model."""
|
90
|
-
config = get_3b_model_config(
|
81
|
+
config = get_3b_model_config()
|
91
82
|
# Qwen has only one block config.
|
92
83
|
block_config = config.block_config(0)
|
93
84
|
block_config.attn_config.num_heads = 14
|
@@ -98,8 +89,8 @@ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
98
89
|
return config
|
99
90
|
|
100
91
|
|
101
|
-
def get_fake_model_config(
|
102
|
-
config = get_3b_model_config(
|
92
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
93
|
+
config = get_3b_model_config()
|
103
94
|
config.vocab_size = 128
|
104
95
|
config.num_layers = 2
|
105
96
|
# Qwen has only one block config.
|
@@ -107,43 +98,47 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
107
98
|
return config
|
108
99
|
|
109
100
|
|
110
|
-
def
|
101
|
+
def _build_model(
|
111
102
|
checkpoint_path: str,
|
103
|
+
config: cfg.ModelConfig,
|
112
104
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
113
|
-
|
105
|
+
mask_cache_size: int = 0,
|
114
106
|
) -> nn.Module:
|
115
107
|
return model_builder.build_decoder_only_model(
|
116
108
|
checkpoint_path=checkpoint_path,
|
117
|
-
config=
|
109
|
+
config=config,
|
118
110
|
tensor_names=TENSOR_NAMES,
|
119
111
|
model_class=Qwen,
|
120
112
|
custom_loader=custom_loader,
|
113
|
+
mask_cache_size=mask_cache_size,
|
114
|
+
)
|
115
|
+
|
116
|
+
|
117
|
+
def build_3b_model(
|
118
|
+
checkpoint_path: str,
|
119
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
120
|
+
mask_cache_size: int = 0,
|
121
|
+
) -> nn.Module:
|
122
|
+
return _build_model(
|
123
|
+
checkpoint_path, get_3b_model_config(), custom_loader, mask_cache_size
|
121
124
|
)
|
122
125
|
|
123
126
|
|
124
127
|
def build_1_5b_model(
|
125
128
|
checkpoint_path: str,
|
126
129
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
127
|
-
|
130
|
+
mask_cache_size: int = 0,
|
128
131
|
) -> nn.Module:
|
129
|
-
return
|
130
|
-
checkpoint_path
|
131
|
-
config=get_1_5b_model_config(**kwargs),
|
132
|
-
tensor_names=TENSOR_NAMES,
|
133
|
-
model_class=Qwen,
|
134
|
-
custom_loader=custom_loader,
|
132
|
+
return _build_model(
|
133
|
+
checkpoint_path, get_1_5b_model_config(), custom_loader, mask_cache_size
|
135
134
|
)
|
136
135
|
|
137
136
|
|
138
137
|
def build_0_5b_model(
|
139
138
|
checkpoint_path: str,
|
140
139
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
141
|
-
|
140
|
+
mask_cache_size: int = 0,
|
142
141
|
) -> nn.Module:
|
143
|
-
return
|
144
|
-
checkpoint_path
|
145
|
-
config=get_0_5b_model_config(**kwargs),
|
146
|
-
tensor_names=TENSOR_NAMES,
|
147
|
-
model_class=Qwen,
|
148
|
-
custom_loader=custom_loader,
|
142
|
+
return _build_model(
|
143
|
+
checkpoint_path, get_0_5b_model_config(), custom_loader, mask_cache_size
|
149
144
|
)
|
@@ -42,20 +42,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
42
42
|
|
43
43
|
class Qwen3(model_builder.DecoderOnlyModel):
|
44
44
|
"""A Qwen3 model built from the Edge Generative API layers."""
|
45
|
-
|
46
45
|
pass
|
47
46
|
|
48
47
|
|
49
|
-
def get_4b_model_config(
|
50
|
-
"""Returns the model config for a Qwen 3.0 4B model.
|
51
|
-
|
52
|
-
Args:
|
53
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
54
|
-
is 1024.
|
55
|
-
|
56
|
-
Returns:
|
57
|
-
The model config for a SmolLM model.
|
58
|
-
"""
|
48
|
+
def get_4b_model_config() -> cfg.ModelConfig:
|
49
|
+
"""Returns the model config for a Qwen 3.0 4B model."""
|
59
50
|
norm_config = cfg.NormalizationConfig(
|
60
51
|
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
|
61
52
|
)
|
@@ -87,16 +78,15 @@ def get_4b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
87
78
|
num_layers=36,
|
88
79
|
max_seq_len=40960,
|
89
80
|
embedding_dim=2560,
|
90
|
-
kv_cache_max_len=kv_cache_max_len,
|
91
81
|
block_configs=block_config,
|
92
82
|
final_norm_config=norm_config,
|
93
83
|
)
|
94
84
|
return config
|
95
85
|
|
96
86
|
|
97
|
-
def get_1_7b_model_config(
|
87
|
+
def get_1_7b_model_config() -> cfg.ModelConfig:
|
98
88
|
"""Returns the model config for a Qwen 3.0 1.7B model."""
|
99
|
-
config = get_4b_model_config(
|
89
|
+
config = get_4b_model_config()
|
100
90
|
# Qwen has only one block config.
|
101
91
|
block_config = config.block_config(0)
|
102
92
|
block_config.attn_config.num_heads = 16
|
@@ -107,9 +97,9 @@ def get_1_7b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
107
97
|
return config
|
108
98
|
|
109
99
|
|
110
|
-
def get_0_6b_model_config(
|
100
|
+
def get_0_6b_model_config() -> cfg.ModelConfig:
|
111
101
|
"""Returns the model config for a Qwen 3.0 0.6B model."""
|
112
|
-
config = get_4b_model_config(
|
102
|
+
config = get_4b_model_config()
|
113
103
|
# Qwen has only one block config.
|
114
104
|
block_config = config.block_config(0)
|
115
105
|
block_config.attn_config.num_heads = 16
|
@@ -120,8 +110,8 @@ def get_0_6b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
120
110
|
return config
|
121
111
|
|
122
112
|
|
123
|
-
def get_fake_model_config(
|
124
|
-
config = get_4b_model_config(
|
113
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
114
|
+
config = get_4b_model_config()
|
125
115
|
config.vocab_size = 128
|
126
116
|
config.num_layers = 2
|
127
117
|
# Qwen has only one block config.
|
@@ -129,43 +119,47 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
129
119
|
return config
|
130
120
|
|
131
121
|
|
132
|
-
def
|
122
|
+
def _build_model(
|
133
123
|
checkpoint_path: str,
|
124
|
+
config: cfg.ModelConfig,
|
134
125
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
135
|
-
|
126
|
+
mask_cache_size: int = 0,
|
136
127
|
) -> nn.Module:
|
137
128
|
return model_builder.build_decoder_only_model(
|
138
129
|
checkpoint_path=checkpoint_path,
|
139
|
-
config=
|
130
|
+
config=config,
|
140
131
|
tensor_names=TENSOR_NAMES,
|
141
132
|
model_class=Qwen3,
|
142
133
|
custom_loader=custom_loader,
|
134
|
+
mask_cache_size=mask_cache_size,
|
135
|
+
)
|
136
|
+
|
137
|
+
|
138
|
+
def build_4b_model(
|
139
|
+
checkpoint_path: str,
|
140
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
141
|
+
mask_cache_size: int = 0,
|
142
|
+
) -> nn.Module:
|
143
|
+
return _build_model(
|
144
|
+
checkpoint_path, get_4b_model_config(), custom_loader, mask_cache_size
|
143
145
|
)
|
144
146
|
|
145
147
|
|
146
148
|
def build_1_7b_model(
|
147
149
|
checkpoint_path: str,
|
148
150
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
149
|
-
|
151
|
+
mask_cache_size: int = 0,
|
150
152
|
) -> nn.Module:
|
151
|
-
return
|
152
|
-
checkpoint_path
|
153
|
-
config=get_1_7b_model_config(**kwargs),
|
154
|
-
tensor_names=TENSOR_NAMES,
|
155
|
-
model_class=Qwen3,
|
156
|
-
custom_loader=custom_loader,
|
153
|
+
return _build_model(
|
154
|
+
checkpoint_path, get_1_7b_model_config(), custom_loader, mask_cache_size
|
157
155
|
)
|
158
156
|
|
159
157
|
|
160
158
|
def build_0_6b_model(
|
161
159
|
checkpoint_path: str,
|
162
160
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
163
|
-
|
161
|
+
mask_cache_size: int = 0,
|
164
162
|
) -> nn.Module:
|
165
|
-
return
|
166
|
-
checkpoint_path
|
167
|
-
config=get_0_6b_model_config(**kwargs),
|
168
|
-
tensor_names=TENSOR_NAMES,
|
169
|
-
model_class=Qwen3,
|
170
|
-
custom_loader=custom_loader,
|
163
|
+
return _build_model(
|
164
|
+
checkpoint_path, get_0_6b_model_config(), custom_loader, mask_cache_size
|
171
165
|
)
|
@@ -42,7 +42,7 @@ def main(_):
|
|
42
42
|
custom_loader=loader.maybe_get_custom_loader(
|
43
43
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
44
44
|
),
|
45
|
-
|
45
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
46
46
|
image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
|
47
47
|
)
|
48
48
|
|
@@ -55,6 +55,7 @@ def main(_):
|
|
55
55
|
output_path=flags.FLAGS.output_path,
|
56
56
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
57
57
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
58
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
58
59
|
pixel_values_size=(
|
59
60
|
pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
|
60
61
|
),
|
@@ -60,8 +60,9 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
60
60
|
rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
|
61
61
|
|
62
62
|
if mask is None:
|
63
|
+
assert kv_cache is not None, "KV cache must be provided."
|
63
64
|
mask = self.mask_cache.index_select(2, input_pos)
|
64
|
-
mask = mask[:, :, :, :
|
65
|
+
mask = mask[:, :, :, :kv_cache.get_max_seq_len()]
|
65
66
|
|
66
67
|
return self._forward_with_embeds(
|
67
68
|
input_embeds,
|
@@ -73,16 +74,8 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
73
74
|
)
|
74
75
|
|
75
76
|
|
76
|
-
def get_decoder_config(
|
77
|
-
"""Returns the model config for a Qwen 2.5 VL 3B model.
|
78
|
-
|
79
|
-
Args:
|
80
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
81
|
-
is 1024.
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
The model config for a Qwen 2.5 VL 3B model.
|
85
|
-
"""
|
77
|
+
def get_decoder_config() -> cfg.ModelConfig:
|
78
|
+
"""Returns the model config for a Qwen 2.5 VL 3B model."""
|
86
79
|
attn_config = cfg.AttentionConfig(
|
87
80
|
num_heads=16,
|
88
81
|
head_dim=128,
|
@@ -110,15 +103,14 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
110
103
|
num_layers=36,
|
111
104
|
max_seq_len=32768,
|
112
105
|
embedding_dim=2048,
|
113
|
-
kv_cache_max_len=kv_cache_max_len,
|
114
106
|
block_configs=block_config,
|
115
107
|
final_norm_config=norm_config,
|
116
108
|
)
|
117
109
|
return config
|
118
110
|
|
119
111
|
|
120
|
-
def get_fake_decoder_config(
|
121
|
-
config = get_decoder_config(
|
112
|
+
def get_fake_decoder_config() -> cfg.ModelConfig:
|
113
|
+
config = get_decoder_config()
|
122
114
|
config.vocab_size = 128
|
123
115
|
config.num_layers = 2
|
124
116
|
# Decoder has only one block config.
|
@@ -126,10 +118,13 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
|
|
126
118
|
return config
|
127
119
|
|
128
120
|
|
129
|
-
def build_decoder(
|
121
|
+
def build_decoder(
|
122
|
+
checkpoint_path: str, mask_cache_size: int = 0
|
123
|
+
) -> torch.nn.Module:
|
130
124
|
return model_builder.build_decoder_only_model(
|
131
125
|
checkpoint_path=checkpoint_path,
|
132
|
-
config=get_decoder_config(
|
126
|
+
config=get_decoder_config(),
|
133
127
|
tensor_names=TENSOR_NAMES,
|
134
128
|
model_class=Decoder,
|
129
|
+
mask_cache_size=mask_cache_size,
|
135
130
|
)
|
@@ -41,13 +41,13 @@ class QwenVLConfig:
|
|
41
41
|
class QwenVL(nn.Module):
|
42
42
|
"""Qwen VL model from the Edge Generative API."""
|
43
43
|
|
44
|
-
def __init__(self, config: QwenVLConfig):
|
44
|
+
def __init__(self, config: QwenVLConfig, mask_cache_size: int = 0):
|
45
45
|
super().__init__()
|
46
46
|
|
47
47
|
self.image_encoder = image_encoder.QwenVLImageEncoder(
|
48
48
|
config.image_encoder_config
|
49
49
|
)
|
50
|
-
self.decoder = decoder.Decoder(config.decoder_config)
|
50
|
+
self.decoder = decoder.Decoder(config.decoder_config, mask_cache_size)
|
51
51
|
# The amount of adjustment in input_pos to calculate RoPE properly in
|
52
52
|
# forward() calls after image is handled.
|
53
53
|
self.rope_pos_adjust = 0
|
@@ -179,26 +179,21 @@ class QwenVL(nn.Module):
|
|
179
179
|
|
180
180
|
|
181
181
|
def get_model_config(
|
182
|
-
kv_cache_max_len: int = 1024,
|
183
182
|
image_size: Tuple[int, int] = (34 * 14, 46 * 14),
|
184
183
|
) -> QwenVLConfig:
|
185
|
-
"""Returns the model config for a PaliGemma 3B-224 model.
|
186
|
-
|
187
|
-
Returns:
|
188
|
-
The model config for a PaliGemma 3B model.
|
189
|
-
"""
|
184
|
+
"""Returns the model config for a PaliGemma 3B-224 model."""
|
190
185
|
return QwenVLConfig(
|
191
186
|
image_encoder_config=image_encoder.get_image_encoder_config(image_size),
|
192
|
-
decoder_config=decoder.get_decoder_config(
|
187
|
+
decoder_config=decoder.get_decoder_config(),
|
193
188
|
image_token_id=151655,
|
194
189
|
mrope_section=[16, 24, 24],
|
195
190
|
)
|
196
191
|
|
197
192
|
|
198
|
-
def get_fake_model_config(
|
193
|
+
def get_fake_model_config() -> QwenVLConfig:
|
199
194
|
return QwenVLConfig(
|
200
195
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
201
|
-
decoder_config=decoder.get_fake_decoder_config(
|
196
|
+
decoder_config=decoder.get_fake_decoder_config(),
|
202
197
|
image_token_id=127,
|
203
198
|
mrope_section=[16, 24, 24],
|
204
199
|
)
|
@@ -207,10 +202,11 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
|
|
207
202
|
def build_model(
|
208
203
|
checkpoint_path: str,
|
209
204
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
205
|
+
mask_cache_size: int = 0,
|
210
206
|
**kwargs
|
211
207
|
) -> QwenVL:
|
212
208
|
config = get_model_config(**kwargs)
|
213
|
-
model = QwenVL(config)
|
209
|
+
model = QwenVL(config, mask_cache_size)
|
214
210
|
image_encoder.load_image_encoder(
|
215
211
|
checkpoint_path, model.image_encoder, custom_loader
|
216
212
|
)
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""Example of converting SmolLM model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
from absl import app
|
19
|
-
from absl import flags
|
20
19
|
from ai_edge_torch.generative.examples.smollm import smollm
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
21
|
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
@@ -38,7 +37,7 @@ def main(_):
|
|
38
37
|
custom_loader=loader.maybe_get_custom_loader(
|
39
38
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
40
39
|
),
|
41
|
-
|
40
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
42
41
|
)
|
43
42
|
|
44
43
|
export_config = export_cfg.get_from_flags()
|
@@ -49,6 +48,7 @@ def main(_):
|
|
49
48
|
output_path=flags.FLAGS.output_path,
|
50
49
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
51
50
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
51
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
52
52
|
quantize=flags.FLAGS.quantize,
|
53
53
|
lora_ranks=flags.FLAGS.lora_ranks,
|
54
54
|
export_config=export_config,
|
@@ -37,7 +37,7 @@ def main(_):
|
|
37
37
|
custom_loader=loader.maybe_get_custom_loader(
|
38
38
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
39
39
|
),
|
40
|
-
|
40
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
41
41
|
)
|
42
42
|
|
43
43
|
export_config = export_cfg.get_from_flags()
|
@@ -48,6 +48,7 @@ def main(_):
|
|
48
48
|
output_path=flags.FLAGS.output_path,
|
49
49
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
50
50
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
51
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
51
52
|
quantize=flags.FLAGS.quantize,
|
52
53
|
lora_ranks=flags.FLAGS.lora_ranks,
|
53
54
|
export_config=export_config,
|
@@ -29,16 +29,8 @@ class SmolLM(model_builder.DecoderOnlyModel):
|
|
29
29
|
pass
|
30
30
|
|
31
31
|
|
32
|
-
def get_model_config(
|
33
|
-
"""Returns the model config for a SmolLM 135M model.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
37
|
-
is 1024.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The model config for a SmolLM model.
|
41
|
-
"""
|
32
|
+
def get_model_config() -> cfg.ModelConfig:
|
33
|
+
"""Returns the model config for a SmolLM 135M model."""
|
42
34
|
attn_config = cfg.AttentionConfig(
|
43
35
|
num_heads=9,
|
44
36
|
head_dim=64,
|
@@ -63,15 +55,14 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
63
55
|
num_layers=30,
|
64
56
|
max_seq_len=2048,
|
65
57
|
embedding_dim=576,
|
66
|
-
kv_cache_max_len=kv_cache_max_len,
|
67
58
|
block_configs=block_config,
|
68
59
|
final_norm_config=norm_config,
|
69
60
|
)
|
70
61
|
return config
|
71
62
|
|
72
63
|
|
73
|
-
def get_fake_model_config(
|
74
|
-
config = get_model_config(
|
64
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
65
|
+
config = get_model_config()
|
75
66
|
config.vocab_size = 128
|
76
67
|
config.num_layers = 2
|
77
68
|
# SmolLM has only one block config.
|
@@ -82,14 +73,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
82
73
|
def build_model(
|
83
74
|
checkpoint_path: str,
|
84
75
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
85
|
-
|
76
|
+
mask_cache_size: int = 0,
|
86
77
|
) -> nn.Module:
|
87
78
|
return model_builder.build_decoder_only_model(
|
88
79
|
checkpoint_path=checkpoint_path,
|
89
|
-
config=get_model_config(
|
80
|
+
config=get_model_config(),
|
90
81
|
tensor_names=TENSOR_NAMES,
|
91
82
|
model_class=SmolLM,
|
92
83
|
custom_loader=custom_loader,
|
84
|
+
mask_cache_size=mask_cache_size,
|
93
85
|
)
|
94
86
|
|
95
87
|
|
@@ -98,23 +90,15 @@ class SmolLM2(model_builder.DecoderOnlyModel):
|
|
98
90
|
pass
|
99
91
|
|
100
92
|
|
101
|
-
def get_model_config_v2(
|
102
|
-
"""Returns the model config for a SmolLM2 135M model.
|
103
|
-
|
104
|
-
Args:
|
105
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
106
|
-
is 1024.
|
107
|
-
|
108
|
-
Returns:
|
109
|
-
The model config for a SmolLM2 model.
|
110
|
-
"""
|
111
|
-
config = get_model_config(kv_cache_max_len)
|
93
|
+
def get_model_config_v2() -> cfg.ModelConfig:
|
94
|
+
"""Returns the model config for a SmolLM2 135M model."""
|
95
|
+
config = get_model_config()
|
112
96
|
config.block_config(0).attn_config.rotary_base = 100000
|
113
97
|
return config
|
114
98
|
|
115
99
|
|
116
|
-
def get_fake_model_config_v2(
|
117
|
-
config = get_model_config_v2(
|
100
|
+
def get_fake_model_config_v2() -> cfg.ModelConfig:
|
101
|
+
config = get_model_config_v2()
|
118
102
|
config.vocab_size = 128
|
119
103
|
config.num_layers = 2
|
120
104
|
# SmolLM2 has only one block config.
|
@@ -125,12 +109,13 @@ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
|
125
109
|
def build_model_v2(
|
126
110
|
checkpoint_path: str,
|
127
111
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
128
|
-
|
112
|
+
mask_cache_size: int = 0,
|
129
113
|
) -> nn.Module:
|
130
114
|
return model_builder.build_decoder_only_model(
|
131
115
|
checkpoint_path=checkpoint_path,
|
132
|
-
config=get_model_config_v2(
|
116
|
+
config=get_model_config_v2(),
|
133
117
|
tensor_names=TENSOR_NAMES,
|
134
118
|
model_class=SmolLM2,
|
135
119
|
custom_loader=custom_loader,
|
120
|
+
mask_cache_size=mask_cache_size,
|
136
121
|
)
|