ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240929__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (40) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +81 -0
  19. ai_edge_torch/generative/examples/qwen/qwen.py +141 -0
  20. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  21. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  23. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  24. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  25. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  26. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  27. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  28. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  29. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  30. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  31. ai_edge_torch/generative/layers/model_config.py +2 -0
  32. ai_edge_torch/generative/test/test_model_conversion_large.py +20 -0
  33. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  34. ai_edge_torch/generative/utilities/verifier.py +117 -97
  35. ai_edge_torch/version.py +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/METADATA +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/RECORD +40 -29
  38. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/LICENSE +0 -0
  39. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/WHEEL +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/top_level.txt +0 -0
@@ -69,15 +69,10 @@ class Gemma(nn.Module):
69
69
  self.rope_cache = attn_utils.build_rope_cache(
70
70
  size=config.kv_cache_max,
71
71
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
72
- base=10_000,
73
- condense_ratio=1,
74
- dtype=torch.float32,
75
- device=torch.device("cpu"),
72
+ base=attn_config.rotary_base,
76
73
  )
77
74
  self.mask_cache = attn_utils.build_causal_mask_cache(
78
75
  size=config.kv_cache_max,
79
- dtype=torch.float32,
80
- device=torch.device("cpu"),
81
76
  )
82
77
  self.config = config
83
78
 
@@ -135,6 +130,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
135
130
  num_heads=8,
136
131
  head_dim=256,
137
132
  num_query_groups=1,
133
+ rotary_base=10000,
138
134
  rotary_percentage=1.0,
139
135
  )
140
136
  ff_config = cfg.FeedForwardConfig(
@@ -109,21 +109,14 @@ class Gemma2(nn.Module):
109
109
  self.rope_cache = attn_utils.build_rope_cache(
110
110
  size=config.kv_cache_max,
111
111
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=10_000,
113
- condense_ratio=1,
114
- dtype=torch.float32,
115
- device=torch.device("cpu"),
112
+ base=attn_config.rotary_base,
116
113
  )
117
114
  self.mask_cache = attn_utils.build_causal_mask_cache(
118
115
  size=config.kv_cache_max,
119
- dtype=torch.float32,
120
- device=torch.device("cpu"),
121
116
  )
122
117
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
123
118
  size=config.kv_cache_max,
124
119
  window_size=attn_config.sliding_window_size,
125
- dtype=torch.float32,
126
- device=torch.device("cpu"),
127
120
  )
128
121
  self.config = config
129
122
 
@@ -208,6 +201,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
208
201
  num_heads=8,
209
202
  head_dim=256,
210
203
  num_query_groups=4,
204
+ rotary_base=10000,
211
205
  rotary_percentage=1.0,
212
206
  qkv_transpose_before_split=True,
213
207
  logit_softcap=50.0,
@@ -20,7 +20,6 @@ from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.gemma import verify_util
23
- from ai_edge_torch.generative.utilities import verifier
24
23
  import kagglehub
25
24
 
26
25
 
@@ -15,7 +15,6 @@
15
15
 
16
16
  """Utility functions to verify the reauthored Gemma model."""
17
17
 
18
- import dataclasses
19
18
  import logging
20
19
  import os
21
20
  from typing import List, Tuple
@@ -27,26 +26,17 @@ from gemma import model as gemma_model
27
26
  import torch
28
27
 
29
28
 
30
- @dataclasses.dataclass
31
- class _Output:
32
- logits: torch.Tensor
33
-
34
-
35
29
  class GemmaWrapper(verifier.ModelWrapper):
36
30
  """Gemma model wrapper for verification.
37
31
 
38
32
  Verifier calls model.forward() with maxium sequence length (1024) expecting
39
- the output has 'logits' field while Gemma gets the input tokens with the
40
- actual length and returns logits in a tuple.
33
+ the output is logits while Gemma gets the input tokens with the actual length
34
+ and returns logits in a tuple.
41
35
 
42
36
  Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
43
37
  inside model.generate().
44
38
  """
45
39
 
