ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/_convert/conversion.py +24 -0
  3. ai_edge_torch/_convert/converter.py +57 -3
  4. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  5. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  6. ai_edge_torch/_convert/test/test_convert.py +25 -0
  7. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +10 -6
  8. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
  9. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
  10. ai_edge_torch/generative/examples/deepseek/deepseek.py +9 -5
  11. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
  12. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
  13. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -6
  14. ai_edge_torch/generative/examples/gemma/gemma2.py +8 -7
  15. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
  16. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  17. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  18. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  19. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
  20. ai_edge_torch/generative/examples/hammer/hammer.py +15 -6
  21. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
  22. ai_edge_torch/generative/examples/llama/llama.py +26 -10
  23. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
  24. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
  25. ai_edge_torch/generative/examples/openelm/openelm.py +9 -3
  26. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
  27. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -4
  28. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -4
  29. ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -5
  30. ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
  31. ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
  32. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
  33. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
  34. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
  35. ai_edge_torch/generative/examples/phi/phi2.py +9 -5
  36. ai_edge_torch/generative/examples/phi/phi3.py +8 -6
  37. ai_edge_torch/generative/examples/phi/phi4.py +8 -6
  38. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
  39. ai_edge_torch/generative/examples/qwen/qwen.py +21 -7
  40. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
  41. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -3
  42. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +13 -7
  43. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
  44. ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
  45. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
  46. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
  47. ai_edge_torch/generative/examples/smollm/smollm.py +15 -6
  48. ai_edge_torch/generative/examples/smollm/verify.py +2 -2
  49. ai_edge_torch/generative/examples/stable_diffusion/clip.py +8 -5
  50. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  51. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +9 -5
  55. ai_edge_torch/generative/layers/model_config.py +2 -2
  56. ai_edge_torch/generative/utilities/converter.py +18 -5
  57. ai_edge_torch/generative/utilities/loader.py +19 -0
  58. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  59. ai_edge_torch/version.py +1 -1
  60. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
  61. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +64 -63
  62. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
  63. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
  64. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building Hammer 2.1 models."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -43,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
43
45
  intermediate_size=8960,
44
46
  )
45
47
  norm_config = cfg.NormalizationConfig(
46
- type=cfg.NormalizationType.RMS_NORM,
47
- epsilon=1e-06,
48
- enable_hlfb=True,
48
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
49
49
  )
50
50
  block_config = cfg.TransformerBlockConfig(
51
51
  attn_config=attn_config,
@@ -61,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
61
61
  kv_cache_max_len=kv_cache_max_len,
62
62
  block_configs=block_config,
63
63
  final_norm_config=norm_config,
64
- enable_hlfb=True,
65
64
  )
66
65
  return config
67
66
 
@@ -89,19 +88,29 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
89
88
  return config
90
89
 
91
90
 
92
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
91
+ def build_1_5b_model(
92
+ checkpoint_path: str,
93
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
94
+ **kwargs
95
+ ) -> nn.Module:
93
96
  return model_builder.build_decoder_only_model(
94
97
  checkpoint_path=checkpoint_path,
95
98
  config=get_1_5b_model_config(**kwargs),
96
99
  tensor_names=TENSOR_NAMES,
97
100
  model_class=Hammer,
101
+ custom_loader=custom_loader,
98
102
  )
99
103
 
100
104
 
