ctranslate2 4.6.2__cp313-cp313-win_amd64.whl → 4.7.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ctranslate2/__init__.py CHANGED
@@ -21,6 +21,8 @@ if sys.platform == "win32":
21
21
  add_dll_directory = getattr(os, "add_dll_directory", None)
22
22
  if add_dll_directory is not None:
23
23
  add_dll_directory(package_dir)
24
+ add_dll_directory(f"{package_dir}/../_rocm_sdk_core/bin")
25
+ add_dll_directory(f"{package_dir}/../_rocm_sdk_libraries_custom/bin")
24
26
 
25
27
  for library in glob.glob(os.path.join(package_dir, "*.dll")):
26
28
  ctypes.CDLL(library)
Binary file
@@ -3,7 +3,7 @@ import argparse
3
3
  from eole.config.run import PredictConfig
4
4
  from eole.constants import PositionEncodingType
5
5
  from eole.inputters.inputter import vocabs_to_dict
6
- from eole.models.model import BaseModel
6
+ from eole.models.model import get_model_class
7
7
 
8
8
  from ctranslate2.converters import utils
9
9
  from ctranslate2.converters.converter import Converter
@@ -164,7 +164,8 @@ class EoleConverter(Converter):
164
164
 
165
165
  config = PredictConfig(model_path=self._model_path, src="dummy")
166
166
 
167
- vocabs, model, model_config = BaseModel.load_test_model(config)
167
+ model_class = get_model_class(config.model)
168
+ model, vocabs, model_config = model_class.for_inference(config)
168
169
  vocabs_dict = vocabs_to_dict(vocabs)
169
170
 
170
171
  config.model = model_config
@@ -146,7 +146,9 @@ class FairseqConverter(Converter):
146
146
  import_user_module(argparse.Namespace(user_dir=self._user_dir))
147
147
 
148
148
  with torch.no_grad():
149
- checkpoint = checkpoint_utils.load_checkpoint_to_cpu(self._model_path)
149
+ checkpoint = torch.load(
150
+ self._model_path, map_location=torch.device("cpu"), weights_only=False
151
+ )
150
152
  args = checkpoint["args"] or checkpoint["cfg"]["model"]
151
153
 
152
154
  args.data = self._data_dir
@@ -174,7 +174,9 @@ class OpenNMTPyConverter(Converter):
174
174
  def _load(self):
175
175
  import torch
176
176
 
177
- checkpoint = torch.load(self._model_path, map_location="cpu")
177
+ checkpoint = torch.load(
178
+ self._model_path, map_location="cpu", weights_only=False
179
+ )
178
180
 
179
181
  src_vocabs, tgt_vocabs = get_vocabs(checkpoint["vocab"])
180
182
 
@@ -89,7 +89,7 @@ class TransformersConverter(Converter):
89
89
  copy_files: List of filenames to copy from the Hugging Face model to the
90
90
  converted model directory.
91
91
  load_as_float16: Load the model weights as float16. More precisely, the model
92
- will be loaded with ``from_pretrained(..., torch_dtype=torch.float16)``.
92
+ will be loaded with ``from_pretrained(..., dtype=torch.float16)``.
93
93
  revision: Revision of the model to download from the Hugging Face Hub.
94
94
  low_cpu_mem_usage: Enable the flag ``low_cpu_mem_usage`` when loading the model
95
95
  with ``from_pretrained``.
@@ -123,10 +123,11 @@ class TransformersConverter(Converter):
123
123
  tokenizer_class = transformers.AutoTokenizer
124
124
 
125
125
  kwargs = {
126
- "torch_dtype": (
126
+ "dtype": (
127
127
  torch.float16
128
128
  if self._load_as_float16
129
- else getattr(config, "torch_dtype", None)
129
+ else getattr(config, "dtype", None)
130
+ or getattr(config, "torch_dtype", None)
130
131
  )
131
132
  }
132
133
 
@@ -235,7 +236,7 @@ class ModelLoader(abc.ABC):
235
236
 
236
237
  if isinstance(module, transformers.Conv1D):
237
238
  spec.weight = spec.weight.transpose(0, 1)
238
- if module.bias is not None:
239
+ if hasattr(module, "bias") and module.bias is not None:
239
240
  spec.bias = module.bias
240
241
 
241
242
  def set_embeddings(self, spec, module):
@@ -252,6 +253,30 @@ class ModelLoader(abc.ABC):
252
253
  "No activation smoothing logic is defined for this model"
253
254
  )
254
255
 
256
+ def get_rotary_params(self, config, default_rope_theta):
257
+ rope_scaling = getattr(config, "rope_scaling", None)
258
+ if rope_scaling:
259
+ rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
260
+
261
+ if rope_type == "default":
262
+ rotary_scaling_type = None
263
+ else:
264
+ rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
265
+ if rotary_scaling_type is None:
266
+ raise NotImplementedError(
267
+ "RoPE scaling type '%s' is not yet implemented. "
268
+ "The following RoPE scaling types are currently supported: %s"
269
+ % (rope_type, ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
270
+ )
271
+ rotary_scaling_factor = rope_scaling.get("factor", 1)
272
+ rope_theta = rope_scaling.get("rope_theta", default_rope_theta)
273
+ else:
274
+ rotary_scaling_type = None
275
+ rotary_scaling_factor = 1
276
+ rope_theta = getattr(config, "rope_theta", default_rope_theta)
277
+
278
+ return rotary_scaling_type, rotary_scaling_factor, rope_theta
279
+
255
280
 
256
281
  @register_loader("BartConfig")
257
282
  class BartLoader(ModelLoader):
@@ -462,7 +487,7 @@ class M2M100Loader(BartLoader):
462
487
  if tokens[-1] == tokenizer.unk_token:
463
488
  tokens.insert(tokenizer.unk_token_id, tokens.pop())
464
489
 
465
- for token in tokenizer.additional_special_tokens:
490
+ for token in tokenizer.special_tokens_map.get("additional_special_tokens", []):
466
491
  if token not in tokens:
467
492
  tokens.append(token)
468
493
 
@@ -487,7 +512,7 @@ class MBartLoader(BartLoader):
487
512
  config.unk_token = tokenizer.unk_token
488
513
 
489
514
  # MBart-25 passes the language code as the decoder start token.
490
- if model.config.tokenizer_class in ("MBartTokenizer", None):
515
+ if getattr(model.config, "tokenizer_class", None) in ("MBartTokenizer", None):
491
516
  config.decoder_start_token = None
492
517
  else:
493
518
  config.decoder_start_token = tokenizer.eos_token
@@ -927,12 +952,14 @@ class WhisperLoader(BartLoader):
927
952
  "<|nocaptions|>",
928
953
  "<|notimestamps|>",
929
954
  ]
955
+
956
+ additional_tokens = getattr(tokenizer, "additional_special_tokens", [])
957
+ if not additional_tokens:
958
+ return []
959
+
930
960
  return [
931
- token_id
932
- for token_id, token in zip(
933
- tokenizer.additional_special_tokens_ids,
934
- tokenizer.additional_special_tokens,
935
- )
961
+ tokenizer.convert_tokens_to_ids(token)
962
+ for token in additional_tokens
936
963
  if token not in non_lang_special_tokens
937
964
  ]
938
965
 
@@ -1673,21 +1700,9 @@ class LlamaLoader(ModelLoader):
1673
1700
  if num_heads_kv == num_heads:
1674
1701
  num_heads_kv = None
1675
1702
 
