ai-edge-torch-nightly 0.6.0.dev20250601__py3-none-any.whl → 0.6.0.dev20250603__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
  2. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
  3. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
  4. ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
  7. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
  8. ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
  9. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
  10. ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
  11. ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
  12. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
  13. ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
  14. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
  15. ai_edge_torch/generative/examples/llama/llama.py +13 -26
  16. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
  17. ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
  18. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
  19. ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
  20. ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
  21. ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
  22. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
  23. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
  24. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
  25. ai_edge_torch/generative/examples/phi/phi2.py +8 -16
  26. ai_edge_torch/generative/examples/phi/phi3.py +8 -16
  27. ai_edge_torch/generative/examples/phi/phi4.py +8 -16
  28. ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
  29. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
  30. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
  31. ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
  32. ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
  33. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
  34. ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
  35. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
  36. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
  37. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
  38. ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
  39. ai_edge_torch/generative/examples/t5/t5.py +23 -23
  40. ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
  41. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
  42. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
  43. ai_edge_torch/generative/layers/kv_cache.py +13 -1
  44. ai_edge_torch/generative/layers/model_config.py +0 -14
  45. ai_edge_torch/generative/test/test_kv_cache.py +14 -24
  46. ai_edge_torch/generative/test/test_lora.py +4 -21
  47. ai_edge_torch/generative/test/test_model_conversion.py +8 -4
  48. ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
  49. ai_edge_torch/generative/utilities/converter.py +15 -6
  50. ai_edge_torch/generative/utilities/model_builder.py +16 -6
  51. ai_edge_torch/generative/utilities/verifier.py +16 -6
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/METADATA +1 -1
  54. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/RECORD +57 -57
  55. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/LICENSE +0 -0
  56. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/WHEEL +0 -0
  57. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/top_level.txt +0 -0
@@ -128,7 +128,7 @@ class T5(nn.Module):
128
128
 
129
129
  self.enc_attn_mask_cache = (
130
130
  torch.zeros(
131
- (config.kv_cache_max, config.kv_cache_max),
131
+ (config.max_seq_len, config.max_seq_len),
132
132
  dtype=torch.float32,
133
133
  device=torch.device("cpu"),
134
134
  )
@@ -137,7 +137,7 @@ class T5(nn.Module):
137
137
  )
138
138
 
139
139
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
140
- size=config.kv_cache_max,
140
+ size=config.max_seq_len,
141
141
  dtype=torch.float32,
142
142
  device=torch.device("cpu"),
143
143
  )
@@ -146,16 +146,16 @@ class T5(nn.Module):
146
146
  attn_config = config.block_config(0).attn_config
147
147
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
148
148
  bidirectional=True,
149
- query_length=config.kv_cache_max,
150
- key_length=config.kv_cache_max,
149
+ query_length=config.max_seq_len,
150
+ key_length=config.max_seq_len,
151
151
  num_buckets=attn_config.relative_attention_num_buckets,
152
152
  max_distance=attn_config.relative_attention_max_distance,
153
153
  )
154
154
 
155
155
  self.dec_rel_pos_mask = attn_utils.build_relative_position_buckets(
156
156
  bidirectional=False,
157
- query_length=config.kv_cache_max,
158
- key_length=config.kv_cache_max,
157
+ query_length=config.max_seq_len,
158
+ key_length=config.max_seq_len,
159
159
  num_buckets=attn_config.relative_attention_num_buckets,
160
160
  max_distance=attn_config.relative_attention_max_distance,
161
161
  )
@@ -176,20 +176,20 @@ class T5(nn.Module):
176
176
  )
177
177
 
178
178
  enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
179
- enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
179
+ enc_mask = enc_mask[:, :, :, : self.config.max_seq_len]
180
180
  # Mask off any "pad" tokens that shouldn't contribute to self-attention
181
181
  enc_mask[:, :, :, :] += pad_mask
182
182
  dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
183
- dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
183
+ dec_mask = dec_mask[:, :, :, : self.config.max_seq_len]
184
184
  enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
