ai-edge-torch-nightly 0.5.0.dev20250515__py3-none-any.whl → 0.5.0.dev20250517__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/_convert/conversion.py +24 -0
  3. ai_edge_torch/_convert/converter.py +57 -3
  4. ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
  5. ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py +40 -0
  6. ai_edge_torch/_convert/test/test_convert.py +25 -0
  7. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +10 -6
  8. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +7 -1
  9. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +7 -2
  10. ai_edge_torch/generative/examples/deepseek/deepseek.py +9 -5
  11. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +7 -1
  12. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +7 -1
  13. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -6
  14. ai_edge_torch/generative/examples/gemma/gemma2.py +8 -7
  15. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +5 -14
  16. ai_edge_torch/generative/examples/gemma3/decoder.py +10 -10
  17. ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -3
  18. ai_edge_torch/generative/examples/gemma3/image_encoder.py +1 -4
  19. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +7 -1
  20. ai_edge_torch/generative/examples/hammer/hammer.py +15 -6
  21. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +7 -1
  22. ai_edge_torch/generative/examples/llama/llama.py +26 -10
  23. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +0 -1
  24. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +7 -1
  25. ai_edge_torch/generative/examples/openelm/openelm.py +9 -3
  26. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +6 -1
  27. ai_edge_torch/generative/examples/paligemma/decoder.py +1 -4
  28. ai_edge_torch/generative/examples/paligemma/decoder2.py +1 -4
  29. ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -5
  30. ai_edge_torch/generative/examples/paligemma/paligemma.py +12 -5
  31. ai_edge_torch/generative/examples/paligemma/verify.py +27 -5
  32. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +7 -1
  33. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +7 -1
  34. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +7 -1
  35. ai_edge_torch/generative/examples/phi/phi2.py +9 -5
  36. ai_edge_torch/generative/examples/phi/phi3.py +8 -6
  37. ai_edge_torch/generative/examples/phi/phi4.py +8 -6
  38. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +7 -1
  39. ai_edge_torch/generative/examples/qwen/qwen.py +21 -7
  40. ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py +6 -1
  41. ai_edge_torch/generative/examples/qwen_vl/decoder.py +1 -3
  42. ai_edge_torch/generative/examples/qwen_vl/image_encoder.py +13 -7
  43. ai_edge_torch/generative/examples/qwen_vl/qwen_vl.py +12 -4
  44. ai_edge_torch/generative/examples/qwen_vl/verify.py +26 -5
  45. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +7 -2
  46. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +7 -1
  47. ai_edge_torch/generative/examples/smollm/smollm.py +15 -6
  48. ai_edge_torch/generative/examples/smollm/verify.py +2 -2
  49. ai_edge_torch/generative/examples/stable_diffusion/clip.py +8 -5
  50. ai_edge_torch/generative/examples/t5/t5.py +1 -3
  51. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -2
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +7 -1
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +9 -5
  55. ai_edge_torch/generative/layers/model_config.py +2 -2
  56. ai_edge_torch/generative/utilities/converter.py +18 -5
  57. ai_edge_torch/generative/utilities/loader.py +19 -0
  58. ai_edge_torch/odml_torch/lowerings/utils.py +13 -0
  59. ai_edge_torch/version.py +1 -1
  60. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/METADATA +1 -1
  61. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/RECORD +64 -63
  62. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/LICENSE +0 -0
  63. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/WHEEL +0 -0
  64. {ai_edge_torch_nightly-0.5.0.dev20250515.dist-info → ai_edge_torch_nightly-0.5.0.dev20250517.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building Qwen 2.5 models."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -51,9 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
51
53
  intermediate_size=11008,
52
54
  )
53
55
  norm_config = cfg.NormalizationConfig(
54
- type=cfg.NormalizationType.RMS_NORM,
55
- epsilon=1e-06,
56
- enable_hlfb=True,
56
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
57
57
  )
58
58
  block_config = cfg.TransformerBlockConfig(
59
59
  attn_config=attn_config,
@@ -69,7 +69,6 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
69
69
  kv_cache_max_len=kv_cache_max_len,
70
70
  block_configs=block_config,
71
71
  final_norm_config=norm_config,
72
- enable_hlfb=True,
73
72
  )
74
73
  return config
75
74
 
@@ -108,28 +107,43 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
108
107
  return config
109
108
 
110
109
 
