ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
21
  import ai_edge_torch.generative.layers.builder as builder
22
22
  import ai_edge_torch.generative.layers.model_config as cfg
23
- import numpy as np
24
23
  import torch
25
24
  import torch.nn as nn
26
25
 
@@ -36,16 +35,16 @@ class ToySingleLayerModel(torch.nn.Module):
36
35
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
37
36
  )
38
37
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
39
- self.transformer_block = TransformerBlock(config)
38
+ self.transformer_block = TransformerBlock(config.block_config(0), config)
40
39
  self.final_norm = builder.build_norm(
41
40
  config.embedding_dim,
42
41
  config.final_norm_config,
43
42
  )
43
+ # Toy model has only one block config.
44
+ attn_config = config.block_config(0).attn_config
44
45
  self.rope_cache = attn_utils.build_rope_cache(
45
46
  size=config.max_seq_len,
46
- dim=int(
47
- config.attn_config.rotary_percentage * config.attn_config.head_dim
48
- ),
47
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
49
48
  base=10_000,
50
49
  condense_ratio=1,
51
50
  dtype=torch.float32,
@@ -85,16 +84,16 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
85
84
  bias=config.lm_head_use_bias,
86
85
  )
87
86
  self.lm_head.weight.data = self.tok_embedding.weight.data
88
- self.transformer_block = TransformerBlock(config)
87
+ self.transformer_block = TransformerBlock(config.block_config(0), config)
89
88
  self.final_norm = builder.build_norm(
90
89
  config.embedding_dim,
91
90
  config.final_norm_config,
92
91
  )
92
+ # Toy model has only one block config.
93
+ attn_config = config.block_config(0).attn_config
93
94
  self.rope_cache = attn_utils.build_rope_cache(
94
95
  size=config.max_seq_len,
95
- dim=int(
96
- config.attn_config.rotary_percentage * config.attn_config.head_dim
97
- ),
96
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
98
97
  base=10_000,
99
98
  condense_ratio=1,
100
99
  dtype=torch.float32,
@@ -135,15 +134,18 @@ def get_model_config() -> cfg.ModelConfig:
135
134
  intermediate_size=256,
136
135
  )
137
136
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
137
+ block_config = cfg.TransformerBlockConfig(
138
+ attn_config=attn_config,
139
+ ff_config=ff_config,
140
+ pre_attention_norm_config=norm_config,
141
+ post_attention_norm_config=norm_config,
142
+ )
138
143
  config = cfg.ModelConfig(
139
144
  vocab_size=400,
140
145
  num_layers=1,
141
146
  max_seq_len=KV_CACHE_MAX_LEN,
142
147
  embedding_dim=128,
143
- attn_config=attn_config,
144
- ff_config=ff_config,
145
- pre_attention_norm_config=norm_config,
146
- post_attention_norm_config=norm_config,
148
+ block_configs=block_config,
147
149
  final_norm_config=norm_config,
148
150
  )
149
151
  return config