1676
- rope_scaling = getattr(model.config, "rope_scaling", None)
1677
- if rope_scaling:
1678
- rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
1679
- rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
1680
- rotary_scaling_factor = rope_scaling["factor"]
1681
-
1682
- if rotary_scaling_type is None:
1683
- raise NotImplementedError(
1684
- "RoPE scaling type '%s' is not yet implemented. "
1685
- "The following RoPE scaling types are currently supported: %s"
1686
- % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
1687
- )
1688
- else:
1689
- rotary_scaling_type = None
1690
- rotary_scaling_factor = 1
1703
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
1704
+ model.config, 10_000
1705
+ )
1691
1706
 
1692
1707
  quantization_config = getattr(model.config, "quantization_config", None)
1693
1708
  if quantization_config:
@@ -1721,7 +1736,7 @@ class LlamaLoader(ModelLoader):
1721
1736
  rotary_interleave=False,
1722
1737
  rotary_scaling_type=rotary_scaling_type,
1723
1738
  rotary_scaling_factor=rotary_scaling_factor,
1724
- rotary_base=getattr(model.config, "rope_theta", 10000),
1739
+ rotary_base=rope_theta,
1725
1740
  num_heads_kv=num_heads_kv,
1726
1741
  quant_type=quant_type,
1727
1742
  quant_group_size=quant_group_size,
@@ -1732,6 +1747,7 @@ class LlamaLoader(ModelLoader):
1732
1747
  self.set_linear(spec.decoder.projection, model.lm_head)
1733
1748
 
1734
1749
  # set extra RoPE parameters for Llama-3.1
1750
+ rope_scaling = getattr(model.config, "rope_scaling", None)
1735
1751
  if rotary_scaling_type == attention_spec.RotaryScalingType.Llama3:
1736
1752
  for layer in spec.decoder.layer:
1737
1753
  layer.self_attention.rotary_low_freq_factor = rope_scaling[
@@ -1858,8 +1874,12 @@ class Gemma3Loader(ModelLoader):
1858
1874
  "Quantization type '%s' is not yet implemented."
1859
1875
  % quantization_config.quant_method
1860
1876
  )
1877
+ quant_group_size = quantization_config.group_size
1878
+ quant_bits = quantization_config.bits
1861
1879
  else:
1862
1880
  quant_type = common_spec.Quantization.CT2
1881
+ quant_group_size = None
1882
+ quant_bits = None
1863
1883
 
1864
1884
  # Create base spec using from_config
1865
1885
  spec = transformer_spec.TransformerDecoderModelSpec.from_config(
@@ -1880,6 +1900,9 @@ class Gemma3Loader(ModelLoader):
1880
1900
  head_dim=head_dim,
1881
1901
  sliding_window=sliding_window, # Default to local sliding window
1882
1902
  pre_post_layer_norm=True,
1903
+ quant_type=quant_type,
1904
+ quant_group_size=quant_group_size,
1905
+ quant_bits=quant_bits,
1883
1906
  qk_norm=True,
1884
1907
  )
1885
1908
 
@@ -1932,7 +1955,8 @@ class Gemma3Loader(ModelLoader):
1932
1955
  config.eos_token = tokenizer.eos_token
1933
1956
 
1934
1957
  def set_layer_norm(self, spec, layer_norm):
1935
- spec.gamma = layer_norm.weight + 1.0
1958
+ spec.gamma = layer_norm.weight
1959
+ spec.layer_norm_use_residual = True
1936
1960
 
1937
1961
  def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
1938
1962
  spec.scale_embeddings = True
@@ -2021,20 +2045,9 @@ class MistralLoader(ModelLoader):
2021
2045
 
2022
2046
  sliding_window = getattr(model.config, "sliding_window", 0)
2023
2047
 
2024
- rope_scaling = getattr(model.config, "rope_scaling", None)
2025
- if rope_scaling:
2026
- rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
2027
- rotary_scaling_factor = rope_scaling["factor"]
2028
-
2029
- if rotary_scaling_type is None:
2030
- raise NotImplementedError(
2031
- "RoPE scaling type '%s' is not yet implemented. "
2032
- "The following RoPE scaling types are currently supported: %s"
2033
- % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
2034
- )
2035
- else:
2036
- rotary_scaling_type = None
2037
- rotary_scaling_factor = 1
2048
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
2049
+ model.config, 10_000
2050
+ )
2038
2051
 
2039
2052
  quantization_config = getattr(model.config, "quantization_config", None)
2040
2053
  if quantization_config:
@@ -2067,7 +2080,7 @@ class MistralLoader(ModelLoader):
2067
2080
  rotary_interleave=False,
2068
2081
  rotary_scaling_type=rotary_scaling_type,
2069
2082
  rotary_scaling_factor=rotary_scaling_factor,
2070
- rotary_base=getattr(model.config, "rope_theta", 10000),
2083
+ rotary_base=rope_theta,
2071
2084
  num_heads_kv=num_heads_kv,
2072
2085
  sliding_window=sliding_window,
2073
2086
  quant_type=quant_type,
@@ -2166,21 +2179,31 @@ class Qwen2Loader(ModelLoader):
2166
2179
  if num_heads_kv == num_heads:
2167
2180
  num_heads_kv = None
2168
2181
 
2169
- rope_scaling = getattr(model.config, "rope_scaling", None)
2170
- if rope_scaling:
2171
- rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
2172
- rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
2173
- rotary_scaling_factor = rope_scaling["factor"]
2182
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
2183
+ model.config, 10_000
2184
+ )
2174
2185
 
2175
- if rotary_scaling_type is None:
2186
+ # Check for AWQ quantization config
2187
+ quantization_config = getattr(model.config, "quantization_config", None)
2188
+ if quantization_config:
2189
+ quant_type = None
2190
+ if quantization_config.quant_method == "awq":
2191
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2192
+ if quant_type is None:
2176
2193
  raise NotImplementedError(
2177
- "RoPE scaling type '%s' is not yet implemented. "
2178
- "The following RoPE scaling types are currently supported: %s"
2179
- % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
2194
+ "Quantization type '%s' is not yet implemented. "
2195
+ "The following Quantization types are currently supported: %s"
2196
+ % (
2197
+ quantization_config.quant_method,
2198
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2199
+ )
2180
2200
  )
2201
+ quant_group_size = quantization_config.group_size
2202
+ quant_bits = quantization_config.bits
2181
2203
  else:
2182
- rotary_scaling_type = None
2183
- rotary_scaling_factor = 1
2204
+ quant_type = common_spec.Quantization.CT2
2205
+ quant_group_size = None
2206
+ quant_bits = None
2184
2207
 
2185
2208
  spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2186
2209
  num_layers,
@@ -2193,11 +2216,14 @@ class Qwen2Loader(ModelLoader):
2193
2216
  rotary_interleave=False,
2194
2217
  rotary_scaling_type=rotary_scaling_type,
2195
2218
  rotary_scaling_factor=rotary_scaling_factor,
2196
- rotary_base=getattr(model.config, "rope_theta", 10000),
2219
+ rotary_base=rope_theta,
2197
2220
  num_heads_kv=num_heads_kv,
2221
+ quant_type=quant_type,
2222
+ quant_group_size=quant_group_size,
2223
+ quant_bits=quant_bits,
2198
2224
  )
2199
2225
 
2200
- self.set_decoder(spec.decoder, model.model)
2226
+ self.set_decoder(spec.decoder, model.model, quant_type)
2201
2227
  self.set_linear(spec.decoder.projection, model.lm_head)
2202
2228
  return spec
2203
2229
 
@@ -2227,7 +2253,7 @@ class Qwen2Loader(ModelLoader):
2227
2253
  def set_layer_norm(self, spec, layer_norm):
2228
2254
  spec.gamma = layer_norm.weight
2229
2255
 
2230
- def set_decoder(self, spec, module):
2256
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2231
2257
  spec.scale_embeddings = False
2232
2258
  self.set_embeddings(spec.embeddings, module.embed_tokens)
