ai-edge-torch-nightly 0.6.0.dev20250601__py3-none-any.whl → 0.6.0.dev20250603__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
- ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
- ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/llama/llama.py +13 -26
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
- ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
- ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/phi2.py +8 -16
- ai_edge_torch/generative/examples/phi/phi3.py +8 -16
- ai_edge_torch/generative/examples/phi/phi4.py +8 -16
- ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
- ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
- ai_edge_torch/generative/examples/t5/t5.py +23 -23
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
- ai_edge_torch/generative/layers/kv_cache.py +13 -1
- ai_edge_torch/generative/layers/model_config.py +0 -14
- ai_edge_torch/generative/test/test_kv_cache.py +14 -24
- ai_edge_torch/generative/test/test_lora.py +4 -21
- ai_edge_torch/generative/test/test_model_conversion.py +8 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
- ai_edge_torch/generative/utilities/converter.py +15 -6
- ai_edge_torch/generative/utilities/model_builder.py +16 -6
- ai_edge_torch/generative/utilities/verifier.py +16 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/RECORD +57 -57
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
from ai_edge_torch.generative.utilities import loader
|
23
23
|
|
24
|
-
|
25
24
|
flags = converter.define_conversion_flags('llama')
|
26
25
|
|
27
26
|
_MODEL_SIZE = flags.DEFINE_enum(
|
@@ -44,13 +43,14 @@ def main(_):
|
|
44
43
|
custom_loader=loader.maybe_get_custom_loader(
|
45
44
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
46
45
|
),
|
47
|
-
|
46
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
48
47
|
)
|
49
48
|
converter.convert_to_tflite(
|
50
49
|
pytorch_model,
|
51
50
|
output_path=flags.FLAGS.output_path,
|
52
51
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
53
52
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
53
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
54
54
|
quantize=flags.FLAGS.quantize,
|
55
55
|
lora_ranks=flags.FLAGS.lora_ranks,
|
56
56
|
export_config=export_config.get_from_flags(),
|
@@ -93,22 +93,12 @@ class Llama(model_builder.DecoderOnlyModel):
|
|
93
93
|
|
94
94
|
Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
|
95
95
|
"""
|
96
|
+
pass
|
96
97
|
|
97
|
-
def __init__(self, config: cfg.ModelConfig):
|
98
|
-
super().__init__(config)
|
99
|
-
attn_config = self.config.block_config(0).attn_config
|
100
98
|
|
99
|
+
def get_1b_model_config() -> cfg.ModelConfig:
|
100
|
+
"""Returns the model config for a Llama 3.2-1B model."""
|
101
101
|
|
102
|
-
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
103
|
-
"""Returns the model config for a Llama 3.2-1B model.
|
104
|
-
|
105
|
-
Args:
|
106
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
107
|
-
is 1024.
|
108
|
-
|
109
|
-
Returns:
|
110
|
-
The model config for a SmolLM model.
|
111
|
-
"""
|
112
102
|
attn_config = cfg.AttentionConfig(
|
113
103
|
num_heads=32,
|
114
104
|
head_dim=64,
|
@@ -147,7 +137,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
147
137
|
num_layers=16,
|
148
138
|
max_seq_len=max_seq_len,
|
149
139
|
embedding_dim=2048,
|
150
|
-
kv_cache_max_len=kv_cache_max_len,
|
151
140
|
block_configs=block_config,
|
152
141
|
final_norm_config=norm_config,
|
153
142
|
build_rope=build_rope,
|
@@ -155,9 +144,9 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
155
144
|
return config
|
156
145
|
|
157
146
|
|
158
|
-
def get_3b_model_config(
|
147
|
+
def get_3b_model_config() -> cfg.ModelConfig:
|
159
148
|
"""Returns the model config for a Llama 3.2-3B model."""
|
160
|
-
config = get_1b_model_config(
|
149
|
+
config = get_1b_model_config()
|
161
150
|
# Llama 3.2 has only one block config.
|
162
151
|
attn_config = config.block_config(0).attn_config
|
163
152
|
attn_config.num_heads = 24
|
@@ -167,8 +156,8 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
167
156
|
return config
|
168
157
|
|
169
158
|
|
170
|
-
def get_fake_model_config(
|
171
|
-
config = get_1b_model_config(
|
159
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
160
|
+
config = get_1b_model_config()
|
172
161
|
config.vocab_size = 128
|
173
162
|
config.num_layers = 2
|
174
163
|
# SmolLM has only one block config.
|
@@ -180,6 +169,7 @@ def _build_model(
|
|
180
169
|
checkpoint_path: str,
|
181
170
|
config: cfg.ModelConfig,
|
182
171
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
172
|
+
mask_cache_size: int = 0,
|
183
173
|
) -> torch.nn.Module:
|
184
174
|
return model_builder.build_decoder_only_model(
|
185
175
|
checkpoint_path=checkpoint_path,
|
@@ -187,28 +177,25 @@ def _build_model(
|
|
187
177
|
tensor_names=TENSOR_NAMES,
|
188
178
|
model_class=Llama,
|
189
179
|
custom_loader=custom_loader,
|
180
|
+
mask_cache_size=mask_cache_size,
|
190
181
|
)
|
191
182
|
|
192
183
|
|
193
184
|
def build_1b_model(
|
194
185
|
checkpoint_path: str,
|
195
186
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
196
|
-
|
187
|
+
mask_cache_size: int = 0,
|
197
188
|
) -> torch.nn.Module:
|
198
189
|
return _build_model(
|
199
|
-
checkpoint_path,
|
200
|
-
get_1b_model_config(**kwargs),
|
201
|
-
custom_loader=custom_loader,
|
190
|
+
checkpoint_path, get_1b_model_config(), custom_loader, mask_cache_size
|
202
191
|
)
|
203
192
|
|
204
193
|
|
205
194
|
def build_3b_model(
|
206
195
|
checkpoint_path: str,
|
207
196
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
208
|
-
|
197
|
+
mask_cache_size: int = 0,
|
209
198
|
) -> torch.nn.Module:
|
210
199
|
return _build_model(
|
211
|
-
checkpoint_path,
|
212
|
-
get_3b_model_config(**kwargs),
|
213
|
-
custom_loader=custom_loader,
|
200
|
+
checkpoint_path, get_3b_model_config(), custom_loader, mask_cache_size
|
214
201
|
)
|
@@ -31,13 +31,14 @@ def main(_):
|
|
31
31
|
custom_loader=loader.maybe_get_custom_loader(
|
32
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
33
|
),
|
34
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
35
35
|
)
|
36
36
|
converter.convert_to_tflite(
|
37
37
|
pytorch_model,
|
38
38
|
output_path=flags.FLAGS.output_path,
|
39
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
40
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
42
|
quantize=flags.FLAGS.quantize,
|
42
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
43
44
|
export_config=export_config.get_from_flags(),
|
@@ -42,16 +42,8 @@ class OpenELM(model_builder.DecoderOnlyModel):
|
|
42
42
|
pass
|
43
43
|
|
44
44
|
|
45
|
-
def get_model_config(
|
46
|
-
"""Returns the model config for an OpenELM model.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
50
|
-
is 1024.
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
The model config for an OpenELM model.
|
54
|
-
"""
|
45
|
+
def get_model_config() -> cfg.ModelConfig:
|
46
|
+
"""Returns the model config for an OpenELM model."""
|
55
47
|
norm_config = cfg.NormalizationConfig(
|
56
48
|
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
|
57
49
|
)
|
@@ -98,18 +90,17 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
98
90
|
num_layers=num_layers,
|
99
91
|
max_seq_len=2048,
|
100
92
|
embedding_dim=3072,
|
101
|
-
kv_cache_max_len=kv_cache_max_len,
|
102
93
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
103
94
|
final_norm_config=norm_config,
|
104
95
|
)
|
105
96
|
return config
|
106
97
|
|
107
98
|
|
108
|
-
def get_fake_model_config(
|
109
|
-
config = get_model_config(
|
99
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
100
|
+
config = get_model_config()
|
110
101
|
config.vocab_size = 128
|
111
102
|
config.num_layers = 2
|
112
|
-
config.max_seq_len =
|
103
|
+
config.max_seq_len = 256
|
113
104
|
config.embedding_dim = 128
|
114
105
|
config.block_configs = config.block_configs[: config.num_layers]
|
115
106
|
for block_config in config.block_configs:
|
@@ -122,12 +113,13 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
122
113
|
def build_model(
|
123
114
|
checkpoint_path: str,
|
124
115
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
125
|
-
|
116
|
+
mask_cache_size: int = 0,
|
126
117
|
) -> nn.Module:
|
127
118
|
return model_builder.build_decoder_only_model(
|
128
119
|
checkpoint_path=checkpoint_path,
|
129
|
-
config=get_model_config(
|
120
|
+
config=get_model_config(),
|
130
121
|
tensor_names=TENSOR_NAMES,
|
131
122
|
model_class=OpenELM,
|
132
123
|
custom_loader=custom_loader,
|
124
|
+
mask_cache_size=mask_cache_size,
|
133
125
|
)
|
@@ -40,7 +40,7 @@ def main(_):
|
|
40
40
|
custom_loader=loader.maybe_get_custom_loader(
|
41
41
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
42
42
|
),
|
43
|
-
|
43
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
44
44
|
)
|
45
45
|
|
46
46
|
config = pytorch_model.image_encoder.config.image_embedding
|
@@ -49,6 +49,7 @@ def main(_):
|
|
49
49
|
output_path=flags.FLAGS.output_path,
|
50
50
|
output_name_prefix=f'{flags.FLAGS.output_name_prefix}_{_VERSION.value}',
|
51
51
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
52
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
52
53
|
pixel_values_size=torch.Size(
|
53
54
|
[1, config.channels, config.image_size, config.image_size]
|
54
55
|
),
|
@@ -73,8 +73,9 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
73
73
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
74
74
|
# doesn't work here.
|
75
75
|
if mask is None:
|
76
|
+
assert kv_cache is not None, "KV cache must be provided."
|
76
77
|
embeds_len = input_embeds.shape[1]
|
77
|
-
mask = torch.zeros(embeds_len,
|
78
|
+
mask = torch.zeros(embeds_len, kv_cache.get_max_seq_len())
|
78
79
|
mask[:, embeds_len:] = attn_config.causal_mask_value
|
79
80
|
|
80
81
|
return self._forward_with_embeds(
|
@@ -87,16 +88,8 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
87
88
|
)
|
88
89
|
|
89
90
|
|
90
|
-
def get_decoder_config(
|
91
|
-
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
92
|
-
|
93
|
-
Args:
|
94
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
95
|
-
is 1024.
|
96
|
-
|
97
|
-
Returns:
|
98
|
-
The model config for the decoder of a PaliGemma 3B model.
|
99
|
-
"""
|
91
|
+
def get_decoder_config() -> cfg.ModelConfig:
|
92
|
+
"""Returns the model config for the decoder of a PaliGemma 3B model."""
|
100
93
|
attn_config = cfg.AttentionConfig(
|
101
94
|
num_heads=8,
|
102
95
|
head_dim=256,
|
@@ -125,7 +118,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
125
118
|
max_seq_len=8192,
|
126
119
|
embedding_dim=embedding_dim,
|
127
120
|
embedding_scale=embedding_dim**0.5,
|
128
|
-
kv_cache_max_len=kv_cache_max_len,
|
129
121
|
block_configs=block_config,
|
130
122
|
final_norm_config=norm_config,
|
131
123
|
lm_head_use_bias=False,
|
@@ -133,22 +125,25 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
133
125
|
return config
|
134
126
|
|
135
127
|
|
136
|
-
def get_fake_decoder_config(
|
137
|
-
config = get_decoder_config(
|
128
|
+
def get_fake_decoder_config() -> cfg.ModelConfig:
|
129
|
+
config = get_decoder_config()
|
138
130
|
# PaliGemma decoder has only one block config.
|
139
131
|
config.block_config(0).ff_config.intermediate_size = 128
|
140
132
|
config.vocab_size = 128
|
141
133
|
config.num_layers = 2
|
142
|
-
config.max_seq_len =
|
134
|
+
config.max_seq_len = 256
|
143
135
|
config.embedding_dim = 128
|
144
136
|
config.embedding_scale = 128**0.5
|
145
137
|
return config
|
146
138
|
|
147
139
|
|
148
|
-
def build_decoder(
|
140
|
+
def build_decoder(
|
141
|
+
checkpoint_path: str, mask_cache_size: int = 0
|
142
|
+
) -> torch.nn.Module:
|
149
143
|
return model_builder.build_decoder_only_model(
|
150
144
|
checkpoint_path=checkpoint_path,
|
151
|
-
config=get_decoder_config(
|
145
|
+
config=get_decoder_config(),
|
152
146
|
tensor_names=TENSOR_NAMES,
|
153
147
|
model_class=Decoder,
|
148
|
+
mask_cache_size=mask_cache_size,
|
154
149
|
)
|
@@ -73,8 +73,9 @@ class Decoder2(gemma2.Gemma2):
|
|
73
73
|
|
74
74
|
if mask is None:
|
75
75
|
# By default, don't mask image embeds with a diagonal causal mask.
|
76
|
+
assert kv_cache is not None, "KV cache must be provided."
|
76
77
|
embeds_len = input_embeds.shape[1]
|
77
|
-
mask = torch.zeros(embeds_len,
|
78
|
+
mask = torch.zeros(embeds_len, kv_cache.get_max_seq_len())
|
78
79
|
mask[:, embeds_len:] = attn_config.causal_mask_value
|
79
80
|
|
80
81
|
return self._forward_with_embeds(
|
@@ -82,16 +83,8 @@ class Decoder2(gemma2.Gemma2):
|
|
82
83
|
)
|
83
84
|
|
84
85
|
|
85
|
-
def get_decoder2_config(
|
86
|
-
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
87
|
-
|
88
|
-
Args:
|
89
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
90
|
-
is 1024.
|
91
|
-
|
92
|
-
Returns:
|
93
|
-
The model config for the decoder of a PaliGemma 3B model.
|
94
|
-
"""
|
86
|
+
def get_decoder2_config() -> cfg.ModelConfig:
|
87
|
+
"""Returns the model config for the decoder of a PaliGemma 3B model."""
|
95
88
|
norm_config = cfg.NormalizationConfig(
|
96
89
|
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
97
90
|
)
|
@@ -133,7 +126,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
133
126
|
max_seq_len=8192,
|
134
127
|
embedding_dim=embedding_dim,
|
135
128
|
embedding_scale=embedding_dim**0.5,
|
136
|
-
kv_cache_max_len=kv_cache_max_len,
|
137
129
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
138
130
|
final_norm_config=norm_config,
|
139
131
|
lm_head_use_bias=False,
|
@@ -142,22 +134,25 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
142
134
|
return config
|
143
135
|
|
144
136
|
|
145
|
-
def get_fake_decoder2_config(
|
146
|
-
config = get_decoder2_config(
|
137
|
+
def get_fake_decoder2_config() -> cfg.ModelConfig:
|
138
|
+
config = get_decoder2_config()
|
147
139
|
# PaliGemma2 decoder has only one block config.
|
148
140
|
config.block_config(0).ff_config.intermediate_size = 128
|
149
141
|
config.vocab_size = 128
|
150
142
|
config.num_layers = 2
|
151
|
-
config.max_seq_len =
|
143
|
+
config.max_seq_len = 256
|
152
144
|
config.embedding_dim = 128
|
153
145
|
config.embedding_scale = 128**0.5
|
154
146
|
return config
|
155
147
|
|
156
148
|
|
157
|
-
def build_decoder2(
|
149
|
+
def build_decoder2(
|
150
|
+
checkpoint_path: str, mask_cache_size: int = 0
|
151
|
+
) -> torch.nn.Module:
|
158
152
|
return model_builder.build_decoder_only_model(
|
159
153
|
checkpoint_path=checkpoint_path,
|
160
|
-
config=get_decoder2_config(
|
154
|
+
config=get_decoder2_config(),
|
161
155
|
tensor_names=TENSOR_NAMES,
|
162
156
|
model_class=Decoder2,
|
157
|
+
mask_cache_size=mask_cache_size,
|
163
158
|
)
|
@@ -45,7 +45,12 @@ class PaliGemmaConfig:
|
|
45
45
|
class PaliGemma(nn.Module):
|
46
46
|
"""PaliGemma model from the Edge Generative API."""
|
47
47
|
|
48
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
config: PaliGemmaConfig,
|
51
|
+
decoder_class: nn.Module,
|
52
|
+
mask_cache_size: int = 0,
|
53
|
+
):
|
49
54
|
super().__init__()
|
50
55
|
|
51
56
|
self.image_encoder = image_encoder.SiglipVisionEncoder(
|
@@ -56,7 +61,7 @@ class PaliGemma(nn.Module):
|
|
56
61
|
config.decoder_config.embedding_dim,
|
57
62
|
bias=config.image_projection_use_bias,
|
58
63
|
)
|
59
|
-
self.decoder = decoder_class(config.decoder_config)
|
64
|
+
self.decoder = decoder_class(config.decoder_config, mask_cache_size)
|
60
65
|
image_embedding_config = config.image_encoder_config.image_embedding
|
61
66
|
self.num_patches = (
|
62
67
|
image_embedding_config.image_size // image_embedding_config.patch_size
|
@@ -116,7 +121,7 @@ class PaliGemma(nn.Module):
|
|
116
121
|
)
|
117
122
|
|
118
123
|
|
119
|
-
def get_model_config(get_decoder_config
|
124
|
+
def get_model_config(get_decoder_config) -> PaliGemmaConfig:
|
120
125
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
121
126
|
|
122
127
|
Returns:
|
@@ -124,16 +129,16 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
|
|
124
129
|
"""
|
125
130
|
return PaliGemmaConfig(
|
126
131
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
127
|
-
decoder_config=get_decoder_config(
|
132
|
+
decoder_config=get_decoder_config(),
|
128
133
|
image_token_id=257152,
|
129
134
|
image_projection_use_bias=True,
|
130
135
|
)
|
131
136
|
|
132
137
|
|
133
|
-
def get_fake_model_config(get_decoder_config
|
138
|
+
def get_fake_model_config(get_decoder_config) -> PaliGemmaConfig:
|
134
139
|
return PaliGemmaConfig(
|
135
140
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
136
|
-
decoder_config=get_decoder_config(
|
141
|
+
decoder_config=get_decoder_config(),
|
137
142
|
image_token_id=127,
|
138
143
|
image_projection_use_bias=True,
|
139
144
|
)
|
@@ -143,7 +148,7 @@ def build_model(
|
|
143
148
|
checkpoint_path: str,
|
144
149
|
version: int = 2,
|
145
150
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
146
|
-
|
151
|
+
mask_cache_size: int = 0,
|
147
152
|
) -> PaliGemma:
|
148
153
|
if version == 1:
|
149
154
|
decoder_class = decoder.Decoder
|
@@ -154,8 +159,8 @@ def build_model(
|
|
154
159
|
decoder_tensor_names = decoder2.TENSOR_NAMES
|
155
160
|
get_decoder_config = decoder2.get_decoder2_config
|
156
161
|
|
157
|
-
config = get_model_config(get_decoder_config
|
158
|
-
model = PaliGemma(config, decoder_class)
|
162
|
+
config = get_model_config(get_decoder_config)
|
163
|
+
model = PaliGemma(config, decoder_class, mask_cache_size)
|
159
164
|
# Load the parameters of image encoder.
|
160
165
|
loader = loading_utils.ModelLoader(
|
161
166
|
checkpoint_path, image_encoder.TENSOR_NAMES, custom_loader
|
@@ -31,13 +31,14 @@ def main(_):
|
|
31
31
|
custom_loader=loader.maybe_get_custom_loader(
|
32
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
33
|
),
|
34
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
35
35
|
)
|
36
36
|
converter.convert_to_tflite(
|
37
37
|
pytorch_model,
|
38
38
|
output_path=flags.FLAGS.output_path,
|
39
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
40
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
42
|
quantize=flags.FLAGS.quantize,
|
42
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
43
44
|
export_config=export_config.get_from_flags(),
|
@@ -31,13 +31,14 @@ def main(_):
|
|
31
31
|
custom_loader=loader.maybe_get_custom_loader(
|
32
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
33
|
),
|
34
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
35
35
|
)
|
36
36
|
converter.convert_to_tflite(
|
37
37
|
pytorch_model,
|
38
38
|
output_path=flags.FLAGS.output_path,
|
39
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
40
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
42
|
quantize=flags.FLAGS.quantize,
|
42
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
43
44
|
export_config=export_config.get_from_flags(),
|
@@ -32,13 +32,14 @@ def main(_):
|
|
32
32
|
custom_loader=loader.maybe_get_custom_loader(
|
33
33
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
34
34
|
),
|
35
|
-
|
35
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
36
36
|
)
|
37
37
|
converter.convert_to_tflite(
|
38
38
|
pytorch_model,
|
39
39
|
output_path=flags.FLAGS.output_path,
|
40
40
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
41
41
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
42
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
42
43
|
quantize=flags.FLAGS.quantize,
|
43
44
|
lora_ranks=flags.FLAGS.lora_ranks,
|
44
45
|
export_config=export_config.get_from_flags(),
|
@@ -41,16 +41,8 @@ class Phi2(model_builder.DecoderOnlyModel):
|
|
41
41
|
pass
|
42
42
|
|
43
43
|
|
44
|
-
def get_model_config(
|
45
|
-
"""Returns the model config for a Phi-2 model.
|
46
|
-
|
47
|
-
Args:
|
48
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
49
|
-
is 1024.
|
50
|
-
|
51
|
-
Returns:
|
52
|
-
The model config for a Phi-2 model.
|
53
|
-
"""
|
44
|
+
def get_model_config() -> cfg.ModelConfig:
|
45
|
+
"""Returns the model config for a Phi-2 model."""
|
54
46
|
attn_config = cfg.AttentionConfig(
|
55
47
|
num_heads=32,
|
56
48
|
head_dim=80,
|
@@ -77,7 +69,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
77
69
|
vocab_size=51200,
|
78
70
|
num_layers=32,
|
79
71
|
max_seq_len=2048,
|
80
|
-
kv_cache_max_len=kv_cache_max_len,
|
81
72
|
embedding_dim=2560,
|
82
73
|
block_configs=block_config,
|
83
74
|
final_norm_config=norm_config,
|
@@ -87,11 +78,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
87
78
|
return config
|
88
79
|
|
89
80
|
|
90
|
-
def get_fake_model_config(
|
91
|
-
config = get_model_config(
|
81
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
82
|
+
config = get_model_config()
|
92
83
|
config.vocab_size = 128
|
93
84
|
config.num_layers = 2
|
94
|
-
config.max_seq_len =
|
85
|
+
config.max_seq_len = 256
|
95
86
|
# Phi-2 has only one block config.
|
96
87
|
config.block_config(0).ff_config.intermediate_size = 128
|
97
88
|
return config
|
@@ -100,12 +91,13 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
100
91
|
def build_model(
|
101
92
|
checkpoint_path: str,
|
102
93
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
103
|
-
|
94
|
+
mask_cache_size: int = 0,
|
104
95
|
) -> nn.Module:
|
105
96
|
return model_builder.build_decoder_only_model(
|
106
97
|
checkpoint_path=checkpoint_path,
|
107
|
-
config=get_model_config(
|
98
|
+
config=get_model_config(),
|
108
99
|
tensor_names=TENSOR_NAMES,
|
109
100
|
model_class=Phi2,
|
110
101
|
custom_loader=custom_loader,
|
102
|
+
mask_cache_size=mask_cache_size,
|
111
103
|
)
|
@@ -139,16 +139,8 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
|
139
139
|
pass
|
140
140
|
|
141
141
|
|
142
|
-
def get_model_config(
|
143
|
-
"""Returns the model config for a Phi-3.5 model.
|
144
|
-
|
145
|
-
Args:
|
146
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
147
|
-
is 1024.
|
148
|
-
|
149
|
-
Returns:
|
150
|
-
The model config for a Phi-3.5 model.
|
151
|
-
"""
|
142
|
+
def get_model_config() -> cfg.ModelConfig:
|
143
|
+
"""Returns the model config for a Phi-3.5 model."""
|
152
144
|
attn_config = cfg.AttentionConfig(
|
153
145
|
num_heads=32,
|
154
146
|
head_dim=96,
|
@@ -185,7 +177,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
185
177
|
vocab_size=32064,
|
186
178
|
num_layers=32,
|
187
179
|
max_seq_len=max_seq_len,
|
188
|
-
kv_cache_max_len=kv_cache_max_len,
|
189
180
|
embedding_dim=3072,
|
190
181
|
block_configs=block_config,
|
191
182
|
final_norm_config=norm_config,
|
@@ -195,11 +186,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
195
186
|
return config
|
196
187
|
|
197
188
|
|
198
|
-
def get_fake_model_config(
|
199
|
-
config = get_model_config(
|
189
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
190
|
+
config = get_model_config()
|
200
191
|
config.vocab_size = 128
|
201
192
|
config.num_layers = 2
|
202
|
-
config.max_seq_len =
|
193
|
+
config.max_seq_len = 256
|
203
194
|
# Phi-3.5 has only one block config.
|
204
195
|
config.block_config(0).ff_config.intermediate_size = 128
|
205
196
|
return config
|
@@ -208,13 +199,14 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
208
199
|
def build_model(
|
209
200
|
checkpoint_path: str,
|
210
201
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
211
|
-
|
202
|
+
mask_cache_size: int = 0,
|
212
203
|
) -> torch.nn.Module:
|
213
204
|
"""Instantiates the model instance and load checkpoint if provided."""
|
214
205
|
return model_builder.build_decoder_only_model(
|
215
206
|
checkpoint_path=checkpoint_path,
|
216
|
-
config=get_model_config(
|
207
|
+
config=get_model_config(),
|
217
208
|
tensor_names=TENSOR_NAMES,
|
218
209
|
model_class=Phi3_5Mini,
|
219
210
|
custom_loader=custom_loader,
|
211
|
+
mask_cache_size=mask_cache_size,
|
220
212
|
)
|
@@ -89,16 +89,8 @@ class Phi4Mini(model_builder.DecoderOnlyModel):
|
|
89
89
|
pass
|
90
90
|
|
91
91
|
|
92
|
-
def get_model_config(
|
93
|
-
"""Returns the model config for a Phi-4 model.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
97
|
-
is 1024.
|
98
|
-
|
99
|
-
Returns:
|
100
|
-
The model config for a Phi-4 model.
|
101
|
-
"""
|
92
|
+
def get_model_config() -> cfg.ModelConfig:
|
93
|
+
"""Returns the model config for a Phi-4 model."""
|
102
94
|
attn_config = cfg.AttentionConfig(
|
103
95
|
num_heads=24,
|
104
96
|
head_dim=128,
|
@@ -135,7 +127,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
135
127
|
vocab_size=200064,
|
136
128
|
num_layers=32,
|
137
129
|
max_seq_len=max_seq_len,
|
138
|
-
kv_cache_max_len=kv_cache_max_len,
|
139
130
|
embedding_dim=3072,
|
140
131
|
block_configs=block_config,
|
141
132
|
final_norm_config=norm_config,
|
@@ -144,11 +135,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
144
135
|
return config
|
145
136
|
|
146
137
|
|
147
|
-
def get_fake_model_config(
|
148
|
-
config = get_model_config(
|
138
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
139
|
+
config = get_model_config()
|
149
140
|
config.vocab_size = 128
|
150
141
|
config.num_layers = 2
|
151
|
-
config.max_seq_len =
|
142
|
+
config.max_seq_len = 256
|
152
143
|
# Phi-4 has only one block config.
|
153
144
|
config.block_config(0).ff_config.intermediate_size = 128
|
154
145
|
return config
|
@@ -157,13 +148,14 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
157
148
|
def build_model(
|
158
149
|
checkpoint_path: str,
|
159
150
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
160
|
-
|
151
|
+
mask_cache_size: int = 0,
|
161
152
|
) -> torch.nn.Module:
|
162
153
|
"""Instantiates the model instance and load checkpoint if provided."""
|
163
154
|
return model_builder.build_decoder_only_model(
|
164
155
|
checkpoint_path=checkpoint_path,
|
165
|
-
config=get_model_config(
|
156
|
+
config=get_model_config(),
|
166
157
|
tensor_names=TENSOR_NAMES,
|
167
158
|
model_class=Phi4Mini,
|
168
159
|
custom_loader=custom_loader,
|
160
|
+
mask_cache_size=mask_cache_size,
|
169
161
|
)
|
@@ -15,7 +15,6 @@
|
|
15
15
|
"""Utils for verifying the Phi model."""
|
16
16
|
|
17
17
|
import logging
|
18
|
-
import os
|
19
18
|
import pathlib
|
20
19
|
from typing import Callable, Dict
|
21
20
|
|
@@ -39,7 +38,6 @@ _BUILDER = {
|
|
39
38
|
def verify_phi(
|
40
39
|
version: str,
|
41
40
|
checkpoint_dir: str,
|
42
|
-
weight_filename: str = "model.safetensors",
|
43
41
|
max_new_tokens: int = 30,
|
44
42
|
prompts: list[str] | None = None,
|
45
43
|
atol: float = 1e-04,
|
@@ -63,7 +61,7 @@ def verify_phi(
|
|
63
61
|
)
|
64
62
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
65
63
|
else:
|
66
|
-
reauthored_checkpoint =
|
64
|
+
reauthored_checkpoint = checkpoint_dir
|
67
65
|
|
68
66
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
69
67
|
reauthored_model = _BUILDER[version](
|