ai-edge-torch-nightly 0.6.0.dev20250602__py3-none-any.whl → 0.6.0.dev20250603__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
  2. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
  3. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
  4. ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
  7. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
  8. ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
  9. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
  10. ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
  11. ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
  12. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
  13. ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
  14. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
  15. ai_edge_torch/generative/examples/llama/llama.py +13 -26
  16. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
  17. ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
  18. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
  19. ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
  20. ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
  21. ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
  22. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
  23. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
  24. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
  25. ai_edge_torch/generative/examples/phi/phi2.py +8 -16
  26. ai_edge_torch/generative/examples/phi/phi3.py +8 -16
  27. ai_edge_torch/generative/examples/phi/phi4.py +8 -16
  28. ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
  29. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
  30. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
  31. ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
  32. ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
  33. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
  34. ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
  35. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
  36. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
  37. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
  38. ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
  39. ai_edge_torch/generative/examples/t5/t5.py +23 -23
  40. ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
  41. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
  42. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
  43. ai_edge_torch/generative/layers/kv_cache.py +13 -1
  44. ai_edge_torch/generative/layers/model_config.py +0 -14
  45. ai_edge_torch/generative/test/test_kv_cache.py +14 -24
  46. ai_edge_torch/generative/test/test_lora.py +4 -21
  47. ai_edge_torch/generative/test/test_model_conversion.py +8 -4
  48. ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
  49. ai_edge_torch/generative/utilities/converter.py +15 -6
  50. ai_edge_torch/generative/utilities/model_builder.py +16 -6
  51. ai_edge_torch/generative/utilities/verifier.py +16 -6
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/METADATA +1 -1
  54. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/RECORD +57 -57
  55. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/LICENSE +0 -0
  56. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/WHEEL +0 -0
  57. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/top_level.txt +0 -0
@@ -29,16 +29,8 @@ class AmdLlama(model_builder.DecoderOnlyModel):
29
29
  pass
30
30
 
31
31
 
32
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33
- """Returns the model config for an AMD-Llama-135m model.
34
-
35
- Args:
36
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37
- is 1024.
38
-
39
- Returns:
40
- The model config for an AMD-Llama-135m model.
41
- """
32
+ def get_model_config() -> cfg.ModelConfig:
33
+ """Returns the model config for an AMD-Llama-135m model."""
42
34
  attn_config = cfg.AttentionConfig(
43
35
  num_heads=12,
44
36
  head_dim=64,
@@ -63,7 +55,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
63
55
  num_layers=12,
64
56
  max_seq_len=2048,
65
57
  embedding_dim=768,
66
- kv_cache_max_len=kv_cache_max_len,
67
58
  block_configs=block_config,
68
59
  final_norm_config=norm_config,
69
60
  lm_head_share_weight_with_embedding=False,
@@ -71,8 +62,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
71
62
  return config
72
63
 
73
64
 
74
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
75
- config = get_model_config(**kwargs)
65
+ def get_fake_model_config() -> cfg.ModelConfig:
66
+ config = get_model_config()
76
67
  config.vocab_size = 128
77
68
  config.num_layers = 2
78
69
  config.block_config(0).ff_config.intermediate_size = 64
@@ -82,12 +73,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
82
73
  def build_model(
83
74
  checkpoint_path: str,
84
75
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
85
- **kwargs
76
+ mask_cache_size: int = 0,
86
77
  ) -> nn.Module:
87
78
  return model_builder.build_decoder_only_model(
88
79
  checkpoint_path=checkpoint_path,
89
- config=get_model_config(**kwargs),
80
+ config=get_model_config(),
90
81
  tensor_names=TENSOR_NAMES,
91
82
  model_class=AmdLlama,
92
83
  custom_loader=custom_loader,
84
+ mask_cache_size=mask_cache_size,
93
85
  )
@@ -31,13 +31,14 @@ def main(_):
31
31
  custom_loader=loader.maybe_get_custom_loader(
32
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
33
  ),
34
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
35
  )
