ctranslate2 4.7.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctranslate2/.dylibs/libctranslate2.4.7.0.dylib +0 -0
- ctranslate2/__init__.py +66 -0
- ctranslate2/_ext.cpython-314-darwin.so +0 -0
- ctranslate2/converters/__init__.py +8 -0
- ctranslate2/converters/converter.py +109 -0
- ctranslate2/converters/eole_ct2.py +353 -0
- ctranslate2/converters/fairseq.py +347 -0
- ctranslate2/converters/marian.py +315 -0
- ctranslate2/converters/openai_gpt2.py +95 -0
- ctranslate2/converters/opennmt_py.py +361 -0
- ctranslate2/converters/opennmt_tf.py +455 -0
- ctranslate2/converters/opus_mt.py +44 -0
- ctranslate2/converters/transformers.py +3721 -0
- ctranslate2/converters/utils.py +127 -0
- ctranslate2/extensions.py +589 -0
- ctranslate2/logging.py +45 -0
- ctranslate2/models/__init__.py +18 -0
- ctranslate2/specs/__init__.py +18 -0
- ctranslate2/specs/attention_spec.py +98 -0
- ctranslate2/specs/common_spec.py +66 -0
- ctranslate2/specs/model_spec.py +767 -0
- ctranslate2/specs/transformer_spec.py +797 -0
- ctranslate2/specs/wav2vec2_spec.py +72 -0
- ctranslate2/specs/wav2vec2bert_spec.py +97 -0
- ctranslate2/specs/whisper_spec.py +77 -0
- ctranslate2/version.py +3 -0
- ctranslate2-4.7.0.dist-info/METADATA +180 -0
- ctranslate2-4.7.0.dist-info/RECORD +31 -0
- ctranslate2-4.7.0.dist-info/WHEEL +6 -0
- ctranslate2-4.7.0.dist-info/entry_points.txt +8 -0
- ctranslate2-4.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ctranslate2.converters import utils
|
|
7
|
+
from ctranslate2.converters.converter import Converter
|
|
8
|
+
from ctranslate2.specs import common_spec, transformer_spec
|
|
9
|
+
|
|
10
|
+
_SUPPORTED_MODELS = {
|
|
11
|
+
"bart",
|
|
12
|
+
"multilingual_transformer",
|
|
13
|
+
"transformer",
|
|
14
|
+
"transformer_align",
|
|
15
|
+
"transformer_lm",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_SUPPORTED_ACTIVATIONS = {
|
|
20
|
+
"gelu": common_spec.Activation.GELU,
|
|
21
|
+
"gelu_accurate": common_spec.Activation.GELUTanh,
|
|
22
|
+
"gelu_fast": common_spec.Activation.GELUTanh,
|
|
23
|
+
"relu": common_spec.Activation.RELU,
|
|
24
|
+
"swish": common_spec.Activation.SWISH,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_model_spec(args):
|
|
29
|
+
import fairseq
|
|
30
|
+
|
|
31
|
+
activation_fn = getattr(args, "activation_fn", "relu")
|
|
32
|
+
model_name = fairseq.models.ARCH_MODEL_NAME_REGISTRY[args.arch]
|
|
33
|
+
|
|
34
|
+
check = utils.ConfigurationChecker()
|
|
35
|
+
check(
|
|
36
|
+
model_name in _SUPPORTED_MODELS,
|
|
37
|
+
"Model '%s' used by architecture '%s' is not supported (supported models are: %s)"
|
|
38
|
+
% (model_name, args.arch, ", ".join(_SUPPORTED_MODELS)),
|
|
39
|
+
)
|
|
40
|
+
check.validate()
|
|
41
|
+
check(
|
|
42
|
+
activation_fn in _SUPPORTED_ACTIVATIONS,
|
|
43
|
+
"Option --activation-fn %s is not supported (supported activations are: %s)"
|
|
44
|
+
% (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
|
|
45
|
+
)
|
|
46
|
+
check(
|
|
47
|
+
not getattr(args, "no_token_positional_embeddings", False),
|
|
48
|
+
"Option --no-token-positional-embeddings is not supported",
|
|
49
|
+
)
|
|
50
|
+
check(
|
|
51
|
+
not getattr(args, "lang_tok_replacing_bos_eos", False),
|
|
52
|
+
"Option --lang-tok-replacing-bos-eos is not supported",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if model_name == "transformer_lm":
|
|
56
|
+
check(
|
|
57
|
+
not args.character_embeddings,
|
|
58
|
+
"Option --character-embeddings is not supported",
|
|
59
|
+
)
|
|
60
|
+
check(
|
|
61
|
+
not args.adaptive_input,
|
|
62
|
+
"Option --adaptive-input is not supported",
|
|
63
|
+
)
|
|
64
|
+
check.validate()
|
|
65
|
+
|
|
66
|
+
return transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
67
|
+
args.decoder_layers,
|
|
68
|
+
args.decoder_attention_heads,
|
|
69
|
+
pre_norm=args.decoder_normalize_before,
|
|
70
|
+
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
|
|
71
|
+
layernorm_embedding=getattr(args, "layernorm_embedding", False),
|
|
72
|
+
no_final_norm=args.no_decoder_final_norm,
|
|
73
|
+
project_in_out=args.decoder_input_dim != args.decoder_embed_dim,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
check(
|
|
78
|
+
args.encoder_normalize_before == args.decoder_normalize_before,
|
|
79
|
+
"Options --encoder-normalize-before and --decoder-normalize-before "
|
|
80
|
+
"must have the same value",
|
|
81
|
+
)
|
|
82
|
+
check(
|
|
83
|
+
args.encoder_attention_heads == args.decoder_attention_heads,
|
|
84
|
+
"Options --encoder-attention-heads and --decoder-attention-heads "
|
|
85
|
+
"must have the same value",
|
|
86
|
+
)
|
|
87
|
+
check.validate()
|
|
88
|
+
|
|
89
|
+
return transformer_spec.TransformerSpec.from_config(
|
|
90
|
+
(args.encoder_layers, args.decoder_layers),
|
|
91
|
+
args.encoder_attention_heads,
|
|
92
|
+
pre_norm=args.encoder_normalize_before,
|
|
93
|
+
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
|
|
94
|
+
alignment_layer=getattr(args, "alignment_layer", -1),
|
|
95
|
+
alignment_heads=getattr(args, "alignment_heads", 0),
|
|
96
|
+
layernorm_embedding=getattr(args, "layernorm_embedding", False),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _get_vocab(dictionary):
|
|
101
|
+
return ["<blank>" if token == "<pad>" else token for token in dictionary.symbols]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class FairseqConverter(Converter):
|
|
105
|
+
"""Converts models trained with Fairseq."""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
model_path: str,
|
|
110
|
+
data_dir: str,
|
|
111
|
+
source_lang: Optional[str] = None,
|
|
112
|
+
target_lang: Optional[str] = None,
|
|
113
|
+
fixed_dictionary: Optional[str] = None,
|
|
114
|
+
no_default_special_tokens: bool = False,
|
|
115
|
+
user_dir: Optional[str] = None,
|
|
116
|
+
):
|
|
117
|
+
"""Initializes the Fairseq converter.
|
|
118
|
+
|
|
119
|
+
Arguments:
|
|
120
|
+
model_path: Path to the Fairseq PyTorch model (.pt file).
|
|
121
|
+
data_dir: Path to the Fairseq data directory containing vocabulary files.
|
|
122
|
+
source_lang: Source language (may be required if not declared in the model).
|
|
123
|
+
target_lang: Target language (may be required if not declared in the model).
|
|
124
|
+
fixed_dictionary: Path to the fixed dictionary for multilingual models.
|
|
125
|
+
no_default_special_tokens: Require all special tokens to be provided by the user
|
|
126
|
+
(e.g. encoder end token, decoder start token).
|
|
127
|
+
user_dir: Path to the user directory containing custom extensions.
|
|
128
|
+
"""
|
|
129
|
+
self._model_path = model_path
|
|
130
|
+
self._data_dir = data_dir
|
|
131
|
+
self._fixed_dictionary = fixed_dictionary
|
|
132
|
+
self._source_lang = source_lang
|
|
133
|
+
self._target_lang = target_lang
|
|
134
|
+
self._no_default_special_tokens = no_default_special_tokens
|
|
135
|
+
self._user_dir = user_dir
|
|
136
|
+
|
|
137
|
+
def _load(self):
|
|
138
|
+
import fairseq
|
|
139
|
+
import torch
|
|
140
|
+
|
|
141
|
+
from fairseq import checkpoint_utils
|
|
142
|
+
|
|
143
|
+
if self._user_dir:
|
|
144
|
+
from fairseq.utils import import_user_module
|
|
145
|
+
|
|
146
|
+
import_user_module(argparse.Namespace(user_dir=self._user_dir))
|
|
147
|
+
|
|
148
|
+
with torch.no_grad():
|
|
149
|
+
checkpoint = torch.load(
|
|
150
|
+
self._model_path, map_location=torch.device("cpu"), weights_only=False
|
|
151
|
+
)
|
|
152
|
+
args = checkpoint["args"] or checkpoint["cfg"]["model"]
|
|
153
|
+
|
|
154
|
+
args.data = self._data_dir
|
|
155
|
+
if self._fixed_dictionary is not None:
|
|
156
|
+
args.fixed_dictionary = self._fixed_dictionary
|
|
157
|
+
if hasattr(args, "lang_dict") and args.lang_dict:
|
|
158
|
+
args.lang_dict = os.path.join(
|
|
159
|
+
self._data_dir, os.path.basename(args.lang_dict)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if self._source_lang is not None:
|
|
163
|
+
args.source_lang = self._source_lang
|
|
164
|
+
|
|
165
|
+
if self._target_lang is not None:
|
|
166
|
+
args.target_lang = self._target_lang
|
|
167
|
+
|
|
168
|
+
spec = _get_model_spec(args)
|
|
169
|
+
|
|
170
|
+
task = fairseq.tasks.setup_task(args)
|
|
171
|
+
model = fairseq.models.build_model(args, task)
|
|
172
|
+
model.eval()
|
|
173
|
+
model.load_state_dict(checkpoint["model"])
|
|
174
|
+
|
|
175
|
+
if isinstance(spec, transformer_spec.TransformerDecoderModelSpec):
|
|
176
|
+
set_transformer_decoder(
|
|
177
|
+
spec.decoder,
|
|
178
|
+
model.decoder,
|
|
179
|
+
with_encoder_attention=False,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
spec.register_vocabulary(_get_vocab(task.dictionary))
|
|
183
|
+
if not args.add_bos_token:
|
|
184
|
+
spec.config.bos_token = spec.config.eos_token
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
set_transformer_encoder(spec.encoder, model.encoder)
|
|
188
|
+
set_transformer_decoder(spec.decoder, model.decoder)
|
|
189
|
+
|
|
190
|
+
spec.register_source_vocabulary(_get_vocab(task.source_dictionary))
|
|
191
|
+
spec.register_target_vocabulary(_get_vocab(task.target_dictionary))
|
|
192
|
+
if self._no_default_special_tokens:
|
|
193
|
+
spec.config.decoder_start_token = None
|
|
194
|
+
else:
|
|
195
|
+
spec.config.decoder_start_token = spec.config.eos_token
|
|
196
|
+
spec.config.add_source_eos = True
|
|
197
|
+
|
|
198
|
+
return spec
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def set_transformer_encoder(spec, module):
|
|
202
|
+
set_input_layers(spec, module)
|
|
203
|
+
for layer_spec, layer in zip(spec.layer, module.layers):
|
|
204
|
+
set_transformer_encoder_layer(layer_spec, layer)
|
|
205
|
+
if module.layer_norm is not None:
|
|
206
|
+
set_layer_norm(spec.layer_norm, module.layer_norm)
|
|
207
|
+
if module.layernorm_embedding is not None:
|
|
208
|
+
set_layer_norm(spec.layernorm_embedding, module.layernorm_embedding)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def set_transformer_decoder(spec, module, with_encoder_attention=True):
|
|
212
|
+
set_input_layers(spec, module)
|
|
213
|
+
set_linear(spec.projection, module.output_projection)
|
|
214
|
+
for layer_spec, layer in zip(spec.layer, module.layers):
|
|
215
|
+
set_transformer_decoder_layer(
|
|
216
|
+
layer_spec,
|
|
217
|
+
layer,
|
|
218
|
+
with_encoder_attention=with_encoder_attention,
|
|
219
|
+
)
|
|
220
|
+
if module.layer_norm is not None:
|
|
221
|
+
set_layer_norm(spec.layer_norm, module.layer_norm)
|
|
222
|
+
if module.layernorm_embedding is not None:
|
|
223
|
+
set_layer_norm(spec.layernorm_embedding, module.layernorm_embedding)
|
|
224
|
+
if module.project_in_dim is not None:
|
|
225
|
+
set_linear(spec.project_in, module.project_in_dim)
|
|
226
|
+
if module.project_out_dim is not None:
|
|
227
|
+
set_linear(spec.project_out, module.project_out_dim)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def set_input_layers(spec, module):
|
|
231
|
+
set_position_encodings(spec.position_encodings, module.embed_positions)
|
|
232
|
+
set_embeddings(
|
|
233
|
+
spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings,
|
|
234
|
+
module.embed_tokens,
|
|
235
|
+
)
|
|
236
|
+
spec.scale_embeddings = module.embed_scale
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def set_transformer_encoder_layer(spec, module):
|
|
240
|
+
set_ffn(spec.ffn, module)
|
|
241
|
+
set_multi_head_attention(spec.self_attention, module.self_attn, self_attention=True)
|
|
242
|
+
set_layer_norm(spec.self_attention.layer_norm, module.self_attn_layer_norm)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def set_transformer_decoder_layer(spec, module, with_encoder_attention=True):
|
|
246
|
+
set_ffn(spec.ffn, module)
|
|
247
|
+
set_multi_head_attention(spec.self_attention, module.self_attn, self_attention=True)
|
|
248
|
+
set_layer_norm(spec.self_attention.layer_norm, module.self_attn_layer_norm)
|
|
249
|
+
if with_encoder_attention:
|
|
250
|
+
set_multi_head_attention(spec.attention, module.encoder_attn)
|
|
251
|
+
set_layer_norm(spec.attention.layer_norm, module.encoder_attn_layer_norm)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def set_ffn(spec, module):
|
|
255
|
+
set_layer_norm(spec.layer_norm, module.final_layer_norm)
|
|
256
|
+
set_linear(spec.linear_0, module.fc1)
|
|
257
|
+
set_linear(spec.linear_1, module.fc2)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def set_multi_head_attention(spec, module, self_attention=False):
|
|
261
|
+
if self_attention:
|
|
262
|
+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
263
|
+
set_linear(split_layers[0], module.q_proj)
|
|
264
|
+
set_linear(split_layers[1], module.k_proj)
|
|
265
|
+
set_linear(split_layers[2], module.v_proj)
|
|
266
|
+
utils.fuse_linear(spec.linear[0], split_layers)
|
|
267
|
+
else:
|
|
268
|
+
set_linear(spec.linear[0], module.q_proj)
|
|
269
|
+
split_layers = [common_spec.LinearSpec() for _ in range(2)]
|
|
270
|
+
set_linear(split_layers[0], module.k_proj)
|
|
271
|
+
set_linear(split_layers[1], module.v_proj)
|
|
272
|
+
utils.fuse_linear(spec.linear[1], split_layers)
|
|
273
|
+
set_linear(spec.linear[-1], module.out_proj)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def set_layer_norm(spec, module):
|
|
277
|
+
spec.gamma = module.weight.numpy()
|
|
278
|
+
spec.beta = module.bias.numpy()
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def set_linear(spec, module):
|
|
282
|
+
spec.weight = module.weight.numpy()
|
|
283
|
+
if module.bias is not None:
|
|
284
|
+
spec.bias = module.bias.numpy()
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def set_embeddings(spec, module):
|
|
288
|
+
spec.weight = module.weight.numpy()
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def set_position_encodings(spec, module):
|
|
292
|
+
import torch
|
|
293
|
+
|
|
294
|
+
weight = module.weight if isinstance(module, torch.nn.Embedding) else module.weights
|
|
295
|
+
spec.encodings = weight.numpy()[module.padding_idx + 1 :]
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def main():
|
|
299
|
+
parser = argparse.ArgumentParser(
|
|
300
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
301
|
+
)
|
|
302
|
+
parser.add_argument("--model_path", required=True, help="Model path.")
|
|
303
|
+
parser.add_argument(
|
|
304
|
+
"--data_dir",
|
|
305
|
+
required=True,
|
|
306
|
+
help="Data directory containing the source and target vocabularies.",
|
|
307
|
+
)
|
|
308
|
+
parser.add_argument(
|
|
309
|
+
"--user_dir",
|
|
310
|
+
help="Directory containing custom extensions.",
|
|
311
|
+
)
|
|
312
|
+
parser.add_argument(
|
|
313
|
+
"--fixed_dictionary",
|
|
314
|
+
help="Fixed dictionary for multilingual models.",
|
|
315
|
+
)
|
|
316
|
+
parser.add_argument(
|
|
317
|
+
"--source_lang",
|
|
318
|
+
help="Source language. This argument is used to find dictionary file from `data_dir`.",
|
|
319
|
+
)
|
|
320
|
+
parser.add_argument(
|
|
321
|
+
"--target_lang",
|
|
322
|
+
help="Target language. This argument is used to find dictionary file from `data_dir`.",
|
|
323
|
+
)
|
|
324
|
+
parser.add_argument(
|
|
325
|
+
"--no_default_special_tokens",
|
|
326
|
+
action="store_true",
|
|
327
|
+
help=(
|
|
328
|
+
"Require all special tokens to be provided by the user during inference, "
|
|
329
|
+
"including the decoder start token."
|
|
330
|
+
),
|
|
331
|
+
)
|
|
332
|
+
Converter.declare_arguments(parser)
|
|
333
|
+
args = parser.parse_args()
|
|
334
|
+
converter = FairseqConverter(
|
|
335
|
+
args.model_path,
|
|
336
|
+
args.data_dir,
|
|
337
|
+
source_lang=args.source_lang,
|
|
338
|
+
target_lang=args.target_lang,
|
|
339
|
+
fixed_dictionary=args.fixed_dictionary,
|
|
340
|
+
no_default_special_tokens=args.no_default_special_tokens,
|
|
341
|
+
user_dir=args.user_dir,
|
|
342
|
+
)
|
|
343
|
+
converter.convert_from_args(args)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
if __name__ == "__main__":
|
|
347
|
+
main()
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from ctranslate2.converters import utils
|
|
10
|
+
from ctranslate2.converters.converter import Converter
|
|
11
|
+
from ctranslate2.specs import common_spec, transformer_spec
|
|
12
|
+
|
|
13
|
+
_SUPPORTED_ACTIVATIONS = {
|
|
14
|
+
"gelu": common_spec.Activation.GELUSigmoid,
|
|
15
|
+
"relu": common_spec.Activation.RELU,
|
|
16
|
+
"swish": common_spec.Activation.SWISH,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
_SUPPORTED_POSTPROCESS_EMB = {"", "d", "n", "nd"}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MarianConverter(Converter):
|
|
23
|
+
"""Converts models trained with Marian."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, model_path: str, vocab_paths: List[str]):
|
|
26
|
+
"""Initializes the Marian converter.
|
|
27
|
+
|
|
28
|
+
Arguments:
|
|
29
|
+
model_path: Path to the Marian model (.npz file).
|
|
30
|
+
vocab_paths: Paths to the vocabularies (.yml files).
|
|
31
|
+
"""
|
|
32
|
+
self._model_path = model_path
|
|
33
|
+
self._vocab_paths = vocab_paths
|
|
34
|
+
|
|
35
|
+
def _load(self):
|
|
36
|
+
model = np.load(self._model_path)
|
|
37
|
+
config = _get_model_config(model)
|
|
38
|
+
vocabs = list(map(load_vocab, self._vocab_paths))
|
|
39
|
+
|
|
40
|
+
activation = config["transformer-ffn-activation"]
|
|
41
|
+
pre_norm = "n" in config["transformer-preprocess"]
|
|
42
|
+
postprocess_emb = config["transformer-postprocess-emb"]
|
|
43
|
+
|
|
44
|
+
check = utils.ConfigurationChecker()
|
|
45
|
+
check(config["type"] == "transformer", "Option --type must be 'transformer'")
|
|
46
|
+
check(
|
|
47
|
+
config["transformer-decoder-autoreg"] == "self-attention",
|
|
48
|
+
"Option --transformer-decoder-autoreg must be 'self-attention'",
|
|
49
|
+
)
|
|
50
|
+
check(
|
|
51
|
+
not config["transformer-no-projection"],
|
|
52
|
+
"Option --transformer-no-projection is not supported",
|
|
53
|
+
)
|
|
54
|
+
check(
|
|
55
|
+
activation in _SUPPORTED_ACTIVATIONS,
|
|
56
|
+
"Option --transformer-ffn-activation %s is not supported "
|
|
57
|
+
"(supported activations are: %s)"
|
|
58
|
+
% (activation, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
|
|
59
|
+
)
|
|
60
|
+
check(
|
|
61
|
+
postprocess_emb in _SUPPORTED_POSTPROCESS_EMB,
|
|
62
|
+
"Option --transformer-postprocess-emb %s is not supported (supported values are: %s)"
|
|
63
|
+
% (postprocess_emb, ", ".join(_SUPPORTED_POSTPROCESS_EMB)),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if pre_norm:
|
|
67
|
+
check(
|
|
68
|
+
config["transformer-preprocess"] == "n"
|
|
69
|
+
and config["transformer-postprocess"] == "da"
|
|
70
|
+
and config.get("transformer-postprocess-top", "") == "n",
|
|
71
|
+
"Unsupported pre-norm Transformer architecture, expected the following "
|
|
72
|
+
"combination of options: "
|
|
73
|
+
"--transformer-preprocess n "
|
|
74
|
+
"--transformer-postprocess da "
|
|
75
|
+
"--transformer-postprocess-top n",
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
check(
|
|
79
|
+
config["transformer-preprocess"] == ""
|
|
80
|
+
and config["transformer-postprocess"] == "dan"
|
|
81
|
+
and config.get("transformer-postprocess-top", "") == "",
|
|
82
|
+
"Unsupported post-norm Transformer architecture, excepted the following "
|
|
83
|
+
"combination of options: "
|
|
84
|
+
"--transformer-preprocess '' "
|
|
85
|
+
"--transformer-postprocess dan "
|
|
86
|
+
"--transformer-postprocess-top ''",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
check.validate()
|
|
90
|
+
|
|
91
|
+
alignment_layer = config["transformer-guided-alignment-layer"]
|
|
92
|
+
alignment_layer = -1 if alignment_layer == "last" else int(alignment_layer) - 1
|
|
93
|
+
layernorm_embedding = "n" in postprocess_emb
|
|
94
|
+
|
|
95
|
+
model_spec = transformer_spec.TransformerSpec.from_config(
|
|
96
|
+
(config["enc-depth"], config["dec-depth"]),
|
|
97
|
+
config["transformer-heads"],
|
|
98
|
+
pre_norm=pre_norm,
|
|
99
|
+
activation=_SUPPORTED_ACTIVATIONS[activation],
|
|
100
|
+
alignment_layer=alignment_layer,
|
|
101
|
+
alignment_heads=1,
|
|
102
|
+
layernorm_embedding=layernorm_embedding,
|
|
103
|
+
)
|
|
104
|
+
set_transformer_spec(model_spec, model)
|
|
105
|
+
model_spec.register_source_vocabulary(vocabs[0])
|
|
106
|
+
model_spec.register_target_vocabulary(vocabs[-1])
|
|
107
|
+
model_spec.config.add_source_eos = True
|
|
108
|
+
return model_spec
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_model_config(model):
|
|
112
|
+
config = model["special:model.yml"]
|
|
113
|
+
config = config[:-1].tobytes()
|
|
114
|
+
config = yaml.safe_load(config)
|
|
115
|
+
return config
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def load_vocab(path):
|
|
119
|
+
# pyyaml skips some entries so we manually parse the vocabulary file.
|
|
120
|
+
with open(path, encoding="utf-8") as vocab:
|
|
121
|
+
tokens = []
|
|
122
|
+
token = None
|
|
123
|
+
idx = None
|
|
124
|
+
for i, line in enumerate(vocab):
|
|
125
|
+
line = line.rstrip("\n\r")
|
|
126
|
+
if not line:
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
if line.startswith("? "): # Complex key mapping (key)
|
|
130
|
+
token = line[2:]
|
|
131
|
+
elif token is not None: # Complex key mapping (value)
|
|
132
|
+
idx = line[2:]
|
|
133
|
+
else:
|
|
134
|
+
token, idx = line.rsplit(":", 1)
|
|
135
|
+
|
|
136
|
+
if token is not None:
|
|
137
|
+
if token.startswith('"') and token.endswith('"'):
|
|
138
|
+
# Unescape characters and remove quotes.
|
|
139
|
+
token = re.sub(r"\\([^x])", r"\1", token)
|
|
140
|
+
token = token[1:-1]
|
|
141
|
+
if token.startswith("\\x"):
|
|
142
|
+
# Convert the digraph \x to the actual escaped sequence.
|
|
143
|
+
token = chr(int(token[2:], base=16))
|
|
144
|
+
elif token.startswith("'") and token.endswith("'"):
|
|
145
|
+
token = token[1:-1]
|
|
146
|
+
token = token.replace("''", "'")
|
|
147
|
+
|
|
148
|
+
if idx is not None:
|
|
149
|
+
try:
|
|
150
|
+
idx = int(idx.strip())
|
|
151
|
+
except ValueError as e:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"Unexpected format at line %d: '%s'" % (i + 1, line)
|
|
154
|
+
) from e
|
|
155
|
+
|
|
156
|
+
tokens.append((idx, token))
|
|
157
|
+
|
|
158
|
+
token = None
|
|
159
|
+
idx = None
|
|
160
|
+
|
|
161
|
+
return [token for _, token in sorted(tokens, key=lambda item: item[0])]
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def set_transformer_spec(spec, weights):
|
|
165
|
+
set_transformer_encoder(spec.encoder, weights, "encoder")
|
|
166
|
+
set_transformer_decoder(spec.decoder, weights, "decoder")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def set_transformer_encoder(spec, weights, scope):
|
|
170
|
+
set_common_layers(spec, weights, scope)
|
|
171
|
+
for i, layer_spec in enumerate(spec.layer):
|
|
172
|
+
set_transformer_encoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def set_transformer_decoder(spec, weights, scope):
|
|
176
|
+
spec.start_from_zero_embedding = True
|
|
177
|
+
set_common_layers(spec, weights, scope)
|
|
178
|
+
for i, layer_spec in enumerate(spec.layer):
|
|
179
|
+
set_transformer_decoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1))
|
|
180
|
+
|
|
181
|
+
set_linear(
|
|
182
|
+
spec.projection,
|
|
183
|
+
weights,
|
|
184
|
+
"%s_ff_logit_out" % scope,
|
|
185
|
+
reuse_weight=spec.embeddings.weight,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def set_common_layers(spec, weights, scope):
|
|
190
|
+
embeddings_specs = spec.embeddings
|
|
191
|
+
if not isinstance(embeddings_specs, list):
|
|
192
|
+
embeddings_specs = [embeddings_specs]
|
|
193
|
+
|
|
194
|
+
set_embeddings(embeddings_specs[0], weights, scope)
|
|
195
|
+
set_position_encodings(
|
|
196
|
+
spec.position_encodings, weights, dim=embeddings_specs[0].weight.shape[1]
|
|
197
|
+
)
|
|
198
|
+
if hasattr(spec, "layernorm_embedding"):
|
|
199
|
+
set_layer_norm(
|
|
200
|
+
spec.layernorm_embedding,
|
|
201
|
+
weights,
|
|
202
|
+
"%s_emb" % scope,
|
|
203
|
+
pre_norm=True,
|
|
204
|
+
)
|
|
205
|
+
if hasattr(spec, "layer_norm"):
|
|
206
|
+
set_layer_norm(spec.layer_norm, weights, "%s_top" % scope)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def set_transformer_encoder_layer(spec, weights, scope):
|
|
210
|
+
set_ffn(spec.ffn, weights, "%s_ffn" % scope)
|
|
211
|
+
set_multi_head_attention(
|
|
212
|
+
spec.self_attention, weights, "%s_self" % scope, self_attention=True
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def set_transformer_decoder_layer(spec, weights, scope):
|
|
217
|
+
set_ffn(spec.ffn, weights, "%s_ffn" % scope)
|
|
218
|
+
set_multi_head_attention(
|
|
219
|
+
spec.self_attention, weights, "%s_self" % scope, self_attention=True
|
|
220
|
+
)
|
|
221
|
+
set_multi_head_attention(spec.attention, weights, "%s_context" % scope)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def set_multi_head_attention(spec, weights, scope, self_attention=False):
|
|
225
|
+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
226
|
+
set_linear(split_layers[0], weights, scope, "q")
|
|
227
|
+
set_linear(split_layers[1], weights, scope, "k")
|
|
228
|
+
set_linear(split_layers[2], weights, scope, "v")
|
|
229
|
+
|
|
230
|
+
if self_attention:
|
|
231
|
+
utils.fuse_linear(spec.linear[0], split_layers)
|
|
232
|
+
else:
|
|
233
|
+
spec.linear[0].weight = split_layers[0].weight
|
|
234
|
+
spec.linear[0].bias = split_layers[0].bias
|
|
235
|
+
utils.fuse_linear(spec.linear[1], split_layers[1:])
|
|
236
|
+
|
|
237
|
+
set_linear(spec.linear[-1], weights, scope, "o")
|
|
238
|
+
set_layer_norm_auto(spec.layer_norm, weights, "%s_Wo" % scope)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def set_ffn(spec, weights, scope):
|
|
242
|
+
set_layer_norm_auto(spec.layer_norm, weights, "%s_ffn" % scope)
|
|
243
|
+
set_linear(spec.linear_0, weights, scope, "1")
|
|
244
|
+
set_linear(spec.linear_1, weights, scope, "2")
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def set_layer_norm_auto(spec, weights, scope):
|
|
248
|
+
try:
|
|
249
|
+
set_layer_norm(spec, weights, scope, pre_norm=True)
|
|
250
|
+
except KeyError:
|
|
251
|
+
set_layer_norm(spec, weights, scope)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def set_layer_norm(spec, weights, scope, pre_norm=False):
|
|
255
|
+
suffix = "_pre" if pre_norm else ""
|
|
256
|
+
spec.gamma = weights["%s_ln_scale%s" % (scope, suffix)].squeeze()
|
|
257
|
+
spec.beta = weights["%s_ln_bias%s" % (scope, suffix)].squeeze()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def set_linear(spec, weights, scope, suffix="", reuse_weight=None):
|
|
261
|
+
weight = weights.get("%s_W%s" % (scope, suffix))
|
|
262
|
+
|
|
263
|
+
if weight is None:
|
|
264
|
+
weight = weights.get("%s_Wt%s" % (scope, suffix), reuse_weight)
|
|
265
|
+
else:
|
|
266
|
+
weight = weight.transpose()
|
|
267
|
+
|
|
268
|
+
spec.weight = weight
|
|
269
|
+
|
|
270
|
+
bias = weights.get("%s_b%s" % (scope, suffix))
|
|
271
|
+
if bias is not None:
|
|
272
|
+
spec.bias = bias.squeeze()
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def set_embeddings(spec, weights, scope):
|
|
276
|
+
spec.weight = weights.get("%s_Wemb" % scope)
|
|
277
|
+
if spec.weight is None:
|
|
278
|
+
spec.weight = weights.get("Wemb")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def set_position_encodings(spec, weights, dim=None):
|
|
282
|
+
spec.encodings = weights.get("Wpos", _make_sinusoidal_position_encodings(dim))
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _make_sinusoidal_position_encodings(dim, num_positions=2048):
|
|
286
|
+
positions = np.arange(num_positions)
|
|
287
|
+
timescales = np.power(10000, 2 * (np.arange(dim) // 2) / dim)
|
|
288
|
+
position_enc = np.expand_dims(positions, 1) / np.expand_dims(timescales, 0)
|
|
289
|
+
table = np.zeros_like(position_enc)
|
|
290
|
+
table[:, : dim // 2] = np.sin(position_enc[:, 0::2])
|
|
291
|
+
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
|
292
|
+
return table
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def main():
|
|
296
|
+
parser = argparse.ArgumentParser(
|
|
297
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
298
|
+
)
|
|
299
|
+
parser.add_argument(
|
|
300
|
+
"--model_path", required=True, help="Path to the model .npz file."
|
|
301
|
+
)
|
|
302
|
+
parser.add_argument(
|
|
303
|
+
"--vocab_paths",
|
|
304
|
+
required=True,
|
|
305
|
+
nargs="+",
|
|
306
|
+
help="List of paths to the YAML vocabularies.",
|
|
307
|
+
)
|
|
308
|
+
Converter.declare_arguments(parser)
|
|
309
|
+
args = parser.parse_args()
|
|
310
|
+
converter = MarianConverter(args.model_path, args.vocab_paths)
|
|
311
|
+
converter.convert_from_args(args)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
if __name__ == "__main__":
|
|
315
|
+
main()
|