ai-edge-torch-nightly 0.6.0.dev20250601__py3-none-any.whl → 0.6.0.dev20250603__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
- ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
- ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/llama/llama.py +13 -26
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
- ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
- ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/phi2.py +8 -16
- ai_edge_torch/generative/examples/phi/phi3.py +8 -16
- ai_edge_torch/generative/examples/phi/phi4.py +8 -16
- ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
- ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
- ai_edge_torch/generative/examples/t5/t5.py +23 -23
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
- ai_edge_torch/generative/layers/kv_cache.py +13 -1
- ai_edge_torch/generative/layers/model_config.py +0 -14
- ai_edge_torch/generative/test/test_kv_cache.py +14 -24
- ai_edge_torch/generative/test/test_lora.py +4 -21
- ai_edge_torch/generative/test/test_model_conversion.py +8 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
- ai_edge_torch/generative/utilities/converter.py +15 -6
- ai_edge_torch/generative/utilities/model_builder.py +16 -6
- ai_edge_torch/generative/utilities/verifier.py +16 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/RECORD +57 -57
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/top_level.txt +0 -0
@@ -29,16 +29,8 @@ class AmdLlama(model_builder.DecoderOnlyModel):
|
|
29
29
|
pass
|
30
30
|
|
31
31
|
|
32
|
-
def get_model_config(
|
33
|
-
"""Returns the model config for an AMD-Llama-135m model.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
37
|
-
is 1024.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The model config for an AMD-Llama-135m model.
|
41
|
-
"""
|
32
|
+
def get_model_config() -> cfg.ModelConfig:
|
33
|
+
"""Returns the model config for an AMD-Llama-135m model."""
|
42
34
|
attn_config = cfg.AttentionConfig(
|
43
35
|
num_heads=12,
|
44
36
|
head_dim=64,
|
@@ -63,7 +55,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
63
55
|
num_layers=12,
|
64
56
|
max_seq_len=2048,
|
65
57
|
embedding_dim=768,
|
66
|
-
kv_cache_max_len=kv_cache_max_len,
|
67
58
|
block_configs=block_config,
|
68
59
|
final_norm_config=norm_config,
|
69
60
|
lm_head_share_weight_with_embedding=False,
|
@@ -71,8 +62,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
71
62
|
return config
|
72
63
|
|
73
64
|
|
74
|
-
def get_fake_model_config(
|
75
|
-
config = get_model_config(
|
65
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
66
|
+
config = get_model_config()
|
76
67
|
config.vocab_size = 128
|
77
68
|
config.num_layers = 2
|
78
69
|
config.block_config(0).ff_config.intermediate_size = 64
|
@@ -82,12 +73,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
82
73
|
def build_model(
|
83
74
|
checkpoint_path: str,
|
84
75
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
|
85
|
-
|
76
|
+
mask_cache_size: int = 0,
|
86
77
|
) -> nn.Module:
|
87
78
|
return model_builder.build_decoder_only_model(
|
88
79
|
checkpoint_path=checkpoint_path,
|
89
|
-
config=get_model_config(
|
80
|
+
config=get_model_config(),
|
90
81
|
tensor_names=TENSOR_NAMES,
|
91
82
|
model_class=AmdLlama,
|
92
83
|
custom_loader=custom_loader,
|
84
|
+
mask_cache_size=mask_cache_size,
|
93
85
|
)
|
@@ -31,13 +31,14 @@ def main(_):
|
|
31
31
|
custom_loader=loader.maybe_get_custom_loader(
|
32
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
33
|
),
|
34
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
35
35
|
)
|
36
36
|
converter.convert_to_tflite(
|
37
37
|
pytorch_model,
|
38
38
|
output_path=flags.FLAGS.output_path,
|
39
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
40
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
42
|
quantize=flags.FLAGS.quantize,
|
42
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
43
44
|
export_config=export_config.get_from_flags(),
|
@@ -23,6 +23,7 @@ from ai_edge_torch.generative.utilities import loader
|
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('deepseek')
|
25
25
|
|
26
|
+
|
26
27
|
def main(_):
|
27
28
|
checkpoint_path = flags.FLAGS.checkpoint_path
|
28
29
|
pytorch_model = deepseek.build_model(
|
@@ -30,13 +31,14 @@ def main(_):
|
|
30
31
|
custom_loader=loader.maybe_get_custom_loader(
|
31
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
32
33
|
),
|
33
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
34
35
|
)
|
35
36
|
converter.convert_to_tflite(
|
36
37
|
pytorch_model,
|
37
38
|
output_path=flags.FLAGS.output_path,
|
38
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
39
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
40
42
|
quantize=flags.FLAGS.quantize,
|
41
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
42
44
|
export_config=export_config.get_from_flags(),
|
@@ -29,16 +29,8 @@ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
|
|
29
29
|
pass
|
30
30
|
|
31
31
|
|
32
|
-
def get_model_config(
|
33
|
-
"""Returns the model config for a Qwen 2.5 3B model.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
37
|
-
is 1024.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The model config for a SmolLM model.
|
41
|
-
"""
|
32
|
+
def get_model_config() -> cfg.ModelConfig:
|
33
|
+
"""Returns the model config for a Qwen 2.5 3B model."""
|
42
34
|
attn_config = cfg.AttentionConfig(
|
43
35
|
num_heads=12,
|
44
36
|
head_dim=128,
|
@@ -66,7 +58,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
66
58
|
num_layers=28,
|
67
59
|
max_seq_len=4096,
|
68
60
|
embedding_dim=1536,
|
69
|
-
kv_cache_max_len=kv_cache_max_len,
|
70
61
|
block_configs=block_config,
|
71
62
|
final_norm_config=norm_config,
|
72
63
|
lm_head_share_weight_with_embedding=False,
|
@@ -74,8 +65,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
74
65
|
return config
|
75
66
|
|
76
67
|
|
77
|
-
def get_fake_model_config(
|
78
|
-
config = get_model_config(
|
68
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
69
|
+
config = get_model_config()
|
79
70
|
config.vocab_size = 128
|
80
71
|
config.num_layers = 2
|
81
72
|
# DeepSeek-R1-Distill-Qwen has only one block config.
|
@@ -86,12 +77,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
86
77
|
def build_model(
|
87
78
|
checkpoint_path: str,
|
88
79
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
89
|
-
|
80
|
+
mask_cache_size: int = 0,
|
90
81
|
) -> nn.Module:
|
91
82
|
return model_builder.build_decoder_only_model(
|
92
83
|
checkpoint_path=checkpoint_path,
|
93
|
-
config=get_model_config(
|
84
|
+
config=get_model_config(),
|
94
85
|
tensor_names=TENSOR_NAMES,
|
95
86
|
model_class=DeepSeekDistillQwen,
|
96
87
|
custom_loader=custom_loader,
|
88
|
+
mask_cache_size=mask_cache_size,
|
97
89
|
)
|
@@ -31,13 +31,14 @@ def main(_):
|
|
31
31
|
custom_loader=loader.maybe_get_custom_loader(
|
32
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
33
|
),
|
34
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
35
35
|
)
|
36
36
|
converter.convert_to_tflite(
|
37
37
|
pytorch_model,
|
38
38
|
output_path=flags.FLAGS.output_path,
|
39
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
40
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
42
|
quantize=flags.FLAGS.quantize,
|
42
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
43
44
|
export_config=export_config.get_from_flags(),
|
@@ -33,13 +33,14 @@ def main(_):
|
|
33
33
|
custom_loader=loader.maybe_get_custom_loader(
|
34
34
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
35
35
|
),
|
36
|
-
|
36
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
37
37
|
)
|
38
38
|
converter.convert_to_tflite(
|
39
39
|
pytorch_model,
|
40
40
|
output_path=flags.FLAGS.output_path,
|
41
41
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
42
42
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
43
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
43
44
|
quantize=flags.FLAGS.quantize,
|
44
45
|
lora_ranks=flags.FLAGS.lora_ranks,
|
45
46
|
export_config=export_config.get_from_flags(),
|
@@ -42,16 +42,8 @@ class Gemma1(model_builder.DecoderOnlyModel):
|
|
42
42
|
pass
|
43
43
|
|
44
44
|
|
45
|
-
def get_model_config_2b(
|
46
|
-
"""Returns the model config for a Gemma 2B model.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
50
|
-
is 1024.
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
The model config for a Gemma 2B model.
|
54
|
-
"""
|
45
|
+
def get_model_config_2b() -> cfg.ModelConfig:
|
46
|
+
"""Returns the model config for a Gemma 2B model."""
|
55
47
|
attn_config = cfg.AttentionConfig(
|
56
48
|
num_heads=8,
|
57
49
|
head_dim=256,
|
@@ -80,7 +72,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
80
72
|
max_seq_len=8192,
|
81
73
|
embedding_dim=embedding_dim,
|
82
74
|
embedding_scale=embedding_dim**0.5,
|
83
|
-
kv_cache_max_len=kv_cache_max_len,
|
84
75
|
block_configs=block_config,
|
85
76
|
final_norm_config=norm_config,
|
86
77
|
lm_head_use_bias=False,
|
@@ -88,25 +79,26 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
88
79
|
return config
|
89
80
|
|
90
81
|
|
91
|
-
def get_fake_model_config(
|
92
|
-
config = get_model_config_2b(
|
82
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
83
|
+
config = get_model_config_2b()
|
93
84
|
# Gemma has only one block config.
|
94
85
|
config.block_config(0).ff_config.intermediate_size = 128
|
95
86
|
config.vocab_size = 128
|
96
87
|
config.num_layers = 2
|
97
|
-
config.max_seq_len =
|
88
|
+
config.max_seq_len = 256
|
98
89
|
return config
|
99
90
|
|
100
91
|
|
101
92
|
def build_2b_model(
|
102
93
|
checkpoint_path: str,
|
103
94
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
104
|
-
|
95
|
+
mask_cache_size: int = 0,
|
105
96
|
) -> nn.Module:
|
106
97
|
return model_builder.build_decoder_only_model(
|
107
98
|
checkpoint_path=checkpoint_path,
|
108
|
-
config=get_model_config_2b(
|
99
|
+
config=get_model_config_2b(),
|
109
100
|
tensor_names=TENSOR_NAMES,
|
110
101
|
model_class=Gemma1,
|
111
102
|
custom_loader=custom_loader,
|
103
|
+
mask_cache_size=mask_cache_size,
|
112
104
|
)
|
@@ -104,7 +104,7 @@ class Gemma2Block(attention.TransformerBlock):
|
|
104
104
|
class Gemma2(nn.Module):
|
105
105
|
"""A Gemma2 model built from the Edge Generative API layers."""
|
106
106
|
|
107
|
-
def __init__(self, config: cfg.ModelConfig):
|
107
|
+
def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
|
108
108
|
super().__init__()
|
109
109
|
|
110
110
|
# Construct model layers.
|
@@ -126,17 +126,24 @@ class Gemma2(nn.Module):
|
|
126
126
|
config.embedding_dim,
|
127
127
|
config.final_norm_config,
|
128
128
|
)
|
129
|
-
self.
|
130
|
-
|
131
|
-
|
129
|
+
self.config = config
|
130
|
+
self.build_mask_cache(mask_cache_size)
|
131
|
+
|
132
|
+
def build_mask_cache(self, mask_cache_size: int):
|
133
|
+
assert (
|
134
|
+
mask_cache_size <= self.config.max_seq_len
|
135
|
+
), "Mask cache size must be less than or equal to the max seq length."
|
136
|
+
if mask_cache_size <= 0:
|
137
|
+
self.mask_cache = None
|
138
|
+
self.sliding_window_mask_cache = None
|
139
|
+
return
|
140
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(mask_cache_size)
|
132
141
|
# Gemma2 has same hyper parameters for each layer except for attention
|
133
142
|
# types. Use the first layer.
|
134
|
-
attn_config = config.block_config(0).attn_config
|
135
143
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
136
|
-
size=
|
137
|
-
window_size=attn_config.sliding_window_size,
|
144
|
+
size=mask_cache_size,
|
145
|
+
window_size=self.config.block_config(0).attn_config.sliding_window_size,
|
138
146
|
)
|
139
|
-
self.config = config
|
140
147
|
|
141
148
|
def get_attention_mask(
|
142
149
|
self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
|
@@ -167,6 +174,7 @@ class Gemma2(nn.Module):
|
|
167
174
|
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
168
175
|
rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
|
169
176
|
if mask is None:
|
177
|
+
assert self.mask_cache is not None, "Mask cache must be built."
|
170
178
|
mask = [
|
171
179
|
self.get_attention_mask(
|
172
180
|
self.config.block_config(i).attn_config.attn_type, input_pos
|
@@ -222,16 +230,8 @@ class Gemma2(nn.Module):
|
|
222
230
|
return {"logits": res, "kv_cache": updated_kv_cache}
|
223
231
|
|
224
232
|
|
225
|
-
def get_model_config_2b(
|
226
|
-
"""Returns the model config for a Gemma2 2B model.
|
227
|
-
|
228
|
-
Args:
|
229
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
230
|
-
is 1024.
|
231
|
-
|
232
|
-
Returns:
|
233
|
-
The model config for a Gemma 2B model.
|
234
|
-
"""
|
233
|
+
def get_model_config_2b() -> cfg.ModelConfig:
|
234
|
+
"""Returns the model config for a Gemma2 2B model."""
|
235
235
|
norm_config = cfg.NormalizationConfig(
|
236
236
|
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
|
237
237
|
)
|
@@ -277,7 +277,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
277
277
|
max_seq_len=8192,
|
278
278
|
embedding_dim=embedding_dim,
|
279
279
|
embedding_scale=embedding_dim**0.5,
|
280
|
-
kv_cache_max_len=kv_cache_max_len,
|
281
280
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
282
281
|
final_norm_config=norm_config,
|
283
282
|
lm_head_use_bias=False,
|
@@ -286,11 +285,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
286
285
|
return config
|
287
286
|
|
288
287
|
|
289
|
-
def get_fake_model_config(
|
290
|
-
config = get_model_config_2b(
|
288
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
289
|
+
config = get_model_config_2b()
|
291
290
|
config.vocab_size = 128
|
292
291
|
config.num_layers = 2
|
293
|
-
config.max_seq_len =
|
292
|
+
config.max_seq_len = 256
|
294
293
|
config.embedding_dim = 128
|
295
294
|
config.embedding_scale = config.embedding_dim**0.5
|
296
295
|
config.block_configs = config.block_configs[: config.num_layers]
|
@@ -305,16 +304,17 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
305
304
|
def build_2b_model(
|
306
305
|
checkpoint_path: str,
|
307
306
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
308
|
-
|
307
|
+
mask_cache_size: int = 0,
|
309
308
|
) -> nn.Module:
|
310
309
|
for tensor_names in TENSOR_NAMES_DICT.values():
|
311
310
|
try:
|
312
311
|
return model_builder.build_decoder_only_model(
|
313
312
|
checkpoint_path=checkpoint_path,
|
314
|
-
config=get_model_config_2b(
|
313
|
+
config=get_model_config_2b(),
|
315
314
|
tensor_names=tensor_names,
|
316
315
|
model_class=Gemma2,
|
317
316
|
custom_loader=custom_loader,
|
317
|
+
mask_cache_size=mask_cache_size,
|
318
318
|
)
|
319
319
|
except KeyError as _:
|
320
320
|
continue
|
@@ -40,7 +40,7 @@ def main(_):
|
|
40
40
|
custom_loader=loader.maybe_get_custom_loader(
|
41
41
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
42
42
|
),
|
43
|
-
|
43
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
44
44
|
)
|
45
45
|
else:
|
46
46
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
@@ -50,6 +50,7 @@ def main(_):
|
|
50
50
|
output_path=flags.FLAGS.output_path,
|
51
51
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
52
52
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
53
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
53
54
|
quantize=flags.FLAGS.quantize,
|
54
55
|
lora_ranks=flags.FLAGS.lora_ranks,
|
55
56
|
export_config=export_config.get_from_flags(),
|
@@ -74,6 +74,7 @@ TENSOR_NAMES_DICT = {
|
|
74
74
|
|
75
75
|
|
76
76
|
class DecoderBlock(attention.TransformerBlock):
|
77
|
+
"""A Gemma3 decoder block built from the Edge Generative API layers."""
|
77
78
|
|
78
79
|
def forward(
|
79
80
|
self,
|
@@ -111,7 +112,7 @@ class DecoderBlock(attention.TransformerBlock):
|
|
111
112
|
class Decoder(nn.Module):
|
112
113
|
"""A Gemma3 decoder model built from the Edge Generative API layers."""
|
113
114
|
|
114
|
-
def __init__(self, config: cfg.ModelConfig):
|
115
|
+
def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
|
115
116
|
super().__init__()
|
116
117
|
|
117
118
|
# Construct model layers.
|
@@ -130,10 +131,17 @@ class Decoder(nn.Module):
|
|
130
131
|
self.final_norm = builder.build_norm(
|
131
132
|
config.embedding_dim, config.final_norm_config
|
132
133
|
)
|
133
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
134
|
-
size=config.kv_cache_max,
|
135
|
-
)
|
136
134
|
self.config = config
|
135
|
+
self.build_mask_cache(mask_cache_size)
|
136
|
+
|
137
|
+
def build_mask_cache(self, mask_cache_size: int):
|
138
|
+
assert (
|
139
|
+
mask_cache_size <= self.config.max_seq_len
|
140
|
+
), "Mask cache size must be less than or equal to the max seq length."
|
141
|
+
if mask_cache_size <= 0:
|
142
|
+
self.mask_cache = None
|
143
|
+
else:
|
144
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(mask_cache_size)
|
137
145
|
|
138
146
|
def get_local_global_attention_mask(
|
139
147
|
self,
|
@@ -205,9 +213,8 @@ class Decoder(nn.Module):
|
|
205
213
|
mask = torch.where(mask, 0, self.config.causal_mask_value)
|
206
214
|
return mask
|
207
215
|
|
208
|
-
def build_pixel_mask(self, image_indices: torch.Tensor):
|
216
|
+
def build_pixel_mask(self, image_indices: torch.Tensor, max_seq_len: int):
|
209
217
|
pixel_mask = image_indices >= 0
|
210
|
-
max_seq_len = self.config.kv_cache_max
|
211
218
|
if pixel_mask.size(1) < max_seq_len:
|
212
219
|
pixel_mask = torch.cat(
|
213
220
|
[
|
@@ -234,14 +241,12 @@ class Decoder(nn.Module):
|
|
234
241
|
image_indices: Optional[torch.Tensor] = None,
|
235
242
|
export_config: Optional[export_cfg.ExportConfig] = None,
|
236
243
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
237
|
-
pixel_mask = None
|
238
244
|
if input_embeds is None:
|
239
245
|
# token embeddings of shape (b, t, n_embd)
|
240
246
|
input_embeds = self.tok_embedding(tokens)
|
241
247
|
if self.config.embedding_scale is not None:
|
242
248
|
input_embeds = input_embeds * self.config.embedding_scale
|
243
|
-
|
244
|
-
pixel_mask = self.build_pixel_mask(image_indices)
|
249
|
+
|
245
250
|
# RoPE parameters are the same for all blocks. Use the first layer.
|
246
251
|
attn_config = self.config.block_config(0).attn_config
|
247
252
|
# Different rotary base for global and local attention
|
@@ -254,9 +259,19 @@ class Decoder(nn.Module):
|
|
254
259
|
)
|
255
260
|
for i in range(self.config.num_layers)
|
256
261
|
]
|
262
|
+
|
257
263
|
if mask is None:
|
264
|
+
assert self.mask_cache is not None, "Mask cache must be built."
|
265
|
+
assert kv_cache is not None, "KV cache must be provided."
|
266
|
+
kv_cache_max_len = kv_cache.get_max_seq_len()
|
258
267
|
mask = self.mask_cache.index_select(2, input_pos)
|
259
|
-
mask = mask[:, :, :, :
|
268
|
+
mask = mask[:, :, :, :kv_cache_max_len]
|
269
|
+
else:
|
270
|
+
kv_cache_max_len = mask.size(3)
|
271
|
+
|
272
|
+
pixel_mask = None
|
273
|
+
if image_indices is not None:
|
274
|
+
pixel_mask = self.build_pixel_mask(image_indices, kv_cache_max_len)
|
260
275
|
|
261
276
|
return self._forward_with_embeds(
|
262
277
|
input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
|
@@ -322,16 +337,8 @@ class Decoder(nn.Module):
|
|
322
337
|
return {"logits": res, "kv_cache": updated_kv_cache}
|
323
338
|
|
324
339
|
|
325
|
-
def get_decoder_config_1b(
|
326
|
-
"""Returns the model config for a Gemma3 1B model.
|
327
|
-
|
328
|
-
Args:
|
329
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
330
|
-
is 2048.
|
331
|
-
|
332
|
-
Returns:
|
333
|
-
The model config for a Gemma 1B model.
|
334
|
-
"""
|
340
|
+
def get_decoder_config_1b() -> cfg.ModelConfig:
|
341
|
+
"""Returns the model config for a Gemma3 1B model."""
|
335
342
|
norm_config = cfg.NormalizationConfig(
|
336
343
|
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
|
337
344
|
)
|
@@ -376,7 +383,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
376
383
|
max_seq_len=32_768,
|
377
384
|
embedding_dim=embedding_dim,
|
378
385
|
embedding_scale=embedding_dim**0.5,
|
379
|
-
kv_cache_max_len=kv_cache_max_len,
|
380
386
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
381
387
|
final_norm_config=norm_config,
|
382
388
|
lm_head_use_bias=False,
|
@@ -385,20 +391,12 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
|
385
391
|
return config
|
386
392
|
|
387
393
|
|
388
|
-
def get_fake_decoder_config_1b(
|
389
|
-
"""Returns a fake model config for a Gemma3 1B model.
|
390
|
-
|
391
|
-
Args:
|
392
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
393
|
-
is 128.
|
394
|
-
|
395
|
-
Returns:
|
396
|
-
A fake model config for a Gemma 1B model.
|
397
|
-
"""
|
398
|
-
config = get_decoder_config_1b(kv_cache_max_len)
|
394
|
+
def get_fake_decoder_config_1b() -> cfg.ModelConfig:
|
395
|
+
"""Returns a fake model config for a Gemma3 1B model."""
|
396
|
+
config = get_decoder_config_1b()
|
399
397
|
config.vocab_size = 128
|
400
398
|
config.num_layers = 2
|
401
|
-
config.max_seq_len =
|
399
|
+
config.max_seq_len = 256
|
402
400
|
config.embedding_dim = 128
|
403
401
|
config.embedding_scale = config.embedding_dim**0.5
|
404
402
|
config.block_configs = config.block_configs[: config.num_layers]
|
@@ -413,7 +411,7 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
413
411
|
def build_model_1b(
|
414
412
|
checkpoint_path: str,
|
415
413
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
416
|
-
|
414
|
+
mask_cache_size: int = 0,
|
417
415
|
) -> nn.Module:
|
418
416
|
# TODO(b/403644647): Better error handling for loading checkpoints with
|
419
417
|
# different tensor names.
|
@@ -421,10 +419,11 @@ def build_model_1b(
|
|
421
419
|
try:
|
422
420
|
return model_builder.build_decoder_only_model(
|
423
421
|
checkpoint_path=checkpoint_path,
|
424
|
-
config=get_decoder_config_1b(
|
422
|
+
config=get_decoder_config_1b(),
|
425
423
|
tensor_names=tensor_names,
|
426
424
|
model_class=Decoder,
|
427
425
|
custom_loader=custom_loader,
|
426
|
+
mask_cache_size=mask_cache_size,
|
428
427
|
)
|
429
428
|
except KeyError as ke:
|
430
429
|
continue
|
@@ -48,13 +48,13 @@ class Gemma3MMConfig:
|
|
48
48
|
class Gemma3MM(nn.Module):
|
49
49
|
"""A Gemma3 multimodal model built from the Edge Generative API layers."""
|
50
50
|
|
51
|
-
def __init__(self, config: Gemma3MMConfig):
|
51
|
+
def __init__(self, config: Gemma3MMConfig, mask_cache_size: int = 0):
|
52
52
|
super().__init__()
|
53
53
|
|
54
54
|
self.image_encoder = image_encoder.SiglipVisionEncoderWithExit(
|
55
55
|
config.image_encoder_config
|
56
56
|
)
|
57
|
-
self.decoder = decoder.Decoder(config.decoder_config)
|
57
|
+
self.decoder = decoder.Decoder(config.decoder_config, mask_cache_size)
|
58
58
|
self.mm_norm = builder.build_norm(
|
59
59
|
config.image_encoder_config.embedding_dim,
|
60
60
|
config.mm_norm_config,
|
@@ -150,10 +150,10 @@ class Gemma3MM(nn.Module):
|
|
150
150
|
)
|
151
151
|
|
152
152
|
|
153
|
-
def get_fake_model_config(
|
153
|
+
def get_fake_model_config() -> Gemma3MMConfig:
|
154
154
|
return Gemma3MMConfig(
|
155
155
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
156
|
-
decoder_config=decoder.get_fake_decoder_config_1b(
|
156
|
+
decoder_config=decoder.get_fake_decoder_config_1b(),
|
157
157
|
image_token_id=127,
|
158
158
|
image_projection_scale=128**0.5,
|
159
159
|
image_projection_use_bias=False,
|
@@ -167,13 +167,15 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
|
167
167
|
def build_model_1b(
|
168
168
|
checkpoint_path: str,
|
169
169
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
170
|
-
|
170
|
+
mask_cache_size: int = 0,
|
171
171
|
) -> decoder.Decoder:
|
172
172
|
if checkpoint_path:
|
173
|
-
model = decoder.build_model_1b(
|
173
|
+
model = decoder.build_model_1b(
|
174
|
+
checkpoint_path, custom_loader, mask_cache_size
|
175
|
+
)
|
174
176
|
else:
|
175
|
-
config = decoder.get_decoder_config_1b(
|
176
|
-
model = decoder.Decoder(config)
|
177
|
+
config = decoder.get_decoder_config_1b()
|
178
|
+
model = decoder.Decoder(config, mask_cache_size)
|
177
179
|
# TODO: Load the parameters of decoder from checkpoint.
|
178
180
|
model.eval()
|
179
181
|
return model
|
@@ -43,13 +43,14 @@ def main(_):
|
|
43
43
|
custom_loader=loader.maybe_get_custom_loader(
|
44
44
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
45
45
|
),
|
46
|
-
|
46
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
47
47
|
)
|
48
48
|
converter.convert_to_tflite(
|
49
49
|
pytorch_model,
|
50
50
|
output_path=flags.FLAGS.output_path,
|
51
51
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
52
52
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
53
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
53
54
|
quantize=flags.FLAGS.quantize,
|
54
55
|
lora_ranks=flags.FLAGS.lora_ranks,
|
55
56
|
export_config=export_config.get_from_flags(),
|
@@ -29,7 +29,7 @@ class Hammer(model_builder.DecoderOnlyModel):
|
|
29
29
|
pass
|
30
30
|
|
31
31
|
|
32
|
-
def get_1_5b_model_config(
|
32
|
+
def get_1_5b_model_config() -> cfg.ModelConfig:
|
33
33
|
"""Returns the model config for a Hammer 2.1 1.5B model."""
|
34
34
|
attn_config = cfg.AttentionConfig(
|
35
35
|
num_heads=12,
|
@@ -58,16 +58,15 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
58
58
|
num_layers=28,
|
59
59
|
max_seq_len=32768,
|
60
60
|
embedding_dim=1536,
|
61
|
-
kv_cache_max_len=kv_cache_max_len,
|
62
61
|
block_configs=block_config,
|
63
62
|
final_norm_config=norm_config,
|
64
63
|
)
|
65
64
|
return config
|
66
65
|
|
67
66
|
|
68
|
-
def get_0_5b_model_config(
|
67
|
+
def get_0_5b_model_config() -> cfg.ModelConfig:
|
69
68
|
"""Returns the model config for a Hammer 2.1 0.5B model."""
|
70
|
-
config = get_1_5b_model_config(
|
69
|
+
config = get_1_5b_model_config()
|
71
70
|
# Hammer has only one block config.
|
72
71
|
block_config = config.block_config(0)
|
73
72
|
block_config.attn_config.num_heads = 14
|
@@ -78,8 +77,8 @@ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
78
77
|
return config
|
79
78
|
|
80
79
|
|
81
|
-
def get_fake_model_config(
|
82
|
-
config = get_1_5b_model_config(
|
80
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
81
|
+
config = get_1_5b_model_config()
|
83
82
|
config.vocab_size = 128
|
84
83
|
config.num_layers = 2
|
85
84
|
config.embedding_dim = 16
|
@@ -88,29 +87,37 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
88
87
|
return config
|
89
88
|
|
90
89
|
|
91
|
-
def
|
90
|
+
def _build_model(
|
92
91
|
checkpoint_path: str,
|
92
|
+
config: cfg.ModelConfig,
|
93
93
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
94
|
-
|
94
|
+
mask_cache_size: int = 0,
|
95
95
|
) -> nn.Module:
|
96
96
|
return model_builder.build_decoder_only_model(
|
97
97
|
checkpoint_path=checkpoint_path,
|
98
|
-
config=
|
98
|
+
config=config,
|
99
99
|
tensor_names=TENSOR_NAMES,
|
100
100
|
model_class=Hammer,
|
101
101
|
custom_loader=custom_loader,
|
102
|
+
mask_cache_size=mask_cache_size,
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
def build_1_5b_model(
|
107
|
+
checkpoint_path: str,
|
108
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
109
|
+
mask_cache_size: int = 0,
|
110
|
+
) -> nn.Module:
|
111
|
+
return _build_model(
|
112
|
+
checkpoint_path, get_1_5b_model_config(), custom_loader, mask_cache_size
|
102
113
|
)
|
103
114
|
|
104
115
|
|
105
116
|
def build_0_5b_model(
|
106
117
|
checkpoint_path: str,
|
107
118
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
108
|
-
|
119
|
+
mask_cache_size: int = 0,
|
109
120
|
) -> nn.Module:
|
110
|
-
return
|
111
|
-
checkpoint_path
|
112
|
-
config=get_0_5b_model_config(**kwargs),
|
113
|
-
tensor_names=TENSOR_NAMES,
|
114
|
-
model_class=Hammer,
|
115
|
-
custom_loader=custom_loader,
|
121
|
+
return _build_model(
|
122
|
+
checkpoint_path, get_0_5b_model_config(), custom_loader, mask_cache_size
|
116
123
|
)
|