ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
- ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +14 -7
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +17 -24
- ai_edge_torch/generative/examples/phi/verify.py +8 -9
- ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +117 -97
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +36 -29
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
|
|
20
20
|
from absl import app
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
import transformers
|
25
26
|
|
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
29
30
|
"What is the meaning of life?",
|
30
31
|
"The input prompts to generate answers.",
|
31
32
|
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
32
38
|
|
33
39
|
|
34
40
|
def main(_):
|
35
41
|
checkpoint = "apple/OpenELM-3B"
|
36
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
37
|
-
|
38
|
-
|
39
|
-
checkpoint, trust_remote_code=True
|
40
|
-
),
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
44
|
+
checkpoint, trust_remote_code=True
|
41
45
|
)
|
42
46
|
|
43
47
|
# Locate the cached dir.
|
@@ -53,10 +57,13 @@ def main(_):
|
|
53
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
54
58
|
|
55
59
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
67
|
)
|
61
68
|
|
62
69
|
|
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
|
|
65
65
|
self.rope_cache = attn_utils.build_rope_cache(
|
66
66
|
size=config.kv_cache_max,
|
67
67
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
68
|
-
base=
|
69
|
-
condense_ratio=1,
|
70
|
-
dtype=torch.float32,
|
71
|
-
device=torch.device("cpu"),
|
68
|
+
base=attn_config.rotary_base,
|
72
69
|
)
|
73
70
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
74
71
|
size=config.kv_cache_max,
|
75
|
-
dtype=torch.float32,
|
76
|
-
device=torch.device("cpu"),
|
77
72
|
)
|
78
73
|
self.config = config
|
79
74
|
|
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
129
124
|
num_heads=32,
|
130
125
|
head_dim=80,
|
131
126
|
num_query_groups=32,
|
127
|
+
rotary_base=10000,
|
132
128
|
rotary_percentage=0.4,
|
133
129
|
qkv_use_bias=True,
|
134
130
|
output_proj_use_bias=True,
|
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
|
|
97
97
|
]
|
98
98
|
|
99
99
|
|
100
|
-
def
|
100
|
+
def _build_rope_cache(
|
101
101
|
size: int,
|
102
102
|
dim: int,
|
103
|
-
base: int
|
104
|
-
condense_ratio: int
|
105
|
-
dtype: torch.dtype
|
106
|
-
device: torch.device
|
107
|
-
theta_factors: torch.Tensor
|
108
|
-
scale: float
|
103
|
+
base: int,
|
104
|
+
condense_ratio: int,
|
105
|
+
dtype: torch.dtype,
|
106
|
+
device: torch.device,
|
107
|
+
theta_factors: torch.Tensor,
|
108
|
+
scale: float,
|
109
109
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
110
|
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
111
|
|
@@ -116,26 +116,20 @@ def build_rope_cache(
|
|
116
116
|
Args:
|
117
117
|
size (int): The size of the built cache.
|
118
118
|
dim (int): Each sequence's dimmension.
|
119
|
-
base (int, optional): Rope base value.
|
119
|
+
base (int, optional): Rope base value.
|
120
120
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
-
condensed.
|
122
|
-
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
-
|
124
|
-
device (torch.device, optional): Output tensor's data type. Defaults to
|
125
|
-
None in which case "cpu" is used.
|
121
|
+
condensed.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
+
device (torch.device, optional): Output tensor's data type.
|
126
124
|
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
127
|
-
scale the theta values.
|
128
|
-
scale (float, optional): A float used to scale the rope values.
|
129
|
-
to 1.0.
|
125
|
+
scale the theta values.
|
126
|
+
scale (float, optional): A float used to scale the rope values.
|
130
127
|
|
131
128
|
Returns:
|
132
129
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
133
130
|
"""
|
134
|
-
if device is None:
|
135
|
-
device = torch.device('cpu')
|
136
131
|
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
137
|
-
|
138
|
-
theta = theta / theta_factors
|
132
|
+
theta = theta / theta_factors
|
139
133
|
seq_idx = torch.arange(size) / condense_ratio
|
140
134
|
idx_theta = torch.outer(seq_idx, theta)
|
141
135
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
|
|
167
161
|
config.final_norm_config,
|
168
162
|
)
|
169
163
|
attn_config = block_config.attn_config
|
170
|
-
self.rope_cache =
|
164
|
+
self.rope_cache = _build_rope_cache(
|
171
165
|
size=config.kv_cache_max,
|
172
166
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
173
|
-
base=
|
167
|
+
base=attn_config.rotary_base,
|
174
168
|
condense_ratio=1,
|
175
169
|
dtype=torch.float32,
|
176
170
|
device=torch.device("cpu"),
|
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
|
|
181
175
|
)
|
182
176
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
183
177
|
size=config.kv_cache_max,
|
184
|
-
dtype=torch.float32,
|
185
|
-
device=torch.device("cpu"),
|
186
178
|
)
|
187
179
|
self.config = config
|
188
180
|
|
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
238
230
|
num_heads=32,
|
239
231
|
head_dim=96,
|
240
232
|
num_query_groups=32,
|
233
|
+
rotary_base=10000,
|
241
234
|
rotary_percentage=1.0,
|
242
235
|
qkv_transpose_before_split=True,
|
243
236
|
)
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from absl import app
|
20
20
|
from absl import flags
|
21
21
|
from ai_edge_torch.generative.examples.phi import phi2
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
22
23
|
from ai_edge_torch.generative.utilities import verifier
|
23
24
|
import kagglehub
|
24
25
|
import transformers
|
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
40
|
def main(_):
|
40
41
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
41
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
42
|
-
|
43
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
44
|
-
wrapper_model = verifier.ModelWrapper(
|
45
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
46
|
-
hf_generation_config=generation_config,
|
47
|
-
)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
48
44
|
|
49
45
|
logging.info("Building the reauthored model from: %s", checkpoint)
|
50
46
|
reauthored_model = phi2.build_model(checkpoint)
|
@@ -53,10 +49,13 @@ def main(_):
|
|
53
49
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
54
50
|
|
55
51
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
52
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
+
original_model
|
54
|
+
),
|
55
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
57
|
generate_prompts=_PROMPTS.value,
|
58
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
59
|
atol=1e-03,
|
61
60
|
)
|
62
61
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
40
41
|
def main(_):
|
41
42
|
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
42
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
-
|
44
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
45
|
-
wrapper_model = verifier.ModelWrapper(
|
46
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
47
|
-
hf_generation_config=generation_config,
|
48
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
49
45
|
|
50
46
|
# Locate the cached dir.
|
51
47
|
cached_config_file = transformers.utils.cached_file(
|
@@ -59,10 +55,13 @@ def main(_):
|
|
59
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
60
56
|
|
61
57
|
verifier.verify_reauthored_model(
|
62
|
-
original_model=
|
63
|
-
|
64
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
66
65
|
)
|
67
66
|
|
68
67
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"What is the meaning of life?",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
41
46
|
# Locate the cached dir.
|
42
47
|
cached_config_file = transformers.utils.cached_file(
|
43
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -50,10 +55,13 @@ def main(_):
|
|
50
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
56
|
|
52
57
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
65
|
atol=1e-04,
|
58
66
|
)
|
59
67
|
|
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
98
98
|
num_heads=num_heads,
|
99
99
|
head_dim=embedding_dim // num_heads,
|
100
100
|
num_query_groups=num_query_groups,
|
101
|
+
rotary_base=0,
|
101
102
|
rotary_percentage=0.0,
|
102
103
|
qkv_use_bias=True,
|
103
104
|
qkv_transpose_before_split=True,
|
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
148
149
|
num_heads=num_heads,
|
149
150
|
head_dim=embedding_dim // num_heads,
|
150
151
|
num_query_groups=num_query_groups,
|
152
|
+
rotary_base=0,
|
151
153
|
rotary_percentage=0.0,
|
152
154
|
qkv_use_bias=True,
|
153
155
|
qkv_transpose_before_split=True,
|
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
295
295
|
enable_kv_cache=False,
|
296
296
|
qkv_transpose_before_split=True,
|
297
297
|
qkv_fused_interleaved=False,
|
298
|
+
rotary_base=0,
|
298
299
|
rotary_percentage=0.0,
|
299
300
|
),
|
300
301
|
enable_hlfb=False,
|
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
351
352
|
enable_kv_cache=False,
|
352
353
|
qkv_transpose_before_split=True,
|
353
354
|
qkv_fused_interleaved=False,
|
355
|
+
rotary_base=0,
|
354
356
|
rotary_percentage=0.0,
|
355
357
|
),
|
356
358
|
enable_hlfb=False,
|
@@ -199,6 +199,7 @@ def build_attention_config(
|
|
199
199
|
num_heads,
|
200
200
|
dim,
|
201
201
|
num_query_groups,
|
202
|
+
rotary_base=0,
|
202
203
|
rotary_percentage=0.0,
|
203
204
|
qkv_transpose_before_split=True,
|
204
205
|
qkv_use_bias=False,
|
@@ -211,6 +212,7 @@ def build_attention_config(
|
|
211
212
|
num_heads=num_heads,
|
212
213
|
head_dim=dim // num_heads,
|
213
214
|
num_query_groups=num_query_groups,
|
215
|
+
rotary_base=rotary_base,
|
214
216
|
rotary_percentage=rotary_percentage,
|
215
217
|
qkv_transpose_before_split=qkv_transpose_before_split,
|
216
218
|
qkv_use_bias=qkv_use_bias,
|
@@ -44,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
44
44
|
self.rope_cache = attn_utils.build_rope_cache(
|
45
45
|
size=config.max_seq_len,
|
46
46
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
47
|
-
base=
|
48
|
-
condense_ratio=1,
|
49
|
-
dtype=torch.float32,
|
50
|
-
device=torch.device('cpu'),
|
47
|
+
base=attn_config.rotary_base,
|
51
48
|
)
|
52
49
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
53
|
-
size=config.max_seq_len,
|
50
|
+
size=config.max_seq_len,
|
54
51
|
)
|
55
52
|
self.config = config
|
56
53
|
|
@@ -93,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
|
93
90
|
self.rope_cache = attn_utils.build_rope_cache(
|
94
91
|
size=config.max_seq_len,
|
95
92
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
96
|
-
base=
|
97
|
-
condense_ratio=1,
|
98
|
-
dtype=torch.float32,
|
99
|
-
device=torch.device('cpu'),
|
93
|
+
base=attn_config.rotary_base,
|
100
94
|
)
|
101
95
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
102
|
-
size=config.max_seq_len,
|
96
|
+
size=config.max_seq_len,
|
103
97
|
)
|
104
98
|
self.config = config
|
105
99
|
|
@@ -124,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
124
118
|
num_heads=32,
|
125
119
|
head_dim=4,
|
126
120
|
num_query_groups=4,
|
121
|
+
rotary_base=10000,
|
127
122
|
rotary_percentage=1.0,
|
128
123
|
enable_kv_cache=False,
|
129
124
|
)
|
@@ -51,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
51
51
|
self.rope_cache = attn_utils.build_rope_cache(
|
52
52
|
size=config.max_seq_len,
|
53
53
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
54
|
-
base=
|
55
|
-
condense_ratio=1,
|
56
|
-
dtype=torch.float32,
|
57
|
-
device=torch.device('cpu'),
|
54
|
+
base=attn_config.rotary_base,
|
58
55
|
)
|
59
56
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
60
|
-
size=config.max_seq_len,
|
57
|
+
size=config.max_seq_len,
|
61
58
|
)
|
62
59
|
self.config = config
|
63
60
|
|
@@ -91,6 +88,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
91
88
|
num_heads=32,
|
92
89
|
head_dim=4,
|
93
90
|
num_query_groups=4,
|
91
|
+
rotary_base=10000,
|
94
92
|
rotary_percentage=1.0,
|
95
93
|
)
|
96
94
|
ff_config = cfg.FeedForwardConfig(
|
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
|
|
67
67
|
self.rope_cache = attn_utils.build_rope_cache(
|
68
68
|
size=config.kv_cache_max,
|
69
69
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
70
|
-
base=
|
71
|
-
condense_ratio=1,
|
72
|
-
dtype=torch.float32,
|
73
|
-
device=torch.device("cpu"),
|
70
|
+
base=attn_config.rotary_base,
|
74
71
|
)
|
75
72
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
76
73
|
size=config.kv_cache_max,
|
77
|
-
dtype=torch.float32,
|
78
|
-
device=torch.device("cpu"),
|
79
74
|
)
|
80
75
|
self.config = config
|
81
76
|
|
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
132
127
|
num_heads=32,
|
133
128
|
head_dim=64,
|
134
129
|
num_query_groups=4,
|
130
|
+
rotary_base=10000,
|
135
131
|
rotary_percentage=1.0,
|
136
132
|
)
|
137
133
|
ff_config = cfg.FeedForwardConfig(
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"Show me the program to add 2 and 3.",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
checkpoint, trust_remote_code=True
|
41
|
-
),
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
42
46
|
)
|
47
|
+
|
43
48
|
# Locate the cached dir.
|
44
49
|
cached_config_file = transformers.utils.cached_file(
|
45
50
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -52,10 +57,13 @@ def main(_):
|
|
52
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
53
58
|
|
54
59
|
verifier.verify_reauthored_model(
|
55
|
-
original_model=
|
56
|
-
|
57
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
58
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
67
|
atol=1e-04,
|
60
68
|
)
|
61
69
|
|
@@ -83,6 +83,8 @@ class AttentionConfig:
|
|
83
83
|
# Used to determine number of groups in grouped query attention (GQA)
|
84
84
|
# https://arxiv.org/pdf/2305.13245.pdf
|
85
85
|
num_query_groups: Optional[int]
|
86
|
+
# Base of rotary positional embedding.
|
87
|
+
rotary_base: int = 10_000
|
86
88
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
87
89
|
rotary_percentage: Optional[float] = None
|
88
90
|
# Whether to transpose the query groups of qkv bundled tensor before
|
@@ -19,6 +19,7 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.llama import llama
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
24
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
25
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
|
|
102
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
103
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
104
105
|
|
106
|
+
@googletest.skipIf(
|
107
|
+
ai_edge_config.Config.use_torch_xla,
|
108
|
+
reason="tests with custom ops are not supported on oss",
|
109
|
+
)
|
110
|
+
def test_llama(self):
|
111
|
+
config = llama.get_fake_model_config()
|
112
|
+
pytorch_model = llama.Llama(config).eval()
|
113
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
114
|
+
|
105
115
|
@googletest.skipIf(
|
106
116
|
ai_edge_config.Config.use_torch_xla,
|
107
117
|
reason="tests with custom ops are not supported on oss",
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
|
+
|
18
|
+
from typing import cast
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.utilities import verifier
|
21
|
+
import torch
|
22
|
+
import transformers
|
23
|
+
|
24
|
+
|
25
|
+
class TransformersModelWrapper(verifier.ModelWrapper):
|
26
|
+
"""A wrapper for the model predefined in HuggingFace transformers.
|
27
|
+
|
28
|
+
Verifier expects forward() to return logits while Transformers models return
|
29
|
+
an object with `logits` field.
|
30
|
+
|
31
|
+
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
+
GenerationConfig.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.model.forward(tokens).logits
|
37
|
+
|
38
|
+
def generate(
|
39
|
+
self, inputs: torch.Tensor, max_new_tokens: int
|
40
|
+
) -> torch.IntTensor:
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|