ctranslate2 4.6.3__cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctranslate2/__init__.py +64 -0
- ctranslate2/_ext.cpython-312-aarch64-linux-gnu.so +0 -0
- ctranslate2/converters/__init__.py +8 -0
- ctranslate2/converters/converter.py +109 -0
- ctranslate2/converters/eole_ct2.py +352 -0
- ctranslate2/converters/fairseq.py +347 -0
- ctranslate2/converters/marian.py +315 -0
- ctranslate2/converters/openai_gpt2.py +95 -0
- ctranslate2/converters/opennmt_py.py +361 -0
- ctranslate2/converters/opennmt_tf.py +455 -0
- ctranslate2/converters/opus_mt.py +44 -0
- ctranslate2/converters/transformers.py +3733 -0
- ctranslate2/converters/utils.py +127 -0
- ctranslate2/extensions.py +589 -0
- ctranslate2/logging.py +45 -0
- ctranslate2/models/__init__.py +18 -0
- ctranslate2/specs/__init__.py +18 -0
- ctranslate2/specs/attention_spec.py +98 -0
- ctranslate2/specs/common_spec.py +66 -0
- ctranslate2/specs/model_spec.py +767 -0
- ctranslate2/specs/transformer_spec.py +797 -0
- ctranslate2/specs/wav2vec2_spec.py +72 -0
- ctranslate2/specs/wav2vec2bert_spec.py +97 -0
- ctranslate2/specs/whisper_spec.py +77 -0
- ctranslate2/version.py +3 -0
- ctranslate2-4.6.3.dist-info/METADATA +178 -0
- ctranslate2-4.6.3.dist-info/RECORD +32 -0
- ctranslate2-4.6.3.dist-info/WHEEL +6 -0
- ctranslate2-4.6.3.dist-info/entry_points.txt +8 -0
- ctranslate2-4.6.3.dist-info/top_level.txt +1 -0
- ctranslate2.libs/libctranslate2-0eb658df.so.4.6.3 +0 -0
- ctranslate2.libs/libgomp-a49a47f9.so.1.0.0 +0 -0
ctranslate2/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
if sys.platform == "win32":
|
|
4
|
+
import ctypes
|
|
5
|
+
import glob
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
module_name = sys.modules[__name__].__name__
|
|
9
|
+
|
|
10
|
+
# Adressing python 3.9 < version
|
|
11
|
+
try:
|
|
12
|
+
from importlib.resources import files
|
|
13
|
+
|
|
14
|
+
# Fixed the pkg_resources depreciation
|
|
15
|
+
package_dir = str(files(module_name))
|
|
16
|
+
except ImportError:
|
|
17
|
+
import pkg_resources
|
|
18
|
+
|
|
19
|
+
package_dir = pkg_resources.resource_filename(module_name, "")
|
|
20
|
+
|
|
21
|
+
add_dll_directory = getattr(os, "add_dll_directory", None)
|
|
22
|
+
if add_dll_directory is not None:
|
|
23
|
+
add_dll_directory(package_dir)
|
|
24
|
+
|
|
25
|
+
for library in glob.glob(os.path.join(package_dir, "*.dll")):
|
|
26
|
+
ctypes.CDLL(library)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from ctranslate2._ext import (
|
|
30
|
+
AsyncGenerationResult,
|
|
31
|
+
AsyncScoringResult,
|
|
32
|
+
AsyncTranslationResult,
|
|
33
|
+
DataType,
|
|
34
|
+
Device,
|
|
35
|
+
Encoder,
|
|
36
|
+
EncoderForwardOutput,
|
|
37
|
+
ExecutionStats,
|
|
38
|
+
GenerationResult,
|
|
39
|
+
GenerationStepResult,
|
|
40
|
+
Generator,
|
|
41
|
+
MpiInfo,
|
|
42
|
+
ScoringResult,
|
|
43
|
+
StorageView,
|
|
44
|
+
TranslationResult,
|
|
45
|
+
Translator,
|
|
46
|
+
contains_model,
|
|
47
|
+
get_cuda_device_count,
|
|
48
|
+
get_supported_compute_types,
|
|
49
|
+
set_random_seed,
|
|
50
|
+
)
|
|
51
|
+
from ctranslate2.extensions import register_extensions
|
|
52
|
+
from ctranslate2.logging import get_log_level, set_log_level
|
|
53
|
+
|
|
54
|
+
register_extensions()
|
|
55
|
+
del register_extensions
|
|
56
|
+
except ImportError as e:
|
|
57
|
+
# Allow using the Python package without the compiled extension.
|
|
58
|
+
if "No module named" in str(e):
|
|
59
|
+
pass
|
|
60
|
+
else:
|
|
61
|
+
raise
|
|
62
|
+
|
|
63
|
+
from ctranslate2 import converters, models, specs
|
|
64
|
+
from ctranslate2.version import __version__
|
|
Binary file
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from ctranslate2.converters.converter import Converter
|
|
2
|
+
from ctranslate2.converters.fairseq import FairseqConverter
|
|
3
|
+
from ctranslate2.converters.marian import MarianConverter
|
|
4
|
+
from ctranslate2.converters.openai_gpt2 import OpenAIGPT2Converter
|
|
5
|
+
from ctranslate2.converters.opennmt_py import OpenNMTPyConverter
|
|
6
|
+
from ctranslate2.converters.opennmt_tf import OpenNMTTFConverter
|
|
7
|
+
from ctranslate2.converters.opus_mt import OpusMTConverter
|
|
8
|
+
from ctranslate2.converters.transformers import TransformersConverter
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import argparse
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, ModelSpec
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Converter(abc.ABC):
|
|
12
|
+
"""Base class for model converters."""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def declare_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
|
16
|
+
"""Adds common conversion options to the command line parser.
|
|
17
|
+
|
|
18
|
+
Arguments:
|
|
19
|
+
parser: Command line argument parser.
|
|
20
|
+
"""
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"--output_dir", required=True, help="Output model directory."
|
|
23
|
+
)
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--vocab_mapping", default=None, help="Vocabulary mapping file (optional)."
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"--quantization",
|
|
29
|
+
default=None,
|
|
30
|
+
choices=ACCEPTED_MODEL_TYPES,
|
|
31
|
+
help="Weight quantization type.",
|
|
32
|
+
)
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"--force",
|
|
35
|
+
action="store_true",
|
|
36
|
+
help="Force conversion even if the output directory already exists.",
|
|
37
|
+
)
|
|
38
|
+
return parser
|
|
39
|
+
|
|
40
|
+
def convert_from_args(self, args: argparse.Namespace) -> str:
|
|
41
|
+
"""Helper function to call :meth:`ctranslate2.converters.Converter.convert`
|
|
42
|
+
with the parsed command line options.
|
|
43
|
+
|
|
44
|
+
Arguments:
|
|
45
|
+
args: Namespace containing parsed arguments.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Path to the output directory.
|
|
49
|
+
"""
|
|
50
|
+
return self.convert(
|
|
51
|
+
args.output_dir,
|
|
52
|
+
vmap=args.vocab_mapping,
|
|
53
|
+
quantization=args.quantization,
|
|
54
|
+
force=args.force,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def convert(
|
|
58
|
+
self,
|
|
59
|
+
output_dir: str,
|
|
60
|
+
vmap: Optional[str] = None,
|
|
61
|
+
quantization: Optional[str] = None,
|
|
62
|
+
force: bool = False,
|
|
63
|
+
) -> str:
|
|
64
|
+
"""Converts the model to the CTranslate2 format.
|
|
65
|
+
|
|
66
|
+
Arguments:
|
|
67
|
+
output_dir: Output directory where the CTranslate2 model is saved.
|
|
68
|
+
vmap: Optional path to a vocabulary mapping file that will be included
|
|
69
|
+
in the converted model directory.
|
|
70
|
+
quantization: Weight quantization scheme (possible values are: int8, int8_float32,
|
|
71
|
+
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
|
|
72
|
+
force: Override the output directory if it already exists.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Path to the output directory.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
RuntimeError: If the output directory already exists and :obj:`force`
|
|
79
|
+
is not set.
|
|
80
|
+
NotImplementedError: If the converter cannot convert this model to the
|
|
81
|
+
CTranslate2 format.
|
|
82
|
+
"""
|
|
83
|
+
if os.path.exists(output_dir) and not force:
|
|
84
|
+
raise RuntimeError(
|
|
85
|
+
"output directory %s already exists, use --force to override"
|
|
86
|
+
% output_dir
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
model_spec = self._load()
|
|
90
|
+
if model_spec is None:
|
|
91
|
+
raise NotImplementedError(
|
|
92
|
+
"This model is not supported by CTranslate2 or this converter"
|
|
93
|
+
)
|
|
94
|
+
if vmap is not None:
|
|
95
|
+
model_spec.register_vocabulary_mapping(vmap)
|
|
96
|
+
|
|
97
|
+
model_spec.validate()
|
|
98
|
+
model_spec.optimize(quantization=quantization)
|
|
99
|
+
|
|
100
|
+
# Create model directory.
|
|
101
|
+
if os.path.exists(output_dir):
|
|
102
|
+
shutil.rmtree(output_dir)
|
|
103
|
+
os.makedirs(output_dir)
|
|
104
|
+
model_spec.save(output_dir)
|
|
105
|
+
return output_dir
|
|
106
|
+
|
|
107
|
+
@abc.abstractmethod
|
|
108
|
+
def _load(self):
|
|
109
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from eole.config.run import PredictConfig
|
|
4
|
+
from eole.constants import PositionEncodingType
|
|
5
|
+
from eole.inputters.inputter import vocabs_to_dict
|
|
6
|
+
from eole.models.model import BaseModel
|
|
7
|
+
|
|
8
|
+
from ctranslate2.converters import utils
|
|
9
|
+
from ctranslate2.converters.converter import Converter
|
|
10
|
+
from ctranslate2.specs import common_spec, transformer_spec
|
|
11
|
+
|
|
12
|
+
_SUPPORTED_ACTIVATIONS = {
|
|
13
|
+
"gelu": common_spec.Activation.GELU,
|
|
14
|
+
"fast_gelu": common_spec.Activation.GELUTanh,
|
|
15
|
+
"relu": common_spec.Activation.RELU,
|
|
16
|
+
"gated-silu": common_spec.Activation.SWISH,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_model_spec_seq2seq(
|
|
21
|
+
config, variables, src_vocabs, tgt_vocabs, num_source_embeddings
|
|
22
|
+
):
|
|
23
|
+
"""Creates a model specification from the model config."""
|
|
24
|
+
with_relative_position = (
|
|
25
|
+
getattr(config.embeddings, "position_encoding_type", None)
|
|
26
|
+
== PositionEncodingType.Relative
|
|
27
|
+
)
|
|
28
|
+
with_rotary = (
|
|
29
|
+
getattr(config.embeddings, "position_encoding_type", None)
|
|
30
|
+
== PositionEncodingType.Rotary
|
|
31
|
+
)
|
|
32
|
+
if with_rotary:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
"Rotary embeddings are not supported yet for encoder/decoder models"
|
|
35
|
+
)
|
|
36
|
+
with_alibi = (
|
|
37
|
+
getattr(config.embeddings, "position_encoding_type", None)
|
|
38
|
+
== PositionEncodingType.Alibi
|
|
39
|
+
)
|
|
40
|
+
if with_alibi:
|
|
41
|
+
raise ValueError("Alibi is not supported yet for encoder/decoder models")
|
|
42
|
+
activation_fn = getattr(config, "mlp_activation_fn", "relu")
|
|
43
|
+
|
|
44
|
+
# Return the first head of the last layer unless the model was trained with alignments.
|
|
45
|
+
if getattr(config.decoder, "lambda_align", 0) == 0:
|
|
46
|
+
alignment_layer = -1
|
|
47
|
+
alignment_heads = 1
|
|
48
|
+
else:
|
|
49
|
+
alignment_layer = config.decoder.alignment_layer
|
|
50
|
+
alignment_heads = config.decoder.alignment_heads
|
|
51
|
+
|
|
52
|
+
num_heads = getattr(config.decoder, "heads", 8)
|
|
53
|
+
# num_kv = getattr(config.decoder, "heads_kv", 0)
|
|
54
|
+
# if num_kv == num_heads or num_kv == 0:
|
|
55
|
+
# num_kv = None
|
|
56
|
+
# rotary_dim = 0 if with_rotary else None
|
|
57
|
+
# rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
|
|
58
|
+
ffn_glu = activation_fn == "gated-silu"
|
|
59
|
+
sliding_window = getattr(config, "sliding_window", 0)
|
|
60
|
+
if sliding_window != 0:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Sliding window is not suported yet for encoder/decoder models"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
model_spec = transformer_spec.TransformerSpec.from_config(
|
|
66
|
+
(config.encoder.layers, config.decoder.layers),
|
|
67
|
+
num_heads,
|
|
68
|
+
with_relative_position=with_relative_position,
|
|
69
|
+
# alibi=with_alibi,
|
|
70
|
+
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
|
|
71
|
+
ffn_glu=ffn_glu,
|
|
72
|
+
rms_norm=config.layer_norm == "rms",
|
|
73
|
+
# rotary_dim=rotary_dim,
|
|
74
|
+
# rotary_interleave=rotary_interleave,
|
|
75
|
+
# num_heads_kv=num_kv,
|
|
76
|
+
# sliding_window=sliding_window,
|
|
77
|
+
alignment_layer=alignment_layer,
|
|
78
|
+
alignment_heads=alignment_heads,
|
|
79
|
+
num_source_embeddings=num_source_embeddings,
|
|
80
|
+
# multi_query_attention=getattr(opt, "multiquery", False),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
set_transformer_spec(model_spec, variables)
|
|
84
|
+
for src_vocab in src_vocabs:
|
|
85
|
+
model_spec.register_source_vocabulary(src_vocab)
|
|
86
|
+
for tgt_vocab in tgt_vocabs:
|
|
87
|
+
model_spec.register_target_vocabulary(tgt_vocab)
|
|
88
|
+
|
|
89
|
+
return model_spec
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_model_spec_lm(
|
|
93
|
+
config, variables, src_vocabs, tgt_vocabs, num_source_embeddings
|
|
94
|
+
):
|
|
95
|
+
"""Creates a model specification from the model config."""
|
|
96
|
+
with_relative_position = (
|
|
97
|
+
getattr(config.embeddings, "position_encoding_type", None)
|
|
98
|
+
== PositionEncodingType.Relative
|
|
99
|
+
)
|
|
100
|
+
with_rotary = (
|
|
101
|
+
getattr(config.embeddings, "position_encoding_type", None)
|
|
102
|
+
== PositionEncodingType.Rotary
|
|
103
|
+
)
|
|
104
|
+
with_alibi = (
|
|
105
|
+
getattr(config.embeddings, "position_encoding_type", None)
|
|
106
|
+
== PositionEncodingType.Alibi
|
|
107
|
+
)
|
|
108
|
+
activation_fn = getattr(config, "mlp_activation_fn", "relu")
|
|
109
|
+
num_heads = getattr(config.decoder, "heads", 8)
|
|
110
|
+
num_kv = getattr(config.decoder, "heads_kv", 0)
|
|
111
|
+
if num_kv == num_heads or num_kv == 0:
|
|
112
|
+
num_kv = None
|
|
113
|
+
rotary_dim = 0 if with_rotary else None
|
|
114
|
+
rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
|
|
115
|
+
ffn_glu = activation_fn == "gated-silu"
|
|
116
|
+
sliding_window = getattr(config, "sliding_window", 0)
|
|
117
|
+
|
|
118
|
+
model_spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
119
|
+
config.decoder.layers,
|
|
120
|
+
num_heads,
|
|
121
|
+
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
|
|
122
|
+
ffn_glu=ffn_glu,
|
|
123
|
+
with_relative_position=with_relative_position,
|
|
124
|
+
alibi=with_alibi,
|
|
125
|
+
rms_norm=config.layer_norm == "rms",
|
|
126
|
+
rotary_dim=rotary_dim,
|
|
127
|
+
rotary_interleave=rotary_interleave,
|
|
128
|
+
num_heads_kv=num_kv,
|
|
129
|
+
sliding_window=sliding_window,
|
|
130
|
+
# multi_query_attention=getattr(opt, "multiquery", False),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
set_transformer_decoder(
|
|
134
|
+
model_spec.decoder,
|
|
135
|
+
variables,
|
|
136
|
+
with_encoder_attention=False,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
for tgt_vocab in tgt_vocabs:
|
|
140
|
+
model_spec.register_vocabulary(tgt_vocab)
|
|
141
|
+
|
|
142
|
+
return model_spec
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_vocabs(vocab):
|
|
146
|
+
src_vocabs = [vocab["src"]]
|
|
147
|
+
tgt_vocabs = [vocab["tgt"]]
|
|
148
|
+
return src_vocabs, tgt_vocabs
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class EoleConverter(Converter):
|
|
152
|
+
"""Converts models generated by OpenNMT-py."""
|
|
153
|
+
|
|
154
|
+
def __init__(self, model_path: str):
|
|
155
|
+
"""Initializes the OpenNMT-py converter.
|
|
156
|
+
|
|
157
|
+
Arguments:
|
|
158
|
+
model_path: Path to the OpenNMT-py PyTorch model (.pt file).
|
|
159
|
+
"""
|
|
160
|
+
self._model_path = model_path
|
|
161
|
+
|
|
162
|
+
def _load(self):
|
|
163
|
+
import torch
|
|
164
|
+
|
|
165
|
+
config = PredictConfig(model_path=self._model_path, src="dummy")
|
|
166
|
+
|
|
167
|
+
vocabs, model, model_config = BaseModel.load_test_model(config)
|
|
168
|
+
vocabs_dict = vocabs_to_dict(vocabs)
|
|
169
|
+
|
|
170
|
+
config.model = model_config
|
|
171
|
+
src_vocabs, tgt_vocabs = get_vocabs(vocabs_dict)
|
|
172
|
+
|
|
173
|
+
if config.model.decoder.decoder_type == "transformer_lm":
|
|
174
|
+
spec = _get_model_spec_lm(
|
|
175
|
+
config.model,
|
|
176
|
+
model.state_dict(),
|
|
177
|
+
src_vocabs,
|
|
178
|
+
tgt_vocabs,
|
|
179
|
+
num_source_embeddings=len(src_vocabs),
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
spec = _get_model_spec_seq2seq(
|
|
183
|
+
config.model,
|
|
184
|
+
model.state_dict(),
|
|
185
|
+
src_vocabs,
|
|
186
|
+
tgt_vocabs,
|
|
187
|
+
num_source_embeddings=len(src_vocabs),
|
|
188
|
+
)
|
|
189
|
+
spec.config.decoder_start_token = vocabs["decoder_start_token"]
|
|
190
|
+
|
|
191
|
+
spec.config.bos_token = vocabs["specials"]["bos_token"]
|
|
192
|
+
spec.config.eos_token = vocabs["specials"]["eos_token"]
|
|
193
|
+
spec.config.unk_token = vocabs["specials"]["unk_token"]
|
|
194
|
+
spec.config.layer_norm_epsilon = getattr(config, "norm_eps", 1e-6)
|
|
195
|
+
|
|
196
|
+
return spec
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def set_transformer_spec(spec, variables):
|
|
200
|
+
set_transformer_encoder(spec.encoder, variables)
|
|
201
|
+
set_transformer_decoder(spec.decoder, variables)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def set_transformer_encoder(spec, variables):
|
|
205
|
+
set_input_layers(spec, variables, "src_emb")
|
|
206
|
+
set_layer_norm(spec.layer_norm, variables, "encoder.layer_norm")
|
|
207
|
+
for i, layer in enumerate(spec.layer):
|
|
208
|
+
set_transformer_encoder_layer(
|
|
209
|
+
layer, variables, "encoder.transformer_layers.%d" % i
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def set_transformer_decoder(spec, variables, with_encoder_attention=True):
|
|
214
|
+
set_input_layers(spec, variables, "tgt_emb")
|
|
215
|
+
set_layer_norm(spec.layer_norm, variables, "decoder.layer_norm")
|
|
216
|
+
for i, layer in enumerate(spec.layer):
|
|
217
|
+
set_transformer_decoder_layer(
|
|
218
|
+
layer,
|
|
219
|
+
variables,
|
|
220
|
+
"decoder.transformer_layers.%d" % i,
|
|
221
|
+
with_encoder_attention=with_encoder_attention,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
set_linear(spec.projection, variables, "generator")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def set_input_layers(spec, variables, scope):
|
|
228
|
+
if hasattr(spec, "position_encodings"):
|
|
229
|
+
set_position_encodings(
|
|
230
|
+
spec.position_encodings,
|
|
231
|
+
variables,
|
|
232
|
+
"%s.pe" % scope,
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
spec.scale_embeddings = False
|
|
236
|
+
|
|
237
|
+
embeddings_specs = spec.embeddings
|
|
238
|
+
# encoder embeddings are stored in a list(onmt/ct2 legacy with features)
|
|
239
|
+
if isinstance(embeddings_specs, list):
|
|
240
|
+
embeddings_specs = embeddings_specs[0]
|
|
241
|
+
set_embeddings(embeddings_specs, variables, "%s.embeddings" % scope)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def set_transformer_encoder_layer(spec, variables, scope):
|
|
245
|
+
set_multi_head_attention(
|
|
246
|
+
spec.self_attention,
|
|
247
|
+
variables,
|
|
248
|
+
"%s.self_attn" % scope,
|
|
249
|
+
self_attention=True,
|
|
250
|
+
)
|
|
251
|
+
set_layer_norm(
|
|
252
|
+
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
|
|
253
|
+
)
|
|
254
|
+
set_layer_norm(
|
|
255
|
+
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
|
|
256
|
+
)
|
|
257
|
+
set_ffn(spec.ffn, variables, "%s.mlp" % scope)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention=True):
|
|
261
|
+
set_multi_head_attention(
|
|
262
|
+
spec.self_attention,
|
|
263
|
+
variables,
|
|
264
|
+
"%s.self_attn" % scope,
|
|
265
|
+
self_attention=True,
|
|
266
|
+
)
|
|
267
|
+
set_layer_norm(
|
|
268
|
+
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
|
|
269
|
+
)
|
|
270
|
+
if with_encoder_attention:
|
|
271
|
+
set_multi_head_attention(spec.attention, variables, "%s.context_attn" % scope)
|
|
272
|
+
set_layer_norm(
|
|
273
|
+
spec.attention.layer_norm, variables, "%s.precontext_layernorm" % scope
|
|
274
|
+
)
|
|
275
|
+
set_layer_norm(
|
|
276
|
+
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
|
|
277
|
+
)
|
|
278
|
+
set_ffn(spec.ffn, variables, "%s.mlp" % scope)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def set_ffn(spec, variables, scope):
|
|
282
|
+
set_linear(spec.linear_0, variables, "%s.gate_up_proj" % scope)
|
|
283
|
+
set_linear(spec.linear_1, variables, "%s.down_proj" % scope)
|
|
284
|
+
if hasattr(spec, "linear_0_noact"):
|
|
285
|
+
set_linear(spec.linear_0_noact, variables, "%s.up_proj" % scope)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def set_multi_head_attention(spec, variables, scope, self_attention=False):
|
|
289
|
+
if self_attention:
|
|
290
|
+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
291
|
+
set_linear(split_layers[0], variables, "%s.linear_query" % scope)
|
|
292
|
+
set_linear(split_layers[1], variables, "%s.linear_keys" % scope)
|
|
293
|
+
set_linear(split_layers[2], variables, "%s.linear_values" % scope)
|
|
294
|
+
utils.fuse_linear(spec.linear[0], split_layers)
|
|
295
|
+
else:
|
|
296
|
+
set_linear(spec.linear[0], variables, "%s.linear_query" % scope)
|
|
297
|
+
split_layers = [common_spec.LinearSpec() for _ in range(2)]
|
|
298
|
+
set_linear(split_layers[0], variables, "%s.linear_keys" % scope)
|
|
299
|
+
set_linear(split_layers[1], variables, "%s.linear_values" % scope)
|
|
300
|
+
utils.fuse_linear(spec.linear[1], split_layers)
|
|
301
|
+
set_linear(spec.linear[-1], variables, "%s.final_linear" % scope)
|
|
302
|
+
if hasattr(spec, "relative_position_keys"):
|
|
303
|
+
spec.relative_position_keys = _get_variable(
|
|
304
|
+
variables, "%s.relative_positions_embeddings.weight" % scope
|
|
305
|
+
)
|
|
306
|
+
spec.relative_position_values = spec.relative_position_keys
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def set_layer_norm(spec, variables, scope):
|
|
310
|
+
try:
|
|
311
|
+
spec.gamma = _get_variable(variables, "%s.weight" % scope)
|
|
312
|
+
except KeyError:
|
|
313
|
+
# Compatibility with older models using a custom LayerNorm module.
|
|
314
|
+
spec.gamma = _get_variable(variables, "%s.a_2" % scope)
|
|
315
|
+
spec.beta = _get_variable(variables, "%s.b_2" % scope)
|
|
316
|
+
try:
|
|
317
|
+
spec.beta = _get_variable(variables, "%s.bias" % scope)
|
|
318
|
+
except KeyError:
|
|
319
|
+
pass
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def set_linear(spec, variables, scope):
|
|
323
|
+
spec.weight = _get_variable(variables, "%s.weight" % scope)
|
|
324
|
+
bias = variables.get("%s.bias" % scope)
|
|
325
|
+
if bias is not None:
|
|
326
|
+
spec.bias = bias
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def set_embeddings(spec, variables, scope):
|
|
330
|
+
spec.weight = _get_variable(variables, "%s.weight" % scope)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def set_position_encodings(spec, variables, scope):
|
|
334
|
+
spec.encodings = _get_variable(variables, "%s.pe" % scope).squeeze()
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _get_variable(variables, name):
|
|
338
|
+
return variables[name]
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def main():
|
|
342
|
+
parser = argparse.ArgumentParser(
|
|
343
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
344
|
+
)
|
|
345
|
+
parser.add_argument("--model_path", required=True, help="Model path.")
|
|
346
|
+
Converter.declare_arguments(parser)
|
|
347
|
+
args = parser.parse_args()
|
|
348
|
+
EoleConverter(args.model_path).convert_from_args(args)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
if __name__ == "__main__":
|
|
352
|
+
main()
|