46
- def __init__(self, model: torch.nn.Module, max_new_tokens: int):
47
- super().__init__(model)
48
- self.max_new_tokens = max_new_tokens
49
-
50
40
  def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
51
41
  for i in range(tokens.shape[1]):
52
42
  if tokens[0, i] == 0:
@@ -63,7 +53,7 @@ class GemmaWrapper(verifier.ModelWrapper):
63
53
  (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
64
54
  ]
65
55
 
66
- def forward(self, tokens: torch.Tensor) -> _Output:
56
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
67
57
  """Forwards the model after reducing input tokens to the actual length."""
68
58
  actual_input_len = self._get_actual_input_len(tokens)
69
59
  input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
@@ -79,28 +69,26 @@ class GemmaWrapper(verifier.ModelWrapper):
79
69
  top_ps=torch.tensor([1.0], dtype=torch.float),
80
70
  top_ks=torch.tensor([1], dtype=torch.long),
81
71
  )
82
- return _Output(logits.float())
72
+ return logits
83
73
 
84
- def generate(self, tokens: torch.Tensor) -> torch.Tensor:
74
+ def generate(
75
+ self, tokens: torch.Tensor, max_new_tokens: int
76
+ ) -> torch.IntTensor:
85
77
  """Generates the response after decoding the tokens into a string."""
86
78
  prompts = self.model.tokenizer.decode(tokens[0].tolist())
87
79
  response = self.model.generate(
88
- prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
80
+ prompts, device="cpu", output_len=max_new_tokens, top_k=1
89
81
  )
90
82
  return torch.tensor([self.model.tokenizer.encode(prompts + response)])
91
83
 
92
84
 
93
- class TokenizerWrapper(torch.nn.Module):
85
+ class GemmaTokenizerWrapper(verifier.TokenizerWrapper):
94
86
  """Tokenizer wrapper for verification.
95
87
 
96
88
  Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
97
89
  tokenizer expects tokens in a list.
98
90
  """
99
91
 
100
- def __init__(self, tokenizer: torch.nn.Module):
101
- super().__init__()
102
- self.tokenizer = tokenizer
103
-
104
92
  def encode(self, text: str, **_) -> torch.Tensor:
105
93
  """Adds one more dimension to the output of the tokenizer."""
106
94
  return torch.tensor([self.tokenizer.encode(text)])