2233
2259
  self.set_layer_norm(spec.layer_norm, module.norm)
@@ -2241,19 +2267,39 @@ class Qwen2Loader(ModelLoader):
2241
2267
  )
2242
2268
 
2243
2269
  split_layers = [common_spec.LinearSpec() for _ in range(3)]
2244
- self.set_linear(split_layers[0], layer.self_attn.q_proj)
2245
- self.set_linear(split_layers[1], layer.self_attn.k_proj)
2246
- self.set_linear(split_layers[2], layer.self_attn.v_proj)
2270
+ self.set_linear(
2271
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2272
+ )
2273
+ self.set_linear(
2274
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2275
+ )
2276
+ self.set_linear(
2277
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2278
+ )
2279
+
2280
+ if quant_type == common_spec.Quantization.CT2:
2281
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2282
+ else:
2283
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2284
+ utils.fuse_linear_prequant(
2285
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2286
+ )
2247
2287
 
2248
- utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2249
2288
  self.set_linear(
2250
2289
  layer_spec.self_attention.linear[1],
2251
2290
  layer.self_attn.o_proj,
2291
+ quant_type=quant_type,
2252
2292
  )
2253
2293
 
2254
- self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
2255
- self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
2256
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2294
+ self.set_linear(
2295
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2296
+ )
2297
+ self.set_linear(
2298
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2299
+ )
2300
+ self.set_linear(
2301
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2302
+ )
2257
2303
 
2258
2304
  delattr(layer, "self_attn")
2259
2305
  delattr(layer, "mlp")
@@ -2277,20 +2323,30 @@ class Qwen3Loader(ModelLoader):
2277
2323
  if num_heads_kv == num_heads:
2278
2324
  num_heads_kv = None
2279
2325
 
2280
- rope_scaling = getattr(model.config, "rope_scaling", None)
2281
- if rope_scaling:
2282
- rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
2283
- rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
2284
- rotary_scaling_factor = rope_scaling["factor"]
2285
- if rotary_scaling_type is None:
2326
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
2327
+ model.config, 1_000_000
2328
+ )
2329
+ # Check for AWQ quantization config
2330
+ quantization_config = getattr(model.config, "quantization_config", None)
2331
+ if quantization_config:
2332
+ quant_type = None
2333
+ if quantization_config.quant_method == "awq":
2334
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2335
+ if quant_type is None:
2286
2336
  raise NotImplementedError(
2287
- "RoPE scaling type '%s' is not yet implemented. "
2288
- "The following RoPE scaling types are currently supported: %s"
2289
- % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
2337
+ "Quantization type '%s' is not yet implemented. "
2338
+ "The following Quantization types are currently supported: %s"
2339
+ % (
2340
+ quantization_config.quant_method,
2341
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2342
+ )
2290
2343
  )
2344
+ quant_group_size = quantization_config.group_size
2345
+ quant_bits = quantization_config.bits
2291
2346
  else:
2292
- rotary_scaling_type = None
2293
- rotary_scaling_factor = 1
2347
+ quant_type = common_spec.Quantization.CT2
2348
+ quant_group_size = None
2349
+ quant_bits = None
2294
2350
 
2295
2351
  spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2296
2352
  num_layers,
@@ -2303,13 +2359,16 @@ class Qwen3Loader(ModelLoader):
2303
2359
  rotary_interleave=False,
2304
2360
  rotary_scaling_type=rotary_scaling_type,
2305
2361
  rotary_scaling_factor=rotary_scaling_factor,
2306
- rotary_base=getattr(model.config, "rope_theta", 10000),
2362
+ rotary_base=rope_theta,
2307
2363
  num_heads_kv=num_heads_kv,
2308
2364
  head_dim=head_dim,
2309
2365
  qk_norm=True,
2366
+ quant_type=quant_type,
2367
+ quant_group_size=quant_group_size,
2368
+ quant_bits=quant_bits,
2310
2369
  )
2311
2370
 
2312
- self.set_decoder(spec.decoder, model.model)
2371
+ self.set_decoder(spec.decoder, model.model, quant_type)
2313
2372
  self.set_linear(spec.decoder.projection, model.lm_head)
2314
2373
  return spec
2315
2374
 
@@ -2338,7 +2397,7 @@ class Qwen3Loader(ModelLoader):
2338
2397
  def set_layer_norm(self, spec, layer_norm):
2339
2398
  spec.gamma = layer_norm.weight
2340
2399
 
2341
- def set_decoder(self, spec, module):
2400
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2342
2401
  spec.scale_embeddings = False
2343
2402
  self.set_embeddings(spec.embeddings, module.embed_tokens)
2344
2403
  self.set_layer_norm(spec.layer_norm, module.norm)
@@ -2359,22 +2418,43 @@ class Qwen3Loader(ModelLoader):
2359
2418
  )
2360
2419
 
2361
2420
  split_layers = [common_spec.LinearSpec() for _ in range(3)]
2362
- self.set_linear(split_layers[0], layer.self_attn.q_proj)
2363
- self.set_linear(split_layers[1], layer.self_attn.k_proj)
2364
- self.set_linear(split_layers[2], layer.self_attn.v_proj)
2365
- utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2421
+ self.set_linear(
2422
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2423
+ )
2424
+ self.set_linear(
2425
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2426
+ )
2427
+ self.set_linear(
2428
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2429
+ )
2430
+
2431
+ if quant_type == common_spec.Quantization.CT2:
2432
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2433
+ else:
2434
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2435
+ utils.fuse_linear_prequant(
2436
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2437
+ )
2366
2438
 
2367
2439
  self.set_linear(
2368
2440
  layer_spec.self_attention.linear[1],
2369
2441
  layer.self_attn.o_proj,
2442
+ quant_type=quant_type,
2370
2443
  )
2371
2444
 
2372
- self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
2373
- self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
2374
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2445
+ self.set_linear(
2446
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2447
+ )
2448
+ self.set_linear(
2449
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2450
+ )
2451
+ self.set_linear(
2452
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2453
+ )
2375
2454
 
2376
2455
  delattr(layer, "self_attn")
2377
2456
  delattr(layer, "mlp")
2457
+ gc.collect()
2378
2458
 
2379
2459
 
2380
2460
  @register_loader("MixFormerSequentialConfig")
@@ -2514,6 +2594,28 @@ class Phi3Loader(ModelLoader):
2514
2594
  rotary_scaling_type = None
2515
2595
  rotary_scaling_factor = 1
2516
2596
 
2597
+ # Check for AWQ quantization config
2598
+ quantization_config = getattr(model.config, "quantization_config", None)
2599
+ if quantization_config:
2600
+ quant_type = None
2601
+ if quantization_config.quant_method == "awq":
2602
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2603
+ if quant_type is None:
2604
+ raise NotImplementedError(
2605
+ "Quantization type '%s' is not yet implemented. "
2606
+ "The following Quantization types are currently supported: %s"
2607
+ % (
2608
+ quantization_config.quant_method,
2609
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2610
+ )
2611
+ )
2612
+ quant_group_size = quantization_config.group_size
2613
+ quant_bits = quantization_config.bits
2614
+ else:
2615
+ quant_type = common_spec.Quantization.CT2
2616
+ quant_group_size = None
2617
+ quant_bits = None
2618
+
2517
2619
  spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2518
2620
  num_layers,
2519
2621
  num_heads,
@@ -2529,9 +2631,12 @@ class Phi3Loader(ModelLoader):
2529
2631
  original_max_position_embeddings=original_max_position_embeddings,
2530
2632
  max_position_embeddings=max_position_embeddings,
2531
2633
  num_heads_kv=num_heads_kv,
2634
+ quant_type=quant_type,
2635
+ quant_group_size=quant_group_size,
2636
+ quant_bits=quant_bits,
2532
2637
  )
