ai-edge-torch-nightly 0.6.0.dev20250602__py3-none-any.whl → 0.6.0.dev20250603__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
  2. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
  3. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
  4. ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
  7. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
  8. ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
  9. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
  10. ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
  11. ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
  12. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
  13. ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
  14. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
  15. ai_edge_torch/generative/examples/llama/llama.py +13 -26
  16. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
  17. ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
  18. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
  19. ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
  20. ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
  21. ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
  22. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
  23. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
  24. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
  25. ai_edge_torch/generative/examples/phi/phi2.py +8 -16
  26. ai_edge_torch/generative/examples/phi/phi3.py +8 -16
  27. ai_edge_torch/generative/examples/phi/phi4.py +8 -16
  28. ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
  29. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
  30. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
  31. ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
  32. ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
  33. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
  34. ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
  35. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
  36. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
  37. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
  38. ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
  39. ai_edge_torch/generative/examples/t5/t5.py +23 -23
  40. ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
  41. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
  42. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
  43. ai_edge_torch/generative/layers/kv_cache.py +13 -1
  44. ai_edge_torch/generative/layers/model_config.py +0 -14
  45. ai_edge_torch/generative/test/test_kv_cache.py +14 -24
  46. ai_edge_torch/generative/test/test_lora.py +4 -21
  47. ai_edge_torch/generative/test/test_model_conversion.py +8 -4
  48. ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
  49. ai_edge_torch/generative/utilities/converter.py +15 -6
  50. ai_edge_torch/generative/utilities/model_builder.py +16 -6
  51. ai_edge_torch/generative/utilities/verifier.py +16 -6
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/METADATA +1 -1
  54. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/RECORD +57 -57
  55. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/LICENSE +0 -0
  56. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/WHEEL +0 -0
  57. {ai_edge_torch_nightly-0.6.0.dev20250602.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
  from ai_edge_torch.generative.utilities import loader
23
23
 
24
-
25
24
  flags = converter.define_conversion_flags('llama')
26
25
 
27
26
  _MODEL_SIZE = flags.DEFINE_enum(
@@ -44,13 +43,14 @@ def main(_):
44
43
  custom_loader=loader.maybe_get_custom_loader(
45
44
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
45
  ),
47
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
46
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
48
47
  )
49
48
  converter.convert_to_tflite(
50
49
  pytorch_model,
51
50
  output_path=flags.FLAGS.output_path,
52
51
  output_name_prefix=flags.FLAGS.output_name_prefix,
53
52
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54
54
  quantize=flags.FLAGS.quantize,
55
55
  lora_ranks=flags.FLAGS.lora_ranks,
56
56
  export_config=export_config.get_from_flags(),
@@ -93,22 +93,12 @@ class Llama(model_builder.DecoderOnlyModel):
93
93
 
94
94
  Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
95
95
  """
96
+ pass
96
97
 
97
- def __init__(self, config: cfg.ModelConfig):
98
- super().__init__(config)
99
- attn_config = self.config.block_config(0).attn_config
100
98
 
99
+ def get_1b_model_config() -> cfg.ModelConfig:
100
+ """Returns the model config for a Llama 3.2-1B model."""
101
101
 
102
- def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
103
- """Returns the model config for a Llama 3.2-1B model.
104
-
105
- Args:
106
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
107
- is 1024.
108
-
109
- Returns:
110
- The model config for a SmolLM model.
111
- """
112
102
  attn_config = cfg.AttentionConfig(
113
103
  num_heads=32,
114
104
  head_dim=64,
@@ -147,7 +137,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
137
  num_layers=16,
148
138
  max_seq_len=max_seq_len,
149
139
  embedding_dim=2048,
150
- kv_cache_max_len=kv_cache_max_len,
151
140
  block_configs=block_config,
152
141
  final_norm_config=norm_config,
153
142
  build_rope=build_rope,
@@ -155,9 +144,9 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
155
144
  return config
156
145
 
157
146
 
158
- def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
+ def get_3b_model_config() -> cfg.ModelConfig:
159
148
  """Returns the model config for a Llama 3.2-3B model."""
160
- config = get_1b_model_config(kv_cache_max_len)
149
+ config = get_1b_model_config()
161
150
  # Llama 3.2 has only one block config.
162
151
  attn_config = config.block_config(0).attn_config
163
152
  attn_config.num_heads = 24
@@ -167,8 +156,8 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
167
156
  return config
168
157
 
169
158
 
170
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
171
- config = get_1b_model_config(**kwargs)
159
+ def get_fake_model_config() -> cfg.ModelConfig:
160
+ config = get_1b_model_config()
172
161
  config.vocab_size = 128
173
162
  config.num_layers = 2
174
163
  # SmolLM has only one block config.
@@ -180,6 +169,7 @@ def _build_model(
180
169
  checkpoint_path: str,
181
170
  config: cfg.ModelConfig,
182
171
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
172
+ mask_cache_size: int = 0,
183
173
  ) -> torch.nn.Module:
184
174
  return model_builder.build_decoder_only_model(
185
175
  checkpoint_path=checkpoint_path,
@@ -187,28 +177,25 @@ def _build_model(
187
177
  tensor_names=TENSOR_NAMES,
188
178
  model_class=Llama,
189
179
  custom_loader=custom_loader,
180
+ mask_cache_size=mask_cache_size,
190
181
  )
191
182
 
192
183
 
193
184
  def build_1b_model(
194
185
  checkpoint_path: str,
195
186
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
196
- **kwargs
187
+ mask_cache_size: int = 0,
197
188
  ) -> torch.nn.Module:
198
189
  return _build_model(
199
- checkpoint_path,
200
- get_1b_model_config(**kwargs),
201
- custom_loader=custom_loader,
190
+ checkpoint_path, get_1b_model_config(), custom_loader, mask_cache_size
202
191
  )
203
192
 
204
193
 
205
194
  def build_3b_model(
206
195
  checkpoint_path: str,
207
196
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
208
- **kwargs
197
+ mask_cache_size: int = 0,
209
198
  ) -> torch.nn.Module:
210
199
  return _build_model(
211
- checkpoint_path,
212
- get_3b_model_config(**kwargs),
213
- custom_loader=custom_loader,
200
+ checkpoint_path, get_3b_model_config(), custom_loader, mask_cache_size
214
201
  )
@@ -31,13 +31,14 @@ def main(_):
31
31
  custom_loader=loader.maybe_get_custom_loader(
32
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
33
  ),
34
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
35
  )
36
36
  converter.convert_to_tflite(
37
37
  pytorch_model,
38
38
  output_path=flags.FLAGS.output_path,
39
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
40
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
42
  quantize=flags.FLAGS.quantize,
42
43
  lora_ranks=flags.FLAGS.lora_ranks,
43
44
  export_config=export_config.get_from_flags(),
@@ -42,16 +42,8 @@ class OpenELM(model_builder.DecoderOnlyModel):
42
42
  pass
43
43
 
44
44
 
45
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
46
- """Returns the model config for an OpenELM model.
47
-
48
- Args:
49
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
50
- is 1024.
51
-
52
- Returns:
53
- The model config for an OpenELM model.
54
- """
45
+ def get_model_config() -> cfg.ModelConfig:
46
+ """Returns the model config for an OpenELM model."""
55
47
  norm_config = cfg.NormalizationConfig(
56
48
  type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
57
49
  )
@@ -98,18 +90,17 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
98
90
  num_layers=num_layers,
99
91
  max_seq_len=2048,
100
92
  embedding_dim=3072,
101
- kv_cache_max_len=kv_cache_max_len,
102
93
  block_configs=[get_block_config(i) for i in range(num_layers)],
103
94
  final_norm_config=norm_config,
104
95
  )
105
96
  return config
106
97
 
107
98
 
108
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
109
- config = get_model_config(kv_cache_max_len)
99
+ def get_fake_model_config() -> cfg.ModelConfig:
100
+ config = get_model_config()
110
101
  config.vocab_size = 128
111
102
  config.num_layers = 2
112
- config.max_seq_len = 2 * kv_cache_max_len
103
+ config.max_seq_len = 256
113
104
  config.embedding_dim = 128
114
105
  config.block_configs = config.block_configs[: config.num_layers]
115
106
  for block_config in config.block_configs:
@@ -122,12 +113,13 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
122
113
  def build_model(
123
114
  checkpoint_path: str,
124
115
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
125
- **kwargs
116
+ mask_cache_size: int = 0,
126
117
  ) -> nn.Module:
127
118
  return model_builder.build_decoder_only_model(
128
119
  checkpoint_path=checkpoint_path,
129
- config=get_model_config(**kwargs),
120
+ config=get_model_config(),
130
121
  tensor_names=TENSOR_NAMES,
131
122
  model_class=OpenELM,
132
123
  custom_loader=custom_loader,
124
+ mask_cache_size=mask_cache_size,
133
125
  )
@@ -40,7 +40,7 @@ def main(_):
40
40
  custom_loader=loader.maybe_get_custom_loader(
41
41
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
42
  ),
43
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
43
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
44
44
  )
45
45
 
46
46
  config = pytorch_model.image_encoder.config.image_embedding
@@ -49,6 +49,7 @@ def main(_):
49
49
  output_path=flags.FLAGS.output_path,
50
50
  output_name_prefix=f'{flags.FLAGS.output_name_prefix}_{_VERSION.value}',
51
51
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
52
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
53
  pixel_values_size=torch.Size(
53
54
  [1, config.channels, config.image_size, config.image_size]
54
55
  ),
@@ -73,8 +73,9 @@ class Decoder(model_builder.DecoderOnlyModel):
73
73
  # The first part of input_embeds are image embeddings. Diagonal causal mask
74
74
  # doesn't work here.
75
75
  if mask is None:
76
+ assert kv_cache is not None, "KV cache must be provided."
76
77
  embeds_len = input_embeds.shape[1]
77
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
78
+ mask = torch.zeros(embeds_len, kv_cache.get_max_seq_len())
78
79
  mask[:, embeds_len:] = attn_config.causal_mask_value
79
80
 
80
81
  return self._forward_with_embeds(
@@ -87,16 +88,8 @@ class Decoder(model_builder.DecoderOnlyModel):
87
88
  )
88
89
 
89
90
 
90
- def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
91
- """Returns the model config for the decoder of a PaliGemma 3B model.
92
-
93
- Args:
94
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
95
- is 1024.
96
-
97
- Returns:
98
- The model config for the decoder of a PaliGemma 3B model.
99
- """
91
+ def get_decoder_config() -> cfg.ModelConfig:
92
+ """Returns the model config for the decoder of a PaliGemma 3B model."""
100
93
  attn_config = cfg.AttentionConfig(
101
94
  num_heads=8,
102
95
  head_dim=256,
@@ -125,7 +118,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
125
118
  max_seq_len=8192,
126
119
  embedding_dim=embedding_dim,
127
120
  embedding_scale=embedding_dim**0.5,
128
- kv_cache_max_len=kv_cache_max_len,
129
121
  block_configs=block_config,
130
122
  final_norm_config=norm_config,
131
123
  lm_head_use_bias=False,
@@ -133,22 +125,25 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
133
125
  return config
134
126
 
135
127
 
136
- def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
137
- config = get_decoder_config(kv_cache_max_len)
128
+ def get_fake_decoder_config() -> cfg.ModelConfig:
129
+ config = get_decoder_config()
138
130
  # PaliGemma decoder has only one block config.
139
131
  config.block_config(0).ff_config.intermediate_size = 128
140
132
  config.vocab_size = 128
141
133
  config.num_layers = 2
142
- config.max_seq_len = 2 * kv_cache_max_len
134
+ config.max_seq_len = 256
143
135
  config.embedding_dim = 128
144
136
  config.embedding_scale = 128**0.5
145
137
  return config
146
138
 
147
139
 
148
- def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
140
+ def build_decoder(
141
+ checkpoint_path: str, mask_cache_size: int = 0
142
+ ) -> torch.nn.Module:
149
143
  return model_builder.build_decoder_only_model(
150
144
  checkpoint_path=checkpoint_path,
151
- config=get_decoder_config(**kwargs),
145
+ config=get_decoder_config(),
152
146
  tensor_names=TENSOR_NAMES,
153
147
  model_class=Decoder,
148
+ mask_cache_size=mask_cache_size,
154
149
  )
@@ -73,8 +73,9 @@ class Decoder2(gemma2.Gemma2):
73
73
 
74
74
  if mask is None:
75
75
  # By default, don't mask image embeds with a diagonal causal mask.
76
+ assert kv_cache is not None, "KV cache must be provided."
76
77
  embeds_len = input_embeds.shape[1]
77
- mask = torch.zeros(embeds_len, self.config.kv_cache_max)
78
+ mask = torch.zeros(embeds_len, kv_cache.get_max_seq_len())
78
79
  mask[:, embeds_len:] = attn_config.causal_mask_value
79
80
 
80
81
  return self._forward_with_embeds(
@@ -82,16 +83,8 @@ class Decoder2(gemma2.Gemma2):
82
83
  )
83
84
 
84
85
 
85
- def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
86
- """Returns the model config for the decoder of a PaliGemma 3B model.
87
-
88
- Args:
89
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
90
- is 1024.
91
-
92
- Returns:
93
- The model config for the decoder of a PaliGemma 3B model.
94
- """
86
+ def get_decoder2_config() -> cfg.ModelConfig:
87
+ """Returns the model config for the decoder of a PaliGemma 3B model."""
95
88
  norm_config = cfg.NormalizationConfig(
96
89
  type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
97
90
  )
@@ -133,7 +126,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
133
126
  max_seq_len=8192,
134
127
  embedding_dim=embedding_dim,
135
128
  embedding_scale=embedding_dim**0.5,
136
- kv_cache_max_len=kv_cache_max_len,
137
129
  block_configs=[get_block_config(i) for i in range(num_layers)],
138
130
  final_norm_config=norm_config,
139
131
  lm_head_use_bias=False,
@@ -142,22 +134,25 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
142
134
  return config
143
135
 
144
136
 
145
- def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
146
- config = get_decoder2_config(kv_cache_max_len)
137
+ def get_fake_decoder2_config() -> cfg.ModelConfig:
138
+ config = get_decoder2_config()
147
139
  # PaliGemma2 decoder has only one block config.
148
140
  config.block_config(0).ff_config.intermediate_size = 128
149
141
  config.vocab_size = 128
150
142
  config.num_layers = 2
151
- config.max_seq_len = 2 * kv_cache_max_len
143
+ config.max_seq_len = 256
152
144
  config.embedding_dim = 128
153
145
  config.embedding_scale = 128**0.5
154
146
  return config
155
147
 
156
148
 
157
- def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
149
+ def build_decoder2(
150
+ checkpoint_path: str, mask_cache_size: int = 0
151
+ ) -> torch.nn.Module:
158
152
  return model_builder.build_decoder_only_model(
159
153
  checkpoint_path=checkpoint_path,
160
- config=get_decoder2_config(**kwargs),
154
+ config=get_decoder2_config(),
161
155
  tensor_names=TENSOR_NAMES,
162
156
  model_class=Decoder2,
157
+ mask_cache_size=mask_cache_size,
163
158
  )
@@ -45,7 +45,12 @@ class PaliGemmaConfig:
45
45
  class PaliGemma(nn.Module):
46
46
  """PaliGemma model from the Edge Generative API."""
47
47
 
48
- def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
48
+ def __init__(
49
+ self,
50
+ config: PaliGemmaConfig,
51
+ decoder_class: nn.Module,
52
+ mask_cache_size: int = 0,
53
+ ):
49
54
  super().__init__()
50
55
 
51
56
  self.image_encoder = image_encoder.SiglipVisionEncoder(
@@ -56,7 +61,7 @@ class PaliGemma(nn.Module):
56
61
  config.decoder_config.embedding_dim,
57
62
  bias=config.image_projection_use_bias,
58
63
  )
59
- self.decoder = decoder_class(config.decoder_config)
64
+ self.decoder = decoder_class(config.decoder_config, mask_cache_size)
60
65
  image_embedding_config = config.image_encoder_config.image_embedding
61
66
  self.num_patches = (
62
67
  image_embedding_config.image_size // image_embedding_config.patch_size
@@ -116,7 +121,7 @@ class PaliGemma(nn.Module):
116
121
  )
117
122
 
118
123
 
119
- def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
124
+ def get_model_config(get_decoder_config) -> PaliGemmaConfig:
120
125
  """Returns the model config for a PaliGemma 3B-224 model.
121
126
 
122
127
  Returns:
@@ -124,16 +129,16 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
124
129
  """
125
130
  return PaliGemmaConfig(
126
131
  image_encoder_config=image_encoder.get_image_encoder_config(),
127
- decoder_config=get_decoder_config(**kwargs),
132
+ decoder_config=get_decoder_config(),
128
133
  image_token_id=257152,
129
134
  image_projection_use_bias=True,
130
135
  )
131
136
 
132
137
 
133
- def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
138
+ def get_fake_model_config(get_decoder_config) -> PaliGemmaConfig:
134
139
  return PaliGemmaConfig(
135
140
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
136
- decoder_config=get_decoder_config(**kwargs),
141
+ decoder_config=get_decoder_config(),
137
142
  image_token_id=127,
138
143
  image_projection_use_bias=True,
139
144
  )
@@ -143,7 +148,7 @@ def build_model(
143
148
  checkpoint_path: str,
144
149
  version: int = 2,
145
150
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
146
- **kwargs,
151
+ mask_cache_size: int = 0,
147
152
  ) -> PaliGemma:
148
153
  if version == 1:
149
154
  decoder_class = decoder.Decoder
@@ -154,8 +159,8 @@ def build_model(
154
159
  decoder_tensor_names = decoder2.TENSOR_NAMES
155
160
  get_decoder_config = decoder2.get_decoder2_config
156
161
 
157
- config = get_model_config(get_decoder_config, **kwargs)
158
- model = PaliGemma(config, decoder_class)
162
+ config = get_model_config(get_decoder_config)
163
+ model = PaliGemma(config, decoder_class, mask_cache_size)
159
164
  # Load the parameters of image encoder.
160
165
  loader = loading_utils.ModelLoader(
161
166
  checkpoint_path, image_encoder.TENSOR_NAMES, custom_loader
@@ -31,13 +31,14 @@ def main(_):
31
31
  custom_loader=loader.maybe_get_custom_loader(
32
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
33
  ),
34
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
35
  )
36
36
  converter.convert_to_tflite(
37
37
  pytorch_model,
38
38
  output_path=flags.FLAGS.output_path,
39
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
40
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
42
  quantize=flags.FLAGS.quantize,
42
43
  lora_ranks=flags.FLAGS.lora_ranks,
43
44
  export_config=export_config.get_from_flags(),
@@ -31,13 +31,14 @@ def main(_):
31
31
  custom_loader=loader.maybe_get_custom_loader(
32
32
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
33
  ),
34
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
35
35
  )
36
36
  converter.convert_to_tflite(
37
37
  pytorch_model,
38
38
  output_path=flags.FLAGS.output_path,
39
39
  output_name_prefix=flags.FLAGS.output_name_prefix,
40
40
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
42
  quantize=flags.FLAGS.quantize,
42
43
  lora_ranks=flags.FLAGS.lora_ranks,
43
44
  export_config=export_config.get_from_flags(),
@@ -32,13 +32,14 @@ def main(_):
32
32
  custom_loader=loader.maybe_get_custom_loader(
33
33
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
34
34
  ),
35
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
35
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
36
36
  )
37
37
  converter.convert_to_tflite(
38
38
  pytorch_model,
39
39
  output_path=flags.FLAGS.output_path,
40
40
  output_name_prefix=flags.FLAGS.output_name_prefix,
41
41
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
42
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
43
  quantize=flags.FLAGS.quantize,
43
44
  lora_ranks=flags.FLAGS.lora_ranks,
44
45
  export_config=export_config.get_from_flags(),
@@ -41,16 +41,8 @@ class Phi2(model_builder.DecoderOnlyModel):
41
41
  pass
42
42
 
43
43
 
44
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
45
- """Returns the model config for a Phi-2 model.
46
-
47
- Args:
48
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
49
- is 1024.
50
-
51
- Returns:
52
- The model config for a Phi-2 model.
53
- """
44
+ def get_model_config() -> cfg.ModelConfig:
45
+ """Returns the model config for a Phi-2 model."""
54
46
  attn_config = cfg.AttentionConfig(
55
47
  num_heads=32,
56
48
  head_dim=80,
@@ -77,7 +69,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
77
69
  vocab_size=51200,
78
70
  num_layers=32,
79
71
  max_seq_len=2048,
80
- kv_cache_max_len=kv_cache_max_len,
81
72
  embedding_dim=2560,
82
73
  block_configs=block_config,
83
74
  final_norm_config=norm_config,
@@ -87,11 +78,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
87
78
  return config
88
79
 
89
80
 
90
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
91
- config = get_model_config(kv_cache_max_len)
81
+ def get_fake_model_config() -> cfg.ModelConfig:
82
+ config = get_model_config()
92
83
  config.vocab_size = 128
93
84
  config.num_layers = 2
94
- config.max_seq_len = 2 * kv_cache_max_len
85
+ config.max_seq_len = 256
95
86
  # Phi-2 has only one block config.
96
87
  config.block_config(0).ff_config.intermediate_size = 128
97
88
  return config
@@ -100,12 +91,13 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
100
91
  def build_model(
101
92
  checkpoint_path: str,
102
93
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
103
- **kwargs
94
+ mask_cache_size: int = 0,
104
95
  ) -> nn.Module:
105
96
  return model_builder.build_decoder_only_model(
106
97
  checkpoint_path=checkpoint_path,
107
- config=get_model_config(**kwargs),
98
+ config=get_model_config(),
108
99
  tensor_names=TENSOR_NAMES,
109
100
  model_class=Phi2,
110
101
  custom_loader=custom_loader,
102
+ mask_cache_size=mask_cache_size,
111
103
  )
@@ -139,16 +139,8 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
139
139
  pass
140
140
 
141
141
 
142
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
- """Returns the model config for a Phi-3.5 model.
144
-
145
- Args:
146
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
147
- is 1024.
148
-
149
- Returns:
150
- The model config for a Phi-3.5 model.
151
- """
142
+ def get_model_config() -> cfg.ModelConfig:
143
+ """Returns the model config for a Phi-3.5 model."""
152
144
  attn_config = cfg.AttentionConfig(
153
145
  num_heads=32,
154
146
  head_dim=96,
@@ -185,7 +177,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
185
177
  vocab_size=32064,
186
178
  num_layers=32,
187
179
  max_seq_len=max_seq_len,
188
- kv_cache_max_len=kv_cache_max_len,
189
180
  embedding_dim=3072,
190
181
  block_configs=block_config,
191
182
  final_norm_config=norm_config,
@@ -195,11 +186,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
195
186
  return config
196
187
 
197
188
 
198
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
199
- config = get_model_config(kv_cache_max_len)
189
+ def get_fake_model_config() -> cfg.ModelConfig:
190
+ config = get_model_config()
200
191
  config.vocab_size = 128
201
192
  config.num_layers = 2
202
- config.max_seq_len = 2 * kv_cache_max_len
193
+ config.max_seq_len = 256
203
194
  # Phi-3.5 has only one block config.
204
195
  config.block_config(0).ff_config.intermediate_size = 128
205
196
  return config
@@ -208,13 +199,14 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
208
199
  def build_model(
209
200
  checkpoint_path: str,
210
201
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
211
- **kwargs
202
+ mask_cache_size: int = 0,
212
203
  ) -> torch.nn.Module:
213
204
  """Instantiates the model instance and load checkpoint if provided."""
214
205
  return model_builder.build_decoder_only_model(
215
206
  checkpoint_path=checkpoint_path,
216
- config=get_model_config(**kwargs),
207
+ config=get_model_config(),
217
208
  tensor_names=TENSOR_NAMES,
218
209
  model_class=Phi3_5Mini,
219
210
  custom_loader=custom_loader,
211
+ mask_cache_size=mask_cache_size,
220
212
  )
@@ -89,16 +89,8 @@ class Phi4Mini(model_builder.DecoderOnlyModel):
89
89
  pass
90
90
 
91
91
 
92
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
- """Returns the model config for a Phi-4 model.
94
-
95
- Args:
96
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
97
- is 1024.
98
-
99
- Returns:
100
- The model config for a Phi-4 model.
101
- """
92
+ def get_model_config() -> cfg.ModelConfig:
93
+ """Returns the model config for a Phi-4 model."""
102
94
  attn_config = cfg.AttentionConfig(
103
95
  num_heads=24,
104
96
  head_dim=128,
@@ -135,7 +127,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
135
127
  vocab_size=200064,
136
128
  num_layers=32,
137
129
  max_seq_len=max_seq_len,
138
- kv_cache_max_len=kv_cache_max_len,
139
130
  embedding_dim=3072,
140
131
  block_configs=block_config,
141
132
  final_norm_config=norm_config,
@@ -144,11 +135,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
144
135
  return config
145
136
 
146
137
 
147
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
148
- config = get_model_config(kv_cache_max_len)
138
+ def get_fake_model_config() -> cfg.ModelConfig:
139
+ config = get_model_config()
149
140
  config.vocab_size = 128
150
141
  config.num_layers = 2
151
- config.max_seq_len = 2 * kv_cache_max_len
142
+ config.max_seq_len = 256
152
143
  # Phi-4 has only one block config.
153
144
  config.block_config(0).ff_config.intermediate_size = 128
154
145
  return config
@@ -157,13 +148,14 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
148
  def build_model(
158
149
  checkpoint_path: str,
159
150
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
160
- **kwargs
151
+ mask_cache_size: int = 0,
161
152
  ) -> torch.nn.Module:
162
153
  """Instantiates the model instance and load checkpoint if provided."""
163
154
  return model_builder.build_decoder_only_model(
164
155
  checkpoint_path=checkpoint_path,
165
- config=get_model_config(**kwargs),
156
+ config=get_model_config(),
166
157
  tensor_names=TENSOR_NAMES,
167
158
  model_class=Phi4Mini,
168
159
  custom_loader=custom_loader,
160
+ mask_cache_size=mask_cache_size,
169
161
  )
@@ -15,7 +15,6 @@
15
15
  """Utils for verifying the Phi model."""
16
16
 
17
17
  import logging
18
- import os
19
18
  import pathlib
20
19
  from typing import Callable, Dict
21
20
 
@@ -39,7 +38,6 @@ _BUILDER = {
39
38
  def verify_phi(
40
39
  version: str,
41
40
  checkpoint_dir: str,
42
- weight_filename: str = "model.safetensors",
43
41
  max_new_tokens: int = 30,
44
42
  prompts: list[str] | None = None,
45
43
  atol: float = 1e-04,
@@ -63,7 +61,7 @@ def verify_phi(
63
61
  )
64
62
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
65
63
  else:
66
- reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
64
+ reauthored_checkpoint = checkpoint_dir
67
65
 
68
66
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
69
67
  reauthored_model = _BUILDER[version](