ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'qwen',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
68
+
58
69
 
59
70
  _BUILDER = {
60
71
  '0.5b': qwen.build_0_5b_model,
@@ -67,16 +78,13 @@ def main(_):
67
78
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
68
79
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
69
80
  )
70
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
71
- model_size = _MODEL_SIZE.value.replace('.', '_')
72
- output_filename = (
73
- f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
74
- )
75
81
  converter.convert_to_tflite(
76
82
  pytorch_model,
77
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
83
+ output_path=_OUTPUT_PATH.value,
84
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
78
85
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
79
86
  quantize=_QUANTIZE.value,
87
+ lora_ranks=_LORA_RANKS.value,
80
88
  export_config=ExportConfig(),
81
89
  )
82
90
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = smollm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
-
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
68
  converter.convert_to_tflite(
62
69
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
66
75
  export_config=ExportConfig(),
67
76
  )
68
77
 
@@ -0,0 +1,71 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmolLM2 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _CHECKPOINT_PATH = flags.DEFINE_string(
28
+ 'checkpoint_path',
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
31
+ )
32
+ _TFLITE_PATH = flags.DEFINE_string(
33
+ 'tflite_path',
34
+ '/tmp/',
35
+ 'The tflite file path to export.',
36
+ )
37
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
+ 'prefill_seq_lens',
39
+ (8, 64, 128, 256, 512, 1024),
40
+ 'List of the maximum sizes of prefill input tensors.',
41
+ )
42
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
43
+ 'kv_cache_max_len',
44
+ 1280,
45
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
46
+ )
47
+ _QUANTIZE = flags.DEFINE_bool(
48
+ 'quantize',
49
+ True,
50
+ 'Whether the model should be quantized.',
51
+ )
52
+
53
+
54
+ def main(_):
55
+ pytorch_model = smollm.build_model_v2(
56
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
+ )
58
+
59
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
+ output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
+ converter.convert_to_tflite(
62
+ pytorch_model,
63
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
+ quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
67
+ )
68
+
69
+
70
+ if __name__ == '__main__':
71
+ app.run(main)
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85
85
  tensor_names=TENSOR_NAMES,
86
86
  model_class=SmolLM,
87
87
  )
88
+
89
+
90
+ class SmolLM2(model_builder.DecoderOnlyModel):
91
+ """A SmolLM2 model built from the Edge Generative API layers."""
92
+ pass
93
+
94
+
95
+ def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
96
+ """Returns the model config for a SmolLM2 135M model.
97
+
98
+ Args:
99
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
100
+ is 1024.
101
+
102
+ Returns:
103
+ The model config for a SmolLM2 model.
104
+ """
105
+ config = get_model_config(kv_cache_max_len)
106
+ config.block_config(0).attn_config.rotary_base = 100000
107
+ return config
108
+
109
+
110
+ def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
111
+ config = get_model_config_v2(**kwargs)
112
+ config.vocab_size = 128
113
+ config.num_layers = 2
114
+ # SmolLM2 has only one block config.
115
+ config.block_config(0).ff_config.intermediate_size = 64
116
+ return config
117
+
118
+
119
+ def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
120
+ return model_builder.build_decoder_only_model(
121
+ checkpoint_path=checkpoint_path,
122
+ config=get_model_config_v2(**kwargs),
123
+ tensor_names=TENSOR_NAMES,
124
+ model_class=SmolLM2,
125
+ )
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
36
  30,
37
37
  "The maximum size of the generated tokens.",
38
38
  )
39
+ _MODEL_VERSION = flags.DEFINE_enum(
40
+ "model_version",
41
+ "v1",
42
+ ["v1", "v2"],
43
+ "The version of SmolLm to verify.",
44
+ )
45
+ _CHECKPOINT = {
46
+ "v1": "HuggingFaceTB/SmolLM-135M",
47
+ "v2": "HuggingFaceTB/SmolLM2-135M",
48
+ }
49
+
50
+ _BUILDER = {
51
+ "v1": smollm.build_model,
52
+ "v2": smollm.build_model_v2,
53
+ }
39
54
 
40
55
 
41
56
  def main(_):
42
- checkpoint = "HuggingFaceTB/SmolLM-135M"
57
+ checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
58
+ builder = _BUILDER[_MODEL_VERSION.value]
43
59
  logging.info("Loading the original model from: %s", checkpoint)
44
60
  original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
61
 
