ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -1,219 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Example of building a Gemma model.
16
- #
17
- # Note: This is an experimental version of Gemma with external KV cache.
18
- # Please use with caution.
19
-
20
- import os
21
- from pathlib import Path
22
- from typing import Tuple
23
-
24
- from ai_edge_torch.generative.layers import builder
25
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- import ai_edge_torch.generative.layers.model_config as cfg
29
- import ai_edge_torch.generative.utilities.loader as loading_utils
30
- import numpy as np
31
- import torch
32
- from torch import nn
33
-
34
-
35
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
- ff_up_proj="model.layers.{}.mlp.up_proj",
37
- ff_down_proj="model.layers.{}.mlp.down_proj",
38
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
39
- attn_query_proj="model.layers.{}.self_attn.q_proj",
40
- attn_key_proj="model.layers.{}.self_attn.k_proj",
41
- attn_value_proj="model.layers.{}.self_attn.v_proj",
42
- attn_output_proj="model.layers.{}.self_attn.o_proj",
43
- pre_attn_norm="model.layers.{}.input_layernorm",
44
- post_attn_norm="model.layers.{}.post_attention_layernorm",
45
- embedding="model.embed_tokens",
46
- final_norm="model.norm",
47
- lm_head=None,
48
- )
49
-
50
-
51
- class Gemma(nn.Module):
52
- """A Gemma model built from the Edge Generative API layers."""
53
-
54
- def __init__(self, config: cfg.ModelConfig):
55
- super().__init__()
56
-
57
- self.config = config
58
- # Construct model layers.
59
- self.tok_embedding = nn.Embedding(
60
- config.vocab_size, config.embedding_dim, padding_idx=0
61
- )
62
- self.lm_head = nn.Linear(
63
- config.embedding_dim,
64
- config.vocab_size,
65
- bias=config.lm_head_use_bias,
66
- )
67
- # Gemma re-uses the embedding as the head projection layer.
68
- self.lm_head.weight.data = self.tok_embedding.weight.data
69
- self.transformer_blocks = nn.ModuleList(
70
- attention.TransformerBlock(config) for _ in range(config.num_layers)
71
- )
72
- self.final_norm = builder.build_norm(
73
- config.embedding_dim,
74
- config.final_norm_config,
75
- )
76
- self.rope_cache = attn_utils.build_rope_cache(
77
- size=config.kv_cache_max,
78
- dim=int(
79
- config.attn_config.rotary_percentage * config.attn_config.head_dim
80
- ),
81
- base=10_000,
82
- condense_ratio=1,
83
- dtype=torch.float32,
84
- device=torch.device("cpu"),
85
- )
86
- self.mask_cache = attn_utils.build_causal_mask_cache(
87
- size=config.kv_cache_max,
88
- dtype=torch.float32,
89
- device=torch.device("cpu"),
90
- )
91
- self.config = config
92
-
93
- @torch.inference_mode
94
- def forward(
95
- self,
96
- tokens: torch.Tensor,
97
- input_pos: torch.Tensor,
98
- kv_cache: kv_utils.EKVCache,
99
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
100
- _, seq_len = tokens.size()
101
- assert self.config.max_seq_len >= seq_len, (
102
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
103
- f" {self.config.max_seq_len}"
104
- )
105
-
106
- cos, sin = self.rope_cache
107
- cos = cos.index_select(0, input_pos)
108
- sin = sin.index_select(0, input_pos)
109
- mask = self.mask_cache.index_select(2, input_pos)
110
- mask = mask[:, :, :, : self.config.kv_cache_max]
111
-
112
- # token embeddings of shape (b, t, n_embd)
113
- x = self.tok_embedding(tokens)
114
- x = x * (self.config.embedding_dim**0.5)
115
-
116
- updated_kv_entires = []
117
- for i, block in enumerate(self.transformer_blocks):
118
- kv_entry = kv_cache.caches[i] if kv_cache else None
119
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
- if kv_entry:
121
- updated_kv_entires.append(kv_entry)
122
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
123
-
124
- x = self.final_norm(x)
125
- res = self.lm_head(x) # (b, t, vocab_size)
126
- return res, updated_kv_cache
127
-
128
-
129
- def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
- """Returns the model config for a Gemma 2B model.
131
-
132
- Args:
133
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
134
- is 1024.
135
-
136
- Returns:
137
- The model config for a Gemma 2B model.
138
- """
139
- attn_config = cfg.AttentionConfig(
140
- num_heads=8,
141
- head_dim=256,
142
- num_query_groups=1,
143
- rotary_percentage=1.0,
144
- )
145
- ff_config = cfg.FeedForwardConfig(
146
- type=cfg.FeedForwardType.GATED,
147
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
148
- intermediate_size=16384,
149
- )
150
- norm_config = cfg.NormalizationConfig(
151
- type=cfg.NormalizationType.RMS_NORM,
152
- epsilon=1e-6,
153
- zero_centered=True,
154
- )
155
- config = cfg.ModelConfig(
156
- vocab_size=256000,
157
- num_layers=18,
158
- max_seq_len=8192,
159
- embedding_dim=2048,
160
- kv_cache_max_len=kv_cache_max_len,
161
- attn_config=attn_config,
162
- ff_config=ff_config,
163
- pre_attention_norm_config=norm_config,
164
- post_attention_norm_config=norm_config,
165
- final_norm_config=norm_config,
166
- parallel_residual=False,
167
- lm_head_use_bias=False,
168
- enable_hlfb=True,
169
- )
170
- return config
171
-
172
-
173
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
174
- config = get_model_config_2b(kv_cache_max_len)
175
- config.ff_config.intermediate_size = 128
176
- config.vocab_size = 128
177
- config.num_layers = 2
178
- config.max_seq_len = 2 * kv_cache_max_len
179
- return config
180
-
181
-
182
- def build_2b_model(
183
- checkpoint_path: str, test_model: bool = False, **kwargs
184
- ) -> nn.Module:
185
- """Instantiates the model instance and load checkpoint if provided."""
186
- config = (
187
- get_fake_model_config(**kwargs)
188
- if test_model
189
- else get_model_config_2b(**kwargs)
190
- )
191
- model = Gemma(config)
192
- if checkpoint_path is not None:
193
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
194
- # since embedding and lm-head use the same weight, we need to set strict
195
- # to False.
196
- loader.load(model, strict=False)
197
- model.eval()
198
- return model
199
-
200
-
201
- def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
202
- """Instantiates and runs a Gemma 2B model."""
203
-
204
- kv_cache_max_len = 1024
205
- model = build_2b_model(
206
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
207
- )
208
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
209
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
210
- tokens[0, :4] = idx
211
- input_pos = torch.arange(0, kv_cache_max_len)
212
- kv = kv_utils.EKVCache.from_model_config(model.config)
213
- print("running an inference")
214
- print(model.forward(tokens, input_pos, kv))
215
-
216
-
217
- if __name__ == "__main__":
218
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
219
- define_and_run_2b(input_checkpoint_path)
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,87 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Note: This is an experimental version of TinyLlama with external KV cache.
17
- # Please use with caution.
18
-
19
-
20
- import os
21
- from pathlib import Path
22
-
23
- import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
- from ai_edge_torch.generative.quantize import quant_recipes
27
- import torch
28
-
29
-
30
- def convert_tiny_llama_to_tflite(
31
- checkpoint_path: str,
32
- prefill_seq_len: int = 512,
33
- kv_cache_max_len: int = 1024,
34
- quantize: bool = True,
35
- ):
36
- """An example for converting TinyLlama model to multi-signature tflite model.
37
-
38
- Args:
39
- checkpoint_path (str): The filepath to the model checkpoint, or directory
40
- holding the checkpoint.
41
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
42
- Defaults to 512.
43
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
44
- including both prefill and decode. Defaults to 1024.
45
- quantize (bool, optional): Whether the model should be quanized. Defaults
46
- to True.
47
- """
48
- pytorch_model = tiny_llama.build_model(
49
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
50
- )
51
- # Tensors used to trace the model graph during conversion.
52
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
53
- prefill_input_pos = torch.arange(0, prefill_seq_len)
54
- decode_token = torch.tensor([[0]], dtype=torch.long)
55
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
57
-
58
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
- edge_model = (
60
- ai_edge_torch.signature(
61
- 'prefill',
62
- pytorch_model,
63
- sample_kwargs={
64
- 'tokens': prefill_tokens,
65
- 'input_pos': prefill_input_pos,
66
- 'kv_cache': kv,
67
- },
68
- )
69
- .signature(
70
- 'decode',
71
- pytorch_model,
72
- sample_kwargs={
73
- 'tokens': decode_token,
74
- 'input_pos': decode_input_pos,
75
- 'kv_cache': kv,
76
- },
77
- )
78
- .convert(quant_config=quant_config)
79
- )
80
- edge_model.export(
81
- f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
- )
83
-
84
-
85
- if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
87
- convert_tiny_llama_to_tflite(checkpoint_path)
@@ -1,205 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of TinyLlama with external KV cache.
18
- # Please use with caution.
19
-
20
- import os
21
- from pathlib import Path
22
- from typing import Tuple
23
-
24
- from ai_edge_torch.generative.layers import builder
25
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- import ai_edge_torch.generative.layers.model_config as cfg
29
- import ai_edge_torch.generative.utilities.loader as loading_utils
30
- import numpy as np
31
- import torch
32
- from torch import nn
33
-
34
-
35
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
- ff_up_proj="model.layers.{}.mlp.up_proj",
37
- ff_down_proj="model.layers.{}.mlp.down_proj",
38
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
39
- attn_query_proj="model.layers.{}.self_attn.q_proj",
40
- attn_key_proj="model.layers.{}.self_attn.k_proj",
41
- attn_value_proj="model.layers.{}.self_attn.v_proj",
42
- attn_output_proj="model.layers.{}.self_attn.o_proj",
43
- pre_attn_norm="model.layers.{}.input_layernorm",
44
- post_attn_norm="model.layers.{}.post_attention_layernorm",
45
- embedding="model.embed_tokens",
46
- final_norm="model.norm",
47
- lm_head="lm_head",
48
- )
49
-
50
-
51
- class TinyLLamma(nn.Module):
52
- """A TinyLlama model built from the Edge Generative API layers."""
53
-
54
- def __init__(self, config: cfg.ModelConfig):
55
- super().__init__()
56
-
57
- self.config = config
58
- # Construct model layers.
59
- self.lm_head = nn.Linear(
60
- config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
61
- )
62
- self.tok_embedding = nn.Embedding(
63
- config.vocab_size, config.embedding_dim, padding_idx=0
64
- )
65
- self.transformer_blocks = nn.ModuleList(
66
- attention.TransformerBlock(config) for _ in range(config.num_layers)
67
- )
68
- self.final_norm = builder.build_norm(
69
- config.embedding_dim,
70
- config.final_norm_config,
71
- )
72
- self.rope_cache = attn_utils.build_rope_cache(
73
- size=config.kv_cache_max,
74
- dim=int(
75
- config.attn_config.rotary_percentage * config.attn_config.head_dim
76
- ),
77
- base=10_000,
78
- condense_ratio=1,
79
- dtype=torch.float32,
80
- device=torch.device("cpu"),
81
- )
82
- self.mask_cache = attn_utils.build_causal_mask_cache(
83
- size=config.kv_cache_max,
84
- dtype=torch.float32,
85
- device=torch.device("cpu"),
86
- )
87
- self.config = config
88
-
89
- @torch.inference_mode
90
- def forward(
91
- self,
92
- tokens: torch.Tensor,
93
- input_pos: torch.Tensor,
94
- kv_cache: kv_utils.EKVCache,
95
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
96
- _, seq_len = tokens.size()
97
- assert self.config.max_seq_len >= seq_len, (
98
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
- f" {self.config.max_seq_len}"
100
- )
101
-
102
- cos, sin = self.rope_cache
103
- cos = cos.index_select(0, input_pos)
104
- sin = sin.index_select(0, input_pos)
105
- mask = self.mask_cache.index_select(2, input_pos)
106
- mask = mask[:, :, :, : self.config.kv_cache_max]
107
-
108
- # token embeddings of shape (b, t, n_embd)
109
- x = self.tok_embedding(tokens)
110
-
111
- updated_kv_entires = []
112
- for i, block in enumerate(self.transformer_blocks):
113
- kv_entry = kv_cache.caches[i] if kv_cache else None
114
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
115
- if kv_entry:
116
- updated_kv_entires.append(kv_entry)
117
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
118
-
119
- x = self.final_norm(x)
120
- res = self.lm_head(x) # (b, t, vocab_size)
121
- return res, updated_kv_cache
122
-
123
-
124
- def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
125
- """Returns the model config for a TinyLlama model.
126
-
127
- Args:
128
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
129
- is 1024.
130
-
131
- Returns:
132
- The model config for a TinyLlama model.
133
- """
134
- attn_config = cfg.AttentionConfig(
135
- num_heads=32,
136
- head_dim=64,
137
- num_query_groups=4,
138
- rotary_percentage=1.0,
139
- )
140
- ff_config = cfg.FeedForwardConfig(
141
- type=cfg.FeedForwardType.GATED,
142
- activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
143
- intermediate_size=5632,
144
- )
145
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
146
- config = cfg.ModelConfig(
147
- vocab_size=32000,
148
- num_layers=22,
149
- max_seq_len=2048,
150
- embedding_dim=2048,
151
- kv_cache_max_len=kv_cache_max_len,
152
- attn_config=attn_config,
153
- ff_config=ff_config,
154
- pre_attention_norm_config=norm_config,
155
- post_attention_norm_config=norm_config,
156
- final_norm_config=norm_config,
157
- enable_hlfb=True,
158
- )
159
- return config
160
-
161
-
162
- def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
163
- config = get_model_config(**kwargs)
164
- config.vocab_size = 128
165
- config.num_layers = 2
166
- config.ff_config.intermediate_size = 256
167
- return config
168
-
169
-
170
- def build_model(
171
- checkpoint_path: str, test_model: bool = False, **kwargs
172
- ) -> nn.Module:
173
- """Instantiates the model instance and load checkpoint if provided."""
174
- config = (
175
- get_fake_model_config(**kwargs)
176
- if test_model
177
- else get_model_config(**kwargs)
178
- )
179
- model = TinyLLamma(config)
180
- if checkpoint_path is not None:
181
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
182
- loader.load(model)
183
- model.eval()
184
- return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
188
- """Instantiates and runs a TinyLlama model."""
189
-
190
- kv_cache_max_len = 1024
191
- model = build_model(
192
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
193
- )
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len)
198
- kv = kv_utils.EKVCache.from_model_config(model.config)
199
- print("running an inference")
200
- print(model.forward(tokens, input_pos, kv))
201
-
202
-
203
- if __name__ == "__main__":
204
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
205
- define_and_run(input_checkpoint_path)
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,67 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import os
17
- from pathlib import Path
18
-
19
- import ai_edge_torch
20
- from ai_edge_torch.generative.examples.phi2 import phi2
21
- from ai_edge_torch.generative.quantize import quant_recipes
22
- import torch
23
-
24
-
25
- def convert_phi2_to_tflite(
26
- checkpoint_path: str,
27
- prefill_seq_len: int = 512,
28
- kv_cache_max_len: int = 1024,
29
- quantize: bool = True,
30
- ):
31
- """Converts a Phi-2 model to multi-signature tflite model.
32
-
33
- Args:
34
- checkpoint_path (str): The filepath to the model checkpoint, or directory
35
- holding the checkpoint.
36
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
- Defaults to 512.
38
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
- including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized. Defaults
41
- to True.
42
- """
43
- pytorch_model = phi2.build_model(
44
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
- )
46
- # Tensors used to trace the model graph during conversion.
47
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
- prefill_input_pos = torch.arange(0, prefill_seq_len)
49
- decode_token = torch.tensor([[0]], dtype=torch.long)
50
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
51
-
52
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
- edge_model = (
54
- ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
56
- )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
- .convert(quant_config=quant_config)
59
- )
60
- edge_model.export(
61
- f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
62
- )
63
-
64
-
65
- if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
67
- convert_phi2_to_tflite(checkpoint_path)