ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
  9. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
  12. ai_edge_torch/generative/layers/attention.py +60 -63
  13. ai_edge_torch/generative/layers/builder.py +4 -2
  14. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  15. ai_edge_torch/generative/layers/model_config.py +1 -0
  16. ai_edge_torch/generative/layers/normalization.py +158 -0
  17. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  18. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  19. ai_edge_torch/generative/test/test_loader.py +1 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  21. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  22. ai_edge_torch/generative/test/utils.py +54 -0
  23. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  24. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  25. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  26. ai_edge_torch/version.py +1 -1
  27. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
  29. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  30. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  31. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  32. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  33. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  34. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  35. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  36. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  37. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  38. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  39. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
  42. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmalLM model."""
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
+ # SmalLM re-uses the embedding as the head projection layer.
32
+ TENSOR_NAMES.lm_head = None
33
+
34
+
35
+ class SmalLM(tiny_llama.TinyLlama):
36
+ """A SmalLM model built from the Edge Generative API layers.
37
+
38
+ SmalLM shares the same architecture as TinyLlama, but with different model
39
+ sizes.
40
+ """
41
+
42
+ def __init__(self, config: cfg.ModelConfig):
43
+ super().__init__(config)
44
+ # SmalLM re-uses the embedding as the head projection layer.
45
+ self.lm_head.weight.data = self.tok_embedding.weight.data
46
+
47
+
48
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
+ """Returns the model config for a SmalLM 135M model.
50
+
51
+ Args:
52
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
+ is 1024.
54
+
55
+ Returns:
56
+ The model config for a SmalLM model.
57
+ """
58
+ attn_config = cfg.AttentionConfig(
59
+ num_heads=9,
60
+ head_dim=64,
61
+ num_query_groups=3,
62
+ rotary_percentage=1.0,
63
+ )
64
+ ff_config = cfg.FeedForwardConfig(
65
+ type=cfg.FeedForwardType.GATED,
66
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
67
+ intermediate_size=1536,
68
+ )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
70
+ config = cfg.ModelConfig(
71
+ vocab_size=49152,
72
+ num_layers=30,
73
+ max_seq_len=2048,
74
+ embedding_dim=576,
75
+ kv_cache_max_len=kv_cache_max_len,
76
+ attn_config=attn_config,
77
+ ff_config=ff_config,
78
+ pre_attention_norm_config=norm_config,
79
+ post_attention_norm_config=norm_config,
80
+ final_norm_config=norm_config,
81
+ enable_hlfb=True,
82
+ )
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ config = get_model_config(**kwargs)
88
+ model = SmalLM(config)
89
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
90
+ # since embedding and lm-head use the same weight, we need to set strict
91
+ # to False.
92
+ loader.load(model, strict=False)
93
+ model.eval()
94
+ return model
95
+
96
+
97
+ def define_and_run(checkpoint_path: str) -> None:
98
+ """Instantiates and runs a SmalLM model."""
99
+
100
+ current_dir = pathlib.Path(__file__).parent.resolve()
101
+ smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
102
+ kv_cache_max_len = 1024
103
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
104
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
105
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
106
+ tokens[0, :4] = idx
107
+ input_pos = torch.arange(0, kv_cache_max_len)
108
+ kv = kv_utils.KVCache.from_model_config(model.config)
109
+ output = model.forward(tokens, input_pos, kv)
110
+ assert torch.allclose(
111
+ smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
112
+ )
113
+
114
+
115
+ if __name__ == "__main__":
116
+ input_checkpoint_path = os.path.join(
117
+ pathlib.Path.home(), "Downloads/llm_data/smallm"
118
+ )
119
+ define_and_run(input_checkpoint_path)
@@ -12,14 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A toy example which has basic transformer block (w/ KV-Cache).
15
+
16
+ """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
+
16
18
  from typing import Tuple
17
19
 
18
20
  import ai_edge_torch
19
21
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
26
  import ai_edge_torch.generative.layers.model_config as cfg
24
27
  import torch
25
28
  import torch.nn as nn
@@ -27,7 +30,7 @@ import torch.nn as nn
27
30
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
28
31
 
29
32
 
30
- class ToyModelWithKV(torch.nn.Module):
33
+ class ToyModelWithKVCache(torch.nn.Module):
31
34
 
32
35
  def __init__(self, config: cfg.ModelConfig) -> None:
33
36
  super().__init__()
@@ -36,7 +39,7 @@ class ToyModelWithKV(torch.nn.Module):
36
39
  )
