ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'qwen',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
68
|
+
|
58
69
|
|
59
70
|
_BUILDER = {
|
60
71
|
'0.5b': qwen.build_0_5b_model,
|
@@ -67,16 +78,13 @@ def main(_):
|
|
67
78
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
68
79
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
69
80
|
)
|
70
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
71
|
-
model_size = _MODEL_SIZE.value.replace('.', '_')
|
72
|
-
output_filename = (
|
73
|
-
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
74
|
-
)
|
75
81
|
converter.convert_to_tflite(
|
76
82
|
pytorch_model,
|
77
|
-
|
83
|
+
output_path=_OUTPUT_PATH.value,
|
84
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
78
85
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
79
86
|
quantize=_QUANTIZE.value,
|
87
|
+
lora_ranks=_LORA_RANKS.value,
|
80
88
|
export_config=ExportConfig(),
|
81
89
|
)
|
82
90
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = smollm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
68
|
converter.convert_to_tflite(
|
62
69
|
pytorch_model,
|
63
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
66
75
|
export_config=ExportConfig(),
|
67
76
|
)
|
68
77
|
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting SmolLM2 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
33
|
+
'tflite_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The tflite file path to export.',
|
36
|
+
)
|
37
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
|
+
'prefill_seq_lens',
|
39
|
+
(8, 64, 128, 256, 512, 1024),
|
40
|
+
'List of the maximum sizes of prefill input tensors.',
|
41
|
+
)
|
42
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
|
+
'kv_cache_max_len',
|
44
|
+
1280,
|
45
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
|
+
)
|
47
|
+
_QUANTIZE = flags.DEFINE_bool(
|
48
|
+
'quantize',
|
49
|
+
True,
|
50
|
+
'Whether the model should be quantized.',
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def main(_):
|
55
|
+
pytorch_model = smollm.build_model_v2(
|
56
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
|
+
)
|
58
|
+
|
59
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
+
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
+
converter.convert_to_tflite(
|
62
|
+
pytorch_model,
|
63
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
64
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
|
+
quantize=_QUANTIZE.value,
|
66
|
+
export_config=ExportConfig(),
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
if __name__ == '__main__':
|
71
|
+
app.run(main)
|
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
85
85
|
tensor_names=TENSOR_NAMES,
|
86
86
|
model_class=SmolLM,
|
87
87
|
)
|
88
|
+
|
89
|
+
|
90
|
+
class SmolLM2(model_builder.DecoderOnlyModel):
|
91
|
+
"""A SmolLM2 model built from the Edge Generative API layers."""
|
92
|
+
pass
|
93
|
+
|
94
|
+
|
95
|
+
def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
96
|
+
"""Returns the model config for a SmolLM2 135M model.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
100
|
+
is 1024.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
The model config for a SmolLM2 model.
|
104
|
+
"""
|
105
|
+
config = get_model_config(kv_cache_max_len)
|
106
|
+
config.block_config(0).attn_config.rotary_base = 100000
|
107
|
+
return config
|
108
|
+
|
109
|
+
|
110
|
+
def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
111
|
+
config = get_model_config_v2(**kwargs)
|
112
|
+
config.vocab_size = 128
|
113
|
+
config.num_layers = 2
|
114
|
+
# SmolLM2 has only one block config.
|
115
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
116
|
+
return config
|
117
|
+
|
118
|
+
|
119
|
+
def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
|
120
|
+
return model_builder.build_decoder_only_model(
|
121
|
+
checkpoint_path=checkpoint_path,
|
122
|
+
config=get_model_config_v2(**kwargs),
|
123
|
+
tensor_names=TENSOR_NAMES,
|
124
|
+
model_class=SmolLM2,
|
125
|
+
)
|
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
36
36
|
30,
|
37
37
|
"The maximum size of the generated tokens.",
|
38
38
|
)
|
39
|
+
_MODEL_VERSION = flags.DEFINE_enum(
|
40
|
+
"model_version",
|
41
|
+
"v1",
|
42
|
+
["v1", "v2"],
|
43
|
+
"The version of SmolLm to verify.",
|
44
|
+
)
|
45
|
+
_CHECKPOINT = {
|
46
|
+
"v1": "HuggingFaceTB/SmolLM-135M",
|
47
|
+
"v2": "HuggingFaceTB/SmolLM2-135M",
|
48
|
+
}
|
49
|
+
|
50
|
+
_BUILDER = {
|
51
|
+
"v1": smollm.build_model,
|
52
|
+
"v2": smollm.build_model_v2,
|
53
|
+
}
|
39
54
|
|
40
55
|
|
41
56
|
def main(_):
|
42
|
-
checkpoint =
|
57
|
+
checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
|
58
|
+
builder = _BUILDER[_MODEL_VERSION.value]
|
43
59
|
logging.info("Loading the original model from: %s", checkpoint)
|
44
60
|
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
61
|
|
@@ -49,7 +65,7 @@ def main(_):
|
|
49
65
|
)
|
50
66
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
67
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model =
|
68
|
+
reauthored_model = builder(reauthored_checkpoint)
|
53
69
|
|
54
70
|
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
71
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entries = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entries.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'tinyllama',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = tiny_llama.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
68
|
converter.convert_to_tflite(
|
63
69
|
pytorch_model,
|
64
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
65
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
66
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
67
75
|
export_config=ExportConfig(),
|
68
76
|
)
|
69
77
|
|
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
@@ -26,33 +27,6 @@ import torch
|
|
26
27
|
from torch import nn
|
27
28
|
|
28
29
|
|
29
|
-
def _embed_rope(
|
30
|
-
q: torch.Tensor,
|
31
|
-
k: torch.Tensor,
|
32
|
-
n_elem: int,
|
33
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
34
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
-
"""Embed rotary positional embedding for query and key.
|
36
|
-
|
37
|
-
Args:
|
38
|
-
q (torch.Tensor): query tensor.
|
39
|
-
k (torch.Tensor): key tensor.
|
40
|
-
n_elem (int): number of elements to embed rotarty positional embedding.
|
41
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
42
|
-
"""
|
43
|
-
if n_elem > 0:
|
44
|
-
cos, sin = rope
|
45
|
-
q_roped = rotary_pos_emb.apply_rope(
|
46
|
-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
47
|
-
)
|
48
|
-
k_roped = rotary_pos_emb.apply_rope(
|
49
|
-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
50
|
-
)
|
51
|
-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
52
|
-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
53
|
-
return q, k
|
54
|
-
|
55
|
-
|
56
30
|
class TransformerBlock(nn.Module):
|
57
31
|
|
58
32
|
def __init__(
|
@@ -93,6 +67,7 @@ class TransformerBlock(nn.Module):
|
|
93
67
|
mask: Optional[torch.Tensor] = None,
|
94
68
|
input_pos: Optional[torch.Tensor] = None,
|
95
69
|
kv_cache: kv_utils.KVCacheEntry = None,
|
70
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
96
71
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
97
72
|
"""Forward function of the TransformerBlock.
|
98
73
|
|
@@ -102,6 +77,7 @@ class TransformerBlock(nn.Module):
|
|
102
77
|
mask (torch.Tensor): the optional mask tensor.
|
103
78
|
input_pos (torch.Tensor): the optional input position tensor.
|
104
79
|
kv_cache (KVCacheEntry): the optional kv cache entry.
|
80
|
+
lora (LoRAEntry): the optional lora entry.
|
105
81
|
|
106
82
|
Returns:
|
107
83
|
output activation from this transformer block, and updated kv cache (if
|
@@ -110,7 +86,9 @@ class TransformerBlock(nn.Module):
|
|
110
86
|
kv = None
|
111
87
|
if self.config.parallel_residual:
|
112
88
|
x_norm = self.pre_atten_norm(x)
|
113
|
-
atten_func_out = self.atten_func(
|
89
|
+
atten_func_out = self.atten_func(
|
90
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
91
|
+
)
|
114
92
|
if kv_cache is None:
|
115
93
|
attn_out = atten_func_out
|
116
94
|
else:
|
@@ -119,7 +97,9 @@ class TransformerBlock(nn.Module):
|
|
119
97
|
output = x + attn_out + ff_out
|
120
98
|
else:
|
121
99
|
x_norm = self.pre_atten_norm(x)
|
122
|
-
atten_func_out = self.atten_func(
|
100
|
+
atten_func_out = self.atten_func(
|
101
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
102
|
+
)
|
123
103
|
if kv_cache is None:
|
124
104
|
attn_out = atten_func_out
|
125
105
|
else:
|
@@ -179,6 +159,7 @@ class CausalSelfAttention(nn.Module):
|
|
179
159
|
mask: Optional[torch.Tensor] = None,
|
180
160
|
input_pos: Optional[torch.Tensor] = None,
|
181
161
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
162
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
182
163
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
183
164
|
"""Forward function of the CausalSelfAttention layer, which can support
|
184
165
|
|
@@ -189,7 +170,8 @@ class CausalSelfAttention(nn.Module):
|
|
189
170
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
190
171
|
mask (torch.Tensor): the optional mask tensor.
|
191
172
|
input_pos (torch.Tensor): the optional input position tensor.
|
192
|
-
kv_cache (KVCacheEntry):
|
173
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
174
|
+
lora (LoRAEntry): the optional lora entry.
|
193
175
|
|
194
176
|
Returns:
|
195
177
|
output activation from this self attention layer, and the updated
|
@@ -228,6 +210,11 @@ class CausalSelfAttention(nn.Module):
|
|
228
210
|
dim=-1,
|
229
211
|
)
|
230
212
|
|
213
|
+
if lora is not None:
|
214
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
215
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
216
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
217
|
+
|
231
218
|
q = self.query_norm(q)
|
232
219
|
k = self.key_norm(k)
|
233
220
|
|
@@ -238,13 +225,14 @@ class CausalSelfAttention(nn.Module):
|
|
238
225
|
if rope is not None:
|
239
226
|
# Compute rotary positional embedding for query and key.
|
240
227
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
-
|
228
|
+
cos, sin = rope
|
229
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
242
230
|
|
243
231
|
if kv_cache is not None:
|
244
232
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
245
233
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
246
234
|
|
247
|
-
|
235
|
+
sdpa_out = self.sdpa_func(
|
248
236
|
q,
|
249
237
|
k,
|
250
238
|
v,
|
@@ -252,10 +240,13 @@ class CausalSelfAttention(nn.Module):
|
|
252
240
|
mask=mask,
|
253
241
|
softcap=self.config.logit_softcap,
|
254
242
|
)
|
255
|
-
|
243
|
+
sdpa_out = sdpa_out.reshape(B, T, -1)
|
256
244
|
|
257
245
|
# Compute the output projection.
|
258
|
-
y = self.output_projection(
|
246
|
+
y = self.output_projection(sdpa_out)
|
247
|
+
if lora is not None:
|
248
|
+
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
|
249
|
+
|
259
250
|
return y if kv_cache is None else (y, kv_cache)
|
260
251
|
|
261
252
|
|
@@ -268,6 +259,7 @@ class SelfAttention(CausalSelfAttention):
|
|
268
259
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
269
260
|
input_pos: Optional[torch.Tensor] = None,
|
270
261
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
262
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
271
263
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
272
264
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
273
265
|
|
@@ -275,18 +267,23 @@ class SelfAttention(CausalSelfAttention):
|
|
275
267
|
x (torch.Tensor): the input tensor.
|
276
268
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
277
269
|
input_pos (torch.Tensor): the optional input position tensor.
|
278
|
-
kv_cache (KVCacheEntry):
|
270
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
271
|
+
lora (LoRAEntry): the optional lora entry.
|
279
272
|
|
280
273
|
Returns:
|
281
274
|
output activation from this self attention layer, and the updated
|
282
275
|
KV Cach Entry (if passed in).
|
283
276
|
"""
|
284
277
|
B, T, _ = x.size()
|
278
|
+
assert (
|
279
|
+
kv_cache is None
|
280
|
+
), "KV cache is not supported in non-causal SelfAttention."
|
285
281
|
return super().forward(
|
286
282
|
x,
|
287
283
|
rope=rope,
|
288
284
|
mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
|
289
285
|
input_pos=input_pos,
|
286
|
+
lora=lora,
|
290
287
|
)
|
291
288
|
|
292
289
|
|
@@ -343,6 +340,7 @@ class CrossAttention(nn.Module):
|
|
343
340
|
mask: Optional[torch.Tensor] = None,
|
344
341
|
input_pos: Optional[torch.Tensor] = None,
|
345
342
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
343
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
346
344
|
):
|
347
345
|
"""Forward function of the CrossAttention layer.
|
348
346
|
|
@@ -353,7 +351,8 @@ class CrossAttention(nn.Module):
|
|
353
351
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
354
352
|
[B, n_heads, target_seq_len, source_seq_len].
|
355
353
|
input_pos (torch.Tensor): the optional input position tensor.
|
356
|
-
kv_cache (KVCacheEntry):
|
354
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
355
|
+
lora (LoRAEntry): the optional lora entry.
|
357
356
|
|
358
357
|
Returns:
|
359
358
|
output activation from this cross attention layer.
|
@@ -366,6 +365,11 @@ class CrossAttention(nn.Module):
|
|
366
365
|
k = self.k_projection(y)
|
367
366
|
v = self.v_projection(y)
|
368
367
|
|
368
|
+
if lora is not None:
|
369
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
370
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
371
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
372
|
+
|
369
373
|
interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
|
370
374
|
q = q.view(interim_shape)
|
371
375
|
k = k.view(interim_shape)
|
@@ -374,7 +378,8 @@ class CrossAttention(nn.Module):
|
|
374
378
|
if rope is not None:
|
375
379
|
# Compute rotary positional embedding for query and key.
|
376
380
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
377
|
-
|
381
|
+
cos, sin = rope
|
382
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
378
383
|
|
379
384
|
if kv_cache is not None:
|
380
385
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -388,4 +393,7 @@ class CrossAttention(nn.Module):
|
|
388
393
|
|
389
394
|
# Compute the output projection.
|
390
395
|
y = self.output_projection(y)
|
396
|
+
if lora is not None:
|
397
|
+
y += lora_utils.apply_lora(y, lora.attention.output)
|
398
|
+
|
391
399
|
return y if kv_cache is None else (y, kv_cache)
|