111
- def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
110
+ def build_3b_model(
111
+ checkpoint_path: str,
112
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
113
+ **kwargs
114
+ ) -> nn.Module:
112
115
  return model_builder.build_decoder_only_model(
113
116
  checkpoint_path=checkpoint_path,
114
117
  config=get_3b_model_config(**kwargs),
115
118
  tensor_names=TENSOR_NAMES,
116
119
  model_class=Qwen,
120
+ custom_loader=custom_loader,
117
121
  )
118
122
 
119
123
 
120
- def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
124
+ def build_1_5b_model(
125
+ checkpoint_path: str,
126
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
127
+ **kwargs
128
+ ) -> nn.Module:
121
129
  return model_builder.build_decoder_only_model(
122
130
  checkpoint_path=checkpoint_path,
123
131
  config=get_1_5b_model_config(**kwargs),
124
132
  tensor_names=TENSOR_NAMES,
125
133
  model_class=Qwen,
134
+ custom_loader=custom_loader,
126
135
  )
127
136
 
128
137
 
129
- def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
138
+ def build_0_5b_model(
139
+ checkpoint_path: str,
140
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
141
+ **kwargs
142
+ ) -> nn.Module:
130
143
  return model_builder.build_decoder_only_model(
131
144
  checkpoint_path=checkpoint_path,
132
145
  config=get_0_5b_model_config(**kwargs),
133
146
  tensor_names=TENSOR_NAMES,
134
147
  model_class=Qwen,
148
+ custom_loader=custom_loader,
135
149
  )
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('qwen_vl')
24
25
 
@@ -35,8 +36,12 @@ _IMAGE_WIDTH = flags.DEFINE_integer(
35
36
 
36
37
 
37
38
  def main(_):
39
+ checkpoint_path = flags.FLAGS.checkpoint_path
38
40
  pytorch_model = qwen_vl.build_model(
39
- flags.FLAGS.checkpoint_path,
41
+ checkpoint_path,
42
+ custom_loader=loader.maybe_get_custom_loader(
43
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
44
+ ),
40
45
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
41
46
  image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
42
47
  )
@@ -97,8 +97,7 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
97
97
  intermediate_size=11008,
98
98
  )
99
99
  norm_config = cfg.NormalizationConfig(
100
- type=cfg.NormalizationType.RMS_NORM,
101
- epsilon=1e-06,
100
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
102
101
  )
103
102
  block_config = cfg.TransformerBlockConfig(
104
103
  attn_config=attn_config,
@@ -114,7 +113,6 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
114
113
  kv_cache_max_len=kv_cache_max_len,
115
114
  block_configs=block_config,
116
115
  final_norm_config=norm_config,
117
- enable_hlfb=True,
118
116
  )
119
117
  return config
120
118
 
@@ -16,7 +16,7 @@
16
16
  """Example of building an image encoder of Qwen 2.5 VL model."""
17
17
 
18
18
  import dataclasses
19
- from typing import List, Optional, Tuple
19
+ from typing import Callable, Dict, List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.layers import attention
22
22
  from ai_edge_torch.generative.layers import attention_utils
@@ -332,8 +332,7 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
332
332
  use_bias=True,
333
333
  )
334
334
  norm_config = cfg.NormalizationConfig(
335
- type=cfg.NormalizationType.RMS_NORM,
336
- epsilon=1e-6,
335
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
337
336
  )
338
337
  block_config = cfg.TransformerBlockConfig(
339
338
  attn_config=attn_config,
@@ -359,7 +358,6 @@ def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
359
358
  window_size=112,
360
359
  spatial_merge_size=2,
361
360
  full_atten_block_indexes=[7, 15, 23, 31],
362
- enable_hlfb=True,
363
361
  )
364
362
  return config
365
363
 
@@ -385,13 +383,21 @@ def build_image_encoder(
385
383
  return encoder
386
384
 
387
385
 
388
- def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
389
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
386
+ def load_image_encoder(
387
+ checkpoint_path: str,
388
+ encoder: QwenVLImageEncoder,
389
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
390
+ ):
391
+ loader = loading_utils.ModelLoader(
392
+ checkpoint_path, TENSOR_NAMES, custom_loader
393
+ )
390
394
  # Loose the strictness because only image encoder is being loaded.
391
395
  loader.load(encoder, strict=False)
392
396
 
393
397
  # Load merger weights.
394
- merger_loader = loading_utils.ModelLoader(checkpoint_path, None)
398
+ merger_loader = loading_utils.ModelLoader(
399
+ checkpoint_path, None, custom_loader
400
+ )
395
401
  state = merger_loader.get_state()
396
402
  w1_state = dict()
397
403
  w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
@@ -16,7 +16,7 @@
16
16
  """Example of building a full-stack of Qwen 2.5 VL model."""
17
17
 
18
18
  import dataclasses
19
- from typing import List, Optional, Tuple
19
+ from typing import Callable, Dict, List, Optional, Tuple
20
20
 
21
21
  from ai_edge_torch.generative.examples.qwen_vl import decoder
22
22
  from ai_edge_torch.generative.examples.qwen_vl import image_encoder
@@ -204,12 +204,20 @@ def get_fake_model_config(**kwargs) -> QwenVLConfig:
204
204
  )