@@ -49,7 +65,7 @@ def main(_):
49
65
  )
50
66
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
67
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = smollm.build_model(reauthored_checkpoint)
68
+ reauthored_model = builder(reauthored_checkpoint)
53
69
 
54
70
  logging.info("Loading the tokenizer from: %s", checkpoint)
55
71
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entires = []
75
+ updated_kv_entries = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entires.append(kv_entry)
80
+ updated_kv_entries.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'tinyllama',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = tiny_llama.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
68
  converter.convert_to_tflite(
63
69
  pytorch_model,
64
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
65
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
67
75
  export_config=ExportConfig(),
68
76
  )
69
77
 
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
24
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -26,33 +27,6 @@ import torch
26
27
  from torch import nn
27
28
 
28
29
 
29
- def _embed_rope(
30
- q: torch.Tensor,
31
- k: torch.Tensor,
32
- n_elem: int,
33
- rope: Tuple[torch.Tensor, torch.Tensor],
34
- ) -> Tuple[torch.Tensor, torch.Tensor]:
35
- """Embed rotary positional embedding for query and key.
36
-
37
- Args:
38
- q (torch.Tensor): query tensor.
39
- k (torch.Tensor): key tensor.
40
- n_elem (int): number of elements to embed rotarty positional embedding.
41
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42
- """
43
- if n_elem > 0:
44
- cos, sin = rope
45
- q_roped = rotary_pos_emb.apply_rope(
46
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47
- )
48
- k_roped = rotary_pos_emb.apply_rope(
49
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50
- )
51
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53
- return q, k
54
-
55
-
56
30
  class TransformerBlock(nn.Module):
57
31
 
58
32
  def __init__(
@@ -93,6 +67,7 @@ class TransformerBlock(nn.Module):
93
67
  mask: Optional[torch.Tensor] = None,
94
68
  input_pos: Optional[torch.Tensor] = None,
95
69
  kv_cache: kv_utils.KVCacheEntry = None,
70
+ lora: Optional[lora_utils.LoRAEntry] = None,
96
71
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
97
72
  """Forward function of the TransformerBlock.
98
73
 
@@ -102,6 +77,7 @@ class TransformerBlock(nn.Module):
102
77
  mask (torch.Tensor): the optional mask tensor.
103
78
  input_pos (torch.Tensor): the optional input position tensor.
104
79
  kv_cache (KVCacheEntry): the optional kv cache entry.
80
+ lora (LoRAEntry): the optional lora entry.
105
81
 
106
82
  Returns:
107
83
  output activation from this transformer block, and updated kv cache (if
@@ -110,7 +86,9 @@ class TransformerBlock(nn.Module):
110
86
  kv = None
111
87
  if self.config.parallel_residual:
112
88
  x_norm = self.pre_atten_norm(x)
113
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
89
+ atten_func_out = self.atten_func(
90
+ x_norm, rope, mask, input_pos, kv_cache, lora
91
+ )
114
92
  if kv_cache is None:
115
93
  attn_out = atten_func_out
116
94
  else:
@@ -119,7 +97,9 @@ class TransformerBlock(nn.Module):
119
97
  output = x + attn_out + ff_out
120
98
  else:
121
99
  x_norm = self.pre_atten_norm(x)
122
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
100
+ atten_func_out = self.atten_func(
101
+ x_norm, rope, mask, input_pos, kv_cache, lora
102
+ )
123
103
  if kv_cache is None:
124
104
  attn_out = atten_func_out
125
105
  else:
@@ -179,6 +159,7 @@ class CausalSelfAttention(nn.Module):
179
159
  mask: Optional[torch.Tensor] = None,
180
160
  input_pos: Optional[torch.Tensor] = None,
181
161
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
162
+ lora: Optional[lora_utils.LoRAEntry] = None,
182
163
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
183
164
  """Forward function of the CausalSelfAttention layer, which can support
184
165
 
@@ -189,7 +170,8 @@ class CausalSelfAttention(nn.Module):
189
170
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
190
171
  mask (torch.Tensor): the optional mask tensor.
191
172
  input_pos (torch.Tensor): the optional input position tensor.
192
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
173
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
174
+ lora (LoRAEntry): the optional lora entry.
193
175
 
194
176
  Returns:
195
177
  output activation from this self attention layer, and the updated
@@ -228,6 +210,11 @@ class CausalSelfAttention(nn.Module):
228
210
  dim=-1,
229
211
  )
