ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'qwen',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
68
+
58
69
 
59
70
  _BUILDER = {
60
71
  '0.5b': qwen.build_0_5b_model,
@@ -67,16 +78,13 @@ def main(_):
67
78
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
68
79
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
69
80
  )
70
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
71
- model_size = _MODEL_SIZE.value.replace('.', '_')
72
- output_filename = (
73
- f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
74
- )
75
81
  converter.convert_to_tflite(
76
82
  pytorch_model,
77
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
83
+ output_path=_OUTPUT_PATH.value,
84
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
78
85
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
79
86
  quantize=_QUANTIZE.value,
87
+ lora_ranks=_LORA_RANKS.value,
80
88
  export_config=ExportConfig(),
81
89
  )
82
90
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = smollm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
-
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
68
  converter.convert_to_tflite(
62
69
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
66
75
  export_config=ExportConfig(),
67
76
  )
68
77
 
@@ -0,0 +1,71 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmolLM2 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _TFLITE_PATH = flags.DEFINE_string(
33
+ 'tflite_path',
34
+ '/tmp/',
35
+ 'The tflite file path to export.',
36
+ )
37
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
+ 'prefill_seq_lens',
39
+ (8, 64, 128, 256, 512, 1024),
40
+ 'List of the maximum sizes of prefill input tensors.',
41
+ )
42
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
+ 'kv_cache_max_len',
44
+ 1280,
45
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
46
+ )
47
+ _QUANTIZE = flags.DEFINE_bool(
48
+ 'quantize',
49
+ True,
50
+ 'Whether the model should be quantized.',
51
+ )
52
+
53
+
54
+ def main(_):
55
+ pytorch_model = smollm.build_model_v2(
56
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
+ )
58
+
59
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
+ output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
+ converter.convert_to_tflite(
62
+ pytorch_model,
63
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
+ quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
67
+ )
68
+
69
+
70
+ if __name__ == '__main__':
71
+ app.run(main)
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85
85
  tensor_names=TENSOR_NAMES,
86
86
  model_class=SmolLM,
87
87
  )
88
+
89
+
90
+ class SmolLM2(model_builder.DecoderOnlyModel):
91
+ """A SmolLM2 model built from the Edge Generative API layers."""
92
+ pass
93
+
94
+
95
+ def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
96
+ """Returns the model config for a SmolLM2 135M model.
97
+
98
+ Args:
99
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
100
+ is 1024.
101
+
102
+ Returns:
103
+ The model config for a SmolLM2 model.
104
+ """
105
+ config = get_model_config(kv_cache_max_len)
106
+ config.block_config(0).attn_config.rotary_base = 100000
107
+ return config
108
+
109
+
110
+ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
111
+ config = get_model_config_v2(**kwargs)
112
+ config.vocab_size = 128
113
+ config.num_layers = 2
114
+ # SmolLM2 has only one block config.
115
+ config.block_config(0).ff_config.intermediate_size = 64
116
+ return config
117
+
118
+
119
+ def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
120
+ return model_builder.build_decoder_only_model(
121
+ checkpoint_path=checkpoint_path,
122
+ config=get_model_config_v2(**kwargs),
123
+ tensor_names=TENSOR_NAMES,
124
+ model_class=SmolLM2,
125
+ )
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
36
  30,
37
37
  "The maximum size of the generated tokens.",
38
38
  )
39
+ _MODEL_VERSION = flags.DEFINE_enum(
40
+ "model_version",
41
+ "v1",
42
+ ["v1", "v2"],
43
+ "The version of SmolLm to verify.",
44
+ )
45
+ _CHECKPOINT = {
46
+ "v1": "HuggingFaceTB/SmolLM-135M",
47
+ "v2": "HuggingFaceTB/SmolLM2-135M",
48
+ }
49
+
50
+ _BUILDER = {
51
+ "v1": smollm.build_model,
52
+ "v2": smollm.build_model_v2,
53
+ }
39
54
 
40
55
 
41
56
  def main(_):
42
- checkpoint = "HuggingFaceTB/SmolLM-135M"
57
+ checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58
+ builder = _BUILDER[_MODEL_VERSION.value]
43
59
  logging.info("Loading the original model from: %s", checkpoint)
44
60
  original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
61
 
@@ -49,7 +65,7 @@ def main(_):
49
65
  )