205
205
 
206
206
 
207
- def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
207
+ def build_model(
208
+ checkpoint_path: str,
209
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
210
+ **kwargs
211
+ ) -> QwenVL:
208
212
  config = get_model_config(**kwargs)
209
213
  model = QwenVL(config)
210
- image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
214
+ image_encoder.load_image_encoder(
215
+ checkpoint_path, model.image_encoder, custom_loader
216
+ )
211
217
  # Load the parameters of decoder.
212
- loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
218
+ loader = loading_utils.ModelLoader(
219
+ checkpoint_path, decoder.TENSOR_NAMES, custom_loader
220
+ )
213
221
  loader.load(model.decoder, strict=False)
214
222
  model.eval()
215
223
  return model
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
24
24
  from ai_edge_torch.generative.layers import kv_cache
25
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
26
  from ai_edge_torch.generative.utilities import verifier
26
27
  from PIL import Image
27
28
  import requests
@@ -33,10 +34,15 @@ _IMAGE_URL = flags.DEFINE_string(
33
34
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
34
35
  "The image URI to encode.",
35
36
  )
36
- _PROMPTS = flags.DEFINE_string(
37
- "prompts",
37
+ _PROMPTS_WITH_IMAGE = flags.DEFINE_string(
38
+ "prompts_with_image",
38
39
  "<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>",
39
- "The input prompts to generate answers.",
40
+ "The input prompts to generate answers with an image.",
41
+ )
42
+ _PROMPTS_TEXT_ONLY = flags.DEFINE_multi_string(
43
+ "prompts_text_only",
44
+ "What is the meaning of life?",
45
+ "The input prompts to generate answers only with text.",
40
46
  )
41
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
42
48
  "max_new_tokens",
@@ -68,13 +74,29 @@ def main(_):
68
74
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
75
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
76
  reauthored_model = qwen_vl.build_model(str(reauthored_checkpoint))
77
+ wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
71
78
 
72
79
  logging.info("Loading the processor from: %s", checkpoint)
73
80
  processor = transformers.AutoProcessor.from_pretrained(checkpoint)
74
81
 
82
+ logging.info("Verifying with text-only prompts...")
83
+ verifier.verify_reauthored_model(
84
+ original_model=transformers_verifier.TransformersModelWrapper(
85
+ original_model
86
+ ),
87
+ reauthored_model=wrapped_reauthored_model,
88
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
89
+ generate_prompts=_PROMPTS_TEXT_ONLY.value,
90
+ max_new_tokens=_MAX_NEW_TOKENS.value,
91
+ atol=1e-04,
92
+ )
93
+
94
+ logging.info("Verifying with image input...")
75
95
  logging.info("Loading the image from: %s", _IMAGE_URL.value)
76
96
  image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
77
- inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
97
+ inputs = processor(
98
+ text=_PROMPTS_WITH_IMAGE.value, images=image, return_tensors="pt"
99
+ )
78
100
 
79
101
  logging.info("Verifying the reauthored model with model.forward()...")
80
102
  logging.info("Forwarding the original model...")
@@ -87,7 +109,6 @@ def main(_):
87
109
  logging.info("outputs_original: %s", outputs_original)
88
110
 
89
111
  logging.info("Forwarding the reauthored model...")
90
- wrapped_reauthored_model = ReauthoredQwenVLWrapper(reauthored_model)
91
112
  grid_thw = inputs["image_grid_thw"].tolist()
92
113
  config = reauthored_model.config.image_encoder_config.image_embedding
93
114
  reauthored_model.image_encoder.set_image_size(
@@ -15,12 +15,12 @@
15
15
 
16
16
  """Example of converting SmolLM model to multi-signature tflite model."""
17
17
 
18
- import os
19
18
  from absl import app
20
19
  from absl import flags
21
20
  from ai_edge_torch.generative.examples.smollm import smollm
22
21
  from ai_edge_torch.generative.utilities import converter
23
22
  from ai_edge_torch.generative.utilities import export_config as export_cfg
23
+ from ai_edge_torch.generative.utilities import loader
24
24
 
25
25
  flags = converter.define_conversion_flags('smollm')
26
26
 
@@ -32,8 +32,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
32
32
 
33
33
 
34
34
  def main(_):
35
+ checkpoint_path = flags.FLAGS.checkpoint_path
35
36
  pytorch_model = smollm.build_model(
36
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
37
+ checkpoint_path,
38
+ custom_loader=loader.maybe_get_custom_loader(
39
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
40
+ ),
41
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
37
42
  )
38
43
 
39
44
  export_config = export_cfg.get_from_flags()
@@ -19,6 +19,7 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.smollm import smollm
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config as export_cfg
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags('smollm2')
24
25
 
@@ -30,8 +31,13 @@ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
30
31
 
31
32
 
32
33
  def main(_):
34
+ checkpoint_path = flags.FLAGS.checkpoint_path
33
35
  pytorch_model = smollm.build_model_v2(
34
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
36
+ checkpoint_path,
37
+ custom_loader=loader.maybe_get_custom_loader(
38
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
39
+ ),
40
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
35
41
  )
36
42
 
37
43
  export_config = export_cfg.get_from_flags()
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building a SmolLM model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -49,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
52
  intermediate_size=1536,
51
53
  )
52
- norm_config = cfg.NormalizationConfig(
53
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
55
55
  block_config = cfg.TransformerBlockConfig(
56
56
  attn_config=attn_config,
57
57
  ff_config=ff_config,
@@ -66,7 +66,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
66
66
  kv_cache_max_len=kv_cache_max_len,
67
67
  block_configs=block_config,
68
68
  final_norm_config=norm_config,
69
- enable_hlfb=True,
70
69
  )
71
70
  return config
72
71
 
@@ -80,12 +79,17 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
80
79
  return config
81
80
 
82
81
 
83
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
82
+ def build_model(
83
+ checkpoint_path: str,
84
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
85
+ **kwargs
86
+ ) -> nn.Module:
84
87
  return model_builder.build_decoder_only_model(
85
88
  checkpoint_path=checkpoint_path,
86
89
  config=get_model_config(**kwargs),
87
90
  tensor_names=TENSOR_NAMES,
88
91
  model_class=SmolLM,
92
+ custom_loader=custom_loader,
89
93
  )
90
94
 
91
95
 
@@ -118,10 +122,15 @@ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
118
122
  return config
119
123
 
120
124
 
121
- def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
125
+ def build_model_v2(
126
+ checkpoint_path: str,
127
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
128
+ **kwargs
129
+ ) -> nn.Module:
122
130
  return model_builder.build_decoder_only_model(
123
131
  checkpoint_path=checkpoint_path,
124
132
  config=get_model_config_v2(**kwargs),
125
133
  tensor_names=TENSOR_NAMES,
126
134
  model_class=SmolLM2,
135
+ custom_loader=custom_loader,
127
136
  )
@@ -43,8 +43,8 @@ _MODEL_VERSION = flags.DEFINE_enum(
43
43
  "The version of SmolLm to verify.",
44
44
  )
45
45
  _CHECKPOINT = {
46
- "v1": "HuggingFaceTB/SmolLM-135M",
47
- "v2": "HuggingFaceTB/SmolLM2-135M",
46
+ "v1": "HuggingFaceTB/SmolLM-135M-Instruct",
47
+ "v2": "HuggingFaceTB/SmolLM2-135M-Instruct",
48
48
  }
49
49
 
50
50
  _BUILDER = {
@@ -57,7 +57,8 @@ class CLIP(nn.Module):
57
57
  super().__init__()
58
58
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
59
59
  self.tok_embedding_position = nn.Parameter(
60
- torch.zeros((config.max_seq_len, config.embedding_dim))
60
+ torch.zeros((config.max_seq_len, config.embedding_dim)),
61
+ requires_grad=False,
61
62
  )
62
63
 
63
64
  self.config = config
@@ -112,7 +113,9 @@ def get_model_config() -> cfg.ModelConfig:
112
113
  use_bias=True,
113
114
  )
114
115
 
115
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
116
+ norm_config = cfg.NormalizationConfig(
117
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
118
+ )
116
119
 
117
120
  block_config = cfg.TransformerBlockConfig(
118
121
  attn_config=attn_config,
@@ -128,7 +131,6 @@ def get_model_config() -> cfg.ModelConfig:
128
131
  embedding_dim=embedding_dim,
129
132
  block_configs=block_config,
130
133
  final_norm_config=norm_config,
131
- enable_hlfb=True,
132
134
  )
133
135
 
134
136
  return config
@@ -163,7 +165,9 @@ def get_fake_model_config() -> cfg.ModelConfig:
163
165
  use_bias=True,
164
166
  )
165
167
 
166
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
168
+ norm_config = cfg.NormalizationConfig(
169
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=False
170
+ )
167
171
 
168
172
  block_config = cfg.TransformerBlockConfig(
169
173
  attn_config=attn_config,
@@ -179,7 +183,6 @@ def get_fake_model_config() -> cfg.ModelConfig:
179
183
  embedding_dim=embedding_dim,
180
184
  block_configs=block_config,
181
185
  final_norm_config=norm_config,
182
- enable_hlfb=True,
183
186
  )
184
187
 
185
188
  return config
@@ -393,8 +393,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
393
393
  )
