ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  18. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  20. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  21. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  22. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  23. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  25. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  26. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  27. ai_edge_torch/generative/layers/model_config.py +2 -0
  28. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  29. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  30. ai_edge_torch/generative/utilities/verifier.py +117 -97
  31. ai_edge_torch/version.py +1 -1
  32. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +36 -29
  34. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
29
30
  "What is the meaning of life?",
30
31
  "The input prompts to generate answers.",
31
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
32
38
 
33
39
 
34
40
  def main(_):
35
41
  checkpoint = "apple/OpenELM-3B"
36
42
  logging.info("Loading the original model from: %s", checkpoint)
37
- wrapper_model = verifier.ModelWrapper(
38
- model=transformers.AutoModelForCausalLM.from_pretrained(
39
- checkpoint, trust_remote_code=True
40
- ),
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
41
45
  )
42
46
 
43
47
  # Locate the cached dir.
@@ -53,10 +57,13 @@ def main(_):
53
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
54
58
 
55
59
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
67
  )
61
68
 
62
69
 
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
67
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=10_000,
69
- condense_ratio=1,
70
- dtype=torch.float32,
71
- device=torch.device("cpu"),
68
+ base=attn_config.rotary_base,
72
69
  )
73
70
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
71
  size=config.kv_cache_max,
75
- dtype=torch.float32,
76
- device=torch.device("cpu"),
77
72
  )
78
73
  self.config = config
79
74
 
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
129
124
  num_heads=32,
130
125
  head_dim=80,
131
126
  num_query_groups=32,
127
+ rotary_base=10000,
132
128
  rotary_percentage=0.4,
133
129
  qkv_use_bias=True,
134
130
  output_proj_use_bias=True,
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
- base=10_000,
167
+ base=attn_config.rotary_base,
174
168
  condense_ratio=1,
175
169
  dtype=torch.float32,
176
170
  device=torch.device("cpu"),
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
181
175
  )
182
176
  self.mask_cache = attn_utils.build_causal_mask_cache(
183
177
  size=config.kv_cache_max,
184
- dtype=torch.float32,
185
- device=torch.device("cpu"),
186
178
  )
187
179
  self.config = config
188
180
 
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
238
230
  num_heads=32,
239
231
  head_dim=96,
240
232
  num_query_groups=32,
233
+ rotary_base=10000,
241
234
  rotary_percentage=1.0,
242
235
  qkv_transpose_before_split=True,
243
236
  )
@@ -19,6 +19,7 @@ import logging
19
19
  from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
22
23
  from ai_edge_torch.generative.utilities import verifier
23
24
  import kagglehub
24
25
  import transformers
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
40
  def main(_):
40
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
41
42
  logging.info("Loading the original model from: %s", checkpoint)
42
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
43
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
44
- wrapper_model = verifier.ModelWrapper(
45
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
46
- hf_generation_config=generation_config,
47
- )
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
48
44
 
49
45
  logging.info("Building the reauthored model from: %s", checkpoint)
50
46
  reauthored_model = phi2.build_model(checkpoint)
@@ -53,10 +49,13 @@ def main(_):
53
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
54
50
 
55
51
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
59
  atol=1e-03,
61
60
  )
62
61
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
40
41
  def main(_):
41
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
42
43
  logging.info("Loading the original model from: %s", checkpoint)
43
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
44
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
45
- wrapper_model = verifier.ModelWrapper(
46
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
47
- hf_generation_config=generation_config,
48
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
49
45
 
50
46
  # Locate the cached dir.
51
47
  cached_config_file = transformers.utils.cached_file(
@@ -59,10 +55,13 @@ def main(_):
59
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
60
56
 
61
57
  verifier.verify_reauthored_model(
62
- original_model=wrapper_model,
63
- reauthored_model=reauthored_model,
64
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
65
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
66
65
  )
67
66
 
68
67
 
@@ -54,6 +54,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
54
54
  num_heads=9,
55
55
  head_dim=64,
56
56
  num_query_groups=3,
57
+ rotary_base=10000,
57
58
  rotary_percentage=1.0,
58
59
  )
59
60
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "What is the meaning of life?",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
40
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
41
46
  # Locate the cached dir.
42
47
  cached_config_file = transformers.utils.cached_file(
43
48
  checkpoint, transformers.utils.CONFIG_NAME
@@ -50,10 +55,13 @@ def main(_):
50
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
56
 
52
57
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
65
  atol=1e-04,
58
66
  )
59
67
 
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
98
98
  num_heads=num_heads,
99
99
  head_dim=embedding_dim // num_heads,
100
100
  num_query_groups=num_query_groups,
101
+ rotary_base=0,
101
102
  rotary_percentage=0.0,
102
103
  qkv_use_bias=True,
103
104
  qkv_transpose_before_split=True,
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
148
149
  num_heads=num_heads,
149
150
  head_dim=embedding_dim // num_heads,
150
151
  num_query_groups=num_query_groups,
152
+ rotary_base=0,
151
153
  rotary_percentage=0.0,
152
154
  qkv_use_bias=True,
153
155
  qkv_transpose_before_split=True,
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
295
295
  enable_kv_cache=False,
296
296
  qkv_transpose_before_split=True,
297
297
  qkv_fused_interleaved=False,
298
+ rotary_base=0,
298
299
  rotary_percentage=0.0,
299
300
  ),
300
301
  enable_hlfb=False,
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
351
352
  enable_kv_cache=False,
