ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.quantize import quant_recipes
24
+
25
+
26
+ def convert_gemma_to_tflite(
27
+ checkpoint_path: str,
28
+ prefill_seq_len: int = 512,
29
+ kv_cache_max_len: int = 1024,
30
+ quantize: bool = True,
31
+ ):
32
+ """An example method for converting a Gemma 2B model to multi-signature
33
+ tflite model.
34
+
35
+ Args:
36
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
38
+ Defaults to 512.
39
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
40
+ including both prefill and decode. Defaults to 1024.
41
+ quantize (bool, optional): Whether the model should be quanized.
42
+ Defaults to True.
43
+ """
44
+ pytorch_model = gemma.build_2b_model(
45
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
46
+ )
47
+ # Tensors used to trace the model graph during conversion.
48
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
49
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
50
+ decode_token = torch.tensor([[0]], dtype=torch.long)
51
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
52
+
53
+ quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
54
+ edge_model = (
55
+ ai_edge_torch.signature(
56
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
57
+ )
58
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
59
+ .convert(quant_config=quant_config)
60
+ )
61
+ edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
62
+
63
+
64
+ if __name__ == '__main__':
65
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
66
+ convert_gemma_to_tflite(checkpoint_path)
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building a Gemma model.
16
+
17
+ import os
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ import ai_edge_torch.generative.layers.model_config as cfg
28
+ import ai_edge_torch.generative.utilities.loader as loading_utils
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
34
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
35
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
36
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
37
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
38
+ pre_attn_norm="model.layers.{}.input_layernorm",
39
+ pre_ff_norm="model.layers.{}.post_attention_layernorm",
40
+ embedding="model.embed_tokens",
41
+ final_norm="model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Gemma(nn.Module):
47
+
48
+ def __init__(self, config: cfg.ModelConfig):
49
+ super().__init__()
50
+
51
+ self.config = config
52
+ # Construct model layers.
53
+ self.tok_embedding = nn.Embedding(
54
+ config.vocab_size, config.embedding_dim, padding_idx=0
55
+ )
56
+ self.lm_head = nn.Linear(
57
+ config.embedding_dim,
58
+ config.vocab_size,
59
+ bias=config.lm_head_use_bias,
60
+ )
61
+ # Gemma re-uses the embedding as the head projection layer.
62
+ self.lm_head.weight.data = self.tok_embedding.weight.data
63
+ self.transformer_blocks = nn.ModuleList(
64
+ TransformerBlock(config) for _ in range(config.num_layers)
65
+ )
66
+ self.final_norm = builder.build_norm(
67
+ config.embedding_dim,
68
+ config.final_norm_config,
69
+ )
70
+ self.rope_cache = attn_utils.build_rope_cache(
71
+ size=config.kv_cache_max,
72
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
73
+ base=10_000,
74
+ condense_ratio=1,
75
+ dtype=torch.float32,
76
+ device=torch.device("cpu"),
77
+ )
78
+ self.mask_cache = attn_utils.build_causal_mask_cache(
79
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
80
+ )
81
+ self.config = config
82
+
83
+ # The model's forward function takes in additional k/v cache tensors
84
+ # and returns the updated k/v cache tensors to the caller.
85
+ # This can be eliminated if we handle k/v cache updates inside the model itself.
86
+ @torch.inference_mode
87
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
+ B, T = idx.size()
89
+ assert (
90
+ self.config.max_seq_len >= T
91
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
92
+
93
+ cos, sin = self.rope_cache
94
+ cos = cos.index_select(0, input_pos)
95
+ sin = sin.index_select(0, input_pos)
96
+ mask = self.mask_cache.index_select(2, input_pos)
97
+ mask = mask[:, :, :, : self.config.kv_cache_max]
98
+
99
+ # token embeddings of shape (b, t, n_embd)
100
+ x = self.tok_embedding(idx)
101
+ x = x * (self.config.embedding_dim**0.5)
102
+
103
+ for i, block in enumerate(self.transformer_blocks):
104
+ x = block(x, (cos, sin), mask, input_pos)
105
+
106
+ x = self.final_norm(x)
107
+ res = self.lm_head(x) # (b, t, vocab_size)
108
+ return res
109
+
110
+
111
+ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
+ attn_config = cfg.AttentionConfig(
113
+ num_heads=8,
114
+ num_query_groups=1,
115
+ rotary_percentage=1.0,
116
+ )
117
+ ff_config = cfg.FeedForwardConfig(
118
+ type=cfg.FeedForwardType.GATED,
119
+ activation=cfg.ActivationType.GELU_TANH,
120
+ intermediate_size=16384,
121
+ )
122
+ norm_config = cfg.NormalizationConfig(
123
+ type=cfg.NormalizationType.RMS_NORM,
124
+ epsilon=1e-6,
125
+ zero_centered=True,
126
+ )
127
+ config = cfg.ModelConfig(
128
+ vocab_size=256000,
129
+ num_layers=18,
130
+ max_seq_len=8192,
131
+ embedding_dim=2048,
132
+ kv_cache_max_len=kv_cache_max_len,
133
+ attn_config=attn_config,
134
+ ff_config=ff_config,
135
+ pre_attention_norm_config=norm_config,
136
+ pre_ff_norm_config=norm_config,
137
+ final_norm_config=norm_config,
138
+ parallel_residual=False,
139
+ lm_head_use_bias=False,
140
+ enable_hlfb=True,
141
+ )
142
+ return config
143
+
144
+
145
+ def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
146
+ config = get_model_config_2b()
147
+ config.num_layers = 2
148
+ return config
149
+
150
+
151
+ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
152
+ config = get_model_config_2b(**kwargs)
153
+ model = Gemma(config)
154
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
155
+ # since embedding and lm-head use the same weight, we need to set strict
156
+ # to False.
157
+ loader.load(model, strict=False)
158
+ return model
159
+
160
+
161
+ def define_and_run_2b() -> None:
162
+ kv_cache_max_len = 1024
163
+ checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
164
+ model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
165
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
166
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
167
+ tokens[0, :4] = idx
168
+ input_pos = torch.arange(0, kv_cache_max_len)
169
+ print("running an inference")
170
+ print(model.forward(tokens, input_pos))
171
+
172
+
173
+ if __name__ == "__main__":
174
+ define_and_run_2b()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.phi2 import phi2
23
+ from ai_edge_torch.generative.quantize import quant_recipes
24
+
25
+
26
+ def convert_phi2_to_tflite(
27
+ checkpoint_path: str,
28
+ prefill_seq_len: int = 512,
29
+ kv_cache_max_len: int = 1024,
30
+ quantize: bool = True,
31
+ ):
32
+ """An example method for converting a Phi-2 model to multi-signature
33
+ tflite model.
34
+
35
+ Args:
36
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
38
+ Defaults to 512.
39
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
40
+ including both prefill and decode. Defaults to 1024.
41
+ quantize (bool, optional): Whether the model should be quanized.
42
+ Defaults to True.
43
+ """
44
+ pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
45
+ # Tensors used to trace the model graph during conversion.
46
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
47
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
48
+ decode_token = torch.tensor([[0]], dtype=torch.long)
49
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
50
+
51
+ quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
52
+ edge_model = (
53
+ ai_edge_torch.signature(
54
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
55
+ )
56
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
57
+ .convert(quant_config=quant_config)
58
+ )
59
+ edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
60
+
61
+
62
+ if __name__ == '__main__':
63
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
64
+ convert_phi2_to_tflite(checkpoint_path)
@@ -0,0 +1,164 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building phi-2 model from the Edge Generative API layers.
16
+
17
+
18
+ import os
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
26
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
27
+ import ai_edge_torch.generative.layers.builder as builder
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+ import ai_edge_torch.generative.utilities.loader as loading_utils
30
+
31
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
+ ff_up_proj="model.layers.{}.mlp.fc1",
33
+ ff_down_proj="model.layers.{}.mlp.fc2",
34
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
35
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
36
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
37
+ attn_output_proj="model.layers.{}.self_attn.dense",
38
+ pre_attn_norm="model.layers.{}.input_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.final_layernorm",
41
+ lm_head="lm_head",
42
+ )
43
+
44
+
45
+ class Phi2(nn.Module):
46
+
47
+ def __init__(self, config: cfg.ModelConfig):
48
+ super().__init__()
49
+
50
+ self.config = config
51
+ # Construct model layers.
52
+ self.lm_head = nn.Linear(
53
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
54
+ )
55
+ self.tok_embedding = nn.Embedding(
56
+ config.vocab_size, config.embedding_dim, padding_idx=0
57
+ )
58
+ self.transformer_blocks = nn.ModuleList(
59
+ TransformerBlock(config) for _ in range(config.num_layers)
60
+ )
61
+ self.final_norm = builder.build_norm(
62
+ config.embedding_dim,
63
+ config.final_norm_config,
64
+ )
65
+ self.rope_cache = attn_utils.build_rope_cache(
66
+ size=config.kv_cache_max,
67
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
68
+ base=10_000,
69
+ condense_ratio=1,
70
+ dtype=torch.float32,
71
+ device=torch.device("cpu"),
72
+ )
73
+ self.mask_cache = attn_utils.build_causal_mask_cache(
74
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
75
+ )
76
+ self.config = config
77
+
78
+ # The model's forward function takes in additional k/v cache tensors
79
+ # and returns the updated k/v cache tensors to the caller.
80
+ # This can be eliminated if we handle k/v cache updates inside the model itself.
81
+ @torch.inference_mode
82
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
83
+ B, T = idx.size()
84
+ assert (
85
+ self.config.max_seq_len >= T
86
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
87
+
88
+ cos, sin = self.rope_cache
89
+ cos = cos.index_select(0, input_pos)
90
+ sin = sin.index_select(0, input_pos)
91
+ mask = self.mask_cache.index_select(2, input_pos)
92
+ mask = mask[:, :, :, : self.config.kv_cache_max]
93
+
94
+ # forward the model itself
95
+ x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
96
+
97
+ for i, block in enumerate(self.transformer_blocks):
98
+ x = block(x, (cos, sin), mask, input_pos)
99
+
100
+ x = self.final_norm(x)
101
+ res = self.lm_head(x) # (b, t, vocab_size)
102
+ return res
103
+
104
+
105
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
106
+ attn_config = cfg.AttentionConfig(
107
+ num_heads=32,
108
+ num_query_groups=32,
109
+ rotary_percentage=0.4,
110
+ qkv_use_bias=True,
111
+ output_proj_use_bias=True,
112
+ )
113
+ ff_config = cfg.FeedForwardConfig(
114
+ type=cfg.FeedForwardType.SEQUENTIAL,
115
+ activation=cfg.ActivationType.GELU_TANH,
116
+ intermediate_size=10240,
117
+ use_bias=True,
118
+ )
119
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
120
+ config = cfg.ModelConfig(
121
+ vocab_size=51200,
122
+ num_layers=32,
123
+ max_seq_len=2048,
124
+ kv_cache_max_len=kv_cache_max_len,
125
+ embedding_dim=2560,
126
+ attn_config=attn_config,
127
+ ff_config=ff_config,
128
+ pre_attention_norm_config=norm_config,
129
+ final_norm_config=norm_config,
130
+ parallel_residual=True,
131
+ lm_head_use_bias=True,
132
+ enable_hlfb=True,
133
+ )
134
+ return config
135
+
136
+
137
+ def get_fake_model_config_for_test() -> cfg.ModelConfig:
138
+ config = get_model_config()
139
+ config.num_layers = 2
140
+ return config
141
+
142
+
143
+ def build_model(checkpoint_path, **kwargs) -> nn.Module:
144
+ config = get_model_config(**kwargs)
145
+ model = Phi2(config)
146
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
147
+ loader.load(model)
148
+ return model
149
+
150
+
151
+ def define_and_run() -> None:
152
+ kv_cache_max_len = 1024
153
+ checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
154
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
155
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
156
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
157
+ tokens[0, :4] = idx
158
+ input_pos = torch.arange(0, kv_cache_max_len)
159
+ print("running an inference")
160
+ print(model.forward(tokens, input_pos))
161
+
162
+
163
+ if __name__ == "__main__":
164
+ define_and_run()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,135 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ import ai_edge_torch
23
+ from ai_edge_torch.generative.examples.t5 import t5
24
+ from ai_edge_torch.generative.quantize import quant_recipes
25
+
26
+
27
+ # TODO(haoliang): clean this up untile 2-sig model is validated e2e.
28
+ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
29
+ pytorch_model = t5.build_t5_model(checkpoint_path)
30
+
31
+ # encoder
32
+ seq_len = 512
33
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
34
+ prompt_e_token = [1, 2, 3, 4, 5, 6]
35
+ prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
36
+ prompt_e_token, dtype=torch.long
37
+ )
38
+ prefill_e_input_pos = torch.arange(0, seq_len)
39
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
40
+ prompt_d_token = [1, 2, 3, 4, 5, 6]
41
+ prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
42
+ prompt_d_token, dtype=torch.long
43
+ )
44
+ prefill_d_input_pos = torch.arange(0, seq_len)
45
+
46
+ # decoder
47
+ decode_token = torch.tensor([[1]], dtype=torch.long)
48
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
49
+ decode_d_token = torch.tensor([[1]], dtype=torch.long)
50
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
51
+
52
+ # Pad mask for self attention only on "real" tokens.
53
+ # Pad with `-inf` for any tokens indices that aren't desired.
54
+ pad_mask = torch.zeros([seq_len], dtype=torch.float32)
55
+
56
+ edge_model = ai_edge_torch.signature(
57
+ 'decode',
58
+ pytorch_model,
59
+ (
60
+ prefill_e_tokens,
61
+ prefill_e_input_pos,
62
+ decode_d_token,
63
+ decode_d_input_pos,
64
+ pad_mask,
65
+ ),
66
+ ).convert()
67
+
68
+ edge_model.export('/tmp/t5_encode_decode.tflite')
69
+
70
+
71
+ def convert_t5_to_tflite_multisig(checkpoint_path: str):
72
+ config = t5.get_model_config_t5()
73
+ embedding_layer = torch.nn.Embedding(
74
+ config.vocab_size, config.embedding_dim, padding_idx=0
75
+ )
76
+ t5_encoder_model = t5.build_t5_encoder_model(config, embedding_layer, checkpoint_path)
77
+ t5_decoder_model = t5.build_t5_decoder_model(config, embedding_layer, checkpoint_path)
78
+
79
+ # encoder
80
+ seq_len = 512
81
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
82
+ prompt_e_token = [1, 2, 3, 4, 5, 6]
83
+ prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
84
+ prompt_e_token, dtype=torch.long
85
+ )
86
+ prefill_e_input_pos = torch.arange(0, seq_len)
87
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
88
+ prompt_d_token = [1, 2, 3, 4, 5, 6]
89
+ prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
90
+ prompt_d_token, dtype=torch.long
91
+ )
92
+ prefill_d_input_pos = torch.arange(0, seq_len)
93
+
94
+ # decoder
95
+ decode_token = torch.tensor([[1]], dtype=torch.long)
96
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
97
+ decode_d_token = torch.tensor([[1]], dtype=torch.long)
98
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
99
+
100
+ # Pad mask for self attention only on "real" tokens.
101
+ # Pad with `-inf` for any tokens indices that aren't desired.
102
+ pad_mask = torch.zeros([seq_len], dtype=torch.float32)
103
+ hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)
104
+ quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
105
+
106
+ edge_model = (
107
+ ai_edge_torch.signature(
108
+ 'encode',
109
+ t5_encoder_model,
110
+ (
111
+ prefill_e_tokens,
112
+ prefill_e_input_pos,
113
+ pad_mask,
114
+ ),
115
+ )
116
+ .signature(
117
+ 'decode',
118
+ t5_decoder_model,
119
+ (
120
+ hidden_states,
121
+ decode_d_token,
122
+ decode_d_input_pos,
123
+ pad_mask,
124
+ ),
125
+ )
126
+ .convert(quant_config=quant_config)
127
+ )
128
+
129
+ edge_model.export('/tmp/t5_encode_decode_2_sigs.tflite')
130
+
131
+
132
+ if __name__ == '__main__':
133
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/t5')
134
+ # convert_t5_to_tflite_singlesig(checkpoint_path)
135
+ convert_t5_to_tflite_multisig(checkpoint_path)