185
185
  enc_relative_position = enc_relative_position[
186
- :, :, :, : self.config.kv_cache_max
186
+ :, :, :, : self.config.max_seq_len
187
187
  ]
188
188
  dec_relative_position = self.enc_rel_pos_mask.index_select(
189
189
  2, decoder_input_pos
190
190
  )
191
191
  dec_relative_position = dec_relative_position[
192
- :, :, :, : self.config.kv_cache_max
192
+ :, :, :, : self.config.max_seq_len
193
193
  ]
194
194
  enc_attention_mask = self.enc_attn_mask_cache.index_select(
195
195
  2, decoder_input_pos
@@ -243,7 +243,7 @@ class T5Encoder(nn.Module):
243
243
 
244
244
  self.enc_attn_mask_cache = (
245
245
  torch.zeros(
246
- (config.kv_cache_max, config.kv_cache_max),
246
+ (config.max_seq_len, config.max_seq_len),
247
247
  dtype=torch.float32,
248
248
  device=torch.device("cpu"),
249
249
  )
@@ -255,8 +255,8 @@ class T5Encoder(nn.Module):
255
255
  attn_config = config.block_config(0).attn_config
256
256
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
257
257
  bidirectional=True,
258
- query_length=config.kv_cache_max,
259
- key_length=config.kv_cache_max,
258
+ query_length=config.max_seq_len,
259
+ key_length=config.max_seq_len,
260
260
  num_buckets=attn_config.relative_attention_num_buckets,
261
261
  max_distance=attn_config.relative_attention_max_distance,
262
262
  )
@@ -275,12 +275,12 @@ class T5Encoder(nn.Module):
275
275
  )
276
276
 
277
277
  enc_mask = self.enc_attn_mask_cache.index_select(2, input_pos)
278
- enc_mask = enc_mask[:, :, :, : self.config.kv_cache_max]
278
+ enc_mask = enc_mask[:, :, :, : self.config.max_seq_len]
279
279
  # Mask off any "pad" tokens that shouldn't contribute to self-attention
280
280
  enc_mask[:, :, :, :] += pad_mask
281
281
  enc_relative_position = self.enc_rel_pos_mask.index_select(2, input_pos)
282
282
  enc_relative_position = enc_relative_position[
283
- :, :, :, : self.config.kv_cache_max
283
+ :, :, :, : self.config.max_seq_len
284
284
  ]
285
285
 
286
286
  # Convert encoder inputs in embeddings if needed
@@ -315,7 +315,7 @@ class T5Decoder(nn.Module):
315
315
 
316
316
  self.enc_attn_mask_cache = (
317
317
  torch.zeros(
318
- (config.kv_cache_max, config.kv_cache_max),
318
+ (config.max_seq_len, config.max_seq_len),
319
319
  dtype=torch.float32,
320
320
  device=torch.device("cpu"),
321
321
  )
@@ -327,14 +327,14 @@ class T5Decoder(nn.Module):
327
327
  attn_config = config.block_config(0).attn_config
328
328
  self.enc_rel_pos_mask = attn_utils.build_relative_position_buckets(
329
329
  bidirectional=True,
330
- query_length=config.kv_cache_max,
331
- key_length=config.kv_cache_max,
330
+ query_length=config.max_seq_len,
331
+ key_length=config.max_seq_len,
332
332
  num_buckets=attn_config.relative_attention_num_buckets,
333
333
  max_distance=attn_config.relative_attention_max_distance,
334
334
  )
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
- size=config.kv_cache_max,
337
+ size=config.max_seq_len,
338
338
  )
339
339
 
340
340
  @torch.inference_mode
@@ -346,12 +346,12 @@ class T5Decoder(nn.Module):
346
346
  pad_mask: torch.Tensor,
347
347
  ) -> torch.Tensor:
348
348
  dec_mask = self.dec_attn_mask_cache.index_select(2, decoder_input_pos)
349
- dec_mask = dec_mask[:, :, :, : self.config.kv_cache_max]
349
+ dec_mask = dec_mask[:, :, :, : self.config.max_seq_len]
350
350
  dec_relative_position = self.enc_rel_pos_mask.index_select(
351
351
  2, decoder_input_pos
352
352
  )
