ai-edge-torch-nightly 0.3.0.dev20250108__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/model_builder.py +16 -12
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +22 -21
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -140,29 +136,51 @@ class Gemma2(nn.Module):
|
|
140
136
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
141
137
|
f" {self.config.max_seq_len}"
|
142
138
|
)
|
139
|
+
|
140
|
+
# token embeddings of shape (b, t, n_embd)
|
141
|
+
input_embeds = self.tok_embedding(tokens)
|
142
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
+
attn_config = self.config.block_config(0).attn_config
|
144
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
+
rope = rotary_pos_emb.build_rope(
|
146
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
+
)
|
148
|
+
mask = [
|
149
|
+
self.get_attention_mask(
|
150
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
151
|
+
)
|
152
|
+
for i in range(self.config.num_layers)
|
153
|
+
]
|
154
|
+
|
155
|
+
return self._forward_with_embeds(
|
156
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
157
|
+
)
|
158
|
+
|
159
|
+
def _forward_with_embeds(
|
160
|
+
self,
|
161
|
+
input_embeds: torch.Tensor,
|
162
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
163
|
+
mask: List[torch.Tensor],
|
164
|
+
input_pos: torch.Tensor,
|
165
|
+
kv_cache: kv_utils.KVCache,
|
166
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
167
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
168
|
+
"""Forwards the model with input embeddings."""
|
143
169
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
144
170
|
"The number of transformer blocks and the number of KV cache entries"
|
145
171
|
" must be the same."
|
146
172
|
)
|
147
173
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(tokens)
|
154
|
-
x = x * (self.config.embedding_dim**0.5)
|
155
|
-
|
156
|
-
updated_kv_entires = []
|
174
|
+
if self.config.embedding_scale is not None:
|
175
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
176
|
+
x = input_embeds
|
177
|
+
updated_kv_entries = []
|
157
178
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
-
mask = self.get_attention_mask(
|
159
|
-
block.config.attn_config.attn_type, input_pos
|
160
|
-
)
|
161
179
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
180
|
+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
|
163
181
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
182
|
+
updated_kv_entries.append(kv_entry)
|
183
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
184
|
|
167
185
|
if export_config is not None:
|
168
186
|
if (
|
@@ -228,11 +246,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
228
246
|
)
|
229
247
|
|
230
248
|
num_layers = 26
|
249
|
+
embedding_dim = 2304
|
231
250
|
config = cfg.ModelConfig(
|
232
251
|
vocab_size=256000,
|
233
252
|
num_layers=num_layers,
|
234
253
|
max_seq_len=8192,
|
235
|
-
embedding_dim=
|
254
|
+
embedding_dim=embedding_dim,
|
255
|
+
embedding_scale=embedding_dim**0.5,
|
236
256
|
kv_cache_max_len=kv_cache_max_len,
|
237
257
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
238
258
|
final_norm_config=norm_config,
|
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
249
269
|
config.num_layers = 2
|
250
270
|
config.max_seq_len = 2 * kv_cache_max_len
|
251
271
|
config.embedding_dim = 128
|
272
|
+
config.embedding_scale = config.embedding_dim**0.5
|
252
273
|
config.block_configs = config.block_configs[: config.num_layers]
|
253
274
|
for block_config in config.block_configs:
|
254
275
|
block_config.attn_config.num_heads = 4
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building Llama 3.2 models."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
|
|
26
27
|
|
27
28
|
|
28
29
|
def _build_llama3_rope_cache(
|
29
|
-
|
30
|
-
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
n_elem: int,
|
31
32
|
base: int,
|
32
33
|
condense_ratio: int,
|
33
34
|
dtype: torch.dtype,
|
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
|
|
36
37
|
low_freq_factor: float,
|
37
38
|
high_freq_factor: float,
|
38
39
|
max_seq_len: int,
|
40
|
+
**kwargs,
|
39
41
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40
|
-
"""
|
42
|
+
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
|
41
43
|
|
42
44
|
It's a modified version of attn_utils.build_rope_cache with additional
|
43
45
|
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
|
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
|
|
47
49
|
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
|
48
50
|
|
49
51
|
Args:
|
50
|
-
|
51
|
-
|
52
|
-
base (int
|
53
|
-
condense_ratio (int
|
54
|
-
|
55
|
-
|
56
|
-
device (torch.device, optional): Output tensor's data type.
|
52
|
+
input_pos (torch.Tensor): the given input sequence positions
|
53
|
+
n_elem (int): Each sequence's dimmension.
|
54
|
+
base (int): Rope base value.
|
55
|
+
condense_ratio (int): The ratio by which sequence indicies are condensed.
|
56
|
+
dtype (torch.dtype): Output tensor's data type.
|
57
|
+
device (torch.device): Output tensor's data type.
|
57
58
|
factor (float): Factor to scale theta down for tokens in long range in the
|
58
59
|
sequence.
|
59
60
|
low_freq_factor (float): Factor to determine if tokens are in long range
|
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
|
|
66
67
|
Returns:
|
67
68
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
68
69
|
"""
|
69
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
70
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
70
71
|
low_freq_wavelen = max_seq_len / low_freq_factor
|
71
72
|
high_freq_wavelen = max_seq_len / high_freq_factor
|
72
73
|
wavelen = 2 * math.pi / theta
|
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
|
|
81
82
|
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
82
83
|
theta = torch.where(is_medium, smoothed_theta, theta)
|
83
84
|
|
84
|
-
seq_idx =
|
85
|
+
seq_idx = input_pos / condense_ratio
|
85
86
|
idx_theta = torch.outer(seq_idx, theta)
|
86
87
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
|
87
88
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
|
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
|
|
97
98
|
def __init__(self, config: cfg.ModelConfig):
|
98
99
|
super().__init__(config)
|
99
100
|
attn_config = self.config.block_config(0).attn_config
|
100
|
-
self.rope_cache = _build_llama3_rope_cache(
|
101
|
-
size=self.config.kv_cache_max,
|
102
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
103
|
-
base=attn_config.rotary_base,
|
104
|
-
condense_ratio=1,
|
105
|
-
dtype=torch.float32,
|
106
|
-
device=torch.device("cpu"),
|
107
|
-
factor=32.0,
|
108
|
-
low_freq_factor=1.0,
|
109
|
-
high_freq_factor=4.0,
|
110
|
-
max_seq_len=self.config.max_seq_len,
|
111
|
-
)
|
112
101
|
|
113
102
|
|
114
103
|
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
129
|
pre_attention_norm_config=norm_config,
|
141
130
|
post_attention_norm_config=norm_config,
|
142
131
|
)
|
132
|
+
|
133
|
+
max_seq_len = 8192
|
134
|
+
# Create the RoPE callable
|
135
|
+
build_rope = partial(
|
136
|
+
_build_llama3_rope_cache,
|
137
|
+
condense_ratio=1,
|
138
|
+
dtype=torch.float32,
|
139
|
+
device=torch.device("cpu"),
|
140
|
+
factor=32.0,
|
141
|
+
low_freq_factor=1.0,
|
142
|
+
high_freq_factor=4.0,
|
143
|
+
max_seq_len=max_seq_len,
|
144
|
+
)
|
145
|
+
|
143
146
|
config = cfg.ModelConfig(
|
144
147
|
vocab_size=128256,
|
145
148
|
num_layers=16,
|
146
|
-
max_seq_len=
|
149
|
+
max_seq_len=max_seq_len,
|
147
150
|
embedding_dim=2048,
|
148
151
|
kv_cache_max_len=kv_cache_max_len,
|
149
152
|
block_configs=block_config,
|
150
153
|
final_norm_config=norm_config,
|
151
154
|
enable_hlfb=True,
|
155
|
+
build_rope=build_rope,
|
152
156
|
)
|
153
157
|
return config
|
154
158
|
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
|
|
93
94
|
]
|
94
95
|
|
95
96
|
|
96
|
-
def
|
97
|
-
|
98
|
-
|
97
|
+
def _build_phi3_rope(
|
98
|
+
input_pos: int,
|
99
|
+
n_elem: int,
|
99
100
|
base: int,
|
100
101
|
condense_ratio: int,
|
101
102
|
dtype: torch.dtype,
|
102
103
|
device: torch.device,
|
103
104
|
theta_factors: torch.Tensor,
|
104
105
|
scale: float,
|
106
|
+
**kwargs,
|
105
107
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
-
"""
|
108
|
+
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
|
107
109
|
|
108
110
|
It's a modified version of attn_utils.build_rope_cache with additional
|
109
111
|
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
|
110
112
|
Cos values with scaling factors for quick lookup during the inference.
|
111
113
|
|
112
114
|
Args:
|
113
|
-
|
114
|
-
|
115
|
+
input_pos (torch.Tensor): the given input sequence positions
|
116
|
+
n_elem (int): Each sequence's dimmension.
|
115
117
|
base (int, optional): Rope base value.
|
116
118
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
117
119
|
condensed.
|
118
120
|
dtype (torch.dtype, optional): Output tensor's data type.
|
119
121
|
device (torch.device, optional): Output tensor's data type.
|
120
|
-
theta_factors (torch.Tensor, optional): A tensor of shape (
|
121
|
-
scale the theta values.
|
122
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
|
123
|
+
to scale the theta values.
|
122
124
|
scale (float, optional): A float used to scale the rope values.
|
123
125
|
|
124
126
|
Returns:
|
125
127
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
126
128
|
"""
|
127
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
129
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
128
130
|
theta = theta / theta_factors
|
129
|
-
seq_idx =
|
131
|
+
seq_idx = input_pos / condense_ratio
|
130
132
|
idx_theta = torch.outer(seq_idx, theta)
|
131
133
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
132
134
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
|
139
141
|
def __init__(self, config: cfg.ModelConfig):
|
140
142
|
super().__init__(config)
|
141
143
|
attn_config = self.config.block_config(0).attn_config
|
142
|
-
self.rope_cache = _build_rope_cache(
|
143
|
-
size=self.config.kv_cache_max,
|
144
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
145
|
-
base=attn_config.rotary_base,
|
146
|
-
condense_ratio=1,
|
147
|
-
dtype=torch.float32,
|
148
|
-
device=torch.device("cpu"),
|
149
|
-
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
150
|
-
scale=math.sqrt(
|
151
|
-
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
152
|
-
),
|
153
|
-
)
|
154
144
|
|
155
145
|
|
156
146
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
183
173
|
pre_attention_norm_config=norm_config,
|
184
174
|
post_attention_norm_config=norm_config,
|
185
175
|
)
|
176
|
+
max_seq_len = 4096
|
177
|
+
# Create the RoPE callable
|
178
|
+
build_rope = partial(
|
179
|
+
_build_phi3_rope,
|
180
|
+
condense_ratio=1,
|
181
|
+
dtype=torch.float32,
|
182
|
+
device=torch.device("cpu"),
|
183
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
184
|
+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
|
185
|
+
max_seq_len=max_seq_len,
|
186
|
+
)
|
187
|
+
|
186
188
|
config = cfg.ModelConfig(
|
187
189
|
vocab_size=32064,
|
188
190
|
num_layers=32,
|
189
|
-
max_seq_len=
|
191
|
+
max_seq_len=max_seq_len,
|
190
192
|
kv_cache_max_len=kv_cache_max_len,
|
191
193
|
embedding_dim=3072,
|
192
194
|
block_configs=block_config,
|
193
195
|
final_norm_config=norm_config,
|
194
196
|
lm_head_share_weight_with_embedding=False,
|
195
197
|
enable_hlfb=True,
|
198
|
+
build_rope=build_rope,
|
196
199
|
)
|
197
200
|
return config
|
198
201
|
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting SmolLM2 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
33
|
+
'tflite_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The tflite file path to export.',
|
36
|
+
)
|
37
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
|
+
'prefill_seq_lens',
|
39
|
+
(8, 64, 128, 256, 512, 1024),
|
40
|
+
'List of the maximum sizes of prefill input tensors.',
|
41
|
+
)
|
42
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
|
+
'kv_cache_max_len',
|
44
|
+
1280,
|
45
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
|
+
)
|
47
|
+
_QUANTIZE = flags.DEFINE_bool(
|
48
|
+
'quantize',
|
49
|
+
True,
|
50
|
+
'Whether the model should be quantized.',
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def main(_):
|
55
|
+
pytorch_model = smollm.build_model_v2(
|
56
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
|
+
)
|
58
|
+
|
59
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
+
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
+
converter.convert_to_tflite(
|
62
|
+
pytorch_model,
|
63
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
64
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
|
+
quantize=_QUANTIZE.value,
|
66
|
+
export_config=ExportConfig(),
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
if __name__ == '__main__':
|
71
|
+
app.run(main)
|
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
85
85
|
tensor_names=TENSOR_NAMES,
|
86
86
|
model_class=SmolLM,
|
87
87
|
)
|
88
|
+
|
89
|
+
|
90
|
+
class SmolLM2(model_builder.DecoderOnlyModel):
|
91
|
+
"""A SmolLM2 model built from the Edge Generative API layers."""
|
92
|
+
pass
|
93
|
+
|
94
|
+
|
95
|
+
def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
96
|
+
"""Returns the model config for a SmolLM2 135M model.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
100
|
+
is 1024.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
The model config for a SmolLM2 model.
|
104
|
+
"""
|
105
|
+
config = get_model_config(kv_cache_max_len)
|
106
|
+
config.block_config(0).attn_config.rotary_base = 100000
|
107
|
+
return config
|
108
|
+
|
109
|
+
|
110
|
+
def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
111
|
+
config = get_model_config_v2(**kwargs)
|
112
|
+
config.vocab_size = 128
|
113
|
+
config.num_layers = 2
|
114
|
+
# SmolLM2 has only one block config.
|
115
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
116
|
+
return config
|
117
|
+
|
118
|
+
|
119
|
+
def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
|
120
|
+
return model_builder.build_decoder_only_model(
|
121
|
+
checkpoint_path=checkpoint_path,
|
122
|
+
config=get_model_config_v2(**kwargs),
|
123
|
+
tensor_names=TENSOR_NAMES,
|
124
|
+
model_class=SmolLM2,
|
125
|
+
)
|
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
36
36
|
30,
|
37
37
|
"The maximum size of the generated tokens.",
|
38
38
|
)
|
39
|
+
_MODEL_VERSION = flags.DEFINE_enum(
|
40
|
+
"model_version",
|
41
|
+
"v1",
|
42
|
+
["v1", "v2"],
|
43
|
+
"The version of SmolLm to verify.",
|
44
|
+
)
|
45
|
+
_CHECKPOINT = {
|
46
|
+
"v1": "HuggingFaceTB/SmolLM-135M",
|
47
|
+
"v2": "HuggingFaceTB/SmolLM2-135M",
|
48
|
+
}
|
49
|
+
|
50
|
+
_BUILDER = {
|
51
|
+
"v1": smollm.build_model,
|
52
|
+
"v2": smollm.build_model_v2,
|
53
|
+
}
|
39
54
|
|
40
55
|
|
41
56
|
def main(_):
|
42
|
-
checkpoint =
|
57
|
+
checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
|
58
|
+
builder = _BUILDER[_MODEL_VERSION.value]
|
43
59
|
logging.info("Loading the original model from: %s", checkpoint)
|
44
60
|
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
61
|
|
@@ -49,7 +65,7 @@ def main(_):
|
|
49
65
|
)
|
50
66
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
67
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model =
|
68
|
+
reauthored_model = builder(reauthored_checkpoint)
|
53
69
|
|
54
70
|
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
71
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entries = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entries.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -27,33 +27,6 @@ import torch
|
|
27
27
|
from torch import nn
|
28
28
|
|
29
29
|
|
30
|
-
def _embed_rope(
|
31
|
-
q: torch.Tensor,
|
32
|
-
k: torch.Tensor,
|
33
|
-
n_elem: int,
|
34
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
35
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
-
"""Embed rotary positional embedding for query and key.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
q (torch.Tensor): query tensor.
|
40
|
-
k (torch.Tensor): key tensor.
|
41
|
-
n_elem (int): number of elements to embed rotarty positional embedding.
|
42
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
43
|
-
"""
|
44
|
-
if n_elem > 0:
|
45
|
-
cos, sin = rope
|
46
|
-
q_roped = rotary_pos_emb.apply_rope(
|
47
|
-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
48
|
-
)
|
49
|
-
k_roped = rotary_pos_emb.apply_rope(
|
50
|
-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
51
|
-
)
|
52
|
-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
53
|
-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
54
|
-
return q, k
|
55
|
-
|
56
|
-
|
57
30
|
class TransformerBlock(nn.Module):
|
58
31
|
|
59
32
|
def __init__(
|
@@ -252,7 +225,8 @@ class CausalSelfAttention(nn.Module):
|
|
252
225
|
if rope is not None:
|
253
226
|
# Compute rotary positional embedding for query and key.
|
254
227
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
255
|
-
|
228
|
+
cos, sin = rope
|
229
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
256
230
|
|
257
231
|
if kv_cache is not None:
|
258
232
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -404,7 +378,8 @@ class CrossAttention(nn.Module):
|
|
404
378
|
if rope is not None:
|
405
379
|
# Compute rotary positional embedding for query and key.
|
406
380
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
407
|
-
|
381
|
+
cos, sin = rope
|
382
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
408
383
|
|
409
384
|
if kv_cache is not None:
|
410
385
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -17,8 +17,8 @@
|
|
17
17
|
|
18
18
|
import dataclasses
|
19
19
|
import enum
|
20
|
-
from typing import Optional, Sequence, Union
|
21
|
-
|
20
|
+
from typing import Callable, Optional, Sequence, Union
|
21
|
+
from ai_edge_torch.generative.layers import rotary_position_embedding
|
22
22
|
|
23
23
|
@enum.unique
|
24
24
|
class ActivationType(enum.Enum):
|
@@ -218,6 +218,10 @@ class ModelConfig:
|
|
218
218
|
# Softcap on the model output logits.
|
219
219
|
final_logit_softcap: Optional[float] = None
|
220
220
|
|
221
|
+
# The function to call to create the RoPE sin and cos vectors during the
|
222
|
+
# forward pass. Defaults to a standard implementation.
|
223
|
+
build_rope: Callable = rotary_position_embedding.build_rope
|
224
|
+
|
221
225
|
@property
|
222
226
|
def kv_cache_max(self) -> int:
|
223
227
|
if self.kv_cache_max_len > 0:
|
@@ -32,57 +32,63 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1 = x
|
36
|
-
|
37
|
-
|
38
|
-
roped = (
|
35
|
+
x1, x2 = torch.split(x, head_size // 2, dim=-1)
|
36
|
+
left = x1 * cos - x2 * sin
|
37
|
+
right = x2 * cos + x1 * sin
|
38
|
+
roped = torch.cat([left, right], dim=-1)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
43
|
-
q: torch.Tensor,
|
44
|
-
k: torch.Tensor,
|
42
|
+
def build_rope(
|
45
43
|
input_pos: torch.Tensor,
|
46
44
|
n_elem: int,
|
45
|
+
head_dim: int,
|
47
46
|
base: int = 10_000,
|
48
47
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
-
"""Computes rotary positional embedding
|
48
|
+
"""Computes rotary positional embedding cosine and sine tensors.
|
50
49
|
|
51
50
|
Args:
|
52
|
-
q: the query tensor.
|
53
|
-
k: the key tensor.
|
54
51
|
input_pos: the sequence indices for the query and key
|
55
52
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
+
base: the base of the exponentiated value for RoPE.
|
56
54
|
|
57
55
|
Returns:
|
58
|
-
|
56
|
+
cos, sin tensors
|
59
57
|
"""
|
60
58
|
|
61
59
|
if n_elem <= 0:
|
62
|
-
return
|
60
|
+
return None, None
|
63
61
|
|
64
|
-
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
62
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
-
|
63
|
+
head_dim // 2, dtype=torch.float32
|
67
64
|
)
|
68
65
|
timescale = float(base) ** freq_exponents
|
69
66
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
67
|
0
|
71
68
|
).unsqueeze(0)
|
72
|
-
cos = torch.cos(radians)
|
73
|
-
sin = torch.sin(radians)
|
69
|
+
cos = torch.cos(radians)
|
70
|
+
sin = torch.sin(radians)
|
71
|
+
return cos, sin
|
72
|
+
|
74
73
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
74
|
+
def apply_rope_inline(
|
75
|
+
q: torch.Tensor,
|
76
|
+
k: torch.Tensor,
|
77
|
+
cos: torch.Tensor,
|
78
|
+
sin: torch.Tensor,
|
79
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
80
|
+
"""Computes rotary positional embedding inline for a query and key.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
q: the query tensor.
|
84
|
+
k: the key tensor.
|
85
|
+
cos: the cosine tensor.
|
86
|
+
sin: the sine tensor.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
output the RoPE'd query and key.
|
90
|
+
"""
|
85
91
|
|
86
|
-
q_roped =
|
87
|
-
k_roped =
|
92
|
+
q_roped = apply_rope(q, cos, sin)
|
93
|
+
k_roped = apply_rope(k, cos, sin)
|
88
94
|
return q_roped, k_roped
|
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
|
|
150
150
|
ai_edge_torch.config.in_oss,
|
151
151
|
reason="tests with custom ops are not supported in oss",
|
152
152
|
)
|
153
|
+
|
154
|
+
def test_smollm2(self):
|
155
|
+
config = smollm.get_fake_model_config_v2()
|
156
|
+
pytorch_model = smollm.SmolLM2(config).eval()
|
157
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
158
|
+
@googletest.skipIf(
|
159
|
+
ai_edge_torch.config.in_oss,
|
160
|
+
reason="tests with custom ops are not supported in oss",
|
161
|
+
)
|
162
|
+
|
153
163
|
def test_openelm(self):
|
154
164
|
config = openelm.get_fake_model_config()
|
155
165
|
pytorch_model = openelm.OpenELM(config).eval()
|
@@ -25,6 +25,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
25
25
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
26
26
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
27
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
28
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
28
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
30
|
import torch
|
30
31
|
from torch import nn
|
@@ -87,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
|
|
87
88
|
config.embedding_dim,
|
88
89
|
config.final_norm_config,
|
89
90
|
)
|
90
|
-
# ROPE parameters for all attn_configs are the same. Take the first one.
|
91
|
-
attn_config = config.block_config(0).attn_config
|
92
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
93
|
-
size=config.kv_cache_max,
|
94
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
95
|
-
base=attn_config.rotary_base,
|
96
|
-
)
|
97
91
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
98
92
|
size=config.kv_cache_max,
|
99
93
|
)
|
@@ -116,8 +110,18 @@ class DecoderOnlyModel(nn.Module):
|
|
116
110
|
|
117
111
|
# token embeddings of shape (b, t, n_embd)
|
118
112
|
input_embeds = self.tok_embedding(tokens)
|
119
|
-
|
120
|
-
|
113
|
+
|
114
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
115
|
+
attn_config = self.config.block_config(0).attn_config
|
116
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
117
|
+
rope = self.config.build_rope(
|
118
|
+
input_pos=input_pos,
|
119
|
+
n_elem=n_elem,
|
120
|
+
base=attn_config.rotary_base,
|
121
|
+
head_dim=attn_config.head_dim,
|
122
|
+
# input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
|
123
|
+
)
|
124
|
+
|
121
125
|
mask = self.mask_cache.index_select(2, input_pos)
|
122
126
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
123
127
|
|
@@ -145,14 +149,14 @@ class DecoderOnlyModel(nn.Module):
|
|
145
149
|
if self.config.embedding_scale is not None:
|
146
150
|
x = x * self.config.embedding_scale
|
147
151
|
|
148
|
-
|
152
|
+
updated_kv_entries = []
|
149
153
|
for i, block in enumerate(self.transformer_blocks):
|
150
154
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
151
155
|
lora_adapter = lora.adapters[i] if lora else None
|
152
156
|
x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
|
153
157
|
if kv_entry:
|
154
|
-
|
155
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
158
|
+
updated_kv_entries.append(kv_entry)
|
159
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
156
160
|
|
157
161
|
if export_config is not None:
|
158
162
|
if (
|
@@ -17,7 +17,7 @@ from typing import Any
|
|
17
17
|
import uuid
|
18
18
|
|
19
19
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
20
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
22
|
import torch
|
23
23
|
|
@@ -87,7 +87,7 @@ def mark_pattern(
|
|
87
87
|
m.meta["ORIGINAL_NODE"] = n
|
88
88
|
|
89
89
|
# Sanitize graph_module to match in the same way as pattern's graph_module.
|
90
|
-
graph_module_to_match =
|
90
|
+
graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
|
91
91
|
|
92
92
|
match_with_attrs = pattern.match(graph_module_to_match)
|
93
93
|
|
@@ -111,13 +111,25 @@ def mark_pattern(
|
|
111
111
|
is_input=True,
|
112
112
|
)
|
113
113
|
|
114
|
-
# Only replace input by the marker node for those nodes used in the
|
114
|
+
# Only replace input by the marker node for those nodes used in the
|
115
|
+
# pattern.
|
115
116
|
in_pattern_nodes = set(match.nodes_map.values())
|
116
117
|
for user in input_node.users.keys():
|
117
|
-
if user in in_pattern_nodes:
|
118
|
-
|
119
|
-
|
120
|
-
|
118
|
+
if user not in in_pattern_nodes:
|
119
|
+
continue
|
120
|
+
|
121
|
+
user.meta["ORIGINAL_NODE"].replace_input_with(
|
122
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
123
|
+
)
|
124
|
+
# Pattern matching graph sanitization may remove clone ops, which means
|
125
|
+
# the user's input in the original graph may be a clone op. When
|
126
|
+
# replacing the input with the marker node, we need to further try
|
127
|
+
# replacing the input of the clone op that connects to the user.
|
128
|
+
for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
|
129
|
+
if fx_utils.is_clone_op(original_user_input):
|
130
|
+
original_user_input.replace_input_with(
|
131
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
132
|
+
)
|
121
133
|
|
122
134
|
for i, pattern_output_node in enumerate(pattern.output_nodes):
|
123
135
|
output_node = match.nodes_map[pattern_output_node]
|
@@ -12,11 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""FX graph utilities for pattern matching clean ups."""
|
16
16
|
|
17
17
|
import torch
|
18
18
|
|
19
19
|
|
20
|
+
def is_clone_op(node: torch.fx.Node) -> bool:
|
21
|
+
"""Checks if the node is a clone op."""
|
22
|
+
return (
|
23
|
+
node.op == "call_function" and node.target == torch.ops.aten.clone.default
|
24
|
+
)
|
25
|
+
|
26
|
+
|
20
27
|
def remove_clone_ops(gm: torch.fx.GraphModule):
|
21
28
|
"""Removes clone ops from the graph.
|
22
29
|
|
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
32
39
|
The graph module with clone ops removed.
|
33
40
|
"""
|
34
41
|
for node in gm.graph.nodes:
|
35
|
-
if node
|
42
|
+
if is_clone_op(node):
|
36
43
|
node.replace_all_uses_with(node.args[0])
|
37
44
|
gm.graph.erase_node(node)
|
38
45
|
|
@@ -18,13 +18,14 @@ import dataclasses
|
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
21
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
22
22
|
import torch
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
|
24
|
+
Graph = torch.fx.Graph
|
25
|
+
GraphModule = torch.fx.GraphModule
|
26
|
+
TensorArgument = torch.export.graph_signature.TensorArgument
|
27
|
+
InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
|
28
|
+
SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
|
28
29
|
|
29
30
|
|
30
31
|
def _are_equal(x: Any, y: Any) -> bool:
|
@@ -219,8 +220,8 @@ class Pattern:
|
|
219
220
|
# Sanitize graph_module for more precise pattern matching.
|
220
221
|
# The graph_module to match against this pattern should apply equivalent
|
221
222
|
# sanitization.
|
222
|
-
self.graph_module =
|
223
|
-
self.graph_module =
|
223
|
+
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
|
224
|
+
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
|
224
225
|
|
225
226
|
# Builds list of ordered input and output nodes.
|
226
227
|
self.graph_nodes_map = {}
|
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
|
|
58
58
|
{"stablehlo.custom_call @mark_tensor": 6},
|
59
59
|
)
|
60
60
|
|
61
|
+
def test_mark_pattern_with_clone_inputs(self):
|
62
|
+
|
63
|
+
class TestModel(torch.nn.Module):
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
return torch.ops.aten.clone.default(x * x) + x
|
67
|
+
|
68
|
+
pattern = pattern_module.Pattern(
|
69
|
+
"test.add",
|
70
|
+
lambda a, b: a + b,
|
71
|
+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
72
|
+
)
|
73
|
+
|
74
|
+
model = TestModel().eval()
|
75
|
+
args = (torch.rand(20, 20),)
|
76
|
+
exported_program = torch.export.export(model, args)
|
77
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
78
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
79
|
+
|
80
|
+
lowertools.assert_string_count(
|
81
|
+
self,
|
82
|
+
mlir,
|
83
|
+
{'stablehlo.composite "test.add"': 1},
|
84
|
+
{"stablehlo.custom_call @mark_tensor": 3},
|
85
|
+
)
|
86
|
+
|
61
87
|
def test_mark_pattern_with_attr_builder(self):
|
62
88
|
class TestModel(torch.nn.Module):
|
63
89
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250109
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=kM89dmK5VqznvQQJTvtq94oCbRtajNvkLPCCWSJxFSY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -47,13 +47,13 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
48
48
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
49
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=pXilP6DHqVdcFH1TpIAtcwAQZH2_jZ6Tz41ddlXZXMs,10177
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
54
54
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
55
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
|
56
|
-
ai_edge_torch/generative/examples/llama/llama.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=kWy6-V4bFtE1yguCROLJS5XB0GOJD1-acJWp2dFjB5Q,6606
|
57
57
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
58
58
|
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
59
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
@@ -76,7 +76,7 @@ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_
|
|
76
76
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=P2K6G7bNespSJLk72qxuCLaCcR_xAPs0Mn1dBZoByhE,2518
|
77
77
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=g-MvEibJT_iIhkec2VGtFFA_iP54VCq9mY4KxwAYF08,2512
|
78
78
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
|
79
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
79
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=SHvJjmi5eIch5cYIWORt6YFmSQx_oCiOk1UbKKGibtk,7119
|
80
80
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
81
81
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
82
82
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -85,8 +85,9 @@ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-L
|
|
85
85
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
86
86
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
87
87
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
88
|
-
ai_edge_torch/generative/examples/smollm/
|
89
|
-
ai_edge_torch/generative/examples/smollm/
|
88
|
+
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
|
89
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
|
90
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
|
90
91
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
91
92
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
92
93
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
@@ -109,7 +110,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
|
|
109
110
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
111
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
111
112
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
112
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
113
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
|
113
114
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
114
115
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
|
115
116
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
@@ -117,15 +118,15 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
117
118
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
118
119
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
119
120
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
120
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
121
|
+
ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
|
121
122
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
122
123
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
123
124
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
124
125
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
125
126
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
126
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
127
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=9yPEmWNw3-_2wXBmPmZ7RUKcPXHF2ZbJwksyQoXTA6M,7784
|
127
128
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
128
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
129
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=1L1MEGPYbDELi0zy2OKl7yXyk9FXdBjcXwRZbfiJriU,2619
|
129
130
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
130
131
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
131
132
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -144,25 +145,25 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
144
145
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
145
146
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
146
147
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
147
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
148
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bBcey-aD4L_TwKRrrM81bN2VQoJjPPC84Rv4o3WOc34,12491
|
148
149
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
149
150
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
150
151
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
151
152
|
ai_edge_torch/generative/utilities/converter.py,sha256=MY8BK29yD-W4v45Xdl_ErbNilipsTlD-4-y9MyBxR5g,7620
|
152
153
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
153
154
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
154
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
155
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=yAO4VcYex21fDpuApewr0cNqgmxJljxonMd6450kblg,6710
|
155
156
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
156
157
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
157
158
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
158
159
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
159
160
|
ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
|
160
161
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
161
|
-
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256
|
162
|
-
ai_edge_torch/hlfb/mark_pattern/
|
163
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
162
|
+
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=-BYE7MGMxr-VfBy8tAiiOaCqYv8ytJ0w5l2P8B7h3eM,5387
|
163
|
+
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=taWLpF5IVglxlsF9HM2dIoKDXuQREaCRAXtJeG5gKzs,2073
|
164
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=7bv9XqRkm1pjxiVL4Cm1cArExnolId8hQKFHtvlkCI8,10061
|
164
165
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
165
|
-
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256
|
166
|
+
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
|
166
167
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
167
168
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
168
169
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
@@ -205,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
205
206
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
206
207
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
207
208
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
211
|
-
ai_edge_torch_nightly-0.3.0.
|
212
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/METADATA,sha256=bkCouLqAI9GXCpiduHyj21ZElW42bdt0w6K5gWw1fOE,1966
|
211
|
+
ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
212
|
+
ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
213
|
+
ai_edge_torch_nightly-0.3.0.dev20250109.dist-info/RECORD,,
|
File without changes
|
File without changes
|