ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmolLM model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = smollm.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmolLM model."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
21
+ import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.utilities.loader as loading_utils
23
+ from torch import nn
24
+
25
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
26
+ # SmolLM re-uses the embedding as the head projection layer.
27
+ TENSOR_NAMES.lm_head = None
28
+
29
+
30
+ class SmolLM(tiny_llama.TinyLlama):
31
+ """A SmolLM model built from the Edge Generative API layers.
32
+
33
+ SmolLM shares the same architecture as TinyLlama, but with different model
34
+ sizes.
35
+ """
36
+
37
+ def __init__(self, config: cfg.ModelConfig):
38
+ super().__init__(config)
39
+ # SmolLM re-uses the embedding as the head projection layer.
40
+ self.lm_head.weight.data = self.tok_embedding.weight.data
41
+
42
+
43
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
44
+ """Returns the model config for a SmolLM 135M model.
45
+
46
+ Args:
47
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
48
+ is 1024.
49
+
50
+ Returns:
51
+ The model config for a SmolLM model.
52
+ """
53
+ attn_config = cfg.AttentionConfig(
54
+ num_heads=9,
55
+ head_dim=64,
56
+ num_query_groups=3,
57
+ rotary_percentage=1.0,
58
+ )
59
+ ff_config = cfg.FeedForwardConfig(
60
+ type=cfg.FeedForwardType.GATED,
61
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
62
+ intermediate_size=1536,
63
+ )
64
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
65
+ block_config = cfg.TransformerBlockConfig(
66
+ attn_config=attn_config,
67
+ ff_config=ff_config,
68
+ pre_attention_norm_config=norm_config,
69
+ post_attention_norm_config=norm_config,
70
+ )
71
+ config = cfg.ModelConfig(
72
+ vocab_size=49152,
73
+ num_layers=30,
74
+ max_seq_len=2048,
75
+ embedding_dim=576,
76
+ kv_cache_max_len=kv_cache_max_len,
77
+ block_configs=block_config,
78
+ final_norm_config=norm_config,
79
+ enable_hlfb=True,
80
+ )
81
+ return config
82
+
83
+
84
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
85
+ config = get_model_config(**kwargs)
86
+ config.vocab_size = 128
87
+ config.num_layers = 2
88
+ # SmolLM has only one block config.
89
+ config.block_config(0).ff_config.intermediate_size = 64
90
+ return config
91
+
92
+
93
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
94
+ config = get_model_config(**kwargs)
95
+ model = SmolLM(config)
96
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
97
+ # Since embedding and lm-head use the same weight, we need to set strict
98
+ # to False.
99
+ loader.load(model, strict=False)
100
+ model.eval()
101
+ return model
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored SmolLM-135M model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ import transformers
26
+
27
+
28
+ _PROMPTS = flags.DEFINE_multi_string(
29
+ "prompts",
30
+ "What is the meaning of life?",
31
+ "The input prompts to generate answers.",
32
+ )
33
+
34
+
35
+ def main(_):
36
+ checkpoint = "HuggingFaceTB/SmolLM-135M"
37
+ logging.info("Loading the original model from: %s", checkpoint)
38
+ wrapper_model = verifier.ModelWrapper(
39
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
40
+ )
41
+ # Locate the cached dir.
42
+ cached_config_file = transformers.utils.cached_file(
43
+ checkpoint, transformers.utils.CONFIG_NAME
44
+ )
45
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
47
+ reauthored_model = smollm.build_model(reauthored_checkpoint)
48
+
49
+ logging.info("Loading the tokenizer from: %s", checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=wrapper_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ generate_prompts=_PROMPTS.value,
57
+ atol=1e-04,
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ app.run(main)
@@ -73,7 +73,9 @@ class SelfAttention(nn.Module):
73
73
 
74
74
  class CrossAttention(nn.Module):
75
75
 
76
- def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
76
+ def __init__(
77
+ self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True
78
+ ):
77
79
  super().__init__()
78
80
  self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
79
81
  self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
@@ -13,25 +13,34 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
17
  import ai_edge_torch.generative.layers.attention_utils as attention_utils
21
18
  import ai_edge_torch.generative.layers.builder as builder
22
19
  import ai_edge_torch.generative.layers.model_config as cfg
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
- ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
27
- ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
25
+ ff_up_proj=(
26
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1"
27
+ ),
28
+ ff_down_proj=(
29
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2"
30
+ ),
28
31
  attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
29
32
  attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
30
33
  attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
31
34
  attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
32
- pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
33
- pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
34
- embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
35
+ pre_attn_norm=(
36
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1"
37
+ ),
38
+ post_attn_norm=(
39
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2"
40
+ ),
41
+ embedding=(
42
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding"
43
+ ),
35
44
  embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
36
45
  final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
37
46
  lm_head=None,
@@ -39,7 +48,8 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
39
48
 
40
49
 
41
50
  class CLIP(nn.Module):
42
- """CLIP text encoder
51
+ """CLIP text encoder.
52
+
43
53
  For details, see https://arxiv.org/abs/2103.00020
44
54
  """
45
55
 
@@ -51,10 +61,14 @@ class CLIP(nn.Module):
51
61
  )
52
62
 
53
63
  self.config = config
64
+ # CLIP has only one block config.
65
+ block_config = config.block_config(0)
54
66
  self.transformer_blocks = nn.ModuleList(
55
- TransformerBlock(config) for _ in range(config.num_layers)
67
+ TransformerBlock(block_config, config) for _ in range(config.num_layers)
68
+ )
69
+ self.final_norm = builder.build_norm(
70
+ config.embedding_dim, config.final_norm_config
56
71
  )
57
- self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
58
72
 
59
73
  self.mask_cache = attention_utils.build_causal_mask_cache(
60
74
  size=config.max_seq_len, dtype=torch.float32
@@ -62,7 +76,7 @@ class CLIP(nn.Module):
62
76
 
63
77
  @torch.inference_mode
64
78
  def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
65
- tokens = tokens.type(torch.long)
79
+ tokens = tokens.type(torch.int)
66
80
 
67
81
  state = self.tok_embedding(tokens) + self.tok_embedding_position
68
82
  for layer in self.transformer_blocks:
@@ -72,6 +86,7 @@ class CLIP(nn.Module):
72
86
 
73
87
 
74
88
  def get_model_config() -> cfg.ModelConfig:
89
+ """Get configs for the CLIP of Stable Diffusion v1.5."""
75
90
  max_seq_len = 77
76
91
  vocab_size = 49408
77
92
  num_layers = 12
@@ -81,6 +96,7 @@ def get_model_config() -> cfg.ModelConfig:
81
96
 
82
97
  attn_config = cfg.AttentionConfig(
83
98
  num_heads=num_heads,
99
+ head_dim=embedding_dim // num_heads,
84
100
  num_query_groups=num_query_groups,
85
101
  rotary_percentage=0.0,
86
102
  qkv_use_bias=True,
@@ -99,15 +115,69 @@ def get_model_config() -> cfg.ModelConfig:
99
115
 
100
116
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
101
117
 
118
+ block_config = cfg.TransformerBlockConfig(
119
+ attn_config=attn_config,
120
+ ff_config=ff_config,
121
+ pre_attention_norm_config=norm_config,
122
+ post_attention_norm_config=norm_config,
123
+ )
124
+
102
125
  config = cfg.ModelConfig(
103
126
  vocab_size=vocab_size,
104
127
  num_layers=num_layers,
105
128
  max_seq_len=max_seq_len,
106
129
  embedding_dim=embedding_dim,
130
+ block_configs=block_config,
131
+ final_norm_config=norm_config,
132
+ enable_hlfb=True,
133
+ )
134
+
135
+ return config
136
+
137
+
138
+ def get_fake_model_config() -> cfg.ModelConfig:
139
+ """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
140
+ max_seq_len = 6
141
+ vocab_size = 100
142
+ num_layers = 2
143
+ num_heads = 12
144
+ num_query_groups = 12
145
+ embedding_dim = 24
146
+
147
+ attn_config = cfg.AttentionConfig(
148
+ num_heads=num_heads,
149
+ head_dim=embedding_dim // num_heads,
150
+ num_query_groups=num_query_groups,
151
+ rotary_percentage=0.0,
152
+ qkv_use_bias=True,
153
+ qkv_transpose_before_split=True,
154
+ qkv_fused_interleaved=False,
155
+ output_proj_use_bias=True,
156
+ enable_kv_cache=False,
157
+ )
158
+
159
+ ff_config = cfg.FeedForwardConfig(
160
+ type=cfg.FeedForwardType.SEQUENTIAL,
161
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
162
+ intermediate_size=embedding_dim * 4,
163
+ use_bias=True,
164
+ )
165
+
166
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
167
+
168
+ block_config = cfg.TransformerBlockConfig(
107
169
  attn_config=attn_config,
108
170
  ff_config=ff_config,
109
171
  pre_attention_norm_config=norm_config,
110
- pre_ff_norm_config=norm_config,
172
+ post_attention_norm_config=norm_config,
173
+ )
174
+
175
+ config = cfg.ModelConfig(
176
+ vocab_size=vocab_size,
177
+ num_layers=num_layers,
178
+ max_seq_len=max_seq_len,
179
+ embedding_dim=embedding_dim,
180
+ block_configs=block_config,
111
181
  final_norm_config=norm_config,
112
182
  enable_hlfb=True,
113
183
  )
@@ -18,19 +18,22 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Optional
20
20
 
21
- import torch
22
-
23
21
  import ai_edge_torch
24
22
  import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
25
23
  import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
26
24
  import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
27
25
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
28
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
+ from ai_edge_torch.generative.quantize import quant_recipes
29
28
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
29
+ import torch
30
30
 
31
31
  arg_parser = argparse.ArgumentParser()
32
32
  arg_parser.add_argument(
33
- '--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
33
+ '--clip_ckpt',
34
+ type=str,
35
+ help='Path to source CLIP model checkpoint',
36
+ required=True,
34
37
  )
35
38
  arg_parser.add_argument(
36
39
  '--diffusion_ckpt',
@@ -60,6 +63,7 @@ def convert_stable_diffusion_to_tflite(
60
63
  decoder_ckpt_path: str,
61
64
  image_height: int = 512,
62
65
  image_width: int = 512,
66
+ quantize: bool = True,
63
67
  ):
64
68
 
65
69
  clip_model = clip.CLIP(clip.get_model_config())
@@ -90,10 +94,14 @@ def convert_stable_diffusion_to_tflite(
90
94
  n_tokens = 77
91
95
  timestamp = 0
92
96
  len_prompt = 1
93
- prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
94
- input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
97
+ prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
98
+ input_image = torch.full(
99
+ (1, 3, image_height, image_width), 0, dtype=torch.float32
100
+ )
95
101
  noise = torch.full(
96
- (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
102
+ (len_prompt, 4, image_height // 8, image_width // 8),
103
+ 0,
104
+ dtype=torch.float32,
97
105
  )
98
106
 
99
107
  input_latents = torch.zeros_like(noise)
@@ -105,15 +113,19 @@ def convert_stable_diffusion_to_tflite(
105
113
  if not os.path.exists(output_dir):
106
114
  Path(output_dir).mkdir(parents=True, exist_ok=True)
107
115
 
116
+ quant_config = (
117
+ quant_recipes.full_int8_weight_only_recipe() if quantize else None
118
+ )
119
+
108
120
  # TODO(yichunk): convert to multi signature tflite model.
109
121
  # CLIP text encoder
110
- ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
111
- f'{output_dir}/clip.tflite'
112
- )
122
+ ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
123
+ quant_config=quant_config
124
+ ).export(f'{output_dir}/clip.tflite')
113
125
 
114
126
  # TODO(yichunk): enable image encoder conversion
115
127
  # Image encoder
116
- # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
128
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config).export(
117
129
  # f'{output_dir}/encoder.tflite'
118
130
  # )
119
131
 
@@ -122,12 +134,12 @@ def convert_stable_diffusion_to_tflite(
122
134
  'diffusion',
123
135
  diffusion_model,
124
136
  (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
125
- ).convert().export(f'{output_dir}/diffusion.tflite')
137
+ ).convert(quant_config=quant_config).export(f'{output_dir}/diffusion.tflite')
126
138
 
127
139
  # Image decoder
128
- ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
129
- f'{output_dir}/decoder.tflite'
130
- )
140
+ ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
141
+ quant_config=quant_config
142
+ ).export(f'{output_dir}/decoder.tflite')
131
143
 
132
144
 
133
145
  if __name__ == '__main__':
@@ -139,4 +151,5 @@ if __name__ == '__main__':
139
151
  decoder_ckpt_path=args.decoder_ckpt,
140
152
  image_height=512,
141
153
  image_width=512,
154
+ quantize=True,
142
155
  )
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  import ai_edge_torch.generative.layers.builder as layers_builder
20
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
21
- import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
18
+ from ai_edge_torch.generative.layers.unet import blocks_2d
22
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
- import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
20
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
25
  post_quant_conv="first_stage_model.post_quant_conv",
@@ -104,7 +103,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
104
103
  norm_2="first_stage_model.decoder.up.1.block.0.norm2",
105
104
  conv_1="first_stage_model.decoder.up.1.block.0.conv1",
106
105
  conv_2="first_stage_model.decoder.up.1.block.0.conv2",
107
- residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
106
+ residual_layer=(
107
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut"
108
+ ),
108
109
  ),