@@ -12,14 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A toy example which has basic transformer block (w/ KV-Cache).
15
+
16
+ """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
+
16
18
  from typing import Tuple
17
19
 
18
20
  import ai_edge_torch
19
21
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
26
  import ai_edge_torch.generative.layers.model_config as cfg
24
27
  import torch
25
28
  import torch.nn as nn
@@ -27,7 +30,7 @@ import torch.nn as nn
27
30
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
28
31
 
29
32
 
30
- class ToyModelWithKV(torch.nn.Module):
33
+ class ToyModelWithKVCache(torch.nn.Module):
31
34
 
32
35
  def __init__(self, config: cfg.ModelConfig) -> None:
33
36
  super().__init__()
@@ -35,18 +38,20 @@ class ToyModelWithKV(torch.nn.Module):
35
38
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
36
39
  )
37
40
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
41
+ # Toy model has only one block config.
42
+ block_config = config.block_config(0)
38
43
  self.transformer_blocks = nn.ModuleList(
39
- TransformerBlock(config) for _ in range(config.num_layers)
44
+ attention.TransformerBlock(block_config, config)
45
+ for _ in range(config.num_layers)
40
46
  )
41
47
  self.final_norm = builder.build_norm(
42
48
  config.embedding_dim,
43
49
  config.final_norm_config,
44
50
  )
51
+ attn_config = block_config.attn_config
45
52
  self.rope_cache = attn_utils.build_rope_cache(
46
53
  size=config.max_seq_len,
47
- dim=int(
48
- config.attn_config.rotary_percentage * config.attn_config.head_dim
49
- ),
54
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
50
55
  base=10_000,
51
56
  condense_ratio=1,
52
57
  dtype=torch.float32,
@@ -57,18 +62,29 @@ class ToyModelWithKV(torch.nn.Module):
57
62
  )
58
63
  self.config = config
59
64
 
60
- @torch.inference_mode
61
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
- x = self.tok_embedding(idx)
65
+ def forward(
66
+ self,
67
+ tokens: torch.Tensor,
68
+ input_pos: torch.Tensor,
69
+ kv_cache: kv_utils.KVCache,
70
+ ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
71
+ x = self.tok_embedding(tokens)
63
72
  cos, sin = self.rope_cache
64
73
  cos = cos.index_select(0, input_pos)
65
74
  sin = sin.index_select(0, input_pos)
66
75
  mask = self.mask_cache.index_select(2, input_pos)
67
76
  mask = mask[:, :, :, : self.config.max_seq_len]
77
+
78
+ updated_kv_entires = []
68
79
  for i, block in enumerate(self.transformer_blocks):
69
- x = block(x, (cos, sin), mask, input_pos)
80
+ kv_entry = kv_cache.caches[i] if kv_cache else None
81
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
82
+ if kv_entry:
83
+ updated_kv_entires.append(kv_entry)
84
+
70
85
  x = self.final_norm(x)
71
- return self.lm_head(x)
86
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
87
+ return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
72
88
 
73
89
 
74
90
  def _export_stablehlo_mlir(model, args):
@@ -78,7 +94,10 @@ def _export_stablehlo_mlir(model, args):
78
94
 
79
95
  def get_model_config() -> cfg.ModelConfig:
80
96
  attn_config = cfg.AttentionConfig(
81
- num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
97
+ num_heads=32,
98
+ head_dim=4,
99
+ num_query_groups=4,
100
+ rotary_percentage=1.0,
82
101
  )
83
102
  ff_config = cfg.FeedForwardConfig(
84
103
  type=cfg.FeedForwardType.GATED,
@@ -86,15 +105,18 @@ def get_model_config() -> cfg.ModelConfig:
86
105
  intermediate_size=256,
87
106
  )
88
107
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
89
- config = cfg.ModelConfig(
90
- vocab_size=150,
91
- num_layers=2,
92
- max_seq_len=500,
93
- embedding_dim=128,
108
+ block_config = cfg.TransformerBlockConfig(
94
109
  attn_config=attn_config,
95
110
  ff_config=ff_config,
96
111
  pre_attention_norm_config=norm_config,
97
112
  post_attention_norm_config=norm_config,
113
+ )
114
+ config = cfg.ModelConfig(
115
+ vocab_size=150,
116
+ num_layers=2,
117
+ max_seq_len=100,
118
+ embedding_dim=128,
119
+ block_configs=block_config,
98
120
  final_norm_config=norm_config,
99
121
  enable_hlfb=True,
100
122
  )
@@ -102,40 +124,59 @@ def get_model_config() -> cfg.ModelConfig:
102
124
 
103
125
 
104
126
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
105
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
127
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
106
128
  input_pos = torch.arange(0, 100)
107
- return idx, input_pos
129
+ return tokens, input_pos
108
130
 
109
131
 
110
132
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
111
- idx = torch.tensor([[1]], dtype=torch.long)
112
- input_pos = torch.tensor([10], dtype=torch.int64)
113
- return idx, input_pos
133
+ tokens = torch.tensor([[1]], dtype=torch.long)
134
+ input_pos = torch.tensor([10])
135
+ return tokens, input_pos
114
136
 
115
137
 
116
138
  def define_and_run() -> None:
117
139
  dump_mlir = False
118
140
 
119
141
  config = get_model_config()
120
- model = ToyModelWithKV(config)
142
+ model = ToyModelWithExternalKV(config)
143
+ model.eval()
121
144
  print('running an inference')
122
- idx, input_pos = get_sample_prefill_inputs()
123
- decode_idx, decode_input_pos = get_sample_decode_inputs()
124
- print(model.forward(idx, input_pos))
145
+ kv = kv_utils.KVCache.from_model_config(config)
146
+
147
+ tokens, input_pos = get_sample_prefill_inputs()
148
+ decode_token, decode_input_pos = get_sample_decode_inputs()
149
+ print(model.forward(tokens, input_pos, kv))
125
150
 
126
151
  if dump_mlir:
127
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
128
- with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
152
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
129
154
  f.write(mlir_text)
130
155
 
131
156
  # Convert model to tflite with 2 signatures (prefill + decode).
132
157
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
133
158
  edge_model = (
134
- ai_edge_torch.signature('prefill', model, (idx, input_pos))
135
- .signature('decode', model, (decode_idx, decode_input_pos))
159
+ ai_edge_torch.signature(
160
+ 'prefill',
161
+ model,
162
+ sample_kwargs={
163
+ 'tokens': tokens,
164
+ 'input_pos': input_pos,
165
+ 'kv_cache': kv,
166
+ },
167
+ )
168
+ .signature(
169
+ 'decode',
170
+ model,
171
+ sample_kwargs={
172
+ 'tokens': decode_token,
173
+ 'input_pos': decode_input_pos,
174
+ 'kv_cache': kv,
175
+ },
176
+ )
136
177
  .convert()
137
178
  )
138
- edge_model.export('/tmp/toy_kv_cache.tflite')
179
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
139
180
 
140
181
 
141
182
  if __name__ == '__main__':
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting TinyLlama model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
67
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
+ convert_tiny_llama_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
15
+
16
+ """Example of building a TinyLlama model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -42,13 +44,12 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
44
  )
