ai-edge-torch-nightly 0.2.0.dev20240610__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (30) hide show
  1. ai_edge_torch/convert/conversion_utils.py +17 -5
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  3. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  4. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  5. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  6. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  7. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  8. ai_edge_torch/generative/layers/attention.py +154 -26
  9. ai_edge_torch/generative/layers/model_config.py +4 -0
  10. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  11. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  12. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  13. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  14. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
  15. ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
  16. ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
  17. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
  18. ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
  19. ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  21. ai_edge_torch/generative/test/test_quantize.py +75 -20
  22. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  23. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  24. ai_edge_torch/quantize/quant_config.py +11 -15
  25. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
  26. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +29 -27
  27. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  28. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
  29. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
  30. {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
@@ -318,7 +318,7 @@ class ModelLoader:
318
318
  q_name = names.attn_query_proj.format(idx)
319
319
  k_name = names.attn_key_proj.format(idx)
320
320
  v_name = names.attn_value_proj.format(idx)
321
- # model.encoder.transformer_blocks[0].atten_func.q.weight
321
+ # model.encoder.transformer_blocks[0].atten_func.q_projection.weight
322
322
  if fuse_attention:
323
323
  converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
324
324
  config,
@@ -334,18 +334,34 @@ class ModelLoader:
334
334
  state.pop(f"{v_name}.bias"),
335
335
  )
336
336
  else:
337
- converted_state[f"{prefix}.atten_func.q.weight"] = state.pop(f"{q_name}.weight")
338
- converted_state[f"{prefix}.atten_func.k.weight"] = state.pop(f"{k_name}.weight")
339
- converted_state[f"{prefix}.atten_func.v.weight"] = state.pop(f"{v_name}.weight")
337
+ converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
338
+ f"{q_name}.weight"
339
+ )
340
+ converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
341
+ f"{k_name}.weight"
342
+ )
343
+ converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
344
+ f"{v_name}.weight"
345
+ )
340
346
  if config.attn_config.qkv_use_bias:
341
- converted_state[f"{prefix}.atten_func.q.bias"] = state.pop(f"{q_name}.bias")
342
- converted_state[f"{prefix}.atten_func.k.bias"] = state.pop(f"{k_name}.bias")
343
- converted_state[f"{prefix}.atten_func.v.bias"] = state.pop(f"{v_name}.bias")
347
+ converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
348
+ f"{q_name}.bias"
349
+ )
350
+ converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
351
+ f"{k_name}.bias"
352
+ )
353
+ converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
354
+ f"{v_name}.bias"
355
+ )
344
356
 
345
357
  o_name = names.attn_output_proj.format(idx)
346
- converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
358
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
359
+ f"{o_name}.weight"
360
+ )
347
361
  if config.attn_config.output_proj_use_bias:
348
- converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
362
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
363
+ f"{o_name}.bias"
364
+ )
349
365
 
350
366
  def _map_cross_attention(
351
367
  self,
@@ -383,32 +399,32 @@ class ModelLoader:
383
399
  state.pop(f"{v_name}.bias"),
384
400
  )
385
401
  else:
386
- converted_state[f"{prefix}.cross_atten_func.q.weight"] = state.pop(
402
+ converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
387
403
  f"{q_name}.weight"
388
404
  )
389
- converted_state[f"{prefix}.cross_atten_func.k.weight"] = state.pop(
405
+ converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
390
406
  f"{k_name}.weight"
391
407
  )
392
- converted_state[f"{prefix}.cross_atten_func.v.weight"] = state.pop(
408
+ converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
393
409
  f"{v_name}.weight"
394
410
  )
395
411
  if config.attn_config.qkv_use_bias:
396
- converted_state[f"{prefix}.cross_atten_func.q.bias"] = state.pop(
412
+ converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
397
413
  f"{q_name}.bias"
398
414
  )
399
- converted_state[f"{prefix}.cross_atten_func.k.bias"] = state.pop(
415
+ converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
400
416
  f"{k_name}.bias"
401
417
  )
402
- converted_state[f"{prefix}.cross_atten_func.v.bias"] = state.pop(
418
+ converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
403
419
  f"{v_name}.bias"
404
420
  )
405
421
 
406
422
  o_name = names.cross_attn_output_proj.format(idx)
407
- converted_state[f"{prefix}.cross_atten_func.proj.weight"] = state.pop(
423
+ converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
408
424
  f"{o_name}.weight"
409
425
  )