101
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
105
+ def build_0_5b_model(
106
+ checkpoint_path: str,
107
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
108
+ **kwargs
109
+ ) -> nn.Module:
102
110
  return model_builder.build_decoder_only_model(
103
111
  checkpoint_path=checkpoint_path,
104
112
  config=get_0_5b_model_config(**kwargs),
105
113
  tensor_names=TENSOR_NAMES,
106
114
  model_class=Hammer,
115
+ custom_loader=custom_loader,
107
116
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.llama import llama
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
 
24
25
  flags = converter.define_conversion_flags('llama')
@@ -37,8 +38,13 @@ _BUILDER = {
37
38
 
38
39
 
39
40
  def main(_):
41
+ checkpoint_path = flags.FLAGS.checkpoint_path
40
42
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
41
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
43
+ checkpoint_path,
44
+ custom_loader=loader.maybe_get_custom_loader(
45
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
+ ),
47
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
48
  )
43
49
  converter.convert_to_tflite(
44
50
  pytorch_model,
@@ -17,7 +17,7 @@
17
17
 
18
18
  from functools import partial
19
19
  import math
20
- from typing import Tuple
20
+ from typing import Callable, Dict, Tuple
21
21
 
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
23
  from ai_edge_torch.generative.utilities import model_builder
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
121
121
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
122
122
  intermediate_size=8192,
123
123
  )
124
- norm_config = cfg.NormalizationConfig(
125
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
126
- )
124
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
127
125
  block_config = cfg.TransformerBlockConfig(
128
126
  attn_config=attn_config,
129
127
  ff_config=ff_config,
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
152
150
  kv_cache_max_len=kv_cache_max_len,
153
151
  block_configs=block_config,
154
152
  final_norm_config=norm_config,
155
- enable_hlfb=True,
156
153
  build_rope=build_rope,
157
154
  )
158
155
  return config
@@ -180,19 +177,38 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
180
177
 
181
178
 
182
179
  def _build_model(
183
- checkpoint_path: str, config: cfg.ModelConfig
180
+ checkpoint_path: str,
181
+ config: cfg.ModelConfig,
182
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
184
183
  ) -> torch.nn.Module:
185
184
  return model_builder.build_decoder_only_model(
186
185
  checkpoint_path=checkpoint_path,
187
186
  config=config,
188
187
  tensor_names=TENSOR_NAMES,
189
188
  model_class=Llama,
189
+ custom_loader=custom_loader,
190
190
  )
191
191
 
192
192
 
193
- def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
194
- return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
193
+ def build_1b_model(
194
+ checkpoint_path: str,
195
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
196
+ **kwargs
197
+ ) -> torch.nn.Module:
198
+ return _build_model(
199
+ checkpoint_path,
200
+ get_1b_model_config(**kwargs),
201
+ custom_loader=custom_loader,
202
+ )
195
203
 
196
204
 
197
- def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
198
- return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
205
+ def build_3b_model(
206
+ checkpoint_path: str,
207
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
208
+ **kwargs
209
+ ) -> torch.nn.Module:
210
+ return _build_model(
211
+ checkpoint_path,
212
+ get_3b_model_config(**kwargs),
213
+ custom_loader=custom_loader,
214
+ )
@@ -22,7 +22,6 @@ from absl import app
22
22
  from absl import flags
23
23
  import ai_edge_torch
24
24
  from ai_edge_torch.generative.examples.moonshine import moonshine
25
- from ai_edge_torch.generative.utilities import converter
26
25
  import torch
27
26
 
28
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.openelm import openelm
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("openelm")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = openelm.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
20
21
  import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
21
23
  from torch import nn
22
24
 
23
25
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -51,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
53
  The model config for an OpenELM model.
52
54
  """
53
55
  norm_config = cfg.NormalizationConfig(
54
- type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
55
57
  )
56
58
  num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
57
59
  num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
@@ -99,7 +101,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
99
101
  kv_cache_max_len=kv_cache_max_len,
100
102
  block_configs=[get_block_config(i) for i in range(num_layers)],
101
103
  final_norm_config=norm_config,
102
- enable_hlfb=True,
103
104
  )
104
105
  return config
105
106
 
@@ -118,10 +119,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
118
119
  return config
119
120
 
120
121
 
121
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
122
+ def build_model(
123
+ checkpoint_path: str,
124
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
125
+ **kwargs
126
+ ) -> nn.Module:
122
127
  return model_builder.build_decoder_only_model(
123
128
  checkpoint_path=checkpoint_path,
124
129
  config=get_model_config(**kwargs),
125
130
  tensor_names=TENSOR_NAMES,
126
131
  model_class=OpenELM,
132
+ custom_loader=custom_loader,
127
133
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.paligemma import paligemma
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
  import torch
23
24
 
24
25
  flags = converter.define_conversion_flags('paligemma2-3b-224')
@@ -32,9 +33,13 @@ _VERSION = flags.DEFINE_enum(
32
33
 
33
34
 
34
35
  def main(_):
36
+ checkpoint_path = flags.FLAGS.checkpoint_path
35
37
  pytorch_model = paligemma.build_model(
36
- flags.FLAGS.checkpoint_path,
38
+ checkpoint_path,
37
39
  version=int(_VERSION.value),
40
+ custom_loader=loader.maybe_get_custom_loader(
41
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42
+ ),
38
43
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
39
44
  )
40
45
 
@@ -110,9 +110,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
110
110
  intermediate_size=16384,
111
111
  )
112
112
  norm_config = cfg.NormalizationConfig(
113
- type=cfg.NormalizationType.RMS_NORM,
114
- epsilon=1e-6,
115
- zero_centered=True,
113
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
116
114
  )
117
115
  block_config = cfg.TransformerBlockConfig(
118
116
  attn_config=attn_config,
@@ -131,7 +129,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
131
129
  block_configs=block_config,
132
130
  final_norm_config=norm_config,
133
131
  lm_head_use_bias=False,
134
- enable_hlfb=True,
135
132
  )
136
133
  return config
137
134
 
@@ -93,9 +93,7 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
93
93
  The model config for the decoder of a PaliGemma 3B model.
94
94
  """
95
95
  norm_config = cfg.NormalizationConfig(
96
- type=cfg.NormalizationType.RMS_NORM,
97
- epsilon=1e-6,
98
- zero_centered=True,
96
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
99
97
  )
100
98
  ff_config = cfg.FeedForwardConfig(
101
99
  type=cfg.FeedForwardType.GATED,
@@ -139,7 +137,6 @@ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
137
  block_configs=[get_block_config(i) for i in range(num_layers)],
140
138
  final_norm_config=norm_config,
141
139
  lm_head_use_bias=False,
142
- enable_hlfb=True,
143
140
  final_logit_softcap=30.0,
144
141
  )
145
142
  return config
@@ -66,7 +66,8 @@ class SiglipVisionEncoder(nn.Module):
66
66
  config.image_embedding.image_size // config.image_embedding.patch_size
67
67
  ) ** 2
