ai-edge-torch-nightly 0.5.0.dev20250514__py3-none-any.whl → 0.5.0.dev20250516__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/_convert/conversion.py +23 -0
  3. ai_edge_torch/_convert/converter.py +57 -3
  4. ai_edge_torch/_convert/test/test_convert.py +25 -0
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +9 -2
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
  7. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
  8. ai_edge_torch/generative/examples/deepseek/deepseek.py +8 -1
  9. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
  10. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
  11. ai_edge_torch/generative/examples/gemma/gemma1.py +9 -1
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +7 -2
  13. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +6 -1
  14. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
  15. ai_edge_torch/generative/examples/hammer/hammer.py +14 -2
  16. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
  17. ai_edge_torch/generative/examples/llama/llama.py +25 -6
  18. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
  19. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
  20. ai_edge_torch/generative/examples/openelm/openelm.py +8 -1
  21. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
  22. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -0
  23. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -0
  24. ai_edge_torch/generative/examples/paligemma/image_encoder.py +2 -1
  25. ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
  26. ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
  27. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
  28. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
  29. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
  30. ai_edge_torch/generative/examples/phi/phi2.py +8 -1
  31. ai_edge_torch/generative/examples/phi/phi3.py +7 -2
  32. ai_edge_torch/generative/examples/phi/phi4.py +7 -2
  33. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
  34. ai_edge_torch/generative/examples/qwen/qwen.py +20 -3
  35. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
  36. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -2
  37. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +12 -4
  38. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
  39. ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
  40. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
  41. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
  42. ai_edge_torch/generative/examples/smollm/smollm.py +14 -2
  43. ai_edge_torch/generative/examples/smollm/verify.py +2 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -1
  45. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
  46. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -1
  47. ai_edge_torch/generative/layers/normalization.py +26 -7
  48. ai_edge_torch/generative/layers/normalization_test.py +73 -0
  49. ai_edge_torch/generative/utilities/converter.py +16 -4
  50. ai_edge_torch/generative/utilities/loader.py +45 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/RECORD +56 -55
  54. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.5.0.dev20250514.dist-info → ai_edge_torch_nightly-0.5.0.dev20250516.dist-info}/top_level.txt +0 -0
@@ -66,7 +66,8 @@ class SiglipVisionEncoder(nn.Module):
66
66
  config.image_embedding.image_size // config.image_embedding.patch_size
67
67
  ) ** 2
68
68
  self.tok_embedding_position = nn.Parameter(
69
- torch.zeros((num_patches, config.embedding_dim))
69
+ torch.zeros((num_patches, config.embedding_dim)),
70
+ requires_grad=False,
70
71
  )
71
72
 
72
73
  self.transformer_blocks = nn.ModuleList(
@@ -16,7 +16,7 @@
16
16
  """Example of building a full-stack of PaliGemma model."""
17
17
 
18
18
  import dataclasses
19
- from typing import Optional
19
+ from typing import Callable, Dict, Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
22
22
  from ai_edge_torch.generative.examples.paligemma import decoder2
@@ -139,7 +139,12 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
139
139
  )
140
140
 
141
141
 
142
- def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
142
+ def build_model(
143
+ checkpoint_path: str,
144
+ version: int = 2,
145
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
146
+ **kwargs,
147
+ ) -> PaliGemma:
143
148
  if version == 1:
144
149
  decoder_class = decoder.Decoder
145
150
  decoder_tensor_names = decoder.TENSOR_NAMES
@@ -153,15 +158,17 @@ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
153
158
  model = PaliGemma(config, decoder_class)
154
159
  # Load the parameters of image encoder.
155
160
  loader = loading_utils.ModelLoader(
156
- checkpoint_path, image_encoder.TENSOR_NAMES
161
+ checkpoint_path, image_encoder.TENSOR_NAMES, custom_loader
157
162
  )
158
163
  loader.load(model.image_encoder, strict=False)
159
164
  # Load the parameters of decoder.
160
- loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
165
+ loader = loading_utils.ModelLoader(
166
+ checkpoint_path, decoder_tensor_names, custom_loader
167
+ )
161
168
  loader.load(model.decoder, strict=False)
162
169
 
163
170
  # Load the parameters of image projection.
164
- loader = loading_utils.ModelLoader(checkpoint_path, None)
171
+ loader = loading_utils.ModelLoader(checkpoint_path, None, custom_loader)
165
172
  state = loader.get_state()
166
173
  converted_state = dict()
167
174
  converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