410
426
  if config.attn_config.output_proj_use_bias:
411
- converted_state[f"{prefix}.cross_atten_func.proj.bias"] = state.pop(
427
+ converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
412
428
  f"{o_name}.bias"
413
429
  )
414
430
 
@@ -32,27 +32,26 @@ class QuantConfig:
32
32
  pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
33
33
  with PT2E quantization. This method of quantization is not applicable to
34
34
  models created with the Edge Generative API.
35
- transformer_recipe: Quantization recipe to be applied on a model created
35
+ generative_recipe: Quantization recipe to be applied on a model created
36
36
  with the Edge Generative API.
37
37
  """
38
38
 
39
39
  pt2e_quantizer: pt2eq.PT2EQuantizer = None
40
- transformer_recipe: quant_recipe.TransformerQuantRecipe = None
40
+ generative_recipe: quant_recipe.GenerativeQuantRecipe = None
41
41
 
42
42
  @enum.unique
43
43
  class _QuantizerMode(enum.Enum):
44
44
  NONE = enum.auto()
45
45
  PT2E_DYNAMIC = enum.auto()
46
46
  PT2E_STATIC = enum.auto()
47
- TFLITE_DYNAMIC = enum.auto()
48
- TFLITE_FP16 = enum.auto()
47
+ AI_EDGE_QUANTIZER = enum.auto()
49
48
 
50
49
  _quantizer_mode: _QuantizerMode = _QuantizerMode.NONE
51
50
 
52
51
  def __init__(
53
52
  self,
54
53
  pt2e_quantizer: Optional[pt2eq.PT2EQuantizer] = None,
55
- transformer_recipe: Optional[quant_recipe.TransformerQuantRecipe] = None,
54
+ generative_recipe: Optional[quant_recipe.GenerativeQuantRecipe] = None,
56
55
  ):
57
56
  """Initializes some internal states based on selected quantization method.
58
57
 
@@ -61,8 +60,8 @@ class QuantConfig:
61
60
  is properly setup. Additionally sets up an utility enum _quantizer_mode to
62
61
  guide certain conversion processes.
