ai-edge-torch-nightly 0.2.0.dev20240731__py3-none-any.whl → 0.2.0.dev20240801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -26,6 +26,7 @@ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
26
26
  import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
27
27
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
28
28
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
29
+ from ai_edge_torch.generative.quantize import quant_recipes
29
30
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
30
31
 
31
32
  arg_parser = argparse.ArgumentParser()
@@ -60,6 +61,7 @@ def convert_stable_diffusion_to_tflite(
60
61
  decoder_ckpt_path: str,
61
62
  image_height: int = 512,
62
63
  image_width: int = 512,
64
+ quantize: bool = True,
63
65
  ):
64
66
 
65
67
  clip_model = clip.CLIP(clip.get_model_config())
@@ -105,15 +107,17 @@ def convert_stable_diffusion_to_tflite(
105
107
  if not os.path.exists(output_dir):
106
108
  Path(output_dir).mkdir(parents=True, exist_ok=True)
107
109
 
110
+ quant_config = quant_recipes.full_int8_weight_only_recipe() if quantize else None
111
+
108
112
  # TODO(yichunk): convert to multi signature tflite model.
109
113
  # CLIP text encoder
110
- ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
111
- f'{output_dir}/clip.tflite'
112
- )
114
+ ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
115
+ quant_config=quant_config
116
+ ).export(f'{output_dir}/clip.tflite')
113
117
 
114
118
  # TODO(yichunk): enable image encoder conversion
115
119
  # Image encoder
116
- # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
120
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config).export(
117
121
  # f'{output_dir}/encoder.tflite'
118
122
  # )
119
123
 
@@ -122,12 +126,12 @@ def convert_stable_diffusion_to_tflite(
122
126
  'diffusion',
123
127
  diffusion_model,
124
128
  (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
125
- ).convert().export(f'{output_dir}/diffusion.tflite')
129
+ ).convert(quant_config=quant_config).export(f'{output_dir}/diffusion.tflite')
126
130
 
127
131
  # Image decoder
128
- ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
129
- f'{output_dir}/decoder.tflite'
130
- )
132
+ ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
133
+ quant_config=quant_config
134
+ ).export(f'{output_dir}/decoder.tflite')
131
135
 
132
136
 
133
137
  if __name__ == '__main__':
@@ -139,4 +143,5 @@ if __name__ == '__main__':
139
143
  decoder_ckpt_path=args.decoder_ckpt,
140
144
  image_height=512,
141
145
  image_width=512,
146
+ quantize=True,
142
147
  )
@@ -293,6 +293,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
293
293
  qkv_fused_interleaved=False,
294
294
  rotary_percentage=0.0,
295
295
  ),
296
+ enable_hlfb=False,
296
297
  )
297
298
 
298
299
  mid_block_config = unet_cfg.MidBlock2DConfig(
@@ -294,6 +294,7 @@ class Diffusion(nn.Module):
294
294
  attention_batch_size=config.transformer_batch_size,
295
295
  normalization_config=config.transformer_norm_config,
296
296
  attention_config=attention_config,
297
+ enable_hlfb=False,
297
298
  ),
298
299
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
299
300
  query_dim=output_channel,
@@ -301,6 +302,7 @@ class Diffusion(nn.Module):
301
302
  attention_batch_size=config.transformer_batch_size,
302
303
  normalization_config=config.transformer_norm_config,
303
304
  attention_config=attention_config,
305
+ enable_hlfb=False,
304
306
  ),
305
307
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
306
308
  feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -354,6 +356,7 @@ class Diffusion(nn.Module):
354
356
  attention_batch_size=config.transformer_batch_size,
355
357
  normalization_config=config.transformer_norm_config,
356
358
  attention_config=attention_config,
359
+ enable_hlfb=False,
357
360
  ),
358
361
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
359
362
  query_dim=mid_block_channels,
@@ -361,6 +364,7 @@ class Diffusion(nn.Module):
361
364
  attention_batch_size=config.transformer_batch_size,
362
365
  normalization_config=config.transformer_norm_config,
363
366
  attention_config=attention_config,
367
+ enable_hlfb=False,
364
368
  ),
365
369
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
366
370
  feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -415,6 +419,7 @@ class Diffusion(nn.Module):
415
419
  attention_batch_size=config.transformer_batch_size,
416
420
  normalization_config=config.transformer_norm_config,
417
421
  attention_config=attention_config,
422
+ enable_hlfb=False,
418
423
  ),
419
424
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
420
425
  query_dim=output_channel,
@@ -422,6 +427,7 @@ class Diffusion(nn.Module):
422
427
  attention_batch_size=config.transformer_batch_size,
423
428
  normalization_config=config.transformer_norm_config,
424
429
  attention_config=attention_config,
430
+ enable_hlfb=False,
425
431
  ),
426
432
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
427
433
  feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -65,6 +65,12 @@ arg_parser.add_argument(
65
65
  choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
66
66
  help='A sampler to be used to denoise the encoded image latents. Can be one of `k_lms, `k_euler`, or `k_euler_ancestral`.',
67
67
  )
68
+ arg_parser.add_argument(
69
+ '--seed',
70
+ default=None,
71
+ type=int,
72
+ help='A seed to make generation deterministic. A random number is used if unspecified.',
73
+ )
68
74
 
69
75
 
70
76
  class StableDiffusion:
@@ -219,4 +225,5 @@ if __name__ == '__main__':
219
225
  output_path=args.output_path,
220
226
  sampler=args.sampler,
221
227
  n_inference_steps=args.n_inference_steps,
228
+ seed=args.seed,
222
229
  )