2533
2638
 
2534
- self.set_decoder(spec.decoder, model.model)
2639
+ self.set_decoder(spec.decoder, model.model, quant_type)
2535
2640
  self.set_linear(spec.decoder.projection, model.lm_head)
2536
2641
  return spec
2537
2642
 
@@ -2565,7 +2670,7 @@ class Phi3Loader(ModelLoader):
2565
2670
  rotary_scaling_short_factor, dtype=torch.float32
2566
2671
  )
2567
2672
 
2568
- def set_decoder(self, spec, module):
2673
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2569
2674
  spec.scale_embeddings = False
2570
2675
  self.set_embeddings(spec.embeddings, module.embed_tokens)
2571
2676
  self.set_layer_norm(spec.layer_norm, module.norm)
@@ -2579,9 +2684,15 @@ class Phi3Loader(ModelLoader):
2579
2684
  )
2580
2685
 
2581
2686
  self.set_linear(
2582
- layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj
2687
+ layer_spec.self_attention.linear[0],
2688
+ layer.self_attn.qkv_proj,
2689
+ quant_type=quant_type,
2690
+ )
2691
+ self.set_linear(
2692
+ layer_spec.self_attention.linear[1],
2693
+ layer.self_attn.o_proj,
2694
+ quant_type=quant_type,
2583
2695
  )
2584
- self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)
2585
2696
  if (
2586
2697
  layer.self_attn.rotary_emb.long_factor is not None
2587
2698
  and layer.self_attn.rotary_emb.short_factor is not None
@@ -2592,10 +2703,30 @@ class Phi3Loader(ModelLoader):
2592
2703
  layer.self_attn.rotary_emb.short_factor,
2593
2704
  )
2594
2705
 
2595
- gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
2596
- layer_spec.ffn.linear_0.weight = gate_proj
2597
- layer_spec.ffn.linear_0_noact.weight = up_proj
2598
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2706
+ # Handle gate_up_proj differently for AWQ vs regular models
2707
+ if quant_type == common_spec.Quantization.CT2:
2708
+ gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
2709
+ layer_spec.ffn.linear_0.weight = gate_proj
2710
+ layer_spec.ffn.linear_0_noact.weight = up_proj
2711
+ else:
2712
+ # AWQ: chunk qweight, scales, and qzeros
2713
+ gate_qweight, up_qweight = layer.mlp.gate_up_proj.qweight.chunk(
2714
+ 2, dim=1
2715
+ )
2716
+ gate_scales, up_scales = layer.mlp.gate_up_proj.scales.chunk(2, dim=1)
2717
+ gate_qzeros, up_qzeros = layer.mlp.gate_up_proj.qzeros.chunk(2, dim=1)
2718
+
2719
+ layer_spec.ffn.linear_0.weight = gate_qweight
2720
+ layer_spec.ffn.linear_0.weight_scale = gate_scales
2721
+ layer_spec.ffn.linear_0.weight_zero = gate_qzeros
2722
+
2723
+ layer_spec.ffn.linear_0_noact.weight = up_qweight
2724
+ layer_spec.ffn.linear_0_noact.weight_scale = up_scales
2725
+ layer_spec.ffn.linear_0_noact.weight_zero = up_qzeros
2726
+
2727
+ self.set_linear(
2728
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2729
+ )
2599
2730
 
2600
2731
  delattr(layer, "self_attn")
2601
2732
  delattr(layer, "mlp")
@@ -3325,3 +3456,266 @@ _WHISPER_ALIGNMENT_HEADS = {
3325
3456
  (25, 6),
3326
3457
  ],
3327
3458
  }
