ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (36) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  18. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  20. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  21. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  22. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  23. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  25. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  26. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  27. ai_edge_torch/generative/layers/model_config.py +2 -0
  28. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  29. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  30. ai_edge_torch/generative/utilities/verifier.py +117 -97
  31. ai_edge_torch/version.py +1 -1
  32. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +36 -29
  34. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
29
30
  "What is the meaning of life?",
30
31
  "The input prompts to generate answers.",
31
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
32
38
 
33
39
 
34
40
  def main(_):
35
41
  checkpoint = "apple/OpenELM-3B"
36
42
  logging.info("Loading the original model from: %s", checkpoint)
37
- wrapper_model = verifier.ModelWrapper(
38
- model=transformers.AutoModelForCausalLM.from_pretrained(
39
- checkpoint, trust_remote_code=True
40
- ),
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
41
45
  )
42
46
 
43
47
  # Locate the cached dir.
@@ -53,10 +57,13 @@ def main(_):
53
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
54
58
 
55
59
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
67
  )
61
68
 
62
69
 
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
65
65
  self.rope_cache = attn_utils.build_rope_cache(
66
66
  size=config.kv_cache_max,
67
67
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
68
- base=10_000,
69
- condense_ratio=1,
70
- dtype=torch.float32,
71
- device=torch.device("cpu"),
68
+ base=attn_config.rotary_base,
72
69
  )
73
70
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
71
  size=config.kv_cache_max,
75
- dtype=torch.float32,
76
- device=torch.device("cpu"),
77
72
  )
78
73
  self.config = config
79
74
 
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
129
124
  num_heads=32,
130
125
  head_dim=80,
131
126
  num_query_groups=32,
127
+ rotary_base=10000,
132
128
  rotary_percentage=0.4,
133
129
  qkv_use_bias=True,
134
130
  output_proj_use_bias=True,
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
- base=10_000,
167
+ base=attn_config.rotary_base,
174
168
  condense_ratio=1,
175
169
  dtype=torch.float32,
176
170
  device=torch.device("cpu"),
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
181
175
  )
182
176
  self.mask_cache = attn_utils.build_causal_mask_cache(
183
177
  size=config.kv_cache_max,
184
- dtype=torch.float32,
185
- device=torch.device("cpu"),
186
178
  )
187
179
  self.config = config
188
180
 
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
238
230
  num_heads=32,
239
231
  head_dim=96,
240
232
  num_query_groups=32,
233
+ rotary_base=10000,
241
234
  rotary_percentage=1.0,
242
235
  qkv_transpose_before_split=True,
243
236
  )
@@ -19,6 +19,7 @@ import logging
19
19
  from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
22
23
  from ai_edge_torch.generative.utilities import verifier
23
24
  import kagglehub
24
25
  import transformers
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
40
  def main(_):
40
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
41
42
  logging.info("Loading the original model from: %s", checkpoint)
42
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
43
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
44
- wrapper_model = verifier.ModelWrapper(
45
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
46
- hf_generation_config=generation_config,
47
- )
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
48
44
 
49
45
  logging.info("Building the reauthored model from: %s", checkpoint)
50
46
  reauthored_model = phi2.build_model(checkpoint)
@@ -53,10 +49,13 @@ def main(_):
53
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
54
50
 
55
51
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
59
  atol=1e-03,
61
60
  )
62
61
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
40
41
  def main(_):
41
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
42
43
  logging.info("Loading the original model from: %s", checkpoint)
43
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
44
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
45
- wrapper_model = verifier.ModelWrapper(
46
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
47
- hf_generation_config=generation_config,
48
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
49
45
 
50
46
  # Locate the cached dir.
51
47
  cached_config_file = transformers.utils.cached_file(
@@ -59,10 +55,13 @@ def main(_):
59
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
60
56
 
61
57
  verifier.verify_reauthored_model(
62
- original_model=wrapper_model,
63
- reauthored_model=reauthored_model,
64
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
65
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
66
65
  )
67
66
 
68
67
 
@@ -54,6 +54,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
54
54
  num_heads=9,
55
55
  head_dim=64,
56
56
  num_query_groups=3,
57
+ rotary_base=10000,
57
58
  rotary_percentage=1.0,
58
59
  )
59
60
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "What is the meaning of life?",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
40
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
41
46
  # Locate the cached dir.
42
47
  cached_config_file = transformers.utils.cached_file(
43
48
  checkpoint, transformers.utils.CONFIG_NAME
@@ -50,10 +55,13 @@ def main(_):
50
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
56
 
52
57
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
65
  atol=1e-04,
58
66
  )
59
67
 
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
98
98
  num_heads=num_heads,
99
99
  head_dim=embedding_dim // num_heads,
100
100
  num_query_groups=num_query_groups,
101
+ rotary_base=0,
101
102
  rotary_percentage=0.0,
102
103
  qkv_use_bias=True,
103
104
  qkv_transpose_before_split=True,
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
148
149
  num_heads=num_heads,
149
150
  head_dim=embedding_dim // num_heads,
150
151
  num_query_groups=num_query_groups,
152
+ rotary_base=0,
151
153
  rotary_percentage=0.0,
152
154
  qkv_use_bias=True,
153
155
  qkv_transpose_before_split=True,
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
295
295
  enable_kv_cache=False,
296
296
  qkv_transpose_before_split=True,
297
297
  qkv_fused_interleaved=False,
298
+ rotary_base=0,
298
299
  rotary_percentage=0.0,
299
300
  ),
300
301
  enable_hlfb=False,
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
351
352
  enable_kv_cache=False,
