ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
@@ -33,17 +33,17 @@ def convert_phi2_to_tflite(
33
33
  quantize: bool = True,
34
34
  ):
35
35
  """An example method for converting a Phi-2 model to multi-signature
36
- tflite model.
37
36
 
37
+ tflite model.
38
38
  Args:
39
- checkpoint_path (str): The filepath to the model checkpoint, or
40
- directory holding the checkpoint.
39
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
40
+ holding the checkpoint.
41
41
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
42
42
  Defaults to 512.
43
43
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
44
44
  including both prefill and decode. Defaults to 1024.
45
- quantize (bool, optional): Whether the model should be quanized.
46
- Defaults to True.
45
+ quantize (bool, optional): Whether the model should be quanized. Defaults
46
+ to True.
47
47
  """
48
48
  pytorch_model = phi2.build_model(
49
49
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -68,7 +68,9 @@ class Phi2(nn.Module):
68
68
  )
69
69
  self.rope_cache = attn_utils.build_rope_cache(
70
70
  size=config.kv_cache_max,
71
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
71
+ dim=int(
72
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
73
+ ),
72
74
  base=10_000,
73
75
  condense_ratio=1,
74
76
  dtype=torch.float32,
@@ -118,6 +120,7 @@ class Phi2(nn.Module):
118
120
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
119
121
  attn_config = cfg.AttentionConfig(
120
122
  num_heads=32,
123
+ head_dim=80,
121
124
  num_query_groups=32,
122
125
  rotary_percentage=0.4,
123
126
  qkv_use_bias=True,
@@ -21,7 +21,7 @@ import os
21
21
  from pathlib import Path
22
22
 
23
23
  import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
24
+ from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
25
25
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
26
  from ai_edge_torch.generative.quantize import quant_recipes
27
27
  import torch
@@ -33,8 +33,7 @@ def convert_tiny_llama_to_tflite(
33
33
  kv_cache_max_len: int = 1024,
34
34
  quantize: bool = True,
35
35
  ):
36
- """An example method for converting TinyLlama model to multi-signature
37
- tflite model.
36
+ """An example for converting TinyLlama model to multi-signature tflite model.
38
37
 
39
38
  Args:
40
39
  checkpoint_path (str): The filepath to the model checkpoint, or directory
@@ -43,8 +42,8 @@ def convert_tiny_llama_to_tflite(
43
42
  Defaults to 512.
44
43
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
45
44
  including both prefill and decode. Defaults to 1024.
46
- quantize (bool, optional): Whether the model should be quanized.
47
- Defaults to True.
45
+ quantize (bool, optional): Whether the model should be quanized. Defaults
46
+ to True.
48
47
  """
49
48
  pytorch_model = tiny_llama.build_model(
50
49
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -70,7 +70,9 @@ class TinyLLamma(nn.Module):
70
70
  )
71
71
  self.rope_cache = attn_utils.build_rope_cache(
72
72
  size=config.kv_cache_max,
73
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
73
+ dim=int(
74
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
75
+ ),
74
76
  base=10_000,
75
77
  condense_ratio=1,
76
78
  dtype=torch.float32,
@@ -121,6 +123,7 @@ class TinyLLamma(nn.Module):
121
123
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
122
124
  attn_config = cfg.AttentionConfig(
123
125
  num_heads=32,
126
+ head_dim=64,
124
127
  num_query_groups=4,
125
128
  rotary_percentage=1.0,
126
129
  )
@@ -28,17 +28,17 @@ def convert_gemma_to_tflite(
28
28
  kv_cache_max_len: int = 1024,
29
29
  quantize: bool = True,
30
30
  ):
31
- """An example method for converting a Gemma 2B model to multi-signature
32
- tflite model.
31
+ """Converts a Gemma 2B model to multi-signature tflite model.
33
32
 
34
33
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
34
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
35
+ holding the checkpoint.
36
36
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
37
  Defaults to 512.
38
38
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
39
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
40
+ quantize (bool, optional): Whether the model should be quanized. Defaults
41
+ to True.
42
42
  """
43
43
  pytorch_model = gemma.build_2b_model(
44
44
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -68,7 +68,9 @@ class Gemma(nn.Module):
68
68
  )
69
69
  self.rope_cache = attn_utils.build_rope_cache(
70
70
  size=config.kv_cache_max,
71
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
71
+ dim=int(
72
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
73
+ ),
72
74
  base=10_000,
73
75
  condense_ratio=1,
74
76
  dtype=torch.float32,
@@ -113,6 +115,7 @@ class Gemma(nn.Module):
113
115
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
114
116
  attn_config = cfg.AttentionConfig(
115
117
  num_heads=8,
118
+ head_dim=256,
116
119
  num_query_groups=1,
117
120
  rotary_percentage=1.0,
118
121
  )
@@ -28,17 +28,17 @@ def convert_phi2_to_tflite(
28
28
  kv_cache_max_len: int = 1024,
29
29
  quantize: bool = True,
30
30
  ):
31
- """An example method for converting a Phi-2 model to multi-signature
32
- tflite model.
31
+ """Converts a Phi-2 model to multi-signature tflite model.
33
32
 
34
33
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
34
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
35
+ holding the checkpoint.
36
36
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
37
  Defaults to 512.
38
38
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
39
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
40
+ quantize (bool, optional): Whether the model should be quanized. Defaults
41
+ to True.
42
42
  """
43
43
  pytorch_model = phi2.build_model(
44
44
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -63,7 +63,9 @@ class Phi2(nn.Module):
63
63
  )
64
64
  self.rope_cache = attn_utils.build_rope_cache(
65
65
  size=config.kv_cache_max,
66
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
66
+ dim=int(
67
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
68
+ ),
67
69
  base=10_000,
68
70
  condense_ratio=1,
69
71
  dtype=torch.float32,
@@ -107,6 +109,7 @@ class Phi2(nn.Module):
107
109
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
108
110
  attn_config = cfg.AttentionConfig(
109
111
  num_heads=32,
112
+ head_dim=80,
110
113
  num_query_groups=32,
111
114
  rotary_percentage=0.4,
112
115
  qkv_use_bias=True,
@@ -49,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
49
49
 
50
50
  class CLIP(nn.Module):
51
51
  """CLIP text encoder
52
+
52
53
  For details, see https://arxiv.org/abs/2103.00020
53
54
  """
54
55
 
@@ -92,6 +93,7 @@ def get_model_config() -> cfg.ModelConfig:
92
93
 
93
94
  attn_config = cfg.AttentionConfig(
94
95
  num_heads=num_heads,
96
+ head_dim=embedding_dim // num_heads,
95
97
  num_query_groups=num_query_groups,
96
98
  rotary_percentage=0.0,
97
99
  qkv_use_bias=True,
@@ -15,9 +15,9 @@
15
15
 
16
16
  import ai_edge_torch.generative.layers.builder as layers_builder
17
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
18
- import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
18
+ from ai_edge_torch.generative.layers.unet import blocks_2d
19
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
20
- import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
20
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
21
21
  import torch
22
22
  from torch import nn
23
23
 
@@ -288,6 +288,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
288
288
  normalization_config=norm_config,
289
289
  attention_config=layers_cfg.AttentionConfig(
290
290
  num_heads=1,
291
+ head_dim=block_out_channels[-1],
291
292
  num_query_groups=1,
292
293
  qkv_use_bias=True,
293
294
  output_proj_use_bias=True,
@@ -15,9 +15,9 @@
15
15
 
16
16
  import ai_edge_torch.generative.layers.builder as layers_builder
17
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
18
- import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
18
+ from ai_edge_torch.generative.layers.unet import blocks_2d
19
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
20
- import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
20
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
21
21
  import torch
22
22
  from torch import nn
23
23
 
@@ -195,6 +195,31 @@ TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
195
195
  )
196
196
 
197
197
 
198
+ def build_attention_config(
199
+ num_heads,
200
+ dim,
201
+ num_query_groups,
202
+ rotary_percentage=0.0,
203
+ qkv_transpose_before_split=True,
204
+ qkv_use_bias=False,
205
+ output_proj_use_bias=True,
206
+ enable_kv_cache=False,
207
+ qkv_fused_interleaved=False,
208
+ ):
209
+
210
+ return layers_cfg.AttentionConfig(
211
+ num_heads=num_heads,
212
+ head_dim=dim // num_heads,
213
+ num_query_groups=num_query_groups,
214
+ rotary_percentage=rotary_percentage,
215
+ qkv_transpose_before_split=qkv_transpose_before_split,
216
+ qkv_use_bias=qkv_use_bias,
217
+ output_proj_use_bias=output_proj_use_bias,
218
+ enable_kv_cache=enable_kv_cache,
219
+ qkv_fused_interleaved=qkv_fused_interleaved,
220
+ )
221
+
222
+
198
223
  class TimeEmbedding(nn.Module):
199
224
 
200
225
  def __init__(self, in_dim, out_dim):
@@ -267,17 +292,6 @@ class Diffusion(nn.Module):
267
292
  config.in_channels, block_out_channels[0], kernel_size=3, padding=1
268
293
  )
269
294
 
270
- attention_config = layers_cfg.AttentionConfig(
271
- num_heads=config.transformer_num_attention_heads,
272
- num_query_groups=config.transformer_num_attention_heads,
273
- rotary_percentage=0.0,
274
- qkv_transpose_before_split=True,
275
- qkv_use_bias=False,
276
- output_proj_use_bias=True,
277
- enable_kv_cache=False,
278
- qkv_fused_interleaved=False,
279
- )
280
-
281
295
  # Down encoders.
282
296
  down_encoders = []
283
297
  output_channel = block_out_channels[0]
@@ -312,7 +326,11 @@ class Diffusion(nn.Module):
312
326
  dim=output_channel,
313
327
  attention_batch_size=config.transformer_batch_size,
314
328
  normalization_config=config.transformer_norm_config,
315
- attention_config=attention_config,
329
+ attention_config=build_attention_config(
330
+ num_heads=config.transformer_num_attention_heads,
331
+ dim=output_channel,
332
+ num_query_groups=config.transformer_num_attention_heads,
333
+ ),
316
334
  enable_hlfb=False,
317
335
  ),
318
336
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
@@ -320,7 +338,11 @@ class Diffusion(nn.Module):
320
338
  cross_dim=config.transformer_cross_attention_dim,
321
339
  attention_batch_size=config.transformer_batch_size,
322
340
  normalization_config=config.transformer_norm_config,
323
- attention_config=attention_config,
341
+ attention_config=build_attention_config(
342
+ num_heads=config.transformer_num_attention_heads,
343
+ dim=output_channel,
344
+ num_query_groups=config.transformer_num_attention_heads,
345
+ ),
324
346
  enable_hlfb=False,
325
347
  ),
326
348
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
@@ -374,7 +396,11 @@ class Diffusion(nn.Module):
374
396
  dim=mid_block_channels,
375
397
  attention_batch_size=config.transformer_batch_size,
376
398
  normalization_config=config.transformer_norm_config,
377
- attention_config=attention_config,
399
+ attention_config=build_attention_config(
400
+ num_heads=config.transformer_num_attention_heads,
401
+ dim=mid_block_channels,
402
+ num_query_groups=config.transformer_num_attention_heads,
403
+ ),
378
404
  enable_hlfb=False,
379
405
  ),
380
406
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
@@ -382,7 +408,11 @@ class Diffusion(nn.Module):
382
408
  cross_dim=config.transformer_cross_attention_dim,
383
409
  attention_batch_size=config.transformer_batch_size,
384
410
  normalization_config=config.transformer_norm_config,
385
- attention_config=attention_config,
411
+ attention_config=build_attention_config(
412
+ num_heads=config.transformer_num_attention_heads,
413
+ dim=mid_block_channels,
414
+ num_query_groups=config.transformer_num_attention_heads,
415
+ ),
386
416
  enable_hlfb=False,
387
417
  ),
388
418
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
@@ -437,7 +467,11 @@ class Diffusion(nn.Module):
437
467
  dim=output_channel,
438
468
  attention_batch_size=config.transformer_batch_size,
439
469
  normalization_config=config.transformer_norm_config,
440
- attention_config=attention_config,
470
+ attention_config=build_attention_config(
471
+ num_heads=config.transformer_num_attention_heads,
472
+ dim=output_channel,
473
+ num_query_groups=config.transformer_num_attention_heads,
474
+ ),
441
475
  enable_hlfb=False,
442
476
  ),
443
477
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
@@ -445,7 +479,11 @@ class Diffusion(nn.Module):
445
479
  cross_dim=config.transformer_cross_attention_dim,
446
480
  attention_batch_size=config.transformer_batch_size,
447
481
  normalization_config=config.transformer_norm_config,
448
- attention_config=attention_config,
482
+ attention_config=build_attention_config(
483
+ num_heads=config.transformer_num_attention_heads,
484
+ dim=output_channel,
485
+ num_query_groups=config.transformer_num_attention_heads,
486
+ ),
449
487
  enable_hlfb=False,
450
488
  ),
451
489
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
@@ -543,7 +581,6 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
543
581
 
544
582
  Retruns:
545
583
  The configuration of diffusion model of Stable Diffusion v1.5.
546
-
547
584
  """
548
585
  in_channels = 4
549
586
  out_channels = 4
@@ -127,7 +127,9 @@ def run_tflite_pipeline(
127
127
  input_image: Optional[Image.Image] = None,
128
128
  ):
129
129
  """Run stable diffusion pipeline with tflite model.
130
+
130
131
  model:
132
+
131
133
  StableDiffsuion model.
132
134
  prompt:
133
135
  The prompt to guide the image generation.
@@ -136,27 +138,36 @@ def run_tflite_pipeline(
136
138
  uncond_prompt:
137
139
  The prompt not to guide the image generation.
138
140
  cfg_scale:
139
- Guidance scale of classifier-free guidance. Higher guidance scale encourages to generate
140
- images that are closely linked to the text `prompt`, usually at the expense of lower
141
+ Guidance scale of classifier-free guidance. Higher guidance scale encourages
142
+ to generate
143
+ images that are closely linked to the text `prompt`, usually at the expense
144
+ of lower
141
145
  image quality.
142
146
  height:
143
147
  The height in pixels of the generated image.
144
148
  width:
145
149
  The width in pixels of the generated image.
146
150
  sampler:
147
- A sampler to be used to denoise the encoded image latents. Can be one of `k_lms, `k_euler`,
151
+ A sampler to be used to denoise the encoded image latents. Can be one of
152
+ `k_lms, `k_euler`,
148
153
  or `k_euler_ancestral`.
149
154
  n_inference_steps:
150
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
155
+ The number of denoising steps. More denoising steps usually lead to a higher
156
+ quality image at the
151
157
  expense of slower inference. This parameter will be modulated by `strength`.
152
158
  seed:
153
159
  A seed to make generation deterministic.
154
160
  strength:
155
- Conceptually, indicates how much to transform the reference `input_image`. Must be between 0 and 1.
156
- `input_image` will be used as a starting point, adding more noise to it the larger the `strength`.
157
- The number of denoising steps depends on the amount of noise initially added. When `strength` is 1,
158
- added noise will be maximum and the denoising process will run for the full number of iterations
159
- specified in `n_inference_steps`. A value of 1, therefore, essentially ignores `input_image`.
161
+ Conceptually, indicates how much to transform the reference `input_image`.
162
+ Must be between 0 and 1.
163
+ `input_image` will be used as a starting point, adding more noise to it the
164
+ larger the `strength`.
165
+ The number of denoising steps depends on the amount of noise initially
166
+ added. When `strength` is 1,
167
+ added noise will be maximum and the denoising process will run for the full
168
+ number of iterations
169
+ specified in `n_inference_steps`. A value of 1, therefore, essentially
170
+ ignores `input_image`.
160
171
  input_image:
161
172
  Image which is served as the starting point for the image generation.
162
173
  """
@@ -28,6 +28,7 @@ class SamplerInterface(abc.ABC):
28
28
  @abc.abstractmethod
29
29
  def set_strength(self, strength: float = 1) -> None:
30
30
  """Set the strength of initial step.
31
+
31
32
  Conceptually, indicates how much to transform the reference `input_images`.
32
33
  """
33
34
  return NotImplemented
@@ -17,14 +17,13 @@
17
17
  import copy
18
18
  import os
19
19
  from pathlib import Path
20
- from typing import Optional, Tuple
20
+ from typing import Optional
21
21
 
22
22
  from ai_edge_torch.generative.examples.t5.t5_attention import EncoderDecoderBlock # NOQA
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.builder as builder
25
25
  import ai_edge_torch.generative.layers.model_config as cfg
26
26
  import ai_edge_torch.generative.utilities.t5_loader as loading_utils
27
- import numpy as np
28
27
  import torch
29
28
  import torch.nn as nn
30
29
 
@@ -371,6 +370,7 @@ class T5Decoder(nn.Module):
371
370
  def get_model_config_t5() -> cfg.ModelConfig:
372
371
  attn_config = cfg.AttentionConfig(
373
372
  num_heads=12,
373
+ head_dim=64,
374
374
  num_query_groups=12,
375
375
  qkv_use_bias=False,
376
376
  relative_attention_num_buckets=32,
@@ -37,10 +37,10 @@ class EncoderDecoderBlock(nn.Module):
37
37
  """Initialize an instance of the EncoderDecoderBlock.
38
38
 
39
39
  Args:
40
- config (cfg.ModelConfig): the configuration object
41
- for this transformer block.
42
- has_relative_attention_bias (bool): whether the
43
- self attention block has relative bias.
40
+ config (cfg.ModelConfig): the configuration object for this transformer
41
+ block.
42
+ has_relative_attention_bias (bool): whether the self attention block has
43
+ relative bias.
44
44
  """
45
45
 
46
46
  super().__init__()
@@ -143,8 +143,10 @@ class T5Attention(CrossAttention):
143
143
  Args:
144
144
  dim (int): causal attention's input/output dimmension.
145
145
  config (cfg.AttentionConfig): attention specific configurations.
146
- norm_config (cfg.NormalizationConfig): normalization configure before attention.
147
- kv_cache_max (int): determines the size of the KV Cache buffer, if enabled.
146
+ norm_config (cfg.NormalizationConfig): normalization configure before
147
+ attention.
148
+ kv_cache_max (int): determines the size of the KV Cache buffer, if
149
+ enabled.
148
150
  enable_hlfb (bool): whether hlfb is enabled or not.
149
151
  has_relative_attention_bias (bool): whether we compute relative bias.
150
152
  """
@@ -185,7 +187,7 @@ class T5Attention(CrossAttention):
185
187
  ) # batch size, sequence length, embedding dimensionality (n_embd)
186
188
  query_states = self.q_projection(x)
187
189
  query_states = query_states.reshape(
188
- B, T, -1, self.head_dim
190
+ B, T, -1, self.config.head_dim
189
191
  ) # (B, T, nh_q, hs)
190
192
 
191
193
  if key_value_states is not None:
@@ -198,13 +200,13 @@ class T5Attention(CrossAttention):
198
200
  ) # batch size, sequence length, embedding dimensionality (n_embd)
199
201
  key_states = self.k_projection(key_value_states)
200
202
  value_states = self.v_projection(key_value_states)
201
- key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
202
- value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
203
+ key_states = key_states.reshape(kvB, kvT, -1, self.config.head_dim)
204
+ value_states = value_states.reshape(kvB, kvT, -1, self.config.head_dim)
203
205
  else:
204
206
  key_states = self.k_projection(x)
205
207
  value_states = self.v_projection(x)
206
- key_states = key_states.reshape(B, T, -1, self.head_dim)
207
- value_states = value_states.reshape(B, T, -1, self.head_dim)
208
+ key_states = key_states.reshape(B, T, -1, self.config.head_dim)
209
+ value_states = value_states.reshape(B, T, -1, self.config.head_dim)
208
210
 
209
211
  if key_value_states is None and self.kv_cache is not None:
210
212
  key_states, value_states = self.kv_cache.update_cache(
@@ -221,7 +223,7 @@ class T5Attention(CrossAttention):
221
223
  0
222
224
  ) # shape (1, num_heads, query_length, key_length)
223
225
  else:
224
- # position_bias = torch.zeros(B, self.n_heads, T, self.head_dim, dtype=torch.float32)
226
+ # position_bias = torch.zeros(B, self.n_heads, T, self.config.head_dim, dtype=torch.float32)
225
227
  position_bias = torch.zeros_like(mask, dtype=torch.float32)
226
228
 
227
229
  mask = mask + position_bias
@@ -229,7 +231,7 @@ class T5Attention(CrossAttention):
229
231
  query_states,
230
232
  key_states,
231
233
  value_states,
232
- self.head_dim,
234
+ self.config.head_dim,
233
235
  mask=mask,
234
236
  scale=1.0,
235
237
  )
@@ -43,7 +43,9 @@ class ToySingleLayerModel(torch.nn.Module):
43
43
  )
44
44
  self.rope_cache = attn_utils.build_rope_cache(
45
45
  size=config.max_seq_len,
46
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
46
+ dim=int(
47
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
48
+ ),
47
49
  base=10_000,
48
50
  condense_ratio=1,
49
51
  dtype=torch.float32,
@@ -72,6 +74,7 @@ class ToySingleLayerModel(torch.nn.Module):
72
74
  def get_model_config() -> cfg.ModelConfig:
73
75
  attn_config = cfg.AttentionConfig(
74
76
  num_heads=32,
77
+ head_dim=4,
75
78
  num_query_groups=4,
76
79
  rotary_percentage=1.0,
77
80
  enable_kv_cache=False,
@@ -17,6 +17,7 @@
17
17
  from typing import Tuple
18
18
 
19
19
  import ai_edge_torch
20
+ from ai_edge_torch import lowertools
20
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
22
  import ai_edge_torch.generative.layers.builder as builder
22
23
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
@@ -24,7 +25,6 @@ from ai_edge_torch.generative.layers.experimental.attention import TransformerBl
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
25
26
  import torch
26
27
  import torch.nn as nn
27
- import torch_xla
28
28
 
29
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
30
30
 
@@ -46,7 +46,9 @@ class ToyModelWithExternalKV(torch.nn.Module):
46
46
  )
47
47
  self.rope_cache = attn_utils.build_rope_cache(
48
48
  size=config.max_seq_len,
49
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
49
+ dim=int(
50
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
51
+ ),
50
52
  base=10_000,
51
53
  condense_ratio=1,
52
54
  dtype=torch.float32,
@@ -84,13 +86,12 @@ class ToyModelWithExternalKV(torch.nn.Module):
84
86
 
85
87
  def _export_stablehlo_mlir(model, args):
86
88
  ep = torch.export.export(model, args)
87
- stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
88
- return stablehlo_gm.get_stablehlo_text()
89
+ return lowertools.exported_program_to_mlir_text(ep)
89
90
 
90
91
 
91
92
  def get_model_config() -> cfg.ModelConfig:
92
93
  attn_config = cfg.AttentionConfig(
93
- num_heads=32, num_query_groups=4, rotary_percentage=1.0
94
+ num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
94
95
  )
95
96
  ff_config = cfg.FeedForwardConfig(
96
97
  type=cfg.FeedForwardType.GATED,
@@ -13,17 +13,16 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # A toy example which has basic transformer block (w/ KV-Cache).
16
- from typing import List, Tuple
16
+ from typing import Tuple
17
17
 
18
18
  import ai_edge_torch
19
+ from ai_edge_torch import lowertools
19
20
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
22
  import ai_edge_torch.generative.layers.builder as builder
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import numpy as np
24
24
  import torch
25
25
  import torch.nn as nn
26
- import torch_xla
27
26
 
28
27
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
29
28
 
@@ -45,7 +44,9 @@ class ToyModelWithKV(torch.nn.Module):
45
44
  )
46
45
  self.rope_cache = attn_utils.build_rope_cache(
47
46
  size=config.max_seq_len,
48
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
47
+ dim=int(
48
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
49
+ ),
49
50
  base=10_000,
50
51
  condense_ratio=1,
51
52
  dtype=torch.float32,
@@ -72,13 +73,12 @@ class ToyModelWithKV(torch.nn.Module):
72
73
 
73
74
  def _export_stablehlo_mlir(model, args):
74
75
  ep = torch.export.export(model, args)
75
- stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
76
- return stablehlo_gm.get_stablehlo_text()
76
+ return lowertools.exported_program_to_mlir_text(ep)
77
77
 
78
78
 
79
79
  def get_model_config() -> cfg.ModelConfig:
80
80
  attn_config = cfg.AttentionConfig(
81
- num_heads=32, num_query_groups=4, rotary_percentage=1.0
81
+ num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
82
82
  )
83
83
  ff_config = cfg.FeedForwardConfig(
84
84
  type=cfg.FeedForwardType.GATED,
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================