36
36
  converter.convert_to_tflite(
37
37
  pytorch_model,
38
38
  output_path=flags.FLAGS.output_path,
39
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
40
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
42
  quantize=flags.FLAGS.quantize,
42
43
  lora_ranks=flags.FLAGS.lora_ranks,
43
44
  export_config=export_config.get_from_flags(),
@@ -23,6 +23,7 @@ from ai_edge_torch.generative.utilities import loader
23
23
 
24
24
  flags = converter.define_conversion_flags('deepseek')
25
25
 
26
+
26
27
  def main(_):
27
28
  checkpoint_path = flags.FLAGS.checkpoint_path
28
29
  pytorch_model = deepseek.build_model(
@@ -30,13 +31,14 @@ def main(_):
30
31
  custom_loader=loader.maybe_get_custom_loader(
31
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
32
33
  ),
33
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
34
35
  )
35
36
  converter.convert_to_tflite(
36
37
  pytorch_model,
37
38
  output_path=flags.FLAGS.output_path,
38
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
39
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
40
42
  quantize=flags.FLAGS.quantize,
41
43
  lora_ranks=flags.FLAGS.lora_ranks,
42
44
  export_config=export_config.get_from_flags(),
@@ -29,16 +29,8 @@ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
29
29
  pass
30
30
 
31
31
 
32
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33
- """Returns the model config for a Qwen 2.5 3B model.
34
-
35
- Args:
36
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37
- is 1024.
38
-
39
- Returns:
40
- The model config for a SmolLM model.
41
- """
32
+ def get_model_config() -> cfg.ModelConfig:
33
+ """Returns the model config for a Qwen 2.5 3B model."""
42
34
  attn_config = cfg.AttentionConfig(
43
35
  num_heads=12,
44
36
  head_dim=128,
@@ -66,7 +58,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
66
58
  num_layers=28,
67
59
  max_seq_len=4096,
68
60
  embedding_dim=1536,
69
- kv_cache_max_len=kv_cache_max_len,
70
61
  block_configs=block_config,
71
62
  final_norm_config=norm_config,
72
63
  lm_head_share_weight_with_embedding=False,
@@ -74,8 +65,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
74
65
  return config
75
66
 
76
67
 
77
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78
- config = get_model_config(**kwargs)
68
+ def get_fake_model_config() -> cfg.ModelConfig:
69
+ config = get_model_config()
79
70
  config.vocab_size = 128
80
71
  config.num_layers = 2
81
72
  # DeepSeek-R1-Distill-Qwen has only one block config.
@@ -86,12 +77,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
86
77
  def build_model(
87
78
  checkpoint_path: str,
88
79
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
89
- **kwargs
80
+ mask_cache_size: int = 0,
90
81
  ) -> nn.Module:
91
82
  return model_builder.build_decoder_only_model(
92
83
  checkpoint_path=checkpoint_path,
93
- config=get_model_config(**kwargs),
84
+ config=get_model_config(),
94
85
  tensor_names=TENSOR_NAMES,
95
86
  model_class=DeepSeekDistillQwen,
96
87
  custom_loader=custom_loader,
88
+ mask_cache_size=mask_cache_size,
97
89
  )
@@ -31,13 +31,14 @@ def main(_):
31
31
  custom_loader=loader.maybe_get_custom_loader(
32
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
33
  ),
34
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
35
  )
36
36
  converter.convert_to_tflite(
37
37
  pytorch_model,
38
38
  output_path=flags.FLAGS.output_path,
39
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
40
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
42
  quantize=flags.FLAGS.quantize,
42
43
  lora_ranks=flags.FLAGS.lora_ranks,
43
44
  export_config=export_config.get_from_flags(),
@@ -33,13 +33,14 @@ def main(_):
33
33
  custom_loader=loader.maybe_get_custom_loader(
34
34
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35
35
  ),
36
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
36
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
37
37
  )
38
38
  converter.convert_to_tflite(
39
39
  pytorch_model,
40
40
  output_path=flags.FLAGS.output_path,
41
41
  output_name_prefix=flags.FLAGS.output_name_prefix,
42
42
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
43
44
  quantize=flags.FLAGS.quantize,
44
45
  lora_ranks=flags.FLAGS.lora_ranks,
45
46
  export_config=export_config.get_from_flags(),
@@ -42,16 +42,8 @@ class Gemma1(model_builder.DecoderOnlyModel):
42
42
  pass
43
43
 
44
44
 
45
- def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
46
- """Returns the model config for a Gemma 2B model.
47
-
48
- Args:
49
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
50
- is 1024.
51
-
52
- Returns:
53
- The model config for a Gemma 2B model.
54
- """
45
+ def get_model_config_2b() -> cfg.ModelConfig:
46
+ """Returns the model config for a Gemma 2B model."""
55
47
  attn_config = cfg.AttentionConfig(
56
48
  num_heads=8,
57
49
  head_dim=256,
@@ -80,7 +72,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
80
72
  max_seq_len=8192,
81
73
  embedding_dim=embedding_dim,
82
74
  embedding_scale=embedding_dim**0.5,
83
- kv_cache_max_len=kv_cache_max_len,
84
75
  block_configs=block_config,
85
76
  final_norm_config=norm_config,
86
77
  lm_head_use_bias=False,
@@ -88,25 +79,26 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
88
79
  return config
89
80
 
90
81
 
91
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
92
- config = get_model_config_2b(kv_cache_max_len)
82
+ def get_fake_model_config() -> cfg.ModelConfig:
83
+ config = get_model_config_2b()
93
84
  # Gemma has only one block config.
94
85
  config.block_config(0).ff_config.intermediate_size = 128
95
86
  config.vocab_size = 128
96
87
  config.num_layers = 2
97
- config.max_seq_len = 2 * kv_cache_max_len
88
+ config.max_seq_len = 256
98
89
  return config
99
90
 
100
91
 
101
92
  def build_2b_model(
102
93
  checkpoint_path: str,
103
94
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
104
- **kwargs
95
+ mask_cache_size: int = 0,
105
96
  ) -> nn.Module:
106
97
  return model_builder.build_decoder_only_model(
107
98
  checkpoint_path=checkpoint_path,
108
- config=get_model_config_2b(**kwargs),
99
+ config=get_model_config_2b(),
109
100
  tensor_names=TENSOR_NAMES,
110
101
  model_class=Gemma1,
111
102
  custom_loader=custom_loader,
103
+ mask_cache_size=mask_cache_size,
112
104
  )
@@ -104,7 +104,7 @@ class Gemma2Block(attention.TransformerBlock):
104
104
  class Gemma2(nn.Module):
105
105
  """A Gemma2 model built from the Edge Generative API layers."""
106
106
 
107
- def __init__(self, config: cfg.ModelConfig):
107
+ def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
108
108
  super().__init__()
109
109
 
110
110
  # Construct model layers.
@@ -126,17 +126,24 @@ class Gemma2(nn.Module):
126
126
  config.embedding_dim,
127
127
  config.final_norm_config,
128
128
  )
129
- self.mask_cache = attn_utils.build_causal_mask_cache(
130
- size=config.kv_cache_max,
131
- )
129
+ self.config = config
130
+ self.build_mask_cache(mask_cache_size)
131
+
132
+ def build_mask_cache(self, mask_cache_size: int):
133
+ assert (
134
+ mask_cache_size <= self.config.max_seq_len
135
+ ), "Mask cache size must be less than or equal to the max seq length."
136
+ if mask_cache_size <= 0:
137
+ self.mask_cache = None
138
+ self.sliding_window_mask_cache = None
139
+ return
140
+ self.mask_cache = attn_utils.build_causal_mask_cache(mask_cache_size)
132
141
  # Gemma2 has same hyper parameters for each layer except for attention
133
142
  # types. Use the first layer.
134
- attn_config = config.block_config(0).attn_config
135
143
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
136
- size=config.kv_cache_max,
137
- window_size=attn_config.sliding_window_size,
144
+ size=mask_cache_size,
145
+ window_size=self.config.block_config(0).attn_config.sliding_window_size,
138
146
  )
139
- self.config = config
140
147
 
141
148
  def get_attention_mask(
142
149
  self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
@@ -167,6 +174,7 @@ class Gemma2(nn.Module):
167
174
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
168
175
  rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
169
176
  if mask is None:
177
+ assert self.mask_cache is not None, "Mask cache must be built."
170
178
  mask = [
171
179
  self.get_attention_mask(
172
180
  self.config.block_config(i).attn_config.attn_type, input_pos
@@ -222,16 +230,8 @@ class Gemma2(nn.Module):
222
230
  return {"logits": res, "kv_cache": updated_kv_cache}
223
231
 
224
232
 
225
- def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
226
- """Returns the model config for a Gemma2 2B model.
227
-
228
- Args:
229
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
230
- is 1024.
231
-
232
- Returns:
233
- The model config for a Gemma 2B model.
234
- """
233
+ def get_model_config_2b() -> cfg.ModelConfig:
234
+ """Returns the model config for a Gemma2 2B model."""
235
235
  norm_config = cfg.NormalizationConfig(
236
236
  type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
237
237
  )
@@ -277,7 +277,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
277
277
  max_seq_len=8192,
278
278
  embedding_dim=embedding_dim,
279
279
  embedding_scale=embedding_dim**0.5,
280
- kv_cache_max_len=kv_cache_max_len,
281
280
  block_configs=[get_block_config(i) for i in range(num_layers)],
282
281
  final_norm_config=norm_config,
283
282
  lm_head_use_bias=False,
@@ -286,11 +285,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
286
285
  return config
287
286
 
288
287
 
289
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
290
- config = get_model_config_2b(kv_cache_max_len)
288
+ def get_fake_model_config() -> cfg.ModelConfig:
289
+ config = get_model_config_2b()
291
290
  config.vocab_size = 128
292
291
  config.num_layers = 2
293
- config.max_seq_len = 2 * kv_cache_max_len
292
+ config.max_seq_len = 256
294
293
  config.embedding_dim = 128
295
294
  config.embedding_scale = config.embedding_dim**0.5
296
295
  config.block_configs = config.block_configs[: config.num_layers]
@@ -305,16 +304,17 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
305
304
  def build_2b_model(
306
305
  checkpoint_path: str,
307
306
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
308
- **kwargs,
307
+ mask_cache_size: int = 0,
309
308
  ) -> nn.Module:
310
309
  for tensor_names in TENSOR_NAMES_DICT.values():
311
310
  try:
312
311
  return model_builder.build_decoder_only_model(
313
312
  checkpoint_path=checkpoint_path,
314
- config=get_model_config_2b(**kwargs),
313
+ config=get_model_config_2b(),
315
314
  tensor_names=tensor_names,
316
315
  model_class=Gemma2,
317
316
  custom_loader=custom_loader,
317
+ mask_cache_size=mask_cache_size,
318
318
  )
319
319
  except KeyError as _:
320
320
  continue
@@ -40,7 +40,7 @@ def main(_):
40
40
  custom_loader=loader.maybe_get_custom_loader(
41
41
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
42
  ),
43
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
43
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
44
44
  )
45
45
  else:
46
46
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
@@ -50,6 +50,7 @@ def main(_):
50
50
  output_path=flags.FLAGS.output_path,
51
51
  output_name_prefix=flags.FLAGS.output_name_prefix,
52
52
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
53
54
  quantize=flags.FLAGS.quantize,
54
55
  lora_ranks=flags.FLAGS.lora_ranks,
55
56
  export_config=export_config.get_from_flags(),
@@ -74,6 +74,7 @@ TENSOR_NAMES_DICT = {
74
74
 
75
75
 
76
76
  class DecoderBlock(attention.TransformerBlock):
77
+ """A Gemma3 decoder block built from the Edge Generative API layers."""
77
78
 
78
79
  def forward(
79
80
  self,
@@ -111,7 +112,7 @@ class DecoderBlock(attention.TransformerBlock):
111
112
  class Decoder(nn.Module):
112
113
  """A Gemma3 decoder model built from the Edge Generative API layers."""
113
114
 
114
- def __init__(self, config: cfg.ModelConfig):
115
+ def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
115
116
  super().__init__()
116
117
 
117
118
  # Construct model layers.
@@ -130,10 +131,17 @@ class Decoder(nn.Module):
130
131
  self.final_norm = builder.build_norm(
131
132
  config.embedding_dim, config.final_norm_config
132
133
  )
133
- self.mask_cache = attn_utils.build_causal_mask_cache(
134
- size=config.kv_cache_max,
135
- )
136
134
  self.config = config
135
+ self.build_mask_cache(mask_cache_size)
136
+
137
+ def build_mask_cache(self, mask_cache_size: int):
138
+ assert (
139
+ mask_cache_size <= self.config.max_seq_len
140
+ ), "Mask cache size must be less than or equal to the max seq length."
141
+ if mask_cache_size <= 0:
142
+ self.mask_cache = None
143
+ else:
144
+ self.mask_cache = attn_utils.build_causal_mask_cache(mask_cache_size)
137
145
 
138
146
  def get_local_global_attention_mask(
139
147
  self,
@@ -205,9 +213,8 @@ class Decoder(nn.Module):
205
213
  mask = torch.where(mask, 0, self.config.causal_mask_value)
206
214
  return mask
207
215
 
208
- def build_pixel_mask(self, image_indices: torch.Tensor):
216
+ def build_pixel_mask(self, image_indices: torch.Tensor, max_seq_len: int):
209
217
  pixel_mask = image_indices >= 0
210
- max_seq_len = self.config.kv_cache_max
211
218
  if pixel_mask.size(1) < max_seq_len:
212
219
  pixel_mask = torch.cat(
213
220
  [
@@ -234,14 +241,12 @@ class Decoder(nn.Module):
234
241
  image_indices: Optional[torch.Tensor] = None,
235
242
  export_config: Optional[export_cfg.ExportConfig] = None,
236
243
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
237
- pixel_mask = None
238
244
  if input_embeds is None:
239
245
  # token embeddings of shape (b, t, n_embd)
240
246
  input_embeds = self.tok_embedding(tokens)
241
247
  if self.config.embedding_scale is not None:
242
248
  input_embeds = input_embeds * self.config.embedding_scale
243
- if image_indices is not None:
244
- pixel_mask = self.build_pixel_mask(image_indices)
249
+
245
250
  # RoPE parameters are the same for all blocks. Use the first layer.
246
251
  attn_config = self.config.block_config(0).attn_config
247
252
  # Different rotary base for global and local attention
@@ -254,9 +259,19 @@ class Decoder(nn.Module):
254
259
  )
255
260
  for i in range(self.config.num_layers)
256
261
  ]
262
+
257
263
  if mask is None:
264
+ assert self.mask_cache is not None, "Mask cache must be built."
265
+ assert kv_cache is not None, "KV cache must be provided."
266
+ kv_cache_max_len = kv_cache.get_max_seq_len()
258
267
  mask = self.mask_cache.index_select(2, input_pos)
259
- mask = mask[:, :, :, : self.config.kv_cache_max]
268
+ mask = mask[:, :, :, :kv_cache_max_len]
269
+ else:
270
+ kv_cache_max_len = mask.size(3)
271
+
272
+ pixel_mask = None
273
+ if image_indices is not None:
274
+ pixel_mask = self.build_pixel_mask(image_indices, kv_cache_max_len)
260
275
 
261
276
  return self._forward_with_embeds(
262
277
  input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
@@ -322,16 +337,8 @@ class Decoder(nn.Module):
322
337
  return {"logits": res, "kv_cache": updated_kv_cache}
323
338
 
324
339
 
325
- def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
326
- """Returns the model config for a Gemma3 1B model.
327
-
328
- Args:
329
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
330
- is 2048.
331
-
332
- Returns:
333
- The model config for a Gemma 1B model.
334
- """
340
+ def get_decoder_config_1b() -> cfg.ModelConfig:
341
+ """Returns the model config for a Gemma3 1B model."""
335
342
  norm_config = cfg.NormalizationConfig(
336
343
  type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
337
344
  )
@@ -376,7 +383,6 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
376
383
  max_seq_len=32_768,
377
384
  embedding_dim=embedding_dim,
378
385
  embedding_scale=embedding_dim**0.5,
379
- kv_cache_max_len=kv_cache_max_len,
380
386
  block_configs=[get_block_config(i) for i in range(num_layers)],
381
387
  final_norm_config=norm_config,
382
388
  lm_head_use_bias=False,
@@ -385,20 +391,12 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
385
391
  return config
386
392
 
387
393
 
388
- def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
389
- """Returns a fake model config for a Gemma3 1B model.
390
-
391
- Args:
392
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
393
- is 128.
394
-
395
- Returns:
396
- A fake model config for a Gemma 1B model.
397
- """
398
- config = get_decoder_config_1b(kv_cache_max_len)
394
+ def get_fake_decoder_config_1b() -> cfg.ModelConfig:
395
+ """Returns a fake model config for a Gemma3 1B model."""
396
+ config = get_decoder_config_1b()
399
397
  config.vocab_size = 128
400
398
  config.num_layers = 2
401
- config.max_seq_len = 2 * kv_cache_max_len
399
+ config.max_seq_len = 256
402
400
  config.embedding_dim = 128
403
401
  config.embedding_scale = config.embedding_dim**0.5
404
402
  config.block_configs = config.block_configs[: config.num_layers]
@@ -413,7 +411,7 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
413
411
  def build_model_1b(
414
412
  checkpoint_path: str,
415
413
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
416
- **kwargs,
414
+ mask_cache_size: int = 0,
417
415
  ) -> nn.Module:
418
416
  # TODO(b/403644647): Better error handling for loading checkpoints with
419
417
  # different tensor names.
@@ -421,10 +419,11 @@ def build_model_1b(
421
419
  try:
422
420
  return model_builder.build_decoder_only_model(
423
421
  checkpoint_path=checkpoint_path,
424
- config=get_decoder_config_1b(**kwargs),
422
+ config=get_decoder_config_1b(),
425
423
  tensor_names=tensor_names,
426
424
  model_class=Decoder,
427
425
  custom_loader=custom_loader,
426
+ mask_cache_size=mask_cache_size,
428
427
  )
429
428
  except KeyError as ke:
430
429
  continue
@@ -48,13 +48,13 @@ class Gemma3MMConfig:
48
48
  class Gemma3MM(nn.Module):
49
49
  """A Gemma3 multimodal model built from the Edge Generative API layers."""
50
50
 
51
- def __init__(self, config: Gemma3MMConfig):
51
+ def __init__(self, config: Gemma3MMConfig, mask_cache_size: int = 0):
52
52
  super().__init__()
53
53
 
54
54
  self.image_encoder = image_encoder.SiglipVisionEncoderWithExit(
55
55
  config.image_encoder_config
56
56
  )
57
- self.decoder = decoder.Decoder(config.decoder_config)
57
+ self.decoder = decoder.Decoder(config.decoder_config, mask_cache_size)
58
58
  self.mm_norm = builder.build_norm(
59
59
  config.image_encoder_config.embedding_dim,
60
60
  config.mm_norm_config,
@@ -150,10 +150,10 @@ class Gemma3MM(nn.Module):
150
150
  )
151
151
 
152
152
 
153
- def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
153
+ def get_fake_model_config() -> Gemma3MMConfig:
154
154
  return Gemma3MMConfig(
155
155
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
156
- decoder_config=decoder.get_fake_decoder_config_1b(**kwargs),
156
+ decoder_config=decoder.get_fake_decoder_config_1b(),
157
157
  image_token_id=127,
158
158
  image_projection_scale=128**0.5,
159
159
  image_projection_use_bias=False,
@@ -167,13 +167,15 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
167
167
  def build_model_1b(
168
168
  checkpoint_path: str,
169
169
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
170
- **kwargs,
170
+ mask_cache_size: int = 0,
171
171
  ) -> decoder.Decoder:
172
172
  if checkpoint_path:
173
- model = decoder.build_model_1b(checkpoint_path, custom_loader, **kwargs)
173
+ model = decoder.build_model_1b(
174
+ checkpoint_path, custom_loader, mask_cache_size
175
+ )
174
176
  else:
175
- config = decoder.get_decoder_config_1b(**kwargs)
176
- model = decoder.Decoder(config)
177
+ config = decoder.get_decoder_config_1b()
178
+ model = decoder.Decoder(config, mask_cache_size)
177
179
  # TODO: Load the parameters of decoder from checkpoint.
178
180
  model.eval()
179
181
  return model
@@ -43,13 +43,14 @@ def main(_):
43
43
  custom_loader=loader.maybe_get_custom_loader(
44
44
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45
45
  ),
46
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
46
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
47
47
  )
48
48
  converter.convert_to_tflite(
49
49
  pytorch_model,
50
50
  output_path=flags.FLAGS.output_path,
51
51
  output_name_prefix=flags.FLAGS.output_name_prefix,
52
52
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
53
54
  quantize=flags.FLAGS.quantize,
54
55
  lora_ranks=flags.FLAGS.lora_ranks,
55
56
  export_config=export_config.get_from_flags(),
@@ -29,7 +29,7 @@ class Hammer(model_builder.DecoderOnlyModel):
29
29
  pass
30
30
 
31
31
 
32
- def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
32
+ def get_1_5b_model_config() -> cfg.ModelConfig:
33
33
  """Returns the model config for a Hammer 2.1 1.5B model."""
34
34
  attn_config = cfg.AttentionConfig(
35
35
  num_heads=12,
@@ -58,16 +58,15 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
58
58
  num_layers=28,
59
59
  max_seq_len=32768,
60
60
  embedding_dim=1536,
61
- kv_cache_max_len=kv_cache_max_len,
62
61
  block_configs=block_config,
63
62
  final_norm_config=norm_config,
64
63
  )
65
64
  return config
66
65
 
67
66
 
68
- def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
67
+ def get_0_5b_model_config() -> cfg.ModelConfig:
69
68
  """Returns the model config for a Hammer 2.1 0.5B model."""
70
- config = get_1_5b_model_config(kv_cache_max_len)
69
+ config = get_1_5b_model_config()
71
70
  # Hammer has only one block config.
72
71
  block_config = config.block_config(0)
73
72
  block_config.attn_config.num_heads = 14
@@ -78,8 +77,8 @@ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
78
77
  return config
79
78
 
80
79
 
81
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
82
- config = get_1_5b_model_config(**kwargs)
80
+ def get_fake_model_config() -> cfg.ModelConfig:
81
+ config = get_1_5b_model_config()
83
82
  config.vocab_size = 128
84
83
  config.num_layers = 2
85
84
  config.embedding_dim = 16
@@ -88,29 +87,37 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
88
87
  return config
89
88
 
90
89
 
91
- def build_1_5b_model(
90
+ def _build_model(
92
91
  checkpoint_path: str,
92
+ config: cfg.ModelConfig,
93
93
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
94
- **kwargs
94
+ mask_cache_size: int = 0,
95
95
  ) -> nn.Module:
96
96
  return model_builder.build_decoder_only_model(
97
97
  checkpoint_path=checkpoint_path,
98
- config=get_1_5b_model_config(**kwargs),
98
+ config=config,
99
99
  tensor_names=TENSOR_NAMES,
100
100
  model_class=Hammer,
101
101
  custom_loader=custom_loader,
102
+ mask_cache_size=mask_cache_size,
103
+ )
104
+
105
+
106
+ def build_1_5b_model(
107
+ checkpoint_path: str,
108
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
109
+ mask_cache_size: int = 0,
110
+ ) -> nn.Module:
111
+ return _build_model(
112
+ checkpoint_path, get_1_5b_model_config(), custom_loader, mask_cache_size
102
113
  )
103
114
 
104
115
 
105
116
  def build_0_5b_model(
106
117
  checkpoint_path: str,
107
118
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
108
- **kwargs
119
+ mask_cache_size: int = 0,
109
120
  ) -> nn.Module:
110
- return model_builder.build_decoder_only_model(
111
- checkpoint_path=checkpoint_path,
112
- config=get_0_5b_model_config(**kwargs),
113
- tensor_names=TENSOR_NAMES,
114
- model_class=Hammer,
115
- custom_loader=custom_loader,
121
+ return _build_model(
122
+ checkpoint_path, get_0_5b_model_config(), custom_loader, mask_cache_size
116
123
  )