353
353
  dec_relative_position = dec_relative_position[
354
- :, :, :, : self.config.kv_cache_max
354
+ :, :, :, : self.config.max_seq_len
355
355
  ]
356
356
  enc_attention_mask = self.enc_attn_mask_cache.index_select(
357
357
  2, decoder_input_pos
@@ -603,7 +603,7 @@ def define_and_run_t5(checkpoint_path: str) -> None:
603
603
 
604
604
  decode_d_token = torch.tensor([[0]], dtype=torch.int)
605
605
  decode_d_input_pos = torch.tensor([0], dtype=torch.int)
606
- pad_mask = torch.zeros([model.config.kv_cache_max], dtype=torch.float32)
606
+ pad_mask = torch.zeros([model.config.max_seq_len], dtype=torch.float32)
607
607
  pad_mask[77:] = float("-inf")
608
608
  lm_logits = model.forward(
609
609
  tokens, input_pos, decode_d_token, decode_d_input_pos, pad_mask
@@ -636,7 +636,7 @@ def define_and_run_t5_split(checkpoint_path: str) -> None:
636
636
  decode_d_token = torch.tensor([[0]], dtype=torch.int)
637
637
  decode_d_input_pos = torch.tensor([0], dtype=torch.int)
638
638
  pad_mask = torch.zeros(
639
- [t5_encoder_model.config.kv_cache_max], dtype=torch.float32
639
+ [t5_encoder_model.config.max_seq_len], dtype=torch.float32
640
640
  )
641
641
  pad_mask[77:] = float("-inf")
642
642
  hidden_states = t5_encoder_model.forward(tokens, input_pos, pad_mask)
@@ -53,7 +53,7 @@ class EncoderDecoderBlock(nn.Module):
53
53
  model_config.embedding_dim,
54
54
  config.attn_config,
55
55
  config.pre_attention_norm_config,
56
- model_config.kv_cache_max,
56
+ model_config.max_seq_len,
57
57
  model_config.enable_hlfb,
58
58
  has_relative_attention_bias=has_relative_attention_bias,
59
59
  )
@@ -64,7 +64,7 @@ class EncoderDecoderBlock(nn.Module):
64
64
  model_config.embedding_dim,
65
65
  config.attn_config,
66
66
  config.pre_attention_norm_config,
67
- model_config.kv_cache_max,
67
+ model_config.max_seq_len,
68
68
  model_config.enable_hlfb,
69
69
  # Cross Attention does not have relative attention bias.
70
70
  has_relative_attention_bias=False,
@@ -31,13 +31,14 @@ def main(_):
31
31
  custom_loader=loader.maybe_get_custom_loader(
32
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
33
  ),
34
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
35
  )
