ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +70 -12
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +21 -16
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = smollm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
68
|
converter.convert_to_tflite(
|
62
69
|
pytorch_model,
|
63
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
66
75
|
export_config=ExportConfig(),
|
67
76
|
)
|
68
77
|
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entires = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entires.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'tinyllama',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = tiny_llama.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
68
|
converter.convert_to_tflite(
|
63
69
|
pytorch_model,
|
64
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
65
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
66
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
67
75
|
export_config=ExportConfig(),
|
68
76
|
)
|
69
77
|
|
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
@@ -26,6 +27,33 @@ import torch
|
|
26
27
|
from torch import nn
|
27
28
|
|
28
29
|
|
30
|
+
def _embed_rope(
|
31
|
+
q: torch.Tensor,
|
32
|
+
k: torch.Tensor,
|
33
|
+
n_elem: int,
|
34
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
+
"""Embed rotary positional embedding for query and key.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
q (torch.Tensor): query tensor.
|
40
|
+
k (torch.Tensor): key tensor.
|
41
|
+
n_elem (int): number of elements to embed rotarty positional embedding.
|
42
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
43
|
+
"""
|
44
|
+
if n_elem > 0:
|
45
|
+
cos, sin = rope
|
46
|
+
q_roped = rotary_pos_emb.apply_rope(
|
47
|
+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
48
|
+
)
|
49
|
+
k_roped = rotary_pos_emb.apply_rope(
|
50
|
+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
51
|
+
)
|
52
|
+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
53
|
+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
54
|
+
return q, k
|
55
|
+
|
56
|
+
|
29
57
|
class TransformerBlock(nn.Module):
|
30
58
|
|
31
59
|
def __init__(
|
@@ -66,6 +94,7 @@ class TransformerBlock(nn.Module):
|
|
66
94
|
mask: Optional[torch.Tensor] = None,
|
67
95
|
input_pos: Optional[torch.Tensor] = None,
|
68
96
|
kv_cache: kv_utils.KVCacheEntry = None,
|
97
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
69
98
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
70
99
|
"""Forward function of the TransformerBlock.
|
71
100
|
|
@@ -75,6 +104,7 @@ class TransformerBlock(nn.Module):
|
|
75
104
|
mask (torch.Tensor): the optional mask tensor.
|
76
105
|
input_pos (torch.Tensor): the optional input position tensor.
|
77
106
|
kv_cache (KVCacheEntry): the optional kv cache entry.
|
107
|
+
lora (LoRAEntry): the optional lora entry.
|
78
108
|
|
79
109
|
Returns:
|
80
110
|
output activation from this transformer block, and updated kv cache (if
|
@@ -83,7 +113,9 @@ class TransformerBlock(nn.Module):
|
|
83
113
|
kv = None
|
84
114
|
if self.config.parallel_residual:
|
85
115
|
x_norm = self.pre_atten_norm(x)
|
86
|
-
atten_func_out = self.atten_func(
|
116
|
+
atten_func_out = self.atten_func(
|
117
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
118
|
+
)
|
87
119
|
if kv_cache is None:
|
88
120
|
attn_out = atten_func_out
|
89
121
|
else:
|
@@ -92,7 +124,9 @@ class TransformerBlock(nn.Module):
|
|
92
124
|
output = x + attn_out + ff_out
|
93
125
|
else:
|
94
126
|
x_norm = self.pre_atten_norm(x)
|
95
|
-
atten_func_out = self.atten_func(
|
127
|
+
atten_func_out = self.atten_func(
|
128
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
129
|
+
)
|
96
130
|
if kv_cache is None:
|
97
131
|
attn_out = atten_func_out
|
98
132
|
else:
|
@@ -152,6 +186,7 @@ class CausalSelfAttention(nn.Module):
|
|
152
186
|
mask: Optional[torch.Tensor] = None,
|
153
187
|
input_pos: Optional[torch.Tensor] = None,
|
154
188
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
189
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
155
190
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
156
191
|
"""Forward function of the CausalSelfAttention layer, which can support
|
157
192
|
|
@@ -162,7 +197,8 @@ class CausalSelfAttention(nn.Module):
|
|
162
197
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
163
198
|
mask (torch.Tensor): the optional mask tensor.
|
164
199
|
input_pos (torch.Tensor): the optional input position tensor.
|
165
|
-
kv_cache (KVCacheEntry):
|
200
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
201
|
+
lora (LoRAEntry): the optional lora entry.
|
166
202
|
|
167
203
|
Returns:
|
168
204
|
output activation from this self attention layer, and the updated
|
@@ -201,6 +237,11 @@ class CausalSelfAttention(nn.Module):
|
|
201
237
|
dim=-1,
|
202
238
|
)
|
203
239
|
|
240
|
+
if lora is not None:
|
241
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
242
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
243
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
244
|
+
|
204
245
|
q = self.query_norm(q)
|
205
246
|
k = self.key_norm(k)
|
206
247
|
|
@@ -211,14 +252,13 @@ class CausalSelfAttention(nn.Module):
|
|
211
252
|
if rope is not None:
|
212
253
|
# Compute rotary positional embedding for query and key.
|
213
254
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
214
|
-
|
215
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
255
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
216
256
|
|
217
257
|
if kv_cache is not None:
|
218
258
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
219
259
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
220
260
|
|
221
|
-
|
261
|
+
sdpa_out = self.sdpa_func(
|
222
262
|
q,
|
223
263
|
k,
|
224
264
|
v,
|
@@ -226,10 +266,13 @@ class CausalSelfAttention(nn.Module):
|
|
226
266
|
mask=mask,
|
227
267
|
softcap=self.config.logit_softcap,
|
228
268
|
)
|
229
|
-
|
269
|
+
sdpa_out = sdpa_out.reshape(B, T, -1)
|
230
270
|
|
231
271
|
# Compute the output projection.
|
232
|
-
y = self.output_projection(
|
272
|
+
y = self.output_projection(sdpa_out)
|
273
|
+
if lora is not None:
|
274
|
+
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
|
275
|
+
|
233
276
|
return y if kv_cache is None else (y, kv_cache)
|
234
277
|
|
235
278
|
|
@@ -242,6 +285,7 @@ class SelfAttention(CausalSelfAttention):
|
|
242
285
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
243
286
|
input_pos: Optional[torch.Tensor] = None,
|
244
287
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
288
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
245
289
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
246
290
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
247
291
|
|
@@ -249,18 +293,23 @@ class SelfAttention(CausalSelfAttention):
|
|
249
293
|
x (torch.Tensor): the input tensor.
|
250
294
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
251
295
|
input_pos (torch.Tensor): the optional input position tensor.
|
252
|
-
kv_cache (KVCacheEntry):
|
296
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
297
|
+
lora (LoRAEntry): the optional lora entry.
|
253
298
|
|
254
299
|
Returns:
|
255
300
|
output activation from this self attention layer, and the updated
|
256
301
|
KV Cach Entry (if passed in).
|
257
302
|
"""
|
258
303
|
B, T, _ = x.size()
|
304
|
+
assert (
|
305
|
+
kv_cache is None
|
306
|
+
), "KV cache is not supported in non-causal SelfAttention."
|
259
307
|
return super().forward(
|
260
308
|
x,
|
261
309
|
rope=rope,
|
262
310
|
mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
|
263
311
|
input_pos=input_pos,
|
312
|
+
lora=lora,
|
264
313
|
)
|
265
314
|
|
266
315
|
|
@@ -317,6 +366,7 @@ class CrossAttention(nn.Module):
|
|
317
366
|
mask: Optional[torch.Tensor] = None,
|
318
367
|
input_pos: Optional[torch.Tensor] = None,
|
319
368
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
369
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
320
370
|
):
|
321
371
|
"""Forward function of the CrossAttention layer.
|
322
372
|
|
@@ -327,7 +377,8 @@ class CrossAttention(nn.Module):
|
|
327
377
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
328
378
|
[B, n_heads, target_seq_len, source_seq_len].
|
329
379
|
input_pos (torch.Tensor): the optional input position tensor.
|
330
|
-
kv_cache (KVCacheEntry):
|
380
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
381
|
+
lora (LoRAEntry): the optional lora entry.
|
331
382
|
|
332
383
|
Returns:
|
333
384
|
output activation from this cross attention layer.
|
@@ -340,6 +391,11 @@ class CrossAttention(nn.Module):
|
|
340
391
|
k = self.k_projection(y)
|
341
392
|
v = self.v_projection(y)
|
342
393
|
|
394
|
+
if lora is not None:
|
395
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
396
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
397
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
398
|
+
|
343
399
|
interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
|
344
400
|
q = q.view(interim_shape)
|
345
401
|
k = k.view(interim_shape)
|
@@ -348,8 +404,7 @@ class CrossAttention(nn.Module):
|
|
348
404
|
if rope is not None:
|
349
405
|
# Compute rotary positional embedding for query and key.
|
350
406
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
351
|
-
|
352
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
407
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
353
408
|
|
354
409
|
if kv_cache is not None:
|
355
410
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -363,4 +418,7 @@ class CrossAttention(nn.Module):
|
|
363
418
|
|
364
419
|
# Compute the output projection.
|
365
420
|
y = self.output_projection(y)
|
421
|
+
if lora is not None:
|
422
|
+
y += lora_utils.apply_lora(y, lora.attention.output)
|
423
|
+
|
366
424
|
return y if kv_cache is None else (y, kv_cache)
|