ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  8. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  9. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  10. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  11. ai_edge_torch/generative/layers/attention.py +60 -63
  12. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  13. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  14. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  16. ai_edge_torch/generative/test/utils.py +54 -0
  17. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  18. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +25 -35
  22. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  24. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  25. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  26. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  29. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  30. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  31. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  32. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  33. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,26 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of phi2 with external KV cache.
18
- # Please use with caution.
15
+
16
+ """Example of building a Phi-2 model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
22
- from typing import Tuple
19
+ import pathlib
23
20
 
21
+ from ai_edge_torch.generative.layers import attention
24
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
25
  import ai_edge_torch.generative.layers.model_config as cfg
29
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
27
  import numpy as np
31
28
  import torch
32
29
  from torch import nn
33
30
 
34
-
35
31
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
32
  ff_up_proj="model.layers.{}.mlp.fc1",
37
33
  ff_down_proj="model.layers.{}.mlp.fc2",
@@ -89,13 +85,17 @@ class Phi2(nn.Module):
89
85
  self,
90
86
  tokens: torch.Tensor,
91
87
  input_pos: torch.Tensor,
92
- kv_cache: kv_utils.EKVCache,
93
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
88
+ kv_cache: kv_utils.KVCache,
89
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
90
  _, seq_len = tokens.size()
95
91
  assert self.config.max_seq_len >= seq_len, (
96
92
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
93
  f" {self.config.max_seq_len}"
98
94
  )
95
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
96
+ "The number of transformer blocks and the number of KV cache entries"
97
+ " must be the same."
98
+ )
99
99
 
100
100
  cos, sin = self.rope_cache
101
101
  cos = cos.index_select(0, input_pos)
@@ -111,11 +111,11 @@ class Phi2(nn.Module):
111
111
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
112
112
  if kv_entry:
113
113
  updated_kv_entires.append(kv_entry)
114
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
114
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
115
115
 
116
116
  x = self.final_norm(x)
117
- res = self.lm_head(x) # (b, t, vocab_size)
118
- return res, updated_kv_cache
117
+ logits = self.lm_head(x) # (b, t, vocab_size)
118
+ return {"logits": logits, "kv_cache": updated_kv_cache}
119
119
 
120
120
 
121
121
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -169,39 +169,37 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
169
169
  return config
170
170
 
171
171
 
172
- def build_model(
173
- checkpoint_path: str, test_model: bool = False, **kwargs
174
- ) -> nn.Module:
172
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
175
173
  """Instantiates the model instance and load checkpoint if provided."""
176
- config = (
177
- get_fake_model_config(**kwargs)
178
- if test_model
179
- else get_model_config(**kwargs)
180
- )
174
+ config = get_model_config(**kwargs)
181
175
  model = Phi2(config)
182
- if checkpoint_path is not None:
183
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
184
- loader.load(model)
176
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
177
+ loader.load(model)
185
178
  model.eval()
186
179
  return model
187
180
 
188
181
 
189
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
182
+ def define_and_run(checkpoint_path: str) -> None:
190
183
  """Instantiates and runs a Phi-2 model."""
191
184
 
185
+ current_dir = pathlib.Path(__file__).parent.resolve()
186
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
187
  kv_cache_max_len = 1024
193
- model = build_model(
194
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
195
- )
188
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
196
189
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
197
190
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
198
191
  tokens[0, :4] = idx
199
192
  input_pos = torch.arange(0, kv_cache_max_len)
200
- kv = kv_utils.EKVCache.from_model_config(model.config)
201
- print("running an inference")
202
- print(model.forward(tokens, input_pos, kv))
193
+ kv = kv_utils.KVCache.from_model_config(model.config)
194
+ output = model.forward(tokens, input_pos, kv)
195
+ print("comparing with goldens..")
196
+ assert torch.allclose(
197
+ phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
198
+ )
203
199
 
204
200
 
205
201
  if __name__ == "__main__":
206
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
202
+ input_checkpoint_path = os.path.join(
203
+ pathlib.Path.home(), "Downloads/llm_data/phi2"
204
+ )
207
205
  define_and_run(input_checkpoint_path)
@@ -12,14 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A toy example which has basic transformer block (w/ KV-Cache).
15
+
16
+ """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
+
16
18
  from typing import Tuple
17
19
 
18
20
  import ai_edge_torch
19
21
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
26
  import ai_edge_torch.generative.layers.model_config as cfg
24
27
  import torch
25
28
  import torch.nn as nn