394
394
  # T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
395
395
  norm_config = cfg.NormalizationConfig(
396
- type=cfg.NormalizationType.RMS_NORM,
397
- epsilon=1e-6,
396
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=False
398
397
  )
399
398
  block_config = cfg.TransformerBlockConfig(
400
399
  attn_config=attn_config,
@@ -411,7 +410,6 @@ def get_model_config_t5() -> cfg.ModelConfig:
411
410
  block_configs=block_config,
412
411
  final_norm_config=norm_config,
413
412
  lm_head_use_bias=False,
414
- enable_hlfb=True,
415
413
  )
416
414
  return config
417
415
 
@@ -138,7 +138,9 @@ def get_model_config() -> cfg.ModelConfig:
138
138
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
139
139
  intermediate_size=256,
140
140
  )
141
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
141
+ norm_config = cfg.NormalizationConfig(
142
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
143
+ )
142
144
  block_config = cfg.TransformerBlockConfig(
143
145
  attn_config=attn_config,
144
146
  ff_config=ff_config,
@@ -152,5 +154,6 @@ def get_model_config() -> cfg.ModelConfig:
152
154
  embedding_dim=128,
153
155
  block_configs=block_config,
154
156
  final_norm_config=norm_config,
157
+ enable_hlfb=False,
155
158
  )
