ai-edge-torch-nightly 0.6.0.dev20250601__py3-none-any.whl → 0.6.0.dev20250603__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +7 -15
  2. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -1
  3. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -1
  4. ai_edge_torch/generative/examples/deepseek/deepseek.py +7 -15
  5. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -1
  6. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -1
  7. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -16
  8. ai_edge_torch/generative/examples/gemma/gemma2.py +24 -24
  9. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -1
  10. ai_edge_torch/generative/examples/gemma3/decoder.py +34 -35
  11. ai_edge_torch/generative/examples/gemma3/gemma3.py +10 -8
  12. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +2 -1
  13. ai_edge_torch/generative/examples/hammer/hammer.py +23 -16
  14. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -2
  15. ai_edge_torch/generative/examples/llama/llama.py +13 -26
  16. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -1
  17. ai_edge_torch/generative/examples/openelm/openelm.py +8 -16
  18. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -1
  19. ai_edge_torch/generative/examples/paligemma/decoder.py +12 -17
  20. ai_edge_torch/generative/examples/paligemma/decoder2.py +12 -17
  21. ai_edge_torch/generative/examples/paligemma/paligemma.py +14 -9
  22. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -1
  23. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +2 -1
  24. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -1
  25. ai_edge_torch/generative/examples/phi/phi2.py +8 -16
  26. ai_edge_torch/generative/examples/phi/phi3.py +8 -16
  27. ai_edge_torch/generative/examples/phi/phi4.py +8 -16
  28. ai_edge_torch/generative/examples/phi/verify_util.py +1 -3
  29. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -1
  30. ai_edge_torch/generative/examples/qwen/convert_v3_to_tflite.py +2 -1
  31. ai_edge_torch/generative/examples/qwen/qwen.py +29 -34
  32. ai_edge_torch/generative/examples/qwen/qwen3.py +29 -35
  33. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +2 -1
  34. ai_edge_torch/generative/examples/qwen_vl/decoder.py +11 -16
  35. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +8 -12
  36. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +2 -2
  37. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +2 -1
  38. ai_edge_torch/generative/examples/smollm/smollm.py +15 -30
  39. ai_edge_torch/generative/examples/t5/t5.py +23 -23
  40. ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
  41. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -1
  42. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +7 -15
  43. ai_edge_torch/generative/layers/kv_cache.py +13 -1
  44. ai_edge_torch/generative/layers/model_config.py +0 -14
  45. ai_edge_torch/generative/test/test_kv_cache.py +14 -24
  46. ai_edge_torch/generative/test/test_lora.py +4 -21
  47. ai_edge_torch/generative/test/test_model_conversion.py +8 -4
  48. ai_edge_torch/generative/test/test_model_conversion_large.py +27 -19
  49. ai_edge_torch/generative/utilities/converter.py +15 -6
  50. ai_edge_torch/generative/utilities/model_builder.py +16 -6
  51. ai_edge_torch/generative/utilities/verifier.py +16 -6
  52. ai_edge_torch/version.py +1 -1
  53. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/METADATA +1 -1
  54. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/RECORD +57 -57
  55. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/LICENSE +0 -0
  56. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/WHEEL +0 -0
  57. {ai_edge_torch_nightly-0.6.0.dev20250601.dist-info → ai_edge_torch_nightly-0.6.0.dev20250603.dist-info}/top_level.txt +0 -0
@@ -44,13 +44,14 @@ def main(_):
44
44
  custom_loader=loader.maybe_get_custom_loader(
45
45
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
46
  ),
47
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
47
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
48
48
  )
49
49
  converter.convert_to_tflite(
50
50
  pytorch_model,
51
51
  output_path=flags.FLAGS.output_path,
52
52
  output_name_prefix=flags.FLAGS.output_name_prefix,
53
53
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
54
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54
55
  quantize=flags.FLAGS.quantize,
55
56
  lora_ranks=flags.FLAGS.lora_ranks,
56
57
  export_config=export_config.get_from_flags(),
@@ -44,13 +44,14 @@ def main(_):
44
44
  custom_loader=loader.maybe_get_custom_loader(
45
45
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
46
  ),
47
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
47
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
48
48
  )
