ai-edge-torch-nightly 0.5.0.dev20250425__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +2 -36
  2. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
  3. ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
  4. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
  5. ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
  6. ai_edge_torch/generative/examples/hammer/verify.py +86 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
  8. ai_edge_torch/generative/examples/llama/llama.py +3 -1
  9. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
  10. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
  12. ai_edge_torch/generative/examples/phi/phi2.py +1 -1
  13. ai_edge_torch/generative/examples/phi/phi3.py +3 -1
  14. ai_edge_torch/generative/examples/phi/phi4.py +3 -1
  15. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +1 -37
  16. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
  17. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
  18. ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
  19. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
  20. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
  21. ai_edge_torch/generative/layers/kv_cache.py +2 -4
  22. ai_edge_torch/generative/test/test_model_conversion_large.py +7 -0
  23. ai_edge_torch/generative/utilities/converter.py +7 -2
  24. ai_edge_torch/generative/utilities/export_config.py +30 -0
  25. ai_edge_torch/model.py +2 -0
  26. ai_edge_torch/version.py +1 -1
  27. {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
  28. {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +31 -27
  29. {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
  30. {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
  31. {ai_edge_torch_nightly-0.5.0.dev20250425.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -19,41 +19,9 @@ from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
20
  from ai_edge_torch.generative.layers import kv_cache
21
21
  from ai_edge_torch.generative.utilities import converter
22
- from ai_edge_torch.generative.utilities.model_builder import export_cfg
23
- import torch
22
+ from ai_edge_torch.generative.utilities import export_config
24
23
 
25
24
  flags = converter.define_conversion_flags('deepseek')
26
- ExportConfig = export_cfg.ExportConfig
27
-
28
-
29
- def _create_mask(mask_len, kv_cache_max_len):
30
- mask = torch.full(
31
- (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
32
- )
33
- mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
34
- return mask
35
-
36
-
37
- def _create_export_config(
38
- prefill_seq_lens: list[int], kv_cache_max_len: int
39
- ) -> ExportConfig:
40
- """Creates the export config for the model."""
41
- export_config = ExportConfig()
42
- if isinstance(prefill_seq_lens, list):
43
- prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
44
- else:
45
- prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
46
-
47
- export_config.prefill_mask = prefill_mask
48
-
49
- decode_mask = torch.full(
50
- (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
51
- )
52
- decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
53
- export_config.decode_mask = decode_mask
54
- export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
55
- return export_config
56
-
57
25
 
58
26
  def main(_):
59
27
  pytorch_model = deepseek.build_model(
@@ -66,9 +34,7 @@ def main(_):
66
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
67
35
  quantize=flags.FLAGS.quantize,
68
36
  lora_ranks=flags.FLAGS.lora_ranks,
69
- export_config=_create_export_config(
70
- flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
71
- ),
37
+ export_config=export_config.get_from_flags(),
72
38
  )
73
39
 
74
40
 
@@ -17,14 +17,10 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
- from ai_edge_torch.generative.layers import kv_cache
21
20
  from ai_edge_torch.generative.utilities import converter
22
21
  from ai_edge_torch.generative.utilities import export_config
23
- import torch
24
22
 
25
23
  flags = converter.define_conversion_flags('gemma3-1b')
26
- ExportConfig = export_config.ExportConfig
27
-
28
24
 
29
25
  _MODEL_SIZE = flags.DEFINE_string(
30
26
  'model_size',
@@ -33,55 +29,23 @@ _MODEL_SIZE = flags.DEFINE_string(
33
29
  )
34
30
 
35
31
 
36
- def _create_mask(mask_len, kv_cache_max_len):
37
- mask = torch.full(
38
- (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
39
- )
40
- mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
41
- return mask
42
-
43
-
44
- def _create_export_config(
45
- prefill_seq_lens: list[int], kv_cache_max_len: int
46
- ) -> ExportConfig:
47
- """Creates the export config for the model."""
48
- export_config = ExportConfig()
49
- if isinstance(prefill_seq_lens, list):
50
- prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
51
- else:
52
- prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
53
-
54
- export_config.prefill_mask = prefill_mask
55
-
56
- decode_mask = torch.full(
57
- (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
58
- )
59
- decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
60
- export_config.decode_mask = decode_mask
61
- export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
62
- return export_config
63
-
64
-
65
32
  def main(_):
66
33
  if _MODEL_SIZE.value == '1b':
67
34
  pytorch_model = gemma3.build_model_1b(
68
35
  flags.FLAGS.checkpoint_path,
69
36
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
70
37
  )
71
- config = pytorch_model.config
72
38
  else:
73
39
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
40
+
74
41
  converter.convert_to_tflite(
75
42
  pytorch_model,
76
43
  output_path=flags.FLAGS.output_path,
77
44
  output_name_prefix=flags.FLAGS.output_name_prefix,
78
45
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
79
46
  quantize=flags.FLAGS.quantize,
80
- config=config,
81
47
  lora_ranks=flags.FLAGS.lora_ranks,
82
- export_config=_create_export_config(
83
- flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
84
- ),
48
+ export_config=export_config.get_from_flags(),
85
49
  )
86
50
 
87
51
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,92 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting hammer 2.1 models to multi-signature tflite model."""
17
+
18
+ from absl import app
19
+ from ai_edge_torch.generative.examples.hammer import hammer
20
+ from ai_edge_torch.generative.layers import kv_cache
21
+ from ai_edge_torch.generative.utilities import converter
22
+ from ai_edge_torch.generative.utilities import export_config as export_cfg
23
+ import torch
24
+
25
+
26
+ flags = converter.define_conversion_flags('hammer')
27
+ ExportConfig = export_cfg.ExportConfig
28
+
29
+
30
+ _MODEL_SIZE = flags.DEFINE_enum(
31
+ 'model_size',
32
+ '1.5b',
33
+ ['0.5b', '1.5b'],
34
+ 'The size of the model to convert.',
35
+ )
36
+
37
+ _BUILDER = {
38
+ '0.5b': hammer.build_0_5b_model,
39
+ '1.5b': hammer.build_1_5b_model,
40
+ }
41
+
42
+
43
+ def _create_mask(mask_len, kv_cache_max_len):
44
+ mask = torch.full(
45
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
46
+ )
47
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
48
+ return mask
49
+
50
+
51
+ def _create_export_config(
52
+ prefill_seq_lens: list[int], kv_cache_max_len: int
53
+ ) -> ExportConfig:
54
+ """Creates the export config for the model."""
55
+ export_config = ExportConfig()
56
+ if isinstance(prefill_seq_lens, list):
57
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
58
+ else:
59
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
60
+
61
+ export_config.prefill_mask = prefill_mask
62
+
63
+ decode_mask = torch.full(
64
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
65
+ )
66
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
67
+ export_config.decode_mask = decode_mask
68
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
69
+ return export_config
70
+
71
+
72
+ def main(_):
73
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
74
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
75
+ )
76
+ converter.convert_to_tflite(
77
+ pytorch_model,
78
+ output_path=flags.FLAGS.output_path,
79
+ output_name_prefix=flags.FLAGS.output_name_prefix,
80
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
81
+ quantize=flags.FLAGS.quantize,
82
+ lora_ranks=flags.FLAGS.lora_ranks,
83
+ export_config=_create_export_config(
84
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
85
+ )
86
+ if flags.FLAGS.transpose_kv_cache
87
+ else ExportConfig(),
88
+ )
89
+
90
+
91
+ if __name__ == '__main__':
92
+ app.run(main)
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Hammer 2.1 models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
23
+
24
+
25
+ class Hammer(model_builder.DecoderOnlyModel):
26
+ """A Hammer model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
30
+ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Hammer 2.1 1.5B model."""
32
+ attn_config = cfg.AttentionConfig(
33
+ num_heads=12,
34
+ head_dim=128,
35
+ num_query_groups=2,
36
+ rotary_base=1000000,
37
+ rotary_percentage=1.0,
38
+ qkv_use_bias=True,
39
+ )
40
+ ff_config = cfg.FeedForwardConfig(
41
+ type=cfg.FeedForwardType.GATED,
42
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
43
+ intermediate_size=8960,
44
+ )
45
+ norm_config = cfg.NormalizationConfig(
46
+ type=cfg.NormalizationType.RMS_NORM,
47
+ epsilon=1e-06,
48
+ enable_hlfb=True,
49
+ )
50
+ block_config = cfg.TransformerBlockConfig(
51
+ attn_config=attn_config,
52
+ ff_config=ff_config,
53
+ pre_attention_norm_config=norm_config,
54
+ post_attention_norm_config=norm_config,
55
+ )
56
+ config = cfg.ModelConfig(
57
+ vocab_size=151665,
58
+ num_layers=28,
59
+ max_seq_len=32768,
60
+ embedding_dim=1536,
61
+ kv_cache_max_len=kv_cache_max_len,
62
+ block_configs=block_config,
63
+ final_norm_config=norm_config,
64
+ enable_hlfb=True,
65
+ )
66
+ return config
67
+
68
+
69
+ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
70
+ """Returns the model config for a Hammer 2.1 0.5B model."""
71
+ config = get_1_5b_model_config(kv_cache_max_len)
72
+ # Hammer has only one block config.
73
+ block_config = config.block_config(0)
74
+ block_config.attn_config.num_heads = 14
75
+ block_config.attn_config.head_dim = 64
76
+ block_config.ff_config.intermediate_size = 4864
77
+ config.num_layers = 24
78
+ config.embedding_dim = 896
79
+ return config
80
+
81
+
82
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
83
+ config = get_1_5b_model_config(**kwargs)
84
+ config.vocab_size = 128
85
+ config.num_layers = 2
86
+ config.embedding_dim = 16
87
+ # Hammer has only one block config.
88
+ config.block_config(0).ff_config.intermediate_size = 64
89
+ return config
90
+
91
+
92
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
93
+ return model_builder.build_decoder_only_model(
94
+ checkpoint_path=checkpoint_path,
95
+ config=get_1_5b_model_config(**kwargs),
96
+ tensor_names=TENSOR_NAMES,
97
+ model_class=Hammer,
98
+ )
99
+
100
+
101
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
102
+ return model_builder.build_decoder_only_model(
103
+ checkpoint_path=checkpoint_path,
104
+ config=get_0_5b_model_config(**kwargs),
105
+ tensor_names=TENSOR_NAMES,
106
+ model_class=Hammer,
107
+ )
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.hammer import hammer
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _MODEL_SIZE = flags.DEFINE_enum(
30
+ "model_size",
31
+ "0.5b",
32
+ ["0.5b", "1.5b"],
33
+ "The size of the model to verify.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_multi_string(
36
+ "prompts",
37
+ "What is the meaning of life?",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+ _CHECKPOINT = {
47
+ "0.5b": "MadeAgents/Hammer2.1-0.5b",
48
+ "1.5b": "MadeAgents/Hammer2.1-1.5b",
49
+ }
50
+
51
+ _BUILDER = {
52
+ "0.5b": hammer.build_0_5b_model,
53
+ "1.5b": hammer.build_1_5b_model,
54
+ }
55
+
56
+
57
+ def main(_):
58
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
59
+ logging.info("Loading the original model from: %s", checkpoint)
60
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
61
+
62
+ # Locate the cached dir.
63
+ cached_config_file = transformers.utils.cached_file(
64
+ checkpoint, transformers.utils.CONFIG_NAME
65
+ )
66
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
67
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
69
+
70
+ logging.info("Loading the tokenizer from: %s", checkpoint)
71
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
72
+
73
+ verifier.verify_reauthored_model(
74
+ original_model=transformers_verifier.TransformersModelWrapper(
75
+ original_model
76
+ ),
77
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
79
+ generate_prompts=_PROMPTS.value,
80
+ max_new_tokens=_MAX_NEW_TOKENS.value,
81
+ atol=1e-04,
82
+ )
83
+
84
+
85
+ if __name__ == "__main__":
86
+ app.run(main)
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
 
24
24
  flags = converter.define_conversion_flags('llama')
25
- ExportConfig = export_config.ExportConfig
26
-
27
25
 
28
26
  _MODEL_SIZE = flags.DEFINE_enum(
29
27
  'model_size',
@@ -49,7 +47,7 @@ def main(_):
49
47
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
50
48
  quantize=flags.FLAGS.quantize,
51
49
  lora_ranks=flags.FLAGS.lora_ranks,
52
- export_config=ExportConfig(),
50
+ export_config=export_config.get_from_flags(),
53
51
  )
54
52
 
55
53
 
@@ -121,7 +121,9 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
121
121
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
122
122
  intermediate_size=8192,
123
123
  )
124
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
124
+ norm_config = cfg.NormalizationConfig(
125
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
126
+ )
125
127
  block_config = cfg.TransformerBlockConfig(
126
128
  attn_config=attn_config,
127
129
  ff_config=ff_config,
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("phi3")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("phi4")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -22,7 +22,6 @@ from ai_edge_torch.generative.utilities import converter
22
22
  from ai_edge_torch.generative.utilities import export_config
23
23
 
24
24
  flags = converter.define_conversion_flags("phi2")
25
- ExportConfig = export_config.ExportConfig
26
25
 
27
26
 
28
27
  def main(_):
@@ -36,7 +35,7 @@ def main(_):
36
35
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
37
36
  quantize=flags.FLAGS.quantize,
38
37
  lora_ranks=flags.FLAGS.lora_ranks,
39
- export_config=ExportConfig(),
38
+ export_config=export_config.get_from_flags(),
40
39
  )
41
40
 
42
41
 
@@ -65,7 +65,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
65
65
  use_bias=True,
66
66
  )
67
67
  norm_config = cfg.NormalizationConfig(
68
- type=cfg.NormalizationType.LAYER_NORM,
68
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
69
69
  )
70
70
  block_config = cfg.TransformerBlockConfig(
71
71
  attn_config=attn_config,
@@ -162,7 +162,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
162
162
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
163
163
  intermediate_size=8192,
164
164
  )
165
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
165
+ norm_config = cfg.NormalizationConfig(
166
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
167
+ )
166
168
  block_config = cfg.TransformerBlockConfig(
167
169
  attn_config=attn_config,
168
170
  ff_config=ff_config,
@@ -112,7 +112,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113
113
  intermediate_size=8192,
114
114
  )
115
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
115
+ norm_config = cfg.NormalizationConfig(
116
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
117
+ )
116
118
  block_config = cfg.TransformerBlockConfig(
117
119
  attn_config=attn_config,
118
120
  ff_config=ff_config,
@@ -17,13 +17,10 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen
20
- from ai_edge_torch.generative.layers import kv_cache
21
20
  from ai_edge_torch.generative.utilities import converter
22
21
  from ai_edge_torch.generative.utilities import export_config
23
- import torch
24
22
 
25
23
  flags = converter.define_conversion_flags('qwen')
26
- ExportConfig = export_config.ExportConfig
27
24
 
28
25
  _MODEL_SIZE = flags.DEFINE_enum(
29
26
  'model_size',
@@ -39,35 +36,6 @@ _BUILDER = {
39
36
  }
40
37
 
41
38
 
42
- def _create_mask(mask_len, kv_cache_max_len):
43
- mask = torch.full(
44
- (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
45
- )
46
- mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
47
- return mask
48
-
49
-
50
- def _create_export_config(
51
- prefill_seq_lens: list[int], kv_cache_max_len: int
52
- ) -> ExportConfig:
53
- """Creates the export config for the model."""
54
- export_config = ExportConfig()
55
- if isinstance(prefill_seq_lens, list):
56
- prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
57
- else:
58
- prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
59
-
60
- export_config.prefill_mask = prefill_mask
61
-
62
- decode_mask = torch.full(
63
- (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
64
- )
65
- decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
66
- export_config.decode_mask = decode_mask
67
- export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
68
- return export_config
69
-
70
-
71
39
  def main(_):
72
40
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
73
41
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
@@ -79,11 +47,7 @@ def main(_):
79
47
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
80
48
  quantize=flags.FLAGS.quantize,
81
49
  lora_ranks=flags.FLAGS.lora_ranks,
82
- export_config=_create_export_config(
83
- flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
84
- )
85
- if flags.FLAGS.transpose_kv_cache
86
- else ExportConfig(),
50
+ export_config=export_config.get_from_flags(),
87
51
  )
88
52
 
89
53
 
@@ -35,6 +35,10 @@ def main(_):
35
35
  pytorch_model = smollm.build_model(
36
36
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
37
37
  )
38
+
39
+ export_config = export_cfg.get_from_flags()
40
+ export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
41
+
38
42
  converter.convert_to_tflite(
39
43
  pytorch_model,
40
44
  output_path=flags.FLAGS.output_path,
@@ -42,9 +46,7 @@ def main(_):
42
46
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43
47
  quantize=flags.FLAGS.quantize,
44
48
  lora_ranks=flags.FLAGS.lora_ranks,
45
- export_config=export_cfg.ExportConfig(
46
- decode_batch_size=_DECODE_BATCH_SIZE.value
47
- ),
49
+ export_config=export_config,
48
50
  )
49
51
 
50
52
 
@@ -34,6 +34,9 @@ def main(_):
34
34
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
35
35
  )
36
36
 
37
+ export_config = export_cfg.get_from_flags()
38
+ export_config.decode_batch_size = _DECODE_BATCH_SIZE.value
39
+
37
40
  converter.convert_to_tflite(
38
41
  pytorch_model,
39
42
  output_path=flags.FLAGS.output_path,
@@ -41,9 +44,7 @@ def main(_):
41
44
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
42
45
  quantize=flags.FLAGS.quantize,
43
46
  lora_ranks=flags.FLAGS.lora_ranks,
44
- export_config=export_cfg.ExportConfig(
45
- decode_batch_size=_DECODE_BATCH_SIZE.value
46
- ),
47
+ export_config=export_config,
47
48
  )
48
49
 
49
50
 
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
49
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
50
  intermediate_size=1536,
51
51
  )
52
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
52
+ norm_config = cfg.NormalizationConfig(
53
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
+ )
53
55
  block_config = cfg.TransformerBlockConfig(
54
56
  attn_config=attn_config,
55
57
  ff_config=ff_config,
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("tiny_llama")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
49
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
50
50
  intermediate_size=5632,
51
51
  )
52
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
52
+ norm_config = cfg.NormalizationConfig(
53
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54
+ )
53
55
  block_config = cfg.TransformerBlockConfig(
54
56
  attn_config=attn_config,
55
57
  ff_config=ff_config,
@@ -51,10 +51,7 @@ class KVCacheEntry:
51
51
  config: model_config.AttentionConfig,
52
52
  batch_size: int,
53
53
  ) -> List[int]:
54
- """Constructs the shape of the key or value cache entry based on
55
-
56
- the specified layout.
57
- """
54
+ """Construct the shape of KV cache entry based on the specified layout."""
58
55
  output_shape = []
59
56
  for dim_spec in shape_spec:
60
57
  if dim_spec is types.TensorDims.BATCH:
@@ -213,6 +210,7 @@ pytree.register_pytree_node(
213
210
  serialized_type_name="",
214
211
  )
215
212
 
213
+
216
214
  def update(
217
215
  cache: KVCacheEntry,
218
216
  input_pos: torch.Tensor,
@@ -20,6 +20,7 @@ from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
20
  from ai_edge_torch.generative.examples.deepseek import deepseek
21
21
  from ai_edge_torch.generative.examples.gemma import gemma1
22
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.examples.hammer import hammer
23
24
  from ai_edge_torch.generative.examples.llama import llama
24
25
  from ai_edge_torch.generative.examples.openelm import openelm
25
26
  from ai_edge_torch.generative.examples.paligemma import decoder
@@ -148,6 +149,12 @@ class TestModelConversion(googletest.TestCase):
148
149
  pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
149
150
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
150
151
 
152
+ def test_hammer(self):
153
+ config = hammer.get_fake_model_config()
154
+ pytorch_model = hammer.Hammer(config).eval()
155
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
156
+
157
+
151
158
  def test_amd_llama_135m(self):
152
159
  config = amd_llama_135m.get_fake_model_config()
153
160
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
@@ -81,12 +81,17 @@ def define_conversion_flags(model_name: str):
81
81
  'If set, the model will be converted with the provided list of LoRA'
82
82
  ' ranks.',
83
83
  )
84
+ flags.DEFINE_bool(
85
+ 'mask_as_input',
86
+ False,
87
+ 'If true, the mask will be passed in as input. Otherwise, mask will be '
88
+ 'built by the model internally.',
89
+ )
84
90
  flags.DEFINE_bool(
85
91
  'transpose_kv_cache',
86
92
  False,
87
- 'If set, the model will be converted with transposed KV cache.',
93
+ 'If true, the model will be converted with transposed KV cache.',
88
94
  )
89
-
90
95
  return flags
91
96
 
92
97
 
@@ -14,8 +14,11 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Config for customizing model export process."""
17
+
17
18
  import dataclasses
18
19
  from typing import List, Optional
20
+
21
+ from absl import flags
19
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
23
  import torch
21
24
 
@@ -38,3 +41,30 @@ class ExportConfig:
38
41
  kvcache_cls: type = kv_utils.KVCache
39
42
  # The batch size of the decode signature.
40
43
  decode_batch_size: int = 1
44
+
45
+
46
+ def _build_mask(mask_len, kv_cache_max_len) -> torch.Tensor:
47
+ if isinstance(mask_len, list):
48
+ return [_build_mask(i, kv_cache_max_len) for i in mask_len]
49
+
50
+ mask = torch.full(
51
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
52
+ )
53
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
54
+ return mask
55
+
56
+
57
+ def get_from_flags() -> ExportConfig:
58
+ """Builds an export config according to the commandline flags."""
59
+ export_config = ExportConfig()
60
+
61
+ if flags.FLAGS.mask_as_input:
62
+ export_config.prefill_mask = _build_mask(
63
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
64
+ )
65
+ export_config.decode_mask = _build_mask(1, flags.FLAGS.kv_cache_max_len)
66
+
67
+ if flags.FLAGS.transpose_kv_cache:
68
+ export_config.kvcache_layout = kv_utils.KV_LAYOUT_TRANSPOSED
69
+
70
+ return export_config
ai_edge_torch/model.py CHANGED
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
 
23
23
  import abc
24
24
  import re
25
+ import os
25
26
  from typing import Callable
26
27
 
27
28
  import numpy.typing as npt
@@ -154,6 +155,7 @@ class TfLiteModel(Model):
154
155
  Args:
155
156
  path: The path to file to which the model is serialized.
156
157
  """
158
+ os.makedirs(os.path.dirname(path), exist_ok=True)
157
159
  with open(path, 'wb') as file_handle:
158
160
  file_handle.write(self._tflite_model)
159
161
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250425"
16
+ __version__ = "0.5.0.dev20250426"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250425
3
+ Version: 0.5.0.dev20250426
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,8 +1,8 @@
1
1
  ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=_aF64u6MXH8zPBTEg6odQq2WazbUIxQYlfJNXzfkMdM,706
4
+ ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
+ ai_edge_torch/version.py,sha256=6qv9zJ0Z2J_RJ-E0S1o1-u2sbxvuuPUWnJcxWhmQEWg,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -53,7 +53,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
53
53
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
54
54
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
55
55
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
56
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=1wz4h3bjyX2qMRZ310UKGNYTORegzxinVFmYz2Fupm4,2666
56
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
57
57
  ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
58
58
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
59
59
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -65,15 +65,19 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
65
65
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
66
66
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
67
67
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
68
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
68
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
69
69
  ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
70
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
71
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
72
72
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
73
73
  ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
74
+ ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
+ ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
76
+ ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
77
+ ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
74
78
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=A4uLUdqvU1NKo3seqZlWSS3fqYahnEKqNBQBJO6yXvE,1762
76
- ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
79
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=nz5h4m8bVnw8P7OEtqhA_fKfvaRzxhT2_75vkFCqHmU,1735
80
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=H7I5iNhIJ55gb0-9k7g-FPcG2IlthnA9XMR8qd__5bQ,6621
77
81
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
78
82
  ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
79
83
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
@@ -93,17 +97,17 @@ ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4I
93
97
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
94
98
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
95
99
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
96
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=Y2qaObMJeh9UABkUI7FBm4sCGi2YMQhsj0CSOS2fYek,1540
97
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=TuGW_FPMs0pV7ZBe46FfaDrlfte4Dz75vGHmBOCFfww,1538
98
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=VZe7OQ54dgOGWe74XT2W7zZBm5uJaeIF8ZuNakkL0iA,1539
99
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
100
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=ddo52Inl5ub81q460cEyKhnsC3txellRErut-_qtBbM,6949
101
- ai_edge_torch/generative/examples/phi/phi4.py,sha256=OkMwLGe8l2JEAgOFi19AdbNBl1xp1djZBZo8MJP58ho,5732
100
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=k-0ZC-_zZZmkdcc6dr1QGXfX9lDZZXRQSuc6wT0n3Is,1514
101
+ ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=5KSJRySjSc89FriCOnfBabD8zRLUcGAw3L0VInuJFUY,1512
102
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=wVIdGenHTi9xUffYddN_uXWMBO2tgo1e_hU4OG_NmHA,1513
103
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=X9MfjK8rmyRSrfNzIaKQNSgqLM5_CBH-BrLFX_7BWL8,3494
104
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=65Dbv8cA4WFdluflHQHzgDmDFjdmc6rxMO4hQukaxKU,6978
105
+ ai_edge_torch/generative/examples/phi/phi4.py,sha256=y3CCZCW4MnvX74d4MNERRuQBE0p5dquC2M9vDXXqnZI,5760
102
106
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
103
107
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
104
108
  ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
105
109
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
106
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=-Xe5koexhNUkWjS2XgS9Ggg7XOQAlMO8QcBJRTNjJa4,2972
110
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVsJ9Obl7PBhMgd3a0T1t8dqoPp_VzZaQ,1776
107
111
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
108
112
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
109
113
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
@@ -115,9 +119,9 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
115
119
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
116
120
  ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
117
121
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
118
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=IjV0jriRKlF9aV5yLjtONjACb4_VxNIAGk9w1sr_hmc,1748
119
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=Wa_3OWXcM389iySwS5E47uCYZaTj6h-4RTP_Xi2-1aE,1721
120
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
122
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=jTM_tndbDqzq19uLz2n71S7M81L1Y6R7oVBPsMcYGzk,1785
123
+ ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=wU72MzpUIi2mQ8ZODW1x4L5KZPWvuXyB-_Eqo-RKqFw,1757
124
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=SFE8fIJx7Y_oan0vXSmhEmI0Ib2HD3k9cyKLU_4MxfI,3807
121
125
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
122
126
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
123
127
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
@@ -143,8 +147,8 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
143
147
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
144
148
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=-z5tkQzGHbo37eAl9sDAJuT1Egxm8xI9CZmYLcmqIfU,4761
145
149
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
146
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=LPxg7mAJ_aAUIx6eE5bxixPA8Ep9Vul0CWJoNcrD5oE,1565
147
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
150
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=XM-dCBW2HG6FlwwPjlJi0I_TEaVqdv7aWpFEv-XUdLc,1539
151
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=6Qhml-XB8_RjQdYN948OaSsPJNrfi-Mr7PFB73C79Ug,2828
148
152
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
149
153
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
150
154
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
@@ -153,7 +157,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWn
153
157
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
154
158
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
155
159
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
156
- ai_edge_torch/generative/layers/kv_cache.py,sha256=WNH_Ab29eXKXs8HAm3Wmdv_LBzO6PQW5d34Eo6Yzgd0,8492
160
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=dDeirtuo9AnlN1tYoLbFi_pKhIDmn35FQY1m6X28hSY,8468
157
161
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
158
162
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
159
163
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
@@ -179,12 +183,12 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0
179
183
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
180
184
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
181
185
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
182
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=6LkLnFOvlnt7JVVDYKMaZClPRBEvdjq6xnSjIFYNdI8,12554
186
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hPmWpg41ZMWwBsngTykRVzRPHtpbkwiLM,12811
183
187
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
184
188
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
185
189
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
186
- ai_edge_torch/generative/utilities/converter.py,sha256=z3CvNJxKzglu1BU_5ri91RUeGHh7urhoWFbk0oq7i2M,10768
187
- ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
190
+ ai_edge_torch/generative/utilities/converter.py,sha256=4RNNl7vk3WN_JG5EZajofiRSqtPnUNCYosxTacdEOto,10948
191
+ ai_edge_torch/generative/utilities/export_config.py,sha256=maUVt0T5FsLpHO5H-BZ-O0FRBZO_ejKwGhPR9Qq8ViM,2490
188
192
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
189
193
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
190
194
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
@@ -242,8 +246,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
242
246
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
243
247
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
244
248
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
245
- ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
- ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/METADATA,sha256=owGeoLcv0XFf4tXFatFjXLSisoaRBBwrtyLx3LFq8PM,2051
247
- ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
- ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
- ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/RECORD,,
249
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
250
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/METADATA,sha256=y_g3V3S_WlYlEmSNZWmP4kV5f_A1Nynk77VwS8qL_X0,2051
251
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
252
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
253
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/RECORD,,