230
212
 
213
+ if lora is not None:
214
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
215
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
216
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
217
+
231
218
  q = self.query_norm(q)
232
219
  k = self.key_norm(k)
233
220
 
@@ -238,13 +225,14 @@ class CausalSelfAttention(nn.Module):
238
225
  if rope is not None:
239
226
  # Compute rotary positional embedding for query and key.
240
227
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
- q, k = _embed_rope(q, k, n_elem, rope)
228
+ cos, sin = rope
229
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
242
230
 
243
231
  if kv_cache is not None:
244
232
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
245
233
  k, v = kv_cache.k_cache, kv_cache.v_cache
246
234
 
247
- y = self.sdpa_func(
235
+ sdpa_out = self.sdpa_func(
248
236
  q,
249
237
  k,
250
238
  v,
@@ -252,10 +240,13 @@ class CausalSelfAttention(nn.Module):
252
240
  mask=mask,
253
241
  softcap=self.config.logit_softcap,
254
242
  )
255
- y = y.reshape(B, T, -1)
243
+ sdpa_out = sdpa_out.reshape(B, T, -1)
256
244
 
257
245
  # Compute the output projection.
258
- y = self.output_projection(y)
246
+ y = self.output_projection(sdpa_out)
247
+ if lora is not None:
248
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
249
+
259
250
  return y if kv_cache is None else (y, kv_cache)
260
251
 
261
252
 
@@ -268,6 +259,7 @@ class SelfAttention(CausalSelfAttention):
268
259
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
269
260
  input_pos: Optional[torch.Tensor] = None,
270
261
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
262
+ lora: Optional[lora_utils.LoRAEntry] = None,
271
263
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
272
264
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
273
265
 
@@ -275,18 +267,23 @@ class SelfAttention(CausalSelfAttention):
275
267
  x (torch.Tensor): the input tensor.
276
268
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
277
269
  input_pos (torch.Tensor): the optional input position tensor.
278
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
270
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
271
+ lora (LoRAEntry): the optional lora entry.
279
272
 
280
273
  Returns:
281
274
  output activation from this self attention layer, and the updated
282
275
  KV Cach Entry (if passed in).
283
276
  """
284
277
  B, T, _ = x.size()
278
+ assert (
279
+ kv_cache is None
280
+ ), "KV cache is not supported in non-causal SelfAttention."
285
281
  return super().forward(
286
282
  x,
287
283
  rope=rope,
288
284
  mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
289
285
  input_pos=input_pos,
286
+ lora=lora,
290
287
  )
291
288
 
292
289
 
@@ -343,6 +340,7 @@ class CrossAttention(nn.Module):
343
340
  mask: Optional[torch.Tensor] = None,
344
341
  input_pos: Optional[torch.Tensor] = None,
345
342
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
343
+ lora: Optional[lora_utils.LoRAEntry] = None,
346
344
  ):
347
345
  """Forward function of the CrossAttention layer.
348
346
 
@@ -353,7 +351,8 @@ class CrossAttention(nn.Module):
353
351
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
354
352
  [B, n_heads, target_seq_len, source_seq_len].
355
353
  input_pos (torch.Tensor): the optional input position tensor.
356
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
354
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
355
+ lora (LoRAEntry): the optional lora entry.
357
356
 
358
357
  Returns:
359
358
  output activation from this cross attention layer.
@@ -366,6 +365,11 @@ class CrossAttention(nn.Module):
366
365
  k = self.k_projection(y)
367
366
  v = self.v_projection(y)
368
367
 
368
+ if lora is not None:
369
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
370
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
371
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
372
+
369
373
  interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
370
374
  q = q.view(interim_shape)
371
375
  k = k.view(interim_shape)
@@ -374,7 +378,8 @@ class CrossAttention(nn.Module):
374
378
  if rope is not None:
375
379
  # Compute rotary positional embedding for query and key.
376
380
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
381
+ cos, sin = rope
382
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
378
383
 
379
384
  if kv_cache is not None:
380
385
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -388,4 +393,7 @@ class CrossAttention(nn.Module):
388
393
 
389
394
  # Compute the output projection.
390
395
  y = self.output_projection(y)
396
+ if lora is not None:
397
+ y += lora_utils.apply_lora(y, lora.attention.output)
398
+
391
399
  return y if kv_cache is None else (y, kv_cache)