68
68
  self.tok_embedding_position = nn.Parameter(
69
- torch.zeros((num_patches, config.embedding_dim))
69
+ torch.zeros((num_patches, config.embedding_dim)),
70
+ requires_grad=False,
70
71
  )
71
72
 
72
73
  self.transformer_blocks = nn.ModuleList(
@@ -117,9 +118,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
117
118
  use_bias=True,
118
119
  )
119
120
  norm_config = cfg.NormalizationConfig(
120
- type=cfg.NormalizationType.LAYER_NORM,
121
- epsilon=1e-6,
122
- enable_hlfb=True,
121
+ type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
123
122
  )
124
123
  block_config = cfg.TransformerBlockConfig(
125
124
  attn_config=attn_config,
@@ -136,7 +135,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
136
135
  image_embedding=image_embedding_config,
137
136
  block_configs=block_config,
138
137
  final_norm_config=norm_config,
139
- enable_hlfb=True,
140
138
  )
141
139
  return config
142
140
 
@@ -16,7 +16,7 @@
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import Callable, Dict, Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
22
22
  from ai_edge_torch.generative.examples.paligemma import decoder2
@@ -139,7 +139,12 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
139
139
  )
140
140
 
141
141
 
142
- def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
142
+ def build_model(
143
+ checkpoint_path: str,
144
+ version: int = 2,
145
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
146
+ **kwargs,
147
+ ) -> PaliGemma:
143
148
  if version == 1:
144
149
  decoder_class = decoder.Decoder
145
150
  decoder_tensor_names = decoder.TENSOR_NAMES
@@ -153,15 +158,17 @@ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
153
158
  model = PaliGemma(config, decoder_class)
154
159
  # Load the parameters of image encoder.
155
160
  loader = loading_utils.ModelLoader(
156
- checkpoint_path, image_encoder.TENSOR_NAMES
161
+ checkpoint_path, image_encoder.TENSOR_NAMES, custom_loader
157
162
  )
158
163
  loader.load(model.image_encoder, strict=False)
159
164
  # Load the parameters of decoder.
160
- loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
165
+ loader = loading_utils.ModelLoader(
166
+ checkpoint_path, decoder_tensor_names, custom_loader
167
+ )
161
168
  loader.load(model.decoder, strict=False)
162
169
 
163
170
  # Load the parameters of image projection.
164
- loader = loading_utils.ModelLoader(checkpoint_path, None)
171
+ loader = loading_utils.ModelLoader(checkpoint_path, None, custom_loader)
165
172
  state = loader.get_state()