@@ -133,10 +121,11 @@ def verify_reauthored_gemma_model(
133
121
  original_model.load_weights(os.path.join(checkpoint, weight_filename))
134
122
 
135
123
  verifier.verify_reauthored_model(
136
- original_model=GemmaWrapper(original_model, max_new_tokens),
137
- reauthored_model=reauthored_model,
138
- tokenizer=TokenizerWrapper(original_model.tokenizer),
124
+ original_model=GemmaWrapper(original_model),
125
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
126
+ tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
139
127
  generate_prompts=generate_prompts,
128
+ max_new_tokens=max_new_tokens,
140
129
  forward_input_ids=forward_input_ids,
141
130
  rtol=rtol,
142
131
  atol=atol,
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Llama 3.2 3B model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = llama.build_3b_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'llama_3b_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting Llama 3.2 1B model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = llama.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,204 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Llama 3.2 models."""
17
+
18
+ import copy
19
+ import math
20
+ from typing import Tuple
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ import ai_edge_torch.generative.layers.model_config as cfg
24
+ import ai_edge_torch.generative.utilities.loader as loading_utils
25
+ import torch
26
+ from torch import nn
27
+
28
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
29
+ # SmolLM re-uses the embedding as the head projection layer.
30
+ TENSOR_NAMES.lm_head = None
31
+
32
+
33
+ def _build_llama3_rope_cache(
34
+ size: int,
35
+ dim: int,
36
+ base: int,
37
+ condense_ratio: int,
38
+ dtype: torch.dtype,
39
+ device: torch.device,
40
+ factor: float,
41
+ low_freq_factor: float,
42
+ high_freq_factor: float,
43
+ max_seq_len: int,
44
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
45
+ """Precomputes Rotary Positional Embeddings for Llama 3.2 model.
46
+
47
+ It's a modified version of attn_utils.build_rope_cache with additional
48
+ arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
49
+ and Cos values with scaling factors for quick lookup during the inference.
50
+
51
+ Reference:
52
+ https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
53
+
54
+ Args:
55
+ size (int): The size of the built cache.
56
+ dim (int): Each sequence's dimmension.
57
+ base (int, optional): Rope base value.
58
+ condense_ratio (int, optional): The ratio by which sequence indicies are
59
+ condensed.
60
+ dtype (torch.dtype, optional): Output tensor's data type.
61
+ device (torch.device, optional): Output tensor's data type.
62
+ factor (float): Factor to scale theta down for tokens in long range in the
63
+ sequence.
64
+ low_freq_factor (float): Factor to determine if tokens are in long range
65
+ in the sequence.
66
+ high_freq_factor (float): Factor to determine if tokens are in short range
67
+ in the sequence.
68
+ max_seq_len (int): The original token sequence length before extending
69
+ ROPE to support longer sequence.
70
+
71
+ Returns:
72
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
73
+ """
74
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
75
+ low_freq_wavelen = max_seq_len / low_freq_factor
76
+ high_freq_wavelen = max_seq_len / high_freq_factor
77
+ wavelen = 2 * math.pi / theta
78
+ # wavelen < high_freq_wavelen: do nothing
79
+ # wavelen > low_freq_wavelen: divide by factor
80
+ theta = torch.where(wavelen > low_freq_wavelen, theta / factor, theta)
81
+ # otherwise: interpolate between the two, using a smooth factor
82
+ smooth_factor = (max_seq_len / wavelen - low_freq_factor) / (
83
+ high_freq_factor - low_freq_factor
84
+ )
85
+ smoothed_theta = (1 - smooth_factor) * theta / factor + smooth_factor * theta
86
+ is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
87
+ theta = torch.where(is_medium, smoothed_theta, theta)
88
+
89
+ seq_idx = torch.arange(size) / condense_ratio
90
+ idx_theta = torch.outer(seq_idx, theta)
91
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
92
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
93
+ return cos, sin
94
+
95
+
96
+ class Llama(tiny_llama.TinyLlama):
97
+ """A Llama model built from the Edge Generative API layers.
98
+
99
+ Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
100
+ """
101
+
102
+ def __init__(self, config: cfg.ModelConfig):
103
+ super().__init__(config)
104
+ # Llama 3.2 re-uses the embedding as the head projection layer.
105
+ self.lm_head.weight.data = self.tok_embedding.weight.data
106
+ # Llama has only one block config.
107
+ attn_config = self.config.block_config(0).attn_config
108
+ self.rope_cache = _build_llama3_rope_cache(
109
+ size=self.config.kv_cache_max,
110
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
111
+ base=attn_config.rotary_base,
112
+ condense_ratio=1,
113
+ dtype=torch.float32,
114
+ device=torch.device("cpu"),
115
+ factor=32.0,
116
+ low_freq_factor=1.0,
117
+ high_freq_factor=4.0,
118
+ max_seq_len=self.config.max_seq_len,
119
+ )
120
+
121
+
122
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
123
+ """Returns the model config for a Llama 3.2-1B model.
124
+
125
+ Args:
126
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
127
+ is 1024.
128
+
129
+ Returns:
130
+ The model config for a SmolLM model.
131
+ """
132
+ attn_config = cfg.AttentionConfig(
133
+ num_heads=32,
134
+ head_dim=64,
135
+ num_query_groups=8,
136
+ rotary_base=500000,
137
+ rotary_percentage=1.0,
138
+ )
139
+ ff_config = cfg.FeedForwardConfig(
140
+ type=cfg.FeedForwardType.GATED,
141
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
142
+ intermediate_size=8192,
143
+ )
144
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
145
+ block_config = cfg.TransformerBlockConfig(
146
+ attn_config=attn_config,
147
+ ff_config=ff_config,
148
+ pre_attention_norm_config=norm_config,
149
+ post_attention_norm_config=norm_config,
150
+ )
151
+ config = cfg.ModelConfig(
152
+ vocab_size=128256,
153
+ num_layers=16,
154
+ max_seq_len=8192,
155
+ embedding_dim=2048,
156
+ kv_cache_max_len=kv_cache_max_len,
157
+ block_configs=block_config,
158
+ final_norm_config=norm_config,
159
+ enable_hlfb=True,
160
+ )
161
+ return config
162
+
163
+
164
+ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
165
+ """Returns the model config for a Llama 3.2-3B model."""
166
+ config = get_model_config(kv_cache_max_len)
167
+ # Llama 3.2 has only one block config.
168
+ attn_config = config.block_config(0).attn_config
169
+ attn_config.num_heads = 24
170
+ attn_config.head_dim = 128
171
+ config.num_layers = 28
172
+ config.embedding_dim = 3072
173
+ return config
174
+
175
+
176
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177
+ config = get_model_config(**kwargs)
178
+ config.vocab_size = 128
179
+ config.num_layers = 2
180
+ # SmolLM has only one block config.
181
+ config.block_config(0).ff_config.intermediate_size = 64
182
+ return config
183
+
184
+
185
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
186
+ config = get_model_config(**kwargs)
187
+ model = Llama(config)
188
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
189
+ # Since embedding and lm-head use the same weight, we need to set strict
190
+ # to False.
191
+ loader.load(model, strict=False)
192
+ model.eval()
193
+ return model
194
+
195
+
196
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
197
+ config = get_3b_model_config(**kwargs)
198
+ model = Llama(config)
199
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
200
+ # Since embedding and lm-head use the same weight, we need to set strict
201
+ # to False.
202
+ loader.load(model, strict=False)
203
+ model.eval()
204
+ return model
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Llama 3.2-1B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = llama.build_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
+ # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
+ # available.
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
+
60
+ verifier.verify_reauthored_model(
61
+ original_model=transformers_verifier.TransformersModelWrapper(
62
+ original_model
63
+ ),
64
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
66
+ generate_prompts=_PROMPTS.value,
67
+ max_new_tokens=_MAX_NEW_TOKENS.value,
68
+ atol=1e-04,
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ app.run(main)
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Llama 3.2-3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.llama import llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _PROMPTS = flags.DEFINE_multi_string(
30
+ "prompts",
31
+ "What is the meaning of life?",
32
+ "The input prompts to generate answers.",
33
+ )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
39
+
40
+
41
+ def main(_):
42
+ checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
46
+ # Locate the cached dir.
47
+ cached_config_file = transformers.utils.cached_file(
48
+ checkpoint, transformers.utils.CONFIG_NAME
49
+ )
50
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
+ reauthored_model = llama.build_3b_model(reauthored_checkpoint)
53
+
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
55
+ # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
+ # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
+ # available.
58
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
+
60
+ verifier.verify_reauthored_model(
61
+ original_model=transformers_verifier.TransformersModelWrapper(
62
+ original_model
63
+ ),
64
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
66
+ generate_prompts=_PROMPTS.value,
67
+ max_new_tokens=_MAX_NEW_TOKENS.value,
68
+ atol=1e-04,
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ app.run(main)
@@ -68,15 +68,10 @@ class OpenELM(nn.Module):
68
68
  self.rope_cache = attn_utils.build_rope_cache(
69
69
  size=config.kv_cache_max,
70
70
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
- base=10_000,
72
- condense_ratio=1,
73
- dtype=torch.float32,
74
- device=torch.device("cpu"),
71
+ base=attn_config.rotary_base,
75
72
  )
76
73
  self.mask_cache = attn_utils.build_causal_mask_cache(
77
74
  size=config.kv_cache_max,
78
- dtype=torch.float32,
79
- device=torch.device("cpu"),
80
75
  )
81
76
  self.config = config
82
77
 
@@ -154,6 +149,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
154
149
  num_heads=num_heads[idx],
155
150
  head_dim=128,
156
151
  num_query_groups=num_query_groups[idx],
152
+ rotary_base=10000,
157
153
  rotary_percentage=1.0,
158
154
  qkv_transpose_before_split=True,
159
155
  query_norm_config=norm_config,