ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
- ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +14 -7
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +17 -24
- ai_edge_torch/generative/examples/phi/verify.py +8 -9
- ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +117 -97
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +36 -29
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import pathlib
|
|
20
20
|
from absl import app
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
import transformers
|
25
26
|
|
@@ -29,15 +30,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
29
30
|
"What is the meaning of life?",
|
30
31
|
"The input prompts to generate answers.",
|
31
32
|
)
|
33
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
34
|
+
"max_new_tokens",
|
35
|
+
30,
|
36
|
+
"The maximum size of the generated tokens.",
|
37
|
+
)
|
32
38
|
|
33
39
|
|
34
40
|
def main(_):
|
35
41
|
checkpoint = "apple/OpenELM-3B"
|
36
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
37
|
-
|
38
|
-
|
39
|
-
checkpoint, trust_remote_code=True
|
40
|
-
),
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
44
|
+
checkpoint, trust_remote_code=True
|
41
45
|
)
|
42
46
|
|
43
47
|
# Locate the cached dir.
|
@@ -53,10 +57,13 @@ def main(_):
|
|
53
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
54
58
|
|
55
59
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
67
|
)
|
61
68
|
|
62
69
|
|
@@ -65,15 +65,10 @@ class Phi2(nn.Module):
|
|
65
65
|
self.rope_cache = attn_utils.build_rope_cache(
|
66
66
|
size=config.kv_cache_max,
|
67
67
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
68
|
-
base=
|
69
|
-
condense_ratio=1,
|
70
|
-
dtype=torch.float32,
|
71
|
-
device=torch.device("cpu"),
|
68
|
+
base=attn_config.rotary_base,
|
72
69
|
)
|
73
70
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
74
71
|
size=config.kv_cache_max,
|
75
|
-
dtype=torch.float32,
|
76
|
-
device=torch.device("cpu"),
|
77
72
|
)
|
78
73
|
self.config = config
|
79
74
|
|
@@ -129,6 +124,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
129
124
|
num_heads=32,
|
130
125
|
head_dim=80,
|
131
126
|
num_query_groups=32,
|
127
|
+
rotary_base=10000,
|
132
128
|
rotary_percentage=0.4,
|
133
129
|
qkv_use_bias=True,
|
134
130
|
output_proj_use_bias=True,
|
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
|
|
97
97
|
]
|
98
98
|
|
99
99
|
|
100
|
-
def
|
100
|
+
def _build_rope_cache(
|
101
101
|
size: int,
|
102
102
|
dim: int,
|
103
|
-
base: int
|
104
|
-
condense_ratio: int
|
105
|
-
dtype: torch.dtype
|
106
|
-
device: torch.device
|
107
|
-
theta_factors: torch.Tensor
|
108
|
-
scale: float
|
103
|
+
base: int,
|
104
|
+
condense_ratio: int,
|
105
|
+
dtype: torch.dtype,
|
106
|
+
device: torch.device,
|
107
|
+
theta_factors: torch.Tensor,
|
108
|
+
scale: float,
|
109
109
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
110
110
|
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
|
111
111
|
|
@@ -116,26 +116,20 @@ def build_rope_cache(
|
|
116
116
|
Args:
|
117
117
|
size (int): The size of the built cache.
|
118
118
|
dim (int): Each sequence's dimmension.
|
119
|
-
base (int, optional): Rope base value.
|
119
|
+
base (int, optional): Rope base value.
|
120
120
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
121
|
-
condensed.
|
122
|
-
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
-
|
124
|
-
device (torch.device, optional): Output tensor's data type. Defaults to
|
125
|
-
None in which case "cpu" is used.
|
121
|
+
condensed.
|
122
|
+
dtype (torch.dtype, optional): Output tensor's data type.
|
123
|
+
device (torch.device, optional): Output tensor's data type.
|
126
124
|
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
|
127
|
-
scale the theta values.
|
128
|
-
scale (float, optional): A float used to scale the rope values.
|
129
|
-
to 1.0.
|
125
|
+
scale the theta values.
|
126
|
+
scale (float, optional): A float used to scale the rope values.
|
130
127
|
|
131
128
|
Returns:
|
132
129
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
133
130
|
"""
|
134
|
-
if device is None:
|
135
|
-
device = torch.device('cpu')
|
136
131
|
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
137
|
-
|
138
|
-
theta = theta / theta_factors
|
132
|
+
theta = theta / theta_factors
|
139
133
|
seq_idx = torch.arange(size) / condense_ratio
|
140
134
|
idx_theta = torch.outer(seq_idx, theta)
|
141
135
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -167,10 +161,10 @@ class Phi3_5Mini(nn.Module):
|
|
167
161
|
config.final_norm_config,
|
168
162
|
)
|
169
163
|
attn_config = block_config.attn_config
|
170
|
-
self.rope_cache =
|
164
|
+
self.rope_cache = _build_rope_cache(
|
171
165
|
size=config.kv_cache_max,
|
172
166
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
173
|
-
base=
|
167
|
+
base=attn_config.rotary_base,
|
174
168
|
condense_ratio=1,
|
175
169
|
dtype=torch.float32,
|
176
170
|
device=torch.device("cpu"),
|
@@ -181,8 +175,6 @@ class Phi3_5Mini(nn.Module):
|
|
181
175
|
)
|
182
176
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
183
177
|
size=config.kv_cache_max,
|
184
|
-
dtype=torch.float32,
|
185
|
-
device=torch.device("cpu"),
|
186
178
|
)
|
187
179
|
self.config = config
|
188
180
|
|
@@ -238,6 +230,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
238
230
|
num_heads=32,
|
239
231
|
head_dim=96,
|
240
232
|
num_query_groups=32,
|
233
|
+
rotary_base=10000,
|
241
234
|
rotary_percentage=1.0,
|
242
235
|
qkv_transpose_before_split=True,
|
243
236
|
)
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from absl import app
|
20
20
|
from absl import flags
|
21
21
|
from ai_edge_torch.generative.examples.phi import phi2
|
22
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
22
23
|
from ai_edge_torch.generative.utilities import verifier
|
23
24
|
import kagglehub
|
24
25
|
import transformers
|
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
39
40
|
def main(_):
|
40
41
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
41
42
|
logging.info("Loading the original model from: %s", checkpoint)
|
42
|
-
|
43
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
44
|
-
wrapper_model = verifier.ModelWrapper(
|
45
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
46
|
-
hf_generation_config=generation_config,
|
47
|
-
)
|
43
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
48
44
|
|
49
45
|
logging.info("Building the reauthored model from: %s", checkpoint)
|
50
46
|
reauthored_model = phi2.build_model(checkpoint)
|
@@ -53,10 +49,13 @@ def main(_):
|
|
53
49
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
54
50
|
|
55
51
|
verifier.verify_reauthored_model(
|
56
|
-
original_model=
|
57
|
-
|
58
|
-
|
52
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
53
|
+
original_model
|
54
|
+
),
|
55
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
56
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
59
57
|
generate_prompts=_PROMPTS.value,
|
58
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
60
59
|
atol=1e-03,
|
61
60
|
)
|
62
61
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi3
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
40
41
|
def main(_):
|
41
42
|
checkpoint = "microsoft/Phi-3.5-mini-instruct"
|
42
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
43
|
-
|
44
|
-
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
45
|
-
wrapper_model = verifier.ModelWrapper(
|
46
|
-
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
47
|
-
hf_generation_config=generation_config,
|
48
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
49
45
|
|
50
46
|
# Locate the cached dir.
|
51
47
|
cached_config_file = transformers.utils.cached_file(
|
@@ -59,10 +55,13 @@ def main(_):
|
|
59
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
60
56
|
|
61
57
|
verifier.verify_reauthored_model(
|
62
|
-
original_model=
|
63
|
-
|
64
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
65
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
66
65
|
)
|
67
66
|
|
68
67
|
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"What is the meaning of life?",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
|
+
|
41
46
|
# Locate the cached dir.
|
42
47
|
cached_config_file = transformers.utils.cached_file(
|
43
48
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -50,10 +55,13 @@ def main(_):
|
|
50
55
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
51
56
|
|
52
57
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
54
|
-
|
55
|
-
|
58
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
59
|
+
original_model
|
60
|
+
),
|
61
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
62
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
56
63
|
generate_prompts=_PROMPTS.value,
|
64
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
57
65
|
atol=1e-04,
|
58
66
|
)
|
59
67
|
|
@@ -98,6 +98,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
98
98
|
num_heads=num_heads,
|
99
99
|
head_dim=embedding_dim // num_heads,
|
100
100
|
num_query_groups=num_query_groups,
|
101
|
+
rotary_base=0,
|
101
102
|
rotary_percentage=0.0,
|
102
103
|
qkv_use_bias=True,
|
103
104
|
qkv_transpose_before_split=True,
|
@@ -148,6 +149,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
148
149
|
num_heads=num_heads,
|
149
150
|
head_dim=embedding_dim // num_heads,
|
150
151
|
num_query_groups=num_query_groups,
|
152
|
+
rotary_base=0,
|
151
153
|
rotary_percentage=0.0,
|
152
154
|
qkv_use_bias=True,
|
153
155
|
qkv_transpose_before_split=True,
|
@@ -295,6 +295,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
295
295
|
enable_kv_cache=False,
|
296
296
|
qkv_transpose_before_split=True,
|
297
297
|
qkv_fused_interleaved=False,
|
298
|
+
rotary_base=0,
|
298
299
|
rotary_percentage=0.0,
|
299
300
|
),
|
300
301
|
enable_hlfb=False,
|
@@ -351,6 +352,7 @@ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
351
352
|
enable_kv_cache=False,
|
352
353
|
qkv_transpose_before_split=True,
|
353
354
|
qkv_fused_interleaved=False,
|
355
|
+
rotary_base=0,
|
354
356
|
rotary_percentage=0.0,
|
355
357
|
),
|
356
358
|
enable_hlfb=False,
|
@@ -199,6 +199,7 @@ def build_attention_config(
|
|
199
199
|
num_heads,
|
200
200
|
dim,
|
201
201
|
num_query_groups,
|
202
|
+
rotary_base=0,
|
202
203
|
rotary_percentage=0.0,
|
203
204
|
qkv_transpose_before_split=True,
|
204
205
|
qkv_use_bias=False,
|
@@ -211,6 +212,7 @@ def build_attention_config(
|
|
211
212
|
num_heads=num_heads,
|
212
213
|
head_dim=dim // num_heads,
|
213
214
|
num_query_groups=num_query_groups,
|
215
|
+
rotary_base=rotary_base,
|
214
216
|
rotary_percentage=rotary_percentage,
|
215
217
|
qkv_transpose_before_split=qkv_transpose_before_split,
|
216
218
|
qkv_use_bias=qkv_use_bias,
|
@@ -44,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
44
44
|
self.rope_cache = attn_utils.build_rope_cache(
|
45
45
|
size=config.max_seq_len,
|
46
46
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
47
|
-
base=
|
48
|
-
condense_ratio=1,
|
49
|
-
dtype=torch.float32,
|
50
|
-
device=torch.device('cpu'),
|
47
|
+
base=attn_config.rotary_base,
|
51
48
|
)
|
52
49
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
53
|
-
size=config.max_seq_len,
|
50
|
+
size=config.max_seq_len,
|
54
51
|
)
|
55
52
|
self.config = config
|
56
53
|
|
@@ -93,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
|
93
90
|
self.rope_cache = attn_utils.build_rope_cache(
|
94
91
|
size=config.max_seq_len,
|
95
92
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
96
|
-
base=
|
97
|
-
condense_ratio=1,
|
98
|
-
dtype=torch.float32,
|
99
|
-
device=torch.device('cpu'),
|
93
|
+
base=attn_config.rotary_base,
|
100
94
|
)
|
101
95
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
102
|
-
size=config.max_seq_len,
|
96
|
+
size=config.max_seq_len,
|
103
97
|
)
|
104
98
|
self.config = config
|
105
99
|
|
@@ -124,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
124
118
|
num_heads=32,
|
125
119
|
head_dim=4,
|
126
120
|
num_query_groups=4,
|
121
|
+
rotary_base=10000,
|
127
122
|
rotary_percentage=1.0,
|
128
123
|
enable_kv_cache=False,
|
129
124
|
)
|
@@ -51,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
51
51
|
self.rope_cache = attn_utils.build_rope_cache(
|
52
52
|
size=config.max_seq_len,
|
53
53
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
54
|
-
base=
|
55
|
-
condense_ratio=1,
|
56
|
-
dtype=torch.float32,
|
57
|
-
device=torch.device('cpu'),
|
54
|
+
base=attn_config.rotary_base,
|
58
55
|
)
|
59
56
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
60
|
-
size=config.max_seq_len,
|
57
|
+
size=config.max_seq_len,
|
61
58
|
)
|
62
59
|
self.config = config
|
63
60
|
|
@@ -91,6 +88,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
91
88
|
num_heads=32,
|
92
89
|
head_dim=4,
|
93
90
|
num_query_groups=4,
|
91
|
+
rotary_base=10000,
|
94
92
|
rotary_percentage=1.0,
|
95
93
|
)
|
96
94
|
ff_config = cfg.FeedForwardConfig(
|
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
|
|
67
67
|
self.rope_cache = attn_utils.build_rope_cache(
|
68
68
|
size=config.kv_cache_max,
|
69
69
|
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
70
|
-
base=
|
71
|
-
condense_ratio=1,
|
72
|
-
dtype=torch.float32,
|
73
|
-
device=torch.device("cpu"),
|
70
|
+
base=attn_config.rotary_base,
|
74
71
|
)
|
75
72
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
76
73
|
size=config.kv_cache_max,
|
77
|
-
dtype=torch.float32,
|
78
|
-
device=torch.device("cpu"),
|
79
74
|
)
|
80
75
|
self.config = config
|
81
76
|
|
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
132
127
|
num_heads=32,
|
133
128
|
head_dim=64,
|
134
129
|
num_query_groups=4,
|
130
|
+
rotary_base=10000,
|
135
131
|
rotary_percentage=1.0,
|
136
132
|
)
|
137
133
|
ff_config = cfg.FeedForwardConfig(
|
@@ -21,6 +21,7 @@ import pathlib
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
24
25
|
from ai_edge_torch.generative.utilities import verifier
|
25
26
|
import transformers
|
26
27
|
|
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
30
31
|
"Show me the program to add 2 and 3.",
|
31
32
|
"The input prompts to generate answers.",
|
32
33
|
)
|
34
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
35
|
+
"max_new_tokens",
|
36
|
+
30,
|
37
|
+
"The maximum size of the generated tokens.",
|
38
|
+
)
|
33
39
|
|
34
40
|
|
35
41
|
def main(_):
|
36
42
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
37
43
|
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
-
|
39
|
-
|
40
|
-
checkpoint, trust_remote_code=True
|
41
|
-
),
|
44
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
|
45
|
+
checkpoint, trust_remote_code=True
|
42
46
|
)
|
47
|
+
|
43
48
|
# Locate the cached dir.
|
44
49
|
cached_config_file = transformers.utils.cached_file(
|
45
50
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -52,10 +57,13 @@ def main(_):
|
|
52
57
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
53
58
|
|
54
59
|
verifier.verify_reauthored_model(
|
55
|
-
original_model=
|
56
|
-
|
57
|
-
|
60
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
61
|
+
original_model
|
62
|
+
),
|
63
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
64
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
58
65
|
generate_prompts=_PROMPTS.value,
|
66
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
59
67
|
atol=1e-04,
|
60
68
|
)
|
61
69
|
|
@@ -83,6 +83,8 @@ class AttentionConfig:
|
|
83
83
|
# Used to determine number of groups in grouped query attention (GQA)
|
84
84
|
# https://arxiv.org/pdf/2305.13245.pdf
|
85
85
|
num_query_groups: Optional[int]
|
86
|
+
# Base of rotary positional embedding.
|
87
|
+
rotary_base: int = 10_000
|
86
88
|
# Percentage of Rotary Positional Embedding added Q and K projections.
|
87
89
|
rotary_percentage: Optional[float] = None
|
88
90
|
# Whether to transpose the query groups of qkv bundled tensor before
|
@@ -19,6 +19,7 @@ import ai_edge_torch
|
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
20
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.llama import llama
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
24
|
from ai_edge_torch.generative.examples.phi import phi2
|
24
25
|
from ai_edge_torch.generative.examples.phi import phi3
|
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
|
|
102
103
|
pytorch_model = gemma2.Gemma2(config).eval()
|
103
104
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
104
105
|
|
106
|
+
@googletest.skipIf(
|
107
|
+
ai_edge_config.Config.use_torch_xla,
|
108
|
+
reason="tests with custom ops are not supported on oss",
|
109
|
+
)
|
110
|
+
def test_llama(self):
|
111
|
+
config = llama.get_fake_model_config()
|
112
|
+
pytorch_model = llama.Llama(config).eval()
|
113
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
114
|
+
|
105
115
|
@googletest.skipIf(
|
106
116
|
ai_edge_config.Config.use_torch_xla,
|
107
117
|
reason="tests with custom ops are not supported on oss",
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utilities for the models predefined in HuggingFace transformers."""
|
17
|
+
|
18
|
+
from typing import cast
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.utilities import verifier
|
21
|
+
import torch
|
22
|
+
import transformers
|
23
|
+
|
24
|
+
|
25
|
+
class TransformersModelWrapper(verifier.ModelWrapper):
|
26
|
+
"""A wrapper for the model predefined in HuggingFace transformers.
|
27
|
+
|
28
|
+
Verifier expects forward() to return logits while Transformers models return
|
29
|
+
an object with `logits` field.
|
30
|
+
|
31
|
+
Transformers models get `max_new_tokens` settings for generate() via
|
32
|
+
GenerationConfig.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
36
|
+
return self.model.forward(tokens).logits
|
37
|
+
|
38
|
+
def generate(
|
39
|
+
self, inputs: torch.Tensor, max_new_tokens: int
|
40
|
+
) -> torch.IntTensor:
|
41
|
+
gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
|
42
|
+
return self.model.generate(inputs=inputs, generation_config=gen_config)
|