3459
+
3460
+
3461
+ # Paper: https://arxiv.org/pdf/2504.06225
3462
+ @register_loader("T5GemmaConfig")
3463
+ class T5GemmaLoader(ModelLoader):
3464
+ @property
3465
+ def architecture_name(self):
3466
+ return "T5GemmaForConditionalGeneration"
3467
+
3468
+ def set_layer_norm(self, spec, layer_norm):
3469
+ spec.gamma = layer_norm.weight.data + 1.0
3470
+
3471
+ def get_model_spec(self, model):
3472
+ encoder_config = model.config.encoder
3473
+ decoder_config = model.config.decoder
3474
+ sliding_window = getattr(model.config, "sliding_window", 4096)
3475
+
3476
+ encoder_num_heads = encoder_config.num_attention_heads
3477
+ encoder_num_heads_kv = getattr(
3478
+ encoder_config, "num_key_value_heads", encoder_num_heads
3479
+ )
3480
+ if encoder_num_heads_kv == encoder_num_heads:
3481
+ encoder_num_heads_kv = None
3482
+
3483
+ encoder = transformer_spec.TransformerEncoderSpec(
3484
+ encoder_config.num_hidden_layers,
3485
+ encoder_config.num_attention_heads,
3486
+ pre_norm=True,
3487
+ activation=_SUPPORTED_ACTIVATIONS[encoder_config.hidden_activation],
3488
+ ffn_glu=True,
3489
+ rms_norm=True,
3490
+ rotary_dim=encoder_config.head_dim,
3491
+ rotary_interleave=False,
3492
+ rotary_base=getattr(encoder_config, "rope_theta", 10000),
3493
+ sliding_window=sliding_window,
3494
+ pre_post_layer_norm=True,
3495
+ num_heads_kv=encoder_num_heads_kv,
3496
+ head_dim=encoder_config.head_dim,
3497
+ )
3498
+
3499
+ decoder_num_heads = decoder_config.num_attention_heads
3500
+ decoder_num_heads_kv = getattr(
3501
+ decoder_config, "num_key_value_heads", decoder_num_heads
3502
+ )
3503
+ if decoder_num_heads_kv == decoder_num_heads:
3504
+ decoder_num_heads_kv = None
3505
+
3506
+ decoder = transformer_spec.TransformerDecoderSpec(
3507
+ decoder_config.num_hidden_layers,
3508
+ decoder_config.num_attention_heads,
3509
+ pre_norm=True,
3510
+ activation=_SUPPORTED_ACTIVATIONS[decoder_config.hidden_activation],
3511
+ ffn_glu=True,
3512
+ rms_norm=True,
3513
+ with_encoder_attention=True,
3514
+ rotary_dim=decoder_config.head_dim,
3515
+ rotary_interleave=False,
3516
+ rotary_base=getattr(decoder_config, "rope_theta", 10000),
3517
+ sliding_window=sliding_window,
3518
+ pre_post_layer_norm=True,
3519
+ external_pre_post_encoder_layers=True,
3520
+ num_heads_kv=decoder_num_heads_kv,
3521
+ head_dim=decoder_config.head_dim,
3522
+ )
3523
+
3524
+ spec = transformer_spec.TransformerSpec(encoder, decoder)
3525
+
3526
+ self.set_encoder(spec.encoder, model.model.encoder, encoder_config)
3527
+
3528
+ self.set_decoder(
3529
+ spec.decoder,
3530
+ model.model.decoder,
3531
+ decoder_config,
3532
+ common_spec.Quantization.CT2,
3533
+ )
3534
+
3535
+ # Tie_word_embeddings
3536
+ self.set_linear(spec.decoder.projection, model.model.decoder.embed_tokens)
3537
+ return spec
3538
+
3539
+ def set_vocabulary(self, spec, tokens):
3540
+ spec.register_source_vocabulary(tokens)
3541
+ spec.register_target_vocabulary(tokens)
3542
+
3543
+ def set_config(self, config, model, tokenizer):
3544
+ config.bos_token = tokenizer.bos_token
3545
+ config.eos_token = tokenizer.eos_token
3546
+ config.unk_token = tokenizer.unk_token
3547
+
3548
+ if hasattr(model.config, "encoder"):
3549
+ config.layer_norm_epsilon = model.config.encoder.rms_norm_eps
3550
+ elif hasattr(model.config, "rms_norm_eps"):
3551
+ config.layer_norm_epsilon = model.config.rms_norm_eps
3552
+ else:
3553
+ config.layer_norm_epsilon = 1e-6
3554
+
3555
+ config.decoder_start_token = tokenizer.bos_token
3556
+
3557
+ def set_encoder(
3558
+ self, spec, encoder, encoder_config, quant_type=common_spec.Quantization.CT2
3559
+ ):
3560
+ spec.scale_embeddings = True
3561
+
3562
+ encoder_emb_spec = (
3563
+ spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
3564
+ )
3565
+
3566
+ self.set_embeddings(encoder_emb_spec, encoder.embed_tokens)
3567
+ encoder_emb_spec.multiply_by_sqrt_depth = encoder_config.hidden_size**0.5
3568
+ self.set_layer_norm(spec.layer_norm, encoder.norm)
3569
+
3570
+ module = encoder
3571
+ for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
3572
+ self.set_layer_norm(
3573
+ layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
3574
+ )
3575
+ self.set_layer_norm(
3576
+ layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
3577
+ )
3578
+
3579
+ # T5GemmaSelfAttention
3580
+ qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
3581
+ self.set_linear(
3582
+ qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
3583
+ )
3584
+ self.set_linear(
3585
+ qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
3586
+ )
3587
+ self.set_linear(
3588
+ qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
3589
+ )
3590
+ utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
3591
+ self.set_linear(
3592
+ layer_spec.self_attention.linear[1],
3593
+ layer.self_attn.o_proj,
3594
+ quant_type=quant_type,
3595
+ )
3596
+
3597
+ # T5GemmaRMSNorm
3598
+ self.set_layer_norm(
3599
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
3600
+ )
3601
+ # T5GemmaRMSNorm
3602
+ self.set_layer_norm(
3603
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
3604
+ )
3605
+
3606
+ # T5GemmaMLP
3607
+ self.set_linear(
3608
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
3609
+ )
3610
+ self.set_linear(
3611
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
3612
+ )
3613
+ self.set_linear(
3614
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
3615
+ )
3616
+
3617
+ # Clean up
3618
+ delattr(layer, "self_attn")
3619
+ delattr(layer, "mlp")
3620
+ gc.collect()
3621
+
3622
+ def set_decoder(
3623
+ self, spec, module, decoder_config, quant_type=common_spec.Quantization.CT2
3624
+ ):
3625
+ spec.scale_embeddings = True
3626
+ spec.start_from_zero_embedding = False
3627
+
3628
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
3629
+ spec.embeddings.multiply_by_sqrt_depth = decoder_config.hidden_size**0.5
3630
+ self.set_layer_norm(spec.layer_norm, module.norm)
3631
+
3632
+ for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
3633
+ # Self-attention block
3634
+ self.set_layer_norm(
3635
+ layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
3636
+ )
3637
+ self.set_layer_norm(
3638
+ layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
3639
+ )
3640
+
3641
+ # T5GemmaSelfAttention - QKV projections
3642
+ qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
3643
+ self.set_linear(
3644
+ qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
3645
+ )
3646
+ self.set_linear(
3647
+ qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
3648
+ )
3649
+ self.set_linear(
3650
+ qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
3651
+ )
3652
+ utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
3653
+ self.set_linear(
3654
+ layer_spec.self_attention.linear[1],
3655
+ layer.self_attn.o_proj,
3656
+ quant_type=quant_type,
3657
+ )
3658
+
3659
+ # Pre and post cross-attention layer norm
3660
+ self.set_layer_norm(
3661
+ layer_spec.external_pre_encoder_attention_layer_norm,
3662
+ layer.pre_cross_attn_layernorm,
3663
+ )
3664
+
3665
+ self.set_layer_norm(
3666
+ layer_spec.external_post_encoder_attention_layer_norm,
3667
+ layer.post_cross_attn_layernorm,
3668
+ )
3669
+
3670
+ # Cross-attention Q projection
3671
+ self.set_linear(
3672
+ layer_spec.attention.linear[0],
3673
+ layer.cross_attn.q_proj,
3674
+ quant_type=quant_type,
3675
+ )
3676
+
3677
+ # Cross-attention K+V fused
3678
+ kv_split_layers = [common_spec.LinearSpec() for _ in range(2)]
3679
+ self.set_linear(
3680
+ kv_split_layers[0],
3681
+ layer.cross_attn.k_proj,
3682
+ quant_type=quant_type,
3683
+ )
3684
+ self.set_linear(
3685
+ kv_split_layers[1],
3686
+ layer.cross_attn.v_proj,
3687
+ quant_type=quant_type,
3688
+ )
3689
+ utils.fuse_linear(layer_spec.attention.linear[1], kv_split_layers)
3690
+
3691
+ # Cross-attention output projection
3692
+ self.set_linear(
3693
+ layer_spec.attention.linear[2],
3694
+ layer.cross_attn.o_proj,
3695
+ quant_type=quant_type,
3696
+ )
3697
+
3698
+ # Feed-forward block
3699
+ self.set_layer_norm(
3700
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
3701
+ )
3702
+ self.set_layer_norm(
3703
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
3704
+ )
3705
+
3706
+ # T5GemmaMLP
3707
+ self.set_linear(
3708
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
3709
+ )
3710
+ self.set_linear(
3711
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
3712
+ )
3713
+ self.set_linear(
3714
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
3715
+ )
3716
+
3717
+ # Clean up
3718
+ delattr(layer, "self_attn")
3719
+ delattr(layer, "cross_attn")
3720
+ delattr(layer, "mlp")
3721
+ gc.collect()
Binary file
ctranslate2/cudnn64_9.dll CHANGED
Binary file
ctranslate2/extensions.py CHANGED
@@ -556,12 +556,28 @@ def _process_iterable(process_func, iterables, max_batch_size, batch_type, **kwa
556
556
 
557
557
  def _batch_iterator(iterable, batch_size, batch_type):
558
558
  streams = None
559
- cur_batch_size = 0
559
+ max_length = 0
560
560
 
561
561
  for example in iterable:
562
562
  if not isinstance(example, tuple):
563
563
  example = (example,)
564
564
 
565
+ if batch_type == "examples":
566
+ if streams and len(streams[0]) == batch_size:
567
+ yield streams
568
+ streams = None
569
+
570
+ elif batch_type == "tokens":
571
+ max_length = max(max_length, len(example[0]))
572
+
573
+ if streams and (len(streams[0]) + 1) * max_length > batch_size:
574
+ yield streams
575
+ streams = None
576
+ max_length = len(example[0])
577
+
578
+ else:
579
+ raise ValueError("Invalid batch type %s" % batch_type)
580
+
565
581
  if streams is None:
566
582
  streams = tuple([] for _ in example)
567
583
  for batch, element in zip(streams, example):
@@ -569,17 +585,5 @@ def _batch_iterator(iterable, batch_size, batch_type):
569
585
  raise ValueError("Input iterables do not have the same length")
570
586
  batch.append(element)
571
587
 
572
- if batch_type == "examples":
573
- cur_batch_size += 1
574
- elif batch_type == "tokens":
575
- cur_batch_size += len(example[0])
576
- else:
577
- raise ValueError("Invalid batch type %s" % batch_type)
578
-
579
- if cur_batch_size >= batch_size:
580
- yield streams
581
- streams = None
582
- cur_batch_size = 0
583
-
584
588
  if streams is not None:
585
589
  yield streams
@@ -34,10 +34,12 @@ class MultiHeadAttentionSpec(model_spec.LayerSpec):
34
34
  sliding_window=None,
35
35
  qk_norm=False,
36
36
  qk_norm_rms=True,
37
+ has_norm=True,
37
38
  ):