@@ -21,6 +21,7 @@ from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import paligemma
23
23
  from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import kagglehub
26
27
  from PIL import Image
@@ -39,10 +40,15 @@ _IMAGE_URL = flags.DEFINE_string(
39
40
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
40
41
  "The image URI to encode.",
41
42
  )
42
- _PROMPTS = flags.DEFINE_string(
43
- "prompts",
43
+ _PROMPTS_WITH_IMAGE = flags.DEFINE_string(
44
+ "prompts_with_image",
44
45
  "<image><bos>describe en",
45
- "The input prompts to generate answers.",
46
+ "The input prompts to generate answers with an image.",
47
+ )
48
+ _PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
49
+ "prompts_text_only",
50
+ "What is the meaning of life?",
51
+ "The input prompts to generate answers only with text.",
46
52
  )
47
53
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
48
54
  "max_new_tokens",
@@ -84,6 +90,7 @@ def main(_):
84
90
  reauthored_model = paligemma.build_model(
85
91
  reauthored_checkpoint, version=int(_VERSION.value)
86
92
  )
93
+ wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
87
94
 
88
95
  logging.info("Loading the processor from: %s", checkpoint)
89
96
  # It works only when GemmaTokenizerFast is available. In some environments,
@@ -91,9 +98,25 @@ def main(_):
91
98
  # sentencepiece model file properly.
92
99
  processor = transformers.AutoProcessor.from_pretrained(checkpoint)
93
100
 
101
+ logging.info("Verifying with text-only prompts...")
102
+ verifier.verify_reauthored_model(
103
+ original_model=transformers_verifier.TransformersModelWrapper(
104
+ original_model
105
+ ),
106
+ reauthored_model=wrapped_reauthored_model,
107
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
108
+ generate_prompts=_PROMPTS_TEXT_ONLY.value,
109
+ max_new_tokens=_MAX_NEW_TOKENS.value,
110
+ verify_inputs=False, # Numeric check not working. Disable it for now.
111
+ atol=1e-04,
112
+ )
113
+
114
+ logging.info("Verifying with image input...")
94
115
  logging.info("Loading the image from: %s", _IMAGE_URL.value)
95
116
  image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
96
- inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
117
+ inputs = processor(
118
+ text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
119
+ )
97
120
 
98
121
  logging.info("Verifying the reauthored model with model.forward()...")
99
122
  logging.info("Forwarding the original model...")
@@ -104,7 +127,6 @@ def main(_):
104
127
  logging.info("outputs_original: %s", outputs_original)
105
128
 
106
129
  logging.info("Forwarding the reauthored model...")