50
66
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
67
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = smollm.build_model(reauthored_checkpoint)
68
+ reauthored_model = builder(reauthored_checkpoint)
53
69
 
54
70
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
71
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entires = []
75
+ updated_kv_entries = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entires.append(kv_entry)
80
+ updated_kv_entries.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'tinyllama',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = tiny_llama.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
68
  converter.convert_to_tflite(
63
69
  pytorch_model,
64
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
65
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
67
75
  export_config=ExportConfig(),
68
76
  )
69
77
 
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
24
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -26,33 +27,6 @@ import torch
26
27
  from torch import nn
27
28
 
28
29
 
29
- def _embed_rope(
30
- q: torch.Tensor,
31
- k: torch.Tensor,
32
- n_elem: int,
33
- rope: Tuple[torch.Tensor, torch.Tensor],
34
- ) -> Tuple[torch.Tensor, torch.Tensor]:
35
- """Embed rotary positional embedding for query and key.
36
-
37
- Args:
38
- q (torch.Tensor): query tensor.
39
- k (torch.Tensor): key tensor.
40
- n_elem (int): number of elements to embed rotarty positional embedding.
41
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42
- """
43
- if n_elem > 0:
44
- cos, sin = rope
45
- q_roped = rotary_pos_emb.apply_rope(
46
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47
- )
48
- k_roped = rotary_pos_emb.apply_rope(
49
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50
- )
51
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53
- return q, k
54
-
55
-
56
30
  class TransformerBlock(nn.Module):
57
31
 
58
32
  def __init__(
@@ -93,6 +67,7 @@ class TransformerBlock(nn.Module):
93
67
  mask: Optional[torch.Tensor] = None,
94
68
  input_pos: Optional[torch.Tensor] = None,
95
69
  kv_cache: kv_utils.KVCacheEntry = None,
70
+ lora: Optional[lora_utils.LoRAEntry] = None,
96
71
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
97
72
  """Forward function of the TransformerBlock.
98
73
 
@@ -102,6 +77,7 @@ class TransformerBlock(nn.Module):
102
77
  mask (torch.Tensor): the optional mask tensor.
103
78
  input_pos (torch.Tensor): the optional input position tensor.
104
79
  kv_cache (KVCacheEntry): the optional kv cache entry.
80
+ lora (LoRAEntry): the optional lora entry.
105
81
 
106
82
  Returns:
107
83
  output activation from this transformer block, and updated kv cache (if
@@ -110,7 +86,9 @@ class TransformerBlock(nn.Module):
110
86
  kv = None
111
87
  if self.config.parallel_residual:
112
88
  x_norm = self.pre_atten_norm(x)
113
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
89
+ atten_func_out = self.atten_func(
90
+ x_norm, rope, mask, input_pos, kv_cache, lora
91
+ )
114
92
  if kv_cache is None:
115
93
  attn_out = atten_func_out
116
94
  else:
@@ -119,7 +97,9 @@ class TransformerBlock(nn.Module):
119
97
  output = x + attn_out + ff_out
120
98
  else:
121
99
  x_norm = self.pre_atten_norm(x)
122
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
100
+ atten_func_out = self.atten_func(
101
+ x_norm, rope, mask, input_pos, kv_cache, lora
102
+ )
123
103
  if kv_cache is None:
124
104
  attn_out = atten_func_out
125
105
  else:
@@ -179,6 +159,7 @@ class CausalSelfAttention(nn.Module):
179
159
  mask: Optional[torch.Tensor] = None,
180
160
  input_pos: Optional[torch.Tensor] = None,
181
161
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
162
+ lora: Optional[lora_utils.LoRAEntry] = None,
182
163
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
183
164
  """Forward function of the CausalSelfAttention layer, which can support
184
165
 
@@ -189,7 +170,8 @@ class CausalSelfAttention(nn.Module):
189
170
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
190
171
  mask (torch.Tensor): the optional mask tensor.
191
172
  input_pos (torch.Tensor): the optional input position tensor.
192
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
173
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
174
+ lora (LoRAEntry): the optional lora entry.
193
175
 
194
176
  Returns:
195
177
  output activation from this self attention layer, and the updated
@@ -228,6 +210,11 @@ class CausalSelfAttention(nn.Module):
228
210
  dim=-1,
229
211
  )
230
212
 