36
36
  converter.convert_to_tflite(
37
37
  pytorch_model,
38
38
  output_path=flags.FLAGS.output_path,
39
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
40
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
42
  quantize=flags.FLAGS.quantize,
42
43
  lora_ranks=flags.FLAGS.lora_ranks,
43
44
  export_config=export_config.get_from_flags(),
@@ -29,16 +29,8 @@ class TinyLlama(model_builder.DecoderOnlyModel):
29
29
  pass
30
30
 
31
31
 
32
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33
- """Returns the model config for a TinyLlama model.
34
-
35
- Args:
36
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37
- is 1024.
38
-
39
- Returns:
40
- The model config for a TinyLlama model.
41
- """
32
+ def get_model_config() -> cfg.ModelConfig:
33
+ """Returns the model config for a TinyLlama model."""
42
34
  attn_config = cfg.AttentionConfig(
43
35
  num_heads=32,
44
36
  head_dim=64,
@@ -63,7 +55,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
63
55
  num_layers=22,
64
56
  max_seq_len=2048,
65
57
  embedding_dim=2048,
66
- kv_cache_max_len=kv_cache_max_len,
67
58
  block_configs=block_config,
68
59
  final_norm_config=norm_config,
69
60
  lm_head_share_weight_with_embedding=False,
@@ -71,8 +62,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
71
62
  return config
72
63
 
73
64
 
74
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
75
- config = get_model_config(**kwargs)
65
+ def get_fake_model_config() -> cfg.ModelConfig:
66
+ config = get_model_config()
76
67
  config.vocab_size = 128
77
68
  config.num_layers = 2
78
69
  # TinyLlama has only one block config.
@@ -83,12 +74,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
83
74
  def build_model(
84
75
  checkpoint_path: str,
85
76
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
86
- **kwargs
77
+ mask_cache_size: int = 0,
87
78
  ) -> nn.Module:
88
79
  return model_builder.build_decoder_only_model(
89
80
  checkpoint_path=checkpoint_path,
90
- config=get_model_config(**kwargs),
81
+ config=get_model_config(),
91
82
  tensor_names=TENSOR_NAMES,
92
83
  model_class=TinyLlama,
93
84
  custom_loader=custom_loader,
85
+ mask_cache_size=mask_cache_size,
94
86
  )
@@ -88,6 +88,12 @@ class KVCacheEntry:
88
88
  obj = cls(k_cache=k, v_cache=v, kv_layout=kv_layout)
89
89
  return obj
90
90
 
91
+ def get_max_seq_len(self) -> int:
92
+ """Get the maximum sequence length in the KV cache."""
93
+ return self.k_cache.size(
94
+ self.kv_layout[0].dimensions.index(types.TensorDims.SEQUENCE)
95
+ )
96
+
91
97
 
92
98
  @dataclasses.dataclass
93
99
  class KVCache:
@@ -98,6 +104,7 @@ class KVCache:
98
104
  @classmethod
99
105
  def from_model_config(
100
106
  cls,
107
+ kv_cache_max: int,
101
108
  config: model_config.ModelConfig,
102
109
  dtype: torch.dtype = torch.float32,
103
110
  device: torch.device | None = None,
@@ -107,6 +114,7 @@ class KVCache:
107
114
  """Build an instance of the class based on model config.
108
115
 
109
116
  Args:
117
+ kv_cache_max (int): The maximum sequence length in the KV cache.
110
118
  config (ModelConfig): Model config used for building the cache.
111
119
  dtype (torch.dtype, optional): The data type of the cache tensor.
112
120
  Defaults to torch.float32.
@@ -120,7 +128,7 @@ class KVCache:
120
128
  """
121
129
  caches = [
122
130
  KVCacheEntry.from_model_config(
123
- config.kv_cache_max
131
+ kv_cache_max
124
132
  if not config.block_config(idx).kv_cache_max_len
125
133
  else config.block_config(idx).kv_cache_max_len,
126
134
  config.block_config(idx).attn_config,
@@ -139,6 +147,10 @@ class KVCache:
139
147
  flattened, _ = _flatten_kvc(self)
140
148
  return flattened
141
149
 
150
+ def get_max_seq_len(self) -> int:
151
+ """Get the maximum sequence length in the KV cache."""
152
+ return self.caches[0].get_max_seq_len()
153
+
142
154
 
143
155
  def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
144
156
  flattened = []
@@ -251,9 +251,6 @@ class ModelConfig:
251
251
  # Whether to turn on high-level function boundary.
252
252
  enable_hlfb: bool = True
253
253
 
254
- # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
255
- kv_cache_max_len: int = 0
256
-
257
254
  # Softcap on the model output logits.
258
255
  final_logit_softcap: Optional[float] = None
259
256
 
@@ -261,23 +258,12 @@ class ModelConfig:
261
258
  # forward pass. Defaults to a standard implementation.
262
259
  build_rope: Callable = rotary_position_embedding.build_rope
263
260
 
264
- # Whether or not to use a mask cache. Mask cache can speed up inference when
265
- # statically exporting models. However, it is not supported in the dynamic
266
- # export.
267
- use_mask_cache: bool = True
268
-
269
261
  # An interleaved sequence of the attention types used in the model.
270
262
  # E.g. [AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING,
271
263
  # AttentionType.GLOBAL] means that the model has an attention pattern of 2
272
264
  # local attentions followed by a global attention in a repeated pattern.
273
265
  attention_patterns: Optional[Sequence[AttentionType]] = None
274
266
 
275
- @property
276
- def kv_cache_max(self) -> int:
277
- if self.kv_cache_max_len > 0:
278
- return self.kv_cache_max_len
279
- return self.max_seq_len
280
-
281
267
  def block_config(self, idx: int) -> TransformerBlockConfig:
282
268
  if isinstance(self.block_configs, TransformerBlockConfig):
283
269
  return self.block_configs
@@ -25,9 +25,7 @@ from absl.testing import absltest as googletest
25
25
 
26
26
  class TestKVLayers(googletest.TestCase):
27
27
 
28
- def _get_test_config(
29
- self, num_layers, head_dim, num_query_groups, kv_cache_max_len
30
- ):
28
+ def _get_test_config(self, num_layers, head_dim, num_query_groups):
31
29
  attn_config = cfg.AttentionConfig(
32
30
  num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
33
31
  )
@@ -35,7 +33,6 @@ class TestKVLayers(googletest.TestCase):
35
33
  attn_config=attn_config, ff_config=None
36
34
  )
37
35
  config = cfg.ModelConfig(
38
- kv_cache_max_len=kv_cache_max_len,
39
36
  embedding_dim=head_dim,
40
37
  block_configs=block_config,
41
38
  num_layers=num_layers,
@@ -50,12 +47,9 @@ class TestKVLayers(googletest.TestCase):
50
47
  NUM_QG = 1
51
48
  KV_LEN = 4
52
49
  config = self._get_test_config(
53
- num_layers=N,
54
- head_dim=HEAD_DIM,
55
- num_query_groups=NUM_QG,
56
- kv_cache_max_len=KV_LEN,
50
+ num_layers=N, head_dim=HEAD_DIM, num_query_groups=NUM_QG
57
51
  )
58
- kv = kv_utils.KVCache.from_model_config(config)
52
+ kv = kv_utils.KVCache.from_model_config(KV_LEN, config)
59
53
  entry = kv.caches[0]
60
54
  # single-slice update
61
55
  input_pos = torch.tensor([1])
@@ -103,12 +97,9 @@ class TestKVLayers(googletest.TestCase):
103
97
  NUM_QG = 1
104
98
  KV_LEN = 4
105
99
  config = self._get_test_config(
106
- num_layers=N,
107
- head_dim=HEAD_DIM,
108
- num_query_groups=NUM_QG,
109
- kv_cache_max_len=KV_LEN,
100
+ num_layers=N, head_dim=HEAD_DIM, num_query_groups=NUM_QG
110
101
  )
111
- kv = kv_utils.KVCache.from_model_config(config)
102
+ kv = kv_utils.KVCache.from_model_config(KV_LEN, config)
112
103
  model = TestModel()
113
104
  exported_program = torch.export.export(model, (kv,))
114
105
  input_specs = exported_program.graph_signature.input_specs
@@ -119,12 +110,11 @@ class TestKVLayers(googletest.TestCase):
119
110
  def test_pytree_roundtrip_kv_cache(self):
120
111
  NUM_LAYERS = 4
121
112
  config = self._get_test_config(
122
- num_layers=NUM_LAYERS,
123
- head_dim=2,
124
- num_query_groups=1,
125
- kv_cache_max_len=4,
113
+ num_layers=NUM_LAYERS, head_dim=2, num_query_groups=1
114
+ )
115
+ kv = kv_utils.KVCache.from_model_config(
116
+ kv_cache_max=4, config=config, batch_size=1
126
117
  )
127
- kv = kv_utils.KVCache.from_model_config(config, batch_size=1)
128
118
  flat, treespec = pytree.tree_flatten(kv)
129
119
  self.assertLen(flat, NUM_LAYERS * 2)
130
120
  kv_unflat = pytree.tree_unflatten(flat, treespec)
@@ -133,13 +123,13 @@ class TestKVLayers(googletest.TestCase):
133
123
  def test_pytree_roundtrip_kv_cache_derived(self):
134
124
  NUM_LAYERS = 4
135
125
  config = self._get_test_config(
136
- num_layers=NUM_LAYERS,
137
- head_dim=2,
138
- num_query_groups=1,
139
- kv_cache_max_len=4,
126
+ num_layers=NUM_LAYERS, head_dim=2, num_query_groups=1
140
127
  )
141
128
  kv = kv_utils.KVCache.from_model_config(
142
- config, batch_size=1, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
129
+ kv_cache_max=4,
130
+ config=config,
131
+ batch_size=1,
132
+ kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED,
143
133
  )
144
134
  flat, treespec = pytree.tree_flatten(kv)
145
135
  self.assertLen(flat, NUM_LAYERS * 2)
@@ -58,12 +58,7 @@ class TestLora(googletest.TestCase):
58
58
  safetensors_file = resource_loader.get_path_to_datafile(
59
59
  "fixtures/test_lora_rank16.safetensors"
60
60
  )
61
- config = self._get_test_config(
62
- num_layers=1,
63
- head_dim=8,
64
- num_query_groups=1,
65
- kv_cache_max_len=16,
66
- )
61
+ config = self._get_test_config(num_layers=1, head_dim=8, num_query_groups=1)
67
62
  lora = lora_utils.LoRA.from_safetensors(
68
63
  safetensors_file,
69
64
  scale=1.0,
@@ -84,12 +79,8 @@ class TestLora(googletest.TestCase):
84
79
  n = 1
85
80
  head_dim = 2
86
81
  num_query_groups = 1
87
- key_length = 4
88
82
  config = self._get_test_config(
89
- num_layers=n,
90
- head_dim=head_dim,
91
- num_query_groups=num_query_groups,
92
- kv_cache_max_len=key_length,
83
+ num_layers=n, head_dim=head_dim, num_query_groups=num_query_groups
93
84
  )
94
85
  inputs = torch.zeros((n, 1, head_dim))
95
86
  lora = lora_utils.LoRA.zeros(rank=16, config=config)
@@ -111,20 +102,13 @@ class TestLora(googletest.TestCase):
111
102
 
112
103
  def test_lora_tflite_serialization(self):
113
104
  """Tests the serialization of the LoRA module."""
114
- config = self._get_test_config(
115
- num_layers=2,
116
- head_dim=8,
117
- num_query_groups=1,
118
- kv_cache_max_len=16,
119
- )
105
+ config = self._get_test_config(num_layers=2, head_dim=8, num_query_groups=1)
120
106
  lora = lora_utils.LoRA.random(rank=16, config=config)
121
107
  flatbuffer_model = lora.to_tflite()
122
108
  recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model)
123
109
  self.assertEqual(lora, recovered_lora)
124
110
 
125
- def _get_test_config(
126
- self, num_layers, head_dim, num_query_groups, kv_cache_max_len
127
- ):
111
+ def _get_test_config(self, num_layers, head_dim, num_query_groups):
128
112
  """Returns a test model config."""
129
113
  attn_config = cfg.AttentionConfig(
130
114
  num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
@@ -133,7 +117,6 @@ class TestLora(googletest.TestCase):
133
117
  attn_config=attn_config, ff_config=None
134
118
  )
135
119
  config = cfg.ModelConfig(
136
- kv_cache_max_len=kv_cache_max_len,
137
120
  embedding_dim=head_dim,
138
121
  block_configs=block_config,
139
122
  num_layers=num_layers,
@@ -47,7 +47,9 @@ class TestModelConversion(googletest.TestCase):
47
47
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
48
48
  [10], dtype=torch.int
49
49
  )
50
- kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
50
+ kv = kv_cache.KVCache.from_model_config(
51
+ kv_cache_max=config.max_seq_len, config=config, kv_layout=kv_layout
52
+ )
51
53
  kwargs = {
52
54
  "tokens": tokens,
53
55
  "input_pos": input_pos,
@@ -122,7 +124,9 @@ class TestModelConversion(googletest.TestCase):
122
124
  decode_token = torch.tensor([[1]], dtype=torch.int)
123
125
  decode_input_pos = torch.tensor([5], dtype=torch.int)
124
126
 
125
- kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
127
+ kv = kv_cache.KVCache.from_model_config(
128
+ kv_cache_max=128, config=config, kv_layout=kv_layout
129
+ )
126
130
 
127
131
  edge_model = (
128
132
  ai_edge_torch.signature(
@@ -177,12 +181,12 @@ class TestModelConversion(googletest.TestCase):
177
181
 
178
182
  def test_tiny_llama_multisig(self):
179
183
  config = tiny_llama.get_fake_model_config()
180
- pytorch_model = tiny_llama.TinyLlama(config).eval()
184
+ pytorch_model = tiny_llama.TinyLlama(config, mask_cache_size=128).eval()
181
185
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
182
186
 
183
187
  def test_tiny_llama_multisig_kv_layout_transposed(self):
184
188
  config = tiny_llama.get_fake_model_config()
185
- pytorch_model = tiny_llama.TinyLlama(config).eval()
189
+ pytorch_model = tiny_llama.TinyLlama(config, mask_cache_size=128).eval()
186
190
  self._test_multisig_model(
187
191
  config,
188
192
  pytorch_model,
@@ -55,6 +55,7 @@ class TestModelConversion(googletest.TestCase):
55
55
  experimental_default_delegate_latest_features=True,
56
56
  )
57
57
  )
58
+ self._kv_cache_max = 128
58
59
  # Default cache_size_limit, 8 is hit and aborts often when the tests are
59
60
  # running all together. Doubles it to avoid abortion.
60
61
  torch._dynamo.config.cache_size_limit = 16
@@ -64,7 +65,7 @@ class TestModelConversion(googletest.TestCase):
64
65
  seq_len = 10
65
66
  tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
66
67
  input_pos = torch.arange(0, seq_len, dtype=torch.int)
67
- kv = kv_cache.KVCache.from_model_config(config)
68
+ kv = kv_cache.KVCache.from_model_config(self._kv_cache_max, config)
68
69
 
69
70
  edge_model = ai_edge_torch.signature(
70
71
  signature_name,
@@ -95,74 +96,77 @@ class TestModelConversion(googletest.TestCase):
95
96
 
96
97
  def test_gemma1(self):
97
98
  config = gemma1.get_fake_model_config()
98
- pytorch_model = gemma1.Gemma1(config).eval()
99
+ pytorch_model = gemma1.Gemma1(config, self._kv_cache_max).eval()
99
100
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
100
101
 
101
102
  def test_gemma2(self):
102
103
  config = gemma2.get_fake_model_config()
103
- pytorch_model = gemma2.Gemma2(config).eval()
104
+ pytorch_model = gemma2.Gemma2(config, self._kv_cache_max).eval()
104
105
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
105
106
 
106
107
  def test_llama(self):
107
108
  config = llama.get_fake_model_config()
108
- pytorch_model = llama.Llama(config).eval()
109
+ pytorch_model = llama.Llama(config, self._kv_cache_max).eval()
109
110
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
110
111
 
111
112
  def test_phi2(self):
112
113
  config = phi2.get_fake_model_config()
113
- pytorch_model = phi2.Phi2(config).eval()
114
+ pytorch_model = phi2.Phi2(config, self._kv_cache_max).eval()
114
115
  # Phi-2 logits are very big, so we need a larger absolute tolerance.
115
116
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
116
117
 
117
118
  def test_phi3(self):
118
119
  config = phi3.get_fake_model_config()
119
- pytorch_model = phi3.Phi3_5Mini(config).eval()
120
+ pytorch_model = phi3.Phi3_5Mini(config, self._kv_cache_max).eval()
120
121
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
121
122
 
122
123
  def test_phi4(self):
123
124
  config = phi4.get_fake_model_config()
124
- pytorch_model = phi4.Phi4Mini(config).eval()
125
+ pytorch_model = phi4.Phi4Mini(config, self._kv_cache_max).eval()
125
126
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
126
127
 
127
128
  def test_smollm(self):
128
129
  config = smollm.get_fake_model_config()
129
- pytorch_model = smollm.SmolLM(config).eval()
130
+ pytorch_model = smollm.SmolLM(config, self._kv_cache_max).eval()
130
131
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
131
132
 
132
133
  def test_smollm2(self):
133
134
  config = smollm.get_fake_model_config_v2()
134
- pytorch_model = smollm.SmolLM2(config).eval()
135
+ pytorch_model = smollm.SmolLM2(config, self._kv_cache_max).eval()
135
136
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
136
137
 
137
138
  def test_openelm(self):
138
139
  config = openelm.get_fake_model_config()
139
- pytorch_model = openelm.OpenELM(config).eval()
140
+ pytorch_model = openelm.OpenELM(config, self._kv_cache_max).eval()
140
141
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
141
142
 
142
143
  def test_qwen(self):
143
144
  config = qwen.get_fake_model_config()
144
- pytorch_model = qwen.Qwen(config).eval()
145
+ pytorch_model = qwen.Qwen(config, self._kv_cache_max).eval()
145
146
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
146
147
 
147
148
  def test_deepseek(self):
148
149
  config = deepseek.get_fake_model_config()
149
- pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
150
+ pytorch_model = deepseek.DeepSeekDistillQwen(
151
+ config, self._kv_cache_max
152
+ ).eval()
150
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
151
154
 
152
155
  def test_hammer(self):
153
156
  config = hammer.get_fake_model_config()
154
- pytorch_model = hammer.Hammer(config).eval()
157
+ pytorch_model = hammer.Hammer(config, self._kv_cache_max).eval()
155
158
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
156
159
 
157
-
158
160
  def test_amd_llama_135m(self):
159
161
  config = amd_llama_135m.get_fake_model_config()
160
- pytorch_model = amd_llama_135m.AmdLlama(config).eval()
162
+ pytorch_model = amd_llama_135m.AmdLlama(config, self._kv_cache_max).eval()
161
163
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
162
164
 
163
165
  def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
164
166
  config = paligemma.get_fake_model_config(decoder_config)
165
- pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
167
+ pytorch_model = paligemma.PaliGemma(
168
+ config, decoder_class, mask_cache_size=self._kv_cache_max
169
+ ).eval()
166
170
 
167
171
  image_config = config.image_encoder_config.image_embedding
168
172
  num_patches = (image_config.image_size // image_config.patch_size) ** 2
@@ -171,7 +175,9 @@ class TestModelConversion(googletest.TestCase):
171
175
  seq_len = num_patches + 10
172
176
  tokens = torch.zeros((1, seq_len), dtype=torch.int)
173
177
  input_pos = torch.arange(0, seq_len, dtype=torch.int)
174
- kv = kv_cache.KVCache.from_model_config(config.decoder_config)
178
+ kv = kv_cache.KVCache.from_model_config(
179
+ self._kv_cache_max, config.decoder_config
180
+ )
175
181
  pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32)
176
182
 
177
183
  edge_model = ai_edge_torch.signature(
@@ -218,7 +224,7 @@ class TestModelConversion(googletest.TestCase):
218
224
 
219
225
  def test_qwen_vl_model(self):
220
226
  config = qwen_vl.get_fake_model_config()
221
- pytorch_model = qwen_vl.QwenVL(config).eval()
227
+ pytorch_model = qwen_vl.QwenVL(config, self._kv_cache_max).eval()
222
228
 
223
229
  grid_thw = pytorch_model.image_encoder.get_grid_thw()
224
230
  pixel_values_size = pytorch_model.image_encoder.get_pixel_values_size(
@@ -229,7 +235,9 @@ class TestModelConversion(googletest.TestCase):
229
235
  seq_len = pixel_values_size[0] + 10
230
236
  tokens = torch.zeros((1, seq_len), dtype=torch.int)
231
237
  input_pos = torch.arange(0, seq_len, dtype=torch.int)
232
- kv = kv_cache.KVCache.from_model_config(config.decoder_config)
238
+ kv = kv_cache.KVCache.from_model_config(
239
+ self._kv_cache_max, config.decoder_config
240
+ )
233
241
  pixel_values = torch.zeros(pixel_values_size, dtype=torch.float32)
234
242
 
235
243
  edge_model = ai_edge_torch.signature(