ai-edge-torch-nightly 0.3.0.dev20250130__py3-none-any.whl → 0.3.0.dev20250201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,91 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building decoder for Qwen 2.5 VL models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
23
+
24
+
25
+ class Decoder(model_builder.DecoderOnlyModel):
26
+ """A decoder for Qwen-VL model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
30
+ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Qwen 2.5 VL 3B model.
32
+
33
+ Args:
34
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35
+ is 1024.
36
+
37
+ Returns:
38
+ The model config for a Qwen 2.5 VL 3B model.
39
+ """
40
+ attn_config = cfg.AttentionConfig(
41
+ num_heads=16,
42
+ head_dim=128,
43
+ num_query_groups=2,
44
+ rotary_base=1000000,
45
+ rotary_percentage=1.0,
46
+ qkv_use_bias=True,
47
+ )
48
+ ff_config = cfg.FeedForwardConfig(
49
+ type=cfg.FeedForwardType.GATED,
50
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51
+ intermediate_size=11008,
52
+ )
53
+ norm_config = cfg.NormalizationConfig(
54
+ type=cfg.NormalizationType.RMS_NORM,
55
+ epsilon=1e-06,
56
+ )
57
+ block_config = cfg.TransformerBlockConfig(
58
+ attn_config=attn_config,
59
+ ff_config=ff_config,
60
+ pre_attention_norm_config=norm_config,
61
+ post_attention_norm_config=norm_config,
62
+ )
63
+ config = cfg.ModelConfig(
64
+ vocab_size=151936,
65
+ num_layers=36,
66
+ max_seq_len=32768,
67
+ embedding_dim=2048,
68
+ kv_cache_max_len=kv_cache_max_len,
69
+ block_configs=block_config,
70
+ final_norm_config=norm_config,
71
+ enable_hlfb=True,
72
+ )
73
+ return config
74
+
75
+
76
+ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
77
+ config = get_decoder_config(**kwargs)
78
+ config.vocab_size = 128
79
+ config.num_layers = 2
80
+ # Decoder has only one block config.
81
+ config.block_config(0).ff_config.intermediate_size = 64
82
+ return config
83
+
84
+
85
+ def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
86
+ return model_builder.build_decoder_only_model(
87
+ checkpoint_path=checkpoint_path,
88
+ config=get_decoder_config(**kwargs),
89
+ tensor_names=TENSOR_NAMES,
90
+ model_class=Decoder,
91
+ )
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of Qwen 2.5 VL 3B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from ai_edge_torch.generative.examples.qwen_vl import decoder
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import torch
25
+ import transformers
26
+
27
+
28
+ class DecoderWrapper(verifier.ModelWrapper):
29
+ """Wraps the decoder of Qwen 2.5 VL models for verification."""
30
+
31
+ def __init__(self, model: torch.nn.Module, lm_head: torch.nn.Module):
32
+ super().__init__(model)
33
+ self.lm_head = lm_head
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ output = self.model.forward(tokens)
37
+ return self.lm_head(output["last_hidden_state"])
38
+
39
+
40
+ def main(_):
41
+ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = (
44
+ transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ checkpoint
46
+ )
47
+ )
48
+
49
+ # Locate the cached dir.
50
+ cached_config_file = transformers.utils.cached_file(
51
+ checkpoint, transformers.utils.CONFIG_NAME
52
+ )
53
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
54
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
55
+ reauthored_model = decoder.build_decoder(reauthored_checkpoint)
56
+
57
+ # Verify the reauthored model only with input IDs because the original decoder
58
+ # does not support generate() with prompts.
59
+ input_ids = [1, 2, 3, 4]
60
+ try:
61
+ verifier.verify_with_input_ids(
62
+ original_model=DecoderWrapper(
63
+ original_model.model,
64
+ original_model.lm_head,
65
+ ),
66
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
67
+ input_ids=input_ids,
68
+ atol=1e-04,
69
+ )
70
+ except AssertionError as e:
71
+ logging.error("*** FAILED *** verify with input IDs: %s", e)
72
+ else:
73
+ logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ app.run(main)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250130"
16
+ __version__ = "0.3.0.dev20250201"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250130
3
+ Version: 0.3.0.dev20250201
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=vazuF44bm-lx-G2_sGZg9gZb_rU9JOa2jdpJlTzoqwE,706
5
+ ai_edge_torch/version.py,sha256=9qgk7SSH80z-tSMSDRr1M3FiR7a69U40_TeqElVeor0,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -93,6 +93,9 @@ ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY
93
93
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehdak9-5DDisACs9VlTwr8eFwcjQ_kZxgc,2776
94
94
  ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
95
95
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
96
+ ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
97
+ ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=rD_Ch5CzuXeatqv0C3z8vU-zou1z9QDUhoB6V4YTPIg,2829
98
+ ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=FEY_PifD9fQGnERzSOljFLraRIbUVF3XTnCv95A30Cs,2602
96
99
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
100
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
98
101
  ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
@@ -222,8 +225,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
222
225
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
223
226
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
224
227
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
225
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/METADATA,sha256=LQAYOwR4xu0c0DXpvjDL0Ph5rP3ZIJCrpaSGllCAkqI,1966
227
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
- ai_edge_torch_nightly-0.3.0.dev20250130.dist-info/RECORD,,
228
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
229
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/METADATA,sha256=oWpE5lBKL49aWAyb11Xi_eO55pPZZzW2M9gUH_lzVlg,1966
230
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
231
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
232
+ ai_edge_torch_nightly-0.3.0.dev20250201.dist-info/RECORD,,