ai-edge-torch-nightly 0.3.0.dev20250129__py3-none-any.whl → 0.3.0.dev20250131__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/qwen_vl/__init__.py +14 -0
- ai_edge_torch/generative/examples/qwen_vl/decoder.py +91 -0
- ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py +77 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250129.dist-info → ai_edge_torch_nightly-0.3.0.dev20250131.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20250129.dist-info → ai_edge_torch_nightly-0.3.0.dev20250131.dist-info}/RECORD +9 -6
- {ai_edge_torch_nightly-0.3.0.dev20250129.dist-info → ai_edge_torch_nightly-0.3.0.dev20250131.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250129.dist-info → ai_edge_torch_nightly-0.3.0.dev20250131.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250129.dist-info → ai_edge_torch_nightly-0.3.0.dev20250131.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building decoder for Qwen 2.5 VL models."""
|
17
|
+
|
18
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
21
|
+
|
22
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
23
|
+
|
24
|
+
|
25
|
+
class Decoder(model_builder.DecoderOnlyModel):
|
26
|
+
"""A decoder for Qwen-VL model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
31
|
+
"""Returns the model config for a Qwen 2.5 VL 3B model.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
35
|
+
is 1024.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
The model config for a Qwen 2.5 VL 3B model.
|
39
|
+
"""
|
40
|
+
attn_config = cfg.AttentionConfig(
|
41
|
+
num_heads=16,
|
42
|
+
head_dim=128,
|
43
|
+
num_query_groups=2,
|
44
|
+
rotary_base=1000000,
|
45
|
+
rotary_percentage=1.0,
|
46
|
+
qkv_use_bias=True,
|
47
|
+
)
|
48
|
+
ff_config = cfg.FeedForwardConfig(
|
49
|
+
type=cfg.FeedForwardType.GATED,
|
50
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
51
|
+
intermediate_size=11008,
|
52
|
+
)
|
53
|
+
norm_config = cfg.NormalizationConfig(
|
54
|
+
type=cfg.NormalizationType.RMS_NORM,
|
55
|
+
epsilon=1e-06,
|
56
|
+
)
|
57
|
+
block_config = cfg.TransformerBlockConfig(
|
58
|
+
attn_config=attn_config,
|
59
|
+
ff_config=ff_config,
|
60
|
+
pre_attention_norm_config=norm_config,
|
61
|
+
post_attention_norm_config=norm_config,
|
62
|
+
)
|
63
|
+
config = cfg.ModelConfig(
|
64
|
+
vocab_size=151936,
|
65
|
+
num_layers=36,
|
66
|
+
max_seq_len=32768,
|
67
|
+
embedding_dim=2048,
|
68
|
+
kv_cache_max_len=kv_cache_max_len,
|
69
|
+
block_configs=block_config,
|
70
|
+
final_norm_config=norm_config,
|
71
|
+
enable_hlfb=True,
|
72
|
+
)
|
73
|
+
return config
|
74
|
+
|
75
|
+
|
76
|
+
def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
|
77
|
+
config = get_decoder_config(**kwargs)
|
78
|
+
config.vocab_size = 128
|
79
|
+
config.num_layers = 2
|
80
|
+
# Decoder has only one block config.
|
81
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
82
|
+
return config
|
83
|
+
|
84
|
+
|
85
|
+
def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
|
86
|
+
return model_builder.build_decoder_only_model(
|
87
|
+
checkpoint_path=checkpoint_path,
|
88
|
+
config=get_decoder_config(**kwargs),
|
89
|
+
tensor_names=TENSOR_NAMES,
|
90
|
+
model_class=Decoder,
|
91
|
+
)
|
@@ -0,0 +1,77 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored decoder of Qwen 2.5 VL 3B models."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from ai_edge_torch.generative.examples.qwen_vl import decoder
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
import torch
|
25
|
+
import transformers
|
26
|
+
|
27
|
+
|
28
|
+
class DecoderWrapper(verifier.ModelWrapper):
|
29
|
+
"""Wraps the decoder of Qwen 2.5 VL models for verification."""
|
30
|
+
|
31
|
+
def __init__(self, model: torch.nn.Module, lm_head: torch.nn.Module):
|
32
|
+
super().__init__(model)
|
33
|
+
self.lm_head = lm_head
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
output = self.model.forward(tokens)
|
37
|
+
return self.lm_head(output["last_hidden_state"])
|
38
|
+
|
39
|
+
|
40
|
+
def main(_):
|
41
|
+
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
|
42
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
+
original_model = (
|
44
|
+
transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
45
|
+
checkpoint
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
# Locate the cached dir.
|
50
|
+
cached_config_file = transformers.utils.cached_file(
|
51
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
52
|
+
)
|
53
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
54
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
55
|
+
reauthored_model = decoder.build_decoder(reauthored_checkpoint)
|
56
|
+
|
57
|
+
# Verify the reauthored model only with input IDs because the original decoder
|
58
|
+
# does not support generate() with prompts.
|
59
|
+
input_ids = [1, 2, 3, 4]
|
60
|
+
try:
|
61
|
+
verifier.verify_with_input_ids(
|
62
|
+
original_model=DecoderWrapper(
|
63
|
+
original_model.model,
|
64
|
+
original_model.lm_head,
|
65
|
+
),
|
66
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
67
|
+
input_ids=input_ids,
|
68
|
+
atol=1e-04,
|
69
|
+
)
|
70
|
+
except AssertionError as e:
|
71
|
+
logging.error("*** FAILED *** verify with input IDs: %s", e)
|
72
|
+
else:
|
73
|
+
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
|
74
|
+
|
75
|
+
|
76
|
+
if __name__ == "__main__":
|
77
|
+
app.run(main)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250131
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -27,7 +27,7 @@ Requires-Dist: scipy
|
|
27
27
|
Requires-Dist: safetensors
|
28
28
|
Requires-Dist: tabulate
|
29
29
|
Requires-Dist: torch>=2.4.0
|
30
|
-
Requires-Dist: tf-nightly>=2.19.0.
|
30
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20250101
|
31
31
|
Requires-Dist: ai-edge-litert-nightly
|
32
32
|
Requires-Dist: ai-edge-quantizer-nightly
|
33
33
|
Requires-Dist: jax
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=Id3c7ukf4DhncTCYK8zP6N-fcPUnLevwA72T7fLSC0s,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -93,6 +93,9 @@ ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY
|
|
93
93
|
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=tqvXVGNdDehdak9-5DDisACs9VlTwr8eFwcjQ_kZxgc,2776
|
94
94
|
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
95
95
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
96
|
+
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
97
|
+
ai_edge_torch/generative/examples/qwen_vl/decoder.py,sha256=rD_Ch5CzuXeatqv0C3z8vU-zou1z9QDUhoB6V4YTPIg,2829
|
98
|
+
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=FEY_PifD9fQGnERzSOljFLraRIbUVF3XTnCv95A30Cs,2602
|
96
99
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
100
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
98
101
|
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
|
@@ -222,8 +225,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
222
225
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
223
226
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
224
227
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
225
|
-
ai_edge_torch_nightly-0.3.0.
|
226
|
-
ai_edge_torch_nightly-0.3.0.
|
227
|
-
ai_edge_torch_nightly-0.3.0.
|
228
|
-
ai_edge_torch_nightly-0.3.0.
|
229
|
-
ai_edge_torch_nightly-0.3.0.
|
228
|
+
ai_edge_torch_nightly-0.3.0.dev20250131.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
229
|
+
ai_edge_torch_nightly-0.3.0.dev20250131.dist-info/METADATA,sha256=2UObtm18rivzomhCIe2uZA6GiNmNiQgx1eYHsL9fucM,1966
|
230
|
+
ai_edge_torch_nightly-0.3.0.dev20250131.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
231
|
+
ai_edge_torch_nightly-0.3.0.dev20250131.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
232
|
+
ai_edge_torch_nightly-0.3.0.dev20250131.dist-info/RECORD,,
|
File without changes
|
File without changes
|