213
+ if lora is not None:
214
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
215
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
216
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
217
+
231
218
  q = self.query_norm(q)
232
219
  k = self.key_norm(k)
233
220
 
@@ -238,13 +225,14 @@ class CausalSelfAttention(nn.Module):
238
225
  if rope is not None:
239
226
  # Compute rotary positional embedding for query and key.
240
227
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
- q, k = _embed_rope(q, k, n_elem, rope)
228
+ cos, sin = rope
229
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
242
230
 
243
231
  if kv_cache is not None:
244
232
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
245
233
  k, v = kv_cache.k_cache, kv_cache.v_cache
246
234
 
247
- y = self.sdpa_func(
235
+ sdpa_out = self.sdpa_func(
248
236
  q,
249
237
  k,
250
238
  v,
@@ -252,10 +240,13 @@ class CausalSelfAttention(nn.Module):
252
240
  mask=mask,
253
241
  softcap=self.config.logit_softcap,
254
242
  )
255
- y = y.reshape(B, T, -1)
243
+ sdpa_out = sdpa_out.reshape(B, T, -1)
256
244
 
257
245
  # Compute the output projection.
258
- y = self.output_projection(y)
246
+ y = self.output_projection(sdpa_out)
247
+ if lora is not None:
248
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
249
+
259
250
  return y if kv_cache is None else (y, kv_cache)
260
251
 
261
252
 
@@ -268,6 +259,7 @@ class SelfAttention(CausalSelfAttention):
268
259
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
269
260
  input_pos: Optional[torch.Tensor] = None,
270
261
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
262
+ lora: Optional[lora_utils.LoRAEntry] = None,
271
263
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
272
264
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
273
265
 
@@ -275,18 +267,23 @@ class SelfAttention(CausalSelfAttention):
275
267
  x (torch.Tensor): the input tensor.
276
268
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
277
269
  input_pos (torch.Tensor): the optional input position tensor.
278
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
270
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
271
+ lora (LoRAEntry): the optional lora entry.
279
272
 
280
273
  Returns:
281
274
  output activation from this self attention layer, and the updated
282
275
  KV Cach Entry (if passed in).
283
276
  """
284
277
  B, T, _ = x.size()
278
+ assert (
279
+ kv_cache is None
280
+ ), "KV cache is not supported in non-causal SelfAttention."
285
281
  return super().forward(
286
282
  x,
287
283
  rope=rope,
288
284
  mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
289
285
  input_pos=input_pos,
286
+ lora=lora,
290
287
  )
291
288
 
292
289
 
@@ -343,6 +340,7 @@ class CrossAttention(nn.Module):
343
340
  mask: Optional[torch.Tensor] = None,
344
341
  input_pos: Optional[torch.Tensor] = None,
345
342
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
343
+ lora: Optional[lora_utils.LoRAEntry] = None,
346
344
  ):
347
345
  """Forward function of the CrossAttention layer.
348
346
 
@@ -353,7 +351,8 @@ class CrossAttention(nn.Module):
353
351
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
354
352
  [B, n_heads, target_seq_len, source_seq_len].
355
353
  input_pos (torch.Tensor): the optional input position tensor.
356
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
354
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
355
+ lora (LoRAEntry): the optional lora entry.
357
356
 
358
357
  Returns:
359
358
  output activation from this cross attention layer.
@@ -366,6 +365,11 @@ class CrossAttention(nn.Module):
366
365
  k = self.k_projection(y)
367
366
  v = self.v_projection(y)
368
367
 
368
+ if lora is not None:
369
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
370
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
371
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
372
+
369
373
  interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
370
374
  q = q.view(interim_shape)
371
375
  k = k.view(interim_shape)
@@ -374,7 +378,8 @@ class CrossAttention(nn.Module):
374
378
  if rope is not None:
375
379
  # Compute rotary positional embedding for query and key.
376
380
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
381
+ cos, sin = rope
382
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
378
383
 
379
384
  if kv_cache is not None:
380
385
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -388,4 +393,7 @@ class CrossAttention(nn.Module):
388
393
 
389
394
  # Compute the output projection.
390
395
  y = self.output_projection(y)
396
+ if lora is not None:
397
+ y += lora_utils.apply_lora(y, lora.attention.output)
398
+
391
399
  return y if kv_cache is None else (y, kv_cache)