156
159
  return config
@@ -108,7 +108,9 @@ def get_model_config() -> cfg.ModelConfig:
108
108
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
109
109
  intermediate_size=256,
110
110
  )
111
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
111
+ norm_config = cfg.NormalizationConfig(
112
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=False
113
+ )
112
114
  block_config = cfg.TransformerBlockConfig(
113
115
  attn_config=attn_config,
114
116
  ff_config=ff_config,
@@ -122,7 +124,6 @@ def get_model_config() -> cfg.ModelConfig:
122
124
  embedding_dim=128,
123
125
  block_configs=block_config,
124
126
  final_norm_config=norm_config,
125
- enable_hlfb=True,
126
127
  )
127
128
  return config
128
129
 
@@ -19,13 +19,19 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
20
20
  from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities import loader
22
23
 
23
24
  flags = converter.define_conversion_flags("tiny_llama")
24
25
 
25
26
 
26
27
  def main(_):
28
+ checkpoint_path = flags.FLAGS.checkpoint_path
27
29
  pytorch_model = tiny_llama.build_model(
28
- flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30
+ checkpoint_path,
31
+ custom_loader=loader.maybe_get_custom_loader(
32
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33
+ ),
34
+ kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
29
35
  )
30
36
  converter.convert_to_tflite(
31
37
  pytorch_model,
@@ -15,8 +15,10 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
+ from typing import Callable, Dict
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
21
+ import torch
20
22
  from torch import nn
21
23
 
22
24
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -49,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
51
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
52
  intermediate_size=5632,
51
53
  )
52
- norm_config = cfg.NormalizationConfig(
53
- type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
- )
54
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
55
55
  block_config = cfg.TransformerBlockConfig(
56
56
  attn_config=attn_config,
57
57
  ff_config=ff_config,
@@ -67,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
67
67
  block_configs=block_config,
68
68
  final_norm_config=norm_config,
69
69
  lm_head_share_weight_with_embedding=False,
70
- enable_hlfb=True,
71
70
  )
72
71
  return config
73
72
 
