ai-edge-torch-nightly 0.3.0.dev20240928__py3-none-any.whl → 0.3.0.dev20240930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Qwen 2.5 models to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '3b',
29
+ ['0.5b', '1.5b', '3b'],
30
+ 'The size of the model to convert.',
31
+ )
32
+ _CHECKPOINT_PATH = flags.DEFINE_string(
33
+ 'checkpoint_path',
34
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
35
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
36
+ )
37
+ _TFLITE_PATH = flags.DEFINE_string(
38
+ 'tflite_path',
39
+ '/tmp/',
40
+ 'The tflite file path to export.',
41
+ )
42
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
+ 'prefill_seq_len',
44
+ 1024,
45
+ 'The maximum size of prefill input tensor.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+
58
+ _BUILDER = {
59
+ '0.5b': qwen.build_0_5b_model,
60
+ '1.5b': qwen.build_1_5b_model,
61
+ '3b': qwen.build_3b_model,
62
+ }
63
+
64
+
65
+ def main(_):
66
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
+ )
69
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
+ model_size = _MODEL_SIZE.value.replace('.', '_')
71
+ output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
72
+ converter.convert_to_tflite(
73
+ pytorch_model,
74
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
75
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
76
+ quantize=_QUANTIZE.value,
77
+ )
78
+
79
+
80
+ if __name__ == '__main__':
81
+ app.run(main)
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Qwen 2.5 models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
+ import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
+ # Qwen re-uses the embedding as the head projection layer.
27
+ TENSOR_NAMES.lm_head = None
28
+
29
+
30
+ class Qwen(tiny_llama.TinyLlama):
31
+ """A Qwen model built from the Edge Generative API layers.
32
+
33
+ Qwen 2.5 shares the same architecture as TinyLlama.
34
+ """
35
+
36
+ def __init__(self, config: cfg.ModelConfig):
37
+ super().__init__(config)
38
+ # Qwen re-uses the embedding as the head projection layer.
39
+ self.lm_head.weight.data = self.tok_embedding.weight.data
40
+
41
+
42
+ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
43
+ """Returns the model config for a Qwen 2.5 3B model.
44
+
45
+ Args:
46
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
47
+ is 1024.
48
+
49
+ Returns:
50
+ The model config for a SmolLM model.
51
+ """
52
+ attn_config = cfg.AttentionConfig(
53
+ num_heads=16,
54
+ head_dim=128,
55
+ num_query_groups=2,
56
+ rotary_base=1000000,
57
+ rotary_percentage=1.0,
58
+ qkv_use_bias=True,
59
+ )
60
+ ff_config = cfg.FeedForwardConfig(
61
+ type=cfg.FeedForwardType.GATED,
62
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
63
+ intermediate_size=11008,
64
+ )
65
+ norm_config = cfg.NormalizationConfig(
66
+ type=cfg.NormalizationType.RMS_NORM,
67
+ epsilon=1e-06,
68
+ )
69
+ block_config = cfg.TransformerBlockConfig(
70
+ attn_config=attn_config,
71
+ ff_config=ff_config,
72
+ pre_attention_norm_config=norm_config,
73
+ post_attention_norm_config=norm_config,
74
+ )
75
+ config = cfg.ModelConfig(
76
+ vocab_size=151936,
77
+ num_layers=36,
78
+ max_seq_len=32768,
79
+ embedding_dim=2048,
80
+ kv_cache_max_len=kv_cache_max_len,
81
+ block_configs=block_config,
82
+ final_norm_config=norm_config,
83
+ enable_hlfb=True,
84
+ )
85
+ return config
86
+
87
+
88
+ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
89
+ """Returns the model config for a Qwen 2.5 1B model."""
90
+ config = get_3b_model_config(kv_cache_max_len)
91
+ # Qwen has only one block config.
92
+ block_config = config.block_config(0)
93
+ block_config.attn_config.num_heads = 12
94
+ block_config.ff_config.intermediate_size = 8960
95
+ config.num_layers = 28
96
+ config.embedding_dim = 1536
97
+ return config
98
+
99
+
100
+ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
101
+ """Returns the model config for a Qwen 2.5 0.5B model."""
102
+ config = get_3b_model_config(kv_cache_max_len)
103
+ # Qwen has only one block config.
104
+ block_config = config.block_config(0)
105
+ block_config.attn_config.num_heads = 14
106
+ block_config.attn_config.head_dim = 64
107
+ block_config.ff_config.intermediate_size = 4864
108
+ config.num_layers = 24
109
+ config.embedding_dim = 896
110
+ return config
111
+
112
+
113
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
114
+ config = get_3b_model_config(**kwargs)
115
+ config.vocab_size = 128
116
+ config.num_layers = 2
117
+ # Qwen has only one block config.
118
+ config.block_config(0).ff_config.intermediate_size = 64
119
+ return config
120
+
121
+
122
+ def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
+ model = Qwen(config)
124
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
+ # Since embedding and lm-head use the same weight, we need to set strict
126
+ # to False.
127
+ loader.load(model, strict=False)
128
+ model.eval()
129
+ return model
130
+
131
+
132
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
+ return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
134
+
135
+
136
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
+ return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
138
+
139
+
140
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
+ return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
@@ -0,0 +1,88 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _MODEL_SIZE = flags.DEFINE_enum(
30
+ "model_size",
31
+ "3b",
32
+ ["0.5b", "1.5b", "3b"],
33
+ "The size of the model to verify.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_multi_string(
36
+ "prompts",
37
+ "What is the meaning of life?",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+ _CHECKPOINT = {
47
+ "0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
48
+ "1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
49
+ "3b": "Qwen/Qwen2.5-3B-Instruct",
50
+ }
51
+
52
+ _BUILDER = {
53
+ "0.5b": qwen.build_0_5b_model,
54
+ "1.5b": qwen.build_1_5b_model,
55
+ "3b": qwen.build_3b_model,
56
+ }
57
+
58
+
59
+ def main(_):
60
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
61
+ logging.info("Loading the original model from: %s", checkpoint)
62
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
63
+
64
+ # Locate the cached dir.
65
+ cached_config_file = transformers.utils.cached_file(
66
+ checkpoint, transformers.utils.CONFIG_NAME
67
+ )
68
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
71
+
72
+ logging.info("Loading the tokenizer from: %s", checkpoint)
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
+
75
+ verifier.verify_reauthored_model(
76
+ original_model=transformers_verifier.TransformersModelWrapper(
77
+ original_model
78
+ ),
79
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
81
+ generate_prompts=_PROMPTS.value,
82
+ max_new_tokens=_MAX_NEW_TOKENS.value,
83
+ atol=1e-04,
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ app.run(main)
@@ -23,6 +23,7 @@ from ai_edge_torch.generative.examples.llama import llama
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
24
  from ai_edge_torch.generative.examples.phi import phi2
25
25
  from ai_edge_torch.generative.examples.phi import phi3
26
+ from ai_edge_torch.generative.examples.qwen import qwen
26
27
  from ai_edge_torch.generative.examples.smollm import smollm
27
28
  from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
28
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
@@ -152,6 +153,15 @@ class TestModelConversion(googletest.TestCase):
152
153
  pytorch_model = openelm.OpenELM(config).eval()
153
154
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
154
155
 
156
+ @googletest.skipIf(
157
+ ai_edge_config.Config.use_torch_xla,
158
+ reason="tests with custom ops are not supported on oss",
159
+ )
160
+ def test_qwen(self):
161
+ config = qwen.get_fake_model_config()
162
+ pytorch_model = qwen.Qwen(config).eval()
163
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
164
+
155
165
  @googletest.skipIf(
156
166
  ai_edge_config.Config.use_torch_xla,
157
167
  reason="tests with custom ops are not supported on oss",
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240928"
16
+ __version__ = "0.3.0.dev20240930"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240928
3
+ Version: 0.3.0.dev20240930
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=YiCjdglLzSPYyRq64U8zJSgWDFqJs-t2JSzuA0bYYzA,706
6
+ ai_edge_torch/version.py,sha256=sc79FL0Fo_EhFaJ-4XlyIRT3Kjmn_oUXzMtvcUeXZDo,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -63,6 +63,10 @@ ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_
63
63
  ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
64
64
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
65
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
66
+ ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
68
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=b03q1On6JzPhJzTs1dQwT_tJjO7C9NYmyzrzV2kQ_yo,4579
69
+ ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
66
70
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
71
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
68
72
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
@@ -121,7 +125,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
121
125
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
122
126
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
123
127
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
124
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=kCm-L3rWbPj25E_QEbkSLiaCk3y23SjKJs-MG-EwKug,8545
128
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=SBGHbY8-k7kSEEv-WQQlxGIYtJEVBIbjJPygGdDg9Qg,8921
125
129
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
126
130
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
127
131
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -177,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
177
181
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
178
182
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
179
183
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
180
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
181
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/METADATA,sha256=3HuAFZTfvmU787dVypwpmUvo4DdZSekGsqGimO-oPfM,1897
182
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
183
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
184
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/RECORD,,
184
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/METADATA,sha256=sFwzdW-SmFri3T15doM8okOIAWKw_GX4xWzAlO6GK7o,1897
186
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/RECORD,,