166
173
  converted_state = dict()
167
174
  converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
@@ -21,6 +21,7 @@ from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import paligemma
23
23
  from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import kagglehub
26
27
  from PIL import Image
@@ -39,10 +40,15 @@ _IMAGE_URL = flags.DEFINE_string(
39
40
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
40
41
  "The image URI to encode.",
41
42
  )
42
- _PROMPTS = flags.DEFINE_string(
43
- "prompts",
43
+ _PROMPTS_WITH_IMAGE = flags.DEFINE_string(
44
+ "prompts_with_image",
44
45
  "<image><bos>describe en",
45
- "The input prompts to generate answers.",
46
+ "The input prompts to generate answers with an image.",
47
+ )
48
+ _PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
49
+ "prompts_text_only",
50
+ "What is the meaning of life?",
51
+ "The input prompts to generate answers only with text.",
46
52
  )
47
53
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
48
54
  "max_new_tokens",
@@ -84,6 +90,7 @@ def main(_):
84
90
  reauthored_model = paligemma.build_model(
85
91
  reauthored_checkpoint, version=int(_VERSION.value)
86
92
  )
93
+ wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
87
94
 
88
95
  logging.info("Loading the processor from: %s", checkpoint)
89
96
  # It works only when GemmaTokenizerFast is available. In some environments,
@@ -91,9 +98,25 @@ def main(_):
91
98
  # sentencepiece model file properly.
92
99
  processor = transformers.AutoProcessor.from_pretrained(checkpoint)
93
100
 
101
+ logging.info("Verifying with text-only prompts...")
102
+ verifier.verify_reauthored_model(
103
+ original_model=transformers_verifier.TransformersModelWrapper(
104
+ original_model
105
+ ),
106
+ reauthored_model=wrapped_reauthored_model,
107
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
108
+ generate_prompts=_PROMPTS_TEXT_ONLY.value,
109
+ max_new_tokens=_MAX_NEW_TOKENS.value,
110
+ verify_inputs=False, # Numeric check not working. Disable it for now.
111
+ atol=1e-04,
112
+ )
113
+
114
+ logging.info("Verifying with image input...")
94
115
  logging.info("Loading the image from: %s", _IMAGE_URL.value)
95
116
  image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
96
- inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
117
+ inputs = processor(
118
+ text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
119
+ )
97
120
 
98
121
  logging.info("Verifying the reauthored model with model.forward()...")
99
122
  logging.info("Forwarding the original model...")
@@ -104,7 +127,6 @@ def main(_):
104
127
  logging.info("outputs_original: %s", outputs_original)
105
128
 
106
129
  logging.info("Forwarding the reauthored model...")
107
- wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
108
130
  outputs_reauthored = wrapped_reauthored_model.forward(
109
131
  tokens=inputs["input_ids"],
110
132
  pixel_values=inputs["pixel_values"],
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.phi import phi3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("phi3")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = phi3.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.phi import phi4
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("phi4")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = phi4.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -20,13 +20,19 @@ from absl import app
20
20
  from ai_edge_torch.generative.examples.phi import phi2
21
21
  from ai_edge_torch.generative.utilities import converter
22
22
  from ai_edge_torch.generative.utilities import export_config
23
+ from ai_edge_torch.generative.utilities import loader
23
24
 
24
25
  flags = converter.define_conversion_flags("phi2")
25
26
 
26
27
 
27
28
  def main(_):
29
+ checkpoint_path = flags.FLAGS.checkpoint_path
28
30
  pytorch_model = phi2.build_model(
29
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
31
+ checkpoint_path,
32
+ custom_loader=loader.maybe_get_custom_loader(
33
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
34
+ ),
35
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
30
36
  )
31
37
  converter.convert_to_tflite(
32
38
  pytorch_model,
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
20
21
  import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
21
23
  from torch import nn
22
24
 
23
25
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -64,9 +66,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
64
66
  intermediate_size=10240,
65
67
  use_bias=True,
66
68
  )
67
- norm_config = cfg.NormalizationConfig(
68
- type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
69
- )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
70
70
  block_config = cfg.TransformerBlockConfig(
71
71
  attn_config=attn_config,
72
72
  ff_config=ff_config,
@@ -83,7 +83,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
83
83
  final_norm_config=norm_config,
84
84
  lm_head_use_bias=True,
85
85
  lm_head_share_weight_with_embedding=False,
86
- enable_hlfb=True,
87
86
  )
88
87
  return config
89
88
 
@@ -98,10 +97,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
98
97
  return config
99
98
 
100
99
 
101
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
100
+ def build_model(
101
+ checkpoint_path: str,
102
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
103
+ **kwargs
104
+ ) -> nn.Module:
102
105
  return model_builder.build_decoder_only_model(
103
106
  checkpoint_path=checkpoint_path,
104
107
  config=get_model_config(**kwargs),
105
108
  tensor_names=TENSOR_NAMES,
106
109
  model_class=Phi2,
110
+ custom_loader=custom_loader,
107
111
  )
@@ -17,7 +17,7 @@
17
17
 
18
18
  from functools import partial
19
19
  import math
20
- from typing import Tuple
20
+ from typing import Callable, Dict, Tuple
21
21
 
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
23
  from ai_edge_torch.generative.utilities import model_builder
@@ -162,9 +162,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
162
162
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
163
163
  intermediate_size=8192,
164
164
  )