@@ -27,7 +30,7 @@ import torch.nn as nn
27
30
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
28
31
 
29
32
 
30
- class ToyModelWithKV(torch.nn.Module):
33
+ class ToyModelWithKVCache(torch.nn.Module):
31
34
 
32
35
  def __init__(self, config: cfg.ModelConfig) -> None:
33
36
  super().__init__()
@@ -36,7 +39,7 @@ class ToyModelWithKV(torch.nn.Module):
36
39
  )
37
40
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
38
41
  self.transformer_blocks = nn.ModuleList(
39
- TransformerBlock(config) for _ in range(config.num_layers)
42
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
40
43
  )
41
44
  self.final_norm = builder.build_norm(
42
45
  config.embedding_dim,
@@ -57,18 +60,29 @@ class ToyModelWithKV(torch.nn.Module):
57
60
  )
58
61
  self.config = config
59
62
 
60
- @torch.inference_mode
61
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
- x = self.tok_embedding(idx)
63
+ def forward(
64
+ self,
65
+ tokens: torch.Tensor,
66
+ input_pos: torch.Tensor,
67
+ kv_cache: kv_utils.KVCache,
68
+ ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
69
+ x = self.tok_embedding(tokens)
63
70
  cos, sin = self.rope_cache
64
71
  cos = cos.index_select(0, input_pos)
65
72
  sin = sin.index_select(0, input_pos)
66
73
  mask = self.mask_cache.index_select(2, input_pos)
67
74
  mask = mask[:, :, :, : self.config.max_seq_len]
75
+
76
+ updated_kv_entires = []
68
77
  for i, block in enumerate(self.transformer_blocks):
69
- x = block(x, (cos, sin), mask, input_pos)
78
+ kv_entry = kv_cache.caches[i] if kv_cache else None
79
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
80
+ if kv_entry:
81
+ updated_kv_entires.append(kv_entry)
82
+
70
83
  x = self.final_norm(x)
71
- return self.lm_head(x)
84
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
85
+ return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
72
86
 
73
87
 
74
88
  def _export_stablehlo_mlir(model, args):
@@ -89,7 +103,7 @@ def get_model_config() -> cfg.ModelConfig:
89
103
  config = cfg.ModelConfig(
90
104
  vocab_size=150,
91
105
  num_layers=2,
92
- max_seq_len=500,
106
+ max_seq_len=100,
93
107
  embedding_dim=128,
94
108
  attn_config=attn_config,
95
109
  ff_config=ff_config,
@@ -102,40 +116,59 @@ def get_model_config() -> cfg.ModelConfig:
102
116
 
103
117
 
104
118
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
105
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
106
120
  input_pos = torch.arange(0, 100)
107
- return idx, input_pos
121
+ return tokens, input_pos
108
122
 
109
123
 
110
124
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
111
- idx = torch.tensor([[1]], dtype=torch.long)
112
- input_pos = torch.tensor([10], dtype=torch.int64)
113
- return idx, input_pos
125
+ tokens = torch.tensor([[1]], dtype=torch.long)
126
+ input_pos = torch.tensor([10])
127
+ return tokens, input_pos
114
128
 
115
129
 
116
130
  def define_and_run() -> None:
117
131
  dump_mlir = False
118
132
 
119
133
  config = get_model_config()
120
- model = ToyModelWithKV(config)
134
+ model = ToyModelWithExternalKV(config)
135
+ model.eval()
121
136
  print('running an inference')
122
- idx, input_pos = get_sample_prefill_inputs()
123
- decode_idx, decode_input_pos = get_sample_decode_inputs()
124
- print(model.forward(idx, input_pos))
137
+ kv = kv_utils.KVCache.from_model_config(config)
138
+
139
+ tokens, input_pos = get_sample_prefill_inputs()
140
+ decode_token, decode_input_pos = get_sample_decode_inputs()
141
+ print(model.forward(tokens, input_pos, kv))
125
142
 
126
143
  if dump_mlir:
127
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
128
- with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
144
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
145
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
129
146
  f.write(mlir_text)
130
147
 
131
148
  # Convert model to tflite with 2 signatures (prefill + decode).
132
149
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
133
150
  edge_model = (
134
- ai_edge_torch.signature('prefill', model, (idx, input_pos))
135
- .signature('decode', model, (decode_idx, decode_input_pos))
151
+ ai_edge_torch.signature(
152
+ 'prefill',
153
+ model,
154
+ sample_kwargs={
155
+ 'tokens': tokens,
156
+ 'input_pos': input_pos,
157
+ 'kv_cache': kv,
158
+ },
159
+ )
160
+ .signature(
161
+ 'decode',
162
+ model,
163
+ sample_kwargs={
164
+ 'tokens': decode_token,
165
+ 'input_pos': decode_input_pos,
166
+ 'kv_cache': kv,
167
+ },
168
+ )
136
169
  .convert()
137
170
  )
138
- edge_model.export('/tmp/toy_kv_cache.tflite')
171
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
139
172
 
140
173
 
141
174
  if __name__ == '__main__':
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting TinyLlama model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
67
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
+ convert_tiny_llama_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
15
+
16
+ """Example of building a TinyLlama model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -80,16 +82,22 @@ class TinyLLamma(nn.Module):
80
82
  )
81
83
  self.config = config
82
84
 
83
- # The model's forward function takes in additional k/v cache tensors
84
- # and returns the updated k/v cache tensors to the caller.
85
- # This can be eliminated if we handle k/v cache updates inside the model itself.
86
85
  @torch.inference_mode
87
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
- _, seq_len = idx.size()
86
+ def forward(
87
+ self,
88
+ tokens: torch.Tensor,
89
+ input_pos: torch.Tensor,
90
+ kv_cache: kv_utils.KVCache,
91
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
92
+ _, seq_len = tokens.size()
89
93
  assert self.config.max_seq_len >= seq_len, (
90
94
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
95
  f" {self.config.max_seq_len}"
92
96
  )
97
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
98
+ "The number of transformer blocks and the number of KV cache entries"
99
+ " must be the same."
100
+ )
93
101
 
94
102
  cos, sin = self.rope_cache
95
103
  cos = cos.index_select(0, input_pos)
@@ -97,16 +105,20 @@ class TinyLLamma(nn.Module):
97
105
  mask = self.mask_cache.index_select(2, input_pos)
98
106
  mask = mask[:, :, :, : self.config.kv_cache_max]
99
107
 
100
- # forward the model itself
101
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
108
+ # token embeddings of shape (b, t, n_embd)
109
+ x = self.tok_embedding(tokens)
102
110
 
103
- for _, block in enumerate(self.transformer_blocks):
104
- x = block(x, (cos, sin), mask, input_pos)
111
+ updated_kv_entires = []
112
+ for i, block in enumerate(self.transformer_blocks):
113
+ kv_entry = kv_cache.caches[i] if kv_cache else None
114
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
115
+ if kv_entry:
116
+ updated_kv_entires.append(kv_entry)
117
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
105
118
 
106
119
  x = self.final_norm(x)
107
-
108
- res = self.lm_head(x) # (b, t, vocab_size)
109
- return res
120
+ logits = self.lm_head(x) # (b, t, vocab_size)
121
+ return {"logits": logits, "kv_cache": updated_kv_cache}
110
122
 
111
123
 
112
124
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -147,8 +159,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
159
  return config
148
160
 
149
161
 
150
- def get_fake_model_config() -> cfg.ModelConfig:
151
- config = get_model_config()
162
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
163
+ config = get_model_config(**kwargs)
152
164
  config.vocab_size = 128
153
165
  config.num_layers = 2
154
166
  config.ff_config.intermediate_size = 64
@@ -160,26 +172,30 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
160
172
  model = TinyLLamma(config)
161
173
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
162
174
  loader.load(model)
175
+ model.eval()
163
176
  return model
164
177
 
165
178
 
166
- def define_and_run() -> None:
179
+ def define_and_run(checkpoint_path: str) -> None:
167
180
  """Instantiates and runs a TinyLlama model."""
168
181
 
169
- current_dir = Path(__file__).parent.resolve()
182
+ current_dir = pathlib.Path(__file__).parent.resolve()
170
183
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
171
184
  kv_cache_max_len = 1024
172
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
173
185
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
174
186
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
187
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
188
  tokens[0, :4] = idx
177
189
  input_pos = torch.arange(0, kv_cache_max_len)
178
- lm_logits = model.forward(tokens, input_pos)
190
+ kv = kv_utils.KVCache.from_model_config(model.config)
191
+ output = model.forward(tokens, input_pos, kv)
179
192
  assert torch.allclose(
180
- tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
193
+ tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
181
194
  )
182
195
 
183
196
 
184
197
  if __name__ == "__main__":
185
- define_and_run()
198
+ input_checkpoint_path = os.path.join(
199
+ pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
200
+ )
201
+ define_and_run(input_checkpoint_path)