ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -12,26 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building phi-2 model from the Edge Generative API layers.
16
- #
17
- # Note: This is an experimental version of phi2 with external KV cache.
18
- # Please use with caution.
15
+
16
+ """Example of building a Phi-2 model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
22
- from typing import Tuple
19
+ import pathlib
23
20
 
21
+ from ai_edge_torch.generative.layers import attention
24
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
25
  import ai_edge_torch.generative.layers.model_config as cfg
29
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
27
  import numpy as np
31
28
  import torch
32
29
  from torch import nn
33
30
 
34
-
35
31
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
32
  ff_up_proj="model.layers.{}.mlp.fc1",
37
33
  ff_down_proj="model.layers.{}.mlp.fc2",
@@ -52,7 +48,6 @@ class Phi2(nn.Module):
52
48
  def __init__(self, config: cfg.ModelConfig):
53
49
  super().__init__()
54
50
 
55
- self.config = config
56
51
  # Construct model layers.
57
52
  self.lm_head = nn.Linear(
58
53
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
@@ -60,18 +55,20 @@ class Phi2(nn.Module):
60
55
  self.tok_embedding = nn.Embedding(
61
56
  config.vocab_size, config.embedding_dim, padding_idx=0
62
57
  )
58
+ # Phi-2 has only one block config.
59
+ block_config = config.block_config(0)
63
60
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
61
+ attention.TransformerBlock(block_config, config)
62
+ for _ in range(config.num_layers)
65
63
  )
66
64
  self.final_norm = builder.build_norm(
67
65
  config.embedding_dim,
68
66
  config.final_norm_config,
69
67
  )
68
+ attn_config = block_config.attn_config
70
69
  self.rope_cache = attn_utils.build_rope_cache(
71
70
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
71
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
72
  base=10_000,
76
73
  condense_ratio=1,
77
74
  dtype=torch.float32,
@@ -89,13 +86,17 @@ class Phi2(nn.Module):
89
86
  self,
90
87
  tokens: torch.Tensor,
91
88
  input_pos: torch.Tensor,
92
- kv_cache: kv_utils.EKVCache,
93
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
89
+ kv_cache: kv_utils.KVCache,
90
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
94
91
  _, seq_len = tokens.size()
95
92
  assert self.config.max_seq_len >= seq_len, (
96
93
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
97
94
  f" {self.config.max_seq_len}"
98
95
  )
96
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
97
+ "The number of transformer blocks and the number of KV cache entries"
98
+ " must be the same."
99
+ )
99
100
 
100
101
  cos, sin = self.rope_cache
101
102
  cos = cos.index_select(0, input_pos)
@@ -111,11 +112,11 @@ class Phi2(nn.Module):
111
112
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
112
113
  if kv_entry:
113
114
  updated_kv_entires.append(kv_entry)
114
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
115
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
115
116
 
116
117
  x = self.final_norm(x)
117
- res = self.lm_head(x) # (b, t, vocab_size)
118
- return res, updated_kv_cache
118
+ logits = self.lm_head(x) # (b, t, vocab_size)
119
+ return {"logits": logits, "kv_cache": updated_kv_cache}
119
120
 
120
121
 
121
122
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -143,17 +144,20 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
144
  use_bias=True,
144
145
  )
145
146
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
147
+ block_config = cfg.TransformerBlockConfig(
148
+ attn_config=attn_config,
149
+ ff_config=ff_config,
150
+ pre_attention_norm_config=norm_config,
151
+ parallel_residual=True,
152
+ )
146
153
  config = cfg.ModelConfig(
147
154
  vocab_size=51200,
148
155
  num_layers=32,
149
156
  max_seq_len=2048,
150
157
  kv_cache_max_len=kv_cache_max_len,
151
158
  embedding_dim=2560,
152
- attn_config=attn_config,
153
- ff_config=ff_config,
154
- pre_attention_norm_config=norm_config,
159
+ block_configs=block_config,
155
160
  final_norm_config=norm_config,
156
- parallel_residual=True,
157
161
  lm_head_use_bias=True,
158
162
  enable_hlfb=True,
159
163
  )
@@ -165,43 +169,42 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
165
169
  config.vocab_size = 128
166
170
  config.num_layers = 2
167
171
  config.max_seq_len = 2 * kv_cache_max_len
168
- config.ff_config.intermediate_size = 128
172
+ # Phi-2 has only one block config.
173
+ config.block_config(0).ff_config.intermediate_size = 128
169
174
  return config
170
175
 
171
176
 
172
- def build_model(
173
- checkpoint_path: str, test_model: bool = False, **kwargs
174
- ) -> nn.Module:
177
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
175
178
  """Instantiates the model instance and load checkpoint if provided."""
176
- config = (
177
- get_fake_model_config(**kwargs)
178
- if test_model
179
- else get_model_config(**kwargs)
180
- )
179
+ config = get_model_config(**kwargs)
181
180
  model = Phi2(config)
182
- if checkpoint_path is not None:
183
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
184
- loader.load(model)
181
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
182
+ loader.load(model)
185
183
  model.eval()
186
184
  return model
187
185
 
188
186
 
189
- def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
187
+ def define_and_run(checkpoint_path: str) -> None:
190
188
  """Instantiates and runs a Phi-2 model."""
191
189
 
190
+ current_dir = pathlib.Path(__file__).parent.resolve()
191
+ phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
192
  kv_cache_max_len = 1024
193
- model = build_model(
194
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
195
- )
193
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
196
194
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
197
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
195
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
198
196
  tokens[0, :4] = idx
199
- input_pos = torch.arange(0, kv_cache_max_len)
200
- kv = kv_utils.EKVCache.from_model_config(model.config)
201
- print("running an inference")
202
- print(model.forward(tokens, input_pos, kv))
197
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
+ kv = kv_utils.KVCache.from_model_config(model.config)
199
+ output = model.forward(tokens, input_pos, kv)
200
+ print("comparing with goldens..")
201
+ assert torch.allclose(
202
+ phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
+ )
203
204
 
204
205
 
205
206
  if __name__ == "__main__":
206
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
207
+ input_checkpoint_path = os.path.join(
208
+ pathlib.Path.home(), "Downloads/llm_data/phi2"
209
+ )
207
210
  define_and_run(input_checkpoint_path)
@@ -12,28 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- #
16
- # Note: This is an experimental version of TinyLlama with external KV cache.
17
- # Please use with caution.
18
15
 
16
+ """Example of converting SmolLM model to multi-signature tflite model."""
19
17
 
20
18
  import os
21
- from pathlib import Path
19
+ import pathlib
22
20
 
23
21
  import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
24
  from ai_edge_torch.generative.quantize import quant_recipes
27
25
  import torch
28
26
 
29
27
 
30
- def convert_tiny_llama_to_tflite(
28
+ def convert_smollm_to_tflite(
31
29
  checkpoint_path: str,
32
30
  prefill_seq_len: int = 512,
33
31
  kv_cache_max_len: int = 1024,
34
32
  quantize: bool = True,
35
33
  ):
36
- """An example for converting TinyLlama model to multi-signature tflite model.
34
+ """Converts SmolLM model to multi-signature tflite model.
37
35
 
38
36
  Args:
39
37
  checkpoint_path (str): The filepath to the model checkpoint, or directory
@@ -45,15 +43,15 @@ def convert_tiny_llama_to_tflite(
45
43
  quantize (bool, optional): Whether the model should be quanized. Defaults
46
44
  to True.
47
45
  """
