ai-edge-torch-nightly 0.2.0.dev20240719__py3-none-any.whl → 0.2.0.dev20240720__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (21) hide show
  1. ai_edge_torch/generative/examples/experimental/__init__.py +14 -0
  2. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +14 -0
  3. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +87 -0
  4. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +195 -0
  5. ai_edge_torch/generative/examples/experimental/phi/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +84 -0
  7. ai_edge_torch/generative/examples/experimental/phi/phi2.py +184 -0
  8. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +14 -0
  9. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +89 -0
  10. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +185 -0
  11. ai_edge_torch/generative/examples/gemma/gemma.py +6 -2
  12. ai_edge_torch/generative/examples/phi2/phi2.py +5 -2
  13. ai_edge_torch/generative/examples/t5/t5.py +5 -2
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +42 -27
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +6 -2
  16. ai_edge_torch/generative/test/test_experimental_ekv.py +122 -0
  17. {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/METADATA +1 -1
  18. {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/RECORD +21 -10
  19. {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/LICENSE +0 -0
  20. {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/WHEEL +0 -0
  21. {ai_edge_torch_nightly-0.2.0.dev20240719.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,87 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Note: This is an experimental version of Gemma with external KV cache.
17
+ # Please use with caution.
18
+
19
+
20
+ import os
21
+ from pathlib import Path
22
+
23
+ import torch
24
+
25
+ import ai_edge_torch
26
+ from ai_edge_torch.generative.examples.experimental.gemma import gemma
27
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
+ from ai_edge_torch.generative.quantize import quant_recipes
29
+
30
+
31
+ def convert_gemma_to_tflite(
32
+ checkpoint_path: str,
33
+ prefill_seq_len: int = 512,
34
+ kv_cache_max_len: int = 1024,
35
+ quantize: bool = True,
36
+ ):
37
+ """An example method for converting a Gemma 2B model to multi-signature
38
+ tflite model.
39
+
40
+ Args:
41
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
42
+ holding the checkpoint.
43
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
44
+ Defaults to 512.
45
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
46
+ including both prefill and decode. Defaults to 1024.
47
+ quantize (bool, optional): Whether the model should be quanized.
48
+ Defaults to True.
49
+ """
50
+ pytorch_model = gemma.build_2b_model(
51
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
52
+ )
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
56
+ decode_token = torch.tensor([[0]], dtype=torch.long)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
58
+ kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
83
+
84
+
85
+ if __name__ == '__main__':
86
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
87
+ convert_gemma_to_tflite(checkpoint_path)
@@ -0,0 +1,195 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building a Gemma model.
16
+ #
17
+ # Note: This is an experimental version of Gemma with external KV cache.
18
+ # Please use with caution.
19
+
20
+ import os
21
+ from pathlib import Path
22
+ from typing import Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
29
+ import ai_edge_torch.generative.layers.builder as builder
30
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
31
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
32
+ import ai_edge_torch.generative.layers.model_config as cfg
33
+ import ai_edge_torch.generative.utilities.loader as loading_utils
34
+
35
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
+ ff_up_proj="model.layers.{}.mlp.up_proj",
37
+ ff_down_proj="model.layers.{}.mlp.down_proj",
38
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
39
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
40
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
41
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
42
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
43
+ pre_attn_norm="model.layers.{}.input_layernorm",
44
+ pre_ff_norm="model.layers.{}.post_attention_layernorm",
45
+ embedding="model.embed_tokens",
46
+ final_norm="model.norm",
47
+ lm_head=None,
48
+ )
49
+
50
+
51
+ class Gemma(nn.Module):
52
+
53
+ def __init__(self, config: cfg.ModelConfig):
54
+ super().__init__()
55
+
56
+ self.config = config
57
+ # Construct model layers.
58
+ self.tok_embedding = nn.Embedding(
59
+ config.vocab_size, config.embedding_dim, padding_idx=0
60
+ )
61
+ self.lm_head = nn.Linear(
62
+ config.embedding_dim,
63
+ config.vocab_size,
64
+ bias=config.lm_head_use_bias,
65
+ )
66
+ # Gemma re-uses the embedding as the head projection layer.
67
+ self.lm_head.weight.data = self.tok_embedding.weight.data
68
+ self.transformer_blocks = nn.ModuleList(
69
+ TransformerBlock(config) for _ in range(config.num_layers)
70
+ )
71
+ self.final_norm = builder.build_norm(
72
+ config.embedding_dim,
73
+ config.final_norm_config,
74
+ )
75
+ self.rope_cache = attn_utils.build_rope_cache(
76
+ size=config.kv_cache_max,
77
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
78
+ base=10_000,
79
+ condense_ratio=1,
80
+ dtype=torch.float32,
81
+ device=torch.device("cpu"),
82
+ )
83
+ self.mask_cache = attn_utils.build_causal_mask_cache(
84
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
85
+ )
86
+ self.config = config
87
+
88
+ @torch.inference_mode
89
+ def forward(
90
+ self,
91
+ tokens: torch.Tensor,
92
+ input_pos: torch.Tensor,
93
+ kv_cache: kv_utils.EKVCache,
94
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
95
+ B, T = tokens.size()
96
+ assert (
97
+ self.config.max_seq_len >= T
98
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
99
+
100
+ cos, sin = self.rope_cache
101
+ cos = cos.index_select(0, input_pos)
102
+ sin = sin.index_select(0, input_pos)
103
+ mask = self.mask_cache.index_select(2, input_pos)
104
+ mask = mask[:, :, :, : self.config.kv_cache_max]
105
+
106
+ # token embeddings of shape (b, t, n_embd)
107
+ x = self.tok_embedding(tokens)
108
+ x = x * (self.config.embedding_dim**0.5)
109
+
110
+ updated_kv_entires = []
111
+ for i, block in enumerate(self.transformer_blocks):
112
+ kv_entry = kv_cache.caches[i] if kv_cache else None
113
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
114
+ if kv_entry:
115
+ updated_kv_entires.append(kv_entry)
116
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
117
+
118
+ x = self.final_norm(x)
119
+ res = self.lm_head(x) # (b, t, vocab_size)
120
+ return res, updated_kv_cache
121
+
122
+
123
+ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
124
+ attn_config = cfg.AttentionConfig(
125
+ num_heads=8,
126
+ num_query_groups=1,
127
+ rotary_percentage=1.0,
128
+ )
129
+ ff_config = cfg.FeedForwardConfig(
130
+ type=cfg.FeedForwardType.GATED,
131
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
132
+ intermediate_size=16384,
133
+ )
134
+ norm_config = cfg.NormalizationConfig(
135
+ type=cfg.NormalizationType.RMS_NORM,
136
+ epsilon=1e-6,
137
+ zero_centered=True,
138
+ )
139
+ config = cfg.ModelConfig(
140
+ vocab_size=256000,
141
+ num_layers=18,
142
+ max_seq_len=8192,
143
+ embedding_dim=2048,
144
+ kv_cache_max_len=kv_cache_max_len,
145
+ attn_config=attn_config,
146
+ ff_config=ff_config,
147
+ pre_attention_norm_config=norm_config,
148
+ pre_ff_norm_config=norm_config,
149
+ final_norm_config=norm_config,
150
+ parallel_residual=False,
151
+ lm_head_use_bias=False,
152
+ enable_hlfb=True,
153
+ )
154
+ return config
155
+
156
+
157
+ def get_fake_model_config_2b_for_test(**kwargs) -> cfg.ModelConfig:
158
+ config = get_model_config_2b(**kwargs)
159
+ config.num_layers = 2
160
+ return config
161
+
162
+
163
+ def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
164
+ config = (
165
+ get_fake_model_config_2b_for_test(**kwargs)
166
+ if test_model
167
+ else get_model_config_2b(**kwargs)
168
+ )
169
+ model = Gemma(config)
170
+ if checkpoint_path is not None:
171
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
172
+ # since embedding and lm-head use the same weight, we need to set strict
173
+ # to False.
174
+ loader.load(model, strict=False)
175
+ model.eval()
176
+ return model
177
+
178
+
179
+ def define_and_run_2b(checkpoint_path, test_model=False) -> None:
180
+ kv_cache_max_len = 1024
181
+ model = build_2b_model(
182
+ checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
183
+ )
184
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
185
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
186
+ tokens[0, :4] = idx
187
+ input_pos = torch.arange(0, kv_cache_max_len)
188
+ kv = kv_utils.EKVCache.from_model_config(model.config)
189
+ print("running an inference")
190
+ print(model.forward(tokens, input_pos, kv))
191
+
192
+
193
+ if __name__ == "__main__":
194
+ checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
195
+ define_and_run_2b(checkpoint_path)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Note: This is an experimental version of phi2 with external KV cache.
17
+ # Please use with caution.
18
+
19
+ import os
20
+ from pathlib import Path
21
+
22
+ import torch
23
+
24
+ import ai_edge_torch
25
+ from ai_edge_torch.generative.examples.experimental.phi import phi2
26
+ from ai_edge_torch.generative.layers.experimental import ekv_cache
27
+ from ai_edge_torch.generative.quantize import quant_recipes
28
+
29
+
30
+ def convert_phi2_to_tflite(
31
+ checkpoint_path: str,
32
+ prefill_seq_len: int = 512,
33
+ kv_cache_max_len: int = 1024,
34
+ quantize: bool = True,
35
+ ):
36
+ """An example method for converting a Phi-2 model to multi-signature
37
+ tflite model.
38
+
39
+ Args:
40
+ checkpoint_path (str): The filepath to the model checkpoint, or
41
+ directory holding the checkpoint.
42
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
43
+ Defaults to 512.
44
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
45
+ including both prefill and decode. Defaults to 1024.
46
+ quantize (bool, optional): Whether the model should be quanized.
47
+ Defaults to True.
48
+ """
49
+ pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
50
+ # Tensors used to trace the model graph during conversion.
51
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
52
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
53
+ decode_token = torch.tensor([[0]], dtype=torch.long)
54
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
55
+ kv = ekv_cache.EKVCache.from_model_config(pytorch_model.config)
56
+
57
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
58
+ edge_model = (
59
+ ai_edge_torch.signature(
60
+ 'prefill',
61
+ pytorch_model,
62
+ sample_kwargs={
63
+ 'tokens': prefill_tokens,
64
+ 'input_pos': prefill_input_pos,
65
+ 'kv_cache': kv,
66
+ },
67
+ )
68
+ .signature(
69
+ 'decode',
70
+ pytorch_model,
71
+ sample_kwargs={
72
+ 'tokens': decode_token,
73
+ 'input_pos': decode_input_pos,
74
+ 'kv_cache': kv,
75
+ },
76
+ )
77
+ .convert(quant_config=quant_config)
78
+ )
79
+ edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
80
+
81
+
82
+ if __name__ == '__main__':
83
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/phi2')
84
+ convert_phi2_to_tflite(checkpoint_path)
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building phi-2 model from the Edge Generative API layers.
16
+ #
17
+ # Note: This is an experimental version of phi2 with external KV cache.
18
+ # Please use with caution.
19
+
20
+
21
+ import os
22
+ from pathlib import Path
23
+ from typing import Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
30
+ import ai_edge_torch.generative.layers.builder as builder
31
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
32
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
33
+ import ai_edge_torch.generative.layers.model_config as cfg
34
+ import ai_edge_torch.generative.utilities.loader as loading_utils
35
+
36
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
37
+ ff_up_proj="model.layers.{}.mlp.fc1",
38
+ ff_down_proj="model.layers.{}.mlp.fc2",
39
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
40
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
41
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
42
+ attn_output_proj="model.layers.{}.self_attn.dense",
43
+ pre_attn_norm="model.layers.{}.input_layernorm",
44
+ embedding="model.embed_tokens",
45
+ final_norm="model.final_layernorm",
46
+ lm_head="lm_head",
47
+ )
48
+
49
+
50
+ class Phi2(nn.Module):
51
+
52
+ def __init__(self, config: cfg.ModelConfig):
53
+ super().__init__()
54
+
55
+ self.config = config
56
+ # Construct model layers.
57
+ self.lm_head = nn.Linear(
58
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
59
+ )
60
+ self.tok_embedding = nn.Embedding(
61
+ config.vocab_size, config.embedding_dim, padding_idx=0
62
+ )
63
+ self.transformer_blocks = nn.ModuleList(
64
+ TransformerBlock(config) for _ in range(config.num_layers)
65
+ )
66
+ self.final_norm = builder.build_norm(
67
+ config.embedding_dim,
68
+ config.final_norm_config,
69
+ )
70
+ self.rope_cache = attn_utils.build_rope_cache(
71
+ size=config.kv_cache_max,
72
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
73
+ base=10_000,
74
+ condense_ratio=1,
75
+ dtype=torch.float32,
76
+ device=torch.device("cpu"),
77
+ )
78
+ self.mask_cache = attn_utils.build_causal_mask_cache(
79
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
80
+ )
81
+ self.config = config
82
+
83
+ @torch.inference_mode
84
+ def forward(
85
+ self,
86
+ tokens: torch.Tensor,
87
+ input_pos: torch.Tensor,
88
+ kv_cache: kv_utils.EKVCache,
89
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
90
+ B, T = tokens.size()
91
+ assert (
92
+ self.config.max_seq_len >= T
93
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
94
+
95
+ cos, sin = self.rope_cache
96
+ cos = cos.index_select(0, input_pos)
97
+ sin = sin.index_select(0, input_pos)
98
+ mask = self.mask_cache.index_select(2, input_pos)
99
+ mask = mask[:, :, :, : self.config.kv_cache_max]
100
+
101
+ x = self.tok_embedding(tokens)
102
+
103
+ updated_kv_entires = []
104
+ for i, block in enumerate(self.transformer_blocks):
105
+ kv_entry = kv_cache.caches[i] if kv_cache else None
106
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
+ if kv_entry:
108
+ updated_kv_entires.append(kv_entry)
109
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
110
+
111
+ x = self.final_norm(x)
112
+ res = self.lm_head(x) # (b, t, vocab_size)
113
+ return res, updated_kv_cache
114
+
115
+
116
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
117
+ attn_config = cfg.AttentionConfig(
118
+ num_heads=32,
119
+ num_query_groups=32,
120
+ rotary_percentage=0.4,
121
+ qkv_use_bias=True,
122
+ output_proj_use_bias=True,
123
+ )
124
+ ff_config = cfg.FeedForwardConfig(
125
+ type=cfg.FeedForwardType.SEQUENTIAL,
126
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
127
+ intermediate_size=10240,
128
+ use_bias=True,
129
+ )
130
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
131
+ config = cfg.ModelConfig(
132
+ vocab_size=51200,
133
+ num_layers=32,
134
+ max_seq_len=2048,
135
+ kv_cache_max_len=kv_cache_max_len,
136
+ embedding_dim=2560,
137
+ attn_config=attn_config,
138
+ ff_config=ff_config,
139
+ pre_attention_norm_config=norm_config,
140
+ final_norm_config=norm_config,
141
+ parallel_residual=True,
142
+ lm_head_use_bias=True,
143
+ enable_hlfb=True,
144
+ )
145
+ return config
146
+
147
+
148
+ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
149
+ config = get_model_config(**kwargs)
150
+ config.num_layers = 2
151
+ return config
152
+
153
+
154
+ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
155
+ config = (
156
+ get_fake_model_config_for_test(**kwargs)
157
+ if test_model
158
+ else get_model_config(**kwargs)
159
+ )
160
+ model = Phi2(config)
161
+ if checkpoint_path is not None:
162
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
163
+ loader.load(model)
164
+ model.eval()
165
+ return model
166
+
167
+
168
+ def define_and_run(checkpoint_path, test_model=False) -> None:
169
+ kv_cache_max_len = 1024
170
+ model = build_model(
171
+ checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
172
+ )
173
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
174
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
175
+ tokens[0, :4] = idx
176
+ input_pos = torch.arange(0, kv_cache_max_len)
177
+ kv = kv_utils.EKVCache.from_model_config(model.config)
178
+ print("running an inference")
179
+ print(model.forward(tokens, input_pos, kv))
180
+
181
+
182
+ if __name__ == "__main__":
183
+ checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
184
+ define_and_run(checkpoint_path)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Note: This is an experimental version of TinyLlama with external KV cache.
17
+ # Please use with caution.
18
+
19
+
20
+ import os
21
+ from pathlib import Path
22
+
23
+ import torch
24
+
25
+ import ai_edge_torch
26
+ from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
27
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
+ from ai_edge_torch.generative.quantize import quant_recipes
29
+
30
+
31
+ def convert_tiny_llama_to_tflite(
32
+ checkpoint_path: str,
33
+ prefill_seq_len: int = 512,
34
+ kv_cache_max_len: int = 1024,
35
+ quantize: bool = True,
36
+ ):
37
+ """An example method for converting TinyLlama model to multi-signature
38
+ tflite model.
39
+
40
+ Args:
41
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
42
+ holding the checkpoint.
43
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
44
+ Defaults to 512.
45
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
46
+ including both prefill and decode. Defaults to 1024.
47
+ quantize (bool, optional): Whether the model should be quanized.
48
+ Defaults to True.
49
+ """
50
+ pytorch_model = tiny_llama.build_model(
51
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
52
+ )
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
56
+ decode_token = torch.tensor([[0]], dtype=torch.long)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
58
+ kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(
83
+ f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
84
+ )
85
+
86
+
87
+ if __name__ == '__main__':
88
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
89
+ convert_tiny_llama_to_tflite(checkpoint_path)
@@ -0,0 +1,185 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building a TinyLlama model from the Edge Generative API layers.
16
+ #
17
+ # Note: This is an experimental version of TinyLlama with external KV cache.
18
+ # Please use with caution.
19
+
20
+
21
+ import os
22
+ from pathlib import Path
23
+ from typing import Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
30
+ import ai_edge_torch.generative.layers.builder as builder
31
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
32
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
33
+ import ai_edge_torch.generative.layers.model_config as cfg
34
+ import ai_edge_torch.generative.utilities.loader as loading_utils
35
+
36
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
37
+ ff_up_proj="model.layers.{}.mlp.up_proj",
38
+ ff_down_proj="model.layers.{}.mlp.down_proj",
39
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
40
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
41
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
42
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
43
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
44
+ pre_attn_norm="model.layers.{}.input_layernorm",
45
+ pre_ff_norm="model.layers.{}.post_attention_layernorm",
46
+ embedding="model.embed_tokens",
47
+ final_norm="model.norm",
48
+ lm_head="lm_head",
49
+ )
50
+
51
+
52
+ class TinyLLamma(nn.Module):
53
+
54
+ def __init__(self, config: cfg.ModelConfig):
55
+ super().__init__()
56
+
57
+ self.config = config
58
+ # Construct model layers.
59
+ self.lm_head = nn.Linear(
60
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
61
+ )
62
+ self.tok_embedding = nn.Embedding(
63
+ config.vocab_size, config.embedding_dim, padding_idx=0
64
+ )
65
+ self.transformer_blocks = nn.ModuleList(
66
+ TransformerBlock(config) for _ in range(config.num_layers)
67
+ )
68
+ self.final_norm = builder.build_norm(
69
+ config.embedding_dim,
70
+ config.final_norm_config,
71
+ )
72
+ self.rope_cache = attn_utils.build_rope_cache(
73
+ size=config.kv_cache_max,
74
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
75
+ base=10_000,
76
+ condense_ratio=1,
77
+ dtype=torch.float32,
78
+ device=torch.device("cpu"),
79
+ )
80
+ self.mask_cache = attn_utils.build_causal_mask_cache(
81
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
82
+ )
83
+ self.config = config
84
+
85
+ @torch.inference_mode
86
+ def forward(
87
+ self,
88
+ tokens: torch.Tensor,
89
+ input_pos: torch.Tensor,
90
+ kv_cache: kv_utils.EKVCache,
91
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
92
+ B, T = tokens.size()
93
+ assert (
94
+ self.config.max_seq_len >= T
95
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
96
+
97
+ cos, sin = self.rope_cache
98
+ cos = cos.index_select(0, input_pos)
99
+ sin = sin.index_select(0, input_pos)
100
+ mask = self.mask_cache.index_select(2, input_pos)
101
+ mask = mask[:, :, :, : self.config.kv_cache_max]
102
+
103
+ # token embeddings of shape (b, t, n_embd)
104
+ x = self.tok_embedding(tokens)
105
+
106
+ updated_kv_entires = []
107
+ for i, block in enumerate(self.transformer_blocks):
108
+ kv_entry = kv_cache.caches[i] if kv_cache else None
109
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
110
+ if kv_entry:
111
+ updated_kv_entires.append(kv_entry)
112
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
113
+
114
+ x = self.final_norm(x)
115
+ res = self.lm_head(x) # (b, t, vocab_size)
116
+ return res, updated_kv_cache
117
+
118
+
119
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
120
+ attn_config = cfg.AttentionConfig(
121
+ num_heads=32,
122
+ num_query_groups=4,
123
+ rotary_percentage=1.0,
124
+ )
125
+ ff_config = cfg.FeedForwardConfig(
126
+ type=cfg.FeedForwardType.GATED,
127
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
128
+ intermediate_size=5632,
129
+ )
130
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
131
+ config = cfg.ModelConfig(
132
+ vocab_size=32000,
133
+ num_layers=22,
134
+ max_seq_len=2048,
135
+ embedding_dim=2048,
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ attn_config=attn_config,
138
+ ff_config=ff_config,
139
+ pre_attention_norm_config=norm_config,
140
+ pre_ff_norm_config=norm_config,
141
+ final_norm_config=norm_config,
142
+ enable_hlfb=True,
143
+ )
144
+ return config
145
+
146
+
147
+ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
148
+ config = get_model_config(**kwargs)
149
+ config.vocab_size = 128
150
+ config.num_layers = 2
151
+ config.ff_config.intermediate_size = 256
152
+ return config
153
+
154
+
155
+ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
156
+ config = (
157
+ get_fake_model_config_for_test(**kwargs)
158
+ if test_model
159
+ else get_model_config(**kwargs)
160
+ )
161
+ model = TinyLLamma(config)
162
+ if checkpoint_path is not None:
163
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
164
+ loader.load(model)
165
+ model.eval()
166
+ return model
167
+
168
+
169
+ def define_and_run(checkpoint_path, test_model=False) -> None:
170
+ kv_cache_max_len = 1024
171
+ model = build_model(
172
+ checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
173
+ )
174
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
+ tokens[0, :4] = idx
177
+ input_pos = torch.arange(0, kv_cache_max_len)
178
+ kv = kv_utils.EKVCache.from_model_config(model.config)
179
+ print("running an inference")
180
+ print(model.forward(tokens, input_pos, kv))
181
+
182
+
183
+ if __name__ == "__main__":
184
+ checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
185
+ define_and_run(checkpoint_path)
@@ -159,6 +159,9 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
159
159
 
160
160
 
161
161
  def define_and_run_2b() -> None:
162
+ current_dir = Path(__file__).parent.resolve()
163
+ gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
164
+
162
165
  kv_cache_max_len = 1024
163
166
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
164
167
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
@@ -166,8 +169,9 @@ def define_and_run_2b() -> None:
166
169
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
167
170
  tokens[0, :4] = idx
168
171
  input_pos = torch.arange(0, kv_cache_max_len)
169
- print("running an inference")
170
- print(model.forward(tokens, input_pos))
172
+ lm_logits = model.forward(tokens, input_pos)
173
+ print("comparing with goldens..")
174
+ assert torch.allclose(gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
171
175
 
172
176
 
173
177
  if __name__ == "__main__":
@@ -149,6 +149,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
149
149
 
150
150
 
151
151
  def define_and_run() -> None:
152
+ current_dir = Path(__file__).parent.resolve()
153
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
152
154
  kv_cache_max_len = 1024
153
155
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
154
156
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
@@ -156,8 +158,9 @@ def define_and_run() -> None:
156
158
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
157
159
  tokens[0, :4] = idx
158
160
  input_pos = torch.arange(0, kv_cache_max_len)
159
- print("running an inference")
160
- print(model.forward(tokens, input_pos))
161
+ lm_logits = model.forward(tokens, input_pos)
162
+ print("comparing with goldens..")
163
+ assert torch.allclose(phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
161
164
 
162
165
 
163
166
  if __name__ == "__main__":
@@ -557,7 +557,8 @@ def get_sample_encoder_input_ids() -> torch.Tensor:
557
557
 
558
558
 
559
559
  def define_and_run_t5(checkpoint_path: str) -> None:
560
- t5_goldens = torch.load("t5_lm_logits.pt")
560
+ current_dir = Path(__file__).parent.resolve()
561
+ t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
561
562
 
562
563
  model = build_t5_model(checkpoint_path)
563
564
 
@@ -579,7 +580,9 @@ def define_and_run_t5(checkpoint_path: str) -> None:
579
580
 
580
581
  # TODO(haoliang): Move those tests.
581
582
  def define_and_run_t5_split(checkpoint_path: str) -> None:
582
- t5_goldens = torch.load("t5_lm_logits.pt")
583
+ current_dir = Path(__file__).parent.resolve()
584
+ t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
585
+
583
586
  config = get_model_config_t5()
584
587
  embedding_layer = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
585
588
  t5_encoder_model = build_t5_encoder_model(config, embedding_layer, checkpoint_path)
@@ -14,9 +14,8 @@
14
14
  # ==============================================================================
15
15
  # A toy example which has basic transformer block (w/ externalized KV-Cache).
16
16
 
17
- from typing import List, Tuple
17
+ from typing import Tuple
18
18
 
19
- import numpy as np
20
19
  import torch
21
20
  import torch.nn as nn
22
21
  import torch_xla
@@ -24,6 +23,7 @@ import torch_xla
24
23
  import ai_edge_torch
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
27
27
  from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
28
  import ai_edge_torch.generative.layers.model_config as cfg
29
29
 
@@ -60,27 +60,27 @@ class ToyModelWithExternalKV(torch.nn.Module):
60
60
 
61
61
  def forward(
62
62
  self,
63
- idx: torch.Tensor,
63
+ tokens: torch.Tensor,
64
64
  input_pos: torch.Tensor,
65
- k_caches: torch.Tensor,
66
- v_caches: torch.Tensor,
67
- ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
68
- x = self.tok_embedding(idx)
65
+ kv_cache: kv_utils.EKVCache,
66
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
67
+ x = self.tok_embedding(tokens)
69
68
  cos, sin = self.rope_cache
70
69
  cos = cos.index_select(0, input_pos)
71
70
  sin = sin.index_select(0, input_pos)
72
71
  mask = self.mask_cache.index_select(2, input_pos)
73
72
  mask = mask[:, :, :, : self.config.max_seq_len]
74
73
 
74
+ updated_kv_entires = []
75
75
  for i, block in enumerate(self.transformer_blocks):
76
- input_k, input_v = k_caches[i], v_caches[i]
77
- x, (updated_k, updated_v) = block(
78
- x, (cos, sin), mask, input_pos, (input_k, input_v)
79
- )
80
- k_caches[i], v_caches[i] = updated_k, updated_v
76
+ kv_entry = kv_cache.caches[i] if kv_cache else None
77
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
78
+ if kv_entry:
79
+ updated_kv_entires.append(kv_entry)
81
80
 
82
81
  x = self.final_norm(x)
83
- return self.lm_head(x), k_caches, v_caches
82
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
83
+ return self.lm_head(x), updated_kv_cache
84
84
 
85
85
 
86
86
  def _export_stablehlo_mlir(model, args):
@@ -115,15 +115,15 @@ def get_model_config() -> cfg.ModelConfig:
115
115
 
116
116
 
117
117
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
118
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
119
119
  input_pos = torch.arange(0, 100)
120
- return idx, input_pos
120
+ return tokens, input_pos
121
121
 
122
122
 
123
123
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
- idx = torch.tensor([[1]], dtype=torch.long)
124
+ tokens = torch.tensor([[1]], dtype=torch.long)
125
125
  input_pos = torch.tensor([10])
126
- return idx, input_pos
126
+ return tokens, input_pos
127
127
 
128
128
 
129
129
  def define_and_run() -> None:
@@ -131,16 +131,16 @@ def define_and_run() -> None:
131
131
 
132
132
  config = get_model_config()
133
133
  model = ToyModelWithExternalKV(config)
134
+ model.eval()
134
135
  print('running an inference')
135
- k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
- v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
+ kv = kv_utils.EKVCache.from_model_config(config)
137
137
 
138
- idx, input_pos = get_sample_prefill_inputs()
139
- decode_idx, decode_input_pos = get_sample_decode_inputs()
140
- print(model.forward(idx, input_pos, k_caches, v_caches))
138
+ tokens, input_pos = get_sample_prefill_inputs()
139
+ decode_token, decode_input_pos = get_sample_decode_inputs()
140
+ print(model.forward(tokens, input_pos, kv))
141
141
 
142
142
  if dump_mlir:
143
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
143
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
144
144
  with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
145
145
  f.write(mlir_text)
146
146
 
@@ -149,13 +149,28 @@ def define_and_run() -> None:
149
149
  # in dynamic update slice op.
150
150
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
151
151
  edge_model = (
152
- ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
153
- .signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
152
+ ai_edge_torch.signature(
153
+ 'prefill',
154
+ model,
155
+ sample_kwargs={
156
+ 'tokens': tokens,
157
+ 'input_pos': input_pos,
158
+ 'kv_cache': kv,
159
+ },
160
+ )
161
+ .signature(
162
+ 'decode',
163
+ model,
164
+ sample_kwargs={
165
+ 'tokens': decode_token,
166
+ 'input_pos': decode_input_pos,
167
+ 'kv_cache': kv,
168
+ },
169
+ )
154
170
  .convert()
155
171
  )
156
172
  edge_model.export('/tmp/toy_external_kv_cache.tflite')
157
173
 
158
174
 
159
175
  if __name__ == '__main__':
160
- with torch.inference_mode():
161
- define_and_run()
176
+ define_and_run()
@@ -149,6 +149,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
149
149
 
150
150
 
151
151
  def define_and_run() -> None:
152
+ current_dir = Path(__file__).parent.resolve()
153
+ tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
152
154
  kv_cache_max_len = 1024
153
155
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
154
156
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
@@ -156,8 +158,10 @@ def define_and_run() -> None:
156
158
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
157
159
  tokens[0, :4] = idx
158
160
  input_pos = torch.arange(0, kv_cache_max_len)
159
- print("running an inference")
160
- print(model.forward(tokens, input_pos))
161
+ lm_logits = model.forward(tokens, input_pos)
162
+ assert torch.allclose(
163
+ tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
164
+ )
161
165
 
162
166
 
163
167
  if __name__ == "__main__":
@@ -0,0 +1,122 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A suite of tests to validate experimental external KV Cache layers and models.
16
+
17
+ import unittest
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ai_edge_torch.generative.examples.experimental.gemma import gemma
23
+ from ai_edge_torch.generative.examples.experimental.phi import phi2
24
+ from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
25
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+
28
+
29
+ class TestExternalKVLayers(unittest.TestCase):
30
+
31
+ def _get_test_config(self, num_layers, head_dim, num_query_groups, kv_cache_max_len):
32
+ attn_config = cfg.AttentionConfig(num_heads=1, num_query_groups=num_query_groups)
33
+ config = cfg.ModelConfig(
34
+ kv_cache_max_len=kv_cache_max_len,
35
+ embedding_dim=head_dim,
36
+ attn_config=attn_config,
37
+ num_layers=num_layers,
38
+ max_seq_len=None,
39
+ vocab_size=None,
40
+ ff_config=None,
41
+ )
42
+ return config
43
+
44
+ def test_cache_udpate(self):
45
+ N = 1
46
+ HEAD_DIM = 2
47
+ NUM_QG = 1
48
+ KV_LEN = 4
49
+ config = self._get_test_config(
50
+ num_layers=N,
51
+ head_dim=HEAD_DIM,
52
+ num_query_groups=NUM_QG,
53
+ kv_cache_max_len=KV_LEN,
54
+ )
55
+ kv = kv_utils.EKVCache.from_model_config(config)
56
+ entry = kv.caches[0]
57
+ # single-slice update
58
+ input_pos = torch.tensor([1])
59
+ k_slice = v_slice = torch.full((1, 1, NUM_QG, HEAD_DIM), 5, dtype=torch.float)
60
+ updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
61
+ self.assertEqual(
62
+ updated_entry.k_cache.numpy().flatten().tolist(), [0, 0, 5, 5, 0, 0, 0, 0]
63
+ )
64
+ self.assertEqual(
65
+ updated_entry.v_cache.numpy().flatten().tolist(), [0, 0, 5, 5, 0, 0, 0, 0]
66
+ )
67
+ # multi-slice update
68
+ input_pos = torch.tensor([0, 3])
69
+ k_slice = v_slice = torch.full((1, 2, NUM_QG, HEAD_DIM), 7, dtype=torch.float)
70
+ updated_entry = kv_utils.update(entry, input_pos, k_slice, v_slice)
71
+ self.assertEqual(
72
+ updated_entry.k_cache.numpy().flatten().tolist(), [7, 7, 0, 0, 0, 0, 7, 7]
73
+ )
74
+ self.assertEqual(
75
+ updated_entry.v_cache.numpy().flatten().tolist(), [7, 7, 0, 0, 0, 0, 7, 7]
76
+ )
77
+
78
+ def test_serialization(self):
79
+ class TestModel(torch.nn.Module):
80
+
81
+ def forward(self, kv: kv_utils.EKVCache) -> kv_utils.EKVCache:
82
+ updated_kv_entries = [
83
+ kv_utils.KVCacheEntry(
84
+ torch.zeros_like(entry.k_cache), torch.zeros_like(entry.v_cache)
85
+ )
86
+ for entry in kv.caches
87
+ ]
88
+ return kv_utils.EKVCache(updated_kv_entries)
89
+
90
+ N = 1
91
+ HEAD_DIM = 2
92
+ NUM_QG = 1
93
+ KV_LEN = 4
94
+ config = self._get_test_config(
95
+ num_layers=N,
96
+ head_dim=HEAD_DIM,
97
+ num_query_groups=NUM_QG,
98
+ kv_cache_max_len=KV_LEN,
99
+ )
100
+ kv = kv_utils.EKVCache.from_model_config(config)
101
+ model = TestModel()
102
+ exported_program = torch.export.export(model, (kv,))
103
+ input_specs = exported_program.graph_signature.input_specs
104
+ self.assertEqual(len(input_specs), 2)
105
+ self.assertEqual(input_specs[0].arg.name, "kv_k_0")
106
+ self.assertEqual(input_specs[1].arg.name, "kv_v_0")
107
+
108
+
109
+ class TestExternalKVModels(unittest.TestCase):
110
+
111
+ def test_can_build_gemma(self):
112
+ gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
113
+
114
+ def test_can_build_phi2(self):
115
+ phi2.define_and_run(checkpoint_path=None, test_model=True)
116
+
117
+ def test_can_build_tinyllama(self):
118
+ tiny_llama.define_and_run(checkpoint_path=None, test_model=True)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240719
3
+ Version: 0.2.0.dev20240720
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -35,12 +35,22 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=0guAEon5cvwBpPXk6J0wVOKj7TX
35
35
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
36
36
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
37
37
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
38
+ ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
39
+ ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
+ ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=bW_KOj_3fcZAggfST3zHWcMcNJs70b0pld-vvauAOgo,3076
41
+ ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=u4DNsZRnN7whDoK8yQet9Yahb01ToVqTuFQmWV1__1g,6606
42
+ ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
+ ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=sLU35tpQ-PEbhZbLfC1vSqM-HamKREVBpIoywWh9O3M,3036
44
+ ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=zgxB2JSFAevyS28C6-wIBaQeeKTUejUJY4dnR4BqRBI,6150
45
+ ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
46
+ ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=PChEhBotZ8k6GZiq9e_AYnn3RyhNIVm_U96QhVjx3jY,3126
47
+ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=1vL0u6Pkd8SV8uei9BGzSAIokclT_RaE3K0IczoPfeI,6291
38
48
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
39
49
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=UMEZGDGhFvAX4eT5KHAE1Xbxw-qtQWEMxgvB8cSH6wY,2531
40
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=1lZfXGHmbII4rFu0U2B9NzlJCRhphxtmQtkCHQ39_uw,5935
50
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YyGGsgEByIg_tIysMBqaBztf_csthZIjah8mmH5o7UA,6144
41
51
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
52
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=uF1A2EX8xYie30-T2Z7s1WZCtFhp5CEwRV8SCd7Umrc,2505
43
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
53
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=KjfTrD2OBzOfq83-XvJ6ZhmXLuP_VqugSOwyj-M5YY4,5767
44
54
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
45
55
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
46
56
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=P-cUUQaQKGKV2p-7hvLJ--RpCIA7gk8WCDRgg0pNtd0,4331
@@ -58,15 +68,15 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX
58
68
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
59
69
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
70
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=7RwaZQaKhFt3zKAUbFjq95CSYhL1nd9BVSbSRNJp4-4,4529
61
- ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
71
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=fVtJ0S8v2bMtvEuDqD6Orw7CTyXqnRIqZfKcz7DBeJc,21212
62
72
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=KaGzCAViNOpJIQbRF-ItouuVPqI9nroWRRGN-KFYKZs,8357
63
73
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
74
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Sf3ZMYv-iuMRKAKLow47qth8vTF1zl6i8TxJ9uT_StU,3885
65
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
75
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=jmucKpWY_nHEOAh7G62IxpReNmrKWo4PxfELul_h9xQ,5796
66
76
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=lfYUiem_Pbn3vGgPx84BeI8n7rN3-1fImwCLm8Eo2U8,4853
67
77
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
68
78
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=nT7Fh-f5ZdwaK3dPoCvZflpJ4fRHjLdFMjk1_uw3-b8,2559
69
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
79
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=to9IlF-X_uIJvO-roZOW1ZMUhmkYbvFjc-tUVaQr6TE,5848
70
80
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
71
81
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=BCAcc_OcEjvbaXQSbc8vlKeMad7E3gCA4BNsUdWRwBI,1966
72
82
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -94,6 +104,7 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
94
104
  ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=iTNPrlubmq9ia7C3zHl50J2YEMsc4o33GwL5tr5VkkE,5229
95
105
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
96
106
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
107
+ ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=qMR0r7Pr_t2bn-cyeA7Qw_Rl94H1NmFcqM2ua8gpDDw,4230
97
108
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
98
109
  ai_edge_torch/generative/test/test_quantize.py,sha256=QbF7LC9olJFGXqlAVGciac7xXc4rDtCSr71tTIYuqPk,5230
99
110
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -114,8 +125,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
114
125
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
115
126
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
116
127
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
117
- ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
- ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/METADATA,sha256=X9TaI_3Rxn0rk89P3ZcXJlNtEIUBOhOIIKAncN3Xpos,1745
119
- ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
- ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
- ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/RECORD,,
128
+ ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
129
+ ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/METADATA,sha256=xkXzcnmvTzJRRNOJ2c8JnWS1ZCofdlZiKsW5sa5sDyM,1745
130
+ ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
131
+ ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
132
+ ai_edge_torch_nightly-0.2.0.dev20240720.dist-info/RECORD,,