ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  9. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  10. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  16. ai_edge_torch/generative/layers/attention.py +70 -12
  17. ai_edge_torch/generative/layers/lora.py +557 -0
  18. ai_edge_torch/generative/layers/normalization.py +2 -50
  19. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  20. ai_edge_torch/generative/test/test_lora.py +147 -0
  21. ai_edge_torch/generative/utilities/converter.py +100 -47
  22. ai_edge_torch/generative/utilities/model_builder.py +21 -16
  23. ai_edge_torch/generative/utilities/verifier.py +4 -4
  24. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  25. ai_edge_torch/odml_torch/export.py +6 -2
  26. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
  30. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = smollm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
-
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
68
  converter.convert_to_tflite(
62
69
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
66
75
  export_config=ExportConfig(),
67
76
  )
68
77
 
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entries = []
75
+ updated_kv_entires = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entries.append(kv_entry)
80
+ updated_kv_entires.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'tinyllama',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = tiny_llama.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
68
  converter.convert_to_tflite(
63
69
  pytorch_model,
64
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
65
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
67
75
  export_config=ExportConfig(),
68
76
  )
69
77
 
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
24
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -26,6 +27,33 @@ import torch
26
27
  from torch import nn
27
28
 
28
29
 
30
+ def _embed_rope(
31
+ q: torch.Tensor,
32
+ k: torch.Tensor,
33
+ n_elem: int,
34
+ rope: Tuple[torch.Tensor, torch.Tensor],
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """Embed rotary positional embedding for query and key.
37
+
38
+ Args:
39
+ q (torch.Tensor): query tensor.
40
+ k (torch.Tensor): key tensor.
41
+ n_elem (int): number of elements to embed rotarty positional embedding.
42
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
43
+ """
44
+ if n_elem > 0:
45
+ cos, sin = rope
46
+ q_roped = rotary_pos_emb.apply_rope(
47
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
48
+ )
49
+ k_roped = rotary_pos_emb.apply_rope(
50
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
51
+ )
52
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
53
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
54
+ return q, k
55
+
56
+
29
57
  class TransformerBlock(nn.Module):
30
58
 
31
59
  def __init__(
@@ -66,6 +94,7 @@ class TransformerBlock(nn.Module):
66
94
  mask: Optional[torch.Tensor] = None,
67
95
  input_pos: Optional[torch.Tensor] = None,
68
96
  kv_cache: kv_utils.KVCacheEntry = None,
97
+ lora: Optional[lora_utils.LoRAEntry] = None,
69
98
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
70
99
  """Forward function of the TransformerBlock.
71
100
 
@@ -75,6 +104,7 @@ class TransformerBlock(nn.Module):
75
104
  mask (torch.Tensor): the optional mask tensor.
76
105
  input_pos (torch.Tensor): the optional input position tensor.
77
106
  kv_cache (KVCacheEntry): the optional kv cache entry.
107
+ lora (LoRAEntry): the optional lora entry.
78
108
 
79
109
  Returns:
80
110
  output activation from this transformer block, and updated kv cache (if
@@ -83,7 +113,9 @@ class TransformerBlock(nn.Module):
83
113
  kv = None
84
114
  if self.config.parallel_residual:
85
115
  x_norm = self.pre_atten_norm(x)
86
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
116
+ atten_func_out = self.atten_func(
117
+ x_norm, rope, mask, input_pos, kv_cache, lora
118
+ )
87
119
  if kv_cache is None:
88
120
  attn_out = atten_func_out
89
121
  else:
@@ -92,7 +124,9 @@ class TransformerBlock(nn.Module):
92
124
  output = x + attn_out + ff_out
93
125
  else:
94
126
  x_norm = self.pre_atten_norm(x)
95
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
127
+ atten_func_out = self.atten_func(
128
+ x_norm, rope, mask, input_pos, kv_cache, lora
129
+ )
96
130
  if kv_cache is None:
97
131
  attn_out = atten_func_out
98
132
  else:
@@ -152,6 +186,7 @@ class CausalSelfAttention(nn.Module):
152
186
  mask: Optional[torch.Tensor] = None,
153
187
  input_pos: Optional[torch.Tensor] = None,
154
188
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
189
+ lora: Optional[lora_utils.LoRAEntry] = None,
155
190
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
156
191
  """Forward function of the CausalSelfAttention layer, which can support
157
192
 
@@ -162,7 +197,8 @@ class CausalSelfAttention(nn.Module):
162
197
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
163
198
  mask (torch.Tensor): the optional mask tensor.
164
199
  input_pos (torch.Tensor): the optional input position tensor.
165
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
200
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
201
+ lora (LoRAEntry): the optional lora entry.
166
202
 
167
203
  Returns:
168
204
  output activation from this self attention layer, and the updated
@@ -201,6 +237,11 @@ class CausalSelfAttention(nn.Module):
201
237
  dim=-1,
202
238
  )
203
239
 
240
+ if lora is not None:
241
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
242
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
243
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
244
+
204
245
  q = self.query_norm(q)
205
246
  k = self.key_norm(k)
206
247
 
@@ -211,14 +252,13 @@ class CausalSelfAttention(nn.Module):
211
252
  if rope is not None:
212
253
  # Compute rotary positional embedding for query and key.
213
254
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
214
- cos, sin = rope
215
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
255
+ q, k = _embed_rope(q, k, n_elem, rope)
216
256
 
217
257
  if kv_cache is not None:
218
258
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
219
259
  k, v = kv_cache.k_cache, kv_cache.v_cache
220
260
 
221
- y = self.sdpa_func(
261
+ sdpa_out = self.sdpa_func(
222
262
  q,
223
263
  k,
224
264
  v,
@@ -226,10 +266,13 @@ class CausalSelfAttention(nn.Module):
226
266
  mask=mask,
227
267
  softcap=self.config.logit_softcap,
228
268
  )
229
- y = y.reshape(B, T, -1)
269
+ sdpa_out = sdpa_out.reshape(B, T, -1)
230
270
 
231
271
  # Compute the output projection.
232
- y = self.output_projection(y)
272
+ y = self.output_projection(sdpa_out)
273
+ if lora is not None:
274
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
275
+
233
276
  return y if kv_cache is None else (y, kv_cache)
234
277
 
235
278
 
@@ -242,6 +285,7 @@ class SelfAttention(CausalSelfAttention):
242
285
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
243
286
  input_pos: Optional[torch.Tensor] = None,
244
287
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
288
+ lora: Optional[lora_utils.LoRAEntry] = None,
245
289
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
246
290
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
247
291
 
@@ -249,18 +293,23 @@ class SelfAttention(CausalSelfAttention):
249
293
  x (torch.Tensor): the input tensor.
250
294
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
251
295
  input_pos (torch.Tensor): the optional input position tensor.
252
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
296
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
297
+ lora (LoRAEntry): the optional lora entry.
253
298
 
254
299
  Returns:
255
300
  output activation from this self attention layer, and the updated
256
301
  KV Cach Entry (if passed in).
257
302
  """
258
303
  B, T, _ = x.size()
304
+ assert (
305
+ kv_cache is None
306
+ ), "KV cache is not supported in non-causal SelfAttention."
259
307
  return super().forward(
260
308
  x,
261
309
  rope=rope,
262
310
  mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
263
311
  input_pos=input_pos,
312
+ lora=lora,
264
313
  )
265
314
 
266
315
 
@@ -317,6 +366,7 @@ class CrossAttention(nn.Module):
317
366
  mask: Optional[torch.Tensor] = None,
318
367
  input_pos: Optional[torch.Tensor] = None,
319
368
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
369
+ lora: Optional[lora_utils.LoRAEntry] = None,
320
370
  ):
321
371
  """Forward function of the CrossAttention layer.
322
372
 
@@ -327,7 +377,8 @@ class CrossAttention(nn.Module):
327
377
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
328
378
  [B, n_heads, target_seq_len, source_seq_len].
329
379
  input_pos (torch.Tensor): the optional input position tensor.
330
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
380
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
381
+ lora (LoRAEntry): the optional lora entry.
331
382
 
332
383
  Returns:
333
384
  output activation from this cross attention layer.
@@ -340,6 +391,11 @@ class CrossAttention(nn.Module):
340
391
  k = self.k_projection(y)
341
392
  v = self.v_projection(y)
342
393
 
394
+ if lora is not None:
395
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
396
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
397
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
398
+
343
399
  interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
344
400
  q = q.view(interim_shape)
345
401
  k = k.view(interim_shape)
@@ -348,8 +404,7 @@ class CrossAttention(nn.Module):
348
404
  if rope is not None:
349
405
  # Compute rotary positional embedding for query and key.
350
406
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
351
- cos, sin = rope
352
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
407
+ q, k = _embed_rope(q, k, n_elem, rope)
353
408
 
354
409
  if kv_cache is not None:
355
410
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -363,4 +418,7 @@ class CrossAttention(nn.Module):
363
418
 
364
419
  # Compute the output projection.
365
420
  y = self.output_projection(y)
421
+ if lora is not None:
422
+ y += lora_utils.apply_lora(y, lora.attention.output)
423
+
366
424
  return y if kv_cache is None else (y, kv_cache)