ai-edge-torch-nightly 0.3.0.dev20240928__py3-none-any.whl → 0.3.0.dev20240930__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,81 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Qwen 2.5 models to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _MODEL_SIZE = flags.DEFINE_enum(
27
+ 'model_size',
28
+ '3b',
29
+ ['0.5b', '1.5b', '3b'],
30
+ 'The size of the model to convert.',
31
+ )
32
+ _CHECKPOINT_PATH = flags.DEFINE_string(
33
+ 'checkpoint_path',
34
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
35
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
36
+ )
37
+ _TFLITE_PATH = flags.DEFINE_string(
38
+ 'tflite_path',
39
+ '/tmp/',
40
+ 'The tflite file path to export.',
41
+ )
42
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
43
+ 'prefill_seq_len',
44
+ 1024,
45
+ 'The maximum size of prefill input tensor.',
46
+ )
47
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48
+ 'kv_cache_max_len',
49
+ 1280,
50
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
51
+ )
52
+ _QUANTIZE = flags.DEFINE_bool(
53
+ 'quantize',
54
+ True,
55
+ 'Whether the model should be quantized.',
56
+ )
57
+
58
+ _BUILDER = {
59
+ '0.5b': qwen.build_0_5b_model,
60
+ '1.5b': qwen.build_1_5b_model,
61
+ '3b': qwen.build_3b_model,
62
+ }
63
+
64
+
65
+ def main(_):
66
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
+ )
69
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
+ model_size = _MODEL_SIZE.value.replace('.', '_')
71
+ output_filename = f'qwen_{model_size}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
72
+ converter.convert_to_tflite(
73
+ pytorch_model,
74
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
75
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
76
+ quantize=_QUANTIZE.value,
77
+ )
78
+
79
+
80
+ if __name__ == '__main__':
81
+ app.run(main)
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Qwen 2.5 models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
+ import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
+ # Qwen re-uses the embedding as the head projection layer.
27
+ TENSOR_NAMES.lm_head = None
28
+
29
+
30
+ class Qwen(tiny_llama.TinyLlama):
31
+ """A Qwen model built from the Edge Generative API layers.
32
+
33
+ Qwen 2.5 shares the same architecture as TinyLlama.
34
+ """
35
+
36
+ def __init__(self, config: cfg.ModelConfig):
37
+ super().__init__(config)
38
+ # Qwen re-uses the embedding as the head projection layer.
39
+ self.lm_head.weight.data = self.tok_embedding.weight.data
40
+
41
+
42
+ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
43
+ """Returns the model config for a Qwen 2.5 3B model.
44
+
45
+ Args:
46
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
47
+ is 1024.
48
+
49
+ Returns:
50
+ The model config for a SmolLM model.
51
+ """
52
+ attn_config = cfg.AttentionConfig(
53
+ num_heads=16,
54
+ head_dim=128,
55
+ num_query_groups=2,
56
+ rotary_base=1000000,
57
+ rotary_percentage=1.0,
58
+ qkv_use_bias=True,
59
+ )
60
+ ff_config = cfg.FeedForwardConfig(
61
+ type=cfg.FeedForwardType.GATED,
62
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
63
+ intermediate_size=11008,
64
+ )
65
+ norm_config = cfg.NormalizationConfig(
66
+ type=cfg.NormalizationType.RMS_NORM,
67
+ epsilon=1e-06,
68
+ )
69
+ block_config = cfg.TransformerBlockConfig(
70
+ attn_config=attn_config,
71
+ ff_config=ff_config,
72
+ pre_attention_norm_config=norm_config,
73
+ post_attention_norm_config=norm_config,
74
+ )
75
+ config = cfg.ModelConfig(
76
+ vocab_size=151936,
77
+ num_layers=36,
78
+ max_seq_len=32768,
79
+ embedding_dim=2048,
80
+ kv_cache_max_len=kv_cache_max_len,
81
+ block_configs=block_config,
82
+ final_norm_config=norm_config,
83
+ enable_hlfb=True,
84
+ )
85
+ return config
86
+
87
+
88
+ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
89
+ """Returns the model config for a Qwen 2.5 1B model."""
90
+ config = get_3b_model_config(kv_cache_max_len)
91
+ # Qwen has only one block config.
92
+ block_config = config.block_config(0)
93
+ block_config.attn_config.num_heads = 12
94
+ block_config.ff_config.intermediate_size = 8960
95
+ config.num_layers = 28
96
+ config.embedding_dim = 1536
97
+ return config
98
+
99
+
100
+ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
101
+ """Returns the model config for a Qwen 2.5 0.5B model."""
102
+ config = get_3b_model_config(kv_cache_max_len)
103
+ # Qwen has only one block config.
104
+ block_config = config.block_config(0)
105
+ block_config.attn_config.num_heads = 14
106
+ block_config.attn_config.head_dim = 64
107
+ block_config.ff_config.intermediate_size = 4864
108
+ config.num_layers = 24
109
+ config.embedding_dim = 896
110
+ return config
111
+
112
+
113
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
114
+ config = get_3b_model_config(**kwargs)
115
+ config.vocab_size = 128
116
+ config.num_layers = 2
117
+ # Qwen has only one block config.
118
+ config.block_config(0).ff_config.intermediate_size = 64
119
+ return config
120
+
121
+
122
+ def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module:
123
+ model = Qwen(config)
124
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
125
+ # Since embedding and lm-head use the same weight, we need to set strict
126
+ # to False.
127
+ loader.load(model, strict=False)
128
+ model.eval()
129
+ return model
130
+
131
+
132
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
133
+ return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
134
+
135
+
136
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
137
+ return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs))
138
+
139
+
140
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
141
+ return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs))
@@ -0,0 +1,88 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Qwen 2.5 0.5B, 1.5B, and 3B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.qwen import qwen
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _MODEL_SIZE = flags.DEFINE_enum(
30
+ "model_size",
31
+ "3b",
32
+ ["0.5b", "1.5b", "3b"],
33
+ "The size of the model to verify.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_multi_string(
36
+ "prompts",
37
+ "What is the meaning of life?",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+ _CHECKPOINT = {
47
+ "0.5b": "Qwen/Qwen2.5-0.5B-Instruct",
48
+ "1.5b": "Qwen/Qwen2.5-1.5B-Instruct",
49
+ "3b": "Qwen/Qwen2.5-3B-Instruct",
50
+ }
51
+
52
+ _BUILDER = {
53
+ "0.5b": qwen.build_0_5b_model,
54
+ "1.5b": qwen.build_1_5b_model,
55
+ "3b": qwen.build_3b_model,
56
+ }
57
+
58
+
59
+ def main(_):
60
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
61
+ logging.info("Loading the original model from: %s", checkpoint)
62
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
63
+
64
+ # Locate the cached dir.
65
+ cached_config_file = transformers.utils.cached_file(
66
+ checkpoint, transformers.utils.CONFIG_NAME
67
+ )
68
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
69
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
70
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
71
+
72
+ logging.info("Loading the tokenizer from: %s", checkpoint)
73
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
74
+
75
+ verifier.verify_reauthored_model(
76
+ original_model=transformers_verifier.TransformersModelWrapper(
77
+ original_model
78
+ ),
79
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
80
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
81
+ generate_prompts=_PROMPTS.value,
82
+ max_new_tokens=_MAX_NEW_TOKENS.value,
83
+ atol=1e-04,
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ app.run(main)
@@ -23,6 +23,7 @@ from ai_edge_torch.generative.examples.llama import llama
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
24
  from ai_edge_torch.generative.examples.phi import phi2
25
25
  from ai_edge_torch.generative.examples.phi import phi3
26
+ from ai_edge_torch.generative.examples.qwen import qwen
26
27
  from ai_edge_torch.generative.examples.smollm import smollm
27
28
  from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
28
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
@@ -152,6 +153,15 @@ class TestModelConversion(googletest.TestCase):
152
153
  pytorch_model = openelm.OpenELM(config).eval()
153
154
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
154
155
 
156
+ @googletest.skipIf(
157
+ ai_edge_config.Config.use_torch_xla,
158
+ reason="tests with custom ops are not supported on oss",
159
+ )
160
+ def test_qwen(self):
161
+ config = qwen.get_fake_model_config()
162
+ pytorch_model = qwen.Qwen(config).eval()
163
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
164
+
155
165
  @googletest.skipIf(
156
166
  ai_edge_config.Config.use_torch_xla,
157
167
  reason="tests with custom ops are not supported on oss",
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240928"
16
+ __version__ = "0.3.0.dev20240930"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240928
3
+ Version: 0.3.0.dev20240930
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=YiCjdglLzSPYyRq64U8zJSgWDFqJs-t2JSzuA0bYYzA,706
6
+ ai_edge_torch/version.py,sha256=sc79FL0Fo_EhFaJ-4XlyIRT3Kjmn_oUXzMtvcUeXZDo,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -63,6 +63,10 @@ ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_
63
63
  ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
64
64
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
65
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
66
+ ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
68
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=b03q1On6JzPhJzTs1dQwT_tJjO7C9NYmyzrzV2kQ_yo,4579
69
+ ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
66
70
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
71
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
68
72
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
@@ -121,7 +125,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
121
125
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
122
126
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
123
127
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
124
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=kCm-L3rWbPj25E_QEbkSLiaCk3y23SjKJs-MG-EwKug,8545
128
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=SBGHbY8-k7kSEEv-WQQlxGIYtJEVBIbjJPygGdDg9Qg,8921
125
129
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
126
130
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
127
131
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -177,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
177
181
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
178
182
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
179
183
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
180
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
181
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/METADATA,sha256=3HuAFZTfvmU787dVypwpmUvo4DdZSekGsqGimO-oPfM,1897
182
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
183
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
184
- ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/RECORD,,
184
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/METADATA,sha256=sFwzdW-SmFri3T15doM8okOIAWKw_GX4xWzAlO6GK7o,1897
186
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
+ ai_edge_torch_nightly-0.3.0.dev20240930.dist-info/RECORD,,