107
- wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
108
130
  outputs_reauthored = wrapped_reauthored_model.forward(
109
131
  tokens=inputs["input_ids"],
110
132
  pixel_values=inputs["pixel_values"],
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.phi import phi3
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("phi3")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = phi3.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.phi import phi4
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("phi4")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = phi4.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -20,13 +20,19 @@ from absl import app
20
20
  from ai_edge_torch.generative.examples.phi import phi2
21
21
  from ai_edge_torch.generative.utilities import converter
22
22
  from ai_edge_torch.generative.utilities import export_config
23
+ from ai_edge_torch.generative.utilities import loader
23
24
 
24
25
  flags = converter.define_conversion_flags("phi2")
25
26
 
26
27
 
27
28
  def main(_):
29
+ checkpoint_path = flags.FLAGS.checkpoint_path
28
30
  pytorch_model = phi2.build_model(
29
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
31
+ checkpoint_path,
32
+ custom_loader=loader.maybe_get_custom_loader(
33
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
34
+ ),
35
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
30
36
  )
31
37
  converter.convert_to_tflite(
32
38
  pytorch_model,
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
20
21
  import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
21
23
  from torch import nn
22
24
 
23
25
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -98,10 +100,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
98
100
  return config
99
101
 
100
102
 
101
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
103
+ def build_model(
104
+ checkpoint_path: str,
105
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
106
+ **kwargs
107
+ ) -> nn.Module:
102
108
  return model_builder.build_decoder_only_model(
103
109
  checkpoint_path=checkpoint_path,
104
110
  config=get_model_config(**kwargs),
105
111
  tensor_names=TENSOR_NAMES,
106
112
  model_class=Phi2,
113
+ custom_loader=custom_loader,
107
114
  )
@@ -17,7 +17,7 @@
17
17
 
18
18
  from functools import partial
19
19
  import math
20
- from typing import Tuple
20
+ from typing import Callable, Dict, Tuple
21
21
 
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
23
  from ai_edge_torch.generative.utilities import model_builder
@@ -208,11 +208,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
208
208
  return config
209
209
 
210
210
 
211
- def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
211
+ def build_model(
212
+ checkpoint_path: str,
213
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
214
+ **kwargs
215
+ ) -> torch.nn.Module:
212
216
  """Instantiates the model instance and load checkpoint if provided."""
213
217
  return model_builder.build_decoder_only_model(
214
218
  checkpoint_path=checkpoint_path,
215
219
  config=get_model_config(**kwargs),
216
220
  tensor_names=TENSOR_NAMES,
217
221
  model_class=Phi3_5Mini,
222
+ custom_loader=custom_loader,
218
223
  )
@@ -17,7 +17,7 @@
17
17
 
18
18
  from functools import partial
19
19
  import math
20
- from typing import Tuple
20
+ from typing import Callable, Dict, Tuple
21
21
 
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
23
  from ai_edge_torch.generative.utilities import model_builder
@@ -157,11 +157,16 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
157
  return config
158
158
 
159
159
 
160
- def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
160
+ def build_model(
161
+ checkpoint_path: str,
162
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
163
+ **kwargs
164
+ ) -> torch.nn.Module:
161
165
  """Instantiates the model instance and load checkpoint if provided."""
162
166
  return model_builder.build_decoder_only_model(
163
167
  checkpoint_path=checkpoint_path,
164
168
  config=get_model_config(**kwargs),
165
169
  tensor_names=TENSOR_NAMES,
166
170
  model_class=Phi4Mini,
171
+ custom_loader=custom_loader,
167
172
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('qwen')
24
25
 
@@ -37,8 +38,13 @@ _BUILDER = {
37
38
 
38
39
 
39
40
  def main(_):
41
+ checkpoint_path = flags.FLAGS.checkpoint_path
40
42
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
41
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
43
+ checkpoint_path,
44
+ custom_loader=loader.maybe_get_custom_loader(
45
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
46
+ ),
47
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42
48
  )
43
49
  converter.convert_to_tflite(
44
50
  pytorch_model,
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -108,28 +110,43 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
108
110
  return config
109
111
 
110
112
 
111
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
113
+ def build_3b_model(
114
+ checkpoint_path: str,
115
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
116
+ **kwargs
117
+ ) -> nn.Module:
112
118
  return model_builder.build_decoder_only_model(
113
119
  checkpoint_path=checkpoint_path,
114
120
  config=get_3b_model_config(**kwargs),
115
121
  tensor_names=TENSOR_NAMES,
116
122
  model_class=Qwen,
123
+ custom_loader=custom_loader,
117
124
  )
118
125
 
119
126
 
120
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
127
+ def build_1_5b_model(
128
+ checkpoint_path: str,
129
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
130
+ **kwargs
131
+ ) -> nn.Module:
121
132
  return model_builder.build_decoder_only_model(
122
133
  checkpoint_path=checkpoint_path,
123
134
  config=get_1_5b_model_config(**kwargs),
124
135
  tensor_names=TENSOR_NAMES,
125
136
  model_class=Qwen,
137
+ custom_loader=custom_loader,
126
138
  )
127
139
 
128
140
 
129
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
+ def build_0_5b_model(
142
+ checkpoint_path: str,
143
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
144
+ **kwargs
145
+ ) -> nn.Module:
130
146
  return model_builder.build_decoder_only_model(
131
147
  checkpoint_path=checkpoint_path,
132
148
  config=get_0_5b_model_config(**kwargs),
133
149
  tensor_names=TENSOR_NAMES,
134
150
  model_class=Qwen,
151
+ custom_loader=custom_loader,
135
152
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('qwen_vl')
24
25
 
@@ -35,8 +36,12 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
35
36
 
36
37
 
37
38
  def main(_):
39
+ checkpoint_path = flags.FLAGS.checkpoint_path
38
40
  pytorch_model = qwen_vl.build_model(
39
- flags.FLAGS.checkpoint_path,
41
+ checkpoint_path,
42
+ custom_loader=loader.maybe_get_custom_loader(
43
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
44
+ ),
40
45
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
46
  image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
42
47
  )
@@ -97,8 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
97
97
  intermediate_size=11008,
98
98
  )
99
99
  norm_config = cfg.NormalizationConfig(
100
- type=cfg.NormalizationType.RMS_NORM,
101
- epsilon=1e-06,
100
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06, enable_hlfb=True
102
101
  )
103
102
  block_config = cfg.TransformerBlockConfig(
104
103
  attn_config=attn_config,
@@ -16,7 +16,7 @@
16
16
  """Example of building an image encoder of Qwen 2.5 VL model."""
17
17
 
18
18
  import dataclasses
19
- from typing import List, Optional, Tuple
19
+ from typing import Callable, Dict, List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import attention
22
22
  from ai_edge_torch.generative.layers import attention_utils
@@ -385,13 +385,21 @@ def build_image_encoder(
385
385
  return encoder
386
386
 
387
387
 
388
- def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
389
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
388
+ def load_image_encoder(
389
+ checkpoint_path: str,
390
+ encoder: QwenVLImageEncoder,
391
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
392
+ ):
393
+ loader = loading_utils.ModelLoader(
394
+ checkpoint_path, TENSOR_NAMES, custom_loader
395
+ )
390
396
  # Loose the strictness because only image encoder is being loaded.
391
397
  loader.load(encoder, strict=False)
392
398
 
393
399
  # Load merger weights.
394
- merger_loader = loading_utils.ModelLoader(checkpoint_path, None)
400
+ merger_loader = loading_utils.ModelLoader(
401
+ checkpoint_path, None, custom_loader
402
+ )
395
403
  state = merger_loader.get_state()
396
404
  w1_state = dict()
397
405
  w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
@@ -16,7 +16,7 @@
16
16
  """Example of building a full-stack of Qwen 2.5 VL model."""
17
17
 
18
18
  import dataclasses
19
- from typing import List, Optional, Tuple
19
+ from typing import Callable, Dict, List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.examples.qwen_vl import decoder
22
22
  from ai_edge_torch.generative.examples.qwen_vl import image_encoder
@@ -204,12 +204,20 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
204
204
  )
205
205
 
206
206
 
207
- def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
207
+ def build_model(
208
+ checkpoint_path: str,
209
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
210
+ **kwargs
211
+ ) -> QwenVL:
208
212
  config = get_model_config(**kwargs)
209
213
  model = QwenVL(config)
210
- image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
214
+ image_encoder.load_image_encoder(
215
+ checkpoint_path, model.image_encoder, custom_loader
216
+ )
211
217
  # Load the parameters of decoder.
212
- loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
218
+ loader = loading_utils.ModelLoader(
219
+ checkpoint_path, decoder.TENSOR_NAMES, custom_loader
220
+ )
213
221
  loader.load(model.decoder, strict=False)
214
222
  model.eval()
215
223
  return model
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
24
24
  from ai_edge_torch.generative.layers import kv_cache
25
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
26
  from ai_edge_torch.generative.utilities import verifier
26
27
  from PIL import Image
27
28
  import requests
@@ -33,10 +34,15 @@ _IMAGE_URL = flags.DEFINE_string(
33
34
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
34
35
  "The image URI to encode.",
35
36
  )
36
- _PROMPTS = flags.DEFINE_string(
37
- "prompts",
37
+ _PROMPTS_WITH_IMAGE = flags.DEFINE_string(
38
+ "prompts_with_image",
38
39
  "<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
39
- "The input prompts to generate answers.",
40
+ "The input prompts to generate answers with an image.",
41
+ )
42
+ _PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
43
+ "prompts_text_only",
44
+ "What is the meaning of life?",
45
+ "The input prompts to generate answers only with text.",
40
46
  )
41
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
42
48
  "max_new_tokens",
@@ -68,13 +74,29 @@ def main(_):
68
74
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
75
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
76
  reauthored_model = qwen_vl.build_model(str(reauthored_checkpoint))
77
+ wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
71
78
 
72
79
  logging.info("Loading the processor from: %s", checkpoint)
73
80
  processor = transformers.AutoProcessor.from_pretrained(checkpoint)
74
81
 
82
+ logging.info("Verifying with text-only prompts...")
83
+ verifier.verify_reauthored_model(
84
+ original_model=transformers_verifier.TransformersModelWrapper(
85
+ original_model
86
+ ),
87
+ reauthored_model=wrapped_reauthored_model,
88
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
89
+ generate_prompts=_PROMPTS_TEXT_ONLY.value,
90
+ max_new_tokens=_MAX_NEW_TOKENS.value,
91
+ atol=1e-04,
92
+ )
93
+
94
+ logging.info("Verifying with image input...")
75
95
  logging.info("Loading the image from: %s", _IMAGE_URL.value)
76
96
  image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
77
- inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
97
+ inputs = processor(
98
+ text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
99
+ )
78
100
 
79
101
  logging.info("Verifying the reauthored model with model.forward()...")
80
102
  logging.info("Forwarding the original model...")
@@ -87,7 +109,6 @@ def main(_):
87
109
  logging.info("outputs_original: %s", outputs_original)
88
110
 
89
111
  logging.info("Forwarding the reauthored model...")
90
- wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
91
112
  grid_thw = inputs["image_grid_thw"].tolist()
92
113
  config = reauthored_model.config.image_encoder_config.image_embedding
93
114
  reauthored_model.image_encoder.set_image_size(
@@ -15,12 +15,12 @@
15
15
 
16
16
  """Example of converting SmolLM model to multi-signature tflite model."""
17
17
 
18
- import os
19
18
  from absl import app
20
19
  from absl import flags
21
20
  from ai_edge_torch.generative.examples.smollm import smollm
22
21
  from ai_edge_torch.generative.utilities import converter
23
22
  from ai_edge_torch.generative.utilities import export_config as export_cfg
23
+ from ai_edge_torch.generative.utilities import loader
24
24
 
25
25
  flags = converter.define_conversion_flags('smollm')
26
26
 
@@ -32,8 +32,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
32
32
 
33
33
 
34
34
  def main(_):
35
+ checkpoint_path = flags.FLAGS.checkpoint_path
35
36
  pytorch_model = smollm.build_model(
36
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
37
+ checkpoint_path,
38
+ custom_loader=loader.maybe_get_custom_loader(
39
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
40
+ ),
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
37
42
  )
38
43
 
39
44
  export_config = export_cfg.get_from_flags()
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.smollm import smollm
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config as export_cfg
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('smollm2')
24
25
 
@@ -30,8 +31,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
30
31
 
31
32
 
32
33
  def main(_):
34
+ checkpoint_path = flags.FLAGS.checkpoint_path
33
35
  pytorch_model = smollm.build_model_v2(
34
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
36
+ checkpoint_path,
37
+ custom_loader=loader.maybe_get_custom_loader(
38
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
39
+ ),
40
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
35
41
  )
36
42
 
37
43
  export_config = export_cfg.get_from_flags()
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -80,12 +82,17 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
80
82
  return config
81
83
 
82
84
 
83
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85
+ def build_model(
86
+ checkpoint_path: str,
87
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
88
+ **kwargs
89
+ ) -> nn.Module:
84
90
  return model_builder.build_decoder_only_model(
85
91
  checkpoint_path=checkpoint_path,
86
92
  config=get_model_config(**kwargs),
87
93
  tensor_names=TENSOR_NAMES,
88
94
  model_class=SmolLM,
95
+ custom_loader=custom_loader,
89
96
  )
90
97
 
91
98
 
@@ -118,10 +125,15 @@ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
118
125
  return config
119
126
 
120
127
 
121
- def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
128
+ def build_model_v2(
129
+ checkpoint_path: str,
130
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
131
+ **kwargs
132
+ ) -> nn.Module:
122
133
  return model_builder.build_decoder_only_model(
123
134
  checkpoint_path=checkpoint_path,
124
135
  config=get_model_config_v2(**kwargs),
125
136
  tensor_names=TENSOR_NAMES,
126
137
  model_class=SmolLM2,
138
+ custom_loader=custom_loader,
127
139
  )
@@ -43,8 +43,8 @@ _MODEL_VERSION = flags.DEFINE_enum(
43
43
  "The version of SmolLm to verify.",
44
44
  )
45
45
  _CHECKPOINT = {
46
- "v1": "HuggingFaceTB/SmolLM-135M",
47
- "v2": "HuggingFaceTB/SmolLM2-135M",
46
+ "v1": "HuggingFaceTB/SmolLM-135M-Instruct",
47
+ "v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
48
48
  }
49
49
 
50
50
  _BUILDER = {
@@ -57,7 +57,8 @@ class CLIP(nn.Module):
57
57
  super().__init__()
58
58
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
59
59
  self.tok_embedding_position = nn.Parameter(
60
- torch.zeros((config.max_seq_len, config.embedding_dim))
60
+ torch.zeros((config.max_seq_len, config.embedding_dim)),
61
+ requires_grad=False,
61
62
  )
62
63
 
63
64
  self.config = config
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("tiny_llama")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = tiny_llama.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,