ai-edge-torch-nightly 0.3.0.dev20241003__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +40 -32
- ai_edge_torch/generative/test/test_model_conversion.py +37 -16
- ai_edge_torch/lowertools/translate_recipe.py +2 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/top_level.txt +0 -0
@@ -75,9 +75,7 @@ class CLIP(nn.Module):
|
|
75
75
|
)
|
76
76
|
|
77
77
|
@torch.inference_mode
|
78
|
-
def forward(self, tokens: torch.
|
79
|
-
tokens = tokens.type(torch.int)
|
80
|
-
|
78
|
+
def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
|
81
79
|
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
82
80
|
for layer in self.transformer_blocks:
|
83
81
|
state = layer(state, mask=self.mask_cache)
|
@@ -13,47 +13,54 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
import argparse
|
17
16
|
import os
|
18
|
-
|
19
|
-
from typing import Optional
|
17
|
+
import pathlib
|
20
18
|
|
19
|
+
from absl import app
|
20
|
+
from absl import flags
|
21
21
|
import ai_edge_torch
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
from ai_edge_torch.generative.examples.stable_diffusion
|
26
|
-
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
22
|
+
from ai_edge_torch.generative.examples.stable_diffusion import clip
|
23
|
+
from ai_edge_torch.generative.examples.stable_diffusion import decoder
|
24
|
+
from ai_edge_torch.generative.examples.stable_diffusion import diffusion
|
25
|
+
from ai_edge_torch.generative.examples.stable_diffusion import util
|
27
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
28
|
-
|
27
|
+
from ai_edge_torch.generative.utilities import stable_diffusion_loader
|
29
28
|
import torch
|
30
29
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
type=str,
|
30
|
+
_CLIP_CKPT = flags.DEFINE_string(
|
31
|
+
'clip_ckpt',
|
32
|
+
None,
|
35
33
|
help='Path to source CLIP model checkpoint',
|
36
34
|
required=True,
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
36
|
+
|
37
|
+
_DIFFUSION_CKPT = flags.DEFINE_string(
|
38
|
+
'diffusion_ckpt',
|
39
|
+
None,
|
41
40
|
help='Path to source diffusion model checkpoint',
|
42
41
|
required=True,
|
43
42
|
)
|
44
|
-
|
45
|
-
|
46
|
-
|
43
|
+
|
44
|
+
_DECODER_CKPT = flags.DEFINE_string(
|
45
|
+
'decoder_ckpt',
|
46
|
+
None,
|
47
47
|
help='Path to source image decoder model checkpoint',
|
48
48
|
required=True,
|
49
49
|
)
|
50
|
-
|
51
|
-
|
52
|
-
|
50
|
+
|
51
|
+
_OUTPUT_DIR = flags.DEFINE_string(
|
52
|
+
'output_dir',
|
53
|
+
None,
|
53
54
|
help='Path to the converted TF Lite directory.',
|
54
55
|
required=True,
|
55
56
|
)
|
56
57
|
|
58
|
+
_QUANTIZE = flags.DEFINE_bool(
|
59
|
+
'quantize',
|
60
|
+
help='Whether to quantize the model during conversion.',
|
61
|
+
default=True,
|
62
|
+
)
|
63
|
+
|
57
64
|
|
58
65
|
@torch.inference_mode
|
59
66
|
def convert_stable_diffusion_to_tflite(
|
@@ -111,7 +118,7 @@ def convert_stable_diffusion_to_tflite(
|
|
111
118
|
time_embedding = util.get_time_embedding(timestamp)
|
112
119
|
|
113
120
|
if not os.path.exists(output_dir):
|
114
|
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
121
|
+
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
115
122
|
|
116
123
|
quant_config = (
|
117
124
|
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
@@ -142,14 +149,15 @@ def convert_stable_diffusion_to_tflite(
|
|
142
149
|
).export(f'{output_dir}/decoder.tflite')
|
143
150
|
|
144
151
|
|
145
|
-
|
146
|
-
args = arg_parser.parse_args()
|
152
|
+
def main(_):
|
147
153
|
convert_stable_diffusion_to_tflite(
|
148
|
-
output_dir=
|
149
|
-
clip_ckpt_path=
|
150
|
-
diffusion_ckpt_path=
|
151
|
-
decoder_ckpt_path=
|
152
|
-
|
153
|
-
image_width=512,
|
154
|
-
quantize=True,
|
154
|
+
output_dir=_OUTPUT_DIR.value,
|
155
|
+
clip_ckpt_path=_CLIP_CKPT.value,
|
156
|
+
diffusion_ckpt_path=_DIFFUSION_CKPT.value,
|
157
|
+
decoder_ckpt_path=_DECODER_CKPT.value,
|
158
|
+
quantize=_QUANTIZE.value,
|
155
159
|
)
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == '__main__':
|
163
|
+
app.run(main)
|
@@ -43,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
|
|
43
43
|
)
|
44
44
|
)
|
45
45
|
|
46
|
-
def
|
46
|
+
def _get_params(self, enable_hlfb: bool):
|
47
|
+
"""Returns a model, edge model and the kwargs to use for testing."""
|
48
|
+
config = toy_model_with_kv_cache.get_model_config()
|
49
|
+
config.enable_hlfb = enable_hlfb
|
50
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
47
51
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
48
52
|
[10], dtype=torch.int
|
49
53
|
)
|
50
54
|
kv = kv_cache.KVCache.from_model_config(config)
|
55
|
+
kwargs = {
|
56
|
+
"tokens": tokens,
|
57
|
+
"input_pos": input_pos,
|
58
|
+
"kv_cache": kv,
|
59
|
+
}
|
51
60
|
|
52
61
|
edge_model = ai_edge_torch.convert(
|
53
62
|
pytorch_model,
|
54
|
-
sample_kwargs=
|
55
|
-
"tokens": tokens,
|
56
|
-
"input_pos": input_pos,
|
57
|
-
"kv_cache": kv,
|
58
|
-
},
|
63
|
+
sample_kwargs=kwargs,
|
59
64
|
)
|
60
65
|
edge_model.set_interpreter_builder(
|
61
66
|
self._interpreter_builder(edge_model.tflite_model())
|
62
67
|
)
|
68
|
+
return pytorch_model, edge_model, kwargs
|
69
|
+
|
70
|
+
def _test_model_with_kv_cache(self, enable_hlfb: bool):
|
71
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
|
63
72
|
|
64
73
|
self.assertTrue(
|
65
74
|
test_utils.compare_tflite_torch(
|
66
75
|
edge_model,
|
67
76
|
pytorch_model,
|
68
|
-
tokens,
|
69
|
-
input_pos,
|
70
|
-
|
77
|
+
kwargs["tokens"],
|
78
|
+
kwargs["input_pos"],
|
79
|
+
kwargs["kv_cache"],
|
71
80
|
signature_name="serving_default",
|
72
81
|
atol=1e-5,
|
73
82
|
rtol=1e-5,
|
@@ -79,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
|
|
79
88
|
reason="tests with custom ops are not supported on oss",
|
80
89
|
)
|
81
90
|
def test_toy_model_with_kv_cache(self):
|
82
|
-
|
83
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
84
|
-
self._test_model_with_kv_cache(config, pytorch_model)
|
91
|
+
self._test_model_with_kv_cache(enable_hlfb=False)
|
85
92
|
|
86
93
|
@googletest.skipIf(
|
87
94
|
ai_edge_config.Config.use_torch_xla,
|
88
95
|
reason="tests with custom ops are not supported on oss",
|
89
96
|
)
|
90
97
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
98
|
+
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
|
+
|
100
|
+
@googletest.skipIf(
|
101
|
+
ai_edge_config.Config.use_torch_xla,
|
102
|
+
reason="tests with custom ops are not supported on oss",
|
103
|
+
)
|
104
|
+
def test_toy_model_has_ekv_op(self):
|
105
|
+
"""Tests that the model has the external kv cache op."""
|
106
|
+
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
|
+
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
109
|
+
model_content=edge_model.tflite_model(),
|
110
|
+
experimental_default_delegate_latest_features=True,
|
111
|
+
)
|
112
|
+
|
113
|
+
# pylint: disable=protected-access
|
114
|
+
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
+
self.assertIn("odml.update_external_kv_cache", op_names)
|
95
116
|
|
96
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
97
118
|
# prefill
|
@@ -156,8 +156,8 @@ def translate_to_ai_edge_recipe(
|
|
156
156
|
|
157
157
|
|
158
158
|
def quantize_model(
|
159
|
-
model:
|
159
|
+
model: bytes, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
|
160
160
|
) -> bytearray:
|
161
|
-
qt = quantizer.Quantizer(
|
161
|
+
qt = quantizer.Quantizer(model, recipe)
|
162
162
|
result = qt.quantize()
|
163
163
|
return result.quantized_model
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241005
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=y5TOP0Z8qFsjIuJuJtSmzOUpHyTa9UH46RdJjtRWYQA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -71,8 +71,8 @@ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZk
|
|
71
71
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
72
72
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
73
73
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
74
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
75
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
75
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=i9mcBITt4jJqKLA4Qdt3uFotCrglv14tPg8VnqsVnaI,5004
|
76
76
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
|
77
77
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
|
78
78
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
@@ -122,7 +122,7 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
122
122
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
123
123
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
124
124
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
125
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256
|
125
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
|
126
126
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
|
127
127
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
128
128
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
@@ -147,7 +147,7 @@ ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjF
|
|
147
147
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
148
148
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
149
149
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
|
150
|
-
ai_edge_torch/lowertools/translate_recipe.py,sha256=
|
150
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
151
151
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
152
152
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
153
153
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/METADATA,sha256=O3P5ofz2aERMO1xbvIC7Z4RWsUNLJOZgn4pxEH3ftRc,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/RECORD,,
|
File without changes
|
File without changes
|