38
39
  self.queries_scale = model_spec.OPTIONAL
39
40
 
40
- self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
41
+ if has_norm:
42
+ self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
41
43
  self.linear = [
42
44
  common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
43
45
  ]
@@ -23,6 +23,16 @@ class TransformerEncoderSpec(model_spec.LayerSpec):
23
23
  ffn_glu: bool = False,
24
24
  rms_norm: bool = False,
25
25
  multi_query_attention: bool = False,
26
+ num_heads_kv: Optional[int] = None,
27
+ head_dim: Optional[int] = None,
28
+ rotary_dim: Optional[int] = None,
29
+ rotary_interleave: bool = True,
30
+ rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
31
+ rotary_scaling_factor: float = 1,
32
+ rotary_base: float = 10000,
33
+ sliding_window: Optional[int] = None,
34
+ qk_norm: Optional[bool] = False,
35
+ pre_post_layer_norm: bool = False,
26
36
  ):
27
37
  """Initializes a Transformer encoder specification.
28
38
 
@@ -43,8 +53,28 @@ class TransformerEncoderSpec(model_spec.LayerSpec):
43
53
  ffn_glu: Use gated linear units in the FFN layers as described in
44
54
  https://arxiv.org/abs/2002.05202.
45
55
  rms_norm: Use the root mean square layer normalization.
46
- multi_query_attention: Use multi-query attention.
56
+ multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
57
+ num_heads_kv: Number of attention heads for the key and value.
58
+ head_dim: Number of dimensions per attention head.
59
+ rotary_dim: Apply rotary embeddings to these first N dimensions. If 0, rotary
60
+ embeddings are applied to all dimensions.
61
+ rotary_interleave: Interleave the head dimensions when rotary embeddings are applied.
62
+ Otherwise the head dimensions are sliced in half.
63
+ rotary_scaling_type: Type of RoPE scaling.
64
+ rotary_scaling_factor: Factor used in the RoPE scaling.
65
+ rotary_base: The base period of the rotary embeddings.
66
+ sliding_window: Max sequence length to retain in KV Cache.
67
+ qk_norm: Apply layer normalization to the query and key projections.
68
+ pre_post_layer_norm: Add post layer norm for each pre norm layer.
47
69
  """
70
+
71
+ if multi_query_attention:
72
+ if num_heads_kv is not None and num_heads_kv != 1:
73
+ raise ValueError(
74
+ "Enabling multi_query_attention implies num_heads_kv=1"
75
+ )
76
+ num_heads_kv = 1
77
+
48
78
  self.multi_query_attention = multi_query_attention
49
79
  self.num_heads = np.dtype("int16").type(num_heads)
50
80
  self.pre_norm = pre_norm
@@ -60,13 +90,24 @@ class TransformerEncoderSpec(model_spec.LayerSpec):
60
90
  self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
61
91
  if layernorm_embedding:
62
92
  self.layernorm_embedding = common_spec.LayerNormSpec(rms_norm=rms_norm)
93
+ if sliding_window is not None:
94
+ self.sliding_window = np.dtype("int32").type(sliding_window)
95
+
63
96
  self.layer = [
64
97
  TransformerEncoderLayerSpec(
65
98
  relative_position=relative_position,
66
99
  relative_attention_bias=relative_attention_bias,
67
100
  ffn_glu=ffn_glu,
68
101
  rms_norm=rms_norm,
69
- num_heads_kv=1 if multi_query_attention else None,
102
+ num_heads_kv=num_heads_kv,
103
+ head_dim=head_dim,
104
+ rotary_dim=rotary_dim,
105
+ rotary_interleave=rotary_interleave,
106
+ rotary_scaling_type=rotary_scaling_type,
107
+ rotary_scaling_factor=rotary_scaling_factor,
108
+ rotary_base=rotary_base,
109
+ qk_norm=qk_norm,
110
+ pre_post_layer_norm=pre_post_layer_norm,
70
111
  )
71
112
  for _ in range(num_layers)
72
113
  ]
@@ -109,7 +150,8 @@ class TransformerDecoderSpec(model_spec.LayerSpec):
109
150
  quant_type: Optional[common_spec.Quantization] = None,
110
151
  quant_group_size: Optional[int] = None,
111
152
  quant_bits: Optional[int] = None,
112
- qk_norm: Optional[bool] = False,
153
+ qk_norm: bool = False,
154
+ external_pre_post_encoder_layers: Optional[bool] = False,
113
155
  ):
114
156
  """Initializes a Transformer decoder specification.
115
157
 
@@ -156,6 +198,8 @@ class TransformerDecoderSpec(model_spec.LayerSpec):
156
198
  quant_type: quantization type used (like awq... for lower bit quantization)
157
199
  quant_group_size: group size of the lower bit quantization
158
200
  quant_bits: number of bit of the quantization (ex: 4bit)
201
+ external_pre_post_encoder_layers: if the encoder attention pre and processing
202
+ is done outside the attention.
159
203
  """
160
204
 
161
205
  self._config = dict()
@@ -172,12 +216,6 @@ class TransformerDecoderSpec(model_spec.LayerSpec):
172
216
  )
173
217
  num_heads_kv = 1
174
218
 
175
- if with_encoder_attention and num_heads_kv not in (None, 1, num_heads):
176
- raise ValueError(
177
- "num_heads_kv=%d is not supported in the cross-attention layers"
178
- % num_heads_kv
179
- )
180
-
181
219
  self.num_heads = np.dtype("int16").type(num_heads)
182
220
  self.pre_norm = pre_norm
183
221
  self.activation = np.dtype("int8").type(activation)
@@ -224,6 +262,7 @@ class TransformerDecoderSpec(model_spec.LayerSpec):
224
262
  head_dim=head_dim,
225
263
  sliding_window=sliding_window,
226
264
  qk_norm=qk_norm,
265
+ external_pre_post_encoder_layers=external_pre_post_encoder_layers,
227
266
  )
228
267
  for _ in range(num_layers)
229
268
  ]
@@ -236,7 +275,7 @@ class TransformerDecoderSpec(model_spec.LayerSpec):
236
275
  self.project_in = common_spec.LinearSpec()
237
276
  self.project_out = common_spec.LinearSpec()
238
277
 
239
- if quant_type is not None:
278
+ if quant_type:
240
279
  self._config["quantization_type"] = quant_type
241
280
  self._config["quantization_bits"] = quant_bits
242
281
  self._config["quantization_group_size"] = quant_group_size
@@ -254,7 +293,15 @@ class TransformerEncoderLayerSpec(model_spec.LayerSpec):
254
293
  ffn_glu=False,
255
294
  rms_norm=False,
256
295
  num_heads_kv=None,
296
+ head_dim=None,
257
297
  sliding_window=None,
298
+ rotary_dim: Optional[int] = None,
299
+ rotary_interleave: bool = True,
300
+ rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
301
+ rotary_scaling_factor: float = 1,
302
+ rotary_base: float = 10000,
303
+ qk_norm=False,
304
+ pre_post_layer_norm: bool = False,
258
305
  ):
259
306
  self.self_attention = attention_spec.MultiHeadAttentionSpec(
260
307
  self_attention=True,
@@ -262,10 +309,32 @@ class TransformerEncoderLayerSpec(model_spec.LayerSpec):
262
309
  relative_attention_bias=relative_attention_bias,
263
310
  rms_norm=rms_norm,
264
311
  num_heads_kv=num_heads_kv,
312
+ head_dim=head_dim,
265
313
  sliding_window=sliding_window,
314
+ rotary_dim=rotary_dim,
315
+ rotary_interleave=rotary_interleave,
316
+ rotary_scaling_type=rotary_scaling_type,
317
+ rotary_scaling_factor=rotary_scaling_factor,
318
+ rotary_base=rotary_base,
319
+ qk_norm=qk_norm,
266
320
  )
