ctranslate2 4.7.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctranslate2/.dylibs/libctranslate2.4.7.0.dylib +0 -0
- ctranslate2/__init__.py +66 -0
- ctranslate2/_ext.cpython-314-darwin.so +0 -0
- ctranslate2/converters/__init__.py +8 -0
- ctranslate2/converters/converter.py +109 -0
- ctranslate2/converters/eole_ct2.py +353 -0
- ctranslate2/converters/fairseq.py +347 -0
- ctranslate2/converters/marian.py +315 -0
- ctranslate2/converters/openai_gpt2.py +95 -0
- ctranslate2/converters/opennmt_py.py +361 -0
- ctranslate2/converters/opennmt_tf.py +455 -0
- ctranslate2/converters/opus_mt.py +44 -0
- ctranslate2/converters/transformers.py +3721 -0
- ctranslate2/converters/utils.py +127 -0
- ctranslate2/extensions.py +589 -0
- ctranslate2/logging.py +45 -0
- ctranslate2/models/__init__.py +18 -0
- ctranslate2/specs/__init__.py +18 -0
- ctranslate2/specs/attention_spec.py +98 -0
- ctranslate2/specs/common_spec.py +66 -0
- ctranslate2/specs/model_spec.py +767 -0
- ctranslate2/specs/transformer_spec.py +797 -0
- ctranslate2/specs/wav2vec2_spec.py +72 -0
- ctranslate2/specs/wav2vec2bert_spec.py +97 -0
- ctranslate2/specs/whisper_spec.py +77 -0
- ctranslate2/version.py +3 -0
- ctranslate2-4.7.0.dist-info/METADATA +180 -0
- ctranslate2-4.7.0.dist-info/RECORD +31 -0
- ctranslate2-4.7.0.dist-info/WHEEL +6 -0
- ctranslate2-4.7.0.dist-info/entry_points.txt +8 -0
- ctranslate2-4.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def fuse_linear(spec, layers):
|
|
5
|
+
if not layers:
|
|
6
|
+
raise ValueError("Cannot fuse linear layers: at least one layer is required")
|
|
7
|
+
|
|
8
|
+
if isinstance(layers[0].weight, np.ndarray):
|
|
9
|
+
concatenate = np.concatenate
|
|
10
|
+
zeros = np.zeros
|
|
11
|
+
else:
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
concatenate = torch.cat
|
|
15
|
+
zeros = torch.zeros
|
|
16
|
+
|
|
17
|
+
spec.weight = concatenate([layer.weight for layer in layers])
|
|
18
|
+
|
|
19
|
+
bias_dtype = None
|
|
20
|
+
for layer in layers:
|
|
21
|
+
if layer.has_bias():
|
|
22
|
+
bias_dtype = layer.bias.dtype
|
|
23
|
+
break
|
|
24
|
+
|
|
25
|
+
if bias_dtype is not None:
|
|
26
|
+
spec.bias = concatenate(
|
|
27
|
+
[
|
|
28
|
+
(
|
|
29
|
+
layer.bias
|
|
30
|
+
if layer.has_bias()
|
|
31
|
+
else zeros([layer.weight.shape[0]], dtype=bias_dtype)
|
|
32
|
+
)
|
|
33
|
+
for layer in layers
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def fuse_linear_prequant(spec, layers, axis):
|
|
39
|
+
if not layers:
|
|
40
|
+
raise ValueError("Cannot fuse linear layers: at least one layer is required")
|
|
41
|
+
params = ["weight", "weight_scale", "weight_zero"]
|
|
42
|
+
if isinstance(layers[0].weight, np.ndarray):
|
|
43
|
+
concatenate = np.concatenate
|
|
44
|
+
else:
|
|
45
|
+
import torch
|
|
46
|
+
|
|
47
|
+
concatenate = torch.cat
|
|
48
|
+
|
|
49
|
+
for param in params:
|
|
50
|
+
setattr(
|
|
51
|
+
spec,
|
|
52
|
+
param,
|
|
53
|
+
concatenate([getattr(layer, param) for layer in layers], axis=axis),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def permute_for_sliced_rotary(weight, num_heads, rotary_dim=None):
|
|
58
|
+
"""Permutes the weight to use the sliced rotary implementation."""
|
|
59
|
+
if rotary_dim is not None:
|
|
60
|
+
weight = weight.reshape(num_heads, weight.shape[0] // num_heads, -1)
|
|
61
|
+
|
|
62
|
+
rotary_weight = weight[:, :rotary_dim]
|
|
63
|
+
rotary_weight = permute_for_sliced_rotary(
|
|
64
|
+
rotary_weight.reshape(num_heads * rotary_dim, -1), num_heads
|
|
65
|
+
).reshape(num_heads, rotary_dim, -1)
|
|
66
|
+
|
|
67
|
+
weight[:, :rotary_dim] = rotary_weight
|
|
68
|
+
|
|
69
|
+
return weight.reshape(-1, weight.shape[-1])
|
|
70
|
+
|
|
71
|
+
return (
|
|
72
|
+
weight.reshape(num_heads, weight.shape[0] // num_heads // 2, 2, weight.shape[1])
|
|
73
|
+
.swapaxes(1, 2)
|
|
74
|
+
.reshape(weight.shape[0], weight.shape[1])
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def smooth_activation(layer_norm, linear, activation_scales):
|
|
79
|
+
"""Applies the activation smoothing technique described in
|
|
80
|
+
https://github.com/mit-han-lab/smoothquant.
|
|
81
|
+
"""
|
|
82
|
+
if not isinstance(linear.weight, np.ndarray):
|
|
83
|
+
linear_weight = linear.weight.numpy()
|
|
84
|
+
activation_scales = activation_scales.numpy()
|
|
85
|
+
else:
|
|
86
|
+
linear_weight = linear.weight
|
|
87
|
+
|
|
88
|
+
weight_scales = np.amax(np.absolute(linear_weight), axis=0)
|
|
89
|
+
weight_scales = np.maximum(weight_scales, 1e-5)
|
|
90
|
+
|
|
91
|
+
activation_scales = activation_scales.astype(weight_scales.dtype)
|
|
92
|
+
|
|
93
|
+
scales = np.sqrt(activation_scales / weight_scales)
|
|
94
|
+
scales = np.maximum(scales, 1e-5)
|
|
95
|
+
|
|
96
|
+
if not isinstance(linear.weight, np.ndarray):
|
|
97
|
+
import torch
|
|
98
|
+
|
|
99
|
+
scales = torch.from_numpy(scales)
|
|
100
|
+
|
|
101
|
+
layer_norm.gamma /= scales
|
|
102
|
+
layer_norm.beta /= scales
|
|
103
|
+
|
|
104
|
+
linear.weight *= scales.reshape(1, -1)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def raise_unsupported(reasons):
|
|
108
|
+
message = (
|
|
109
|
+
"The model you are trying to convert is not supported by CTranslate2. "
|
|
110
|
+
"We identified the following reasons:\n"
|
|
111
|
+
)
|
|
112
|
+
for reason in reasons:
|
|
113
|
+
message += "\n- " + reason
|
|
114
|
+
raise ValueError(message)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class ConfigurationChecker:
|
|
118
|
+
def __init__(self):
|
|
119
|
+
self._unsupported_reasons = []
|
|
120
|
+
|
|
121
|
+
def __call__(self, assert_condition, error_message):
|
|
122
|
+
if not assert_condition:
|
|
123
|
+
self._unsupported_reasons.append(error_message)
|
|
124
|
+
|
|
125
|
+
def validate(self):
|
|
126
|
+
if self._unsupported_reasons:
|
|
127
|
+
raise_unsupported(self._unsupported_reasons)
|