352
353
  qkv_transpose_before_split=True,
353
354
  qkv_fused_interleaved=False,
355
+ rotary_base=0,
354
356
  rotary_percentage=0.0,
355
357
  ),
356
358
  enable_hlfb=False,
@@ -199,6 +199,7 @@ def build_attention_config(
199
199
  num_heads,
200
200
  dim,
201
201
  num_query_groups,
202
+ rotary_base=0,
202
203
  rotary_percentage=0.0,
203
204
  qkv_transpose_before_split=True,
204
205
  qkv_use_bias=False,
@@ -211,6 +212,7 @@ def build_attention_config(
211
212
  num_heads=num_heads,
212
213
  head_dim=dim // num_heads,
213
214
  num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
214
216
  rotary_percentage=rotary_percentage,
215
217
  qkv_transpose_before_split=qkv_transpose_before_split,
216
218
  qkv_use_bias=qkv_use_bias,
@@ -335,8 +335,6 @@ class T5Decoder(nn.Module):
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
337
  size=config.kv_cache_max,
338
- dtype=torch.float32,
339
- device=torch.device("cpu"),
340
338
  )
341
339
 
342
340
  @torch.inference_mode
@@ -44,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
44
44
  self.rope_cache = attn_utils.build_rope_cache(
45
45
  size=config.max_seq_len,
46
46
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
47
- base=10_000,
48
- condense_ratio=1,
49
- dtype=torch.float32,
50
- device=torch.device('cpu'),
47
+ base=attn_config.rotary_base,
51
48
  )
52
49
  self.mask_cache = attn_utils.build_causal_mask_cache(
53
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
50
+ size=config.max_seq_len,
54
51
  )
55
52
  self.config = config
56
53
 
@@ -93,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
93
90
  self.rope_cache = attn_utils.build_rope_cache(
94
91
  size=config.max_seq_len,
95
92
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
96
- base=10_000,
97
- condense_ratio=1,
98
- dtype=torch.float32,
99
- device=torch.device('cpu'),
93
+ base=attn_config.rotary_base,
100
94
  )
101
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
102
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
96
+ size=config.max_seq_len,
103
97
  )
104
98
  self.config = config
105
99
 
@@ -124,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
124
118
  num_heads=32,
125
119
  head_dim=4,
126
120
  num_query_groups=4,
121
+ rotary_base=10000,
127
122
  rotary_percentage=1.0,
128
123
  enable_kv_cache=False,
129
124
  )
@@ -51,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
51
51
  self.rope_cache = attn_utils.build_rope_cache(
52
52
  size=config.max_seq_len,
53
53
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
54
- base=10_000,
55
- condense_ratio=1,
56
- dtype=torch.float32,
57
- device=torch.device('cpu'),
54
+ base=attn_config.rotary_base,
58
55
  )
59
56
  self.mask_cache = attn_utils.build_causal_mask_cache(
60
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ size=config.max_seq_len,
61
58
  )
62
59
  self.config = config
63
60
 
@@ -91,6 +88,7 @@ def get_model_config() -> cfg.ModelConfig:
91
88
  num_heads=32,
92
89
  head_dim=4,
93
90
  num_query_groups=4,
91
+ rotary_base=10000,
94
92
  rotary_percentage=1.0,
95
93
  )
96
94
  ff_config = cfg.FeedForwardConfig(
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
67
67
  self.rope_cache = attn_utils.build_rope_cache(
68
68
  size=config.kv_cache_max,
69
69
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=10_000,
71
- condense_ratio=1,
72
- dtype=torch.float32,
73
- device=torch.device("cpu"),
70
+ base=attn_config.rotary_base,
74
71
  )
75
72
  self.mask_cache = attn_utils.build_causal_mask_cache(
76
73
  size=config.kv_cache_max,
77
- dtype=torch.float32,
78
- device=torch.device("cpu"),
79
74
  )
80
75
  self.config = config
81
76
 
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
127
  num_heads=32,
133
128
  head_dim=64,
134
129
  num_query_groups=4,
130
+ rotary_base=10000,
135
131
  rotary_percentage=1.0,
136
132
  )
137
133
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "Show me the program to add 2 and 3.",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(
40
- checkpoint, trust_remote_code=True
41
- ),
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
42
46
  )
47
+
43
48
  # Locate the cached dir.
44
49
  cached_config_file = transformers.utils.cached_file(
45
50
  checkpoint, transformers.utils.CONFIG_NAME
@@ -52,10 +57,13 @@ def main(_):
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  atol=1e-04,
60
68
  )
61
69
 
@@ -83,6 +83,8 @@ class AttentionConfig:
83
83
  # Used to determine number of groups in grouped query attention (GQA)
84
84
  # https://arxiv.org/pdf/2305.13245.pdf
85
85
  num_query_groups: Optional[int]
86
+ # Base of rotary positional embedding.
87
+ rotary_base: int = 10_000
86
88
  # Percentage of Rotary Positional Embedding added Q and K projections.
87
89
  rotary_percentage: Optional[float] = None
88
90
  # Whether to transpose the query groups of qkv bundled tensor before
@@ -19,6 +19,7 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
24
25
  from ai_edge_torch.generative.examples.phi import phi3
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
102
103
  pytorch_model = gemma2.Gemma2(config).eval()
103
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
105
115
  @googletest.skipIf(
106
116
  ai_edge_config.Config.use_torch_xla,
107
117
  reason="tests with custom ops are not supported on oss",
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)