267
321
  self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
268
322
 
323
+ if pre_post_layer_norm:
324
+ self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
325
+ self.post_attention_layer_norm = common_spec.LayerNormSpec(
326
+ rms_norm=rms_norm
327
+ )
328
+ self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
329
+ rms_norm=rms_norm
330
+ )
331
+ self.post_feedforward_layer_norm = common_spec.LayerNormSpec(
332
+ rms_norm=rms_norm
333
+ )
334
+
335
+ delattr(self.self_attention, "layer_norm")
336
+ delattr(self.ffn, "layer_norm")
337
+
269
338
 
270
339
  class TransformerDecoderLayerSpec(model_spec.LayerSpec):
271
340
  def __init__(
@@ -289,6 +358,7 @@ class TransformerDecoderLayerSpec(model_spec.LayerSpec):
289
358
  head_dim=None,
290
359
  sliding_window=None,
291
360
  qk_norm=False,
361
+ external_pre_post_encoder_layers=False,
292
362
  ):
293
363
  self.self_attention = attention_spec.MultiHeadAttentionSpec(
294
364
  self_attention=True,
@@ -312,8 +382,10 @@ class TransformerDecoderLayerSpec(model_spec.LayerSpec):
312
382
  self.attention = attention_spec.MultiHeadAttentionSpec(
313
383
  rms_norm=rms_norm,
314
384
  num_heads_kv=num_heads_kv,
385
+ head_dim=head_dim,
315
386
  sliding_window=sliding_window,
316
387
  qk_norm=qk_norm,
388
+ has_norm=external_pre_post_encoder_layers is False,
317
389
  )
318
390
 
319
391
  self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
@@ -329,10 +401,21 @@ class TransformerDecoderLayerSpec(model_spec.LayerSpec):
329
401
  delattr(self.ffn, "layer_norm")
330
402
 
331
403
  if pre_post_layer_norm:
404
+ # Self-attention layer norms
332
405
  self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
333
406
  self.post_attention_layer_norm = common_spec.LayerNormSpec(
334
407
  rms_norm=rms_norm
335
408
  )
409
+
410
+ if with_encoder_attention and external_pre_post_encoder_layers:
411
+ self.external_post_encoder_attention_layer_norm = (
412
+ common_spec.LayerNormSpec(rms_norm=rms_norm)
413
+ )
414
+ self.external_pre_encoder_attention_layer_norm = (
415
+ common_spec.LayerNormSpec(rms_norm=rms_norm)
416
+ )
417
+
418
+ # Feed-forward layer norms
336
419
  self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
337
420
  rms_norm=rms_norm
338
421
  )
@@ -562,7 +645,7 @@ class TransformerDecoderModelSpec(model_spec.LanguageModelSpec):
562
645
  quant_type: Optional[common_spec.Quantization] = None,
563
646
  quant_group_size: Optional[int] = None,
564
647
  quant_bits: Optional[int] = None,
565
- qk_norm: Optional[bool] = False,
648
+ qk_norm: bool = False,
566
649
  ):
567
650
  """Creates a Transformer decoder model specification.
568
651
 
ctranslate2/version.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Version information."""
2
2
 
3
- __version__ = "4.6.2"
3
+ __version__ = "4.7.0"
@@ -1,9 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ctranslate2
3
- Version: 4.6.2
3
+ Version: 4.7.0
4
4
  Summary: Fast inference engine for Transformer models
5
5
  Home-page: https://opennmt.net
6
6
  Author: OpenNMT
7
+ License: MIT
7
8
  Project-URL: Documentation, https://opennmt.net/CTranslate2
8
9
  Project-URL: Forum, https://forum.opennmt.net
9
10
  Project-URL: Gitter, https://gitter.im/OpenNMT/CTranslate2
@@ -13,7 +14,6 @@ Classifier: Development Status :: 5 - Production/Stable
13
14
  Classifier: Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.4
14
15
  Classifier: Intended Audience :: Developers
15
16
  Classifier: Intended Audience :: Science/Research
16
- Classifier: License :: OSI Approved :: MIT License
17
17
  Classifier: Programming Language :: Python :: 3
18
18
  Classifier: Programming Language :: Python :: 3 :: Only
19
19
  Classifier: Programming Language :: Python :: 3.9
@@ -34,6 +34,7 @@ Dynamic: description
34
34
  Dynamic: description-content-type
35
35
  Dynamic: home-page
36
36
  Dynamic: keywords
37
+ Dynamic: license
37
38
  Dynamic: project-url
38
39
  Dynamic: requires-dist
39
40
  Dynamic: requires-python
@@ -49,7 +50,7 @@ The project implements a custom runtime that applies many performance optimizati
49
50
 
50
51
  The following model types are currently supported:
51
52
 
52
- * Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
53
+ * Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper T5Gemma
53
54
  * Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2
54
55
  * Encoder-only models: BERT, DistilBERT, XLM-RoBERTa
55
56
 
@@ -99,6 +100,8 @@ generator.generate_batch(start_tokens)
99
100
 
