ai-edge-torch-nightly 0.3.0.dev20250130__py3-none-any.whl → 0.3.0.dev20250201__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,91 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building decoder for Qwen 2.5 VL models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
23
+
24
+
25
+ class Decoder(model_builder.DecoderOnlyModel):
26
+ """A decoder for Qwen-VL model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
30
+ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Qwen 2.5 VL 3B model.
32
+
33
+ Args:
34
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35
+ is 1024.
36
+
37
+ Returns:
38
+ The model config for a Qwen 2.5 VL 3B model.
39
+ """
40
+ attn_config = cfg.AttentionConfig(
41
+ num_heads=16,
42
+ head_dim=128,
43
+ num_query_groups=2,
44
+ rotary_base=1000000,
45
+ rotary_percentage=1.0,
46
+ qkv_use_bias=True,
47
+ )
48
+ ff_config = cfg.FeedForwardConfig(
49
+ type=cfg.FeedForwardType.GATED,
50
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51
+ intermediate_size=11008,
52
+ )
53
+ norm_config = cfg.NormalizationConfig(
54
+ type=cfg.NormalizationType.RMS_NORM,
55
+ epsilon=1e-06,
56
+ )
57
+ block_config = cfg.TransformerBlockConfig(
58
+ attn_config=attn_config,
59
+ ff_config=ff_config,
60
+ pre_attention_norm_config=norm_config,
61
+ post_attention_norm_config=norm_config,
62
+ )
63
+ config = cfg.ModelConfig(
64
+ vocab_size=151936,
65
+ num_layers=36,
66
+ max_seq_len=32768,
67
+ embedding_dim=2048,
68
+ kv_cache_max_len=kv_cache_max_len,
69
+ block_configs=block_config,
70
+ final_norm_config=norm_config,
71
+ enable_hlfb=True,
72
+ )
73
+ return config
74
+
75
+
76
+ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
77
+ config = get_decoder_config(**kwargs)
78
+ config.vocab_size = 128
79
+ config.num_layers = 2
80
+ # Decoder has only one block config.
81
+ config.block_config(0).ff_config.intermediate_size = 64
82
+ return config
83
+
84
+
85
+ def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
86
+ return model_builder.build_decoder_only_model(
87
+ checkpoint_path=checkpoint_path,
88
+ config=get_decoder_config(**kwargs),
89
+ tensor_names=TENSOR_NAMES,
90
+ model_class=Decoder,
91
+ )
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of Qwen 2.5 VL 3B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from ai_edge_torch.generative.examples.qwen_vl import decoder
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import torch
25
+ import transformers
26
+
27
+
28
+ class DecoderWrapper(verifier.ModelWrapper):
29
+ """Wraps the decoder of Qwen 2.5 VL models for verification."""
30
+
31
+ def __init__(self, model: torch.nn.Module, lm_head: torch.nn.Module):
32
+ super().__init__(model)
33
+ self.lm_head = lm_head
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ output = self.model.forward(tokens)
37
+ return self.lm_head(output["last_hidden_state"])
38
+
39
+
40
+ def main(_):
41
+ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = (
44
+ transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ checkpoint
46
+ )
47
+ )
48
+
49
+ # Locate the cached dir.
50
+ cached_config_file = transformers.utils.cached_file(
51
+ checkpoint, transformers.utils.CONFIG_NAME
52
+ )
53
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
54
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
55
+ reauthored_model = decoder.build_decoder(reauthored_checkpoint)
56
+
57
+ # Verify the reauthored model only with input IDs because the original decoder
58
+ # does not support generate() with prompts.
59
+ input_ids = [1, 2, 3, 4]
60
+ try:
61
+ verifier.verify_with_input_ids(
62
+ original_model=DecoderWrapper(
63
+ original_model.model,
64
+ original_model.lm_head,
65
+ ),
66
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
67
+ input_ids=input_ids,
68
+ atol=1e-04,
69
+ )
70
+ except AssertionError as e:
71
+ logging.error("*** FAILED *** verify with input IDs: %s", e)
72
+ else:
73
+ logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ app.run(main)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250130"
16
+ __version__ = "0.3.0.dev20250201"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250130
3
+ Version: 0.3.0.dev20250201
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=vazuF44bm-lx-G2_sGZg9gZb_rU9JOa2jdpJlTzoqwE,706
5
+ ai_edge_torch/version.py,sha256=9qgk7SSH80z-tSMSDRr1M3FiR7a69U40_TeqElVeor0,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -93,6 +93,9 @@ ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY
93
93
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehdak9-5DDisACs9VlTwr8eFwcjQ_kZxgc,2776
94
94
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
95
95
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
96
+ ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
97
+ ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=rD_Ch5CzuXeatqv0C3z8vU-zou1z9QDUhoB6V4YTPIg,2829
98
+ ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=FEY_PifD9fQGnERzSOljFLraRIbUVF3XTnCv95A30Cs,2602
96
99
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
100
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
98
101
  ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
@@ -222,8 +225,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
222
225
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
223
226
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
224
227
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
225
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/METADATA,sha256=LQAYOwR4xu0c0DXpvjDL0Ph5rP3ZIJCrpaSGllCAkqI,1966
227
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/RECORD,,
228
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
229
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/METADATA,sha256=oWpE5lBKL49aWAyb11Xi_eO55pPZZzW2M9gUH_lzVlg,1966
230
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
231
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
232
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/RECORD,,