ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +35 -13
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.
|
21
|
-
from ai_edge_torch.generative.examples.phi2 import phi2
|
22
|
-
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
20
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
|
23
21
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
24
|
-
from ai_edge_torch.
|
22
|
+
from ai_edge_torch.generative.layers import kv_cache
|
23
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
24
|
import numpy as np
|
26
25
|
import torch
|
27
26
|
|
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
|
|
49
48
|
)
|
50
49
|
def test_toy_model_with_kv_cache(self):
|
51
50
|
config = toy_model_with_kv_cache.get_model_config()
|
52
|
-
pytorch_model = toy_model_with_kv_cache.
|
53
|
-
|
51
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
52
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
54
53
|
[10], dtype=torch.int64
|
55
54
|
)
|
56
|
-
|
57
|
-
|
55
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
56
|
+
|
57
|
+
edge_model = ai_edge_torch.convert(
|
58
|
+
pytorch_model,
|
59
|
+
sample_kwargs={
|
60
|
+
"tokens": tokens,
|
61
|
+
"input_pos": input_pos,
|
62
|
+
"kv_cache": kv,
|
63
|
+
},
|
64
|
+
)
|
58
65
|
edge_model.set_interpreter_builder(
|
59
66
|
self._interpreter_builder(edge_model.tflite_model())
|
60
67
|
)
|
61
68
|
|
62
69
|
self.assertTrue(
|
63
|
-
|
70
|
+
test_utils.compare_tflite_torch(
|
64
71
|
edge_model,
|
65
72
|
pytorch_model,
|
66
|
-
|
67
|
-
|
73
|
+
tokens,
|
74
|
+
input_pos,
|
75
|
+
kv,
|
76
|
+
signature_name="serving_default",
|
68
77
|
atol=1e-5,
|
69
78
|
rtol=1e-5,
|
70
79
|
)
|
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
|
|
77
86
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
78
87
|
config = toy_model_with_kv_cache.get_model_config()
|
79
88
|
config.enable_hlfb = True
|
80
|
-
pytorch_model = toy_model_with_kv_cache.
|
81
|
-
|
89
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
90
|
+
tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
82
91
|
[10], dtype=torch.int64
|
83
92
|
)
|
84
|
-
|
85
|
-
|
93
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
94
|
+
|
95
|
+
edge_model = ai_edge_torch.convert(
|
96
|
+
pytorch_model,
|
97
|
+
sample_kwargs={
|
98
|
+
"tokens": tokens,
|
99
|
+
"input_pos": input_pos,
|
100
|
+
"kv_cache": kv,
|
101
|
+
},
|
102
|
+
)
|
86
103
|
edge_model.set_interpreter_builder(
|
87
104
|
self._interpreter_builder(edge_model.tflite_model())
|
88
105
|
)
|
89
106
|
|
90
107
|
self.assertTrue(
|
91
|
-
|
108
|
+
test_utils.compare_tflite_torch(
|
92
109
|
edge_model,
|
93
110
|
pytorch_model,
|
94
|
-
|
95
|
-
|
111
|
+
tokens,
|
112
|
+
input_pos,
|
113
|
+
kv,
|
114
|
+
signature_name="serving_default",
|
96
115
|
atol=1e-5,
|
97
116
|
rtol=1e-5,
|
98
117
|
)
|
@@ -104,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
|
|
104
123
|
)
|
105
124
|
def test_tiny_llama_multisig(self):
|
106
125
|
config = tiny_llama.get_fake_model_config()
|
107
|
-
pytorch_model = tiny_llama.
|
126
|
+
pytorch_model = tiny_llama.TinyLlama(config).eval()
|
108
127
|
|
109
128
|
# prefill
|
110
129
|
seq_len = 10
|
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
|
|
117
136
|
decode_token = torch.tensor([[1]], dtype=torch.long)
|
118
137
|
decode_input_pos = torch.tensor([5], dtype=torch.int64)
|
119
138
|
|
139
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
140
|
+
|
120
141
|
edge_model = (
|
121
142
|
ai_edge_torch.signature(
|
122
|
-
"prefill",
|
143
|
+
"prefill",
|
144
|
+
pytorch_model,
|
145
|
+
sample_kwargs={
|
146
|
+
"tokens": prefill_tokens,
|
147
|
+
"input_pos": prefill_input_pos,
|
148
|
+
"kv_cache": kv,
|
149
|
+
},
|
150
|
+
)
|
151
|
+
.signature(
|
152
|
+
"decode",
|
153
|
+
pytorch_model,
|
154
|
+
sample_kwargs={
|
155
|
+
"tokens": decode_token,
|
156
|
+
"input_pos": decode_input_pos,
|
157
|
+
"kv_cache": kv,
|
158
|
+
},
|
123
159
|
)
|
124
|
-
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
125
160
|
.convert()
|
126
161
|
)
|
127
162
|
edge_model.set_interpreter_builder(
|
128
163
|
self._interpreter_builder(edge_model.tflite_model())
|
129
164
|
)
|
130
165
|
|
131
|
-
copied_model = copy.deepcopy(pytorch_model)
|
132
|
-
copied_edge = copy.deepcopy(edge_model)
|
133
|
-
|
134
166
|
self.assertTrue(
|
135
|
-
|
167
|
+
test_utils.compare_tflite_torch(
|
136
168
|
edge_model,
|
137
169
|
pytorch_model,
|
138
|
-
|
170
|
+
prefill_tokens,
|
171
|
+
prefill_input_pos,
|
172
|
+
kv,
|
139
173
|
signature_name="prefill",
|
140
|
-
|
174
|
+
atol=1e-5,
|
175
|
+
rtol=1e-5,
|
141
176
|
)
|
142
177
|
)
|
143
178
|
|
144
179
|
self.assertTrue(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
180
|
+
test_utils.compare_tflite_torch(
|
181
|
+
edge_model,
|
182
|
+
pytorch_model,
|
183
|
+
decode_token,
|
184
|
+
decode_input_pos,
|
185
|
+
kv,
|
149
186
|
signature_name="decode",
|
150
|
-
|
187
|
+
atol=1e-5,
|
188
|
+
rtol=1e-5,
|
151
189
|
)
|
152
190
|
)
|
153
191
|
|
@@ -12,16 +12,16 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
15
|
+
|
16
|
+
"""Testing model conversion for a few gen-ai models."""
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
-
from ai_edge_torch.generative.examples.
|
22
|
-
from ai_edge_torch.generative.examples.
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma
|
21
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
|
+
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.test import utils as test_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
27
|
|
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
|
|
55
55
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
56
56
|
tokens[0, :4] = idx
|
57
57
|
input_pos = torch.arange(0, 10)
|
58
|
-
|
59
|
-
|
58
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
59
|
+
|
60
|
+
edge_model = ai_edge_torch.convert(
|
61
|
+
model,
|
62
|
+
sample_kwargs={
|
63
|
+
"tokens": tokens,
|
64
|
+
"input_pos": input_pos,
|
65
|
+
"kv_cache": kv,
|
66
|
+
},
|
67
|
+
)
|
60
68
|
edge_model.set_interpreter_builder(
|
61
69
|
self._interpreter_builder(edge_model.tflite_model())
|
62
70
|
)
|
63
71
|
|
64
72
|
self.assertTrue(
|
65
|
-
|
73
|
+
test_utils.compare_tflite_torch(
|
66
74
|
edge_model,
|
67
75
|
model,
|
68
|
-
|
69
|
-
|
76
|
+
tokens,
|
77
|
+
input_pos,
|
78
|
+
kv,
|
79
|
+
signature_name="serving_default",
|
70
80
|
atol=1e-2,
|
71
81
|
rtol=1e-5,
|
72
82
|
)
|
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
|
|
85
95
|
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
96
|
prefill_tokens[0, :4] = idx
|
87
97
|
prefill_input_pos = torch.arange(0, 10)
|
98
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
88
99
|
|
89
100
|
edge_model = ai_edge_torch.signature(
|
90
|
-
"prefill",
|
101
|
+
"prefill",
|
102
|
+
model,
|
103
|
+
sample_kwargs={
|
104
|
+
"tokens": prefill_tokens,
|
105
|
+
"input_pos": prefill_input_pos,
|
106
|
+
"kv_cache": kv,
|
107
|
+
},
|
91
108
|
).convert()
|
92
109
|
edge_model.set_interpreter_builder(
|
93
110
|
self._interpreter_builder(edge_model.tflite_model())
|
94
111
|
)
|
95
112
|
|
96
113
|
self.assertTrue(
|
97
|
-
|
114
|
+
test_utils.compare_tflite_torch(
|
98
115
|
edge_model,
|
99
116
|
model,
|
100
|
-
|
117
|
+
prefill_tokens,
|
118
|
+
prefill_input_pos,
|
119
|
+
kv,
|
101
120
|
signature_name="prefill",
|
102
|
-
|
103
|
-
|
104
|
-
rtol=1e-5,
|
121
|
+
atol=1e-1,
|
122
|
+
rtol=1e-3,
|
105
123
|
)
|
106
124
|
)
|
107
125
|
|
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
|
|
117
135
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
118
136
|
tokens[0, :4] = idx
|
119
137
|
input_pos = torch.arange(0, 10)
|
120
|
-
|
121
|
-
|
138
|
+
kv = kv_cache.KVCache.from_model_config(config)
|
139
|
+
|
140
|
+
edge_model = ai_edge_torch.convert(
|
141
|
+
pytorch_model,
|
142
|
+
sample_kwargs={
|
143
|
+
"tokens": tokens,
|
144
|
+
"input_pos": input_pos,
|
145
|
+
"kv_cache": kv,
|
146
|
+
},
|
147
|
+
)
|
122
148
|
edge_model.set_interpreter_builder(
|
123
149
|
self._interpreter_builder(edge_model.tflite_model())
|
124
150
|
)
|
125
151
|
|
126
152
|
self.assertTrue(
|
127
|
-
|
153
|
+
test_utils.compare_tflite_torch(
|
128
154
|
edge_model,
|
129
155
|
pytorch_model,
|
130
|
-
|
131
|
-
|
156
|
+
tokens,
|
157
|
+
input_pos,
|
158
|
+
kv,
|
159
|
+
signature_name="serving_default",
|
132
160
|
atol=1e-3,
|
133
161
|
rtol=1e-3,
|
134
162
|
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utils for testing."""
|
17
|
+
|
18
|
+
from ai_edge_torch import model
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.lowertools import common_utils
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
from torch.utils import _pytree as pytree
|
24
|
+
|
25
|
+
|
26
|
+
def compare_tflite_torch(
|
27
|
+
edge_model: model.Model,
|
28
|
+
torch_model: torch.nn.Module,
|
29
|
+
tokens: torch.Tensor,
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
kv_cache: kv_utils.KVCache,
|
32
|
+
signature_name: str,
|
33
|
+
atol: float = 1e-5,
|
34
|
+
rtol: float = 1e-5,
|
35
|
+
):
|
36
|
+
"""Compares torch models and TFLite models."""
|
37
|
+
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
+
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
|
41
|
+
input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
|
42
|
+
edge_output = edge_model(
|
43
|
+
signature_name=signature_name,
|
44
|
+
tokens=tokens.numpy(),
|
45
|
+
input_pos=input_pos.numpy(),
|
46
|
+
**input_kv_flatten,
|
47
|
+
)
|
48
|
+
|
49
|
+
return np.allclose(
|
50
|
+
edge_output["logits"],
|
51
|
+
torch_output["logits"].detach().numpy(),
|
52
|
+
atol=atol,
|
53
|
+
rtol=rtol,
|
54
|
+
)
|
@@ -221,7 +221,8 @@ class ModelLoader:
|
|
221
221
|
converted_state: Dict[str, torch.Tensor],
|
222
222
|
):
|
223
223
|
prefix = f"transformer_blocks.{idx}"
|
224
|
-
|
224
|
+
ff_config = config.block_config(idx).ff_config
|
225
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
225
226
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
226
227
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
227
228
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
@@ -230,7 +231,7 @@ class ModelLoader:
|
|
230
231
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
231
232
|
f"{ff_down_proj_name}.weight"
|
232
233
|
)
|
233
|
-
if
|
234
|
+
if ff_config.use_bias:
|
234
235
|
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
235
236
|
f"{ff_up_proj_name}.bias"
|
236
237
|
)
|
@@ -250,7 +251,7 @@ class ModelLoader:
|
|
250
251
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
251
252
|
f"{ff_gate_proj_name}.weight"
|
252
253
|
)
|
253
|
-
if
|
254
|
+
if ff_config.use_bias:
|
254
255
|
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
255
256
|
f"{ff_up_proj_name}.bias"
|
256
257
|
)
|
@@ -289,6 +290,7 @@ class ModelLoader:
|
|
289
290
|
converted_state: Dict[str, torch.Tensor],
|
290
291
|
):
|
291
292
|
prefix = f"transformer_blocks.{idx}"
|
293
|
+
attn_config = config.block_config(idx).attn_config
|
292
294
|
if self._names.attn_fused_qkv_proj:
|
293
295
|
fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
|
294
296
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
|
@@ -300,13 +302,13 @@ class ModelLoader:
|
|
300
302
|
v_name = self._names.attn_value_proj.format(idx)
|
301
303
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
|
302
304
|
self._fuse_qkv(
|
303
|
-
|
305
|
+
attn_config,
|
304
306
|
state.pop(f"{q_name}.weight"),
|
305
307
|
state.pop(f"{k_name}.weight"),
|
306
308
|
state.pop(f"{v_name}.weight"),
|
307
309
|
)
|
308
310
|
)
|
309
|
-
if
|
311
|
+
if attn_config.qkv_use_bias:
|
310
312
|
if self._names.attn_fused_qkv_proj:
|
311
313
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
|
312
314
|
f"{fused_qkv_name}.bias"
|
@@ -314,7 +316,7 @@ class ModelLoader:
|
|
314
316
|
else:
|
315
317
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
|
316
318
|
self._fuse_qkv(
|
317
|
-
|
319
|
+
attn_config,
|
318
320
|
state.pop(f"{q_name}.bias"),
|
319
321
|
state.pop(f"{k_name}.bias"),
|
320
322
|
state.pop(f"{v_name}.bias"),
|
@@ -325,7 +327,7 @@ class ModelLoader:
|
|
325
327
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
326
328
|
state.pop(f"{o_name}.weight")
|
327
329
|
)
|
328
|
-
if
|
330
|
+
if attn_config.output_proj_use_bias:
|
329
331
|
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
330
332
|
state.pop(f"{o_name}.bias")
|
331
333
|
)
|
@@ -360,18 +362,16 @@ class ModelLoader:
|
|
360
362
|
|
361
363
|
def _fuse_qkv(
|
362
364
|
self,
|
363
|
-
|
365
|
+
attn_config: model_config.AttentionConfig,
|
364
366
|
q: torch.Tensor,
|
365
367
|
k: torch.Tensor,
|
366
368
|
v: torch.Tensor,
|
367
369
|
) -> torch.Tensor:
|
368
|
-
if
|
369
|
-
q_per_kv =
|
370
|
-
|
371
|
-
)
|
372
|
-
|
373
|
-
ks = torch.split(k, config.attn_config.head_dim)
|
374
|
-
vs = torch.split(v, config.attn_config.head_dim)
|
370
|
+
if attn_config.qkv_fused_interleaved:
|
371
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
372
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
373
|
+
ks = torch.split(k, attn_config.head_dim)
|
374
|
+
vs = torch.split(v, attn_config.head_dim)
|
375
375
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
376
376
|
return torch.cat(cycled)
|
377
377
|
else:
|
@@ -279,7 +279,8 @@ class ModelLoader:
|
|
279
279
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
280
280
|
if names.ff_up_proj is None or names.ff_down_proj is None:
|
281
281
|
return
|
282
|
-
|
282
|
+
ff_config = config.block_config(idx).ff_config
|
283
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
283
284
|
ff_up_proj_name = names.ff_up_proj.format(idx)
|
284
285
|
ff_down_proj_name = names.ff_down_proj.format(idx)
|
285
286
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
@@ -288,7 +289,7 @@ class ModelLoader:
|
|
288
289
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
289
290
|
f"{ff_down_proj_name}.weight"
|
290
291
|
)
|
291
|
-
if
|
292
|
+
if ff_config.use_bias:
|
292
293
|
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
293
294
|
f"{ff_up_proj_name}.bias"
|
294
295
|
)
|
@@ -309,7 +310,7 @@ class ModelLoader:
|
|
309
310
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
310
311
|
f"{ff_gate_proj_name}.weight"
|
311
312
|
)
|
312
|
-
if
|
313
|
+
if ff_config.use_bias:
|
313
314
|
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
314
315
|
f"{ff_up_proj_name}.bias"
|
315
316
|
)
|
@@ -337,20 +338,21 @@ class ModelLoader:
|
|
337
338
|
):
|
338
339
|
return
|
339
340
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
341
|
+
attn_config = config.block_config(idx).attn_config
|
340
342
|
q_name = names.attn_query_proj.format(idx)
|
341
343
|
k_name = names.attn_key_proj.format(idx)
|
342
344
|
v_name = names.attn_value_proj.format(idx)
|
343
345
|
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
|
344
346
|
if fuse_attention:
|
345
347
|
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
|
346
|
-
|
348
|
+
attn_config,
|
347
349
|
state.pop(f"{q_name}.weight"),
|
348
350
|
state.pop(f"{k_name}.weight"),
|
349
351
|
state.pop(f"{v_name}.weight"),
|
350
352
|
)
|
351
|
-
if
|
353
|
+
if attn_config.qkv_use_bias:
|
352
354
|
converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
|
353
|
-
|
355
|
+
attn_config,
|
354
356
|
state.pop(f"{q_name}.bias"),
|
355
357
|
state.pop(f"{k_name}.bias"),
|
356
358
|
state.pop(f"{v_name}.bias"),
|
@@ -365,7 +367,7 @@ class ModelLoader:
|
|
365
367
|
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
|
366
368
|
f"{v_name}.weight"
|
367
369
|
)
|
368
|
-
if
|
370
|
+
if attn_config.qkv_use_bias:
|
369
371
|
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
|
370
372
|
f"{q_name}.bias"
|
371
373
|
)
|
@@ -380,7 +382,7 @@ class ModelLoader:
|
|
380
382
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
381
383
|
state.pop(f"{o_name}.weight")
|
382
384
|
)
|
383
|
-
if
|
385
|
+
if attn_config.output_proj_use_bias:
|
384
386
|
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
385
387
|
state.pop(f"{o_name}.bias")
|
386
388
|
)
|
@@ -402,6 +404,7 @@ class ModelLoader:
|
|
402
404
|
):
|
403
405
|
return
|
404
406
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
407
|
+
attn_config = config.block_config(idx).attn_config
|
405
408
|
q_name = names.cross_attn_query_proj.format(idx)
|
406
409
|
k_name = names.cross_attn_key_proj.format(idx)
|
407
410
|
v_name = names.cross_attn_value_proj.format(idx)
|
@@ -409,16 +412,16 @@ class ModelLoader:
|
|
409
412
|
if fuse_attention:
|
410
413
|
converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
|
411
414
|
self._fuse_qkv(
|
412
|
-
|
415
|
+
attn_config,
|
413
416
|
state.pop(f"{q_name}.weight"),
|
414
417
|
state.pop(f"{k_name}.weight"),
|
415
418
|
state.pop(f"{v_name}.weight"),
|
416
419
|
)
|
417
420
|
)
|
418
|
-
if
|
421
|
+
if attn_config.qkv_use_bias:
|
419
422
|
converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
|
420
423
|
self._fuse_qkv(
|
421
|
-
|
424
|
+
attn_config,
|
422
425
|
state.pop(f"{q_name}.bias"),
|
423
426
|
state.pop(f"{k_name}.bias"),
|
424
427
|
state.pop(f"{v_name}.bias"),
|
@@ -434,7 +437,7 @@ class ModelLoader:
|
|
434
437
|
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
|
435
438
|
state.pop(f"{v_name}.weight")
|
436
439
|
)
|
437
|
-
if
|
440
|
+
if attn_config.qkv_use_bias:
|
438
441
|
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
|
439
442
|
state.pop(f"{q_name}.bias")
|
440
443
|
)
|
@@ -449,7 +452,7 @@ class ModelLoader:
|
|
449
452
|
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
|
450
453
|
state.pop(f"{o_name}.weight")
|
451
454
|
)
|
452
|
-
if
|
455
|
+
if attn_config.output_proj_use_bias:
|
453
456
|
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
|
454
457
|
state.pop(f"{o_name}.bias")
|
455
458
|
)
|
@@ -496,16 +499,14 @@ class ModelLoader:
|
|
496
499
|
|
497
500
|
def _fuse_qkv(
|
498
501
|
self,
|
499
|
-
|
502
|
+
attn_config: model_config.AttentionConfig,
|
500
503
|
q: torch.Tensor,
|
501
504
|
k: torch.Tensor,
|
502
505
|
v: torch.Tensor,
|
503
506
|
) -> torch.Tensor:
|
504
|
-
q_per_kv =
|
505
|
-
|
506
|
-
)
|
507
|
-
|
508
|
-
ks = torch.split(k, config.attn_config.head_dim)
|
509
|
-
vs = torch.split(v, config.attn_config.head_dim)
|
507
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
508
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
509
|
+
ks = torch.split(k, attn_config.head_dim)
|
510
|
+
vs = torch.split(v, attn_config.head_dim)
|
510
511
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
511
512
|
return torch.cat(cycled)
|