100
101
  See the [documentation](https://opennmt.net/CTranslate2) for more information and examples.
101
102
 
103
+ If you have an AMD ROCm GPU, we provide specific Python wheels on the [releases page](https://github.com/OpenNMT/CTranslate2/releases/).
104
+
102
105
  ## Benchmarks
103
106
 
104
107
  We translate the En->De test set *newstest2014* with multiple models:
@@ -160,6 +163,16 @@ Executed with 4 threads on a [*c5.2xlarge*](https://aws.amazon.com/ec2/instance-
160
163
 
161
164
  Executed with CUDA 11 on a [*g5.xlarge*](https://aws.amazon.com/ec2/instance-types/g5/) Amazon EC2 instance equipped with a NVIDIA A10G GPU (driver version: 510.47.03).
162
165
 
166
+ ## Contributing
167
+
168
+ CTranslate2 is a community-driven project. We welcome contributions of all kinds:
169
+ * **New Model Support:** Help us implement more Transformer architectures.
170
+ * **Performance:** Propose optimizations for CPU or GPU kernels.
171
+ * **Bug Reports:** Open an issue if you find something not working as expected.
172
+ * **Documentation:** Improve our guides or add new examples.
173
+
174
+ Check out our [Contributing Guide](CONTRIBUTING.md) to learn how to set up your development environment.
175
+
163
176
  ## Additional resources
164
177
 
165
178
  * [Documentation](https://opennmt.net/CTranslate2)
@@ -1,33 +1,33 @@
1
- ctranslate2/__init__.py,sha256=CGqShDaFxQ-u-aCtVq99T4HKuBdMB8b49l2KSxnQb8M,1735
2
- ctranslate2/_ext.cp313-win_amd64.pyd,sha256=T7xwyuyjSstcSjc_SHDfLjDQt8z5s_qpSIXkdrR5UAU,715776
3
- ctranslate2/ctranslate2.dll,sha256=9zIz4dY3yV1kTTKaipyQwjcGDwzZ3OzKiOkNpXdcQ1U,58389504
4
- ctranslate2/cudnn64_9.dll,sha256=wHzEfy-kpWZZPHr0qn5X7fCamFoP3dFMuNb0VuJSrwU,438840
5
- ctranslate2/extensions.py,sha256=axO2FI8ddiFmlko2AzQ6VcdtF-3hDA7VmPGnTIkrPkI,21782
1
+ ctranslate2/__init__.py,sha256=LZy5gF-9vTRdcERSnTSP_RrCPDks9UDU7uzxw1-d0aU,1881
2
+ ctranslate2/_ext.cp313-win_amd64.pyd,sha256=bPfgej3CXraCad6brhOxuQFimiFviEjVksLxGSC7Oas,715776
3
+ ctranslate2/ctranslate2.dll,sha256=umrDDC_rg_IbXg2MOi-8jNasZZdjb9b_Io8CWS5_M_U,59823104
4
+ ctranslate2/cudnn64_9.dll,sha256=ntvN_3OwrwcOsWCyzmbln-ygSqAXNR2O7cxejhSZZ9I,266288
5
+ ctranslate2/extensions.py,sha256=kDNt0H9KvfNCc3PrRGzfkj9Fkvna84i2O5Y-rav6UkU,21940
6
6
  ctranslate2/libiomp5md.dll,sha256=mCIzNmsK_NoeD1WgsTQJfjW3eWE_VN22nmhebNBrdV8,1614192
7
7
  ctranslate2/logging.py,sha256=P9evHdxuMx_iHvwJjEASEq-j5062H64Pl5-fJjxEuHk,1221
8
- ctranslate2/version.py,sha256=f2Hk9NHTYgXftujV8JVkeOzenykZ9QzbsZ-nIt9U1uc,53
8
+ ctranslate2/version.py,sha256=cWqiIzEeUIcvUfq82ZopTbW1pRWqZkZOW7b6pks8tz8,53
9
9
  ctranslate2/converters/__init__.py,sha256=ufYjcXf2sK4fiXAUU6tIJyWmNuLjKFf_KH3GWLXe4ls,507
10
10
  ctranslate2/converters/converter.py,sha256=Qkb8NGLLmgqMT6HZkFq61zwbxyq3NlWcaxLZ6Ap-YOQ,3601
11
- ctranslate2/converters/eole_ct2.py,sha256=RUcDJH_2AUt0jDs5oAqccE6tQPbO9LQ6JmVriC1DTy8,12564
12
- ctranslate2/converters/fairseq.py,sha256=uQpd-ftYSO4c6WdEwCUyuZWhzWX1UTG7dGOC6EtcDVE,12765
11
+ ctranslate2/converters/eole_ct2.py,sha256=sRXvPark9V-4umXpMxPuJVQekMLstyNZ7xNjyAFthvg,12623
12
+ ctranslate2/converters/fairseq.py,sha256=2vlBk4AVCHwXxKkwPHVmcjyfo1dAV0_DJS1i6q-44NE,12822
13
13
  ctranslate2/converters/marian.py,sha256=1_7P3EbIDPOdyJbtb_Lp-LCBPBb9A8E9OhzoyFwTb64,11274
14
14
  ctranslate2/converters/openai_gpt2.py,sha256=1rXKM2ZURZHWRv4XZ135fPkVWpM4rTG-q7VR7OD6d-A,3304
15
- ctranslate2/converters/opennmt_py.py,sha256=Vva60az6tGqlQXs0UgC09r_fCD3u2u6wUJB-8V4OUFQ,13183
15
+ ctranslate2/converters/opennmt_py.py,sha256=zex4TbHiiJMy0tkqQg39oNjxmSZKf8dnRLH3iQ1H4z0,13227
16
16
  ctranslate2/converters/opennmt_tf.py,sha256=uBRp2wz5xriSQcA_c0S0ekY7ws6RpRX_0EKeMRdM7-s,16222
17
17
  ctranslate2/converters/opus_mt.py,sha256=5KbPaTiBhhorPzMpTugIfIJ8SgcqHfJUbJrWKBN-Djs,1254
18
- ctranslate2/converters/transformers.py,sha256=zwqUFFFwLpam6z5lpBz2rgfYj065CbsdT9S_xVqPjCk,126110
18
+ ctranslate2/converters/transformers.py,sha256=41E9rMH6Qm77OIfswMVn7esp_NPZn3ZimiLTA6Be_50,141519
19
19
  ctranslate2/converters/utils.py,sha256=w7NG39lx-9dOdL57OqKVTdC__opkuP8RACg1TLlUJwM,3817
20
20
  ctranslate2/models/__init__.py,sha256=53p98uemtuvVPz8xK7_LbOhBiUJJu-c-NdmOHJgdXus,497
21
21
  ctranslate2/specs/__init__.py,sha256=9GabtSyczznYqiqUS6XvULi8pQ3_3RNRogXobGP0G80,653
22
- ctranslate2/specs/attention_spec.py,sha256=0JhCBrbb20G07UFnUAYIUtfcqn4VtflJHYWGIunwKDw,3442
22
+ ctranslate2/specs/attention_spec.py,sha256=FnaSiQREWQw_cURgsCb9_aIpGOCxyVGTCpIOdd-08v8,3492
23
23
  ctranslate2/specs/common_spec.py,sha256=freTDhQMy5PYofBrij4_FDgrKokMYApWSPIpASZIlJc,1608
24
24
  ctranslate2/specs/model_spec.py,sha256=atCAYzDEIzyJ1TCayFGZVutHqSWa1ww-vbZ0OiIJqh8,25736
25
- ctranslate2/specs/transformer_spec.py,sha256=43jOIvCSbAvqZJ1IyvRdGUa4f-zhdKhQBOXvp0T8YLE,30360
25
+ ctranslate2/specs/transformer_spec.py,sha256=s6mY6MMHneraXrWua_531Xjb5MVEJZCUTemUERO11GI,34305
26
26
  ctranslate2/specs/wav2vec2_spec.py,sha256=NITsuOuf2F5bU1-aXit8-WEtWV9fH2Eq7A7857UyYho,2106
27
27
  ctranslate2/specs/wav2vec2bert_spec.py,sha256=UgtsJWC9mMgJ7bn4T_xg1uXK0rqA4-9tT2KMGVgPKnw,3529
28
28
  ctranslate2/specs/whisper_spec.py,sha256=_vm1sc5yOowOJ4iyvcxMXrgt-UcLJrZT8OtPscUXcQQ,2447
29
- ctranslate2-4.6.2.dist-info/METADATA,sha256=r5HnmZE0BMI60j3N0GmDdM6l7Q7KW3w5nLLOX_AKCRY,10354
30
- ctranslate2-4.6.2.dist-info/WHEEL,sha256=qV0EIPljj1XC_vuSatRWjn02nZIz3N1t8jsZz7HBr2U,101
31
- ctranslate2-4.6.2.dist-info/entry_points.txt,sha256=ZHkojut_TmVRHl0bJIGm2b9wqr98GAJqxN9rlJtQshs,466
32
- ctranslate2-4.6.2.dist-info/top_level.txt,sha256=1hUaWzcFIuSo2BAIUHFA3Osgsu6S1giq0y6Rosv8HOQ,12
33
- ctranslate2-4.6.2.dist-info/RECORD,,
29
+ ctranslate2-4.7.0.dist-info/METADATA,sha256=Vm9SM5sybdzcJHc6HBek2PgP6nbuDiEHWQFZuJjWDvc,10979
30
+ ctranslate2-4.7.0.dist-info/WHEEL,sha256=-WvvtQtdhM1F5HMi-4hSXLQ_1Tg6qJRWO1HnLNr4mCU,102
31
+ ctranslate2-4.7.0.dist-info/entry_points.txt,sha256=ZHkojut_TmVRHl0bJIGm2b9wqr98GAJqxN9rlJtQshs,466
32
+ ctranslate2-4.7.0.dist-info/top_level.txt,sha256=1hUaWzcFIuSo2BAIUHFA3Osgsu6S1giq0y6Rosv8HOQ,12
33
+ ctranslate2-4.7.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp313-cp313-win_amd64
5
5