ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.quantize import quant_recipes
24
+
25
+
26
+ def convert_gemma_to_tflite(
27
+ checkpoint_path: str,
28
+ prefill_seq_len: int = 512,
29
+ kv_cache_max_len: int = 1024,
30
+ quantize: bool = True,
31
+ ):
32
+ """An example method for converting a Gemma 2B model to multi-signature
33
+ tflite model.
34
+
35
+ Args:
36
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
38
+ Defaults to 512.
39
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
40
+ including both prefill and decode. Defaults to 1024.
41
+ quantize (bool, optional): Whether the model should be quanized.
42
+ Defaults to True.
43
+ """
44
+ pytorch_model = gemma.build_2b_model(
45
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
46
+ )
47
+ # Tensors used to trace the model graph during conversion.
48
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
49
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
50
+ decode_token = torch.tensor([[0]], dtype=torch.long)
51
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
52
+
53
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
54
+ edge_model = (
55
+ ai_edge_torch.signature(
56
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
57
+ )
58
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
59
+ .convert(quant_config=quant_config)
60
+ )
61
+ edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
62
+
63
+
64
+ if __name__ == '__main__':
65
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
66
+ convert_gemma_to_tflite(checkpoint_path)
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building a Gemma model.
16
+
17
+ import os
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ import ai_edge_torch.generative.layers.model_config as cfg
28
+ import ai_edge_torch.generative.utilities.loader as loading_utils
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
34
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
35
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
36
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
37
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
38
+ pre_attn_norm="model.layers.{}.input_layernorm",
39
+ pre_ff_norm="model.layers.{}.post_attention_layernorm",
40
+ embedding="model.embed_tokens",
41
+ final_norm="model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Gemma(nn.Module):
47
+
48
+ def __init__(self, config: cfg.ModelConfig):
49
+ super().__init__()
50
+
51
+ self.config = config
52
+ # Construct model layers.
53
+ self.tok_embedding = nn.Embedding(
54
+ config.vocab_size, config.embedding_dim, padding_idx=0
55
+ )
56
+ self.lm_head = nn.Linear(
57
+ config.embedding_dim,
58
+ config.vocab_size,
59
+ bias=config.lm_head_use_bias,
60
+ )
61
+ # Gemma re-uses the embedding as the head projection layer.
62
+ self.lm_head.weight.data = self.tok_embedding.weight.data
63
+ self.transformer_blocks = nn.ModuleList(
64
+ TransformerBlock(config) for _ in range(config.num_layers)
65
+ )
66
+ self.final_norm = builder.build_norm(
67
+ config.embedding_dim,
68
+ config.final_norm_config,
69
+ )
70
+ self.rope_cache = attn_utils.build_rope_cache(
71
+ size=config.kv_cache_max,
72
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
73
+ base=10_000,
74
+ condense_ratio=1,
75
+ dtype=torch.float32,
76
+ device=torch.device("cpu"),
77
+ )
78
+ self.mask_cache = attn_utils.build_causal_mask_cache(
79
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
80
+ )
81
+ self.config = config
82
+
83
+ # The model's forward function takes in additional k/v cache tensors
84
+ # and returns the updated k/v cache tensors to the caller.
85
+ # This can be eliminated if we handle k/v cache updates inside the model itself.
86
+ @torch.inference_mode
87
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
+ B, T = idx.size()
89
+ assert (
90
+ self.config.max_seq_len >= T
91
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
92
+
93
+ cos, sin = self.rope_cache
94
+ cos = cos.index_select(0, input_pos)
95
+ sin = sin.index_select(0, input_pos)
96
+ mask = self.mask_cache.index_select(2, input_pos)
97
+ mask = mask[:, :, :, : self.config.kv_cache_max]
98
+
99
+ # token embeddings of shape (b, t, n_embd)
100
+ x = self.tok_embedding(idx)
101
+ x = x * (self.config.embedding_dim**0.5)
102
+
103
+ for i, block in enumerate(self.transformer_blocks):
104
+ x = block(x, (cos, sin), mask, input_pos)
105
+
106
+ x = self.final_norm(x)
107
+ res = self.lm_head(x) # (b, t, vocab_size)
108
+ return res
109
+
110
+
111
+ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
+ attn_config = cfg.AttentionConfig(
113
+ num_heads=8,
114
+ num_query_groups=1,
115
+ rotary_percentage=1.0,
116
+ )
117
+ ff_config = cfg.FeedForwardConfig(
118
+ type=cfg.FeedForwardType.GATED,
119
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
120
+ intermediate_size=16384,
121
+ )
122
+ norm_config = cfg.NormalizationConfig(
123
+ type=cfg.NormalizationType.RMS_NORM,
124
+ epsilon=1e-6,
125
+ zero_centered=True,
126
+ )
127
+ config = cfg.ModelConfig(
128
+ vocab_size=256000,
129
+ num_layers=18,
130
+ max_seq_len=8192,
131
+ embedding_dim=2048,
132
+ kv_cache_max_len=kv_cache_max_len,
133
+ attn_config=attn_config,
134
+ ff_config=ff_config,
135
+ pre_attention_norm_config=norm_config,
136
+ pre_ff_norm_config=norm_config,
137
+ final_norm_config=norm_config,
138
+ parallel_residual=False,
139
+ lm_head_use_bias=False,
140
+ enable_hlfb=True,
141
+ )
142
+ return config
143
+
144
+
145
+ def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
146
+ config = get_model_config_2b()
147
+ config.num_layers = 2
148
+ return config
149
+
150
+
151
+ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
152
+ config = get_model_config_2b(**kwargs)
153
+ model = Gemma(config)
154
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
155
+ # since embedding and lm-head use the same weight, we need to set strict
156
+ # to False.
157
+ loader.load(model, strict=False)
158
+ return model
159
+
160
+
161
+ def define_and_run_2b() -> None:
162
+ kv_cache_max_len = 1024
163
+ checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
164
+ model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
165
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
166
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
167
+ tokens[0, :4] = idx
168
+ input_pos = torch.arange(0, kv_cache_max_len)
169
+ print("running an inference")
170
+ print(model.forward(tokens, input_pos))
171
+
172
+
173
+ if __name__ == "__main__":
174
+ define_and_run_2b()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.phi2 import phi2
23
+ from ai_edge_torch.generative.quantize import quant_recipes
24
+
25
+
26
+ def convert_phi2_to_tflite(
27
+ checkpoint_path: str,
28
+ prefill_seq_len: int = 512,
29
+ kv_cache_max_len: int = 1024,
30
+ quantize: bool = True,
31
+ ):
32
+ """An example method for converting a Phi-2 model to multi-signature
33
+ tflite model.
34
+
35
+ Args:
36
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
38
+ Defaults to 512.
39
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
40
+ including both prefill and decode. Defaults to 1024.
41
+ quantize (bool, optional): Whether the model should be quanized.
42
+ Defaults to True.
43
+ """
44
+ pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
45
+ # Tensors used to trace the model graph during conversion.
46
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
47
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
48
+ decode_token = torch.tensor([[0]], dtype=torch.long)
49
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+
51
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
52
+ edge_model = (
53
+ ai_edge_torch.signature(
54
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
55
+ )
56
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
57
+ .convert(quant_config=quant_config)
58
+ )
59
+ edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
60
+
61
+
62
+ if __name__ == '__main__':
63
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
64
+ convert_phi2_to_tflite(checkpoint_path)
@@ -0,0 +1,164 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building phi-2 model from the Edge Generative API layers.
16
+
17
+
18
+ import os
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
26
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
27
+ import ai_edge_torch.generative.layers.builder as builder
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+ import ai_edge_torch.generative.utilities.loader as loading_utils
30
+
31
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
+ ff_up_proj="model.layers.{}.mlp.fc1",
33
+ ff_down_proj="model.layers.{}.mlp.fc2",
34
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
35
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
36
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
37
+ attn_output_proj="model.layers.{}.self_attn.dense",
38
+ pre_attn_norm="model.layers.{}.input_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.final_layernorm",
41
+ lm_head="lm_head",
42
+ )
43
+
44
+
45
+ class Phi2(nn.Module):
46
+
47
+ def __init__(self, config: cfg.ModelConfig):
48
+ super().__init__()
49
+
50
+ self.config = config
51
+ # Construct model layers.
52
+ self.lm_head = nn.Linear(
53
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
54
+ )
55
+ self.tok_embedding = nn.Embedding(
56
+ config.vocab_size, config.embedding_dim, padding_idx=0
57
+ )
58
+ self.transformer_blocks = nn.ModuleList(
59
+ TransformerBlock(config) for _ in range(config.num_layers)
60
+ )
61
+ self.final_norm = builder.build_norm(
62
+ config.embedding_dim,
63
+ config.final_norm_config,
64
+ )
65
+ self.rope_cache = attn_utils.build_rope_cache(
66
+ size=config.kv_cache_max,
67
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
68
+ base=10_000,
69
+ condense_ratio=1,
70
+ dtype=torch.float32,
71
+ device=torch.device("cpu"),
72
+ )
73
+ self.mask_cache = attn_utils.build_causal_mask_cache(
74
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
75
+ )
76
+ self.config = config
77
+
78
+ # The model's forward function takes in additional k/v cache tensors
79
+ # and returns the updated k/v cache tensors to the caller.
80
+ # This can be eliminated if we handle k/v cache updates inside the model itself.
81
+ @torch.inference_mode
82
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
83
+ B, T = idx.size()
84
+ assert (
85
+ self.config.max_seq_len >= T
86
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
87
+
88
+ cos, sin = self.rope_cache
89
+ cos = cos.index_select(0, input_pos)
90
+ sin = sin.index_select(0, input_pos)
91
+ mask = self.mask_cache.index_select(2, input_pos)
92
+ mask = mask[:, :, :, : self.config.kv_cache_max]
93
+
94
+ # forward the model itself
95
+ x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
96
+
97
+ for i, block in enumerate(self.transformer_blocks):
98
+ x = block(x, (cos, sin), mask, input_pos)
99
+
100
+ x = self.final_norm(x)
101
+ res = self.lm_head(x) # (b, t, vocab_size)
102
+ return res
103
+
104
+
105
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
106
+ attn_config = cfg.AttentionConfig(
107
+ num_heads=32,
108
+ num_query_groups=32,
109
+ rotary_percentage=0.4,
110
+ qkv_use_bias=True,
111
+ output_proj_use_bias=True,
112
+ )
113
+ ff_config = cfg.FeedForwardConfig(
114
+ type=cfg.FeedForwardType.SEQUENTIAL,
115
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
116
+ intermediate_size=10240,
117
+ use_bias=True,
118
+ )
119
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
120
+ config = cfg.ModelConfig(
121
+ vocab_size=51200,
122
+ num_layers=32,
123
+ max_seq_len=2048,
124
+ kv_cache_max_len=kv_cache_max_len,
125
+ embedding_dim=2560,
126
+ attn_config=attn_config,
127
+ ff_config=ff_config,
128
+ pre_attention_norm_config=norm_config,
129
+ final_norm_config=norm_config,
130
+ parallel_residual=True,
131
+ lm_head_use_bias=True,
132
+ enable_hlfb=True,
133
+ )
134
+ return config
135
+
136
+
137
+ def get_fake_model_config_for_test() -> cfg.ModelConfig:
138
+ config = get_model_config()
139
+ config.num_layers = 2
140
+ return config
141
+
142
+
143
+ def build_model(checkpoint_path, **kwargs) -> nn.Module:
144
+ config = get_model_config(**kwargs)
145
+ model = Phi2(config)
146
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
147
+ loader.load(model)
148
+ return model
149
+
150
+
151
+ def define_and_run() -> None:
152
+ kv_cache_max_len = 1024
153
+ checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
154
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
155
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
156
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
157
+ tokens[0, :4] = idx
158
+ input_pos = torch.arange(0, kv_cache_max_len)
159
+ print("running an inference")
160
+ print(model.forward(tokens, input_pos))
161
+
162
+
163
+ if __name__ == "__main__":
164
+ define_and_run()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import math
17
+
18
+ import torch
19
+ from torch import _decomp
20
+ from torch import nn
21
+ from torch._prims_common import mask_tensor
22
+ from torch._prims_common.wrappers import out_wrapper
23
+ from torch.nn import functional as F
24
+
25
+
26
+ def triu(a):
27
+ h, w = a.shape[-2:]
28
+ mask = (
29
+ torch.arange(w, device=a.device).unsqueeze(-2)
30
+ - torch.arange(h, device=a.device).unsqueeze(-1)
31
+ ) >= 1
32
+ mask = torch.broadcast_to(mask, a.shape)
33
+ return torch.ops.aten.logical_and(a, mask).contiguous()
34
+
35
+
36
+ # _decomp.decomposition_table[torch.ops.aten.triu.default] = triu
37
+
38
+
39
+ class SelfAttention(nn.Module):
40
+
41
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
42
+ super().__init__()
43
+ self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
44
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
45
+ self.n_heads = n_heads
46
+ self.d_head = d_embed // n_heads
47
+
48
+ def forward(self, x, causal_mask=False):
49
+ input_shape = x.shape
50
+ batch_size, sequence_length, d_embed = input_shape
51
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
52
+
53
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
54
+
55
+ q = q.view(interim_shape).transpose(1, 2)
56
+ k = k.view(interim_shape).transpose(1, 2)
57
+ v = v.view(interim_shape).transpose(1, 2)
58
+
59
+ weight = q @ k.transpose(-1, -2)
60
+ if causal_mask:
61
+ # mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
62
+ mask = triu(torch.ones_like(weight, dtype=torch.bool))
63
+ weight.masked_fill_(mask, -torch.inf)
64
+ weight /= math.sqrt(self.d_head)
65
+ weight = F.softmax(weight, dim=-1)
66
+
67
+ output = weight @ v
68
+ output = output.transpose(1, 2)
69
+ output = output.reshape(input_shape)
70
+ output = self.out_proj(output)
71
+ return output
72
+
73
+
74
+ class CrossAttention(nn.Module):
75
+
76
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
77
+ super().__init__()
78
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
79
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
80
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
81
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
82
+ self.n_heads = n_heads
83
+ self.d_head = d_embed // n_heads
84
+
85
+ def forward(self, x, y):
86
+ input_shape = x.shape
87
+ batch_size, sequence_length, d_embed = input_shape
88
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
89
+
90
+ q = self.q_proj(x)
91
+ k = self.k_proj(y)
92
+ v = self.v_proj(y)
93
+
94
+ q = q.view(interim_shape).transpose(1, 2)
95
+ k = k.view(interim_shape).transpose(1, 2)
96
+ v = v.view(interim_shape).transpose(1, 2)
97
+
98
+ weight = q @ k.transpose(-1, -2)
99
+ weight /= math.sqrt(self.d_head)
100
+ weight = F.softmax(weight, dim=-1)
101
+
102
+ output = weight @ v
103
+ output = output.transpose(1, 2).contiguous()
104
+ output = output.view(input_shape)
105
+ output = self.out_proj(output)
106
+ return output