48
- pytorch_model = tiny_llama.build_model(
46
+ pytorch_model = smollm.build_model(
49
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
50
48
  )
51
49
  # Tensors used to trace the model graph during conversion.
52
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
53
- prefill_input_pos = torch.arange(0, prefill_seq_len)
54
- decode_token = torch.tensor([[0]], dtype=torch.long)
55
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
+ decode_token = torch.tensor([[0]], dtype=torch.int)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
57
55
 
58
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
57
  edge_model = (
@@ -77,11 +75,12 @@ def convert_tiny_llama_to_tflite(
77
75
  )
78
76
  .convert(quant_config=quant_config)
79
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
80
79
  edge_model.export(
81
- f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
80
+ f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
81
  )
83
82
 
84
83
 
85
84
  if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
87
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
86
+ convert_smollm_to_tflite(path)
@@ -0,0 +1,131 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmolLM model."""
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
+ # SmolLM re-uses the embedding as the head projection layer.
32
+ TENSOR_NAMES.lm_head = None
33
+
34
+
35
+ class SmolLM(tiny_llama.TinyLlama):
36
+ """A SmolLM model built from the Edge Generative API layers.
37
+
38
+ SmolLM shares the same architecture as TinyLlama, but with different model
39
+ sizes.
40
+ """
41
+
42
+ def __init__(self, config: cfg.ModelConfig):
43
+ super().__init__(config)
44
+ # SmolLM re-uses the embedding as the head projection layer.
45
+ self.lm_head.weight.data = self.tok_embedding.weight.data
46
+
47
+
48
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
+ """Returns the model config for a SmolLM 135M model.
50
+
51
+ Args:
52
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
+ is 1024.
54
+
55
+ Returns:
56
+ The model config for a SmolLM model.
57
+ """
58
+ attn_config = cfg.AttentionConfig(
59
+ num_heads=9,
60
+ head_dim=64,
61
+ num_query_groups=3,
62
+ rotary_percentage=1.0,
63
+ )
64
+ ff_config = cfg.FeedForwardConfig(
65
+ type=cfg.FeedForwardType.GATED,
66
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
67
+ intermediate_size=1536,
68
+ )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
70
+ block_config = cfg.TransformerBlockConfig(
71
+ attn_config=attn_config,
72
+ ff_config=ff_config,
73
+ pre_attention_norm_config=norm_config,
74
+ post_attention_norm_config=norm_config,
75
+ )
76
+ config = cfg.ModelConfig(
77
+ vocab_size=49152,
78
+ num_layers=30,
79
+ max_seq_len=2048,
80
+ embedding_dim=576,
81
+ kv_cache_max_len=kv_cache_max_len,
82
+ block_configs=block_config,
83
+ final_norm_config=norm_config,
84
+ enable_hlfb=True,
85
+ )
86
+ return config
87
+
88
+
89
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
90
+ config = get_model_config(**kwargs)
91
+ config.vocab_size = 128
92
+ config.num_layers = 2
93
+ # SmolLM has only one block config.
94
+ config.block_config(0).ff_config.intermediate_size = 64
95
+ return config
96
+
97
+
98
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
99
+ config = get_model_config(**kwargs)
100
+ model = SmolLM(config)
101
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
102
+ # Since embedding and lm-head use the same weight, we need to set strict
103
+ # to False.
104
+ loader.load(model, strict=False)
105
+ model.eval()
106
+ return model
107
+
108
+
109
+ def define_and_run(checkpoint_path: str) -> None:
110
+ """Instantiates and runs a SmolLM model."""
111
+
112
+ current_dir = pathlib.Path(__file__).parent.resolve()
113
+ smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
114
+ kv_cache_max_len = 1024
115
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
116
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
118
+ tokens[0, :4] = idx
119
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
120
+ kv = kv_utils.KVCache.from_model_config(model.config)
121
+ output = model.forward(tokens, input_pos, kv)
122
+ assert torch.allclose(
123
+ smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
124
+ )
125
+
126
+
127
+ if __name__ == "__main__":
128
+ input_checkpoint_path = os.path.join(
129
+ pathlib.Path.home(), "Downloads/llm_data/smollm"
130
+ )
131
+ define_and_run(input_checkpoint_path)
@@ -61,8 +61,10 @@ class CLIP(nn.Module):
61
61
  )
62
62
 
63
63
  self.config = config
64
+ # CLIP has only one block config.
65
+ block_config = config.block_config(0)
64
66
  self.transformer_blocks = nn.ModuleList(
65
- TransformerBlock(config) for _ in range(config.num_layers)
67
+ TransformerBlock(block_config, config) for _ in range(config.num_layers)
66
68
  )
67
69
  self.final_norm = builder.build_norm(
68
70
  config.embedding_dim, config.final_norm_config
@@ -74,7 +76,7 @@ class CLIP(nn.Module):
74
76
 
75
77
  @torch.inference_mode
76
78
  def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
77
- tokens = tokens.type(torch.long)
79
+ tokens = tokens.type(torch.int)
78
80
 
79
81
  state = self.tok_embedding(tokens) + self.tok_embedding_position
80
82
  for layer in self.transformer_blocks:
@@ -112,15 +114,19 @@ def get_model_config() -> cfg.ModelConfig:
112
114
 
113
115
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
114
116
 
117
+ block_config = cfg.TransformerBlockConfig(
118
+ attn_config=attn_config,
119
+ ff_config=ff_config,
120
+ pre_attention_norm_config=norm_config,
121
+ post_attention_norm_config=norm_config,
122
+ )
123
+
115
124
  config = cfg.ModelConfig(
116
125
  vocab_size=vocab_size,
117
126
  num_layers=num_layers,
118
127
  max_seq_len=max_seq_len,
119
128
  embedding_dim=embedding_dim,
120
- attn_config=attn_config,
121
- ff_config=ff_config,
122
- pre_attention_norm_config=norm_config,
123
- post_attention_norm_config=norm_config,
129
+ block_configs=block_config,
124
130
  final_norm_config=norm_config,
125
131
  enable_hlfb=True,
126
132
  )
@@ -94,7 +94,7 @@ def convert_stable_diffusion_to_tflite(
94
94
  n_tokens = 77
95
95
  timestamp = 0
96
96
  len_prompt = 1
97
- prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
97
+ prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.int)
98
98
  input_image = torch.full(
99
99
  (1, 3, image_height, image_width), 0, dtype=torch.float32
100
100
  )
@@ -29,24 +29,24 @@ def convert_t5_to_tflite_singlesig(checkpoint_path: str):
29
29
 
30
30
  # encoder
31
31
  seq_len = 512
32
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
32
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
33
33
  prompt_e_token = [1, 2, 3, 4, 5, 6]
34
34
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
35
- prompt_e_token, dtype=torch.long
35
+ prompt_e_token, dtype=torch.int
36
36
  )
37
- prefill_e_input_pos = torch.arange(0, seq_len)
38
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
37
+ prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
38
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
39
39
  prompt_d_token = [1, 2, 3, 4, 5, 6]
40
40
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
41
- prompt_d_token, dtype=torch.long
41
+ prompt_d_token, dtype=torch.int
42
42
  )
43
- prefill_d_input_pos = torch.arange(0, seq_len)
43
+ prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
44
44
 
45
45
  # decoder
46
- decode_token = torch.tensor([[1]], dtype=torch.long)
47
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
48
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
49
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
46
+ decode_token = torch.tensor([[1]], dtype=torch.int)
47
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
48
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
49
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
50
50
 
51
51
  # Pad mask for self attention only on "real" tokens.
52
52
  # Pad with `-inf` for any tokens indices that aren't desired.
@@ -81,24 +81,24 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
81
81
 
82
82
  # encoder
83
83
  seq_len = 512
84
- prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
84
+ prefill_e_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
85
85
  prompt_e_token = [1, 2, 3, 4, 5, 6]
86
86
  prefill_e_tokens[0, : len(prompt_e_token)] = torch.tensor(
87
- prompt_e_token, dtype=torch.long
87
+ prompt_e_token, dtype=torch.int
88
88
  )
89
- prefill_e_input_pos = torch.arange(0, seq_len)
90
- prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.long)
89
+ prefill_e_input_pos = torch.arange(0, seq_len, dtype=torch.int)
90
+ prefill_d_tokens = torch.full((1, seq_len), 0, dtype=torch.int)
91
91
  prompt_d_token = [1, 2, 3, 4, 5, 6]
92
92
  prefill_d_tokens[0, : len(prompt_d_token)] = torch.tensor(
93
- prompt_d_token, dtype=torch.long
93
+ prompt_d_token, dtype=torch.int
94
94
  )
95
- prefill_d_input_pos = torch.arange(0, seq_len)
95
+ prefill_d_input_pos = torch.arange(0, seq_len, dtype=torch.int)
96
96
 
97
97
  # decoder
98
- decode_token = torch.tensor([[1]], dtype=torch.long)
99
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
100
- decode_d_token = torch.tensor([[1]], dtype=torch.long)
101
- decode_d_input_pos = torch.tensor([0], dtype=torch.int64)
98
+ decode_token = torch.tensor([[1]], dtype=torch.int)
99
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
100
+ decode_d_token = torch.tensor([[1]], dtype=torch.int)
101
+ decode_d_input_pos = torch.tensor([0], dtype=torch.int)
102
102
 
103
103
  # Pad mask for self attention only on "real" tokens.
104
104
  # Pad with `-inf` for any tokens indices that aren't desired.