49
49
  converter.convert_to_tflite(
50
50
  pytorch_model,
51
51
  output_path=flags.FLAGS.output_path,
52
52
  output_name_prefix=flags.FLAGS.output_name_prefix,
53
53
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
54
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54
55
  quantize=flags.FLAGS.quantize,
55
56
  lora_ranks=flags.FLAGS.lora_ranks,
56
57
  export_config=export_config.get_from_flags(),
@@ -29,16 +29,8 @@ class Qwen(model_builder.DecoderOnlyModel):
29
29
  pass
30
30
 
31
31
 
32
- def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33
- """Returns the model config for a Qwen 2.5 3B model.
34
-
35
- Args:
36
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37
- is 1024.
38
-
39
- Returns:
40
- The model config for a SmolLM model.
41
- """
32
+ def get_3b_model_config() -> cfg.ModelConfig:
33
+ """Returns the model config for a Qwen 2.5 3B model."""
42
34
  attn_config = cfg.AttentionConfig(
43
35
  num_heads=16,
44
36
  head_dim=128,
@@ -66,16 +58,15 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
66
58
  num_layers=36,
67
59
  max_seq_len=32768,
68
60
  embedding_dim=2048,
69
- kv_cache_max_len=kv_cache_max_len,
70
61
  block_configs=block_config,
71
62
  final_norm_config=norm_config,
72
63
  )
73
64
  return config
74
65
 
75
66
 
76
- def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
67
+ def get_1_5b_model_config() -> cfg.ModelConfig:
77
68
  """Returns the model config for a Qwen 2.5 1B model."""
78
- config = get_3b_model_config(kv_cache_max_len)
69
+ config = get_3b_model_config()
79
70
  # Qwen has only one block config.
80
71
  block_config = config.block_config(0)
81
72
  block_config.attn_config.num_heads = 12
@@ -85,9 +76,9 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
85
76
  return config
86
77
 
87
78
 
88
- def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
79
+ def get_0_5b_model_config() -> cfg.ModelConfig:
89
80
  """Returns the model config for a Qwen 2.5 0.5B model."""
90
- config = get_3b_model_config(kv_cache_max_len)
81
+ config = get_3b_model_config()
91
82
  # Qwen has only one block config.
92
83
  block_config = config.block_config(0)
93
84
  block_config.attn_config.num_heads = 14
@@ -98,8 +89,8 @@ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
98
89
  return config
99
90
 
100
91
 
101
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
102
- config = get_3b_model_config(**kwargs)
92
+ def get_fake_model_config() -> cfg.ModelConfig:
93
+ config = get_3b_model_config()
103
94
  config.vocab_size = 128
104
95
  config.num_layers = 2
105
96
  # Qwen has only one block config.
@@ -107,43 +98,47 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
107
98
  return config
108
99
 
109
100
 
110
- def build_3b_model(
101
+ def _build_model(
111
102
  checkpoint_path: str,
103
+ config: cfg.ModelConfig,
112
104
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
113
- **kwargs
105
+ mask_cache_size: int = 0,
114
106
  ) -> nn.Module:
115
107
  return model_builder.build_decoder_only_model(
116
108
  checkpoint_path=checkpoint_path,
117
- config=get_3b_model_config(**kwargs),
109
+ config=config,
118
110
  tensor_names=TENSOR_NAMES,
119
111
  model_class=Qwen,
120
112
  custom_loader=custom_loader,
113
+ mask_cache_size=mask_cache_size,
114
+ )
115
+
116
+
117
+ def build_3b_model(
118
+ checkpoint_path: str,
119
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
120
+ mask_cache_size: int = 0,
121
+ ) -> nn.Module:
122
+ return _build_model(
123
+ checkpoint_path, get_3b_model_config(), custom_loader, mask_cache_size
121
124
  )
122
125
 
123
126
 
124
127
  def build_1_5b_model(
125
128
  checkpoint_path: str,
126
129
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
127
- **kwargs
130
+ mask_cache_size: int = 0,
128
131
  ) -> nn.Module:
129
- return model_builder.build_decoder_only_model(
130
- checkpoint_path=checkpoint_path,
131
- config=get_1_5b_model_config(**kwargs),
132
- tensor_names=TENSOR_NAMES,
133
- model_class=Qwen,
134
- custom_loader=custom_loader,
132
+ return _build_model(
133
+ checkpoint_path, get_1_5b_model_config(), custom_loader, mask_cache_size
135
134
  )
136
135
 
137
136
 
138
137
  def build_0_5b_model(
139
138
  checkpoint_path: str,
140
139
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
141
- **kwargs
140
+ mask_cache_size: int = 0,
142
141
  ) -> nn.Module:
143
- return model_builder.build_decoder_only_model(
144
- checkpoint_path=checkpoint_path,
145
- config=get_0_5b_model_config(**kwargs),
146
- tensor_names=TENSOR_NAMES,
147
- model_class=Qwen,
148
- custom_loader=custom_loader,
142
+ return _build_model(
143
+ checkpoint_path, get_0_5b_model_config(), custom_loader, mask_cache_size
149
144
  )
@@ -42,20 +42,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
42
 
43
43
  class Qwen3(model_builder.DecoderOnlyModel):
44
44
  """A Qwen3 model built from the Edge Generative API layers."""
45
-
46
45
  pass
47
46
 
48
47
 
49
- def get_4b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
50
- """Returns the model config for a Qwen 3.0 4B model.
51
-
52
- Args:
53
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
54
- is 1024.
55
-
56
- Returns:
57
- The model config for a SmolLM model.
58
- """
48
+ def get_4b_model_config() -> cfg.ModelConfig:
49
+ """Returns the model config for a Qwen 3.0 4B model."""
59
50
  norm_config = cfg.NormalizationConfig(
60
51
  type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
61
52
  )
@@ -87,16 +78,15 @@ def get_4b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
87
78
  num_layers=36,
88
79
  max_seq_len=40960,
89
80
  embedding_dim=2560,
90
- kv_cache_max_len=kv_cache_max_len,
91
81
  block_configs=block_config,
92
82
  final_norm_config=norm_config,
93
83
  )
94
84
  return config
95
85
 
96
86
 
97
- def get_1_7b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
87
+ def get_1_7b_model_config() -> cfg.ModelConfig:
98
88
  """Returns the model config for a Qwen 3.0 1.7B model."""
99
- config = get_4b_model_config(kv_cache_max_len)
89
+ config = get_4b_model_config()
100
90
  # Qwen has only one block config.
101
91
  block_config = config.block_config(0)
102
92
  block_config.attn_config.num_heads = 16
@@ -107,9 +97,9 @@ def get_1_7b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
107
97
  return config
108
98
 
109
99
 
110
- def get_0_6b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
100
+ def get_0_6b_model_config() -> cfg.ModelConfig:
111
101
  """Returns the model config for a Qwen 3.0 0.6B model."""
112
- config = get_4b_model_config(kv_cache_max_len)
102
+ config = get_4b_model_config()
113
103
  # Qwen has only one block config.
114
104
  block_config = config.block_config(0)
115
105
  block_config.attn_config.num_heads = 16
@@ -120,8 +110,8 @@ def get_0_6b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
120
110
  return config
121
111
 
122
112
 
123
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
124
- config = get_4b_model_config(**kwargs)
113
+ def get_fake_model_config() -> cfg.ModelConfig:
114
+ config = get_4b_model_config()
125
115
  config.vocab_size = 128
126
116
  config.num_layers = 2
127
117
  # Qwen has only one block config.
@@ -129,43 +119,47 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
129
119
  return config
130
120
 
131
121
 
132
- def build_4b_model(
122
+ def _build_model(
133
123
  checkpoint_path: str,
124
+ config: cfg.ModelConfig,
134
125
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
135
- **kwargs
126
+ mask_cache_size: int = 0,
136
127
  ) -> nn.Module:
137
128
  return model_builder.build_decoder_only_model(
138
129
  checkpoint_path=checkpoint_path,
139
- config=get_4b_model_config(**kwargs),
130
+ config=config,
140
131
  tensor_names=TENSOR_NAMES,
141
132
  model_class=Qwen3,
142
133
  custom_loader=custom_loader,
134
+ mask_cache_size=mask_cache_size,
135
+ )
136
+
137
+
138
+ def build_4b_model(
139
+ checkpoint_path: str,
140
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
141
+ mask_cache_size: int = 0,
142
+ ) -> nn.Module:
143
+ return _build_model(
144
+ checkpoint_path, get_4b_model_config(), custom_loader, mask_cache_size
143
145
  )
144
146
 
145
147
 
146
148
  def build_1_7b_model(
147
149
  checkpoint_path: str,
148
150
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
149
- **kwargs
151
+ mask_cache_size: int = 0,
150
152
  ) -> nn.Module:
151
- return model_builder.build_decoder_only_model(
152
- checkpoint_path=checkpoint_path,
153
- config=get_1_7b_model_config(**kwargs),
154
- tensor_names=TENSOR_NAMES,
155
- model_class=Qwen3,
156
- custom_loader=custom_loader,
153
+ return _build_model(
154
+ checkpoint_path, get_1_7b_model_config(), custom_loader, mask_cache_size
157
155
  )
158
156
 
159
157
 
160
158
  def build_0_6b_model(
161
159
  checkpoint_path: str,
162
160
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
163
- **kwargs
161
+ mask_cache_size: int = 0,
164
162
  ) -> nn.Module:
165
- return model_builder.build_decoder_only_model(
166
- checkpoint_path=checkpoint_path,
167
- config=get_0_6b_model_config(**kwargs),
168
- tensor_names=TENSOR_NAMES,
169
- model_class=Qwen3,
170
- custom_loader=custom_loader,
163
+ return _build_model(
164
+ checkpoint_path, get_0_6b_model_config(), custom_loader, mask_cache_size
171
165
  )
@@ -42,7 +42,7 @@ def main(_):
42
42
  custom_loader=loader.maybe_get_custom_loader(
43
43
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
44
44
  ),
45
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
45
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
46
46
  image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
47
47
  )
48
48
 
@@ -55,6 +55,7 @@ def main(_):
55
55
  output_path=flags.FLAGS.output_path,
56
56
  output_name_prefix=flags.FLAGS.output_name_prefix,
57
57
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
58
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
58
59
  pixel_values_size=(
59
60
  pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
60
61
  ),
@@ -60,8 +60,9 @@ class Decoder(model_builder.DecoderOnlyModel):
60
60
  rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
61
61
 
62
62
  if mask is None:
63
+ assert kv_cache is not None, "KV cache must be provided."
63
64
  mask = self.mask_cache.index_select(2, input_pos)
64
- mask = mask[:, :, :, : self.config.kv_cache_max]
65
+ mask = mask[:, :, :, :kv_cache.get_max_seq_len()]
65
66
 
66
67
  return self._forward_with_embeds(
67
68
  input_embeds,
@@ -73,16 +74,8 @@ class Decoder(model_builder.DecoderOnlyModel):
73
74
  )
74
75
 
75
76
 
76
- def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
77
- """Returns the model config for a Qwen 2.5 VL 3B model.
78
-
79
- Args:
80
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
81
- is 1024.
82
-
83
- Returns:
84
- The model config for a Qwen 2.5 VL 3B model.
85
- """
77
+ def get_decoder_config() -> cfg.ModelConfig:
78
+ """Returns the model config for a Qwen 2.5 VL 3B model."""
86
79
  attn_config = cfg.AttentionConfig(
87
80
  num_heads=16,
88
81
  head_dim=128,
@@ -110,15 +103,14 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
110
103
  num_layers=36,
111
104
  max_seq_len=32768,
112
105
  embedding_dim=2048,
113
- kv_cache_max_len=kv_cache_max_len,
114
106
  block_configs=block_config,
115
107
  final_norm_config=norm_config,
116
108
  )
117
109
  return config
118
110
 
119
111
 
120
- def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
121
- config = get_decoder_config(**kwargs)
112
+ def get_fake_decoder_config() -> cfg.ModelConfig:
113
+ config = get_decoder_config()
122
114
  config.vocab_size = 128
123
115
  config.num_layers = 2
124
116
  # Decoder has only one block config.
@@ -126,10 +118,13 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
126
118
  return config
127
119
 
128
120
 
129
- def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
121
+ def build_decoder(
122
+ checkpoint_path: str, mask_cache_size: int = 0
123
+ ) -> torch.nn.Module:
130
124
  return model_builder.build_decoder_only_model(
131
125
  checkpoint_path=checkpoint_path,
132
- config=get_decoder_config(**kwargs),
126
+ config=get_decoder_config(),
133
127
  tensor_names=TENSOR_NAMES,
134
128
  model_class=Decoder,
129
+ mask_cache_size=mask_cache_size,
135
130
  )
@@ -41,13 +41,13 @@ class QwenVLConfig:
41
41
  class QwenVL(nn.Module):
42
42
  """Qwen VL model from the Edge Generative API."""
43
43
 
44
- def __init__(self, config: QwenVLConfig):
44
+ def __init__(self, config: QwenVLConfig, mask_cache_size: int = 0):
45
45
  super().__init__()
46
46
 
47
47
  self.image_encoder = image_encoder.QwenVLImageEncoder(
48
48
  config.image_encoder_config
49
49
  )
50
- self.decoder = decoder.Decoder(config.decoder_config)
50
+ self.decoder = decoder.Decoder(config.decoder_config, mask_cache_size)
51
51
  # The amount of adjustment in input_pos to calculate RoPE properly in
52
52
  # forward() calls after image is handled.
53
53
  self.rope_pos_adjust = 0
@@ -179,26 +179,21 @@ class QwenVL(nn.Module):
179
179
 
180
180
 
181
181
  def get_model_config(
182
- kv_cache_max_len: int = 1024,
183
182
  image_size: Tuple[int, int] = (34 * 14, 46 * 14),
184
183
  ) -> QwenVLConfig:
185
- """Returns the model config for a PaliGemma 3B-224 model.
186
-
187
- Returns:
188
- The model config for a PaliGemma 3B model.
189
- """
184
+ """Returns the model config for a PaliGemma 3B-224 model."""
190
185
  return QwenVLConfig(
191
186
  image_encoder_config=image_encoder.get_image_encoder_config(image_size),
192
- decoder_config=decoder.get_decoder_config(kv_cache_max_len),
187
+ decoder_config=decoder.get_decoder_config(),
193
188
  image_token_id=151655,
194
189
  mrope_section=[16, 24, 24],
195
190
  )
196
191
 
197
192
 
198
- def get_fake_model_config(**kwargs) -> QwenVLConfig:
193
+ def get_fake_model_config() -> QwenVLConfig:
199
194
  return QwenVLConfig(
200
195
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
201
- decoder_config=decoder.get_fake_decoder_config(**kwargs),
196
+ decoder_config=decoder.get_fake_decoder_config(),
202
197
  image_token_id=127,
203
198
  mrope_section=[16, 24, 24],
204
199
  )
@@ -207,10 +202,11 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
207
202
  def build_model(
208
203
  checkpoint_path: str,
209
204
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
205
+ mask_cache_size: int = 0,
210
206
  **kwargs
211
207
  ) -> QwenVL:
212
208
  config = get_model_config(**kwargs)
213
- model = QwenVL(config)
209
+ model = QwenVL(config, mask_cache_size)
214
210
  image_encoder.load_image_encoder(
215
211
  checkpoint_path, model.image_encoder, custom_loader
216
212
  )
@@ -16,7 +16,6 @@
16
16
  """Example of converting SmolLM model to multi-signature tflite model."""
17
17
 
18
18
  from absl import app
19
- from absl import flags
20
19
  from ai_edge_torch.generative.examples.smollm import smollm
21
20
  from ai_edge_torch.generative.utilities import converter
22
21
  from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -38,7 +37,7 @@ def main(_):
38
37
  custom_loader=loader.maybe_get_custom_loader(
39
38
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
40
39
  ),
41
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
40
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
42
41
  )
43
42
 
44
43
  export_config = export_cfg.get_from_flags()
@@ -49,6 +48,7 @@ def main(_):
49
48
  output_path=flags.FLAGS.output_path,
50
49
  output_name_prefix=flags.FLAGS.output_name_prefix,
51
50
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
51
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52
52
  quantize=flags.FLAGS.quantize,
53
53
  lora_ranks=flags.FLAGS.lora_ranks,
54
54
  export_config=export_config,
@@ -37,7 +37,7 @@ def main(_):
37
37
  custom_loader=loader.maybe_get_custom_loader(
38
38
  checkpoint_path, flags.FLAGS.custom_checkpoint_loader
39
39
  ),
40
- kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
40
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
41
41
  )
42
42
 
43
43
  export_config = export_cfg.get_from_flags()
@@ -48,6 +48,7 @@ def main(_):
48
48
  output_path=flags.FLAGS.output_path,
49
49
  output_name_prefix=flags.FLAGS.output_name_prefix,
50
50
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
51
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
51
52
  quantize=flags.FLAGS.quantize,
52
53
  lora_ranks=flags.FLAGS.lora_ranks,
53
54
  export_config=export_config,
@@ -29,16 +29,8 @@ class SmolLM(model_builder.DecoderOnlyModel):
29
29
  pass
30
30
 
31
31
 
32
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33
- """Returns the model config for a SmolLM 135M model.
34
-
35
- Args:
36
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37
- is 1024.
38
-
39
- Returns:
40
- The model config for a SmolLM model.
41
- """
32
+ def get_model_config() -> cfg.ModelConfig:
33
+ """Returns the model config for a SmolLM 135M model."""
42
34
  attn_config = cfg.AttentionConfig(
43
35
  num_heads=9,
44
36
  head_dim=64,
@@ -63,15 +55,14 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
63
55
  num_layers=30,
64
56
  max_seq_len=2048,
65
57
  embedding_dim=576,
66
- kv_cache_max_len=kv_cache_max_len,
67
58
  block_configs=block_config,
68
59
  final_norm_config=norm_config,
69
60
  )
70
61
  return config
71
62
 
72
63
 
73
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
74
- config = get_model_config(**kwargs)
64
+ def get_fake_model_config() -> cfg.ModelConfig:
65
+ config = get_model_config()
75
66
  config.vocab_size = 128
76
67
  config.num_layers = 2
77
68
  # SmolLM has only one block config.
@@ -82,14 +73,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
82
73
  def build_model(
83
74
  checkpoint_path: str,
84
75
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
85
- **kwargs
76
+ mask_cache_size: int = 0,
86
77
  ) -> nn.Module:
87
78
  return model_builder.build_decoder_only_model(
88
79
  checkpoint_path=checkpoint_path,
89
- config=get_model_config(**kwargs),
80
+ config=get_model_config(),
90
81
  tensor_names=TENSOR_NAMES,
91
82
  model_class=SmolLM,
92
83
  custom_loader=custom_loader,
84
+ mask_cache_size=mask_cache_size,
93
85
  )
94
86
 
95
87
 
@@ -98,23 +90,15 @@ class SmolLM2(model_builder.DecoderOnlyModel):
98
90
  pass
99
91
 
100
92
 
101
- def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
102
- """Returns the model config for a SmolLM2 135M model.
103
-
104
- Args:
105
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
106
- is 1024.
107
-
108
- Returns:
109
- The model config for a SmolLM2 model.
110
- """
111
- config = get_model_config(kv_cache_max_len)
93
+ def get_model_config_v2() -> cfg.ModelConfig:
94
+ """Returns the model config for a SmolLM2 135M model."""
95
+ config = get_model_config()
112
96
  config.block_config(0).attn_config.rotary_base = 100000
113
97
  return config
114
98
 
115
99
 
116
- def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
117
- config = get_model_config_v2(**kwargs)
100
+ def get_fake_model_config_v2() -> cfg.ModelConfig:
101
+ config = get_model_config_v2()
118
102
  config.vocab_size = 128
119
103
  config.num_layers = 2
120
104
  # SmolLM2 has only one block config.
@@ -125,12 +109,13 @@ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
125
109
  def build_model_v2(
126
110
  checkpoint_path: str,
127
111
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
128
- **kwargs
112
+ mask_cache_size: int = 0,
129
113
  ) -> nn.Module:
130
114
  return model_builder.build_decoder_only_model(
131
115
  checkpoint_path=checkpoint_path,
132
- config=get_model_config_v2(**kwargs),
116
+ config=get_model_config_v2(),
133
117
  tensor_names=TENSOR_NAMES,
134
118
  model_class=SmolLM2,
135
119
  custom_loader=custom_loader,
120
+ mask_cache_size=mask_cache_size,
136
121
  )