43
45
 
44
46
 
45
- class TinyLLamma(nn.Module):
47
+ class TinyLlama(nn.Module):
46
48
  """A TinyLlama model built from the Edge Generative API layers."""
47
49
 
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.lm_head = nn.Linear(
54
55
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
@@ -56,18 +57,20 @@ class TinyLLamma(nn.Module):
56
57
  self.tok_embedding = nn.Embedding(
57
58
  config.vocab_size, config.embedding_dim, padding_idx=0
58
59
  )
60
+ # TinyLlama has only one block config.
61
+ block_config = config.block_config(0)
59
62
  self.transformer_blocks = nn.ModuleList(
60
- attention.TransformerBlock(config) for _ in range(config.num_layers)
63
+ attention.TransformerBlock(block_config, config)
64
+ for _ in range(config.num_layers)
61
65
  )
62
66
  self.final_norm = builder.build_norm(
63
67
  config.embedding_dim,
64
68
  config.final_norm_config,
65
69
  )
70
+ attn_config = block_config.attn_config
66
71
  self.rope_cache = attn_utils.build_rope_cache(
67
72
  size=config.kv_cache_max,
68
- dim=int(
69
- config.attn_config.rotary_percentage * config.attn_config.head_dim
70
- ),
73
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
74
  base=10_000,
72
75
  condense_ratio=1,
73
76
  dtype=torch.float32,
@@ -80,16 +83,22 @@ class TinyLLamma(nn.Module):
80
83
  )
81
84
  self.config = config
82
85
 
83
- # The model's forward function takes in additional k/v cache tensors
84
- # and returns the updated k/v cache tensors to the caller.
85
- # This can be eliminated if we handle k/v cache updates inside the model itself.
86
86
  @torch.inference_mode
87
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
- _, seq_len = idx.size()
87
+ def forward(
88
+ self,
89
+ tokens: torch.Tensor,
90
+ input_pos: torch.Tensor,
91
+ kv_cache: kv_utils.KVCache,
92
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
93
+ _, seq_len = tokens.size()
89
94
  assert self.config.max_seq_len >= seq_len, (
90
95
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
96
  f" {self.config.max_seq_len}"
92
97
  )
98
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
99
+ "The number of transformer blocks and the number of KV cache entries"
100
+ " must be the same."
101
+ )
93
102
 
94
103
  cos, sin = self.rope_cache
95
104
  cos = cos.index_select(0, input_pos)
@@ -97,16 +106,20 @@ class TinyLLamma(nn.Module):
97
106
  mask = self.mask_cache.index_select(2, input_pos)
98
107
  mask = mask[:, :, :, : self.config.kv_cache_max]
99
108
 
100
- # forward the model itself
101
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
109
+ # token embeddings of shape (b, t, n_embd)
110
+ x = self.tok_embedding(tokens)
102
111
 
103
- for _, block in enumerate(self.transformer_blocks):
104
- x = block(x, (cos, sin), mask, input_pos)
112
+ updated_kv_entires = []
113
+ for i, block in enumerate(self.transformer_blocks):
114
+ kv_entry = kv_cache.caches[i] if kv_cache else None
115
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
116
+ if kv_entry:
117
+ updated_kv_entires.append(kv_entry)
118
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
105
119
 
106
120
  x = self.final_norm(x)
107
-
108
- res = self.lm_head(x) # (b, t, vocab_size)
109
- return res
121
+ logits = self.lm_head(x) # (b, t, vocab_size)
122
+ return {"logits": logits, "kv_cache": updated_kv_cache}
110
123
 
111
124
 
112
125
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -131,55 +144,63 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
131
144
  intermediate_size=5632,
132
145
  )
