ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +70 -12
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +21 -16
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = smollm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
68
|
converter.convert_to_tflite(
|
62
69
|
pytorch_model,
|
63
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
66
75
|
export_config=ExportConfig(),
|
67
76
|
)
|
68
77
|
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entires = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entires.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'tinyllama',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = tiny_llama.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
68
|
converter.convert_to_tflite(
|
63
69
|
pytorch_model,
|
64
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
65
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
66
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
67
75
|
export_config=ExportConfig(),
|
68
76
|
)
|
69
77
|
|
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
@@ -26,6 +27,33 @@ import torch
|
|
26
27
|
from torch import nn
|
27
28
|
|
28
29
|
|
30
|
+
def _embed_rope(
|
31
|
+
q: torch.Tensor,
|
32
|
+
k: torch.Tensor,
|
33
|
+
n_elem: int,
|
34
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
+
"""Embed rotary positional embedding for query and key.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
q (torch.Tensor): query tensor.
|
40
|
+
k (torch.Tensor): key tensor.
|
41
|
+
n_elem (int): number of elements to embed rotarty positional embedding.
|
42
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
43
|
+
"""
|
44
|
+
if n_elem > 0:
|
45
|
+
cos, sin = rope
|
46
|
+
q_roped = rotary_pos_emb.apply_rope(
|
47
|
+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
48
|
+
)
|
49
|
+
k_roped = rotary_pos_emb.apply_rope(
|
50
|
+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
51
|
+
)
|
52
|
+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
53
|
+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
54
|
+
return q, k
|
55
|
+
|
56
|
+
|
29
57
|
class TransformerBlock(nn.Module):
|
30
58
|
|
31
59
|
def __init__(
|
@@ -66,6 +94,7 @@ class TransformerBlock(nn.Module):
|
|
66
94
|
mask: Optional[torch.Tensor] = None,
|
67
95
|
input_pos: Optional[torch.Tensor] = None,
|
68
96
|
kv_cache: kv_utils.KVCacheEntry = None,
|
97
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
69
98
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
70
99
|
"""Forward function of the TransformerBlock.
|
71
100
|
|
@@ -75,6 +104,7 @@ class TransformerBlock(nn.Module):
|
|
75
104
|
mask (torch.Tensor): the optional mask tensor.
|
76
105
|
input_pos (torch.Tensor): the optional input position tensor.
|
77
106
|
kv_cache (KVCacheEntry): the optional kv cache entry.
|
107
|
+
lora (LoRAEntry): the optional lora entry.
|
78
108
|
|
79
109
|
Returns:
|
80
110
|
output activation from this transformer block, and updated kv cache (if
|
@@ -83,7 +113,9 @@ class TransformerBlock(nn.Module):
|
|
83
113
|
kv = None
|
84
114
|
if self.config.parallel_residual:
|
85
115
|
x_norm = self.pre_atten_norm(x)
|
86
|
-
atten_func_out = self.atten_func(
|
116
|
+
atten_func_out = self.atten_func(
|
117
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
118
|
+
)
|
87
119
|
if kv_cache is None:
|
88
120
|
attn_out = atten_func_out
|
89
121
|
else:
|
@@ -92,7 +124,9 @@ class TransformerBlock(nn.Module):
|
|
92
124
|
output = x + attn_out + ff_out
|
93
125
|
else:
|
94
126
|
x_norm = self.pre_atten_norm(x)
|
95
|
-
atten_func_out = self.atten_func(
|
127
|
+
atten_func_out = self.atten_func(
|
128
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
129
|
+
)
|
96
130
|
if kv_cache is None:
|
97
131
|
attn_out = atten_func_out
|
98
132
|
else:
|
@@ -152,6 +186,7 @@ class CausalSelfAttention(nn.Module):
|
|
152
186
|
mask: Optional[torch.Tensor] = None,
|
153
187
|
input_pos: Optional[torch.Tensor] = None,
|
154
188
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
189
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
155
190
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
156
191
|
"""Forward function of the CausalSelfAttention layer, which can support
|
157
192
|
|
@@ -162,7 +197,8 @@ class CausalSelfAttention(nn.Module):
|
|
162
197
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
163
198
|
mask (torch.Tensor): the optional mask tensor.
|
164
199
|
input_pos (torch.Tensor): the optional input position tensor.
|
165
|
-
kv_cache (KVCacheEntry):
|
200
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
201
|
+
lora (LoRAEntry): the optional lora entry.
|
166
202
|
|
167
203
|
Returns:
|
168
204
|
output activation from this self attention layer, and the updated
|
@@ -201,6 +237,11 @@ class CausalSelfAttention(nn.Module):
|
|
201
237
|
dim=-1,
|
202
238
|
)
|
203
239
|
|
240
|
+
if lora is not None:
|
241
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
242
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
243
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
244
|
+
|
204
245
|
q = self.query_norm(q)
|
205
246
|
k = self.key_norm(k)
|
206
247
|
|
@@ -211,14 +252,13 @@ class CausalSelfAttention(nn.Module):
|
|
211
252
|
if rope is not None:
|
212
253
|
# Compute rotary positional embedding for query and key.
|
213
254
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
214
|
-
|
215
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
255
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
216
256
|
|
217
257
|
if kv_cache is not None:
|
218
258
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
219
259
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
220
260
|
|
221
|
-
|
261
|
+
sdpa_out = self.sdpa_func(
|
222
262
|
q,
|
223
263
|
k,
|
224
264
|
v,
|
@@ -226,10 +266,13 @@ class CausalSelfAttention(nn.Module):
|
|
226
266
|
mask=mask,
|
227
267
|
softcap=self.config.logit_softcap,
|
228
268
|
)
|
229
|
-
|
269
|
+
sdpa_out = sdpa_out.reshape(B, T, -1)
|
230
270
|
|
231
271
|
# Compute the output projection.
|
232
|
-
y = self.output_projection(
|
272
|
+
y = self.output_projection(sdpa_out)
|
273
|
+
if lora is not None:
|
274
|
+
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
|
275
|
+
|
233
276
|
return y if kv_cache is None else (y, kv_cache)
|
234
277
|
|
235
278
|
|
@@ -242,6 +285,7 @@ class SelfAttention(CausalSelfAttention):
|
|
242
285
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
243
286
|
input_pos: Optional[torch.Tensor] = None,
|
244
287
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
288
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
245
289
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
246
290
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
247
291
|
|
@@ -249,18 +293,23 @@ class SelfAttention(CausalSelfAttention):
|
|
249
293
|
x (torch.Tensor): the input tensor.
|
250
294
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
251
295
|
input_pos (torch.Tensor): the optional input position tensor.
|
252
|
-
kv_cache (KVCacheEntry):
|
296
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
297
|
+
lora (LoRAEntry): the optional lora entry.
|
253
298
|
|
254
299
|
Returns:
|
255
300
|
output activation from this self attention layer, and the updated
|
256
301
|
KV Cach Entry (if passed in).
|
257
302
|
"""
|
258
303
|
B, T, _ = x.size()
|
304
|
+
assert (
|
305
|
+
kv_cache is None
|
306
|
+
), "KV cache is not supported in non-causal SelfAttention."
|
259
307
|
return super().forward(
|
260
308
|
x,
|
261
309
|
rope=rope,
|
262
310
|
mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
|
263
311
|
input_pos=input_pos,
|
312
|
+
lora=lora,
|
264
313
|
)
|
265
314
|
|
266
315
|
|
@@ -317,6 +366,7 @@ class CrossAttention(nn.Module):
|
|
317
366
|
mask: Optional[torch.Tensor] = None,
|
318
367
|
input_pos: Optional[torch.Tensor] = None,
|
319
368
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
369
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
320
370
|
):
|
321
371
|
"""Forward function of the CrossAttention layer.
|
322
372
|
|
@@ -327,7 +377,8 @@ class CrossAttention(nn.Module):
|
|
327
377
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
328
378
|
[B, n_heads, target_seq_len, source_seq_len].
|
329
379
|
input_pos (torch.Tensor): the optional input position tensor.
|
330
|
-
kv_cache (KVCacheEntry):
|
380
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
381
|
+
lora (LoRAEntry): the optional lora entry.
|
331
382
|
|
332
383
|
Returns:
|
333
384
|
output activation from this cross attention layer.
|
@@ -340,6 +391,11 @@ class CrossAttention(nn.Module):
|
|
340
391
|
k = self.k_projection(y)
|
341
392
|
v = self.v_projection(y)
|
342
393
|
|
394
|
+
if lora is not None:
|
395
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
396
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
397
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
398
|
+
|
343
399
|
interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
|
344
400
|
q = q.view(interim_shape)
|
345
401
|
k = k.view(interim_shape)
|
@@ -348,8 +404,7 @@ class CrossAttention(nn.Module):
|
|
348
404
|
if rope is not None:
|
349
405
|
# Compute rotary positional embedding for query and key.
|
350
406
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
351
|
-
|
352
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
407
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
353
408
|
|
354
409
|
if kv_cache is not None:
|
355
410
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -363,4 +418,7 @@ class CrossAttention(nn.Module):
|
|
363
418
|
|
364
419
|
# Compute the output projection.
|
365
420
|
y = self.output_projection(y)
|
421
|
+
if lora is not None:
|
422
|
+
y += lora_utils.apply_lora(y, lora.attention.output)
|
423
|
+
|
366
424
|
return y if kv_cache is None else (y, kv_cache)
|