@@ -81,10 +80,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
81
80
  return config
82
81
 
83
82
 
84
- def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
83
+ def build_model(
84
+ checkpoint_path: str,
85
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
86
+ **kwargs
87
+ ) -> nn.Module:
85
88
  return model_builder.build_decoder_only_model(
86
89
  checkpoint_path=checkpoint_path,
87
90
  config=get_model_config(**kwargs),
88
91
  tensor_names=TENSOR_NAMES,
89
92
  model_class=TinyLlama,
93
+ custom_loader=custom_loader,
90
94
  )
@@ -66,7 +66,7 @@ class NormalizationConfig:
66
66
  """Normalizater parameters."""
67
67
 
68
68
  type: NormalizationType = NormalizationType.NONE
69
- enable_hlfb: bool = False
69
+ enable_hlfb: bool = True
70
70
  epsilon: float = 1e-5
71
71
  zero_centered: bool = False
72
72
  # Number of groups used in group normalization.
@@ -218,7 +218,7 @@ class ModelConfig:
218
218
  lm_head_share_weight_with_embedding: bool = True
219
219
 
220
220
  # Whether to turn on high-level function boundary.
221
- enable_hlfb: bool = False
221
+ enable_hlfb: bool = True
222
222
 
223
223
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
224
224
  kv_cache_max_len: int = 0
@@ -100,7 +100,8 @@ def define_conversion_flags(
100
100
  flags.DEFINE_string(
101
101
  'quantize',
102
102
  'dynamic_int8',
103
- 'How the model should be quantized.',
103
+ 'How the model should be quantized. Set to "none" to disable'
104
+ ' quantization. See `QuantizationName` for supported quantization types.',
104
105
  )
105
106
  flags.DEFINE_multi_integer(
106
107
  'lora_ranks',
@@ -119,6 +120,12 @@ def define_conversion_flags(
119
120
  default_transpose_kv_cache,
120
121
  'If true, the model will be converted with transposed KV cache.',
121
122
  )
123
+ flags.DEFINE_bool(
124
+ 'custom_checkpoint_loader',
125
+ False,
126
+ 'If true, the conversion script will use a custom checkpoint loader which'
127
+ ' will read a checkpoint from a remote source.',
128
+ )
122
129
  return flags
123
130
 
124
131
 
@@ -397,13 +404,19 @@ def _export_helper(
397
404
  )
398
405
 
399
406
  if prefill_pixel_values is not None:
400
- sample_kwargs['tokens'] = prefill_tokens_list_with_pixel[i]
401
- sample_kwargs['input_pos'] = prefill_input_pos_list_with_pixel[i]
402
- sample_kwargs['pixel_values'] = prefill_pixel_values
407
+ sample_pixel_kwargs = {
408
+ 'tokens': prefill_tokens_list_with_pixel[i],
409
+ 'input_pos': prefill_input_pos_list_with_pixel[i],
410
+ 'kv_cache': prefill_kv,
411
+ 'pixel_values': prefill_pixel_values,
412
+ }
413
+ # mask should be built internally when pixel values are passed.
414
+ if lora is not None:
415
+ sample_pixel_kwargs['lora'] = lora
403
416
  converter.add_signature(
404
417
  prefill_signature_name + '_pixel',
405
418
  mod,
406
- sample_kwargs=sample_kwargs,
419
+ sample_kwargs=sample_pixel_kwargs,
407
420
  )
408
421
 
409
422
  sample_kwargs = {
@@ -49,6 +49,25 @@ def get_custom_loader(
49
49
  raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
50
50
 
51
51
 
52
+ def maybe_get_custom_loader(
53
+ checkpoint_path: str,
54
+ use_custom_loader: bool = False,
55
+ ) -> Callable[[str], Dict[str, torch.Tensor]] | None:
56
+ """Returns a custom loader for the given checkpoint path.
57
+
58
+ If use_custom_loader is True, the function will return a custom loader.
59
+ Otherwise, it will return None.
60
+
61
+ Args:
62
+ checkpoint_path (string): The path to the checkpoint.
63
+ use_custom_loader (bool): Whether to use a custom loader.
64
+
65
+ Returns:
66
+ Callable[[str], Dict[str, torch.Tensor]] | None: The custom loader.
67
+ """
68
+ return get_custom_loader(checkpoint_path) if use_custom_loader else None
69
+
70
+
52
71
  def load_safetensors(full_path: str):
53
72
  """Loads safetensors into a single state dictionary.
54
73