165
- norm_config = cfg.NormalizationConfig(
166
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
167
- )
165
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
168
166
  block_config = cfg.TransformerBlockConfig(
169
167
  attn_config=attn_config,
170
168
  ff_config=ff_config,
@@ -192,7 +190,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
192
190
  block_configs=block_config,
193
191
  final_norm_config=norm_config,
194
192
  lm_head_share_weight_with_embedding=False,
195
- enable_hlfb=True,
196
193
  build_rope=build_rope,
197
194
  )
198
195
  return config
@@ -208,11 +205,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
208
205
  return config
209
206
 
210
207
 
211
- def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
208
+ def build_model(
209
+ checkpoint_path: str,
210
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
211
+ **kwargs
212
+ ) -> torch.nn.Module:
212
213
  """Instantiates the model instance and load checkpoint if provided."""
213
214
  return model_builder.build_decoder_only_model(
214
215
  checkpoint_path=checkpoint_path,
215
216
  config=get_model_config(**kwargs),
216
217
  tensor_names=TENSOR_NAMES,
217
218
  model_class=Phi3_5Mini,
219
+ custom_loader=custom_loader,
218
220
  )
@@ -17,7 +17,7 @@
17
17
 
18
18
  from functools import partial
19
19
  import math
20
- from typing import Tuple
20
+ from typing import Callable, Dict, Tuple
21
21
 
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
23
  from ai_edge_torch.generative.utilities import model_builder
@@ -112,9 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113
113
  intermediate_size=8192,
114
114
  )
115
- norm_config = cfg.NormalizationConfig(
116
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
117
- )
115
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
118
116
  block_config = cfg.TransformerBlockConfig(
119
117
  attn_config=attn_config,
120
118
  ff_config=ff_config,
@@ -141,7 +139,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
141
139
  embedding_dim=3072,
142
140
  block_configs=block_config,
143
141
  final_norm_config=norm_config,
144
- enable_hlfb=True,
145
142
  build_rope=build_rope,
146
143
  )
147
144
  return config
@@ -157,11 +154,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
154
  return config
158
155
 
159
156
 
160
- def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
157
+ def build_model(
158
+ checkpoint_path: str,
159
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
160
+ **kwargs
161
+ ) -> torch.nn.Module:
161
162
  """Instantiates the model instance and load checkpoint if provided."""
162
163
  return model_builder.build_decoder_only_model(
163
164
  checkpoint_path=checkpoint_path,
164
165
  config=get_model_config(**kwargs),
165
166
  tensor_names=TENSOR_NAMES,
166
167
  model_class=Phi4Mini,
168
+ custom_loader=custom_loader,
167
169
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('qwen')
24
25
 
@@ -37,8 +38,13 @@ _BUILDER = {
37
38
 
38
39
 
39
40
  def main(_):
41
+ checkpoint_path = flags.FLAGS.checkpoint_path
40
42
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
41
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
43
+ checkpoint_path,
44
+ custom_loader=loader.maybe_get_custom_loader(
45
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
+ ),
47
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
48
  )
43
49
  converter.convert_to_tflite(
44
50
  pytorch_model,