ai-edge-torch-nightly 0.2.0.dev20240731__py3-none-any.whl → 0.2.0.dev20240801__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +13 -8
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +1 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +7 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/test_quantize.py +1 -0
- {ai_edge_torch_nightly-0.2.0.dev20240731.dist-info → ai_edge_torch_nightly-0.2.0.dev20240801.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240731.dist-info → ai_edge_torch_nightly-0.2.0.dev20240801.dist-info}/RECORD +12 -12
- {ai_edge_torch_nightly-0.2.0.dev20240731.dist-info → ai_edge_torch_nightly-0.2.0.dev20240801.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240731.dist-info → ai_edge_torch_nightly-0.2.0.dev20240801.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240731.dist-info → ai_edge_torch_nightly-0.2.0.dev20240801.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,7 @@ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
|
26
26
|
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
|
|
27
27
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
28
28
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
29
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
30
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
30
31
|
|
|
31
32
|
arg_parser = argparse.ArgumentParser()
|
|
@@ -60,6 +61,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
60
61
|
decoder_ckpt_path: str,
|
|
61
62
|
image_height: int = 512,
|
|
62
63
|
image_width: int = 512,
|
|
64
|
+
quantize: bool = True,
|
|
63
65
|
):
|
|
64
66
|
|
|
65
67
|
clip_model = clip.CLIP(clip.get_model_config())
|
|
@@ -105,15 +107,17 @@ def convert_stable_diffusion_to_tflite(
|
|
|
105
107
|
if not os.path.exists(output_dir):
|
|
106
108
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
107
109
|
|
|
110
|
+
quant_config = quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
|
111
|
+
|
|
108
112
|
# TODO(yichunk): convert to multi signature tflite model.
|
|
109
113
|
# CLIP text encoder
|
|
110
|
-
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
|
|
111
|
-
|
|
112
|
-
)
|
|
114
|
+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
|
|
115
|
+
quant_config=quant_config
|
|
116
|
+
).export(f'{output_dir}/clip.tflite')
|
|
113
117
|
|
|
114
118
|
# TODO(yichunk): enable image encoder conversion
|
|
115
119
|
# Image encoder
|
|
116
|
-
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
|
|
120
|
+
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config).export(
|
|
117
121
|
# f'{output_dir}/encoder.tflite'
|
|
118
122
|
# )
|
|
119
123
|
|
|
@@ -122,12 +126,12 @@ def convert_stable_diffusion_to_tflite(
|
|
|
122
126
|
'diffusion',
|
|
123
127
|
diffusion_model,
|
|
124
128
|
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
|
|
125
|
-
).convert().export(f'{output_dir}/diffusion.tflite')
|
|
129
|
+
).convert(quant_config=quant_config).export(f'{output_dir}/diffusion.tflite')
|
|
126
130
|
|
|
127
131
|
# Image decoder
|
|
128
|
-
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
|
|
129
|
-
|
|
130
|
-
)
|
|
132
|
+
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
|
|
133
|
+
quant_config=quant_config
|
|
134
|
+
).export(f'{output_dir}/decoder.tflite')
|
|
131
135
|
|
|
132
136
|
|
|
133
137
|
if __name__ == '__main__':
|
|
@@ -139,4 +143,5 @@ if __name__ == '__main__':
|
|
|
139
143
|
decoder_ckpt_path=args.decoder_ckpt,
|
|
140
144
|
image_height=512,
|
|
141
145
|
image_width=512,
|
|
146
|
+
quantize=True,
|
|
142
147
|
)
|
|
@@ -294,6 +294,7 @@ class Diffusion(nn.Module):
|
|
|
294
294
|
attention_batch_size=config.transformer_batch_size,
|
|
295
295
|
normalization_config=config.transformer_norm_config,
|
|
296
296
|
attention_config=attention_config,
|
|
297
|
+
enable_hlfb=False,
|
|
297
298
|
),
|
|
298
299
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
299
300
|
query_dim=output_channel,
|
|
@@ -301,6 +302,7 @@ class Diffusion(nn.Module):
|
|
|
301
302
|
attention_batch_size=config.transformer_batch_size,
|
|
302
303
|
normalization_config=config.transformer_norm_config,
|
|
303
304
|
attention_config=attention_config,
|
|
305
|
+
enable_hlfb=False,
|
|
304
306
|
),
|
|
305
307
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
306
308
|
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
|
|
@@ -354,6 +356,7 @@ class Diffusion(nn.Module):
|
|
|
354
356
|
attention_batch_size=config.transformer_batch_size,
|
|
355
357
|
normalization_config=config.transformer_norm_config,
|
|
356
358
|
attention_config=attention_config,
|
|
359
|
+
enable_hlfb=False,
|
|
357
360
|
),
|
|
358
361
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
359
362
|
query_dim=mid_block_channels,
|
|
@@ -361,6 +364,7 @@ class Diffusion(nn.Module):
|
|
|
361
364
|
attention_batch_size=config.transformer_batch_size,
|
|
362
365
|
normalization_config=config.transformer_norm_config,
|
|
363
366
|
attention_config=attention_config,
|
|
367
|
+
enable_hlfb=False,
|
|
364
368
|
),
|
|
365
369
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
366
370
|
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
|
|
@@ -415,6 +419,7 @@ class Diffusion(nn.Module):
|
|
|
415
419
|
attention_batch_size=config.transformer_batch_size,
|
|
416
420
|
normalization_config=config.transformer_norm_config,
|
|
417
421
|
attention_config=attention_config,
|
|
422
|
+
enable_hlfb=False,
|
|
418
423
|
),
|
|
419
424
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
420
425
|
query_dim=output_channel,
|
|
@@ -422,6 +427,7 @@ class Diffusion(nn.Module):
|
|
|
422
427
|
attention_batch_size=config.transformer_batch_size,
|
|
423
428
|
normalization_config=config.transformer_norm_config,
|
|
424
429
|
attention_config=attention_config,
|
|
430
|
+
enable_hlfb=False,
|
|
425
431
|
),
|
|
426
432
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
427
433
|
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
|
|
@@ -65,6 +65,12 @@ arg_parser.add_argument(
|
|
|
65
65
|
choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
|
|
66
66
|
help='A sampler to be used to denoise the encoded image latents. Can be one of `k_lms, `k_euler`, or `k_euler_ancestral`.',
|
|
67
67
|
)
|
|
68
|
+
arg_parser.add_argument(
|
|
69
|
+
'--seed',
|
|
70
|
+
default=None,
|
|
71
|
+
type=int,
|
|
72
|
+
help='A seed to make generation deterministic. A random number is used if unspecified.',
|
|
73
|
+
)
|
|
68
74
|
|
|
69
75
|
|
|
70
76
|
class StableDiffusion:
|
|
@@ -219,4 +225,5 @@ if __name__ == '__main__':
|
|
|
219
225
|
output_path=args.output_path,
|
|
220
226
|
sampler=args.sampler,
|
|
221
227
|
n_inference_steps=args.n_inference_steps,
|
|
228
|
+
seed=args.seed,
|
|
222
229
|
)
|
|
@@ -41,6 +41,16 @@ def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
|
|
45
|
+
return quant_recipe.LayerQuantRecipe(
|
|
46
|
+
activation_dtype=quant_attrs.Dtype.FP32,
|
|
47
|
+
weight_dtype=quant_attrs.Dtype.INT8,
|
|
48
|
+
mode=quant_attrs.Mode.WEIGHT_ONLY,
|
|
49
|
+
algorithm=quant_attrs.Algorithm.MIN_MAX,
|
|
50
|
+
granularity=quant_attrs.Granularity.CHANNELWISE,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
44
54
|
def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
|
|
45
55
|
return quant_recipe.LayerQuantRecipe(
|
|
46
56
|
activation_dtype=quant_attrs.Dtype.FP32,
|
|
@@ -40,6 +40,14 @@ def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
+
def full_int8_weight_only_recipe() -> quant_config.QuantConfig:
|
|
44
|
+
return quant_config.QuantConfig(
|
|
45
|
+
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
46
|
+
default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
43
51
|
def full_fp16_recipe() -> quant_config.QuantConfig:
|
|
44
52
|
return quant_config.QuantConfig(
|
|
45
53
|
generative_recipe=quant_recipe.GenerativeQuantRecipe(
|
|
@@ -111,6 +111,7 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
111
111
|
[
|
|
112
112
|
(quant_recipes.full_fp16_recipe()),
|
|
113
113
|
(quant_recipes.full_int8_dynamic_recipe()),
|
|
114
|
+
(quant_recipes.full_int8_weight_only_recipe()),
|
|
114
115
|
(_attention_int8_dynamic_recipe()),
|
|
115
116
|
(_feedforward_int8_dynamic_recipe()),
|
|
116
117
|
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240801
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -54,11 +54,11 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=KjfTrD2OBzOfq83-XvJ6ZhmXLu
|
|
|
54
54
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
55
55
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
56
56
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=P-cUUQaQKGKV2p-7hvLJ--RpCIA7gk8WCDRgg0pNtd0,4331
|
|
57
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
58
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
|
59
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
|
57
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=UmKqiUbgte8PR-uslaYln-Z_TNrVWgubq_2nSyy8lQ4,4997
|
|
58
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=o-FprmF_LSxte62p0Ud1wZGE9_sC_ClX9PKnDNfJR9E,13839
|
|
59
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=MEiVP1x8kDJkvYqimtVVZt_UCTxEjcSd208Lwp8qPvc,28734
|
|
60
60
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
|
|
61
|
-
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=
|
|
61
|
+
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=MzK0wMrpOcMhU4wLFwZnmn3eLMy8BjU-mHC_85SKP70,8465
|
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
63
63
|
ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=NFpOfA4KN0JpShm5QvuYbQYZ844NzexWD8nV3WjMOZM,2397
|
|
64
64
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
|
|
@@ -97,8 +97,8 @@ ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
|
97
97
|
ai_edge_torch/generative/quantize/example.py,sha256=Oy-Ss1oKXMu5RVOGt8QiUwKtrHEfhbVjTXXjxPcOqDA,1536
|
|
98
98
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
99
99
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
|
|
100
|
-
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256
|
|
101
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
|
100
|
+
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=5yCOwHTUA-SgWqP27pvCLPBj1z_AcjXCqyPwQFo15O8,2270
|
|
101
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
|
102
102
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
|
103
103
|
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
104
|
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=iTNPrlubmq9ia7C3zHl50J2YEMsc4o33GwL5tr5VkkE,5229
|
|
@@ -106,7 +106,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
|
106
106
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
107
107
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=qMR0r7Pr_t2bn-cyeA7Qw_Rl94H1NmFcqM2ua8gpDDw,4230
|
|
108
108
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
|
|
109
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
|
109
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=nHzhthe_zcXpdAC6ZyYSW_B-UYuvEHx-5cUMHXyG5Uc,5288
|
|
110
110
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
111
111
|
ai_edge_torch/generative/utilities/loader.py,sha256=NTaCrU2qmeJpqdAau13ZgyeOpwATqhZB68GY0LZjU6A,11438
|
|
112
112
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=zixjZryUaCSDKmfPkQvYwbPJhUyTmZ4AK_lWN8iFo68,33324
|
|
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
125
125
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
126
126
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
127
127
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
128
|
-
ai_edge_torch_nightly-0.2.0.
|
|
129
|
-
ai_edge_torch_nightly-0.2.0.
|
|
130
|
-
ai_edge_torch_nightly-0.2.0.
|
|
131
|
-
ai_edge_torch_nightly-0.2.0.
|
|
132
|
-
ai_edge_torch_nightly-0.2.0.
|
|
128
|
+
ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
129
|
+
ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/METADATA,sha256=r_k99TD1FMN2U-8-xG1j24NbC6Ynph8lHBqXcY315BI,1889
|
|
130
|
+
ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
131
|
+
ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
132
|
+
ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|