109
110
  stable_diffusion_loader.ResidualBlockTensorNames(
110
111
  norm_1="first_stage_model.decoder.up.1.block.1.norm1",
@@ -128,7 +129,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
128
129
  norm_2="first_stage_model.decoder.up.0.block.0.norm2",
129
130
  conv_1="first_stage_model.decoder.up.0.block.0.conv1",
130
131
  conv_2="first_stage_model.decoder.up.0.block.0.conv2",
131
- residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
132
+ residual_layer=(
133
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut"
134
+ ),
132
135
  ),
133
136
  stable_diffusion_loader.ResidualBlockTensorNames(
134
137
  norm_1="first_stage_model.decoder.up.0.block.1.norm1",
@@ -285,6 +288,63 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
285
288
  normalization_config=norm_config,
286
289
  attention_config=layers_cfg.AttentionConfig(
287
290
  num_heads=1,
291
+ head_dim=block_out_channels[-1],
292
+ num_query_groups=1,
293
+ qkv_use_bias=True,
294
+ output_proj_use_bias=True,
295
+ enable_kv_cache=False,
296
+ qkv_transpose_before_split=True,
297
+ qkv_fused_interleaved=False,
298
+ rotary_percentage=0.0,
299
+ ),
300
+ enable_hlfb=False,
301
+ )
302
+
303
+ mid_block_config = unet_cfg.MidBlock2DConfig(
304
+ in_channels=block_out_channels[-1],
305
+ normalization_config=norm_config,
306
+ activation_config=layers_cfg.ActivationConfig(
307
+ layers_cfg.ActivationType.SILU
308
+ ),
309
+ num_layers=1,
310
+ attention_block_config=att_config,
311
+ )
312
+
313
+ config = unet_cfg.AutoEncoderConfig(
314
+ in_channels=in_channels,
315
+ latent_channels=latent_channels,
316
+ out_channels=out_channels,
317
+ activation_config=layers_cfg.ActivationConfig(
318
+ layers_cfg.ActivationType.SILU
319
+ ),
320
+ block_out_channels=block_out_channels,
321
+ scaling_factor=scaling_factor,
322
+ layers_per_block=layers_per_block,
323
+ normalization_config=norm_config,
324
+ mid_block_config=mid_block_config,
325
+ )
326
+ return config
327
+
328
+
329
+ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
330
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
331
+ in_channels = 3
332
+ latent_channels = 4
333
+ out_channels = 3
334
+ block_out_channels = [2, 4]
335
+ scaling_factor = 0.18215
336
+ layers_per_block = 2
337
+
338
+ norm_config = layers_cfg.NormalizationConfig(
339
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
340
+ )
341
+
342
+ att_config = unet_cfg.AttentionBlock2DConfig(
343
+ dim=block_out_channels[-1],
344
+ normalization_config=norm_config,
345
+ attention_config=layers_cfg.AttentionConfig(
346
+ num_heads=1,
347
+ head_dim=block_out_channels[-1],
288
348
  num_query_groups=1,
289
349
  qkv_use_bias=True,
290
350
  output_proj_use_bias=True,
@@ -293,12 +353,15 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
293
353
  qkv_fused_interleaved=False,
294
354
  rotary_percentage=0.0,
295
355
  ),
356
+ enable_hlfb=False,
296
357
  )
297
358
 
298
359
  mid_block_config = unet_cfg.MidBlock2DConfig(
299
360
  in_channels=block_out_channels[-1],
300
361
  normalization_config=norm_config,
301
- activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
362
+ activation_config=layers_cfg.ActivationConfig(
363
+ layers_cfg.ActivationType.SILU
364
+ ),
302
365
  num_layers=1,
303
366
  attention_block_config=att_config,
304
367
  )
@@ -307,7 +370,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
307
370
  in_channels=in_channels,
308
371
  latent_channels=latent_channels,
309
372
  out_channels=out_channels,
310
- activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
373
+ activation_config=layers_cfg.ActivationConfig(
374
+ layers_cfg.ActivationType.SILU
375
+ ),
311
376
  block_out_channels=block_out_channels,
312
377
  scaling_factor=scaling_factor,
313
378
  layers_per_block=layers_per_block,