63
62
  """
64
- if pt2e_quantizer is not None and transformer_recipe is not None:
65
- raise ValueError('Cannot set both pt2e_quantizer and transformer_recipe.')
63
+ if pt2e_quantizer is not None and generative_recipe is not None:
64
+ raise ValueError('Cannot set both pt2e_quantizer and generative_recipe.')
66
65
  elif pt2e_quantizer is not None:
67
66
  object.__setattr__(self, 'pt2e_quantizer', pt2e_quantizer)
68
67
  object.__setattr__(
@@ -74,12 +73,9 @@ class QuantConfig:
74
73
  else self._QuantizerMode.PT2E_STATIC
75
74
  ),
76
75
  )
77
- elif transformer_recipe is not None:
78
- transformer_recipe.verify()
79
- object.__setattr__(self, 'transformer_recipe', transformer_recipe)
80
- if self.transformer_recipe.default.mode == quant_attrs.Mode.DYNAMIC_RANGE:
81
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_DYNAMIC)
82
- elif self.transformer_recipe.default.weight_dtype == quant_attrs.Dtype.FP16:
83
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_FP16)
76
+ elif generative_recipe is not None:
77
+ generative_recipe.verify()
78
+ object.__setattr__(self, 'generative_recipe', generative_recipe)
79
+ object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
84
80
  else:
85
- raise ValueError('Either pt2e_quantizer or transformer_recipe must be set.')
81
+ raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240610
3
+ Version: 0.2.0.dev20240617
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=FPMmuFU3pyMREtjB_san1fy_0PFtAsgA0VZfOYvDrb4,100
2
2
  ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
4
  ai_edge_torch/convert/conversion.py,sha256=GN2Js232u_5Y118wg3qIfEoYewxbxLl3TpSnO6osi8c,4029
5
- ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBKWVBDe5h_ZaGE,11021
5
+ ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
7
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
8
8
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
@@ -15,11 +15,11 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,
15
15
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
16
16
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
17
17
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
18
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=wr59GFss8fP8Vy--BaBj34Bto0N16gXxQj6OuTXH8cE,10030
18
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
19
19
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
20
20
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
21
21
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
22
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
22
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
23
23
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
24
24
  ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
25
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
@@ -41,9 +41,9 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTc
41
41
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
43
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
44
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=NmgDo5uAefrhMUbYku0TKHlqzO0NVWI_M1ue8tddQR4,4024
45
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=meW8t-3BDdjFs5vCAf76cn6lGx49a_GcEvnVa9R5if4,11106
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=_gEeUxa9Xyd3iLb_fyeUefHKuELVDorDlQs8e7wdXKg,7878
44
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
45
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
47
47
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
48
48
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
49
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -56,7 +56,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
56
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
58
  ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
59
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
59
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
60
60
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
62
62
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
@@ -65,34 +65,36 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
65
65
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
66
66
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
67
67
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
68
- ai_edge_torch/generative/layers/attention.py,sha256=Z8gXHYs6h8gaRiYAdvYUbHzg_2EmqfxiChsf_SYraAc,7902
68
+ ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
69
69
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
70
70
  ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
71
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
72
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
73
- ai_edge_torch/generative/layers/model_config.py,sha256=g_XJXcQOCkE-mt58fSH4-T4GY_uLeMilg6mxwDMCfz4,4557
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
74
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
75
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
76
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
77
77
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=7mHyJYq9lq5zVYp4mEz-R8Az3FFngi711YC20KP6ED8,10066
79
- ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonPTfrXQHcF8U0slrIREM,1210
80
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
78
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
79
+ ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
80
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
81
81
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
82
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
83
- ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
84
- ai_edge_torch/generative/quantize/quant_recipe.py,sha256=BOk4E0FW-_YD8Y-oPVmIDsgXx_bPtvzsP_V1av5DvgU,3327
85
- ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7fkiFL4oWNLVs1dMWXkS_EuyA3Y,1913
86
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
87
- ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
83
+ ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
84
+ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
85
+ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
86
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=9ItD70jQRXMEhWod-nUfEeoWGJUUu6V9YOffF07VU9g,1795
87
+ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
88
+ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
+ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
88
90
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
91
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
90
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
91
- ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
92
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
93
+ ai_edge_torch/generative/test/test_quantize.py,sha256=IjCbCPWzIgXk3s7y7SJsg2usIxhOqs3PuhFvEYR4Sdw,5388
92
94
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
93
- ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
94
95
  ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
95
- ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
96
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
97
+ ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
96
98
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
97
99
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
98
100
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
@@ -103,12 +105,12 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp
103
105
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
104
106
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=ye1f5vAZ0Vr4RWAtfrgU1o3JLs03Sa4inHRq3YxJDGo,15602
105
107
  ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCtUMXVzq3vHKxblsd5Y,36046
106
- ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCziCfhsoMPA,3435
108
+ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
107
109
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
110
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
109
111
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
110
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
111
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/METADATA,sha256=6hL5PV3S56VU2l6xqS-YrmzMZeajtXsikIdR7kDYcWE,1748
112
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
114
- ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/RECORD,,
112
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
113
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/METADATA,sha256=Z9rUO2CabVbBpydpRk8OxNlwK4yznGCb2QHGlJhqRsM,1748
114
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
115
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
116
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/RECORD,,
@@ -1,298 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Common utility functions for data loading etc.
16
- from dataclasses import dataclass
17
- from typing import Dict, List, Tuple
18
-
19
- import torch
20
-
21
- import ai_edge_torch.generative.layers.model_config as layers_config
22
- import ai_edge_torch.generative.layers.unet.model_config as unet_config
23
- import ai_edge_torch.generative.utilities.loader as loader
24
-
25
-
26
- @dataclass
27
- class ResidualBlockTensorNames:
28
- norm_1: str = None
29
- conv_1: str = None
30
- norm_2: str = None
31
- conv_2: str = None
32
- residual_layer: str = None
33
-
34
-
35
- @dataclass
36
- class AttnetionBlockTensorNames:
37
- norm: str = None
38
- fused_qkv_proj: str = None
39
- output_proj: str = None
40
-
41
-
42
- @dataclass
43
- class MidBlockTensorNames:
44
- residual_block_tensor_names: List[ResidualBlockTensorNames]
45
- attention_block_tensor_names: List[AttnetionBlockTensorNames]
46
-
47
-
48
- @dataclass
49
- class UpDecoderBlockTensorNames:
50
- residual_block_tensor_names: List[ResidualBlockTensorNames]
51
- upsample_conv: str = None
52
-
53
-
54
- def _map_to_converted_state(
55
- state: Dict[str, torch.Tensor],
56
- state_param: str,
57
- converted_state: Dict[str, torch.Tensor],
58
- converted_state_param: str,
59
- ):
60
- converted_state[f"{converted_state_param}.weight"] = state.pop(
61
- f"{state_param}.weight"
62
- )
63
- if f"{state_param}.bias" in state:
64
- converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
65
-
66
-
67
- class AutoEncoderModelLoader(loader.ModelLoader):
68
-
69
- @dataclass
70
- class TensorNames:
71
- quant_conv: str = None
72
- post_quant_conv: str = None
73
- conv_in: str = None
74
- conv_out: str = None
75
- final_norm: str = None
76
- mid_block_tensor_names: MidBlockTensorNames = None
77
- up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
78
-
79
- def __init__(self, file_name: str, names: TensorNames):
80
- """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
81
-
82
- Args:
83
- file_name (str): Path to the checkpoint. Can be a directory or an
84
- exact file.
85
- names (TensorNames): An instance of `TensorNames` to determine mappings.
86
- """
87
- self._file_name = file_name
88
- self._names = names
89
- self._loader = self._get_loader()
90
-
91
- def load(
92
- self, model: torch.nn.Module, strict: bool = True
93
- ) -> Tuple[List[str], List[str]]:
94
- """Load the model from the checkpoint.
95
-
96
- Args:
97
- model (torch.nn.Module): The pytorch model that needs to be loaded.
98
- strict (bool, optional): Whether the converted keys are strictly
99
- matched. Defaults to True.
100
-
101
- Returns:
102
- missing_keys (List[str]): a list of str containing the missing keys.
103
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
104
-
105
- Raises:
106
- ValueError: If conversion results in unmapped tensors and strict mode is
107
- enabled.
108
- """
109
- state = self._loader(self._file_name)
110
- converted_state = dict()
111
- if self._names.quant_conv is not None:
112
- _map_to_converted_state(
113
- state, self._names.quant_conv, converted_state, "quant_conv"
114
- )
115
- if self._names.post_quant_conv is not None:
116
- _map_to_converted_state(
117
- state, self._names.post_quant_conv, converted_state, "post_quant_conv"
118
- )
119
- if self._names.conv_in is not None:
120
- _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
121
- if self._names.conv_out is not None:
122
- _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
123
- if self._names.final_norm is not None:
124
- _map_to_converted_state(
125
- state, self._names.final_norm, converted_state, "final_norm"
126
- )
127
- self._map_mid_block(
128
- state,
129
- converted_state,
130
- model.config.mid_block_config,
131
- self._names.mid_block_tensor_names,
132
- )
133
-
134
- reversed_block_out_channels = list(reversed(model.config.block_out_channels))
135
- block_out_channels = reversed_block_out_channels[0]
136
- for i, out_channels in enumerate(reversed_block_out_channels):
137
- prev_output_channel = block_out_channels
138
- block_out_channels = out_channels
139
- not_final_block = i < len(reversed_block_out_channels) - 1
140
- self._map_up_decoder_block(
141
- state,
142
- converted_state,
143
- f"up_decoder_blocks.{i}",
144
- unet_config.UpDecoderBlock2DConfig(
145
- in_channels=prev_output_channel,
146
- out_channels=block_out_channels,
147
- normalization_config=model.config.normalization_config,
148
- activation_type=model.config.activation_type,
149
- num_layers=model.config.layers_per_block,
150
- add_upsample=not_final_block,
151
- upsample_conv=True,
152
- ),
153
- self._names.up_decoder_blocks_tensor_names[i],
154
- )
155
- if strict and state:
156
- raise ValueError(
157
- f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
158
- )
159
- return model.load_state_dict(converted_state, strict=strict)
160
-
161
- def _map_residual_block(
162
- self,
163
- state: Dict[str, torch.Tensor],
164
- converted_state: Dict[str, torch.Tensor],
165
- tensor_names: ResidualBlockTensorNames,
166
- converted_state_param_prefix: str,
167
- config: unet_config.ResidualBlock2DConfig,
168
- ):
169
- _map_to_converted_state(
170
- state,
171
- tensor_names.norm_1,
172
- converted_state,
173
- f"{converted_state_param_prefix}.norm_1",
174
- )
175
- _map_to_converted_state(
176
- state,
177
- tensor_names.conv_1,
178
- converted_state,
179
- f"{converted_state_param_prefix}.conv_1",
180
- )
181
- _map_to_converted_state(
182
- state,
183
- tensor_names.norm_2,
184
- converted_state,
185
- f"{converted_state_param_prefix}.norm_2",
186
- )
187
- _map_to_converted_state(
188
- state,
189
- tensor_names.conv_2,
190
- converted_state,
191
- f"{converted_state_param_prefix}.conv_2",
192
- )
193
- if config.in_channels != config.out_channels:
194
- _map_to_converted_state(
195
- state,
196
- tensor_names.residual_layer,
197
- converted_state,
198
- f"{converted_state_param_prefix}.residual_layer",
199
- )
200
-
201
- def _map_attention_block(
202
- self,
203
- state: Dict[str, torch.Tensor],
204
- converted_state: Dict[str, torch.Tensor],
205
- tensor_names: AttnetionBlockTensorNames,
206
- converted_state_param_prefix: str,
207
- config: unet_config.AttentionBlock2DConfig,
208
- ):
209
- if config.normalization_config.type != layers_config.NormalizationType.NONE:
210
- _map_to_converted_state(
211
- state,
212
- tensor_names.norm,
213
- converted_state,
214
- f"{converted_state_param_prefix}.norm",
215
- )
216
- attention_layer_prefix = f"{converted_state_param_prefix}.attention"
217
- _map_to_converted_state(
218
- state,
219
- tensor_names.fused_qkv_proj,
220
- converted_state,
221
- f"{attention_layer_prefix}.qkv_projection",
222
- )
223
- _map_to_converted_state(
224
- state,
225
- tensor_names.output_proj,
226
- converted_state,
227
- f"{attention_layer_prefix}.output_projection",
228
- )
229
-
230
- def _map_mid_block(
231
- self,
232
- state: Dict[str, torch.Tensor],
233
- converted_state: Dict[str, torch.Tensor],
234
- config: unet_config.MidBlock2DConfig,
235
- tensor_names: MidBlockTensorNames,
236
- ):
237
- converted_state_param_prefix = "mid_block"
238
- residual_block_config = unet_config.ResidualBlock2DConfig(
239
- in_channels=config.in_channels,
240
- out_channels=config.in_channels,
241
- time_embedding_channels=config.time_embedding_channels,
242
- normalization_config=config.normalization_config,
243
- activation_type=config.activation_type,
244
- )
245
- self._map_residual_block(
246
- state,
247
- converted_state,
248
- tensor_names.residual_block_tensor_names[0],
249
- f"{converted_state_param_prefix}.resnets.0",
250
- residual_block_config,
251
- )
252
- for i in range(config.num_layers):
253
- if config.attention_block_config:
254
- self._map_attention_block(
255
- state,
256
- converted_state,
257
- tensor_names.attention_block_tensor_names[i],
258
- f"{converted_state_param_prefix}.attentions.{i}",
259
- config.attention_block_config,
260
- )
261
- self._map_residual_block(
262
- state,
263
- converted_state,
264
- tensor_names.residual_block_tensor_names[i + 1],
265
- f"{converted_state_param_prefix}.resnets.{i+1}",
266
- residual_block_config,
267
- )
268
-
269
- def _map_up_decoder_block(
270
- self,
271
- state: Dict[str, torch.Tensor],
272
- converted_state: Dict[str, torch.Tensor],
273
- converted_state_param_prefix: str,
274
- config: unet_config.UpDecoderBlock2DConfig,
275
- tensor_names: UpDecoderBlockTensorNames,
276
- ):
277
- for i in range(config.num_layers):
278
- input_channels = config.in_channels if i == 0 else config.out_channels
279
- self._map_residual_block(
280
- state,
281
- converted_state,
282
- tensor_names.residual_block_tensor_names[i],
283
- f"{converted_state_param_prefix}.resnets.{i}",
284
- unet_config.ResidualBlock2DConfig(
285
- in_channels=input_channels,
286
- out_channels=config.out_channels,
287
- time_embedding_channels=config.time_embedding_channels,
288
- normalization_config=config.normalization_config,
289
- activation_type=config.activation_type,
290
- ),
291
- )
292
- if config.add_upsample and config.upsample_conv:
293
- _map_to_converted_state(
294
- state,
295
- tensor_names.upsample_conv,
296
- converted_state,
297
- f"{converted_state_param_prefix}.upsample_conv",
298
- )