ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'qwen',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
68
|
+
|
58
69
|
|
59
70
|
_BUILDER = {
|
60
71
|
'0.5b': qwen.build_0_5b_model,
|
@@ -67,16 +78,13 @@ def main(_):
|
|
67
78
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
68
79
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
69
80
|
)
|
70
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
71
|
-
model_size = _MODEL_SIZE.value.replace('.', '_')
|
72
|
-
output_filename = (
|
73
|
-
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
74
|
-
)
|
75
81
|
converter.convert_to_tflite(
|
76
82
|
pytorch_model,
|
77
|
-
|
83
|
+
output_path=_OUTPUT_PATH.value,
|
84
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
78
85
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
79
86
|
quantize=_QUANTIZE.value,
|
87
|
+
lora_ranks=_LORA_RANKS.value,
|
80
88
|
export_config=ExportConfig(),
|
81
89
|
)
|
82
90
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = smollm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
68
|
converter.convert_to_tflite(
|
62
69
|
pytorch_model,
|
63
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
66
75
|
export_config=ExportConfig(),
|
67
76
|
)
|
68
77
|
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting SmolLM2 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.smollm import smollm
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
|
+
'checkpoint_path',
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
|
30
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
|
+
)
|
32
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
33
|
+
'tflite_path',
|
34
|
+
'/tmp/',
|
35
|
+
'The tflite file path to export.',
|
36
|
+
)
|
37
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
|
+
'prefill_seq_lens',
|
39
|
+
(8, 64, 128, 256, 512, 1024),
|
40
|
+
'List of the maximum sizes of prefill input tensors.',
|
41
|
+
)
|
42
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
|
+
'kv_cache_max_len',
|
44
|
+
1280,
|
45
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
46
|
+
)
|
47
|
+
_QUANTIZE = flags.DEFINE_bool(
|
48
|
+
'quantize',
|
49
|
+
True,
|
50
|
+
'Whether the model should be quantized.',
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
def main(_):
|
55
|
+
pytorch_model = smollm.build_model_v2(
|
56
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
|
+
)
|
58
|
+
|
59
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
+
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
+
converter.convert_to_tflite(
|
62
|
+
pytorch_model,
|
63
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
64
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
|
+
quantize=_QUANTIZE.value,
|
66
|
+
export_config=ExportConfig(),
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
if __name__ == '__main__':
|
71
|
+
app.run(main)
|
@@ -85,3 +85,41 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
85
85
|
tensor_names=TENSOR_NAMES,
|
86
86
|
model_class=SmolLM,
|
87
87
|
)
|
88
|
+
|
89
|
+
|
90
|
+
class SmolLM2(model_builder.DecoderOnlyModel):
|
91
|
+
"""A SmolLM2 model built from the Edge Generative API layers."""
|
92
|
+
pass
|
93
|
+
|
94
|
+
|
95
|
+
def get_model_config_v2(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
96
|
+
"""Returns the model config for a SmolLM2 135M model.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
100
|
+
is 1024.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
The model config for a SmolLM2 model.
|
104
|
+
"""
|
105
|
+
config = get_model_config(kv_cache_max_len)
|
106
|
+
config.block_config(0).attn_config.rotary_base = 100000
|
107
|
+
return config
|
108
|
+
|
109
|
+
|
110
|
+
def get_fake_model_config_v2(**kwargs) -> cfg.ModelConfig:
|
111
|
+
config = get_model_config_v2(**kwargs)
|
112
|
+
config.vocab_size = 128
|
113
|
+
config.num_layers = 2
|
114
|
+
# SmolLM2 has only one block config.
|
115
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
116
|
+
return config
|
117
|
+
|
118
|
+
|
119
|
+
def build_model_v2(checkpoint_path: str, **kwargs) -> nn.Module:
|
120
|
+
return model_builder.build_decoder_only_model(
|
121
|
+
checkpoint_path=checkpoint_path,
|
122
|
+
config=get_model_config_v2(**kwargs),
|
123
|
+
tensor_names=TENSOR_NAMES,
|
124
|
+
model_class=SmolLM2,
|
125
|
+
)
|
@@ -36,10 +36,26 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
36
36
|
30,
|
37
37
|
"The maximum size of the generated tokens.",
|
38
38
|
)
|
39
|
+
_MODEL_VERSION = flags.DEFINE_enum(
|
40
|
+
"model_version",
|
41
|
+
"v1",
|
42
|
+
["v1", "v2"],
|
43
|
+
"The version of SmolLm to verify.",
|
44
|
+
)
|
45
|
+
_CHECKPOINT = {
|
46
|
+
"v1": "HuggingFaceTB/SmolLM-135M",
|
47
|
+
"v2": "HuggingFaceTB/SmolLM2-135M",
|
48
|
+
}
|
49
|
+
|
50
|
+
_BUILDER = {
|
51
|
+
"v1": smollm.build_model,
|
52
|
+
"v2": smollm.build_model_v2,
|
53
|
+
}
|
39
54
|
|
40
55
|
|
41
56
|
def main(_):
|
42
|
-
checkpoint =
|
57
|
+
checkpoint = _CHECKPOINT[_MODEL_VERSION.value]
|
58
|
+
builder = _BUILDER[_MODEL_VERSION.value]
|
43
59
|
logging.info("Loading the original model from: %s", checkpoint)
|
44
60
|
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
45
61
|
|
@@ -49,7 +65,7 @@ def main(_):
|
|
49
65
|
)
|
50
66
|
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
51
67
|
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
52
|
-
reauthored_model =
|
68
|
+
reauthored_model = builder(reauthored_checkpoint)
|
53
69
|
|
54
70
|
logging.info("Loading the tokenizer from: %s", checkpoint)
|
55
71
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entries = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entries.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'tinyllama',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = tiny_llama.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
68
|
converter.convert_to_tflite(
|
63
69
|
pytorch_model,
|
64
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
65
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
66
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
67
75
|
export_config=ExportConfig(),
|
68
76
|
)
|
69
77
|
|
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
@@ -26,33 +27,6 @@ import torch
|
|
26
27
|
from torch import nn
|
27
28
|
|
28
29
|
|
29
|
-
def _embed_rope(
|
30
|
-
q: torch.Tensor,
|
31
|
-
k: torch.Tensor,
|
32
|
-
n_elem: int,
|
33
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
34
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
-
"""Embed rotary positional embedding for query and key.
|
36
|
-
|
37
|
-
Args:
|
38
|
-
q (torch.Tensor): query tensor.
|
39
|
-
k (torch.Tensor): key tensor.
|
40
|
-
n_elem (int): number of elements to embed rotarty positional embedding.
|
41
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
42
|
-
"""
|
43
|
-
if n_elem > 0:
|
44
|
-
cos, sin = rope
|
45
|
-
q_roped = rotary_pos_emb.apply_rope(
|
46
|
-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
47
|
-
)
|
48
|
-
k_roped = rotary_pos_emb.apply_rope(
|
49
|
-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
50
|
-
)
|
51
|
-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
52
|
-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
53
|
-
return q, k
|
54
|
-
|
55
|
-
|
56
30
|
class TransformerBlock(nn.Module):
|
57
31
|
|
58
32
|
def __init__(
|
@@ -93,6 +67,7 @@ class TransformerBlock(nn.Module):
|
|
93
67
|
mask: Optional[torch.Tensor] = None,
|
94
68
|
input_pos: Optional[torch.Tensor] = None,
|
95
69
|
kv_cache: kv_utils.KVCacheEntry = None,
|
70
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
96
71
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
97
72
|
"""Forward function of the TransformerBlock.
|
98
73
|
|
@@ -102,6 +77,7 @@ class TransformerBlock(nn.Module):
|
|
102
77
|
mask (torch.Tensor): the optional mask tensor.
|
103
78
|
input_pos (torch.Tensor): the optional input position tensor.
|
104
79
|
kv_cache (KVCacheEntry): the optional kv cache entry.
|
80
|
+
lora (LoRAEntry): the optional lora entry.
|
105
81
|
|
106
82
|
Returns:
|
107
83
|
output activation from this transformer block, and updated kv cache (if
|
@@ -110,7 +86,9 @@ class TransformerBlock(nn.Module):
|
|
110
86
|
kv = None
|
111
87
|
if self.config.parallel_residual:
|
112
88
|
x_norm = self.pre_atten_norm(x)
|
113
|
-
atten_func_out = self.atten_func(
|
89
|
+
atten_func_out = self.atten_func(
|
90
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
91
|
+
)
|
114
92
|
if kv_cache is None:
|
115
93
|
attn_out = atten_func_out
|
116
94
|
else:
|
@@ -119,7 +97,9 @@ class TransformerBlock(nn.Module):
|
|
119
97
|
output = x + attn_out + ff_out
|
120
98
|
else:
|
121
99
|
x_norm = self.pre_atten_norm(x)
|
122
|
-
atten_func_out = self.atten_func(
|
100
|
+
atten_func_out = self.atten_func(
|
101
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
102
|
+
)
|
123
103
|
if kv_cache is None:
|
124
104
|
attn_out = atten_func_out
|
125
105
|
else:
|
@@ -179,6 +159,7 @@ class CausalSelfAttention(nn.Module):
|
|
179
159
|
mask: Optional[torch.Tensor] = None,
|
180
160
|
input_pos: Optional[torch.Tensor] = None,
|
181
161
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
162
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
182
163
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
183
164
|
"""Forward function of the CausalSelfAttention layer, which can support
|
184
165
|
|
@@ -189,7 +170,8 @@ class CausalSelfAttention(nn.Module):
|
|
189
170
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
190
171
|
mask (torch.Tensor): the optional mask tensor.
|
191
172
|
input_pos (torch.Tensor): the optional input position tensor.
|
192
|
-
kv_cache (KVCacheEntry):
|
173
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
174
|
+
lora (LoRAEntry): the optional lora entry.
|
193
175
|
|
194
176
|
Returns:
|
195
177
|
output activation from this self attention layer, and the updated
|
@@ -228,6 +210,11 @@ class CausalSelfAttention(nn.Module):
|
|
228
210
|
dim=-1,
|
229
211
|
)
|
230
212
|
|
213
|
+
if lora is not None:
|
214
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
215
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
216
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
217
|
+
|
231
218
|
q = self.query_norm(q)
|
232
219
|
k = self.key_norm(k)
|
233
220
|
|
@@ -238,13 +225,14 @@ class CausalSelfAttention(nn.Module):
|
|
238
225
|
if rope is not None:
|
239
226
|
# Compute rotary positional embedding for query and key.
|
240
227
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
-
|
228
|
+
cos, sin = rope
|
229
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
242
230
|
|
243
231
|
if kv_cache is not None:
|
244
232
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
245
233
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
246
234
|
|
247
|
-
|
235
|
+
sdpa_out = self.sdpa_func(
|
248
236
|
q,
|
249
237
|
k,
|
250
238
|
v,
|
@@ -252,10 +240,13 @@ class CausalSelfAttention(nn.Module):
|
|
252
240
|
mask=mask,
|
253
241
|
softcap=self.config.logit_softcap,
|
254
242
|
)
|
255
|
-
|
243
|
+
sdpa_out = sdpa_out.reshape(B, T, -1)
|
256
244
|
|
257
245
|
# Compute the output projection.
|
258
|
-
y = self.output_projection(
|
246
|
+
y = self.output_projection(sdpa_out)
|
247
|
+
if lora is not None:
|
248
|
+
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
|
249
|
+
|
259
250
|
return y if kv_cache is None else (y, kv_cache)
|
260
251
|
|
261
252
|
|
@@ -268,6 +259,7 @@ class SelfAttention(CausalSelfAttention):
|
|
268
259
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
269
260
|
input_pos: Optional[torch.Tensor] = None,
|
270
261
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
262
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
271
263
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
272
264
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
273
265
|
|
@@ -275,18 +267,23 @@ class SelfAttention(CausalSelfAttention):
|
|
275
267
|
x (torch.Tensor): the input tensor.
|
276
268
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
277
269
|
input_pos (torch.Tensor): the optional input position tensor.
|
278
|
-
kv_cache (KVCacheEntry):
|
270
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
271
|
+
lora (LoRAEntry): the optional lora entry.
|
279
272
|
|
280
273
|
Returns:
|
281
274
|
output activation from this self attention layer, and the updated
|
282
275
|
KV Cach Entry (if passed in).
|
283
276
|
"""
|
284
277
|
B, T, _ = x.size()
|
278
|
+
assert (
|
279
|
+
kv_cache is None
|
280
|
+
), "KV cache is not supported in non-causal SelfAttention."
|
285
281
|
return super().forward(
|
286
282
|
x,
|
287
283
|
rope=rope,
|
288
284
|
mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
|
289
285
|
input_pos=input_pos,
|
286
|
+
lora=lora,
|
290
287
|
)
|
291
288
|
|
292
289
|
|
@@ -343,6 +340,7 @@ class CrossAttention(nn.Module):
|
|
343
340
|
mask: Optional[torch.Tensor] = None,
|
344
341
|
input_pos: Optional[torch.Tensor] = None,
|
345
342
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
343
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
346
344
|
):
|
347
345
|
"""Forward function of the CrossAttention layer.
|
348
346
|
|
@@ -353,7 +351,8 @@ class CrossAttention(nn.Module):
|
|
353
351
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
354
352
|
[B, n_heads, target_seq_len, source_seq_len].
|
355
353
|
input_pos (torch.Tensor): the optional input position tensor.
|
356
|
-
kv_cache (KVCacheEntry):
|
354
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
355
|
+
lora (LoRAEntry): the optional lora entry.
|
357
356
|
|
358
357
|
Returns:
|
359
358
|
output activation from this cross attention layer.
|
@@ -366,6 +365,11 @@ class CrossAttention(nn.Module):
|
|
366
365
|
k = self.k_projection(y)
|
367
366
|
v = self.v_projection(y)
|
368
367
|
|
368
|
+
if lora is not None:
|
369
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
370
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
371
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
372
|
+
|
369
373
|
interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
|
370
374
|
q = q.view(interim_shape)
|
371
375
|
k = k.view(interim_shape)
|
@@ -374,7 +378,8 @@ class CrossAttention(nn.Module):
|
|
374
378
|
if rope is not None:
|
375
379
|
# Compute rotary positional embedding for query and key.
|
376
380
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
377
|
-
|
381
|
+
cos, sin = rope
|
382
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
378
383
|
|
379
384
|
if kv_cache is not None:
|
380
385
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -388,4 +393,7 @@ class CrossAttention(nn.Module):
|
|
388
393
|
|
389
394
|
# Compute the output projection.
|
390
395
|
y = self.output_projection(y)
|
396
|
+
if lora is not None:
|
397
|
+
y += lora_utils.apply_lora(y, lora.attention.output)
|
398
|
+
|
391
399
|
return y if kv_cache is None else (y, kv_cache)
|