133
146
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
147
+ block_config = cfg.TransformerBlockConfig(
148
+ attn_config=attn_config,
149
+ ff_config=ff_config,
150
+ pre_attention_norm_config=norm_config,
151
+ post_attention_norm_config=norm_config,
152
+ )
134
153
  config = cfg.ModelConfig(
135
154
  vocab_size=32000,
136
155
  num_layers=22,
137
156
  max_seq_len=2048,
138
157
  embedding_dim=2048,
139
158
  kv_cache_max_len=kv_cache_max_len,
140
- attn_config=attn_config,
141
- ff_config=ff_config,
142
- pre_attention_norm_config=norm_config,
143
- post_attention_norm_config=norm_config,
159
+ block_configs=block_config,
144
160
  final_norm_config=norm_config,
145
161
  enable_hlfb=True,
146
162
  )
147
163
  return config
148
164
 
149
165
 
150
- def get_fake_model_config() -> cfg.ModelConfig:
151
- config = get_model_config()
166
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
167
+ config = get_model_config(**kwargs)
152
168
  config.vocab_size = 128
153
169
  config.num_layers = 2
154
- config.ff_config.intermediate_size = 64
170
+ # TinyLlama has only one block config.
171
+ config.block_config(0).ff_config.intermediate_size = 64
155
172
  return config
156
173
 
157
174
 
158
175
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
159
176
  config = get_model_config(**kwargs)
160
- model = TinyLLamma(config)
177
+ model = TinyLlama(config)
161
178
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
162
179
  loader.load(model)
180
+ model.eval()
163
181
  return model
164
182
 
165
183
 
166
- def define_and_run() -> None:
184
+ def define_and_run(checkpoint_path: str) -> None:
167
185
  """Instantiates and runs a TinyLlama model."""
168
186
 
169
- current_dir = Path(__file__).parent.resolve()
187
+ current_dir = pathlib.Path(__file__).parent.resolve()
170
188
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
171
189
  kv_cache_max_len = 1024
172
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
173
190
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
174
191
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
192
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
193
  tokens[0, :4] = idx
177
194
  input_pos = torch.arange(0, kv_cache_max_len)
178
- lm_logits = model.forward(tokens, input_pos)
195
+ kv = kv_utils.KVCache.from_model_config(model.config)
196
+ output = model.forward(tokens, input_pos, kv)
179
197
  assert torch.allclose(
180
- tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
198
+ tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
181
199
  )
182
200
 
183
201
 
184
202
  if __name__ == "__main__":
185
- define_and_run()
203
+ input_checkpoint_path = os.path.join(
204
+ pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
+ )
206
+ define_and_run(input_checkpoint_path)