ai-edge-torch-nightly 0.2.0.dev20240718__py3-none-any.whl → 0.2.0.dev20240720__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (23) hide show
  1. ai_edge_torch/convert/conversion_utils.py +39 -18
  2. ai_edge_torch/convert/test/test_convert.py +106 -0
  3. ai_edge_torch/generative/examples/experimental/__init__.py +14 -0
  4. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +14 -0
  5. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +87 -0
  6. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +195 -0
  7. ai_edge_torch/generative/examples/experimental/phi/__init__.py +14 -0
  8. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +84 -0
  9. ai_edge_torch/generative/examples/experimental/phi/phi2.py +184 -0
  10. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +14 -0
  11. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +89 -0
  12. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +185 -0
  13. ai_edge_torch/generative/examples/gemma/gemma.py +6 -2
  14. ai_edge_torch/generative/examples/phi2/phi2.py +5 -2
  15. ai_edge_torch/generative/examples/t5/t5.py +5 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +42 -27
  17. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +6 -2
  18. ai_edge_torch/generative/test/test_experimental_ekv.py +122 -0
  19. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/RECORD +23 -12
  21. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240720.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building phi-2 model from the Edge Generative API layers.
16
+ #
17
+ # Note: This is an experimental version of phi2 with external KV cache.
18
+ # Please use with caution.
19
+
20
+
21
+ import os
22
+ from pathlib import Path
23
+ from typing import Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
30
+ import ai_edge_torch.generative.layers.builder as builder
31
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
32
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
33
+ import ai_edge_torch.generative.layers.model_config as cfg
34
+ import ai_edge_torch.generative.utilities.loader as loading_utils
35
+
36
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
37
+ ff_up_proj="model.layers.{}.mlp.fc1",
38
+ ff_down_proj="model.layers.{}.mlp.fc2",
39
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
40
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
41
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
42
+ attn_output_proj="model.layers.{}.self_attn.dense",
43
+ pre_attn_norm="model.layers.{}.input_layernorm",
44
+ embedding="model.embed_tokens",
45
+ final_norm="model.final_layernorm",
46
+ lm_head="lm_head",
47
+ )
48
+
49
+
50
+ class Phi2(nn.Module):
51
+
52
+ def __init__(self, config: cfg.ModelConfig):
53
+ super().__init__()
54
+
55
+ self.config = config
56
+ # Construct model layers.
57
+ self.lm_head = nn.Linear(
58
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
59
+ )
60
+ self.tok_embedding = nn.Embedding(
61
+ config.vocab_size, config.embedding_dim, padding_idx=0
62
+ )
63
+ self.transformer_blocks = nn.ModuleList(
64
+ TransformerBlock(config) for _ in range(config.num_layers)
65
+ )
66
+ self.final_norm = builder.build_norm(
67
+ config.embedding_dim,
68
+ config.final_norm_config,
69
+ )
70
+ self.rope_cache = attn_utils.build_rope_cache(
71
+ size=config.kv_cache_max,
72
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
73
+ base=10_000,
74
+ condense_ratio=1,
75
+ dtype=torch.float32,
76
+ device=torch.device("cpu"),
77
+ )
78
+ self.mask_cache = attn_utils.build_causal_mask_cache(
79
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
80
+ )
81
+ self.config = config
82
+
83
+ @torch.inference_mode
84
+ def forward(
85
+ self,
86
+ tokens: torch.Tensor,
87
+ input_pos: torch.Tensor,
88
+ kv_cache: kv_utils.EKVCache,
89
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
90
+ B, T = tokens.size()
91
+ assert (
92
+ self.config.max_seq_len >= T
93
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
94
+
95
+ cos, sin = self.rope_cache
96
+ cos = cos.index_select(0, input_pos)
97
+ sin = sin.index_select(0, input_pos)
98
+ mask = self.mask_cache.index_select(2, input_pos)
99
+ mask = mask[:, :, :, : self.config.kv_cache_max]
100
+
101
+ x = self.tok_embedding(tokens)
102
+
103
+ updated_kv_entires = []
104
+ for i, block in enumerate(self.transformer_blocks):
105
+ kv_entry = kv_cache.caches[i] if kv_cache else None
106
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
107
+ if kv_entry:
108
+ updated_kv_entires.append(kv_entry)
109
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
110
+
111
+ x = self.final_norm(x)
112
+ res = self.lm_head(x) # (b, t, vocab_size)
113
+ return res, updated_kv_cache
114
+
115
+
116
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
117
+ attn_config = cfg.AttentionConfig(
118
+ num_heads=32,
119
+ num_query_groups=32,
120
+ rotary_percentage=0.4,
121
+ qkv_use_bias=True,
122
+ output_proj_use_bias=True,
123
+ )
124
+ ff_config = cfg.FeedForwardConfig(
125
+ type=cfg.FeedForwardType.SEQUENTIAL,
126
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
127
+ intermediate_size=10240,
128
+ use_bias=True,
129
+ )
130
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
131
+ config = cfg.ModelConfig(
132
+ vocab_size=51200,
133
+ num_layers=32,
134
+ max_seq_len=2048,
135
+ kv_cache_max_len=kv_cache_max_len,
136
+ embedding_dim=2560,
137
+ attn_config=attn_config,
138
+ ff_config=ff_config,
139
+ pre_attention_norm_config=norm_config,
140
+ final_norm_config=norm_config,
141
+ parallel_residual=True,
142
+ lm_head_use_bias=True,
143
+ enable_hlfb=True,
144
+ )
145
+ return config
146
+
147
+
148
+ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
149
+ config = get_model_config(**kwargs)
150
+ config.num_layers = 2
151
+ return config
152
+
153
+
154
+ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
155
+ config = (
156
+ get_fake_model_config_for_test(**kwargs)
157
+ if test_model
158
+ else get_model_config(**kwargs)
159
+ )
160
+ model = Phi2(config)
161
+ if checkpoint_path is not None:
162
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
163
+ loader.load(model)
164
+ model.eval()
165
+ return model
166
+
167
+
168
+ def define_and_run(checkpoint_path, test_model=False) -> None:
169
+ kv_cache_max_len = 1024
170
+ model = build_model(
171
+ checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
172
+ )
173
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
174
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
175
+ tokens[0, :4] = idx
176
+ input_pos = torch.arange(0, kv_cache_max_len)
177
+ kv = kv_utils.EKVCache.from_model_config(model.config)
178
+ print("running an inference")
179
+ print(model.forward(tokens, input_pos, kv))
180
+
181
+
182
+ if __name__ == "__main__":
183
+ checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
184
+ define_and_run(checkpoint_path)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Note: This is an experimental version of TinyLlama with external KV cache.
17
+ # Please use with caution.
18
+
19
+
20
+ import os
21
+ from pathlib import Path
22
+
23
+ import torch
24
+
25
+ import ai_edge_torch
26
+ from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
27
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
+ from ai_edge_torch.generative.quantize import quant_recipes
29
+
30
+
31
+ def convert_tiny_llama_to_tflite(
32
+ checkpoint_path: str,
33
+ prefill_seq_len: int = 512,
34
+ kv_cache_max_len: int = 1024,
35
+ quantize: bool = True,
36
+ ):
37
+ """An example method for converting TinyLlama model to multi-signature
38
+ tflite model.
39
+
40
+ Args:
41
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
42
+ holding the checkpoint.
43
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
44
+ Defaults to 512.
45
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
46
+ including both prefill and decode. Defaults to 1024.
47
+ quantize (bool, optional): Whether the model should be quanized.
48
+ Defaults to True.
49
+ """
50
+ pytorch_model = tiny_llama.build_model(
51
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
52
+ )
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
56
+ decode_token = torch.tensor([[0]], dtype=torch.long)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
58
+ kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(
83
+ f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
84
+ )
85
+
86
+
87
+ if __name__ == '__main__':
88
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
89
+ convert_tiny_llama_to_tflite(checkpoint_path)
@@ -0,0 +1,185 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building a TinyLlama model from the Edge Generative API layers.
16
+ #
17
+ # Note: This is an experimental version of TinyLlama with external KV cache.
18
+ # Please use with caution.
19
+
20
+
21
+ import os
22
+ from pathlib import Path
23
+ from typing import Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
30
+ import ai_edge_torch.generative.layers.builder as builder
31
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
32
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
33
+ import ai_edge_torch.generative.layers.model_config as cfg
34
+ import ai_edge_torch.generative.utilities.loader as loading_utils
35
+
36
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
37
+ ff_up_proj="model.layers.{}.mlp.up_proj",
38
+ ff_down_proj="model.layers.{}.mlp.down_proj",
39
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
40
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
41
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
42
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
43
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
44
+ pre_attn_norm="model.layers.{}.input_layernorm",
45
+ pre_ff_norm="model.layers.{}.post_attention_layernorm",
46
+ embedding="model.embed_tokens",
47
+ final_norm="model.norm",
48
+ lm_head="lm_head",
49
+ )
50
+
51
+
52
+ class TinyLLamma(nn.Module):
53
+
54
+ def __init__(self, config: cfg.ModelConfig):
55
+ super().__init__()
56
+
57
+ self.config = config
58
+ # Construct model layers.
59
+ self.lm_head = nn.Linear(
60
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
61
+ )
62
+ self.tok_embedding = nn.Embedding(
63
+ config.vocab_size, config.embedding_dim, padding_idx=0
64
+ )
65
+ self.transformer_blocks = nn.ModuleList(
66
+ TransformerBlock(config) for _ in range(config.num_layers)
67
+ )
68
+ self.final_norm = builder.build_norm(
69
+ config.embedding_dim,
70
+ config.final_norm_config,
71
+ )
72
+ self.rope_cache = attn_utils.build_rope_cache(
73
+ size=config.kv_cache_max,
74
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
75
+ base=10_000,
76
+ condense_ratio=1,
77
+ dtype=torch.float32,
78
+ device=torch.device("cpu"),
79
+ )
80
+ self.mask_cache = attn_utils.build_causal_mask_cache(
81
+ size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
82
+ )
83
+ self.config = config
84
+
85
+ @torch.inference_mode
86
+ def forward(
87
+ self,
88
+ tokens: torch.Tensor,
89
+ input_pos: torch.Tensor,
90
+ kv_cache: kv_utils.EKVCache,
91
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
92
+ B, T = tokens.size()
93
+ assert (
94
+ self.config.max_seq_len >= T
95
+ ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
96
+
97
+ cos, sin = self.rope_cache
98
+ cos = cos.index_select(0, input_pos)
99
+ sin = sin.index_select(0, input_pos)
100
+ mask = self.mask_cache.index_select(2, input_pos)
101
+ mask = mask[:, :, :, : self.config.kv_cache_max]
102
+
103
+ # token embeddings of shape (b, t, n_embd)
104
+ x = self.tok_embedding(tokens)
105
+
106
+ updated_kv_entires = []
107
+ for i, block in enumerate(self.transformer_blocks):
108
+ kv_entry = kv_cache.caches[i] if kv_cache else None
109
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
110
+ if kv_entry:
111
+ updated_kv_entires.append(kv_entry)
112
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
113
+
114
+ x = self.final_norm(x)
115
+ res = self.lm_head(x) # (b, t, vocab_size)
116
+ return res, updated_kv_cache
117
+
118
+
119
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
120
+ attn_config = cfg.AttentionConfig(
121
+ num_heads=32,
122
+ num_query_groups=4,
123
+ rotary_percentage=1.0,
124
+ )
125
+ ff_config = cfg.FeedForwardConfig(
126
+ type=cfg.FeedForwardType.GATED,
127
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
128
+ intermediate_size=5632,
129
+ )
130
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
131
+ config = cfg.ModelConfig(
132
+ vocab_size=32000,
133
+ num_layers=22,
134
+ max_seq_len=2048,
135
+ embedding_dim=2048,
136
+ kv_cache_max_len=kv_cache_max_len,
137
+ attn_config=attn_config,
138
+ ff_config=ff_config,
139
+ pre_attention_norm_config=norm_config,
140
+ pre_ff_norm_config=norm_config,
141
+ final_norm_config=norm_config,
142
+ enable_hlfb=True,
143
+ )
144
+ return config
145
+
146
+
147
+ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
148
+ config = get_model_config(**kwargs)
149
+ config.vocab_size = 128
150
+ config.num_layers = 2
151
+ config.ff_config.intermediate_size = 256
152
+ return config
153
+
154
+
155
+ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
156
+ config = (
157
+ get_fake_model_config_for_test(**kwargs)
158
+ if test_model
159
+ else get_model_config(**kwargs)
160
+ )
161
+ model = TinyLLamma(config)
162
+ if checkpoint_path is not None:
163
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
164
+ loader.load(model)
165
+ model.eval()
166
+ return model
167
+
168
+
169
+ def define_and_run(checkpoint_path, test_model=False) -> None:
170
+ kv_cache_max_len = 1024
171
+ model = build_model(
172
+ checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
173
+ )
174
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
+ tokens[0, :4] = idx
177
+ input_pos = torch.arange(0, kv_cache_max_len)
178
+ kv = kv_utils.EKVCache.from_model_config(model.config)
179
+ print("running an inference")
180
+ print(model.forward(tokens, input_pos, kv))
181
+
182
+
183
+ if __name__ == "__main__":
184
+ checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
185
+ define_and_run(checkpoint_path)
@@ -159,6 +159,9 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
159
159
 
160
160
 
161
161
  def define_and_run_2b() -> None:
162
+ current_dir = Path(__file__).parent.resolve()
163
+ gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
164
+
162
165
  kv_cache_max_len = 1024
163
166
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
164
167
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
@@ -166,8 +169,9 @@ def define_and_run_2b() -> None:
166
169
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
167
170
  tokens[0, :4] = idx
168
171
  input_pos = torch.arange(0, kv_cache_max_len)
169
- print("running an inference")
170
- print(model.forward(tokens, input_pos))
172
+ lm_logits = model.forward(tokens, input_pos)
173
+ print("comparing with goldens..")
174
+ assert torch.allclose(gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
171
175
 
172
176
 
173
177
  if __name__ == "__main__":
@@ -149,6 +149,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
149
149
 
150
150
 
151
151
  def define_and_run() -> None:
152
+ current_dir = Path(__file__).parent.resolve()
153
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
152
154
  kv_cache_max_len = 1024
153
155
  checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/phi2")
154
156
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
@@ -156,8 +158,9 @@ def define_and_run() -> None:
156
158
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
157
159
  tokens[0, :4] = idx
158
160
  input_pos = torch.arange(0, kv_cache_max_len)
159
- print("running an inference")
160
- print(model.forward(tokens, input_pos))
161
+ lm_logits = model.forward(tokens, input_pos)
162
+ print("comparing with goldens..")
163
+ assert torch.allclose(phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
161
164
 
162
165
 
163
166
  if __name__ == "__main__":
@@ -557,7 +557,8 @@ def get_sample_encoder_input_ids() -> torch.Tensor:
557
557
 
558
558
 
559
559
  def define_and_run_t5(checkpoint_path: str) -> None:
560
- t5_goldens = torch.load("t5_lm_logits.pt")
560
+ current_dir = Path(__file__).parent.resolve()
561
+ t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
561
562
 
562
563
  model = build_t5_model(checkpoint_path)
563
564
 
@@ -579,7 +580,9 @@ def define_and_run_t5(checkpoint_path: str) -> None:
579
580
 
580
581
  # TODO(haoliang): Move those tests.
581
582
  def define_and_run_t5_split(checkpoint_path: str) -> None:
582
- t5_goldens = torch.load("t5_lm_logits.pt")
583
+ current_dir = Path(__file__).parent.resolve()
584
+ t5_goldens = torch.load(current_dir / "t5_lm_logits.pt")
585
+
583
586
  config = get_model_config_t5()
584
587
  embedding_layer = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)
585
588
  t5_encoder_model = build_t5_encoder_model(config, embedding_layer, checkpoint_path)
@@ -14,9 +14,8 @@
14
14
  # ==============================================================================
15
15
  # A toy example which has basic transformer block (w/ externalized KV-Cache).
16
16
 
17
- from typing import List, Tuple
17
+ from typing import Tuple
18
18
 
19
- import numpy as np
20
19
  import torch
21
20
  import torch.nn as nn
22
21
  import torch_xla
@@ -24,6 +23,7 @@ import torch_xla
24
23
  import ai_edge_torch
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
27
27
  from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
28
  import ai_edge_torch.generative.layers.model_config as cfg
29
29
 
@@ -60,27 +60,27 @@ class ToyModelWithExternalKV(torch.nn.Module):
60
60
 
61
61
  def forward(
62
62
  self,
63
- idx: torch.Tensor,
63
+ tokens: torch.Tensor,
64
64
  input_pos: torch.Tensor,
65
- k_caches: torch.Tensor,
66
- v_caches: torch.Tensor,
67
- ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
68
- x = self.tok_embedding(idx)
65
+ kv_cache: kv_utils.EKVCache,
66
+ ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
67
+ x = self.tok_embedding(tokens)
69
68
  cos, sin = self.rope_cache
70
69
  cos = cos.index_select(0, input_pos)
71
70
  sin = sin.index_select(0, input_pos)
72
71
  mask = self.mask_cache.index_select(2, input_pos)
73
72
  mask = mask[:, :, :, : self.config.max_seq_len]
74
73
 
74
+ updated_kv_entires = []
75
75
  for i, block in enumerate(self.transformer_blocks):
76
- input_k, input_v = k_caches[i], v_caches[i]
77
- x, (updated_k, updated_v) = block(
78
- x, (cos, sin), mask, input_pos, (input_k, input_v)
79
- )
80
- k_caches[i], v_caches[i] = updated_k, updated_v
76
+ kv_entry = kv_cache.caches[i] if kv_cache else None
77
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
78
+ if kv_entry:
79
+ updated_kv_entires.append(kv_entry)
81
80
 
82
81
  x = self.final_norm(x)
83
- return self.lm_head(x), k_caches, v_caches
82
+ updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
83
+ return self.lm_head(x), updated_kv_cache
84
84
 
85
85
 
86
86
  def _export_stablehlo_mlir(model, args):
@@ -115,15 +115,15 @@ def get_model_config() -> cfg.ModelConfig:
115
115
 
116
116
 
117
117
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
118
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
119
119
  input_pos = torch.arange(0, 100)
120
- return idx, input_pos
120
+ return tokens, input_pos
121
121
 
122
122
 
123
123
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
- idx = torch.tensor([[1]], dtype=torch.long)
124
+ tokens = torch.tensor([[1]], dtype=torch.long)
125
125
  input_pos = torch.tensor([10])
126
- return idx, input_pos
126
+ return tokens, input_pos
127
127
 
128
128
 
129
129
  def define_and_run() -> None:
@@ -131,16 +131,16 @@ def define_and_run() -> None:
131
131
 
132
132
  config = get_model_config()
133
133
  model = ToyModelWithExternalKV(config)
134
+ model.eval()
134
135
  print('running an inference')
135
- k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
- v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
+ kv = kv_utils.EKVCache.from_model_config(config)
137
137
 
138
- idx, input_pos = get_sample_prefill_inputs()
139
- decode_idx, decode_input_pos = get_sample_decode_inputs()
140
- print(model.forward(idx, input_pos, k_caches, v_caches))
138
+ tokens, input_pos = get_sample_prefill_inputs()
139
+ decode_token, decode_input_pos = get_sample_decode_inputs()
140
+ print(model.forward(tokens, input_pos, kv))
141
141
 
142
142
  if dump_mlir:
143
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
143
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
144
144
  with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
145
145
  f.write(mlir_text)
146
146
 
@@ -149,13 +149,28 @@ def define_and_run() -> None:
149
149
  # in dynamic update slice op.
150
150
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
151
151
  edge_model = (
152
- ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
153
- .signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
152
+ ai_edge_torch.signature(
153
+ 'prefill',
154
+ model,
155
+ sample_kwargs={
156
+ 'tokens': tokens,
157
+ 'input_pos': input_pos,
158
+ 'kv_cache': kv,
159
+ },
160
+ )
161
+ .signature(
162
+ 'decode',
163
+ model,
164
+ sample_kwargs={
165
+ 'tokens': decode_token,
166
+ 'input_pos': decode_input_pos,
167
+ 'kv_cache': kv,
168
+ },
169
+ )
154
170
  .convert()
155
171
  )
156
172
  edge_model.export('/tmp/toy_external_kv_cache.tflite')
157
173
 
158
174
 
159
175
  if __name__ == '__main__':
160
- with torch.inference_mode():
161
- define_and_run()
176
+ define_and_run()