ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  9. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  10. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  16. ai_edge_torch/generative/layers/attention.py +70 -12
  17. ai_edge_torch/generative/layers/lora.py +557 -0
  18. ai_edge_torch/generative/layers/normalization.py +2 -50
  19. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  20. ai_edge_torch/generative/test/test_lora.py +147 -0
  21. ai_edge_torch/generative/utilities/converter.py +100 -47
  22. ai_edge_torch/generative/utilities/model_builder.py +21 -16
  23. ai_edge_torch/generative/utilities/verifier.py +4 -4
  24. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  25. ai_edge_torch/odml_torch/export.py +6 -2
  26. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
  30. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = smollm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
-
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
68
  converter.convert_to_tflite(
62
69
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
66
75
  export_config=ExportConfig(),
67
76
  )
68
77
 
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entries = []
75
+ updated_kv_entires = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entries.append(kv_entry)
80
+ updated_kv_entires.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'tinyllama',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = tiny_llama.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
68
  converter.convert_to_tflite(
63
69
  pytorch_model,
64
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
65
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
67
75
  export_config=ExportConfig(),
68
76
  )
69
77
 
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
24
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -26,6 +27,33 @@ import torch
26
27
  from torch import nn
27
28
 
28
29
 
30
+ def _embed_rope(
31
+ q: torch.Tensor,
32
+ k: torch.Tensor,
33
+ n_elem: int,
34
+ rope: Tuple[torch.Tensor, torch.Tensor],
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """Embed rotary positional embedding for query and key.
37
+
38
+ Args:
39
+ q (torch.Tensor): query tensor.
40
+ k (torch.Tensor): key tensor.
41
+ n_elem (int): number of elements to embed rotarty positional embedding.
42
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
43
+ """
44
+ if n_elem > 0:
45
+ cos, sin = rope
46
+ q_roped = rotary_pos_emb.apply_rope(
47
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
48
+ )
49
+ k_roped = rotary_pos_emb.apply_rope(
50
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
51
+ )
52
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
53
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
54
+ return q, k
55
+
56
+
29
57
  class TransformerBlock(nn.Module):
30
58
 
31
59
  def __init__(
@@ -66,6 +94,7 @@ class TransformerBlock(nn.Module):
66
94
  mask: Optional[torch.Tensor] = None,
67
95
  input_pos: Optional[torch.Tensor] = None,
68
96
  kv_cache: kv_utils.KVCacheEntry = None,
97
+ lora: Optional[lora_utils.LoRAEntry] = None,
69
98
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
70
99
  """Forward function of the TransformerBlock.
71
100
 
@@ -75,6 +104,7 @@ class TransformerBlock(nn.Module):
75
104
  mask (torch.Tensor): the optional mask tensor.
76
105
  input_pos (torch.Tensor): the optional input position tensor.
77
106
  kv_cache (KVCacheEntry): the optional kv cache entry.
107
+ lora (LoRAEntry): the optional lora entry.
78
108
 
79
109
  Returns:
80
110
  output activation from this transformer block, and updated kv cache (if
@@ -83,7 +113,9 @@ class TransformerBlock(nn.Module):
83
113
  kv = None
84
114
  if self.config.parallel_residual:
85
115
  x_norm = self.pre_atten_norm(x)
86
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
116
+ atten_func_out = self.atten_func(
117
+ x_norm, rope, mask, input_pos, kv_cache, lora
118
+ )
87
119
  if kv_cache is None:
88
120
  attn_out = atten_func_out
89
121
  else:
@@ -92,7 +124,9 @@ class TransformerBlock(nn.Module):
92
124
  output = x + attn_out + ff_out
93
125
  else:
94
126
  x_norm = self.pre_atten_norm(x)
95
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
127
+ atten_func_out = self.atten_func(
128
+ x_norm, rope, mask, input_pos, kv_cache, lora
129
+ )
96
130
  if kv_cache is None:
97
131
  attn_out = atten_func_out
98
132
  else:
@@ -152,6 +186,7 @@ class CausalSelfAttention(nn.Module):
152
186
  mask: Optional[torch.Tensor] = None,
153
187
  input_pos: Optional[torch.Tensor] = None,
154
188
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
189
+ lora: Optional[lora_utils.LoRAEntry] = None,
155
190
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
156
191
  """Forward function of the CausalSelfAttention layer, which can support
157
192
 
@@ -162,7 +197,8 @@ class CausalSelfAttention(nn.Module):
162
197
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
163
198
  mask (torch.Tensor): the optional mask tensor.
164
199
  input_pos (torch.Tensor): the optional input position tensor.
165
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
200
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
201
+ lora (LoRAEntry): the optional lora entry.
166
202
 
167
203
  Returns:
168
204
  output activation from this self attention layer, and the updated
@@ -201,6 +237,11 @@ class CausalSelfAttention(nn.Module):
201
237
  dim=-1,
202
238
  )
203
239
 
240
+ if lora is not None:
241
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
242
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
243
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
244
+
204
245
  q = self.query_norm(q)
205
246
  k = self.key_norm(k)
206
247
 
@@ -211,14 +252,13 @@ class CausalSelfAttention(nn.Module):
211
252
  if rope is not None:
212
253
  # Compute rotary positional embedding for query and key.
213
254
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
214
- cos, sin = rope
215
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
255
+ q, k = _embed_rope(q, k, n_elem, rope)
216
256
 
217
257
  if kv_cache is not None:
218
258
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
219
259
  k, v = kv_cache.k_cache, kv_cache.v_cache
220
260
 
221
- y = self.sdpa_func(
261
+ sdpa_out = self.sdpa_func(
222
262
  q,
223
263
  k,
224
264
  v,
@@ -226,10 +266,13 @@ class CausalSelfAttention(nn.Module):
226
266
  mask=mask,
227
267
  softcap=self.config.logit_softcap,
228
268
  )
229
- y = y.reshape(B, T, -1)
269
+ sdpa_out = sdpa_out.reshape(B, T, -1)
230
270
 
231
271
  # Compute the output projection.
232
- y = self.output_projection(y)
272
+ y = self.output_projection(sdpa_out)
273
+ if lora is not None:
274
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
275
+
233
276
  return y if kv_cache is None else (y, kv_cache)
234
277
 
235
278
 
@@ -242,6 +285,7 @@ class SelfAttention(CausalSelfAttention):
242
285
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
243
286
  input_pos: Optional[torch.Tensor] = None,
244
287
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
288
+ lora: Optional[lora_utils.LoRAEntry] = None,
245
289
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
246
290
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
247
291
 
@@ -249,18 +293,23 @@ class SelfAttention(CausalSelfAttention):
249
293
  x (torch.Tensor): the input tensor.
250
294
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
251
295
  input_pos (torch.Tensor): the optional input position tensor.
252
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
296
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
297
+ lora (LoRAEntry): the optional lora entry.
253
298
 
254
299
  Returns:
255
300
  output activation from this self attention layer, and the updated
256
301
  KV Cach Entry (if passed in).
257
302
  """
258
303
  B, T, _ = x.size()
304
+ assert (
305
+ kv_cache is None
306
+ ), "KV cache is not supported in non-causal SelfAttention."
259
307
  return super().forward(
260
308
  x,
261
309
  rope=rope,
262
310
  mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
263
311
  input_pos=input_pos,
312
+ lora=lora,
264
313
  )
265
314
 
266
315
 
@@ -317,6 +366,7 @@ class CrossAttention(nn.Module):
317
366
  mask: Optional[torch.Tensor] = None,
318
367
  input_pos: Optional[torch.Tensor] = None,
319
368
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
369
+ lora: Optional[lora_utils.LoRAEntry] = None,
320
370
  ):
321
371
  """Forward function of the CrossAttention layer.
322
372
 
@@ -327,7 +377,8 @@ class CrossAttention(nn.Module):
327
377
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
328
378
  [B, n_heads, target_seq_len, source_seq_len].
329
379
  input_pos (torch.Tensor): the optional input position tensor.
330
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
380
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
381
+ lora (LoRAEntry): the optional lora entry.
331
382
 
332
383
  Returns:
333
384
  output activation from this cross attention layer.
@@ -340,6 +391,11 @@ class CrossAttention(nn.Module):
340
391
  k = self.k_projection(y)
341
392
  v = self.v_projection(y)
342
393
 
394
+ if lora is not None:
395
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
396
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
397
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
398
+
343
399
  interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
344
400
  q = q.view(interim_shape)
345
401
  k = k.view(interim_shape)
@@ -348,8 +404,7 @@ class CrossAttention(nn.Module):
348
404
  if rope is not None:
349
405
  # Compute rotary positional embedding for query and key.
350
406
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
351
- cos, sin = rope
352
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
407
+ q, k = _embed_rope(q, k, n_elem, rope)
353
408
 
354
409
  if kv_cache is not None:
355
410
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -363,4 +418,7 @@ class CrossAttention(nn.Module):
363
418
 
364
419
  # Compute the output projection.
365
420
  y = self.output_projection(y)
421
+ if lora is not None:
422
+ y += lora_utils.apply_lora(y, lora.attention.output)
423
+
366
424
  return y if kv_cache is None else (y, kv_cache)