ai-edge-torch-nightly 0.6.0.dev20250602__py3-none-any.whl → 0.6.0.dev20250604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +3 -1
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
- ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
- ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +6 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +3 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
- ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/llama/llama.py +13 -26
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
- ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
- ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/phi/phi2.py +8 -16
- ai_edge_torch/generative/examples/phi/phi3.py +8 -16
- ai_edge_torch/generative/examples/phi/phi4.py +8 -16
- ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
- ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
- ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
- ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
- ai_edge_torch/generative/examples/t5/t5.py +23 -23
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
- ai_edge_torch/generative/layers/kv_cache.py +13 -1
- ai_edge_torch/generative/layers/model_config.py +0 -14
- ai_edge_torch/generative/test/test_kv_cache.py +14 -24
- ai_edge_torch/generative/test/test_lora.py +4 -21
- ai_edge_torch/generative/test/test_model_conversion.py +8 -4
- ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
- ai_edge_torch/generative/utilities/converter.py +15 -6
- ai_edge_torch/generative/utilities/model_builder.py +16 -6
- ai_edge_torch/generative/utilities/verifier.py +16 -6
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/RECORD +60 -60
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250604.dist-info}/top_level.txt +0 -0
@@ -128,7 +128,7 @@ class T5(nn.Module):
|
|
128
128
|
|
129
129
|
self.enc_attn_mask_cache = (
|
130
130
|
torch.zeros(
|
131
|
-
(config.
|
131
|
+
(config.max_seq_len, config.max_seq_len),
|
132
132
|
dtype=torch.float32,
|
133
133
|
device=torch.device("cpu"),
|
134
134
|
)
|
@@ -137,7 +137,7 @@ class T5(nn.Module):
|
|
137
137
|
)
|
138
138
|
|
139
139
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
140
|
-
size=config.
|
140
|
+
size=config.max_seq_len,
|
141
141
|
dtype=torch.float32,
|
142
142
|
device=torch.device("cpu"),
|
143
143
|
)
|
@@ -146,16 +146,16 @@ class T5(nn.Module):
|
|
146
146
|
attn_config = config.block_config(0).attn_config
|
147
147
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
148
148
|
bidirectional=True,
|
149
|
-
query_length=config.
|
150
|
-
key_length=config.
|
149
|
+
query_length=config.max_seq_len,
|
150
|
+
key_length=config.max_seq_len,
|
151
151
|
num_buckets=attn_config.relative_attention_num_buckets,
|
152
152
|
max_distance=attn_config.relative_attention_max_distance,
|
153
153
|
)
|
154
154
|
|
155
155
|
self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
156
156
|
bidirectional=False,
|
157
|
-
query_length=config.
|
158
|
-
key_length=config.
|
157
|
+
query_length=config.max_seq_len,
|
158
|
+
key_length=config.max_seq_len,
|
159
159
|
num_buckets=attn_config.relative_attention_num_buckets,
|
160
160
|
max_distance=attn_config.relative_attention_max_distance,
|
161
161
|
)
|
@@ -176,20 +176,20 @@ class T5(nn.Module):
|
|
176
176
|
)
|
177
177
|
|
178
178
|
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
179
|
-
enc_mask = enc_mask[:, :, :, : self.config.
|
179
|
+
enc_mask = enc_mask[:, :, :, : self.config.max_seq_len]
|
180
180
|
# Mask off any "pad" tokens that shouldn't contribute to self-attention
|
181
181
|
enc_mask[:, :, :, :] += pad_mask
|
182
182
|
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
183
|
-
dec_mask = dec_mask[:, :, :, : self.config.
|
183
|
+
dec_mask = dec_mask[:, :, :, : self.config.max_seq_len]
|
184
184
|
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
185
185
|
enc_relative_position = enc_relative_position[
|
186
|
-
:, :, :, : self.config.
|
186
|
+
:, :, :, : self.config.max_seq_len
|
187
187
|
]
|
188
188
|
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
189
189
|
2, decoder_input_pos
|
190
190
|
)
|
191
191
|
dec_relative_position = dec_relative_position[
|
192
|
-
:, :, :, : self.config.
|
192
|
+
:, :, :, : self.config.max_seq_len
|
193
193
|
]
|
194
194
|
enc_attention_mask = self.enc_attn_mask_cache.index_select(
|
195
195
|
2, decoder_input_pos
|
@@ -243,7 +243,7 @@ class T5Encoder(nn.Module):
|
|
243
243
|
|
244
244
|
self.enc_attn_mask_cache = (
|
245
245
|
torch.zeros(
|
246
|
-
(config.
|
246
|
+
(config.max_seq_len, config.max_seq_len),
|
247
247
|
dtype=torch.float32,
|
248
248
|
device=torch.device("cpu"),
|
249
249
|
)
|
@@ -255,8 +255,8 @@ class T5Encoder(nn.Module):
|
|
255
255
|
attn_config = config.block_config(0).attn_config
|
256
256
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
257
257
|
bidirectional=True,
|
258
|
-
query_length=config.
|
259
|
-
key_length=config.
|
258
|
+
query_length=config.max_seq_len,
|
259
|
+
key_length=config.max_seq_len,
|
260
260
|
num_buckets=attn_config.relative_attention_num_buckets,
|
261
261
|
max_distance=attn_config.relative_attention_max_distance,
|
262
262
|
)
|
@@ -275,12 +275,12 @@ class T5Encoder(nn.Module):
|
|
275
275
|
)
|
276
276
|
|
277
277
|
enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
|
278
|
-
enc_mask = enc_mask[:, :, :, : self.config.
|
278
|
+
enc_mask = enc_mask[:, :, :, : self.config.max_seq_len]
|
279
279
|
# Mask off any "pad" tokens that shouldn't contribute to self-attention
|
280
280
|
enc_mask[:, :, :, :] += pad_mask
|
281
281
|
enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
|
282
282
|
enc_relative_position = enc_relative_position[
|
283
|
-
:, :, :, : self.config.
|
283
|
+
:, :, :, : self.config.max_seq_len
|
284
284
|
]
|
285
285
|
|
286
286
|
# Convert encoder inputs in embeddings if needed
|
@@ -315,7 +315,7 @@ class T5Decoder(nn.Module):
|
|
315
315
|
|
316
316
|
self.enc_attn_mask_cache = (
|
317
317
|
torch.zeros(
|
318
|
-
(config.
|
318
|
+
(config.max_seq_len, config.max_seq_len),
|
319
319
|
dtype=torch.float32,
|
320
320
|
device=torch.device("cpu"),
|
321
321
|
)
|
@@ -327,14 +327,14 @@ class T5Decoder(nn.Module):
|
|
327
327
|
attn_config = config.block_config(0).attn_config
|
328
328
|
self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
|
329
329
|
bidirectional=True,
|
330
|
-
query_length=config.
|
331
|
-
key_length=config.
|
330
|
+
query_length=config.max_seq_len,
|
331
|
+
key_length=config.max_seq_len,
|
332
332
|
num_buckets=attn_config.relative_attention_num_buckets,
|
333
333
|
max_distance=attn_config.relative_attention_max_distance,
|
334
334
|
)
|
335
335
|
|
336
336
|
self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
|
337
|
-
size=config.
|
337
|
+
size=config.max_seq_len,
|
338
338
|
)
|
339
339
|
|
340
340
|
@torch.inference_mode
|
@@ -346,12 +346,12 @@ class T5Decoder(nn.Module):
|
|
346
346
|
pad_mask: torch.Tensor,
|
347
347
|
) -> torch.Tensor:
|
348
348
|
dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
|
349
|
-
dec_mask = dec_mask[:, :, :, : self.config.
|
349
|
+
dec_mask = dec_mask[:, :, :, : self.config.max_seq_len]
|
350
350
|
dec_relative_position = self.enc_rel_pos_mask.index_select(
|
351
351
|
2, decoder_input_pos
|
352
352
|
)
|
353
353
|
dec_relative_position = dec_relative_position[
|
354
|
-
:, :, :, : self.config.
|
354
|
+
:, :, :, : self.config.max_seq_len
|
355
355
|
]
|
356
356
|
enc_attention_mask = self.enc_attn_mask_cache.index_select(
|
357
357
|
2, decoder_input_pos
|
@@ -603,7 +603,7 @@ def define_and_run_t5(checkpoint_path: str) -> None:
|
|
603
603
|
|
604
604
|
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
605
605
|
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
606
|
-
pad_mask = torch.zeros([model.config.
|
606
|
+
pad_mask = torch.zeros([model.config.max_seq_len], dtype=torch.float32)
|
607
607
|
pad_mask[77:] = float("-inf")
|
608
608
|
lm_logits = model.forward(
|
609
609
|
tokens, input_pos, decode_d_token, decode_d_input_pos, pad_mask
|
@@ -636,7 +636,7 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
|
|
636
636
|
decode_d_token = torch.tensor([[0]], dtype=torch.int)
|
637
637
|
decode_d_input_pos = torch.tensor([0], dtype=torch.int)
|
638
638
|
pad_mask = torch.zeros(
|
639
|
-
[t5_encoder_model.config.
|
639
|
+
[t5_encoder_model.config.max_seq_len], dtype=torch.float32
|
640
640
|
)
|
641
641
|
pad_mask[77:] = float("-inf")
|
642
642
|
hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
|
@@ -53,7 +53,7 @@ class EncoderDecoderBlock(nn.Module):
|
|
53
53
|
model_config.embedding_dim,
|
54
54
|
config.attn_config,
|
55
55
|
config.pre_attention_norm_config,
|
56
|
-
model_config.
|
56
|
+
model_config.max_seq_len,
|
57
57
|
model_config.enable_hlfb,
|
58
58
|
has_relative_attention_bias=has_relative_attention_bias,
|
59
59
|
)
|
@@ -64,7 +64,7 @@ class EncoderDecoderBlock(nn.Module):
|
|
64
64
|
model_config.embedding_dim,
|
65
65
|
config.attn_config,
|
66
66
|
config.pre_attention_norm_config,
|
67
|
-
model_config.
|
67
|
+
model_config.max_seq_len,
|
68
68
|
model_config.enable_hlfb,
|
69
69
|
# Cross Attention does not have relative attention bias.
|
70
70
|
has_relative_attention_bias=False,
|
@@ -31,13 +31,14 @@ def main(_):
|
|
31
31
|
custom_loader=loader.maybe_get_custom_loader(
|
32
32
|
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
33
33
|
),
|
34
|
-
|
34
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
35
35
|
)
|
36
36
|
converter.convert_to_tflite(
|
37
37
|
pytorch_model,
|
38
38
|
output_path=flags.FLAGS.output_path,
|
39
39
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
40
40
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
41
|
+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
41
42
|
quantize=flags.FLAGS.quantize,
|
42
43
|
lora_ranks=flags.FLAGS.lora_ranks,
|
43
44
|
export_config=export_config.get_from_flags(),
|
@@ -29,16 +29,8 @@ class TinyLlama(model_builder.DecoderOnlyModel):
|
|
29
29
|
pass
|
30
30
|
|
31
31
|
|
32
|
-
def get_model_config(
|
33
|
-
"""Returns the model config for a TinyLlama model.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
37
|
-
is 1024.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The model config for a TinyLlama model.
|
41
|
-
"""
|
32
|
+
def get_model_config() -> cfg.ModelConfig:
|
33
|
+
"""Returns the model config for a TinyLlama model."""
|
42
34
|
attn_config = cfg.AttentionConfig(
|
43
35
|
num_heads=32,
|
44
36
|
head_dim=64,
|
@@ -63,7 +55,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
63
55
|
num_layers=22,
|
64
56
|
max_seq_len=2048,
|
65
57
|
embedding_dim=2048,
|
66
|
-
kv_cache_max_len=kv_cache_max_len,
|
67
58
|
block_configs=block_config,
|
68
59
|
final_norm_config=norm_config,
|
69
60
|
lm_head_share_weight_with_embedding=False,
|
@@ -71,8 +62,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
71
62
|
return config
|
72
63
|
|
73
64
|
|
74
|
-
def get_fake_model_config(
|
75
|
-
config = get_model_config(
|
65
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
66
|
+
config = get_model_config()
|
76
67
|
config.vocab_size = 128
|
77
68
|
config.num_layers = 2
|
78
69
|
# TinyLlama has only one block config.
|
@@ -83,12 +74,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
|
83
74
|
def build_model(
|
84
75
|
checkpoint_path: str,
|
85
76
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
86
|
-
|
77
|
+
mask_cache_size: int = 0,
|
87
78
|
) -> nn.Module:
|
88
79
|
return model_builder.build_decoder_only_model(
|
89
80
|
checkpoint_path=checkpoint_path,
|
90
|
-
config=get_model_config(
|
81
|
+
config=get_model_config(),
|
91
82
|
tensor_names=TENSOR_NAMES,
|
92
83
|
model_class=TinyLlama,
|
93
84
|
custom_loader=custom_loader,
|
85
|
+
mask_cache_size=mask_cache_size,
|
94
86
|
)
|
@@ -88,6 +88,12 @@ class KVCacheEntry:
|
|
88
88
|
obj = cls(k_cache=k, v_cache=v, kv_layout=kv_layout)
|
89
89
|
return obj
|
90
90
|
|
91
|
+
def get_max_seq_len(self) -> int:
|
92
|
+
"""Get the maximum sequence length in the KV cache."""
|
93
|
+
return self.k_cache.size(
|
94
|
+
self.kv_layout[0].dimensions.index(types.TensorDims.SEQUENCE)
|
95
|
+
)
|
96
|
+
|
91
97
|
|
92
98
|
@dataclasses.dataclass
|
93
99
|
class KVCache:
|
@@ -98,6 +104,7 @@ class KVCache:
|
|
98
104
|
@classmethod
|
99
105
|
def from_model_config(
|
100
106
|
cls,
|
107
|
+
kv_cache_max: int,
|
101
108
|
config: model_config.ModelConfig,
|
102
109
|
dtype: torch.dtype = torch.float32,
|
103
110
|
device: torch.device | None = None,
|
@@ -107,6 +114,7 @@ class KVCache:
|
|
107
114
|
"""Build an instance of the class based on model config.
|
108
115
|
|
109
116
|
Args:
|
117
|
+
kv_cache_max (int): The maximum sequence length in the KV cache.
|
110
118
|
config (ModelConfig): Model config used for building the cache.
|
111
119
|
dtype (torch.dtype, optional): The data type of the cache tensor.
|
112
120
|
Defaults to torch.float32.
|
@@ -120,7 +128,7 @@ class KVCache:
|
|
120
128
|
"""
|
121
129
|
caches = [
|
122
130
|
KVCacheEntry.from_model_config(
|
123
|
-
|
131
|
+
kv_cache_max
|
124
132
|
if not config.block_config(idx).kv_cache_max_len
|
125
133
|
else config.block_config(idx).kv_cache_max_len,
|
126
134
|
config.block_config(idx).attn_config,
|
@@ -139,6 +147,10 @@ class KVCache:
|
|
139
147
|
flattened, _ = _flatten_kvc(self)
|
140
148
|
return flattened
|
141
149
|
|
150
|
+
def get_max_seq_len(self) -> int:
|
151
|
+
"""Get the maximum sequence length in the KV cache."""
|
152
|
+
return self.caches[0].get_max_seq_len()
|
153
|
+
|
142
154
|
|
143
155
|
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
144
156
|
flattened = []
|
@@ -251,9 +251,6 @@ class ModelConfig:
|
|
251
251
|
# Whether to turn on high-level function boundary.
|
252
252
|
enable_hlfb: bool = True
|
253
253
|
|
254
|
-
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
255
|
-
kv_cache_max_len: int = 0
|
256
|
-
|
257
254
|
# Softcap on the model output logits.
|
258
255
|
final_logit_softcap: Optional[float] = None
|
259
256
|
|
@@ -261,23 +258,12 @@ class ModelConfig:
|
|
261
258
|
# forward pass. Defaults to a standard implementation.
|
262
259
|
build_rope: Callable = rotary_position_embedding.build_rope
|
263
260
|
|
264
|
-
# Whether or not to use a mask cache. Mask cache can speed up inference when
|
265
|
-
# statically exporting models. However, it is not supported in the dynamic
|
266
|
-
# export.
|
267
|
-
use_mask_cache: bool = True
|
268
|
-
|
269
261
|
# An interleaved sequence of the attention types used in the model.
|
270
262
|
# E.g. [AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING,
|
271
263
|
# AttentionType.GLOBAL] means that the model has an attention pattern of 2
|
272
264
|
# local attentions followed by a global attention in a repeated pattern.
|
273
265
|
attention_patterns: Optional[Sequence[AttentionType]] = None
|
274
266
|
|
275
|
-
@property
|
276
|
-
def kv_cache_max(self) -> int:
|
277
|
-
if self.kv_cache_max_len > 0:
|
278
|
-
return self.kv_cache_max_len
|
279
|
-
return self.max_seq_len
|
280
|
-
|
281
267
|
def block_config(self, idx: int) -> TransformerBlockConfig:
|
282
268
|
if isinstance(self.block_configs, TransformerBlockConfig):
|
283
269
|
return self.block_configs
|
@@ -25,9 +25,7 @@ from absl.testing import absltest as googletest
|
|
25
25
|
|
26
26
|
class TestKVLayers(googletest.TestCase):
|
27
27
|
|
28
|
-
def _get_test_config(
|
29
|
-
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
30
|
-
):
|
28
|
+
def _get_test_config(self, num_layers, head_dim, num_query_groups):
|
31
29
|
attn_config = cfg.AttentionConfig(
|
32
30
|
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
33
31
|
)
|
@@ -35,7 +33,6 @@ class TestKVLayers(googletest.TestCase):
|
|
35
33
|
attn_config=attn_config, ff_config=None
|
36
34
|
)
|
37
35
|
config = cfg.ModelConfig(
|
38
|
-
kv_cache_max_len=kv_cache_max_len,
|
39
36
|
embedding_dim=head_dim,
|
40
37
|
block_configs=block_config,
|
41
38
|
num_layers=num_layers,
|
@@ -50,12 +47,9 @@ class TestKVLayers(googletest.TestCase):
|
|
50
47
|
NUM_QG = 1
|
51
48
|
KV_LEN = 4
|
52
49
|
config = self._get_test_config(
|
53
|
-
num_layers=N,
|
54
|
-
head_dim=HEAD_DIM,
|
55
|
-
num_query_groups=NUM_QG,
|
56
|
-
kv_cache_max_len=KV_LEN,
|
50
|
+
num_layers=N, head_dim=HEAD_DIM, num_query_groups=NUM_QG
|
57
51
|
)
|
58
|
-
kv = kv_utils.KVCache.from_model_config(config)
|
52
|
+
kv = kv_utils.KVCache.from_model_config(KV_LEN, config)
|
59
53
|
entry = kv.caches[0]
|
60
54
|
# single-slice update
|
61
55
|
input_pos = torch.tensor([1])
|
@@ -103,12 +97,9 @@ class TestKVLayers(googletest.TestCase):
|
|
103
97
|
NUM_QG = 1
|
104
98
|
KV_LEN = 4
|
105
99
|
config = self._get_test_config(
|
106
|
-
num_layers=N,
|
107
|
-
head_dim=HEAD_DIM,
|
108
|
-
num_query_groups=NUM_QG,
|
109
|
-
kv_cache_max_len=KV_LEN,
|
100
|
+
num_layers=N, head_dim=HEAD_DIM, num_query_groups=NUM_QG
|
110
101
|
)
|
111
|
-
kv = kv_utils.KVCache.from_model_config(config)
|
102
|
+
kv = kv_utils.KVCache.from_model_config(KV_LEN, config)
|
112
103
|
model = TestModel()
|
113
104
|
exported_program = torch.export.export(model, (kv,))
|
114
105
|
input_specs = exported_program.graph_signature.input_specs
|
@@ -119,12 +110,11 @@ class TestKVLayers(googletest.TestCase):
|
|
119
110
|
def test_pytree_roundtrip_kv_cache(self):
|
120
111
|
NUM_LAYERS = 4
|
121
112
|
config = self._get_test_config(
|
122
|
-
num_layers=NUM_LAYERS,
|
123
|
-
|
124
|
-
|
125
|
-
|
113
|
+
num_layers=NUM_LAYERS, head_dim=2, num_query_groups=1
|
114
|
+
)
|
115
|
+
kv = kv_utils.KVCache.from_model_config(
|
116
|
+
kv_cache_max=4, config=config, batch_size=1
|
126
117
|
)
|
127
|
-
kv = kv_utils.KVCache.from_model_config(config, batch_size=1)
|
128
118
|
flat, treespec = pytree.tree_flatten(kv)
|
129
119
|
self.assertLen(flat, NUM_LAYERS * 2)
|
130
120
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
@@ -133,13 +123,13 @@ class TestKVLayers(googletest.TestCase):
|
|
133
123
|
def test_pytree_roundtrip_kv_cache_derived(self):
|
134
124
|
NUM_LAYERS = 4
|
135
125
|
config = self._get_test_config(
|
136
|
-
num_layers=NUM_LAYERS,
|
137
|
-
head_dim=2,
|
138
|
-
num_query_groups=1,
|
139
|
-
kv_cache_max_len=4,
|
126
|
+
num_layers=NUM_LAYERS, head_dim=2, num_query_groups=1
|
140
127
|
)
|
141
128
|
kv = kv_utils.KVCache.from_model_config(
|
142
|
-
|
129
|
+
kv_cache_max=4,
|
130
|
+
config=config,
|
131
|
+
batch_size=1,
|
132
|
+
kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED,
|
143
133
|
)
|
144
134
|
flat, treespec = pytree.tree_flatten(kv)
|
145
135
|
self.assertLen(flat, NUM_LAYERS * 2)
|
@@ -58,12 +58,7 @@ class TestLora(googletest.TestCase):
|
|
58
58
|
safetensors_file = resource_loader.get_path_to_datafile(
|
59
59
|
"fixtures/test_lora_rank16.safetensors"
|
60
60
|
)
|
61
|
-
config = self._get_test_config(
|
62
|
-
num_layers=1,
|
63
|
-
head_dim=8,
|
64
|
-
num_query_groups=1,
|
65
|
-
kv_cache_max_len=16,
|
66
|
-
)
|
61
|
+
config = self._get_test_config(num_layers=1, head_dim=8, num_query_groups=1)
|
67
62
|
lora = lora_utils.LoRA.from_safetensors(
|
68
63
|
safetensors_file,
|
69
64
|
scale=1.0,
|
@@ -84,12 +79,8 @@ class TestLora(googletest.TestCase):
|
|
84
79
|
n = 1
|
85
80
|
head_dim = 2
|
86
81
|
num_query_groups = 1
|
87
|
-
key_length = 4
|
88
82
|
config = self._get_test_config(
|
89
|
-
num_layers=n,
|
90
|
-
head_dim=head_dim,
|
91
|
-
num_query_groups=num_query_groups,
|
92
|
-
kv_cache_max_len=key_length,
|
83
|
+
num_layers=n, head_dim=head_dim, num_query_groups=num_query_groups
|
93
84
|
)
|
94
85
|
inputs = torch.zeros((n, 1, head_dim))
|
95
86
|
lora = lora_utils.LoRA.zeros(rank=16, config=config)
|
@@ -111,20 +102,13 @@ class TestLora(googletest.TestCase):
|
|
111
102
|
|
112
103
|
def test_lora_tflite_serialization(self):
|
113
104
|
"""Tests the serialization of the LoRA module."""
|
114
|
-
config = self._get_test_config(
|
115
|
-
num_layers=2,
|
116
|
-
head_dim=8,
|
117
|
-
num_query_groups=1,
|
118
|
-
kv_cache_max_len=16,
|
119
|
-
)
|
105
|
+
config = self._get_test_config(num_layers=2, head_dim=8, num_query_groups=1)
|
120
106
|
lora = lora_utils.LoRA.random(rank=16, config=config)
|
121
107
|
flatbuffer_model = lora.to_tflite()
|
122
108
|
recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
|
123
109
|
self.assertEqual(lora, recovered_lora)
|
124
110
|
|
125
|
-
def _get_test_config(
|
126
|
-
self, num_layers, head_dim, num_query_groups, kv_cache_max_len
|
127
|
-
):
|
111
|
+
def _get_test_config(self, num_layers, head_dim, num_query_groups):
|
128
112
|
"""Returns a test model config."""
|
129
113
|
attn_config = cfg.AttentionConfig(
|
130
114
|
num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
|
@@ -133,7 +117,6 @@ class TestLora(googletest.TestCase):
|
|
133
117
|
attn_config=attn_config, ff_config=None
|
134
118
|
)
|
135
119
|
config = cfg.ModelConfig(
|
136
|
-
kv_cache_max_len=kv_cache_max_len,
|
137
120
|
embedding_dim=head_dim,
|
138
121
|
block_configs=block_config,
|
139
122
|
num_layers=num_layers,
|
@@ -47,7 +47,9 @@ class TestModelConversion(googletest.TestCase):
|
|
47
47
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
48
48
|
[10], dtype=torch.int
|
49
49
|
)
|
50
|
-
kv = kv_cache.KVCache.from_model_config(
|
50
|
+
kv = kv_cache.KVCache.from_model_config(
|
51
|
+
kv_cache_max=config.max_seq_len, config=config, kv_layout=kv_layout
|
52
|
+
)
|
51
53
|
kwargs = {
|
52
54
|
"tokens": tokens,
|
53
55
|
"input_pos": input_pos,
|
@@ -122,7 +124,9 @@ class TestModelConversion(googletest.TestCase):
|
|
122
124
|
decode_token = torch.tensor([[1]], dtype=torch.int)
|
123
125
|
decode_input_pos = torch.tensor([5], dtype=torch.int)
|
124
126
|
|
125
|
-
kv = kv_cache.KVCache.from_model_config(
|
127
|
+
kv = kv_cache.KVCache.from_model_config(
|
128
|
+
kv_cache_max=128, config=config, kv_layout=kv_layout
|
129
|
+
)
|
126
130
|
|
127
131
|
edge_model = (
|
128
132
|
ai_edge_torch.signature(
|
@@ -177,12 +181,12 @@ class TestModelConversion(googletest.TestCase):
|
|
177
181
|
|
178
182
|
def test_tiny_llama_multisig(self):
|
179
183
|
config = tiny_llama.get_fake_model_config()
|
180
|
-
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
184
|
+
pytorch_model = tiny_llama.TinyLlama(config, mask_cache_size=128).eval()
|
181
185
|
self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
|
182
186
|
|
183
187
|
def test_tiny_llama_multisig_kv_layout_transposed(self):
|
184
188
|
config = tiny_llama.get_fake_model_config()
|
185
|
-
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
189
|
+
pytorch_model = tiny_llama.TinyLlama(config, mask_cache_size=128).eval()
|
186
190
|
self._test_multisig_model(
|
187
191
|
config,
|
188
192
|
pytorch_model,
|
@@ -55,6 +55,7 @@ class TestModelConversion(googletest.TestCase):
|
|
55
55
|
experimental_default_delegate_latest_features=True,
|
56
56
|
)
|
57
57
|
)
|
58
|
+
self._kv_cache_max = 128
|
58
59
|
# Default cache_size_limit, 8 is hit and aborts often when the tests are
|
59
60
|
# running all together. Doubles it to avoid abortion.
|
60
61
|
torch._dynamo.config.cache_size_limit = 16
|
@@ -64,7 +65,7 @@ class TestModelConversion(googletest.TestCase):
|
|
64
65
|
seq_len = 10
|
65
66
|
tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
66
67
|
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
67
|
-
kv = kv_cache.KVCache.from_model_config(config)
|
68
|
+
kv = kv_cache.KVCache.from_model_config(self._kv_cache_max, config)
|
68
69
|
|
69
70
|
edge_model = ai_edge_torch.signature(
|
70
71
|
signature_name,
|
@@ -95,74 +96,77 @@ class TestModelConversion(googletest.TestCase):
|
|
95
96
|
|
96
97
|
def test_gemma1(self):
|
97
98
|
config = gemma1.get_fake_model_config()
|
98
|
-
pytorch_model = gemma1.Gemma1(config).eval()
|
99
|
+
pytorch_model = gemma1.Gemma1(config, self._kv_cache_max).eval()
|
99
100
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
100
101
|
|
101
102
|
def test_gemma2(self):
|
102
103
|
config = gemma2.get_fake_model_config()
|
103
|
-
pytorch_model = gemma2.Gemma2(config).eval()
|
104
|
+
pytorch_model = gemma2.Gemma2(config, self._kv_cache_max).eval()
|
104
105
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
105
106
|
|
106
107
|
def test_llama(self):
|
107
108
|
config = llama.get_fake_model_config()
|
108
|
-
pytorch_model = llama.Llama(config).eval()
|
109
|
+
pytorch_model = llama.Llama(config, self._kv_cache_max).eval()
|
109
110
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
110
111
|
|
111
112
|
def test_phi2(self):
|
112
113
|
config = phi2.get_fake_model_config()
|
113
|
-
pytorch_model = phi2.Phi2(config).eval()
|
114
|
+
pytorch_model = phi2.Phi2(config, self._kv_cache_max).eval()
|
114
115
|
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
115
116
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
116
117
|
|
117
118
|
def test_phi3(self):
|
118
119
|
config = phi3.get_fake_model_config()
|
119
|
-
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
120
|
+
pytorch_model = phi3.Phi3_5Mini(config, self._kv_cache_max).eval()
|
120
121
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
121
122
|
|
122
123
|
def test_phi4(self):
|
123
124
|
config = phi4.get_fake_model_config()
|
124
|
-
pytorch_model = phi4.Phi4Mini(config).eval()
|
125
|
+
pytorch_model = phi4.Phi4Mini(config, self._kv_cache_max).eval()
|
125
126
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
126
127
|
|
127
128
|
def test_smollm(self):
|
128
129
|
config = smollm.get_fake_model_config()
|
129
|
-
pytorch_model = smollm.SmolLM(config).eval()
|
130
|
+
pytorch_model = smollm.SmolLM(config, self._kv_cache_max).eval()
|
130
131
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
131
132
|
|
132
133
|
def test_smollm2(self):
|
133
134
|
config = smollm.get_fake_model_config_v2()
|
134
|
-
pytorch_model = smollm.SmolLM2(config).eval()
|
135
|
+
pytorch_model = smollm.SmolLM2(config, self._kv_cache_max).eval()
|
135
136
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
136
137
|
|
137
138
|
def test_openelm(self):
|
138
139
|
config = openelm.get_fake_model_config()
|
139
|
-
pytorch_model = openelm.OpenELM(config).eval()
|
140
|
+
pytorch_model = openelm.OpenELM(config, self._kv_cache_max).eval()
|
140
141
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
141
142
|
|
142
143
|
def test_qwen(self):
|
143
144
|
config = qwen.get_fake_model_config()
|
144
|
-
pytorch_model = qwen.Qwen(config).eval()
|
145
|
+
pytorch_model = qwen.Qwen(config, self._kv_cache_max).eval()
|
145
146
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
146
147
|
|
147
148
|
def test_deepseek(self):
|
148
149
|
config = deepseek.get_fake_model_config()
|
149
|
-
pytorch_model = deepseek.DeepSeekDistillQwen(
|
150
|
+
pytorch_model = deepseek.DeepSeekDistillQwen(
|
151
|
+
config, self._kv_cache_max
|
152
|
+
).eval()
|
150
153
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
151
154
|
|
152
155
|
def test_hammer(self):
|
153
156
|
config = hammer.get_fake_model_config()
|
154
|
-
pytorch_model = hammer.Hammer(config).eval()
|
157
|
+
pytorch_model = hammer.Hammer(config, self._kv_cache_max).eval()
|
155
158
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
156
159
|
|
157
|
-
|
158
160
|
def test_amd_llama_135m(self):
|
159
161
|
config = amd_llama_135m.get_fake_model_config()
|
160
|
-
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
162
|
+
pytorch_model = amd_llama_135m.AmdLlama(config, self._kv_cache_max).eval()
|
161
163
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
162
164
|
|
163
165
|
def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
|
164
166
|
config = paligemma.get_fake_model_config(decoder_config)
|
165
|
-
pytorch_model = paligemma.PaliGemma(
|
167
|
+
pytorch_model = paligemma.PaliGemma(
|
168
|
+
config, decoder_class, mask_cache_size=self._kv_cache_max
|
169
|
+
).eval()
|
166
170
|
|
167
171
|
image_config = config.image_encoder_config.image_embedding
|
168
172
|
num_patches = (image_config.image_size // image_config.patch_size) ** 2
|
@@ -171,7 +175,9 @@ class TestModelConversion(googletest.TestCase):
|
|
171
175
|
seq_len = num_patches + 10
|
172
176
|
tokens = torch.zeros((1, seq_len), dtype=torch.int)
|
173
177
|
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
174
|
-
kv = kv_cache.KVCache.from_model_config(
|
178
|
+
kv = kv_cache.KVCache.from_model_config(
|
179
|
+
self._kv_cache_max, config.decoder_config
|
180
|
+
)
|
175
181
|
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32)
|
176
182
|
|
177
183
|
edge_model = ai_edge_torch.signature(
|
@@ -218,7 +224,7 @@ class TestModelConversion(googletest.TestCase):
|
|
218
224
|
|
219
225
|
def test_qwen_vl_model(self):
|
220
226
|
config = qwen_vl.get_fake_model_config()
|
221
|
-
pytorch_model = qwen_vl.QwenVL(config).eval()
|
227
|
+
pytorch_model = qwen_vl.QwenVL(config, self._kv_cache_max).eval()
|
222
228
|
|
223
229
|
grid_thw = pytorch_model.image_encoder.get_grid_thw()
|
224
230
|
pixel_values_size = pytorch_model.image_encoder.get_pixel_values_size(
|
@@ -229,7 +235,9 @@ class TestModelConversion(googletest.TestCase):
|
|
229
235
|
seq_len = pixel_values_size[0] + 10
|
230
236
|
tokens = torch.zeros((1, seq_len), dtype=torch.int)
|
231
237
|
input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
232
|
-
kv = kv_cache.KVCache.from_model_config(
|
238
|
+
kv = kv_cache.KVCache.from_model_config(
|
239
|
+
self._kv_cache_max, config.decoder_config
|
240
|
+
)
|
233
241
|
pixel_values = torch.zeros(pixel_values_size, dtype=torch.float32)
|
234
242
|
|
235
243
|
edge_model = ai_edge_torch.signature(
|