@@ -41,6 +41,16 @@ def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
41
41
  )
42
42
 
43
43
 
44
+ def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
45
+ return quant_recipe.LayerQuantRecipe(
46
+ activation_dtype=quant_attrs.Dtype.FP32,
47
+ weight_dtype=quant_attrs.Dtype.INT8,
48
+ mode=quant_attrs.Mode.WEIGHT_ONLY,
49
+ algorithm=quant_attrs.Algorithm.MIN_MAX,
50
+ granularity=quant_attrs.Granularity.CHANNELWISE,
51
+ )
52
+
53
+
44
54
  def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
45
55
  return quant_recipe.LayerQuantRecipe(
46
56
  activation_dtype=quant_attrs.Dtype.FP32,
@@ -40,6 +40,14 @@ def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
40
40
  )
41
41
 
42
42
 
43
+ def full_int8_weight_only_recipe() -> quant_config.QuantConfig:
44
+ return quant_config.QuantConfig(
45
+ generative_recipe=quant_recipe.GenerativeQuantRecipe(
46
+ default=quant_recipe_utils.create_layer_quant_int8_weight_only(),
47
+ )
48
+ )
49
+
50
+
43
51
  def full_fp16_recipe() -> quant_config.QuantConfig:
44
52
  return quant_config.QuantConfig(
45
53
  generative_recipe=quant_recipe.GenerativeQuantRecipe(
@@ -111,6 +111,7 @@ class TestQuantizeConvert(unittest.TestCase):
111
111
  [
112
112
  (quant_recipes.full_fp16_recipe()),
113
113
  (quant_recipes.full_int8_dynamic_recipe()),
114
+ (quant_recipes.full_int8_weight_only_recipe()),
114
115
  (_attention_int8_dynamic_recipe()),
115
116
  (_feedforward_int8_dynamic_recipe()),
116
117
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240731
3
+ Version: 0.2.0.dev20240801
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -54,11 +54,11 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=KjfTrD2OBzOfq83-XvJ6ZhmXLu
54
54
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
55
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
56
56
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=P-cUUQaQKGKV2p-7hvLJ--RpCIA7gk8WCDRgg0pNtd0,4331
57
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=XwV1z7cVkQ947k_ERftEeL8n0NUFCJAltLtqDVfzYGI,4704
58
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=xHcmOZaW7hoWlEEEqtB4FWoHMw5AsGHPHXMNiXEfviY,13814
59
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=G-MgiEM_PpegNMePBPuNQDeUfjk42EYrVZAyJHC54AY,28468
57
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=UmKqiUbgte8PR-uslaYln-Z_TNrVWgubq_2nSyy8lQ4,4997
58
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=o-FprmF_LSxte62p0Ud1wZGE9_sC_ClX9PKnDNfJR9E,13839
59
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=MEiVP1x8kDJkvYqimtVVZt_UCTxEjcSd208Lwp8qPvc,28734
60
60
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
61
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=AopJ-KE74lzq4QJUP_hYeiXvGth7uWv7nNKqkhtcoF8,8277
61
+ ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=MzK0wMrpOcMhU4wLFwZnmn3eLMy8BjU-mHC_85SKP70,8465
62
62
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
63
63
  ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=NFpOfA4KN0JpShm5QvuYbQYZ844NzexWD8nV3WjMOZM,2397
64
64
  ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
@@ -97,8 +97,8 @@ ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
97
97
  ai_edge_torch/generative/quantize/example.py,sha256=Oy-Ss1oKXMu5RVOGt8QiUwKtrHEfhbVjTXXjxPcOqDA,1536
98
98
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
99
99
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
100
- ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
101
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=4OdKES9BhofzFoHut4qPVh-3ndVL9fu-BNOEEZc_2xE,1781
100
+ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=5yCOwHTUA-SgWqP27pvCLPBj1z_AcjXCqyPwQFo15O8,2270
101
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
102
102
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
103
103
  ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
104
104
  ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=iTNPrlubmq9ia7C3zHl50J2YEMsc4o33GwL5tr5VkkE,5229
@@ -106,7 +106,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
106
106
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
107
107
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=qMR0r7Pr_t2bn-cyeA7Qw_Rl94H1NmFcqM2ua8gpDDw,4230
108
108
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
109
- ai_edge_torch/generative/test/test_quantize.py,sha256=QbF7LC9olJFGXqlAVGciac7xXc4rDtCSr71tTIYuqPk,5230
109
+ ai_edge_torch/generative/test/test_quantize.py,sha256=nHzhthe_zcXpdAC6ZyYSW_B-UYuvEHx-5cUMHXyG5Uc,5288
110
110
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
111
111
  ai_edge_torch/generative/utilities/loader.py,sha256=NTaCrU2qmeJpqdAau13ZgyeOpwATqhZB68GY0LZjU6A,11438
112
112
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=zixjZryUaCSDKmfPkQvYwbPJhUyTmZ4AK_lWN8iFo68,33324
@@ -125,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
125
125
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
126
126
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
127
127
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
128
- ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
- ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/METADATA,sha256=B2Nf7g2PWOU-bYTAByfDNV_FAKy3ah88O-Plsk-uW_M,1889
130
- ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
- ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
- ai_edge_torch_nightly-0.2.0.dev20240731.dist-info/RECORD,,
128
+ ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
+ ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/METADATA,sha256=r_k99TD1FMN2U-8-xG1j24NbC6Ynph8lHBqXcY315BI,1889
130
+ ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
+ ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
+ ai_edge_torch_nightly-0.2.0.dev20240801.dist-info/RECORD,,