ctranslate2 4.7.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctranslate2/.dylibs/libctranslate2.4.7.0.dylib +0 -0
- ctranslate2/__init__.py +66 -0
- ctranslate2/_ext.cpython-314-darwin.so +0 -0
- ctranslate2/converters/__init__.py +8 -0
- ctranslate2/converters/converter.py +109 -0
- ctranslate2/converters/eole_ct2.py +353 -0
- ctranslate2/converters/fairseq.py +347 -0
- ctranslate2/converters/marian.py +315 -0
- ctranslate2/converters/openai_gpt2.py +95 -0
- ctranslate2/converters/opennmt_py.py +361 -0
- ctranslate2/converters/opennmt_tf.py +455 -0
- ctranslate2/converters/opus_mt.py +44 -0
- ctranslate2/converters/transformers.py +3721 -0
- ctranslate2/converters/utils.py +127 -0
- ctranslate2/extensions.py +589 -0
- ctranslate2/logging.py +45 -0
- ctranslate2/models/__init__.py +18 -0
- ctranslate2/specs/__init__.py +18 -0
- ctranslate2/specs/attention_spec.py +98 -0
- ctranslate2/specs/common_spec.py +66 -0
- ctranslate2/specs/model_spec.py +767 -0
- ctranslate2/specs/transformer_spec.py +797 -0
- ctranslate2/specs/wav2vec2_spec.py +72 -0
- ctranslate2/specs/wav2vec2bert_spec.py +97 -0
- ctranslate2/specs/whisper_spec.py +77 -0
- ctranslate2/version.py +3 -0
- ctranslate2-4.7.0.dist-info/METADATA +180 -0
- ctranslate2-4.7.0.dist-info/RECORD +31 -0
- ctranslate2-4.7.0.dist-info/WHEEL +6 -0
- ctranslate2-4.7.0.dist-info/entry_points.txt +8 -0
- ctranslate2-4.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from ctranslate2.converters.converter import Converter
|
|
6
|
+
from ctranslate2.specs import common_spec, model_spec, transformer_spec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenAIGPT2Converter(Converter):
|
|
10
|
+
"""Converts GPT-2 models from https://github.com/openai/gpt-2."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_dir: str):
|
|
13
|
+
"""Initializes the OpenAI GPT-2 converter.
|
|
14
|
+
|
|
15
|
+
Arguments:
|
|
16
|
+
model_dir: Path to the OpenAI GPT-2 model directory.
|
|
17
|
+
"""
|
|
18
|
+
self._model_dir = model_dir
|
|
19
|
+
|
|
20
|
+
def _load(self):
|
|
21
|
+
import tensorflow as tf
|
|
22
|
+
|
|
23
|
+
reader = tf.train.load_checkpoint(self._model_dir)
|
|
24
|
+
weights = {
|
|
25
|
+
name: reader.get_tensor(name)
|
|
26
|
+
for name in reader.get_variable_to_shape_map().keys()
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
with open(os.path.join(self._model_dir, "hparams.json")) as hparams_file:
|
|
30
|
+
hparams = json.load(hparams_file)
|
|
31
|
+
with open(os.path.join(self._model_dir, "encoder.json")) as vocab_file:
|
|
32
|
+
vocab = json.load(vocab_file)
|
|
33
|
+
vocab = [
|
|
34
|
+
token
|
|
35
|
+
for token, index in sorted(vocab.items(), key=lambda item: item[1])
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
39
|
+
hparams["n_layer"],
|
|
40
|
+
hparams["n_head"],
|
|
41
|
+
pre_norm=True,
|
|
42
|
+
activation=common_spec.Activation.GELUTanh,
|
|
43
|
+
)
|
|
44
|
+
set_decoder(spec.decoder, weights, "model")
|
|
45
|
+
spec.unk_token = "<|endoftext|>"
|
|
46
|
+
spec.bos_token = "<|endoftext|>"
|
|
47
|
+
spec.eos_token = "<|endoftext|>"
|
|
48
|
+
spec.register_vocabulary(vocab)
|
|
49
|
+
return spec
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def set_decoder(spec, weights, scope):
|
|
53
|
+
spec.embeddings.weight = weights["%s/wte" % scope]
|
|
54
|
+
spec.position_encodings.encodings = weights["%s/wpe" % scope]
|
|
55
|
+
spec.scale_embeddings = False
|
|
56
|
+
spec.projection.weight = spec.embeddings.weight
|
|
57
|
+
set_layer_norm(spec.layer_norm, weights, "%s/ln_f" % scope)
|
|
58
|
+
for i, layer_spec in enumerate(spec.layer):
|
|
59
|
+
set_layer(layer_spec, weights, "%s/h%d" % (scope, i))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def set_layer_norm(spec, weights, scope):
|
|
63
|
+
spec.gamma = weights["%s/g" % scope]
|
|
64
|
+
spec.beta = weights["%s/b" % scope]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def set_linear(spec, weights, scope):
|
|
68
|
+
spec.weight = weights["%s/w" % scope].squeeze().transpose()
|
|
69
|
+
spec.bias = weights["%s/b" % scope]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def set_layer(spec, weights, scope):
|
|
73
|
+
set_layer_norm(spec.self_attention.layer_norm, weights, "%s/ln_1" % scope)
|
|
74
|
+
set_linear(spec.self_attention.linear[0], weights, "%s/attn/c_attn" % scope)
|
|
75
|
+
set_linear(spec.self_attention.linear[1], weights, "%s/attn/c_proj" % scope)
|
|
76
|
+
set_layer_norm(spec.ffn.layer_norm, weights, "%s/ln_2" % scope)
|
|
77
|
+
set_linear(spec.ffn.linear_0, weights, "%s/mlp/c_fc" % scope)
|
|
78
|
+
set_linear(spec.ffn.linear_1, weights, "%s/mlp/c_proj" % scope)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def main():
|
|
82
|
+
parser = argparse.ArgumentParser(
|
|
83
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
84
|
+
)
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--model_dir", required=True, help="Path to the model directory."
|
|
87
|
+
)
|
|
88
|
+
Converter.declare_arguments(parser)
|
|
89
|
+
args = parser.parse_args()
|
|
90
|
+
converter = OpenAIGPT2Converter(args.model_dir)
|
|
91
|
+
converter.convert_from_args(args)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
if __name__ == "__main__":
|
|
95
|
+
main()
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from ctranslate2.converters import utils
|
|
4
|
+
from ctranslate2.converters.converter import Converter
|
|
5
|
+
from ctranslate2.specs import common_spec, transformer_spec
|
|
6
|
+
|
|
7
|
+
_SUPPORTED_ACTIVATIONS = {
|
|
8
|
+
"gelu": common_spec.Activation.GELU,
|
|
9
|
+
"fast_gelu": common_spec.Activation.GELUTanh,
|
|
10
|
+
"relu": common_spec.Activation.RELU,
|
|
11
|
+
"silu": common_spec.Activation.SWISH,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
_SUPPORTED_FEATURES_MERGE = {
|
|
15
|
+
"concat": common_spec.EmbeddingsMerge.CONCAT,
|
|
16
|
+
"sum": common_spec.EmbeddingsMerge.ADD,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def check_opt(opt, num_source_embeddings):
|
|
21
|
+
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
|
|
22
|
+
with_rotary = getattr(opt, "max_relative_positions", 0) == -1
|
|
23
|
+
with_alibi = getattr(opt, "max_relative_positions", 0) == -2
|
|
24
|
+
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
|
|
25
|
+
feat_merge = getattr(opt, "feat_merge", "concat")
|
|
26
|
+
self_attn_type = getattr(opt, "self_attn_type", "scaled-dot")
|
|
27
|
+
|
|
28
|
+
check = utils.ConfigurationChecker()
|
|
29
|
+
check(
|
|
30
|
+
opt.encoder_type == opt.decoder_type
|
|
31
|
+
and opt.decoder_type in {"transformer", "transformer_lm"},
|
|
32
|
+
"Options --encoder_type and --decoder_type must be"
|
|
33
|
+
" 'transformer' or 'transformer_lm",
|
|
34
|
+
)
|
|
35
|
+
check(
|
|
36
|
+
self_attn_type == "scaled-dot",
|
|
37
|
+
"Option --self_attn_type %s is not supported (supported values are: scaled-dot)"
|
|
38
|
+
% self_attn_type,
|
|
39
|
+
)
|
|
40
|
+
check(
|
|
41
|
+
activation_fn in _SUPPORTED_ACTIVATIONS,
|
|
42
|
+
"Option --pos_ffn_activation_fn %s is not supported (supported activations are: %s)"
|
|
43
|
+
% (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
|
|
44
|
+
)
|
|
45
|
+
check(
|
|
46
|
+
opt.position_encoding != (with_relative_position or with_rotary or with_alibi),
|
|
47
|
+
"Options --position_encoding and --max_relative_positions cannot be both enabled "
|
|
48
|
+
"or both disabled",
|
|
49
|
+
)
|
|
50
|
+
check(
|
|
51
|
+
num_source_embeddings == 1 or feat_merge in _SUPPORTED_FEATURES_MERGE,
|
|
52
|
+
"Option --feat_merge %s is not supported (supported merge modes are: %s)"
|
|
53
|
+
% (feat_merge, " ".join(_SUPPORTED_FEATURES_MERGE.keys())),
|
|
54
|
+
)
|
|
55
|
+
check.validate()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_model_spec_seq2seq(
|
|
59
|
+
opt, variables, src_vocabs, tgt_vocabs, num_source_embeddings
|
|
60
|
+
):
|
|
61
|
+
"""Creates a model specification from the model options."""
|
|
62
|
+
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
|
|
63
|
+
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
|
|
64
|
+
feat_merge = getattr(opt, "feat_merge", "concat")
|
|
65
|
+
|
|
66
|
+
# Return the first head of the last layer unless the model was trained with alignments.
|
|
67
|
+
if getattr(opt, "lambda_align", 0) == 0:
|
|
68
|
+
alignment_layer = -1
|
|
69
|
+
alignment_heads = 1
|
|
70
|
+
else:
|
|
71
|
+
alignment_layer = opt.alignment_layer
|
|
72
|
+
alignment_heads = opt.alignment_heads
|
|
73
|
+
|
|
74
|
+
num_heads = getattr(opt, "heads", 8)
|
|
75
|
+
|
|
76
|
+
model_spec = transformer_spec.TransformerSpec.from_config(
|
|
77
|
+
(opt.enc_layers, opt.dec_layers),
|
|
78
|
+
num_heads,
|
|
79
|
+
with_relative_position=with_relative_position,
|
|
80
|
+
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
|
|
81
|
+
alignment_layer=alignment_layer,
|
|
82
|
+
alignment_heads=alignment_heads,
|
|
83
|
+
num_source_embeddings=num_source_embeddings,
|
|
84
|
+
embeddings_merge=_SUPPORTED_FEATURES_MERGE[feat_merge],
|
|
85
|
+
multi_query_attention=getattr(opt, "multiquery", False),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "<s>")
|
|
89
|
+
|
|
90
|
+
set_transformer_spec(model_spec, variables)
|
|
91
|
+
for src_vocab in src_vocabs:
|
|
92
|
+
model_spec.register_source_vocabulary(src_vocab)
|
|
93
|
+
for tgt_vocab in tgt_vocabs:
|
|
94
|
+
model_spec.register_target_vocabulary(tgt_vocab)
|
|
95
|
+
|
|
96
|
+
return model_spec
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embeddings):
|
|
100
|
+
"""Creates a model specification from the model options."""
|
|
101
|
+
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
|
|
102
|
+
with_rotary = getattr(opt, "max_relative_positions", 0) == -1
|
|
103
|
+
with_alibi = getattr(opt, "max_relative_positions", 0) == -2
|
|
104
|
+
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
|
|
105
|
+
num_heads = getattr(opt, "heads", 8)
|
|
106
|
+
num_kv = getattr(opt, "num_kv", 0)
|
|
107
|
+
if num_kv == num_heads or num_kv == 0:
|
|
108
|
+
num_kv = None
|
|
109
|
+
rotary_dim = 0 if with_rotary else None
|
|
110
|
+
rotary_interleave = getattr(opt, "rotary_interleave", True)
|
|
111
|
+
ffn_glu = activation_fn == "silu"
|
|
112
|
+
sliding_window = getattr(opt, "sliding_window", 0)
|
|
113
|
+
|
|
114
|
+
model_spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
115
|
+
opt.dec_layers,
|
|
116
|
+
num_heads,
|
|
117
|
+
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
|
|
118
|
+
ffn_glu=ffn_glu,
|
|
119
|
+
with_relative_position=with_relative_position,
|
|
120
|
+
alibi=with_alibi,
|
|
121
|
+
rms_norm=opt.layer_norm == "rms",
|
|
122
|
+
rotary_dim=rotary_dim,
|
|
123
|
+
rotary_interleave=rotary_interleave,
|
|
124
|
+
multi_query_attention=getattr(opt, "multiquery", False),
|
|
125
|
+
num_heads_kv=num_kv,
|
|
126
|
+
sliding_window=sliding_window,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
model_spec.config.layer_norm_epsilon = getattr(opt, "norm_eps", 1e-6)
|
|
130
|
+
|
|
131
|
+
set_transformer_decoder(
|
|
132
|
+
model_spec.decoder,
|
|
133
|
+
variables,
|
|
134
|
+
with_encoder_attention=False,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
for tgt_vocab in tgt_vocabs:
|
|
138
|
+
model_spec.register_vocabulary(tgt_vocab)
|
|
139
|
+
|
|
140
|
+
return model_spec
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_vocabs(vocab):
|
|
144
|
+
if isinstance(vocab, dict) and "src" in vocab:
|
|
145
|
+
if isinstance(vocab["src"], list):
|
|
146
|
+
src_vocabs = [vocab["src"]]
|
|
147
|
+
tgt_vocabs = [vocab["tgt"]]
|
|
148
|
+
|
|
149
|
+
src_feats = vocab.get("src_feats")
|
|
150
|
+
if src_feats is not None:
|
|
151
|
+
src_vocabs.extend(src_feats.values())
|
|
152
|
+
else:
|
|
153
|
+
src_vocabs = [field[1].vocab.itos for field in vocab["src"].fields]
|
|
154
|
+
tgt_vocabs = [field[1].vocab.itos for field in vocab["tgt"].fields]
|
|
155
|
+
else:
|
|
156
|
+
# Compatibility with older models.
|
|
157
|
+
src_vocabs = [vocab[0][1].itos]
|
|
158
|
+
tgt_vocabs = [vocab[1][1].itos]
|
|
159
|
+
|
|
160
|
+
return src_vocabs, tgt_vocabs
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class OpenNMTPyConverter(Converter):
|
|
164
|
+
"""Converts models generated by OpenNMT-py."""
|
|
165
|
+
|
|
166
|
+
def __init__(self, model_path: str):
|
|
167
|
+
"""Initializes the OpenNMT-py converter.
|
|
168
|
+
|
|
169
|
+
Arguments:
|
|
170
|
+
model_path: Path to the OpenNMT-py PyTorch model (.pt file).
|
|
171
|
+
"""
|
|
172
|
+
self._model_path = model_path
|
|
173
|
+
|
|
174
|
+
def _load(self):
|
|
175
|
+
import torch
|
|
176
|
+
|
|
177
|
+
checkpoint = torch.load(
|
|
178
|
+
self._model_path, map_location="cpu", weights_only=False
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
src_vocabs, tgt_vocabs = get_vocabs(checkpoint["vocab"])
|
|
182
|
+
|
|
183
|
+
check_opt(checkpoint["opt"], num_source_embeddings=len(src_vocabs))
|
|
184
|
+
|
|
185
|
+
variables = checkpoint["model"]
|
|
186
|
+
variables.update(
|
|
187
|
+
{
|
|
188
|
+
"generator.%s" % key: value
|
|
189
|
+
for key, value in checkpoint["generator"].items()
|
|
190
|
+
}
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if checkpoint["opt"].decoder_type == "transformer_lm":
|
|
194
|
+
return _get_model_spec_lm(
|
|
195
|
+
checkpoint["opt"],
|
|
196
|
+
variables,
|
|
197
|
+
src_vocabs,
|
|
198
|
+
tgt_vocabs,
|
|
199
|
+
num_source_embeddings=len(src_vocabs),
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
return _get_model_spec_seq2seq(
|
|
203
|
+
checkpoint["opt"],
|
|
204
|
+
variables,
|
|
205
|
+
src_vocabs,
|
|
206
|
+
tgt_vocabs,
|
|
207
|
+
num_source_embeddings=len(src_vocabs),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def set_transformer_spec(spec, variables):
|
|
212
|
+
set_transformer_encoder(spec.encoder, variables)
|
|
213
|
+
set_transformer_decoder(spec.decoder, variables)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def set_transformer_encoder(spec, variables):
|
|
217
|
+
set_input_layers(spec, variables, "encoder")
|
|
218
|
+
set_layer_norm(spec.layer_norm, variables, "encoder.layer_norm")
|
|
219
|
+
for i, layer in enumerate(spec.layer):
|
|
220
|
+
set_transformer_encoder_layer(layer, variables, "encoder.transformer.%d" % i)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def set_transformer_decoder(spec, variables, with_encoder_attention=True):
|
|
224
|
+
set_input_layers(spec, variables, "decoder")
|
|
225
|
+
set_layer_norm(spec.layer_norm, variables, "decoder.layer_norm")
|
|
226
|
+
for i, layer in enumerate(spec.layer):
|
|
227
|
+
set_transformer_decoder_layer(
|
|
228
|
+
layer,
|
|
229
|
+
variables,
|
|
230
|
+
"decoder.transformer_layers.%d" % i,
|
|
231
|
+
with_encoder_attention=with_encoder_attention,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
set_linear(spec.projection, variables, "generator")
|
|
236
|
+
except KeyError:
|
|
237
|
+
# Compatibility when the generator was a nn.Sequential module.
|
|
238
|
+
set_linear(spec.projection, variables, "generator.0")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def set_input_layers(spec, variables, scope):
|
|
242
|
+
if hasattr(spec, "position_encodings"):
|
|
243
|
+
set_position_encodings(
|
|
244
|
+
spec.position_encodings,
|
|
245
|
+
variables,
|
|
246
|
+
"%s.embeddings.make_embedding.pe" % scope,
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
# See https://github.com/OpenNMT/OpenNMT-py/issues/1722
|
|
250
|
+
spec.scale_embeddings = False
|
|
251
|
+
|
|
252
|
+
embeddings_specs = spec.embeddings
|
|
253
|
+
if not isinstance(embeddings_specs, list):
|
|
254
|
+
embeddings_specs = [embeddings_specs]
|
|
255
|
+
|
|
256
|
+
for i, embeddings_spec in enumerate(embeddings_specs):
|
|
257
|
+
set_embeddings(
|
|
258
|
+
embeddings_spec,
|
|
259
|
+
variables,
|
|
260
|
+
"%s.embeddings.make_embedding.emb_luts.%d" % (scope, i),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def set_transformer_encoder_layer(spec, variables, scope):
|
|
265
|
+
set_ffn(spec.ffn, variables, "%s.feed_forward" % scope)
|
|
266
|
+
set_multi_head_attention(
|
|
267
|
+
spec.self_attention,
|
|
268
|
+
variables,
|
|
269
|
+
"%s.self_attn" % scope,
|
|
270
|
+
self_attention=True,
|
|
271
|
+
)
|
|
272
|
+
set_layer_norm(spec.self_attention.layer_norm, variables, "%s.layer_norm" % scope)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention=True):
|
|
276
|
+
set_ffn(spec.ffn, variables, "%s.feed_forward" % scope)
|
|
277
|
+
set_multi_head_attention(
|
|
278
|
+
spec.self_attention,
|
|
279
|
+
variables,
|
|
280
|
+
"%s.self_attn" % scope,
|
|
281
|
+
self_attention=True,
|
|
282
|
+
)
|
|
283
|
+
set_layer_norm(spec.self_attention.layer_norm, variables, "%s.layer_norm_1" % scope)
|
|
284
|
+
if with_encoder_attention:
|
|
285
|
+
set_multi_head_attention(spec.attention, variables, "%s.context_attn" % scope)
|
|
286
|
+
set_layer_norm(spec.attention.layer_norm, variables, "%s.layer_norm_2" % scope)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def set_ffn(spec, variables, scope):
|
|
290
|
+
set_layer_norm(spec.layer_norm, variables, "%s.layer_norm" % scope)
|
|
291
|
+
set_linear(spec.linear_0, variables, "%s.w_1" % scope)
|
|
292
|
+
set_linear(spec.linear_1, variables, "%s.w_2" % scope)
|
|
293
|
+
if hasattr(spec, "linear_0_noact"):
|
|
294
|
+
set_linear(spec.linear_0_noact, variables, "%s.w_3" % scope)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def set_multi_head_attention(spec, variables, scope, self_attention=False):
|
|
298
|
+
if self_attention:
|
|
299
|
+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
300
|
+
set_linear(split_layers[0], variables, "%s.linear_query" % scope)
|
|
301
|
+
set_linear(split_layers[1], variables, "%s.linear_keys" % scope)
|
|
302
|
+
set_linear(split_layers[2], variables, "%s.linear_values" % scope)
|
|
303
|
+
utils.fuse_linear(spec.linear[0], split_layers)
|
|
304
|
+
else:
|
|
305
|
+
set_linear(spec.linear[0], variables, "%s.linear_query" % scope)
|
|
306
|
+
split_layers = [common_spec.LinearSpec() for _ in range(2)]
|
|
307
|
+
set_linear(split_layers[0], variables, "%s.linear_keys" % scope)
|
|
308
|
+
set_linear(split_layers[1], variables, "%s.linear_values" % scope)
|
|
309
|
+
utils.fuse_linear(spec.linear[1], split_layers)
|
|
310
|
+
set_linear(spec.linear[-1], variables, "%s.final_linear" % scope)
|
|
311
|
+
if hasattr(spec, "relative_position_keys"):
|
|
312
|
+
spec.relative_position_keys = _get_variable(
|
|
313
|
+
variables, "%s.relative_positions_embeddings.weight" % scope
|
|
314
|
+
)
|
|
315
|
+
spec.relative_position_values = spec.relative_position_keys
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def set_layer_norm(spec, variables, scope):
|
|
319
|
+
try:
|
|
320
|
+
spec.gamma = _get_variable(variables, "%s.weight" % scope)
|
|
321
|
+
except KeyError:
|
|
322
|
+
# Compatibility with older models using a custom LayerNorm module.
|
|
323
|
+
spec.gamma = _get_variable(variables, "%s.a_2" % scope)
|
|
324
|
+
spec.beta = _get_variable(variables, "%s.b_2" % scope)
|
|
325
|
+
try:
|
|
326
|
+
spec.beta = _get_variable(variables, "%s.bias" % scope)
|
|
327
|
+
except KeyError:
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def set_linear(spec, variables, scope):
|
|
332
|
+
spec.weight = _get_variable(variables, "%s.weight" % scope)
|
|
333
|
+
bias = variables.get("%s.bias" % scope)
|
|
334
|
+
if bias is not None:
|
|
335
|
+
spec.bias = bias
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def set_embeddings(spec, variables, scope):
|
|
339
|
+
spec.weight = _get_variable(variables, "%s.weight" % scope)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def set_position_encodings(spec, variables, scope):
|
|
343
|
+
spec.encodings = _get_variable(variables, "%s.pe" % scope).squeeze()
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _get_variable(variables, name):
|
|
347
|
+
return variables[name]
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def main():
|
|
351
|
+
parser = argparse.ArgumentParser(
|
|
352
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
353
|
+
)
|
|
354
|
+
parser.add_argument("--model_path", required=True, help="Model path.")
|
|
355
|
+
Converter.declare_arguments(parser)
|
|
356
|
+
args = parser.parse_args()
|
|
357
|
+
OpenNMTPyConverter(args.model_path).convert_from_args(args)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
if __name__ == "__main__":
|
|
361
|
+
main()
|