ai-edge-torch-nightly 0.3.0.dev20250108__py3-none-any.whl → 0.3.0.dev20250110__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma2.py +54 -24
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/paligemma/decoder.py +4 -2
- ai_edge_torch/generative/examples/paligemma/decoder2.py +16 -11
- ai_edge_torch/generative/examples/paligemma/paligemma.py +3 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model.py +16 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -5
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/model_builder.py +20 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/RECORD +27 -26
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250108.dist-info → ai_edge_torch_nightly-0.3.0.dev20250110.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -133,6 +129,7 @@ class Gemma2(nn.Module):
|
|
133
129
|
tokens: torch.Tensor,
|
134
130
|
input_pos: torch.Tensor,
|
135
131
|
kv_cache: kv_utils.KVCache,
|
132
|
+
mask: Optional[torch.Tensor] = None,
|
136
133
|
export_config: Optional[model_builder.ExportConfig] = None,
|
137
134
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
138
135
|
_, seq_len = tokens.size()
|
@@ -140,29 +137,59 @@ class Gemma2(nn.Module):
|
|
140
137
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
141
138
|
f" {self.config.max_seq_len}"
|
142
139
|
)
|
140
|
+
|
141
|
+
# token embeddings of shape (b, t, n_embd)
|
142
|
+
input_embeds = self.tok_embedding(tokens)
|
143
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
144
|
+
attn_config = self.config.block_config(0).attn_config
|
145
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
146
|
+
rope = rotary_pos_emb.build_rope(
|
147
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
148
|
+
)
|
149
|
+
mask = [
|
150
|
+
self.get_attention_mask(
|
151
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
152
|
+
)
|
153
|
+
for i in range(self.config.num_layers)
|
154
|
+
]
|
155
|
+
|
156
|
+
return self._forward_with_embeds(
|
157
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
158
|
+
)
|
159
|
+
|
160
|
+
def _forward_with_embeds(
|
161
|
+
self,
|
162
|
+
input_embeds: torch.Tensor,
|
163
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
164
|
+
mask: List[torch.Tensor],
|
165
|
+
input_pos: torch.Tensor,
|
166
|
+
kv_cache: kv_utils.KVCache,
|
167
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
168
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
169
|
+
"""Forwards the model with input embeddings."""
|
143
170
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
144
171
|
"The number of transformer blocks and the number of KV cache entries"
|
145
172
|
" must be the same."
|
146
173
|
)
|
147
174
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
x = self.tok_embedding(tokens)
|
154
|
-
x = x * (self.config.embedding_dim**0.5)
|
155
|
-
|
156
|
-
updated_kv_entires = []
|
175
|
+
if self.config.embedding_scale is not None:
|
176
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
177
|
+
x = input_embeds
|
178
|
+
updated_kv_entries = []
|
179
|
+
mask_input = mask is not None
|
157
180
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
-
mask =
|
159
|
-
|
181
|
+
mask = (
|
182
|
+
mask
|
183
|
+
if mask_input
|
184
|
+
else self.get_attention_mask(
|
185
|
+
block.config.attn_config.attn_type, input_pos
|
186
|
+
)
|
160
187
|
)
|
161
188
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
189
|
+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
|
163
190
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
191
|
+
updated_kv_entries.append(kv_entry)
|
192
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
193
|
|
167
194
|
if export_config is not None:
|
168
195
|
if (
|
@@ -228,11 +255,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
228
255
|
)
|
229
256
|
|
230
257
|
num_layers = 26
|
258
|
+
embedding_dim = 2304
|
231
259
|
config = cfg.ModelConfig(
|
232
260
|
vocab_size=256000,
|
233
261
|
num_layers=num_layers,
|
234
262
|
max_seq_len=8192,
|
235
|
-
embedding_dim=
|
263
|
+
embedding_dim=embedding_dim,
|
264
|
+
embedding_scale=embedding_dim**0.5,
|
236
265
|
kv_cache_max_len=kv_cache_max_len,
|
237
266
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
238
267
|
final_norm_config=norm_config,
|
@@ -249,6 +278,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
249
278
|
config.num_layers = 2
|
250
279
|
config.max_seq_len = 2 * kv_cache_max_len
|
251
280
|
config.embedding_dim = 128
|
281
|
+
config.embedding_scale = config.embedding_dim**0.5
|
252
282
|
config.block_configs = config.block_configs[: config.num_layers]
|
253
283
|
for block_config in config.block_configs:
|
254
284
|
block_config.attn_config.num_heads = 4
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building Llama 3.2 models."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
|
|
26
27
|
|
27
28
|
|
28
29
|
def _build_llama3_rope_cache(
|
29
|
-
|
30
|
-
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
n_elem: int,
|
31
32
|
base: int,
|
32
33
|
condense_ratio: int,
|
33
34
|
dtype: torch.dtype,
|
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
|
|
36
37
|
low_freq_factor: float,
|
37
38
|
high_freq_factor: float,
|
38
39
|
max_seq_len: int,
|
40
|
+
**kwargs,
|
39
41
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40
|
-
"""
|
42
|
+
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
|
41
43
|
|
42
44
|
It's a modified version of attn_utils.build_rope_cache with additional
|
43
45
|
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
|
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
|
|
47
49
|
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
|
48
50
|
|
49
51
|
Args:
|
50
|
-
|
51
|
-
|
52
|
-
base (int
|
53
|
-
condense_ratio (int
|
54
|
-
|
55
|
-
|
56
|
-
device (torch.device, optional): Output tensor's data type.
|
52
|
+
input_pos (torch.Tensor): the given input sequence positions
|
53
|
+
n_elem (int): Each sequence's dimmension.
|
54
|
+
base (int): Rope base value.
|
55
|
+
condense_ratio (int): The ratio by which sequence indicies are condensed.
|
56
|
+
dtype (torch.dtype): Output tensor's data type.
|
57
|
+
device (torch.device): Output tensor's data type.
|
57
58
|
factor (float): Factor to scale theta down for tokens in long range in the
|
58
59
|
sequence.
|
59
60
|
low_freq_factor (float): Factor to determine if tokens are in long range
|
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
|
|
66
67
|
Returns:
|
67
68
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
68
69
|
"""
|
69
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
70
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
70
71
|
low_freq_wavelen = max_seq_len / low_freq_factor
|
71
72
|
high_freq_wavelen = max_seq_len / high_freq_factor
|
72
73
|
wavelen = 2 * math.pi / theta
|
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
|
|
81
82
|
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
82
83
|
theta = torch.where(is_medium, smoothed_theta, theta)
|
83
84
|
|
84
|
-
seq_idx =
|
85
|
+
seq_idx = input_pos / condense_ratio
|
85
86
|
idx_theta = torch.outer(seq_idx, theta)
|
86
87
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
|
87
88
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
|
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
|
|
97
98
|
def __init__(self, config: cfg.ModelConfig):
|
98
99
|
super().__init__(config)
|
99
100
|
attn_config = self.config.block_config(0).attn_config
|
100
|
-
self.rope_cache = _build_llama3_rope_cache(
|
101
|
-
size=self.config.kv_cache_max,
|
102
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
103
|
-
base=attn_config.rotary_base,
|
104
|
-
condense_ratio=1,
|
105
|
-
dtype=torch.float32,
|
106
|
-
device=torch.device("cpu"),
|
107
|
-
factor=32.0,
|
108
|
-
low_freq_factor=1.0,
|
109
|
-
high_freq_factor=4.0,
|
110
|
-
max_seq_len=self.config.max_seq_len,
|
111
|
-
)
|
112
101
|
|
113
102
|
|
114
103
|
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
129
|
pre_attention_norm_config=norm_config,
|
141
130
|
post_attention_norm_config=norm_config,
|
142
131
|
)
|
132
|
+
|
133
|
+
max_seq_len = 8192
|
134
|
+
# Create the RoPE callable
|
135
|
+
build_rope = partial(
|
136
|
+
_build_llama3_rope_cache,
|
137
|
+
condense_ratio=1,
|
138
|
+
dtype=torch.float32,
|
139
|
+
device=torch.device("cpu"),
|
140
|
+
factor=32.0,
|
141
|
+
low_freq_factor=1.0,
|
142
|
+
high_freq_factor=4.0,
|
143
|
+
max_seq_len=max_seq_len,
|
144
|
+
)
|
145
|
+
|
143
146
|
config = cfg.ModelConfig(
|
144
147
|
vocab_size=128256,
|
145
148
|
num_layers=16,
|
146
|
-
max_seq_len=
|
149
|
+
max_seq_len=max_seq_len,
|
147
150
|
embedding_dim=2048,
|
148
151
|
kv_cache_max_len=kv_cache_max_len,
|
149
152
|
block_configs=block_config,
|
150
153
|
final_norm_config=norm_config,
|
151
154
|
enable_hlfb=True,
|
155
|
+
build_rope=build_rope,
|
152
156
|
)
|
153
157
|
return config
|
154
158
|
|
@@ -54,6 +54,7 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
54
54
|
input_pos: torch.Tensor,
|
55
55
|
kv_cache: kv_utils.KVCache,
|
56
56
|
input_embeds: torch.Tensor = None,
|
57
|
+
mask: Optional[torch.Tensor] = None,
|
57
58
|
export_config: Optional[model_builder.ExportConfig] = None,
|
58
59
|
called_by_generate: bool = True,
|
59
60
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
@@ -73,8 +74,9 @@ class Decoder(model_builder.DecoderOnlyModel):
|
|
73
74
|
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
74
75
|
# doesn't work here.
|
75
76
|
embeds_len = input_embeds.shape[1]
|
76
|
-
mask
|
77
|
-
|
77
|
+
if mask is None:
|
78
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
79
|
+
mask[:, embeds_len:] = float("-inf")
|
78
80
|
|
79
81
|
return self._forward_with_embeds(
|
80
82
|
input_embeds, rope, mask, input_pos, kv_cache
|
@@ -57,6 +57,7 @@ class Decoder2(gemma2.Gemma2):
|
|
57
57
|
input_pos: torch.Tensor,
|
58
58
|
kv_cache: kv_utils.KVCache,
|
59
59
|
input_embeds: torch.Tensor = None,
|
60
|
+
mask: Optional[torch.Tensor] = None,
|
60
61
|
export_config: Optional[model_builder.ExportConfig] = None,
|
61
62
|
called_by_generate: bool = True,
|
62
63
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
@@ -73,17 +74,21 @@ class Decoder2(gemma2.Gemma2):
|
|
73
74
|
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
74
75
|
)
|
75
76
|
|
76
|
-
if
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
77
|
+
if mask is None:
|
78
|
+
if called_by_generate:
|
79
|
+
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
|
80
|
+
mask = [
|
81
|
+
self.get_attention_mask(
|
82
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
83
|
+
)
|
84
|
+
for i in range(self.config.num_layers)
|
85
|
+
]
|
86
|
+
else:
|
87
|
+
# By default, don't mask image embeds with a diagonal causal mask.
|
88
|
+
embeds_len = input_embeds.shape[1]
|
89
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
90
|
+
mask[:, embeds_len:] = float("-inf")
|
91
|
+
mask = [mask] * self.config.num_layers
|
87
92
|
|
88
93
|
return self._forward_with_embeds(
|
89
94
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
@@ -70,6 +70,7 @@ class PaliGemma(nn.Module):
|
|
70
70
|
tokens: torch.Tensor,
|
71
71
|
input_pos: torch.Tensor,
|
72
72
|
kv_cache: kv_utils.KVCache,
|
73
|
+
mask: Optional[torch.Tensor] = None,
|
73
74
|
pixel_values: torch.Tensor = None,
|
74
75
|
export_config: Optional[model_builder.ExportConfig] = None,
|
75
76
|
called_by_generate: bool = True,
|
@@ -79,6 +80,7 @@ class PaliGemma(nn.Module):
|
|
79
80
|
tokens=tokens,
|
80
81
|
input_pos=input_pos,
|
81
82
|
kv_cache=kv_cache,
|
83
|
+
mask=mask,
|
82
84
|
input_embeds=None,
|
83
85
|
export_config=export_config,
|
84
86
|
called_by_generate=called_by_generate,
|
@@ -111,6 +113,7 @@ class PaliGemma(nn.Module):
|
|
111
113
|
tokens=None,
|
112
114
|
input_pos=input_pos,
|
113
115
|
kv_cache=kv_cache,
|
116
|
+
mask=mask,
|
114
117
|
input_embeds=input_embeds,
|
115
118
|
export_config=export_config,
|
116
119
|
called_by_generate=called_by_generate,
|
@@ -26,7 +26,7 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
32
|
_OUTPUT_PATH = flags.DEFINE_string(
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
|
|
93
94
|
]
|
94
95
|
|
95
96
|
|
96
|
-
def
|
97
|
-
|
98
|
-
|
97
|
+
def _build_phi3_rope(
|
98
|
+
input_pos: int,
|
99
|
+
n_elem: int,
|
99
100
|
base: int,
|
100
101
|
condense_ratio: int,
|
101
102
|
dtype: torch.dtype,
|
102
103
|
device: torch.device,
|
103
104
|
theta_factors: torch.Tensor,
|
104
105
|
scale: float,
|
106
|
+
**kwargs,
|
105
107
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
-
"""
|
108
|
+
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
|
107
109
|
|
108
110
|
It's a modified version of attn_utils.build_rope_cache with additional
|
109
111
|
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
|
110
112
|
Cos values with scaling factors for quick lookup during the inference.
|
111
113
|
|
112
114
|
Args:
|
113
|
-
|
114
|
-
|
115
|
+
input_pos (torch.Tensor): the given input sequence positions
|
116
|
+
n_elem (int): Each sequence's dimmension.
|
115
117
|
base (int, optional): Rope base value.
|
116
118
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
117
119
|
condensed.
|
118
120
|
dtype (torch.dtype, optional): Output tensor's data type.
|
119
121
|
device (torch.device, optional): Output tensor's data type.
|
120
|
-
theta_factors (torch.Tensor, optional): A tensor of shape (
|
121
|
-
scale the theta values.
|
122
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
|
123
|
+
to scale the theta values.
|
122
124
|
scale (float, optional): A float used to scale the rope values.
|
123
125
|
|
124
126
|
Returns:
|
125
127
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
126
128
|
"""
|
127
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
129
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
128
130
|
theta = theta / theta_factors
|
129
|
-
seq_idx =
|
131
|
+
seq_idx = input_pos / condense_ratio
|
130
132
|
idx_theta = torch.outer(seq_idx, theta)
|
131
133
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
132
134
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
|
139
141
|
def __init__(self, config: cfg.ModelConfig):
|
140
142
|
super().__init__(config)
|
141
143
|
attn_config = self.config.block_config(0).attn_config
|
142
|
-
self.rope_cache = _build_rope_cache(
|
143
|
-
size=self.config.kv_cache_max,
|
144
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
145
|
-
base=attn_config.rotary_base,
|
146
|
-
condense_ratio=1,
|
147
|
-
dtype=torch.float32,
|
148
|
-
device=torch.device("cpu"),
|
149
|
-
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
150
|
-
scale=math.sqrt(
|
151
|
-
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
152
|
-
),
|
153
|
-
)
|
154
144
|
|
155
145
|
|
156
146
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
183
173
|
pre_attention_norm_config=norm_config,
|
184
174
|
post_attention_norm_config=norm_config,
|
185
175
|
)
|
176
|
+
max_seq_len = 4096
|
177
|
+
# Create the RoPE callable
|
178
|
+
build_rope = partial(
|
179
|
+
_build_phi3_rope,
|
180
|
+
condense_ratio=1,
|
181
|
+
dtype=torch.float32,
|
182
|
+
device=torch.device("cpu"),
|
183
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
184
|
+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
|
185
|
+
max_seq_len=max_seq_len,
|
186
|
+
)
|
187
|
+
|
186
188
|
config = cfg.ModelConfig(
|
187
189
|
vocab_size=32064,
|
188
190
|
num_layers=32,
|
189
|
-
max_seq_len=
|
191
|
+
max_seq_len=max_seq_len,
|
190
192
|
kv_cache_max_len=kv_cache_max_len,
|
191
193
|
embedding_dim=3072,
|
192
194
|
block_configs=block_config,
|
193
195
|
final_norm_config=norm_config,
|
194
196
|
lm_head_share_weight_with_embedding=False,
|
195
197
|
enable_hlfb=True,
|
198
|
+
build_rope=build_rope,
|
196
199
|
)
|
197
200
|
return config
|
198
201
|
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting SmolLM2 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
33
|
+
'tflite_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The tflite file path to export.',
|
36
|
+
)
|
37
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
|
+
'prefill_seq_lens',
|
39
|
+
(8, 64, 128, 256, 512, 1024),
|
40
|
+
'List of the maximum sizes of prefill input tensors.',
|
41
|
+
)
|
42
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
|
+
'kv_cache_max_len',
|
44
|
+
1280,
|
45
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
|
+
)
|
47
|
+
_QUANTIZE = flags.DEFINE_bool(
|
48
|
+
'quantize',
|
49
|
+
True,
|
50
|
+
'Whether the model should be quantized.',
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def main(_):
|
55
|
+
pytorch_model = smollm.build_model_v2(
|
56
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
|
+
)
|
58
|
+
|
59
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
+
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
+
converter.convert_to_tflite(
|
62
|
+
pytorch_model,
|
63
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
64
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
|
+
quantize=_QUANTIZE.value,
|
66
|
+
export_config=ExportConfig(),
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
if __name__ == '__main__':
|
71
|
+
app.run(main)
|
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
85
85
|
tensor_names=TENSOR_NAMES,
|
86
86
|
model_class=SmolLM,
|
87
87
|
)
|
88
|
+
|
89
|
+
|
90
|
+
class SmolLM2(model_builder.DecoderOnlyModel):
|
91
|
+
"""A SmolLM2 model built from the Edge Generative API layers."""
|
92
|
+
pass
|
93
|
+
|
94
|
+
|
95
|
+
def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
96
|
+
"""Returns the model config for a SmolLM2 135M model.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
100
|
+
is 1024.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
The model config for a SmolLM2 model.
|
104
|
+
"""
|
105
|
+
config = get_model_config(kv_cache_max_len)
|
106
|
+
config.block_config(0).attn_config.rotary_base = 100000
|
107
|
+
return config
|
108
|
+
|
109
|
+
|
110
|
+
def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
111
|
+
config = get_model_config_v2(**kwargs)
|
112
|
+
config.vocab_size = 128
|
113
|
+
config.num_layers = 2
|
114
|
+
# SmolLM2 has only one block config.
|
115
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
116
|
+
return config
|
117
|
+
|
118
|
+
|
119
|
+
def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
|
120
|
+
return model_builder.build_decoder_only_model(
|
121
|
+
checkpoint_path=checkpoint_path,
|
122
|
+
config=get_model_config_v2(**kwargs),
|
123
|
+
tensor_names=TENSOR_NAMES,
|
124
|
+
model_class=SmolLM2,
|
125
|
+
)
|
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
36
36
|
30,
|
37
37
|
"The maximum size of the generated tokens.",
|
38
38
|
)
|
39
|
+
_MODEL_VERSION = flags.DEFINE_enum(
|
40
|
+
"model_version",
|
41
|
+
"v1",
|
42
|
+
["v1", "v2"],
|
43
|
+
"The version of SmolLm to verify.",
|
44
|
+
)
|
45
|
+
_CHECKPOINT = {
|
46
|
+
"v1": "HuggingFaceTB/SmolLM-135M",
|
47
|
+
"v2": "HuggingFaceTB/SmolLM2-135M",
|
48
|
+
}
|
49
|
+
|
50
|
+
_BUILDER = {
|
51
|
+
"v1": smollm.build_model,
|
52
|
+
"v2": smollm.build_model_v2,
|
53
|
+
}
|
39
54
|
|
40
55
|
|
41
56
|
def main(_):
|
42
|
-
checkpoint =
|
57
|
+
checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
|
58
|
+
builder = _BUILDER[_MODEL_VERSION.value]
|
43
59
|
logging.info("Loading the original model from: %s", checkpoint)
|
44
60
|
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
61
|
|
@@ -49,7 +65,7 @@ def main(_):
|
|
49
65
|
)
|
50
66
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
67
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model =
|
68
|
+
reauthored_model = builder(reauthored_checkpoint)
|
53
69
|
|
54
70
|
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
71
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
# A toy example which has a single-layer transformer block.
|
16
|
-
from typing import Tuple
|
16
|
+
from typing import Optional, Tuple
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers import builder
|
19
19
|
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
@@ -52,14 +52,20 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
52
52
|
self.config = config
|
53
53
|
|
54
54
|
@torch.inference_mode
|
55
|
-
def forward(
|
55
|
+
def forward(
|
56
|
+
self,
|
57
|
+
idx: torch.Tensor,
|
58
|
+
input_pos: torch.Tensor,
|
59
|
+
mask: Optional[torch.Tensor] = None,
|
60
|
+
) -> torch.Tensor:
|
56
61
|
x = self.tok_embedding(idx)
|
57
62
|
cos, sin = self.rope_cache
|
58
63
|
|
59
64
|
cos = cos.index_select(0, input_pos)
|
60
65
|
sin = sin.index_select(0, input_pos)
|
61
|
-
mask
|
62
|
-
|
66
|
+
if mask is None:
|
67
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
68
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
63
69
|
|
64
70
|
x = self.transformer_block(x, (cos, sin), mask, input_pos)
|
65
71
|
x = self.final_norm(x)
|
@@ -98,7 +104,12 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
|
98
104
|
self.config = config
|
99
105
|
|
100
106
|
@torch.inference_mode
|
101
|
-
def forward(
|
107
|
+
def forward(
|
108
|
+
self,
|
109
|
+
idx: torch.Tensor,
|
110
|
+
input_pos: torch.Tensor,
|
111
|
+
mask: Optional[torch.Tensor] = None,
|
112
|
+
) -> torch.Tensor:
|
102
113
|
x = self.tok_embedding(idx)
|
103
114
|
cos, sin = self.rope_cache
|
104
115
|
|
@@ -63,23 +63,25 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
63
63
|
tokens: torch.Tensor,
|
64
64
|
input_pos: torch.Tensor,
|
65
65
|
kv_cache: kv_utils.KVCache,
|
66
|
+
mask: Optional[torch.Tensor] = None,
|
66
67
|
export_config: Optional[ExportConfig] = None,
|
67
68
|
) -> Tuple[torch.Tensor, kv_utils.KVCache]:
|
68
69
|
x = self.tok_embedding(tokens)
|
69
70
|
cos, sin = self.rope_cache
|
70
71
|
cos = cos.index_select(0, input_pos)
|
71
72
|
sin = sin.index_select(0, input_pos)
|
72
|
-
mask
|
73
|
-
|
73
|
+
if mask is None:
|
74
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
75
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
76
|
|
75
|
-
|
77
|
+
updated_kv_entries = []
|
76
78
|
for i, block in enumerate(self.transformer_blocks):
|
77
79
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
80
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
81
|
if kv_entry:
|
80
|
-
|
82
|
+
updated_kv_entries.append(kv_entry)
|
81
83
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
84
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
83
85
|
|
84
86
|
if export_config is not None:
|
85
87
|
if (
|
@@ -27,33 +27,6 @@ import torch
|
|
27
27
|
from torch import nn
|
28
28
|
|
29
29
|
|
30
|
-
def _embed_rope(
|
31
|
-
q: torch.Tensor,
|
32
|
-
k: torch.Tensor,
|
33
|
-
n_elem: int,
|
34
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
35
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
-
"""Embed rotary positional embedding for query and key.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
q (torch.Tensor): query tensor.
|
40
|
-
k (torch.Tensor): key tensor.
|
41
|
-
n_elem (int): number of elements to embed rotarty positional embedding.
|
42
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
43
|
-
"""
|
44
|
-
if n_elem > 0:
|
45
|
-
cos, sin = rope
|
46
|
-
q_roped = rotary_pos_emb.apply_rope(
|
47
|
-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
48
|
-
)
|
49
|
-
k_roped = rotary_pos_emb.apply_rope(
|
50
|
-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
51
|
-
)
|
52
|
-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
53
|
-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
54
|
-
return q, k
|
55
|
-
|
56
|
-
|
57
30
|
class TransformerBlock(nn.Module):
|
58
31
|
|
59
32
|
def __init__(
|
@@ -252,7 +225,8 @@ class CausalSelfAttention(nn.Module):
|
|
252
225
|
if rope is not None:
|
253
226
|
# Compute rotary positional embedding for query and key.
|
254
227
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
255
|
-
|
228
|
+
cos, sin = rope
|
229
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
256
230
|
|
257
231
|
if kv_cache is not None:
|
258
232
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -404,7 +378,8 @@ class CrossAttention(nn.Module):
|
|
404
378
|
if rope is not None:
|
405
379
|
# Compute rotary positional embedding for query and key.
|
406
380
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
407
|
-
|
381
|
+
cos, sin = rope
|
382
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
408
383
|
|
409
384
|
if kv_cache is not None:
|
410
385
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -17,8 +17,8 @@
|
|
17
17
|
|
18
18
|
import dataclasses
|
19
19
|
import enum
|
20
|
-
from typing import Optional, Sequence, Union
|
21
|
-
|
20
|
+
from typing import Callable, Optional, Sequence, Union
|
21
|
+
from ai_edge_torch.generative.layers import rotary_position_embedding
|
22
22
|
|
23
23
|
@enum.unique
|
24
24
|
class ActivationType(enum.Enum):
|
@@ -218,6 +218,10 @@ class ModelConfig:
|
|
218
218
|
# Softcap on the model output logits.
|
219
219
|
final_logit_softcap: Optional[float] = None
|
220
220
|
|
221
|
+
# The function to call to create the RoPE sin and cos vectors during the
|
222
|
+
# forward pass. Defaults to a standard implementation.
|
223
|
+
build_rope: Callable = rotary_position_embedding.build_rope
|
224
|
+
|
221
225
|
@property
|
222
226
|
def kv_cache_max(self) -> int:
|
223
227
|
if self.kv_cache_max_len > 0:
|
@@ -32,57 +32,63 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1 = x
|
36
|
-
|
37
|
-
|
38
|
-
roped = (
|
35
|
+
x1, x2 = torch.split(x, head_size // 2, dim=-1)
|
36
|
+
left = x1 * cos - x2 * sin
|
37
|
+
right = x2 * cos + x1 * sin
|
38
|
+
roped = torch.cat([left, right], dim=-1)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
43
|
-
q: torch.Tensor,
|
44
|
-
k: torch.Tensor,
|
42
|
+
def build_rope(
|
45
43
|
input_pos: torch.Tensor,
|
46
44
|
n_elem: int,
|
45
|
+
head_dim: int,
|
47
46
|
base: int = 10_000,
|
48
47
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
-
"""Computes rotary positional embedding
|
48
|
+
"""Computes rotary positional embedding cosine and sine tensors.
|
50
49
|
|
51
50
|
Args:
|
52
|
-
q: the query tensor.
|
53
|
-
k: the key tensor.
|
54
51
|
input_pos: the sequence indices for the query and key
|
55
52
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
+
base: the base of the exponentiated value for RoPE.
|
56
54
|
|
57
55
|
Returns:
|
58
|
-
|
56
|
+
cos, sin tensors
|
59
57
|
"""
|
60
58
|
|
61
59
|
if n_elem <= 0:
|
62
|
-
return
|
60
|
+
return None, None
|
63
61
|
|
64
|
-
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
62
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
-
|
63
|
+
head_dim // 2, dtype=torch.float32
|
67
64
|
)
|
68
65
|
timescale = float(base) ** freq_exponents
|
69
66
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
67
|
0
|
71
68
|
).unsqueeze(0)
|
72
|
-
cos = torch.cos(radians)
|
73
|
-
sin = torch.sin(radians)
|
69
|
+
cos = torch.cos(radians)
|
70
|
+
sin = torch.sin(radians)
|
71
|
+
return cos, sin
|
72
|
+
|
74
73
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
74
|
+
def apply_rope_inline(
|
75
|
+
q: torch.Tensor,
|
76
|
+
k: torch.Tensor,
|
77
|
+
cos: torch.Tensor,
|
78
|
+
sin: torch.Tensor,
|
79
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
80
|
+
"""Computes rotary positional embedding inline for a query and key.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
q: the query tensor.
|
84
|
+
k: the key tensor.
|
85
|
+
cos: the cosine tensor.
|
86
|
+
sin: the sine tensor.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
output the RoPE'd query and key.
|
90
|
+
"""
|
85
91
|
|
86
|
-
q_roped =
|
87
|
-
k_roped =
|
92
|
+
q_roped = apply_rope(q, cos, sin)
|
93
|
+
k_roped = apply_rope(k, cos, sin)
|
88
94
|
return q_roped, k_roped
|
@@ -150,6 +150,16 @@ class TestModelConversion(googletest.TestCase):
|
|
150
150
|
ai_edge_torch.config.in_oss,
|
151
151
|
reason="tests with custom ops are not supported in oss",
|
152
152
|
)
|
153
|
+
|
154
|
+
def test_smollm2(self):
|
155
|
+
config = smollm.get_fake_model_config_v2()
|
156
|
+
pytorch_model = smollm.SmolLM2(config).eval()
|
157
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
158
|
+
@googletest.skipIf(
|
159
|
+
ai_edge_torch.config.in_oss,
|
160
|
+
reason="tests with custom ops are not supported in oss",
|
161
|
+
)
|
162
|
+
|
153
163
|
def test_openelm(self):
|
154
164
|
config = openelm.get_fake_model_config()
|
155
165
|
pytorch_model = openelm.OpenELM(config).eval()
|
@@ -25,6 +25,7 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
25
25
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
26
26
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
27
27
|
import ai_edge_torch.generative.layers.model_config as cfg
|
28
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
28
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
30
|
import torch
|
30
31
|
from torch import nn
|
@@ -87,13 +88,6 @@ class DecoderOnlyModel(nn.Module):
|
|
87
88
|
config.embedding_dim,
|
88
89
|
config.final_norm_config,
|
89
90
|
)
|
90
|
-
# ROPE parameters for all attn_configs are the same. Take the first one.
|
91
|
-
attn_config = config.block_config(0).attn_config
|
92
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
93
|
-
size=config.kv_cache_max,
|
94
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
95
|
-
base=attn_config.rotary_base,
|
96
|
-
)
|
97
91
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
98
92
|
size=config.kv_cache_max,
|
99
93
|
)
|
@@ -105,6 +99,7 @@ class DecoderOnlyModel(nn.Module):
|
|
105
99
|
tokens: torch.Tensor,
|
106
100
|
input_pos: torch.Tensor,
|
107
101
|
kv_cache: kv_utils.KVCache,
|
102
|
+
mask: Optional[torch.Tensor] = None,
|
108
103
|
lora: Optional[lora_utils.LoRA] = None,
|
109
104
|
export_config: Optional[ExportConfig] = None,
|
110
105
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
@@ -116,10 +111,21 @@ class DecoderOnlyModel(nn.Module):
|
|
116
111
|
|
117
112
|
# token embeddings of shape (b, t, n_embd)
|
118
113
|
input_embeds = self.tok_embedding(tokens)
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
114
|
+
|
115
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
116
|
+
attn_config = self.config.block_config(0).attn_config
|
117
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
118
|
+
rope = self.config.build_rope(
|
119
|
+
input_pos=input_pos,
|
120
|
+
n_elem=n_elem,
|
121
|
+
base=attn_config.rotary_base,
|
122
|
+
head_dim=attn_config.head_dim,
|
123
|
+
# input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
|
124
|
+
)
|
125
|
+
|
126
|
+
if mask is None:
|
127
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
128
|
+
mask = mask[:, :, :, : self.config.kv_cache_max]
|
123
129
|
|
124
130
|
return self.forward_with_embeds(
|
125
131
|
input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
|
@@ -145,14 +151,14 @@ class DecoderOnlyModel(nn.Module):
|
|
145
151
|
if self.config.embedding_scale is not None:
|
146
152
|
x = x * self.config.embedding_scale
|
147
153
|
|
148
|
-
|
154
|
+
updated_kv_entries = []
|
149
155
|
for i, block in enumerate(self.transformer_blocks):
|
150
156
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
151
157
|
lora_adapter = lora.adapters[i] if lora else None
|
152
158
|
x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter)
|
153
159
|
if kv_entry:
|
154
|
-
|
155
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
160
|
+
updated_kv_entries.append(kv_entry)
|
161
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
156
162
|
|
157
163
|
if export_config is not None:
|
158
164
|
if (
|
@@ -17,7 +17,7 @@ from typing import Any
|
|
17
17
|
import uuid
|
18
18
|
|
19
19
|
from ai_edge_torch import lowertools
|
20
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
20
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
22
|
import torch
|
23
23
|
|
@@ -87,7 +87,7 @@ def mark_pattern(
|
|
87
87
|
m.meta["ORIGINAL_NODE"] = n
|
88
88
|
|
89
89
|
# Sanitize graph_module to match in the same way as pattern's graph_module.
|
90
|
-
graph_module_to_match =
|
90
|
+
graph_module_to_match = fx_utils.remove_clone_ops(graph_module_to_match)
|
91
91
|
|
92
92
|
match_with_attrs = pattern.match(graph_module_to_match)
|
93
93
|
|
@@ -111,13 +111,25 @@ def mark_pattern(
|
|
111
111
|
is_input=True,
|
112
112
|
)
|
113
113
|
|
114
|
-
# Only replace input by the marker node for those nodes used in the
|
114
|
+
# Only replace input by the marker node for those nodes used in the
|
115
|
+
# pattern.
|
115
116
|
in_pattern_nodes = set(match.nodes_map.values())
|
116
117
|
for user in input_node.users.keys():
|
117
|
-
if user in in_pattern_nodes:
|
118
|
-
|
119
|
-
|
120
|
-
|
118
|
+
if user not in in_pattern_nodes:
|
119
|
+
continue
|
120
|
+
|
121
|
+
user.meta["ORIGINAL_NODE"].replace_input_with(
|
122
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
123
|
+
)
|
124
|
+
# Pattern matching graph sanitization may remove clone ops, which means
|
125
|
+
# the user's input in the original graph may be a clone op. When
|
126
|
+
# replacing the input with the marker node, we need to further try
|
127
|
+
# replacing the input of the clone op that connects to the user.
|
128
|
+
for original_user_input in user.meta["ORIGINAL_NODE"].all_input_nodes:
|
129
|
+
if fx_utils.is_clone_op(original_user_input):
|
130
|
+
original_user_input.replace_input_with(
|
131
|
+
input_node.meta["ORIGINAL_NODE"], new_input_node
|
132
|
+
)
|
121
133
|
|
122
134
|
for i, pattern_output_node in enumerate(pattern.output_nodes):
|
123
135
|
output_node = match.nodes_map[pattern_output_node]
|
@@ -12,11 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""FX graph utilities for pattern matching clean ups."""
|
16
16
|
|
17
17
|
import torch
|
18
18
|
|
19
19
|
|
20
|
+
def is_clone_op(node: torch.fx.Node) -> bool:
|
21
|
+
"""Checks if the node is a clone op."""
|
22
|
+
return (
|
23
|
+
node.op == "call_function" and node.target == torch.ops.aten.clone.default
|
24
|
+
)
|
25
|
+
|
26
|
+
|
20
27
|
def remove_clone_ops(gm: torch.fx.GraphModule):
|
21
28
|
"""Removes clone ops from the graph.
|
22
29
|
|
@@ -32,7 +39,7 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
32
39
|
The graph module with clone ops removed.
|
33
40
|
"""
|
34
41
|
for node in gm.graph.nodes:
|
35
|
-
if node
|
42
|
+
if is_clone_op(node):
|
36
43
|
node.replace_all_uses_with(node.args[0])
|
37
44
|
gm.graph.erase_node(node)
|
38
45
|
|
@@ -18,13 +18,14 @@ import dataclasses
|
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import
|
21
|
+
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
22
22
|
import torch
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
|
24
|
+
Graph = torch.fx.Graph
|
25
|
+
GraphModule = torch.fx.GraphModule
|
26
|
+
TensorArgument = torch.export.graph_signature.TensorArgument
|
27
|
+
InternalMatch = torch.fx.passes.utils.matcher_utils.InternalMatch
|
28
|
+
SubgraphMatcher = torch.fx.passes.utils.matcher_utils.SubgraphMatcher
|
28
29
|
|
29
30
|
|
30
31
|
def _are_equal(x: Any, y: Any) -> bool:
|
@@ -219,8 +220,8 @@ class Pattern:
|
|
219
220
|
# Sanitize graph_module for more precise pattern matching.
|
220
221
|
# The graph_module to match against this pattern should apply equivalent
|
221
222
|
# sanitization.
|
222
|
-
self.graph_module =
|
223
|
-
self.graph_module =
|
223
|
+
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
|
224
|
+
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
|
224
225
|
|
225
226
|
# Builds list of ordered input and output nodes.
|
226
227
|
self.graph_nodes_map = {}
|
@@ -58,6 +58,32 @@ class TestMarkPattern(googletest.TestCase):
|
|
58
58
|
{"stablehlo.custom_call @mark_tensor": 6},
|
59
59
|
)
|
60
60
|
|
61
|
+
def test_mark_pattern_with_clone_inputs(self):
|
62
|
+
|
63
|
+
class TestModel(torch.nn.Module):
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
return torch.ops.aten.clone.default(x * x) + x
|
67
|
+
|
68
|
+
pattern = pattern_module.Pattern(
|
69
|
+
"test.add",
|
70
|
+
lambda a, b: a + b,
|
71
|
+
export_args=(torch.rand(2, 2), torch.rand(2, 2)),
|
72
|
+
)
|
73
|
+
|
74
|
+
model = TestModel().eval()
|
75
|
+
args = (torch.rand(20, 20),)
|
76
|
+
exported_program = torch.export.export(model, args)
|
77
|
+
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
78
|
+
mlir = _export_stablehlo_mlir(exported_program)
|
79
|
+
|
80
|
+
lowertools.assert_string_count(
|
81
|
+
self,
|
82
|
+
mlir,
|
83
|
+
{'stablehlo.composite "test.add"': 1},
|
84
|
+
{"stablehlo.custom_call @mark_tensor": 3},
|
85
|
+
)
|
86
|
+
|
61
87
|
def test_mark_pattern_with_attr_builder(self):
|
62
88
|
class TestModel(torch.nn.Module):
|
63
89
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250110
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=VHqAyYw4u6BgyQ6v7Xp08Jqb0cnzIVGsulfnclxgY5c,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -47,13 +47,13 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
48
48
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
49
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=e9HfiHr4FkQZwVBYdDUZGzOjB5TqY2LqtVTHEzwVkQY,10428
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
54
54
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
55
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
|
56
|
-
ai_edge_torch/generative/examples/llama/llama.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=kWy6-V4bFtE1yguCROLJS5XB0GOJD1-acJWp2dFjB5Q,6606
|
57
57
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
58
58
|
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
59
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
@@ -64,19 +64,19 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
|
|
64
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
66
|
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=scLsguzzuHfKYDWUd2uZkKYVRzdAbQHLd-kPam8QwvM,3004
|
67
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
68
|
-
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=
|
67
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=NJGhfPxVQjHDqea_lYGffjihOBdIYiXftiFTM6ccrwM,5475
|
68
|
+
ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=L6F6KWHqxdnGQTOp9P3c8r_K1Uxet0ZCcbdvmjWtIos,6513
|
69
69
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
70
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=CEMG9gh51ev1KXPew927a6nfampiXX9bL6m-25tNYN8,6340
|
71
71
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
73
73
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
74
74
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
75
75
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
|
-
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=CaI_-Vtd0j9FoWIDd8q5z4CFsGYUhTwEWGvMGaXICuU,2514
|
77
77
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=g-MvEibJT_iIhkec2VGtFFA_iP54VCq9mY4KxwAYF08,2512
|
78
78
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
|
79
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
79
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=SHvJjmi5eIch5cYIWORt6YFmSQx_oCiOk1UbKKGibtk,7119
|
80
80
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
81
81
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
82
82
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -85,8 +85,9 @@ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-L
|
|
85
85
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
86
86
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
87
87
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
|
88
|
-
ai_edge_torch/generative/examples/smollm/
|
89
|
-
ai_edge_torch/generative/examples/smollm/
|
88
|
+
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
|
89
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
|
90
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
|
90
91
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
91
92
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
92
93
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
@@ -108,8 +109,8 @@ ai_edge_torch/generative/examples/t5/t5.py,sha256=gFTmPi-xB8pcPRgoF3DJxvH_fT-KWT
|
|
108
109
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
|
109
110
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
111
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
111
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
112
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
|
113
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=Ab_N9xc-4DImA-Pvevr-nnnslBXScXVo4Pw7L3_OlhI,4732
|
113
114
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
114
115
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
|
115
116
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
@@ -117,15 +118,15 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
117
118
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
118
119
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
119
120
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
120
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
121
|
+
ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
|
121
122
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
122
123
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
123
124
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
124
125
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
125
126
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
126
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
127
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=9yPEmWNw3-_2wXBmPmZ7RUKcPXHF2ZbJwksyQoXTA6M,7784
|
127
128
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
128
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
129
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=1L1MEGPYbDELi0zy2OKl7yXyk9FXdBjcXwRZbfiJriU,2619
|
129
130
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
130
131
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
131
132
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -144,25 +145,25 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
144
145
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
145
146
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
146
147
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
147
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
148
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bBcey-aD4L_TwKRrrM81bN2VQoJjPPC84Rv4o3WOc34,12491
|
148
149
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
149
150
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
150
151
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
151
152
|
ai_edge_torch/generative/utilities/converter.py,sha256=MY8BK29yD-W4v45Xdl_ErbNilipsTlD-4-y9MyBxR5g,7620
|
152
153
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
153
154
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
154
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
155
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=6OBKyOmbg5Sap_np1wnajpCQ1fh8P0eONqNls9eHAX4,6778
|
155
156
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
156
157
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
157
158
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
158
159
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
159
160
|
ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
|
160
161
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
161
|
-
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256
|
162
|
-
ai_edge_torch/hlfb/mark_pattern/
|
163
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
162
|
+
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=-BYE7MGMxr-VfBy8tAiiOaCqYv8ytJ0w5l2P8B7h3eM,5387
|
163
|
+
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=taWLpF5IVglxlsF9HM2dIoKDXuQREaCRAXtJeG5gKzs,2073
|
164
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=7bv9XqRkm1pjxiVL4Cm1cArExnolId8hQKFHtvlkCI8,10061
|
164
165
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
165
|
-
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256
|
166
|
+
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=-5UqJyk__1YbUNGuxi4b2sn0CED0W-G337AXwxPGdEs,5567
|
166
167
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
167
168
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
168
169
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
@@ -205,8 +206,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
205
206
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
206
207
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
207
208
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
211
|
-
ai_edge_torch_nightly-0.3.0.
|
212
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/METADATA,sha256=D_Vexo_GTTaYsb6IqB5rLrD-mos2YWze1Xcj3IFDgKE,1966
|
211
|
+
ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
212
|
+
ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
213
|
+
ai_edge_torch_nightly-0.3.0.dev20250110.dist-info/RECORD,,
|
File without changes
|
File without changes
|