ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions to verify the reauthored Gemma model."""
17
+
18
+ import dataclasses
19
+ import logging
20
+ import os
21
+ from typing import List, Tuple
22
+
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ from gemma import config as gemma_config
26
+ from gemma import model as gemma_model
27
+ import torch
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class _Output:
32
+ logits: torch.Tensor
33
+
34
+
35
+ class GemmaWrapper(verifier.ModelWrapper):
36
+ """Gemma model wrapper for verification.
37
+
38
+ Verifier calls model.forward() with maxium sequence length (1024) expecting
39
+ the output has 'logits' field while Gemma gets the input tokens with the
40
+ actual length and returns logits in a tuple.
41
+
42
+ Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
43
+ inside model.generate().
44
+ """
45
+
46
+ def __init__(self, model: torch.nn.Module, max_new_tokens: int):
47
+ super().__init__(model)
48
+ self.max_new_tokens = max_new_tokens
49
+
50
+ def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
51
+ for i in range(tokens.shape[1]):
52
+ if tokens[0, i] == 0:
53
+ return i
54
+ return tokens.shape[1]
55
+
56
+ def _get_kv_caches(
57
+ self, max_seq_len: int
58
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
59
+ config = self.model.config
60
+ cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
61
+ cache = torch.zeros(cache_size)
62
+ return [
63
+ (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
64
+ ]
65
+
66
+ def forward(self, tokens: torch.Tensor) -> _Output:
67
+ """Forwards the model after reducing input tokens to the actual length."""
68
+ actual_input_len = self._get_actual_input_len(tokens)
69
+ input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
70
+ mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
71
+ _, logits = self.model.forward(
72
+ input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
73
+ input_positions=input_pos,
74
+ kv_write_indices=None,
75
+ kv_caches=self._get_kv_caches(tokens.shape[1]),
76
+ mask=mask_cache.index_select(2, input_pos),
77
+ output_positions=input_pos,
78
+ temperatures=None,
79
+ top_ps=torch.tensor([1.0], dtype=torch.float),
80
+ top_ks=torch.tensor([1], dtype=torch.long),
81
+ )
82
+ return _Output(logits.float())
83
+
84
+ def generate(self, tokens: torch.Tensor) -> torch.Tensor:
85
+ """Generates the response after decoding the tokens into a string."""
86
+ prompts = self.model.tokenizer.decode(tokens[0].tolist())
87
+ response = self.model.generate(
88
+ prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
89
+ )
90
+ return torch.tensor([self.model.tokenizer.encode(prompts + response)])
91
+
92
+
93
+ class TokenizerWrapper(torch.nn.Module):
94
+ """Tokenizer wrapper for verification.
95
+
96
+ Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
97
+ tokenizer expects tokens in a list.
98
+ """
99
+
100
+ def __init__(self, tokenizer: torch.nn.Module):
101
+ super().__init__()
102
+ self.tokenizer = tokenizer
103
+
104
+ def encode(self, text: str, **_) -> torch.Tensor:
105
+ """Adds one more dimension to the output of the tokenizer."""
106
+ return torch.tensor([self.tokenizer.encode(text)])
107
+
108
+ def decode(self, tokens: torch.Tensor) -> str:
109
+ """Decodes the token sequence after converting to a list."""
110
+ return self.tokenizer.decode(tokens.tolist())
111
+
112
+
113
+ def verify_reauthored_gemma_model(
114
+ checkpoint: str,
115
+ variant: str,
116
+ reauthored_model: torch.nn.Module,
117
+ generate_prompts: List[str],
118
+ forward_input_ids: List[List[int]],
119
+ weight_filename: str = "model.ckpt",
120
+ tokenizer_filename: str = "tokenizer.model",
121
+ max_new_tokens: int = 20,
122
+ rtol: float = 1e-05,
123
+ atol: float = 1e-05,
124
+ ):
125
+ """Verifies the reauthored Gemma model against the original model."""
126
+ config = gemma_config.get_model_config(variant)
127
+ config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
128
+ # Use float32 to be compatible with the reauthored model.
129
+ config.dtype = torch.float32
130
+
131
+ logging.info("Loading the original model from: %s", checkpoint)
132
+ original_model = gemma_model.GemmaForCausalLM(config).eval()
133
+ original_model.load_weights(os.path.join(checkpoint, weight_filename))
134
+
135
+ verifier.verify_reauthored_model(
136
+ original_model=GemmaWrapper(original_model, max_new_tokens),
137
+ reauthored_model=reauthored_model,
138
+ tokenizer=TokenizerWrapper(original_model.tokenizer),
139
+ generate_prompts=generate_prompts,
140
+ forward_input_ids=forward_input_ids,
141
+ rtol=rtol,
142
+ atol=atol,
143
+ )
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting OpenELM model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.openelm import openelm
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = openelm.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,206 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an OpenELM model."""
17
+
18
+ from ai_edge_torch.generative.layers import attention
19
+ from ai_edge_torch.generative.layers import builder
20
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.utilities.loader as loading_utils
24
+ import torch
25
+ from torch import nn
26
+
27
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
28
+ ff_up_proj="transformer.layers.{}.ffn.proj_1",
29
+ ff_down_proj="transformer.layers.{}.ffn.proj_2",
30
+ attn_fused_qkv_proj="transformer.layers.{}.attn.qkv_proj",
31
+ attn_query_norm="transformer.layers.{}.attn.q_norm",
32
+ attn_key_norm="transformer.layers.{}.attn.k_norm",
33
+ attn_output_proj="transformer.layers.{}.attn.out_proj",
34
+ pre_attn_norm="transformer.layers.{}.attn_norm",
35
+ pre_ff_norm="transformer.layers.{}.ffn_norm",
36
+ embedding="transformer.token_embeddings",
37
+ final_norm="transformer.norm",
38
+ lm_head=None,
39
+ )
40
+
41
+
42
+ class OpenELM(nn.Module):
43
+ """An OpenELM model built from the Edge Generative API layers."""
44
+
45
+ def __init__(self, config: cfg.ModelConfig):
46
+ super().__init__()
47
+
48
+ # Construct model layers.
49
+ self.tok_embedding = nn.Embedding(
50
+ config.vocab_size, config.embedding_dim, padding_idx=0
51
+ )
52
+ self.lm_head = nn.Linear(
53
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
54
+ )
55
+ # OpenELM re-uses the embedding as the head projection layer.
56
+ self.lm_head.weight.data = self.tok_embedding.weight.data
57
+ self.transformer_blocks = nn.ModuleList(
58
+ attention.TransformerBlock(config.block_config(idx), config)
59
+ for idx in range(config.num_layers)
60
+ )
61
+ self.final_norm = builder.build_norm(
62
+ config.embedding_dim,
63
+ config.final_norm_config,
64
+ )
65
+ # OpenELM has same hyper parameters for rotary_percentage and head_dim for
66
+ # each layer block. Use the first block.
67
+ attn_config = config.block_config(0).attn_config
68
+ self.rope_cache = attn_utils.build_rope_cache(
69
+ size=config.kv_cache_max,
70
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
+ base=10_000,
72
+ condense_ratio=1,
73
+ dtype=torch.float32,
74
+ device=torch.device("cpu"),
75
+ )
76
+ self.mask_cache = attn_utils.build_causal_mask_cache(
77
+ size=config.kv_cache_max,
78
+ dtype=torch.float32,
79
+ device=torch.device("cpu"),
80
+ )
81
+ self.config = config
82
+
83
+ @torch.inference_mode
84
+ def forward(
85
+ self,
86
+ tokens: torch.Tensor,
87
+ input_pos: torch.Tensor,
88
+ kv_cache: kv_utils.KVCache,
89
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
90
+ _, seq_len = tokens.size()
91
+ assert self.config.max_seq_len >= seq_len, (
92
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
93
+ f" {self.config.max_seq_len}"
94
+ )
95
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
96
+ "The number of transformer blocks and the number of KV cache entries"
97
+ " must be the same."
98
+ )
99
+
100
+ cos, sin = self.rope_cache
101
+ cos = cos.index_select(0, input_pos)
102
+ sin = sin.index_select(0, input_pos)
103
+ mask = self.mask_cache.index_select(2, input_pos)
104
+ mask = mask[:, :, :, : self.config.kv_cache_max]
105
+
106
+ # token embeddings of shape (b, t, n_embd)
107
+ x = self.tok_embedding(tokens)
108
+
109
+ updated_kv_entires = []
110
+ for i, block in enumerate(self.transformer_blocks):
111
+ kv_entry = kv_cache.caches[i] if kv_cache else None
112
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
113
+ if kv_entry:
114
+ updated_kv_entires.append(kv_entry)
115
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
116
+
117
+ x = self.final_norm(x)
118
+ logits = self.lm_head(x) # (b, t, vocab_size)
119
+ return {"logits": logits, "kv_cache": updated_kv_cache}
120
+
121
+
122
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
123
+ """Returns the model config for an OpenELM model.
124
+
125
+ Args:
126
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
127
+ is 1024.
128
+
129
+ Returns:
130
+ The model config for an OpenELM model.
131
+ """
132
+ norm_config = cfg.NormalizationConfig(
133
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
134
+ )
135
+ num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
136
+ num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
137
+
138
+ def make_divisible(v, d):
139
+ """Ensures that all layers have a channel number that is divisible by d."""
140
+ new_v = int(v + d / 2) // d * d
141
+ # Make sure that round down does not go down by more than 10%.
142
+ if new_v < 0.9 * v:
143
+ new_v += d
144
+ return new_v
145
+
146
+ # The way to get intermediate size is from
147
+ # https://huggingface.co/apple/OpenELM-3B/blob/main/modeling_openelm.py
148
+ def get_intermediate_size(idx: int) -> int:
149
+ return make_divisible((0.5 + 0.1 * idx) * 3072, 256)
150
+
151
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
152
+ return cfg.TransformerBlockConfig(
153
+ attn_config=cfg.AttentionConfig(
154
+ num_heads=num_heads[idx],
155
+ head_dim=128,
156
+ num_query_groups=num_query_groups[idx],
157
+ rotary_percentage=1.0,
158
+ qkv_transpose_before_split=True,
159
+ query_norm_config=norm_config,
160
+ key_norm_config=norm_config,
161
+ ),
162
+ ff_config=cfg.FeedForwardConfig(
163
+ type=cfg.FeedForwardType.SEQUENTIAL,
164
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
165
+ intermediate_size=get_intermediate_size(idx),
166
+ pre_ff_norm_config=norm_config,
167
+ ),
168
+ pre_attention_norm_config=norm_config,
169
+ )
170
+
171
+ num_layers = 36
172
+ config = cfg.ModelConfig(
173
+ vocab_size=32000,
174
+ num_layers=num_layers,
175
+ max_seq_len=2048,
176
+ embedding_dim=3072,
177
+ kv_cache_max_len=kv_cache_max_len,
178
+ block_configs=[get_block_config(i) for i in range(num_layers)],
179
+ final_norm_config=norm_config,
180
+ )
181
+ return config
182
+
183
+
184
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
185
+ config = get_model_config(kv_cache_max_len)
186
+ config.vocab_size = 128
187
+ config.num_layers = 2
188
+ config.max_seq_len = 2 * kv_cache_max_len
189
+ config.embedding_dim = 128
190
+ config.block_configs = config.block_configs[: config.num_layers]
191
+ for block_config in config.block_configs:
192
+ block_config.attn_config.num_heads = 3
193
+ block_config.attn_config.head_dim = 64
194
+ block_config.ff_config.intermediate_size = 128
195
+ return config
196
+
197
+
198
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
199
+ config = get_model_config(**kwargs)
200
+ model = OpenELM(config)
201
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
202
+ # Since embedding and lm-head use the same weight, we need to set strict
203
+ # to False.
204
+ loader.load(model, strict=False)
205
+ model.eval()
206
+ return model
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+
27
+ _PROMPTS = flags.DEFINE_multi_string(
28
+ "prompts",
29
+ "What is the meaning of life?",
30
+ "The input prompts to generate answers.",
31
+ )
32
+
33
+
34
+ def main(_):
35
+ checkpoint = "apple/OpenELM-3B"
36
+ logging.info("Loading the original model from: %s", checkpoint)
37
+ wrapper_model = verifier.ModelWrapper(
38
+ model=transformers.AutoModelForCausalLM.from_pretrained(
39
+ checkpoint, trust_remote_code=True
40
+ ),
41
+ )
42
+
43
+ # Locate the cached dir.
44
+ cached_config_file = transformers.utils.cached_file(
45
+ checkpoint, transformers.utils.CONFIG_NAME
46
+ )
47
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
50
+
51
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
52
+ logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
53
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
54
+
55
+ verifier.verify_reauthored_model(
56
+ original_model=wrapper_model,
57
+ reauthored_model=reauthored_model,
58
+ tokenizer=tokenizer,
59
+ generate_prompts=_PROMPTS.value,
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ app.run(main)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi3.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-2 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi2.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)