ai-edge-torch-nightly 0.3.0.dev20241003__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +40 -32
- ai_edge_torch/generative/test/test_model_conversion.py +37 -16
- ai_edge_torch/lowertools/translate_recipe.py +2 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/top_level.txt +0 -0
@@ -75,9 +75,7 @@ class CLIP(nn.Module):
|
|
75
75
|
)
|
76
76
|
|
77
77
|
@torch.inference_mode
|
78
|
-
def forward(self, tokens: torch.
|
79
|
-
tokens = tokens.type(torch.int)
|
80
|
-
|
78
|
+
def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
|
81
79
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
82
80
|
for layer in self.transformer_blocks:
|
83
81
|
state = layer(state, mask=self.mask_cache)
|
@@ -13,47 +13,54 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
import argparse
|
17
16
|
import os
|
18
|
-
|
19
|
-
from typing import Optional
|
17
|
+
import pathlib
|
20
18
|
|
19
|
+
from absl import app
|
20
|
+
from absl import flags
|
21
21
|
import ai_edge_torch
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
from ai_edge_torch.generative.examples.stable_diffusion
|
26
|
-
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
22
|
+
from ai_edge_torch.generative.examples.stable_diffusion import clip
|
23
|
+
from ai_edge_torch.generative.examples.stable_diffusion import decoder
|
24
|
+
from ai_edge_torch.generative.examples.stable_diffusion import diffusion
|
25
|
+
from ai_edge_torch.generative.examples.stable_diffusion import util
|
27
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
28
|
-
|
27
|
+
from ai_edge_torch.generative.utilities import stable_diffusion_loader
|
29
28
|
import torch
|
30
29
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
type=str,
|
30
|
+
_CLIP_CKPT = flags.DEFINE_string(
|
31
|
+
'clip_ckpt',
|
32
|
+
None,
|
35
33
|
help='Path to source CLIP model checkpoint',
|
36
34
|
required=True,
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
36
|
+
|
37
|
+
_DIFFUSION_CKPT = flags.DEFINE_string(
|
38
|
+
'diffusion_ckpt',
|
39
|
+
None,
|
41
40
|
help='Path to source diffusion model checkpoint',
|
42
41
|
required=True,
|
43
42
|
)
|
44
|
-
|
45
|
-
|
46
|
-
|
43
|
+
|
44
|
+
_DECODER_CKPT = flags.DEFINE_string(
|
45
|
+
'decoder_ckpt',
|
46
|
+
None,
|
47
47
|
help='Path to source image decoder model checkpoint',
|
48
48
|
required=True,
|
49
49
|
)
|
50
|
-
|
51
|
-
|
52
|
-
|
50
|
+
|
51
|
+
_OUTPUT_DIR = flags.DEFINE_string(
|
52
|
+
'output_dir',
|
53
|
+
None,
|
53
54
|
help='Path to the converted TF Lite directory.',
|
54
55
|
required=True,
|
55
56
|
)
|
56
57
|
|
58
|
+
_QUANTIZE = flags.DEFINE_bool(
|
59
|
+
'quantize',
|
60
|
+
help='Whether to quantize the model during conversion.',
|
61
|
+
default=True,
|
62
|
+
)
|
63
|
+
|
57
64
|
|
58
65
|
@torch.inference_mode
|
59
66
|
def convert_stable_diffusion_to_tflite(
|
@@ -111,7 +118,7 @@ def convert_stable_diffusion_to_tflite(
|
|
111
118
|
time_embedding = util.get_time_embedding(timestamp)
|
112
119
|
|
113
120
|
if not os.path.exists(output_dir):
|
114
|
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
121
|
+
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
115
122
|
|
116
123
|
quant_config = (
|
117
124
|
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
@@ -142,14 +149,15 @@ def convert_stable_diffusion_to_tflite(
|
|
142
149
|
).export(f'{output_dir}/decoder.tflite')
|
143
150
|
|
144
151
|
|
145
|
-
|
146
|
-
args = arg_parser.parse_args()
|
152
|
+
def main(_):
|
147
153
|
convert_stable_diffusion_to_tflite(
|
148
|
-
output_dir=
|
149
|
-
clip_ckpt_path=
|
150
|
-
diffusion_ckpt_path=
|
151
|
-
decoder_ckpt_path=
|
152
|
-
|
153
|
-
image_width=512,
|
154
|
-
quantize=True,
|
154
|
+
output_dir=_OUTPUT_DIR.value,
|
155
|
+
clip_ckpt_path=_CLIP_CKPT.value,
|
156
|
+
diffusion_ckpt_path=_DIFFUSION_CKPT.value,
|
157
|
+
decoder_ckpt_path=_DECODER_CKPT.value,
|
158
|
+
quantize=_QUANTIZE.value,
|
155
159
|
)
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == '__main__':
|
163
|
+
app.run(main)
|
@@ -43,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
|
|
43
43
|
)
|
44
44
|
)
|
45
45
|
|
46
|
-
def
|
46
|
+
def _get_params(self, enable_hlfb: bool):
|
47
|
+
"""Returns a model, edge model and the kwargs to use for testing."""
|
48
|
+
config = toy_model_with_kv_cache.get_model_config()
|
49
|
+
config.enable_hlfb = enable_hlfb
|
50
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
47
51
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
48
52
|
[10], dtype=torch.int
|
49
53
|
)
|
50
54
|
kv = kv_cache.KVCache.from_model_config(config)
|
55
|
+
kwargs = {
|
56
|
+
"tokens": tokens,
|
57
|
+
"input_pos": input_pos,
|
58
|
+
"kv_cache": kv,
|
59
|
+
}
|
51
60
|
|
52
61
|
edge_model = ai_edge_torch.convert(
|
53
62
|
pytorch_model,
|
54
|
-
sample_kwargs=
|
55
|
-
"tokens": tokens,
|
56
|
-
"input_pos": input_pos,
|
57
|
-
"kv_cache": kv,
|
58
|
-
},
|
63
|
+
sample_kwargs=kwargs,
|
59
64
|
)
|
60
65
|
edge_model.set_interpreter_builder(
|
61
66
|
self._interpreter_builder(edge_model.tflite_model())
|
62
67
|
)
|
68
|
+
return pytorch_model, edge_model, kwargs
|
69
|
+
|
70
|
+
def _test_model_with_kv_cache(self, enable_hlfb: bool):
|
71
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
|
63
72
|
|
64
73
|
self.assertTrue(
|
65
74
|
test_utils.compare_tflite_torch(
|
66
75
|
edge_model,
|
67
76
|
pytorch_model,
|
68
|
-
tokens,
|
69
|
-
input_pos,
|
70
|
-
|
77
|
+
kwargs["tokens"],
|
78
|
+
kwargs["input_pos"],
|
79
|
+
kwargs["kv_cache"],
|
71
80
|
signature_name="serving_default",
|
72
81
|
atol=1e-5,
|
73
82
|
rtol=1e-5,
|
@@ -79,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
|
|
79
88
|
reason="tests with custom ops are not supported on oss",
|
80
89
|
)
|
81
90
|
def test_toy_model_with_kv_cache(self):
|
82
|
-
|
83
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
84
|
-
self._test_model_with_kv_cache(config, pytorch_model)
|
91
|
+
self._test_model_with_kv_cache(enable_hlfb=False)
|
85
92
|
|
86
93
|
@googletest.skipIf(
|
87
94
|
ai_edge_config.Config.use_torch_xla,
|
88
95
|
reason="tests with custom ops are not supported on oss",
|
89
96
|
)
|
90
97
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
98
|
+
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
|
+
|
100
|
+
@googletest.skipIf(
|
101
|
+
ai_edge_config.Config.use_torch_xla,
|
102
|
+
reason="tests with custom ops are not supported on oss",
|
103
|
+
)
|
104
|
+
def test_toy_model_has_ekv_op(self):
|
105
|
+
"""Tests that the model has the external kv cache op."""
|
106
|
+
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
|
+
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
109
|
+
model_content=edge_model.tflite_model(),
|
110
|
+
experimental_default_delegate_latest_features=True,
|
111
|
+
)
|
112
|
+
|
113
|
+
# pylint: disable=protected-access
|
114
|
+
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
+
self.assertIn("odml.update_external_kv_cache", op_names)
|
95
116
|
|
96
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
97
118
|
# prefill
|
@@ -156,8 +156,8 @@ def translate_to_ai_edge_recipe(
|
|
156
156
|
|
157
157
|
|
158
158
|
def quantize_model(
|
159
|
-
model:
|
159
|
+
model: bytes, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
|
160
160
|
) -> bytearray:
|
161
|
-
qt = quantizer.Quantizer(
|
161
|
+
qt = quantizer.Quantizer(model, recipe)
|
162
162
|
result = qt.quantize()
|
163
163
|
return result.quantized_model
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241005
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=y5TOP0Z8qFsjIuJuJtSmzOUpHyTa9UH46RdJjtRWYQA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -71,8 +71,8 @@ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZk
|
|
71
71
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
72
72
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
73
73
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
74
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
75
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
75
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=i9mcBITt4jJqKLA4Qdt3uFotCrglv14tPg8VnqsVnaI,5004
|
76
76
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
|
77
77
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
|
78
78
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
@@ -122,7 +122,7 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
122
122
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
123
123
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
124
124
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
125
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256
|
125
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
|
126
126
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
|
127
127
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
128
128
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
@@ -147,7 +147,7 @@ ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjF
|
|
147
147
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
148
148
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
149
149
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
|
150
|
-
ai_edge_torch/lowertools/translate_recipe.py,sha256=
|
150
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
151
151
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
152
152
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
153
153
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/METADATA,sha256=O3P5ofz2aERMO1xbvIC7Z4RWsUNLJOZgn4pxEH3ftRc,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/RECORD,,
|
File without changes
|
File without changes
|