ctranslate2 4.6.1__cp314-cp314-win_amd64.whl → 4.6.3__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctranslate2/__init__.py +11 -3
- ctranslate2/_ext.cp314-win_amd64.pyd +0 -0
- ctranslate2/converters/fairseq.py +3 -1
- ctranslate2/converters/opennmt_py.py +3 -1
- ctranslate2/converters/transformers.py +769 -60
- ctranslate2/ctranslate2.dll +0 -0
- ctranslate2/cudnn64_9.dll +0 -0
- ctranslate2/extensions.py +17 -13
- ctranslate2/specs/attention_spec.py +9 -1
- ctranslate2/specs/transformer_spec.py +98 -8
- ctranslate2/version.py +1 -1
- {ctranslate2-4.6.1.dist-info → ctranslate2-4.6.3.dist-info}/METADATA +14 -3
- {ctranslate2-4.6.1.dist-info → ctranslate2-4.6.3.dist-info}/RECORD +16 -16
- {ctranslate2-4.6.1.dist-info → ctranslate2-4.6.3.dist-info}/WHEEL +0 -0
- {ctranslate2-4.6.1.dist-info → ctranslate2-4.6.3.dist-info}/entry_points.txt +0 -0
- {ctranslate2-4.6.1.dist-info → ctranslate2-4.6.3.dist-info}/top_level.txt +0 -0
|
@@ -89,7 +89,7 @@ class TransformersConverter(Converter):
|
|
|
89
89
|
copy_files: List of filenames to copy from the Hugging Face model to the
|
|
90
90
|
converted model directory.
|
|
91
91
|
load_as_float16: Load the model weights as float16. More precisely, the model
|
|
92
|
-
will be loaded with ``from_pretrained(...,
|
|
92
|
+
will be loaded with ``from_pretrained(..., dtype=torch.float16)``.
|
|
93
93
|
revision: Revision of the model to download from the Hugging Face Hub.
|
|
94
94
|
low_cpu_mem_usage: Enable the flag ``low_cpu_mem_usage`` when loading the model
|
|
95
95
|
with ``from_pretrained``.
|
|
@@ -123,10 +123,11 @@ class TransformersConverter(Converter):
|
|
|
123
123
|
tokenizer_class = transformers.AutoTokenizer
|
|
124
124
|
|
|
125
125
|
kwargs = {
|
|
126
|
-
"
|
|
126
|
+
"dtype": (
|
|
127
127
|
torch.float16
|
|
128
128
|
if self._load_as_float16
|
|
129
|
-
else getattr(config, "
|
|
129
|
+
else getattr(config, "dtype", None)
|
|
130
|
+
or getattr(config, "torch_dtype", None)
|
|
130
131
|
)
|
|
131
132
|
}
|
|
132
133
|
|
|
@@ -235,7 +236,7 @@ class ModelLoader(abc.ABC):
|
|
|
235
236
|
|
|
236
237
|
if isinstance(module, transformers.Conv1D):
|
|
237
238
|
spec.weight = spec.weight.transpose(0, 1)
|
|
238
|
-
if module.bias is not None:
|
|
239
|
+
if hasattr(module, "bias") and module.bias is not None:
|
|
239
240
|
spec.bias = module.bias
|
|
240
241
|
|
|
241
242
|
def set_embeddings(self, spec, module):
|
|
@@ -1819,6 +1820,192 @@ class LlamaLoader(ModelLoader):
|
|
|
1819
1820
|
gc.collect()
|
|
1820
1821
|
|
|
1821
1822
|
|
|
1823
|
+
@register_loader("Gemma3TextConfig")
|
|
1824
|
+
@register_loader("Gemma3Config")
|
|
1825
|
+
class Gemma3Loader(ModelLoader):
|
|
1826
|
+
@property
|
|
1827
|
+
def architecture_name(self):
|
|
1828
|
+
return "Gemma3ForCausalLM"
|
|
1829
|
+
|
|
1830
|
+
def get_model_spec(self, model):
|
|
1831
|
+
num_layers = model.config.num_hidden_layers
|
|
1832
|
+
num_heads = model.config.num_attention_heads
|
|
1833
|
+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
|
|
1834
|
+
if num_heads_kv == num_heads:
|
|
1835
|
+
num_heads_kv = None
|
|
1836
|
+
|
|
1837
|
+
head_dim = model.config.head_dim
|
|
1838
|
+
|
|
1839
|
+
activation_config = getattr(
|
|
1840
|
+
model.config, "hidden_activation", "gelu_pytorch_tanh"
|
|
1841
|
+
)
|
|
1842
|
+
|
|
1843
|
+
# Get RoPE parameters
|
|
1844
|
+
rope_theta = getattr(model.config, "rope_theta", 1_000_000) # Global: 1M
|
|
1845
|
+
rope_local_base_freq = getattr(
|
|
1846
|
+
model.config, "rope_local_base_freq", 10_000
|
|
1847
|
+
) # Local: 10k
|
|
1848
|
+
|
|
1849
|
+
# Get sliding window configuration
|
|
1850
|
+
sliding_window = getattr(model.config, "sliding_window", 1024)
|
|
1851
|
+
layer_types = getattr(model.config, "layer_types", None)
|
|
1852
|
+
|
|
1853
|
+
quantization_config = getattr(model.config, "quantization_config", None)
|
|
1854
|
+
if quantization_config:
|
|
1855
|
+
if quantization_config.quant_method == "awq":
|
|
1856
|
+
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
|
|
1857
|
+
if quant_type is None:
|
|
1858
|
+
raise NotImplementedError(
|
|
1859
|
+
"Quantization type '%s' is not yet implemented."
|
|
1860
|
+
% quantization_config.quant_method
|
|
1861
|
+
)
|
|
1862
|
+
else:
|
|
1863
|
+
quant_type = common_spec.Quantization.CT2
|
|
1864
|
+
|
|
1865
|
+
# Create base spec using from_config
|
|
1866
|
+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
1867
|
+
num_layers,
|
|
1868
|
+
num_heads,
|
|
1869
|
+
activation=(
|
|
1870
|
+
common_spec.Activation.GELU
|
|
1871
|
+
if activation_config == "gelu"
|
|
1872
|
+
else common_spec.Activation.GELUTanh
|
|
1873
|
+
),
|
|
1874
|
+
pre_norm=True,
|
|
1875
|
+
ffn_glu=True,
|
|
1876
|
+
rms_norm=True,
|
|
1877
|
+
rotary_dim=head_dim,
|
|
1878
|
+
rotary_interleave=False,
|
|
1879
|
+
rotary_base=rope_local_base_freq, # Default to local base freq
|
|
1880
|
+
num_heads_kv=num_heads_kv,
|
|
1881
|
+
head_dim=head_dim,
|
|
1882
|
+
sliding_window=sliding_window, # Default to local sliding window
|
|
1883
|
+
pre_post_layer_norm=True,
|
|
1884
|
+
qk_norm=True,
|
|
1885
|
+
)
|
|
1886
|
+
|
|
1887
|
+
# Store layer_types for use in set_decoder
|
|
1888
|
+
self._layer_types = layer_types
|
|
1889
|
+
|
|
1890
|
+
# Override per-layer settings for global vs local attention
|
|
1891
|
+
for i, layer_type in enumerate(layer_types):
|
|
1892
|
+
layer = spec.decoder.layer[i]
|
|
1893
|
+
if layer_type == "full_attention":
|
|
1894
|
+
layer.self_attention.rotary_base = np.dtype("float32").type(rope_theta)
|
|
1895
|
+
layer.self_attention.sliding_window = np.dtype("int32").type(0)
|
|
1896
|
+
elif layer_type == "sliding_attention":
|
|
1897
|
+
layer.self_attention.rotary_base = np.dtype("float32").type(
|
|
1898
|
+
rope_local_base_freq
|
|
1899
|
+
)
|
|
1900
|
+
layer.self_attention.sliding_window = np.dtype("int32").type(
|
|
1901
|
+
sliding_window
|
|
1902
|
+
)
|
|
1903
|
+
|
|
1904
|
+
self.set_decoder(spec.decoder, model.model, quant_type)
|
|
1905
|
+
self.set_linear(spec.decoder.projection, model.lm_head)
|
|
1906
|
+
return spec
|
|
1907
|
+
|
|
1908
|
+
def get_vocabulary(self, model, tokenizer):
|
|
1909
|
+
tokens = super().get_vocabulary(model, tokenizer)
|
|
1910
|
+
|
|
1911
|
+
extra_ids = model.config.vocab_size - len(tokens)
|
|
1912
|
+
for i in range(extra_ids):
|
|
1913
|
+
tokens.append("<extra_id_%d>" % i)
|
|
1914
|
+
if model.config.vocab_size < len(tokens):
|
|
1915
|
+
tokens = tokens[: model.config.vocab_size]
|
|
1916
|
+
|
|
1917
|
+
return tokens
|
|
1918
|
+
|
|
1919
|
+
def set_vocabulary(self, spec, tokens):
|
|
1920
|
+
spec.register_vocabulary(tokens)
|
|
1921
|
+
|
|
1922
|
+
def set_config(self, config, model, tokenizer):
|
|
1923
|
+
config.bos_token = tokenizer.bos_token
|
|
1924
|
+
config.unk_token = tokenizer.unk_token
|
|
1925
|
+
|
|
1926
|
+
if (
|
|
1927
|
+
hasattr(tokenizer, "chat_template")
|
|
1928
|
+
and isinstance(tokenizer.chat_template, str)
|
|
1929
|
+
and tokenizer.chat_template.strip()
|
|
1930
|
+
):
|
|
1931
|
+
config.eos_token = "<end_of_turn>"
|
|
1932
|
+
else:
|
|
1933
|
+
config.eos_token = tokenizer.eos_token
|
|
1934
|
+
|
|
1935
|
+
def set_layer_norm(self, spec, layer_norm):
|
|
1936
|
+
spec.gamma = layer_norm.weight + 1.0
|
|
1937
|
+
|
|
1938
|
+
def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
|
|
1939
|
+
spec.scale_embeddings = True
|
|
1940
|
+
spec.start_from_zero_embedding = False
|
|
1941
|
+
self.set_embeddings(spec.embeddings, module.embed_tokens) # Input
|
|
1942
|
+
self.set_layer_norm(spec.layer_norm, module.norm) # Output
|
|
1943
|
+
|
|
1944
|
+
for layer_spec, layer in zip(spec.layer, module.layers):
|
|
1945
|
+
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)
|
|
1946
|
+
|
|
1947
|
+
self.set_layer_norm(
|
|
1948
|
+
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
|
|
1949
|
+
)
|
|
1950
|
+
|
|
1951
|
+
self.set_layer_norm(
|
|
1952
|
+
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
|
|
1953
|
+
)
|
|
1954
|
+
|
|
1955
|
+
self.set_layer_norm(
|
|
1956
|
+
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
|
|
1957
|
+
)
|
|
1958
|
+
|
|
1959
|
+
# Set QK-norm weights (Gemma 3 uses this instead of soft-capping)
|
|
1960
|
+
self.set_layer_norm(
|
|
1961
|
+
layer_spec.self_attention.q_norm, layer.self_attn.q_norm
|
|
1962
|
+
)
|
|
1963
|
+
self.set_layer_norm(
|
|
1964
|
+
layer_spec.self_attention.k_norm, layer.self_attn.k_norm
|
|
1965
|
+
)
|
|
1966
|
+
|
|
1967
|
+
# Set attention projections
|
|
1968
|
+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
1969
|
+
self.set_linear(
|
|
1970
|
+
split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
|
|
1971
|
+
)
|
|
1972
|
+
self.set_linear(
|
|
1973
|
+
split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
|
|
1974
|
+
)
|
|
1975
|
+
self.set_linear(
|
|
1976
|
+
split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
if quant_type == common_spec.Quantization.CT2:
|
|
1980
|
+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
|
|
1981
|
+
else:
|
|
1982
|
+
cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
|
|
1983
|
+
utils.fuse_linear_prequant(
|
|
1984
|
+
layer_spec.self_attention.linear[0], split_layers, cc_dim
|
|
1985
|
+
)
|
|
1986
|
+
|
|
1987
|
+
self.set_linear(
|
|
1988
|
+
layer_spec.self_attention.linear[1],
|
|
1989
|
+
layer.self_attn.o_proj,
|
|
1990
|
+
quant_type=quant_type,
|
|
1991
|
+
)
|
|
1992
|
+
|
|
1993
|
+
# Set FFN weights
|
|
1994
|
+
self.set_linear(
|
|
1995
|
+
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
|
|
1996
|
+
)
|
|
1997
|
+
self.set_linear(
|
|
1998
|
+
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
|
|
1999
|
+
)
|
|
2000
|
+
self.set_linear(
|
|
2001
|
+
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
|
|
2002
|
+
)
|
|
2003
|
+
|
|
2004
|
+
delattr(layer, "self_attn")
|
|
2005
|
+
delattr(layer, "mlp")
|
|
2006
|
+
gc.collect()
|
|
2007
|
+
|
|
2008
|
+
|
|
1822
2009
|
@register_loader("MistralConfig")
|
|
1823
2010
|
class MistralLoader(ModelLoader):
|
|
1824
2011
|
@property
|
|
@@ -1996,6 +2183,28 @@ class Qwen2Loader(ModelLoader):
|
|
|
1996
2183
|
rotary_scaling_type = None
|
|
1997
2184
|
rotary_scaling_factor = 1
|
|
1998
2185
|
|
|
2186
|
+
# Check for AWQ quantization config
|
|
2187
|
+
quantization_config = getattr(model.config, "quantization_config", None)
|
|
2188
|
+
if quantization_config:
|
|
2189
|
+
quant_type = None
|
|
2190
|
+
if quantization_config.quant_method == "awq":
|
|
2191
|
+
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
|
|
2192
|
+
if quant_type is None:
|
|
2193
|
+
raise NotImplementedError(
|
|
2194
|
+
"Quantization type '%s' is not yet implemented. "
|
|
2195
|
+
"The following Quantization types are currently supported: %s"
|
|
2196
|
+
% (
|
|
2197
|
+
quantization_config.quant_method,
|
|
2198
|
+
", ".join(_SUPPORTED_QUANTIZATION.keys()),
|
|
2199
|
+
)
|
|
2200
|
+
)
|
|
2201
|
+
quant_group_size = quantization_config.group_size
|
|
2202
|
+
quant_bits = quantization_config.bits
|
|
2203
|
+
else:
|
|
2204
|
+
quant_type = common_spec.Quantization.CT2
|
|
2205
|
+
quant_group_size = None
|
|
2206
|
+
quant_bits = None
|
|
2207
|
+
|
|
1999
2208
|
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
2000
2209
|
num_layers,
|
|
2001
2210
|
num_heads,
|
|
@@ -2009,9 +2218,12 @@ class Qwen2Loader(ModelLoader):
|
|
|
2009
2218
|
rotary_scaling_factor=rotary_scaling_factor,
|
|
2010
2219
|
rotary_base=getattr(model.config, "rope_theta", 10000),
|
|
2011
2220
|
num_heads_kv=num_heads_kv,
|
|
2221
|
+
quant_type=quant_type,
|
|
2222
|
+
quant_group_size=quant_group_size,
|
|
2223
|
+
quant_bits=quant_bits,
|
|
2012
2224
|
)
|
|
2013
2225
|
|
|
2014
|
-
self.set_decoder(spec.decoder, model.model)
|
|
2226
|
+
self.set_decoder(spec.decoder, model.model, quant_type)
|
|
2015
2227
|
self.set_linear(spec.decoder.projection, model.lm_head)
|
|
2016
2228
|
return spec
|
|
2017
2229
|
|
|
@@ -2041,7 +2253,7 @@ class Qwen2Loader(ModelLoader):
|
|
|
2041
2253
|
def set_layer_norm(self, spec, layer_norm):
|
|
2042
2254
|
spec.gamma = layer_norm.weight
|
|
2043
2255
|
|
|
2044
|
-
def set_decoder(self, spec, module):
|
|
2256
|
+
def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
|
|
2045
2257
|
spec.scale_embeddings = False
|
|
2046
2258
|
self.set_embeddings(spec.embeddings, module.embed_tokens)
|
|
2047
2259
|
self.set_layer_norm(spec.layer_norm, module.norm)
|
|
@@ -2055,72 +2267,255 @@ class Qwen2Loader(ModelLoader):
|
|
|
2055
2267
|
)
|
|
2056
2268
|
|
|
2057
2269
|
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
2058
|
-
self.set_linear(
|
|
2059
|
-
|
|
2060
|
-
|
|
2270
|
+
self.set_linear(
|
|
2271
|
+
split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
|
|
2272
|
+
)
|
|
2273
|
+
self.set_linear(
|
|
2274
|
+
split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
|
|
2275
|
+
)
|
|
2276
|
+
self.set_linear(
|
|
2277
|
+
split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
|
|
2278
|
+
)
|
|
2279
|
+
|
|
2280
|
+
if quant_type == common_spec.Quantization.CT2:
|
|
2281
|
+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
|
|
2282
|
+
else:
|
|
2283
|
+
cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
|
|
2284
|
+
utils.fuse_linear_prequant(
|
|
2285
|
+
layer_spec.self_attention.linear[0], split_layers, cc_dim
|
|
2286
|
+
)
|
|
2061
2287
|
|
|
2062
|
-
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
|
|
2063
2288
|
self.set_linear(
|
|
2064
2289
|
layer_spec.self_attention.linear[1],
|
|
2065
2290
|
layer.self_attn.o_proj,
|
|
2291
|
+
quant_type=quant_type,
|
|
2066
2292
|
)
|
|
2067
2293
|
|
|
2068
|
-
self.set_linear(
|
|
2069
|
-
|
|
2070
|
-
|
|
2294
|
+
self.set_linear(
|
|
2295
|
+
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
|
|
2296
|
+
)
|
|
2297
|
+
self.set_linear(
|
|
2298
|
+
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
|
|
2299
|
+
)
|
|
2300
|
+
self.set_linear(
|
|
2301
|
+
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
|
|
2302
|
+
)
|
|
2071
2303
|
|
|
2072
2304
|
delattr(layer, "self_attn")
|
|
2073
2305
|
delattr(layer, "mlp")
|
|
2074
2306
|
gc.collect()
|
|
2075
2307
|
|
|
2076
2308
|
|
|
2077
|
-
@register_loader("
|
|
2078
|
-
class
|
|
2309
|
+
@register_loader("Qwen3Config")
|
|
2310
|
+
class Qwen3Loader(ModelLoader):
|
|
2079
2311
|
@property
|
|
2080
2312
|
def architecture_name(self):
|
|
2081
|
-
return "
|
|
2313
|
+
return "Qwen3ForCausalLM"
|
|
2082
2314
|
|
|
2083
2315
|
def get_model_spec(self, model):
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
rotary_dim=model.config.rotary_dim,
|
|
2090
|
-
rotary_interleave=False,
|
|
2091
|
-
parallel_residual=True,
|
|
2092
|
-
shared_layer_norm=True,
|
|
2316
|
+
num_layers = model.config.num_hidden_layers
|
|
2317
|
+
num_heads = model.config.num_attention_heads
|
|
2318
|
+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
|
|
2319
|
+
head_dim = getattr(
|
|
2320
|
+
model.config, "head_dim", model.config.hidden_size // num_heads
|
|
2093
2321
|
)
|
|
2094
2322
|
|
|
2095
|
-
|
|
2096
|
-
|
|
2097
|
-
return spec
|
|
2098
|
-
|
|
2099
|
-
def get_vocabulary(self, model, tokenizer):
|
|
2100
|
-
tokens = super().get_vocabulary(model, tokenizer)
|
|
2101
|
-
|
|
2102
|
-
extra_ids = model.config.vocab_size - len(tokens)
|
|
2103
|
-
for i in range(extra_ids):
|
|
2104
|
-
tokens.append("<extra_id_%d>" % i)
|
|
2105
|
-
|
|
2106
|
-
return tokens
|
|
2107
|
-
|
|
2108
|
-
def set_vocabulary(self, spec, tokens):
|
|
2109
|
-
spec.register_vocabulary(tokens)
|
|
2110
|
-
|
|
2111
|
-
def set_config(self, config, model, tokenizer):
|
|
2112
|
-
config.bos_token = tokenizer.bos_token
|
|
2113
|
-
config.eos_token = tokenizer.eos_token
|
|
2114
|
-
config.unk_token = tokenizer.unk_token
|
|
2323
|
+
if num_heads_kv == num_heads:
|
|
2324
|
+
num_heads_kv = None
|
|
2115
2325
|
|
|
2116
|
-
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
|
|
2326
|
+
rope_scaling = getattr(model.config, "rope_scaling", None)
|
|
2327
|
+
if rope_scaling:
|
|
2328
|
+
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
|
|
2329
|
+
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
|
|
2330
|
+
rotary_scaling_factor = rope_scaling["factor"]
|
|
2331
|
+
if rotary_scaling_type is None:
|
|
2332
|
+
raise NotImplementedError(
|
|
2333
|
+
"RoPE scaling type '%s' is not yet implemented. "
|
|
2334
|
+
"The following RoPE scaling types are currently supported: %s"
|
|
2335
|
+
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
|
|
2336
|
+
)
|
|
2337
|
+
else:
|
|
2338
|
+
rotary_scaling_type = None
|
|
2339
|
+
rotary_scaling_factor = 1
|
|
2120
2340
|
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
|
|
2341
|
+
# Check for AWQ quantization config
|
|
2342
|
+
quantization_config = getattr(model.config, "quantization_config", None)
|
|
2343
|
+
if quantization_config:
|
|
2344
|
+
quant_type = None
|
|
2345
|
+
if quantization_config.quant_method == "awq":
|
|
2346
|
+
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
|
|
2347
|
+
if quant_type is None:
|
|
2348
|
+
raise NotImplementedError(
|
|
2349
|
+
"Quantization type '%s' is not yet implemented. "
|
|
2350
|
+
"The following Quantization types are currently supported: %s"
|
|
2351
|
+
% (
|
|
2352
|
+
quantization_config.quant_method,
|
|
2353
|
+
", ".join(_SUPPORTED_QUANTIZATION.keys()),
|
|
2354
|
+
)
|
|
2355
|
+
)
|
|
2356
|
+
quant_group_size = quantization_config.group_size
|
|
2357
|
+
quant_bits = quantization_config.bits
|
|
2358
|
+
else:
|
|
2359
|
+
quant_type = common_spec.Quantization.CT2
|
|
2360
|
+
quant_group_size = None
|
|
2361
|
+
quant_bits = None
|
|
2362
|
+
|
|
2363
|
+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
2364
|
+
num_layers,
|
|
2365
|
+
num_heads,
|
|
2366
|
+
activation=common_spec.Activation.SWISH,
|
|
2367
|
+
pre_norm=True,
|
|
2368
|
+
ffn_glu=True,
|
|
2369
|
+
rms_norm=True,
|
|
2370
|
+
rotary_dim=model.config.head_dim,
|
|
2371
|
+
rotary_interleave=False,
|
|
2372
|
+
rotary_scaling_type=rotary_scaling_type,
|
|
2373
|
+
rotary_scaling_factor=rotary_scaling_factor,
|
|
2374
|
+
rotary_base=getattr(model.config, "rope_theta", 10000),
|
|
2375
|
+
num_heads_kv=num_heads_kv,
|
|
2376
|
+
head_dim=head_dim,
|
|
2377
|
+
qk_norm=True,
|
|
2378
|
+
quant_type=quant_type,
|
|
2379
|
+
quant_group_size=quant_group_size,
|
|
2380
|
+
quant_bits=quant_bits,
|
|
2381
|
+
)
|
|
2382
|
+
|
|
2383
|
+
self.set_decoder(spec.decoder, model.model, quant_type)
|
|
2384
|
+
self.set_linear(spec.decoder.projection, model.lm_head)
|
|
2385
|
+
return spec
|
|
2386
|
+
|
|
2387
|
+
def get_vocabulary(self, model, tokenizer):
|
|
2388
|
+
tokens = super().get_vocabulary(model, tokenizer)
|
|
2389
|
+
extra_ids = model.config.vocab_size - len(tokens)
|
|
2390
|
+
for i in range(extra_ids):
|
|
2391
|
+
tokens.append("<extra_id_%d>" % i)
|
|
2392
|
+
return tokens
|
|
2393
|
+
|
|
2394
|
+
def set_vocabulary(self, spec, tokens):
|
|
2395
|
+
spec.register_vocabulary(tokens)
|
|
2396
|
+
|
|
2397
|
+
def set_config(self, config, model, tokenizer):
|
|
2398
|
+
config.bos_token = (
|
|
2399
|
+
tokenizer.bos_token
|
|
2400
|
+
if tokenizer.bos_token is not None
|
|
2401
|
+
else tokenizer.pad_token
|
|
2402
|
+
)
|
|
2403
|
+
config.eos_token = tokenizer.eos_token
|
|
2404
|
+
config.unk_token = (
|
|
2405
|
+
tokenizer.unk_token if tokenizer.unk_token is not None else ""
|
|
2406
|
+
)
|
|
2407
|
+
config.layer_norm_epsilon = model.config.rms_norm_eps
|
|
2408
|
+
|
|
2409
|
+
def set_layer_norm(self, spec, layer_norm):
|
|
2410
|
+
spec.gamma = layer_norm.weight
|
|
2411
|
+
|
|
2412
|
+
def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
|
|
2413
|
+
spec.scale_embeddings = False
|
|
2414
|
+
self.set_embeddings(spec.embeddings, module.embed_tokens)
|
|
2415
|
+
self.set_layer_norm(spec.layer_norm, module.norm)
|
|
2416
|
+
|
|
2417
|
+
for layer_idx, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
|
|
2418
|
+
self.set_layer_norm(
|
|
2419
|
+
layer_spec.self_attention.layer_norm, layer.input_layernorm
|
|
2420
|
+
)
|
|
2421
|
+
self.set_layer_norm(
|
|
2422
|
+
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
|
|
2423
|
+
)
|
|
2424
|
+
|
|
2425
|
+
self.set_layer_norm(
|
|
2426
|
+
layer_spec.self_attention.q_norm, layer.self_attn.q_norm
|
|
2427
|
+
)
|
|
2428
|
+
self.set_layer_norm(
|
|
2429
|
+
layer_spec.self_attention.k_norm, layer.self_attn.k_norm
|
|
2430
|
+
)
|
|
2431
|
+
|
|
2432
|
+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
2433
|
+
self.set_linear(
|
|
2434
|
+
split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
|
|
2435
|
+
)
|
|
2436
|
+
self.set_linear(
|
|
2437
|
+
split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
|
|
2438
|
+
)
|
|
2439
|
+
self.set_linear(
|
|
2440
|
+
split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
|
|
2441
|
+
)
|
|
2442
|
+
|
|
2443
|
+
if quant_type == common_spec.Quantization.CT2:
|
|
2444
|
+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
|
|
2445
|
+
else:
|
|
2446
|
+
cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
|
|
2447
|
+
utils.fuse_linear_prequant(
|
|
2448
|
+
layer_spec.self_attention.linear[0], split_layers, cc_dim
|
|
2449
|
+
)
|
|
2450
|
+
|
|
2451
|
+
self.set_linear(
|
|
2452
|
+
layer_spec.self_attention.linear[1],
|
|
2453
|
+
layer.self_attn.o_proj,
|
|
2454
|
+
quant_type=quant_type,
|
|
2455
|
+
)
|
|
2456
|
+
|
|
2457
|
+
self.set_linear(
|
|
2458
|
+
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
|
|
2459
|
+
)
|
|
2460
|
+
self.set_linear(
|
|
2461
|
+
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
|
|
2462
|
+
)
|
|
2463
|
+
self.set_linear(
|
|
2464
|
+
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
|
|
2465
|
+
)
|
|
2466
|
+
|
|
2467
|
+
delattr(layer, "self_attn")
|
|
2468
|
+
delattr(layer, "mlp")
|
|
2469
|
+
gc.collect()
|
|
2470
|
+
|
|
2471
|
+
|
|
2472
|
+
@register_loader("MixFormerSequentialConfig")
|
|
2473
|
+
class MixFormerSequentialLoader(ModelLoader):
|
|
2474
|
+
@property
|
|
2475
|
+
def architecture_name(self):
|
|
2476
|
+
return "AutoModelForCausalLM"
|
|
2477
|
+
|
|
2478
|
+
def get_model_spec(self, model):
|
|
2479
|
+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
2480
|
+
num_layers=model.config.n_layer,
|
|
2481
|
+
num_heads=model.config.n_head,
|
|
2482
|
+
pre_norm=True,
|
|
2483
|
+
activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
|
|
2484
|
+
rotary_dim=model.config.rotary_dim,
|
|
2485
|
+
rotary_interleave=False,
|
|
2486
|
+
parallel_residual=True,
|
|
2487
|
+
shared_layer_norm=True,
|
|
2488
|
+
)
|
|
2489
|
+
|
|
2490
|
+
self.set_decoder(spec.decoder, model.layers)
|
|
2491
|
+
self.set_linear(spec.decoder.projection, model.layers[-1].linear)
|
|
2492
|
+
return spec
|
|
2493
|
+
|
|
2494
|
+
def get_vocabulary(self, model, tokenizer):
|
|
2495
|
+
tokens = super().get_vocabulary(model, tokenizer)
|
|
2496
|
+
|
|
2497
|
+
extra_ids = model.config.vocab_size - len(tokens)
|
|
2498
|
+
for i in range(extra_ids):
|
|
2499
|
+
tokens.append("<extra_id_%d>" % i)
|
|
2500
|
+
|
|
2501
|
+
return tokens
|
|
2502
|
+
|
|
2503
|
+
def set_vocabulary(self, spec, tokens):
|
|
2504
|
+
spec.register_vocabulary(tokens)
|
|
2505
|
+
|
|
2506
|
+
def set_config(self, config, model, tokenizer):
|
|
2507
|
+
config.bos_token = tokenizer.bos_token
|
|
2508
|
+
config.eos_token = tokenizer.eos_token
|
|
2509
|
+
config.unk_token = tokenizer.unk_token
|
|
2510
|
+
|
|
2511
|
+
def set_decoder(self, spec, module):
|
|
2512
|
+
spec.scale_embeddings = False
|
|
2513
|
+
self.set_embeddings(spec.embeddings, module[0].wte)
|
|
2514
|
+
self.set_layer_norm(spec.layer_norm, module[-1].ln)
|
|
2515
|
+
|
|
2516
|
+
for layer_spec, layer in zip(spec.layer, module[1:-1]):
|
|
2517
|
+
self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln)
|
|
2518
|
+
self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv)
|
|
2124
2519
|
self.set_linear(layer_spec.self_attention.linear[1], layer.mixer.out_proj)
|
|
2125
2520
|
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc1)
|
|
2126
2521
|
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
|
|
@@ -2211,6 +2606,28 @@ class Phi3Loader(ModelLoader):
|
|
|
2211
2606
|
rotary_scaling_type = None
|
|
2212
2607
|
rotary_scaling_factor = 1
|
|
2213
2608
|
|
|
2609
|
+
# Check for AWQ quantization config
|
|
2610
|
+
quantization_config = getattr(model.config, "quantization_config", None)
|
|
2611
|
+
if quantization_config:
|
|
2612
|
+
quant_type = None
|
|
2613
|
+
if quantization_config.quant_method == "awq":
|
|
2614
|
+
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
|
|
2615
|
+
if quant_type is None:
|
|
2616
|
+
raise NotImplementedError(
|
|
2617
|
+
"Quantization type '%s' is not yet implemented. "
|
|
2618
|
+
"The following Quantization types are currently supported: %s"
|
|
2619
|
+
% (
|
|
2620
|
+
quantization_config.quant_method,
|
|
2621
|
+
", ".join(_SUPPORTED_QUANTIZATION.keys()),
|
|
2622
|
+
)
|
|
2623
|
+
)
|
|
2624
|
+
quant_group_size = quantization_config.group_size
|
|
2625
|
+
quant_bits = quantization_config.bits
|
|
2626
|
+
else:
|
|
2627
|
+
quant_type = common_spec.Quantization.CT2
|
|
2628
|
+
quant_group_size = None
|
|
2629
|
+
quant_bits = None
|
|
2630
|
+
|
|
2214
2631
|
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
|
|
2215
2632
|
num_layers,
|
|
2216
2633
|
num_heads,
|
|
@@ -2226,9 +2643,12 @@ class Phi3Loader(ModelLoader):
|
|
|
2226
2643
|
original_max_position_embeddings=original_max_position_embeddings,
|
|
2227
2644
|
max_position_embeddings=max_position_embeddings,
|
|
2228
2645
|
num_heads_kv=num_heads_kv,
|
|
2646
|
+
quant_type=quant_type,
|
|
2647
|
+
quant_group_size=quant_group_size,
|
|
2648
|
+
quant_bits=quant_bits,
|
|
2229
2649
|
)
|
|
2230
2650
|
|
|
2231
|
-
self.set_decoder(spec.decoder, model.model)
|
|
2651
|
+
self.set_decoder(spec.decoder, model.model, quant_type)
|
|
2232
2652
|
self.set_linear(spec.decoder.projection, model.lm_head)
|
|
2233
2653
|
return spec
|
|
2234
2654
|
|
|
@@ -2262,7 +2682,7 @@ class Phi3Loader(ModelLoader):
|
|
|
2262
2682
|
rotary_scaling_short_factor, dtype=torch.float32
|
|
2263
2683
|
)
|
|
2264
2684
|
|
|
2265
|
-
def set_decoder(self, spec, module):
|
|
2685
|
+
def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
|
|
2266
2686
|
spec.scale_embeddings = False
|
|
2267
2687
|
self.set_embeddings(spec.embeddings, module.embed_tokens)
|
|
2268
2688
|
self.set_layer_norm(spec.layer_norm, module.norm)
|
|
@@ -2276,9 +2696,15 @@ class Phi3Loader(ModelLoader):
|
|
|
2276
2696
|
)
|
|
2277
2697
|
|
|
2278
2698
|
self.set_linear(
|
|
2279
|
-
layer_spec.self_attention.linear[0],
|
|
2699
|
+
layer_spec.self_attention.linear[0],
|
|
2700
|
+
layer.self_attn.qkv_proj,
|
|
2701
|
+
quant_type=quant_type,
|
|
2702
|
+
)
|
|
2703
|
+
self.set_linear(
|
|
2704
|
+
layer_spec.self_attention.linear[1],
|
|
2705
|
+
layer.self_attn.o_proj,
|
|
2706
|
+
quant_type=quant_type,
|
|
2280
2707
|
)
|
|
2281
|
-
self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)
|
|
2282
2708
|
if (
|
|
2283
2709
|
layer.self_attn.rotary_emb.long_factor is not None
|
|
2284
2710
|
and layer.self_attn.rotary_emb.short_factor is not None
|
|
@@ -2289,10 +2715,30 @@ class Phi3Loader(ModelLoader):
|
|
|
2289
2715
|
layer.self_attn.rotary_emb.short_factor,
|
|
2290
2716
|
)
|
|
2291
2717
|
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2718
|
+
# Handle gate_up_proj differently for AWQ vs regular models
|
|
2719
|
+
if quant_type == common_spec.Quantization.CT2:
|
|
2720
|
+
gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
|
|
2721
|
+
layer_spec.ffn.linear_0.weight = gate_proj
|
|
2722
|
+
layer_spec.ffn.linear_0_noact.weight = up_proj
|
|
2723
|
+
else:
|
|
2724
|
+
# AWQ: chunk qweight, scales, and qzeros
|
|
2725
|
+
gate_qweight, up_qweight = layer.mlp.gate_up_proj.qweight.chunk(
|
|
2726
|
+
2, dim=1
|
|
2727
|
+
)
|
|
2728
|
+
gate_scales, up_scales = layer.mlp.gate_up_proj.scales.chunk(2, dim=1)
|
|
2729
|
+
gate_qzeros, up_qzeros = layer.mlp.gate_up_proj.qzeros.chunk(2, dim=1)
|
|
2730
|
+
|
|
2731
|
+
layer_spec.ffn.linear_0.weight = gate_qweight
|
|
2732
|
+
layer_spec.ffn.linear_0.weight_scale = gate_scales
|
|
2733
|
+
layer_spec.ffn.linear_0.weight_zero = gate_qzeros
|
|
2734
|
+
|
|
2735
|
+
layer_spec.ffn.linear_0_noact.weight = up_qweight
|
|
2736
|
+
layer_spec.ffn.linear_0_noact.weight_scale = up_scales
|
|
2737
|
+
layer_spec.ffn.linear_0_noact.weight_zero = up_qzeros
|
|
2738
|
+
|
|
2739
|
+
self.set_linear(
|
|
2740
|
+
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
|
|
2741
|
+
)
|
|
2296
2742
|
|
|
2297
2743
|
delattr(layer, "self_attn")
|
|
2298
2744
|
delattr(layer, "mlp")
|
|
@@ -3022,3 +3468,266 @@ _WHISPER_ALIGNMENT_HEADS = {
|
|
|
3022
3468
|
(25, 6),
|
|
3023
3469
|
],
|
|
3024
3470
|
}
|
|
3471
|
+
|
|
3472
|
+
|
|
3473
|
+
# Paper: https://arxiv.org/pdf/2504.06225
|
|
3474
|
+
@register_loader("T5GemmaConfig")
|
|
3475
|
+
class T5GemmaLoader(ModelLoader):
|
|
3476
|
+
@property
|
|
3477
|
+
def architecture_name(self):
|
|
3478
|
+
return "T5GemmaForConditionalGeneration"
|
|
3479
|
+
|
|
3480
|
+
def set_layer_norm(self, spec, layer_norm):
|
|
3481
|
+
spec.gamma = layer_norm.weight.data + 1.0
|
|
3482
|
+
|
|
3483
|
+
def get_model_spec(self, model):
|
|
3484
|
+
encoder_config = model.config.encoder
|
|
3485
|
+
decoder_config = model.config.decoder
|
|
3486
|
+
sliding_window = getattr(model.config, "sliding_window", 4096)
|
|
3487
|
+
|
|
3488
|
+
encoder_num_heads = encoder_config.num_attention_heads
|
|
3489
|
+
encoder_num_heads_kv = getattr(
|
|
3490
|
+
encoder_config, "num_key_value_heads", encoder_num_heads
|
|
3491
|
+
)
|
|
3492
|
+
if encoder_num_heads_kv == encoder_num_heads:
|
|
3493
|
+
encoder_num_heads_kv = None
|
|
3494
|
+
|
|
3495
|
+
encoder = transformer_spec.TransformerEncoderSpec(
|
|
3496
|
+
encoder_config.num_hidden_layers,
|
|
3497
|
+
encoder_config.num_attention_heads,
|
|
3498
|
+
pre_norm=True,
|
|
3499
|
+
activation=_SUPPORTED_ACTIVATIONS[encoder_config.hidden_activation],
|
|
3500
|
+
ffn_glu=True,
|
|
3501
|
+
rms_norm=True,
|
|
3502
|
+
rotary_dim=encoder_config.head_dim,
|
|
3503
|
+
rotary_interleave=False,
|
|
3504
|
+
rotary_base=getattr(encoder_config, "rope_theta", 10000),
|
|
3505
|
+
sliding_window=sliding_window,
|
|
3506
|
+
pre_post_layer_norm=True,
|
|
3507
|
+
num_heads_kv=encoder_num_heads_kv,
|
|
3508
|
+
head_dim=encoder_config.head_dim,
|
|
3509
|
+
)
|
|
3510
|
+
|
|
3511
|
+
decoder_num_heads = decoder_config.num_attention_heads
|
|
3512
|
+
decoder_num_heads_kv = getattr(
|
|
3513
|
+
decoder_config, "num_key_value_heads", decoder_num_heads
|
|
3514
|
+
)
|
|
3515
|
+
if decoder_num_heads_kv == decoder_num_heads:
|
|
3516
|
+
decoder_num_heads_kv = None
|
|
3517
|
+
|
|
3518
|
+
decoder = transformer_spec.TransformerDecoderSpec(
|
|
3519
|
+
decoder_config.num_hidden_layers,
|
|
3520
|
+
decoder_config.num_attention_heads,
|
|
3521
|
+
pre_norm=True,
|
|
3522
|
+
activation=_SUPPORTED_ACTIVATIONS[decoder_config.hidden_activation],
|
|
3523
|
+
ffn_glu=True,
|
|
3524
|
+
rms_norm=True,
|
|
3525
|
+
with_encoder_attention=True,
|
|
3526
|
+
rotary_dim=decoder_config.head_dim,
|
|
3527
|
+
rotary_interleave=False,
|
|
3528
|
+
rotary_base=getattr(decoder_config, "rope_theta", 10000),
|
|
3529
|
+
sliding_window=sliding_window,
|
|
3530
|
+
pre_post_layer_norm=True,
|
|
3531
|
+
external_pre_post_encoder_layers=True,
|
|
3532
|
+
num_heads_kv=decoder_num_heads_kv,
|
|
3533
|
+
head_dim=decoder_config.head_dim,
|
|
3534
|
+
)
|
|
3535
|
+
|
|
3536
|
+
spec = transformer_spec.TransformerSpec(encoder, decoder)
|
|
3537
|
+
|
|
3538
|
+
self.set_encoder(spec.encoder, model.model.encoder, encoder_config)
|
|
3539
|
+
|
|
3540
|
+
self.set_decoder(
|
|
3541
|
+
spec.decoder,
|
|
3542
|
+
model.model.decoder,
|
|
3543
|
+
decoder_config,
|
|
3544
|
+
common_spec.Quantization.CT2,
|
|
3545
|
+
)
|
|
3546
|
+
|
|
3547
|
+
# Tie_word_embeddings
|
|
3548
|
+
self.set_linear(spec.decoder.projection, model.model.decoder.embed_tokens)
|
|
3549
|
+
return spec
|
|
3550
|
+
|
|
3551
|
+
def set_vocabulary(self, spec, tokens):
|
|
3552
|
+
spec.register_source_vocabulary(tokens)
|
|
3553
|
+
spec.register_target_vocabulary(tokens)
|
|
3554
|
+
|
|
3555
|
+
def set_config(self, config, model, tokenizer):
|
|
3556
|
+
config.bos_token = tokenizer.bos_token
|
|
3557
|
+
config.eos_token = tokenizer.eos_token
|
|
3558
|
+
config.unk_token = tokenizer.unk_token
|
|
3559
|
+
|
|
3560
|
+
if hasattr(model.config, "encoder"):
|
|
3561
|
+
config.layer_norm_epsilon = model.config.encoder.rms_norm_eps
|
|
3562
|
+
elif hasattr(model.config, "rms_norm_eps"):
|
|
3563
|
+
config.layer_norm_epsilon = model.config.rms_norm_eps
|
|
3564
|
+
else:
|
|
3565
|
+
config.layer_norm_epsilon = 1e-6
|
|
3566
|
+
|
|
3567
|
+
config.decoder_start_token = tokenizer.bos_token
|
|
3568
|
+
|
|
3569
|
+
def set_encoder(
|
|
3570
|
+
self, spec, encoder, encoder_config, quant_type=common_spec.Quantization.CT2
|
|
3571
|
+
):
|
|
3572
|
+
spec.scale_embeddings = True
|
|
3573
|
+
|
|
3574
|
+
encoder_emb_spec = (
|
|
3575
|
+
spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
|
|
3576
|
+
)
|
|
3577
|
+
|
|
3578
|
+
self.set_embeddings(encoder_emb_spec, encoder.embed_tokens)
|
|
3579
|
+
encoder_emb_spec.multiply_by_sqrt_depth = encoder_config.hidden_size**0.5
|
|
3580
|
+
self.set_layer_norm(spec.layer_norm, encoder.norm)
|
|
3581
|
+
|
|
3582
|
+
module = encoder
|
|
3583
|
+
for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
|
|
3584
|
+
self.set_layer_norm(
|
|
3585
|
+
layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
|
|
3586
|
+
)
|
|
3587
|
+
self.set_layer_norm(
|
|
3588
|
+
layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
|
|
3589
|
+
)
|
|
3590
|
+
|
|
3591
|
+
# T5GemmaSelfAttention
|
|
3592
|
+
qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
3593
|
+
self.set_linear(
|
|
3594
|
+
qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
|
|
3595
|
+
)
|
|
3596
|
+
self.set_linear(
|
|
3597
|
+
qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
|
|
3598
|
+
)
|
|
3599
|
+
self.set_linear(
|
|
3600
|
+
qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
|
|
3601
|
+
)
|
|
3602
|
+
utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
|
|
3603
|
+
self.set_linear(
|
|
3604
|
+
layer_spec.self_attention.linear[1],
|
|
3605
|
+
layer.self_attn.o_proj,
|
|
3606
|
+
quant_type=quant_type,
|
|
3607
|
+
)
|
|
3608
|
+
|
|
3609
|
+
# T5GemmaRMSNorm
|
|
3610
|
+
self.set_layer_norm(
|
|
3611
|
+
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
|
|
3612
|
+
)
|
|
3613
|
+
# T5GemmaRMSNorm
|
|
3614
|
+
self.set_layer_norm(
|
|
3615
|
+
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
|
|
3616
|
+
)
|
|
3617
|
+
|
|
3618
|
+
# T5GemmaMLP
|
|
3619
|
+
self.set_linear(
|
|
3620
|
+
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
|
|
3621
|
+
)
|
|
3622
|
+
self.set_linear(
|
|
3623
|
+
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
|
|
3624
|
+
)
|
|
3625
|
+
self.set_linear(
|
|
3626
|
+
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
|
|
3627
|
+
)
|
|
3628
|
+
|
|
3629
|
+
# Clean up
|
|
3630
|
+
delattr(layer, "self_attn")
|
|
3631
|
+
delattr(layer, "mlp")
|
|
3632
|
+
gc.collect()
|
|
3633
|
+
|
|
3634
|
+
def set_decoder(
|
|
3635
|
+
self, spec, module, decoder_config, quant_type=common_spec.Quantization.CT2
|
|
3636
|
+
):
|
|
3637
|
+
spec.scale_embeddings = True
|
|
3638
|
+
spec.start_from_zero_embedding = False
|
|
3639
|
+
|
|
3640
|
+
self.set_embeddings(spec.embeddings, module.embed_tokens)
|
|
3641
|
+
spec.embeddings.multiply_by_sqrt_depth = decoder_config.hidden_size**0.5
|
|
3642
|
+
self.set_layer_norm(spec.layer_norm, module.norm)
|
|
3643
|
+
|
|
3644
|
+
for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
|
|
3645
|
+
# Self-attention block
|
|
3646
|
+
self.set_layer_norm(
|
|
3647
|
+
layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
|
|
3648
|
+
)
|
|
3649
|
+
self.set_layer_norm(
|
|
3650
|
+
layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
|
|
3651
|
+
)
|
|
3652
|
+
|
|
3653
|
+
# T5GemmaSelfAttention - QKV projections
|
|
3654
|
+
qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
|
|
3655
|
+
self.set_linear(
|
|
3656
|
+
qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
|
|
3657
|
+
)
|
|
3658
|
+
self.set_linear(
|
|
3659
|
+
qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
|
|
3660
|
+
)
|
|
3661
|
+
self.set_linear(
|
|
3662
|
+
qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
|
|
3663
|
+
)
|
|
3664
|
+
utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
|
|
3665
|
+
self.set_linear(
|
|
3666
|
+
layer_spec.self_attention.linear[1],
|
|
3667
|
+
layer.self_attn.o_proj,
|
|
3668
|
+
quant_type=quant_type,
|
|
3669
|
+
)
|
|
3670
|
+
|
|
3671
|
+
# Pre and post cross-attention layer norm
|
|
3672
|
+
self.set_layer_norm(
|
|
3673
|
+
layer_spec.external_pre_encoder_attention_layer_norm,
|
|
3674
|
+
layer.pre_cross_attn_layernorm,
|
|
3675
|
+
)
|
|
3676
|
+
|
|
3677
|
+
self.set_layer_norm(
|
|
3678
|
+
layer_spec.external_post_encoder_attention_layer_norm,
|
|
3679
|
+
layer.post_cross_attn_layernorm,
|
|
3680
|
+
)
|
|
3681
|
+
|
|
3682
|
+
# Cross-attention Q projection
|
|
3683
|
+
self.set_linear(
|
|
3684
|
+
layer_spec.attention.linear[0],
|
|
3685
|
+
layer.cross_attn.q_proj,
|
|
3686
|
+
quant_type=quant_type,
|
|
3687
|
+
)
|
|
3688
|
+
|
|
3689
|
+
# Cross-attention K+V fused
|
|
3690
|
+
kv_split_layers = [common_spec.LinearSpec() for _ in range(2)]
|
|
3691
|
+
self.set_linear(
|
|
3692
|
+
kv_split_layers[0],
|
|
3693
|
+
layer.cross_attn.k_proj,
|
|
3694
|
+
quant_type=quant_type,
|
|
3695
|
+
)
|
|
3696
|
+
self.set_linear(
|
|
3697
|
+
kv_split_layers[1],
|
|
3698
|
+
layer.cross_attn.v_proj,
|
|
3699
|
+
quant_type=quant_type,
|
|
3700
|
+
)
|
|
3701
|
+
utils.fuse_linear(layer_spec.attention.linear[1], kv_split_layers)
|
|
3702
|
+
|
|
3703
|
+
# Cross-attention output projection
|
|
3704
|
+
self.set_linear(
|
|
3705
|
+
layer_spec.attention.linear[2],
|
|
3706
|
+
layer.cross_attn.o_proj,
|
|
3707
|
+
quant_type=quant_type,
|
|
3708
|
+
)
|
|
3709
|
+
|
|
3710
|
+
# Feed-forward block
|
|
3711
|
+
self.set_layer_norm(
|
|
3712
|
+
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
|
|
3713
|
+
)
|
|
3714
|
+
self.set_layer_norm(
|
|
3715
|
+
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
|
|
3716
|
+
)
|
|
3717
|
+
|
|
3718
|
+
# T5GemmaMLP
|
|
3719
|
+
self.set_linear(
|
|
3720
|
+
layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
|
|
3721
|
+
)
|
|
3722
|
+
self.set_linear(
|
|
3723
|
+
layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
|
|
3724
|
+
)
|
|
3725
|
+
self.set_linear(
|
|
3726
|
+
layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
|
|
3727
|
+
)
|
|
3728
|
+
|
|
3729
|
+
# Clean up
|
|
3730
|
+
delattr(layer, "self_attn")
|
|
3731
|
+
delattr(layer, "cross_attn")
|
|
3732
|
+
delattr(layer, "mlp")
|
|
3733
|
+
gc.collect()
|