37
40
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
38
41
  self.transformer_blocks = nn.ModuleList(
39
- TransformerBlock(config) for _ in range(config.num_layers)
42
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
40
43
  )
41
44
  self.final_norm = builder.build_norm(
42
45
  config.embedding_dim,
@@ -57,18 +60,29 @@ class ToyModelWithKV(torch.nn.Module):
57
60
  )
58
61
  self.config = config
59
62
 
60
- @torch.inference_mode
61
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
- x = self.tok_embedding(idx)
63
+ def forward(
64
+ self,
65
+ tokens: torch.Tensor,
66
+ input_pos: torch.Tensor,
67
+ kv_cache: kv_utils.KVCache,
68
+ ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
69
+ x = self.tok_embedding(tokens)
63
70
  cos, sin = self.rope_cache
64
71
  cos = cos.index_select(0, input_pos)
65
72
  sin = sin.index_select(0, input_pos)
66
73
  mask = self.mask_cache.index_select(2, input_pos)
67
74
  mask = mask[:, :, :, : self.config.max_seq_len]
75
+
76
+ updated_kv_entires = []
68
77
  for i, block in enumerate(self.transformer_blocks):
69
- x = block(x, (cos, sin), mask, input_pos)
78
+ kv_entry = kv_cache.caches[i] if kv_cache else None
79
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
80
+ if kv_entry:
81
+ updated_kv_entires.append(kv_entry)
82
+
70
83
  x = self.final_norm(x)
71
- return self.lm_head(x)
84
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
85
+ return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
72
86
 
73
87
 
74
88
  def _export_stablehlo_mlir(model, args):
@@ -89,7 +103,7 @@ def get_model_config() -> cfg.ModelConfig:
89
103
  config = cfg.ModelConfig(
90
104
  vocab_size=150,
91
105
  num_layers=2,
92
- max_seq_len=500,
106
+ max_seq_len=100,
93
107
  embedding_dim=128,
94
108
  attn_config=attn_config,
95
109
  ff_config=ff_config,
@@ -102,40 +116,59 @@ def get_model_config() -> cfg.ModelConfig:
102
116
 
103
117
 
104
118
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
105
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
106
120
  input_pos = torch.arange(0, 100)
107
- return idx, input_pos
121
+ return tokens, input_pos
108
122
 
109
123
 
110
124
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
111
- idx = torch.tensor([[1]], dtype=torch.long)
112
- input_pos = torch.tensor([10], dtype=torch.int64)
113
- return idx, input_pos
125
+ tokens = torch.tensor([[1]], dtype=torch.long)
126
+ input_pos = torch.tensor([10])
127
+ return tokens, input_pos
114
128
 
115
129
 
116
130
  def define_and_run() -> None:
117
131
  dump_mlir = False
118
132
 
119
133
  config = get_model_config()
120
- model = ToyModelWithKV(config)
134
+ model = ToyModelWithExternalKV(config)
135
+ model.eval()
121
136
  print('running an inference')
122
- idx, input_pos = get_sample_prefill_inputs()
123
- decode_idx, decode_input_pos = get_sample_decode_inputs()
124
- print(model.forward(idx, input_pos))
137
+ kv = kv_utils.KVCache.from_model_config(config)
138
+
139
+ tokens, input_pos = get_sample_prefill_inputs()
140
+ decode_token, decode_input_pos = get_sample_decode_inputs()
141
+ print(model.forward(tokens, input_pos, kv))
125
142
 
126
143
  if dump_mlir:
127
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
128
- with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
144
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
145
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
129
146
  f.write(mlir_text)
130
147
 
131
148
  # Convert model to tflite with 2 signatures (prefill + decode).
132
149
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
133
150
  edge_model = (
134
- ai_edge_torch.signature('prefill', model, (idx, input_pos))
135
- .signature('decode', model, (decode_idx, decode_input_pos))
151
+ ai_edge_torch.signature(
152
+ 'prefill',
153
+ model,
154
+ sample_kwargs={
155
+ 'tokens': tokens,
156
+ 'input_pos': input_pos,
157
+ 'kv_cache': kv,
158
+ },
159
+ )
160
+ .signature(
161
+ 'decode',
162
+ model,
163
+ sample_kwargs={
164
+ 'tokens': decode_token,
165
+ 'input_pos': decode_input_pos,
166
+ 'kv_cache': kv,
167
+ },
168
+ )
136
169
  .convert()
137
170
  )
138
- edge_model.export('/tmp/toy_kv_cache.tflite')
171
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
139
172
 
140
173
 
141
174
  if __name__ == '__main__':
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting TinyLlama model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
67
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
+ convert_tiny_llama_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
15
+
16
+ """Example of building a TinyLlama model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -42,7 +44,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
44
  )
43
45
 
44
46
 
45
- class TinyLLamma(nn.Module):
47
+ class TinyLlama(nn.Module):
46
48
  """A TinyLlama model built from the Edge Generative API layers."""
47
49
 
48
50
  def __init__(self, config: cfg.ModelConfig):
@@ -80,16 +82,22 @@ class TinyLLamma(nn.Module):
80
82
  )
81
83
  self.config = config
82
84
 
83
- # The model's forward function takes in additional k/v cache tensors
84
- # and returns the updated k/v cache tensors to the caller.
85
- # This can be eliminated if we handle k/v cache updates inside the model itself.
86
85
  @torch.inference_mode
87
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
- _, seq_len = idx.size()
86
+ def forward(
87
+ self,
88
+ tokens: torch.Tensor,
89
+ input_pos: torch.Tensor,
90
+ kv_cache: kv_utils.KVCache,
91
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
92
+ _, seq_len = tokens.size()
89
93
  assert self.config.max_seq_len >= seq_len, (
90
94
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
95
  f" {self.config.max_seq_len}"
92
96
  )
97
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
98
+ "The number of transformer blocks and the number of KV cache entries"
99
+ " must be the same."
100
+ )
93
101
 
94
102
  cos, sin = self.rope_cache
95
103
  cos = cos.index_select(0, input_pos)
@@ -97,16 +105,20 @@ class TinyLLamma(nn.Module):
97
105
  mask = self.mask_cache.index_select(2, input_pos)
98
106
  mask = mask[:, :, :, : self.config.kv_cache_max]
99
107
 
100
- # forward the model itself
101
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
108
+ # token embeddings of shape (b, t, n_embd)
109
+ x = self.tok_embedding(tokens)
102
110
 
103
- for _, block in enumerate(self.transformer_blocks):
104
- x = block(x, (cos, sin), mask, input_pos)
111
+ updated_kv_entires = []
112
+ for i, block in enumerate(self.transformer_blocks):
113
+ kv_entry = kv_cache.caches[i] if kv_cache else None
114
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
115
+ if kv_entry:
116
+ updated_kv_entires.append(kv_entry)
117
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
105
118
 
106
119
  x = self.final_norm(x)
107
-
108
- res = self.lm_head(x) # (b, t, vocab_size)
109
- return res
120
+ logits = self.lm_head(x) # (b, t, vocab_size)
121
+ return {"logits": logits, "kv_cache": updated_kv_cache}
110
122
 
111
123
 
112
124
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -147,8 +159,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
159
  return config
148
160
 
149
161
 
150
- def get_fake_model_config() -> cfg.ModelConfig:
151
- config = get_model_config()
162
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
163
+ config = get_model_config(**kwargs)
152
164
  config.vocab_size = 128
153
165
  config.num_layers = 2
154
166
  config.ff_config.intermediate_size = 64
@@ -157,29 +169,33 @@ def get_fake_model_config() -> cfg.ModelConfig:
157
169
 
158
170
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
159
171
  config = get_model_config(**kwargs)
160
- model = TinyLLamma(config)
172
+ model = TinyLlama(config)
161
173
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
162
174
  loader.load(model)
175
+ model.eval()
163
176
  return model
164
177
 
165
178
 
166
- def define_and_run() -> None:
179
+ def define_and_run(checkpoint_path: str) -> None:
167
180
  """Instantiates and runs a TinyLlama model."""
168
181
 
169
- current_dir = Path(__file__).parent.resolve()
182
+ current_dir = pathlib.Path(__file__).parent.resolve()
170
183
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
171
184
  kv_cache_max_len = 1024
172
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
173
185
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
174
186
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
187
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
188
  tokens[0, :4] = idx
177
189
  input_pos = torch.arange(0, kv_cache_max_len)
178
- lm_logits = model.forward(tokens, input_pos)
190
+ kv = kv_utils.KVCache.from_model_config(model.config)
191
+ output = model.forward(tokens, input_pos, kv)
179
192
  assert torch.allclose(
180
- tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
193
+ tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
181
194
  )
182
195
 
183
196
 
184
197
  if __name__ == "__main__":
185
- define_and_run()
198
+ input_checkpoint_path = os.path.join(
199
+ pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
200
+ )
201
+ define_and_run(input_checkpoint_path)