352
353
  qkv_transpose_before_split=True,
353
354
  qkv_fused_interleaved=False,
355
+ rotary_base=0,
354
356
  rotary_percentage=0.0,
355
357
  ),
356
358
  enable_hlfb=False,
@@ -199,6 +199,7 @@ def build_attention_config(
199
199
  num_heads,
200
200
  dim,
201
201
  num_query_groups,
202
+ rotary_base=0,
202
203
  rotary_percentage=0.0,
203
204
  qkv_transpose_before_split=True,
204
205
  qkv_use_bias=False,
@@ -211,6 +212,7 @@ def build_attention_config(
211
212
  num_heads=num_heads,
212
213
  head_dim=dim // num_heads,
213
214
  num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
214
216
  rotary_percentage=rotary_percentage,
215
217
  qkv_transpose_before_split=qkv_transpose_before_split,
216
218
  qkv_use_bias=qkv_use_bias,
@@ -335,8 +335,6 @@ class T5Decoder(nn.Module):
335
335
 
336
336
  self.dec_attn_mask_cache = attn_utils.build_causal_mask_cache(
337
337
  size=config.kv_cache_max,
338
- dtype=torch.float32,
339
- device=torch.device("cpu"),
340
338
  )
341
339
 
342
340
  @torch.inference_mode
@@ -44,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
44
44
  self.rope_cache = attn_utils.build_rope_cache(
45
45
  size=config.max_seq_len,
46
46
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
47
- base=10_000,
48
- condense_ratio=1,
49
- dtype=torch.float32,
50
- device=torch.device('cpu'),
47
+ base=attn_config.rotary_base,
51
48
  )
52
49
  self.mask_cache = attn_utils.build_causal_mask_cache(
53
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
50
+ size=config.max_seq_len,
54
51
  )
55
52
  self.config = config
56
53
 
@@ -93,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
93
90
  self.rope_cache = attn_utils.build_rope_cache(
94
91
  size=config.max_seq_len,
95
92
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
96
- base=10_000,
97
- condense_ratio=1,
98
- dtype=torch.float32,
99
- device=torch.device('cpu'),
93
+ base=attn_config.rotary_base,
100
94
  )
101
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
102
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
96
+ size=config.max_seq_len,
103
97
  )
104
98
  self.config = config
105
99
 
@@ -124,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
124
118
  num_heads=32,
125
119
  head_dim=4,
126
120
  num_query_groups=4,
121
+ rotary_base=10000,
127
122
  rotary_percentage=1.0,
128
123
  enable_kv_cache=False,
129
124
  )
@@ -51,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
51
51
  self.rope_cache = attn_utils.build_rope_cache(
52
52
  size=config.max_seq_len,
53
53
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
54
- base=10_000,
55
- condense_ratio=1,
56
- dtype=torch.float32,
57
- device=torch.device('cpu'),
54
+ base=attn_config.rotary_base,
58
55
  )
59
56
  self.mask_cache = attn_utils.build_causal_mask_cache(
60
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ size=config.max_seq_len,
61
58
  )
62
59
  self.config = config
63
60
 
@@ -91,6 +88,7 @@ def get_model_config() -> cfg.ModelConfig:
91
88
  num_heads=32,
92
89
  head_dim=4,
93
90
  num_query_groups=4,
91
+ rotary_base=10000,
94
92
  rotary_percentage=1.0,
95
93
  )
96
94
  ff_config = cfg.FeedForwardConfig(
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
67
67
  self.rope_cache = attn_utils.build_rope_cache(
68
68
  size=config.kv_cache_max,
69
69
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=10_000,
71
- condense_ratio=1,
72
- dtype=torch.float32,
73
- device=torch.device("cpu"),
70
+ base=attn_config.rotary_base,
74
71
  )
75
72
  self.mask_cache = attn_utils.build_causal_mask_cache(
76
73
  size=config.kv_cache_max,
77
- dtype=torch.float32,
78
- device=torch.device("cpu"),
79
74
  )
80
75
  self.config = config
81
76
 
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
127
  num_heads=32,
133
128
  head_dim=64,
134
129
  num_query_groups=4,
130
+ rotary_base=10000,
135
131
  rotary_percentage=1.0,
136
132
  )
137
133
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "Show me the program to add 2 and 3.",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(
40
- checkpoint, trust_remote_code=True
41
- ),
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
42
46
  )
47
+
43
48
  # Locate the cached dir.
44
49
  cached_config_file = transformers.utils.cached_file(
45
50
  checkpoint, transformers.utils.CONFIG_NAME
@@ -52,10 +57,13 @@ def main(_):
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  atol=1e-04,
60
68
  )
61
69
 
@@ -83,6 +83,8 @@ class AttentionConfig:
83
83
  # Used to determine number of groups in grouped query attention (GQA)
84
84
  # https://arxiv.org/pdf/2305.13245.pdf
85
85
  num_query_groups: Optional[int]
86
+ # Base of rotary positional embedding.
87
+ rotary_base: int = 10_000
86
88
  # Percentage of Rotary Positional Embedding added Q and K projections.
87
89
  rotary_percentage: Optional[float] = None
88
90
  # Whether to transpose the query groups of qkv bundled tensor before
@@ -19,6 +19,7 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
24
25
  from ai_edge_torch.generative.examples.phi import phi3
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
102
103
  pytorch_model = gemma2.Gemma2(config).eval()
103
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
105
115
  @googletest.skipIf(
106
116
  ai_edge_config.Config.use_torch_xla,
107
117
  reason="tests with custom ops are not supported on oss",
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)