ctranslate2 4.6.1__cp314-cp314-win_amd64.whl → 4.6.3__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -89,7 +89,7 @@ class TransformersConverter(Converter):
89
89
  copy_files: List of filenames to copy from the Hugging Face model to the
90
90
  converted model directory.
91
91
  load_as_float16: Load the model weights as float16. More precisely, the model
92
- will be loaded with ``from_pretrained(..., torch_dtype=torch.float16)``.
92
+ will be loaded with ``from_pretrained(..., dtype=torch.float16)``.
93
93
  revision: Revision of the model to download from the Hugging Face Hub.
94
94
  low_cpu_mem_usage: Enable the flag ``low_cpu_mem_usage`` when loading the model
95
95
  with ``from_pretrained``.
@@ -123,10 +123,11 @@ class TransformersConverter(Converter):
123
123
  tokenizer_class = transformers.AutoTokenizer
124
124
 
125
125
  kwargs = {
126
- "torch_dtype": (
126
+ "dtype": (
127
127
  torch.float16
128
128
  if self._load_as_float16
129
- else getattr(config, "torch_dtype", None)
129
+ else getattr(config, "dtype", None)
130
+ or getattr(config, "torch_dtype", None)
130
131
  )
131
132
  }
132
133
 
@@ -235,7 +236,7 @@ class ModelLoader(abc.ABC):
235
236
 
236
237
  if isinstance(module, transformers.Conv1D):
237
238
  spec.weight = spec.weight.transpose(0, 1)
238
- if module.bias is not None:
239
+ if hasattr(module, "bias") and module.bias is not None:
239
240
  spec.bias = module.bias
240
241
 
241
242
  def set_embeddings(self, spec, module):
@@ -1819,6 +1820,192 @@ class LlamaLoader(ModelLoader):
1819
1820
  gc.collect()
1820
1821
 
1821
1822
 
1823
+ @register_loader("Gemma3TextConfig")
1824
+ @register_loader("Gemma3Config")
1825
+ class Gemma3Loader(ModelLoader):
1826
+ @property
1827
+ def architecture_name(self):
1828
+ return "Gemma3ForCausalLM"
1829
+
1830
+ def get_model_spec(self, model):
1831
+ num_layers = model.config.num_hidden_layers
1832
+ num_heads = model.config.num_attention_heads
1833
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1834
+ if num_heads_kv == num_heads:
1835
+ num_heads_kv = None
1836
+
1837
+ head_dim = model.config.head_dim
1838
+
1839
+ activation_config = getattr(
1840
+ model.config, "hidden_activation", "gelu_pytorch_tanh"
1841
+ )
1842
+
1843
+ # Get RoPE parameters
1844
+ rope_theta = getattr(model.config, "rope_theta", 1_000_000) # Global: 1M
1845
+ rope_local_base_freq = getattr(
1846
+ model.config, "rope_local_base_freq", 10_000
1847
+ ) # Local: 10k
1848
+
1849
+ # Get sliding window configuration
1850
+ sliding_window = getattr(model.config, "sliding_window", 1024)
1851
+ layer_types = getattr(model.config, "layer_types", None)
1852
+
1853
+ quantization_config = getattr(model.config, "quantization_config", None)
1854
+ if quantization_config:
1855
+ if quantization_config.quant_method == "awq":
1856
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
1857
+ if quant_type is None:
1858
+ raise NotImplementedError(
1859
+ "Quantization type '%s' is not yet implemented."
1860
+ % quantization_config.quant_method
1861
+ )
1862
+ else:
1863
+ quant_type = common_spec.Quantization.CT2
1864
+
1865
+ # Create base spec using from_config
1866
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1867
+ num_layers,
1868
+ num_heads,
1869
+ activation=(
1870
+ common_spec.Activation.GELU
1871
+ if activation_config == "gelu"
1872
+ else common_spec.Activation.GELUTanh
1873
+ ),
1874
+ pre_norm=True,
1875
+ ffn_glu=True,
1876
+ rms_norm=True,
1877
+ rotary_dim=head_dim,
1878
+ rotary_interleave=False,
1879
+ rotary_base=rope_local_base_freq, # Default to local base freq
1880
+ num_heads_kv=num_heads_kv,
1881
+ head_dim=head_dim,
1882
+ sliding_window=sliding_window, # Default to local sliding window
1883
+ pre_post_layer_norm=True,
1884
+ qk_norm=True,
1885
+ )
1886
+
1887
+ # Store layer_types for use in set_decoder
1888
+ self._layer_types = layer_types
1889
+
1890
+ # Override per-layer settings for global vs local attention
1891
+ for i, layer_type in enumerate(layer_types):
1892
+ layer = spec.decoder.layer[i]
1893
+ if layer_type == "full_attention":
1894
+ layer.self_attention.rotary_base = np.dtype("float32").type(rope_theta)
1895
+ layer.self_attention.sliding_window = np.dtype("int32").type(0)
1896
+ elif layer_type == "sliding_attention":
1897
+ layer.self_attention.rotary_base = np.dtype("float32").type(
1898
+ rope_local_base_freq
1899
+ )
1900
+ layer.self_attention.sliding_window = np.dtype("int32").type(
1901
+ sliding_window
1902
+ )
1903
+
1904
+ self.set_decoder(spec.decoder, model.model, quant_type)
1905
+ self.set_linear(spec.decoder.projection, model.lm_head)
1906
+ return spec
1907
+
1908
+ def get_vocabulary(self, model, tokenizer):
1909
+ tokens = super().get_vocabulary(model, tokenizer)
1910
+
1911
+ extra_ids = model.config.vocab_size - len(tokens)
1912
+ for i in range(extra_ids):
1913
+ tokens.append("<extra_id_%d>" % i)
1914
+ if model.config.vocab_size < len(tokens):
1915
+ tokens = tokens[: model.config.vocab_size]
1916
+
1917
+ return tokens
1918
+
1919
+ def set_vocabulary(self, spec, tokens):
1920
+ spec.register_vocabulary(tokens)
1921
+
1922
+ def set_config(self, config, model, tokenizer):
1923
+ config.bos_token = tokenizer.bos_token
1924
+ config.unk_token = tokenizer.unk_token
1925
+
1926
+ if (
1927
+ hasattr(tokenizer, "chat_template")
1928
+ and isinstance(tokenizer.chat_template, str)
1929
+ and tokenizer.chat_template.strip()
1930
+ ):
1931
+ config.eos_token = "<end_of_turn>"
1932
+ else:
1933
+ config.eos_token = tokenizer.eos_token
1934
+
1935
+ def set_layer_norm(self, spec, layer_norm):
1936
+ spec.gamma = layer_norm.weight + 1.0
1937
+
1938
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
1939
+ spec.scale_embeddings = True
1940
+ spec.start_from_zero_embedding = False
1941
+ self.set_embeddings(spec.embeddings, module.embed_tokens) # Input
1942
+ self.set_layer_norm(spec.layer_norm, module.norm) # Output
1943
+
1944
+ for layer_spec, layer in zip(spec.layer, module.layers):
1945
+ self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)
1946
+
1947
+ self.set_layer_norm(
1948
+ layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
1949
+ )
1950
+
1951
+ self.set_layer_norm(
1952
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
1953
+ )
1954
+
1955
+ self.set_layer_norm(
1956
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
1957
+ )
1958
+
1959
+ # Set QK-norm weights (Gemma 3 uses this instead of soft-capping)
1960
+ self.set_layer_norm(
1961
+ layer_spec.self_attention.q_norm, layer.self_attn.q_norm
1962
+ )
1963
+ self.set_layer_norm(
1964
+ layer_spec.self_attention.k_norm, layer.self_attn.k_norm
1965
+ )
1966
+
1967
+ # Set attention projections
1968
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
1969
+ self.set_linear(
1970
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
1971
+ )
1972
+ self.set_linear(
1973
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
1974
+ )
1975
+ self.set_linear(
1976
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
1977
+ )
1978
+
1979
+ if quant_type == common_spec.Quantization.CT2:
1980
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
1981
+ else:
1982
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
1983
+ utils.fuse_linear_prequant(
1984
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
1985
+ )
1986
+
1987
+ self.set_linear(
1988
+ layer_spec.self_attention.linear[1],
1989
+ layer.self_attn.o_proj,
1990
+ quant_type=quant_type,
1991
+ )
1992
+
1993
+ # Set FFN weights
1994
+ self.set_linear(
1995
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
1996
+ )
1997
+ self.set_linear(
1998
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
1999
+ )
2000
+ self.set_linear(
2001
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2002
+ )
2003
+
2004
+ delattr(layer, "self_attn")
2005
+ delattr(layer, "mlp")
2006
+ gc.collect()
2007
+
2008
+
1822
2009
  @register_loader("MistralConfig")
1823
2010
  class MistralLoader(ModelLoader):
1824
2011
  @property
@@ -1996,6 +2183,28 @@ class Qwen2Loader(ModelLoader):
1996
2183
  rotary_scaling_type = None
1997
2184
  rotary_scaling_factor = 1
1998
2185
 
2186
+ # Check for AWQ quantization config
2187
+ quantization_config = getattr(model.config, "quantization_config", None)
2188
+ if quantization_config:
2189
+ quant_type = None
2190
+ if quantization_config.quant_method == "awq":
2191
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2192
+ if quant_type is None:
2193
+ raise NotImplementedError(
2194
+ "Quantization type '%s' is not yet implemented. "
2195
+ "The following Quantization types are currently supported: %s"
2196
+ % (
2197
+ quantization_config.quant_method,
2198
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2199
+ )
2200
+ )
2201
+ quant_group_size = quantization_config.group_size
2202
+ quant_bits = quantization_config.bits
2203
+ else:
2204
+ quant_type = common_spec.Quantization.CT2
2205
+ quant_group_size = None
2206
+ quant_bits = None
2207
+
1999
2208
  spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2000
2209
  num_layers,
2001
2210
  num_heads,
@@ -2009,9 +2218,12 @@ class Qwen2Loader(ModelLoader):
2009
2218
  rotary_scaling_factor=rotary_scaling_factor,
2010
2219
  rotary_base=getattr(model.config, "rope_theta", 10000),
2011
2220
  num_heads_kv=num_heads_kv,
2221
+ quant_type=quant_type,
2222
+ quant_group_size=quant_group_size,
2223
+ quant_bits=quant_bits,
2012
2224
  )
2013
2225
 
2014
- self.set_decoder(spec.decoder, model.model)
2226
+ self.set_decoder(spec.decoder, model.model, quant_type)
2015
2227
  self.set_linear(spec.decoder.projection, model.lm_head)
2016
2228
  return spec
2017
2229
 
@@ -2041,7 +2253,7 @@ class Qwen2Loader(ModelLoader):
2041
2253
  def set_layer_norm(self, spec, layer_norm):
2042
2254
  spec.gamma = layer_norm.weight
2043
2255
 
2044
- def set_decoder(self, spec, module):
2256
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2045
2257
  spec.scale_embeddings = False
2046
2258
  self.set_embeddings(spec.embeddings, module.embed_tokens)
2047
2259
  self.set_layer_norm(spec.layer_norm, module.norm)
@@ -2055,72 +2267,255 @@ class Qwen2Loader(ModelLoader):
2055
2267
  )
2056
2268
 
2057
2269
  split_layers = [common_spec.LinearSpec() for _ in range(3)]
2058
- self.set_linear(split_layers[0], layer.self_attn.q_proj)
2059
- self.set_linear(split_layers[1], layer.self_attn.k_proj)
2060
- self.set_linear(split_layers[2], layer.self_attn.v_proj)
2270
+ self.set_linear(
2271
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2272
+ )
2273
+ self.set_linear(
2274
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2275
+ )
2276
+ self.set_linear(
2277
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2278
+ )
2279
+
2280
+ if quant_type == common_spec.Quantization.CT2:
2281
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2282
+ else:
2283
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2284
+ utils.fuse_linear_prequant(
2285
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2286
+ )
2061
2287
 
2062
- utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2063
2288
  self.set_linear(
2064
2289
  layer_spec.self_attention.linear[1],
2065
2290
  layer.self_attn.o_proj,
2291
+ quant_type=quant_type,
2066
2292
  )
2067
2293
 
2068
- self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
2069
- self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
2070
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2294
+ self.set_linear(
2295
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2296
+ )
2297
+ self.set_linear(
2298
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2299
+ )
2300
+ self.set_linear(
2301
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2302
+ )
2071
2303
 
2072
2304
  delattr(layer, "self_attn")
2073
2305
  delattr(layer, "mlp")
2074
2306
  gc.collect()
2075
2307
 
2076
2308
 
2077
- @register_loader("MixFormerSequentialConfig")
2078
- class MixFormerSequentialLoader(ModelLoader):
2309
+ @register_loader("Qwen3Config")
2310
+ class Qwen3Loader(ModelLoader):
2079
2311
  @property
2080
2312
  def architecture_name(self):
2081
- return "AutoModelForCausalLM"
2313
+ return "Qwen3ForCausalLM"
2082
2314
 
2083
2315
  def get_model_spec(self, model):
2084
- spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2085
- num_layers=model.config.n_layer,
2086
- num_heads=model.config.n_head,
2087
- pre_norm=True,
2088
- activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
2089
- rotary_dim=model.config.rotary_dim,
2090
- rotary_interleave=False,
2091
- parallel_residual=True,
2092
- shared_layer_norm=True,
2316
+ num_layers = model.config.num_hidden_layers
2317
+ num_heads = model.config.num_attention_heads
2318
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
2319
+ head_dim = getattr(
2320
+ model.config, "head_dim", model.config.hidden_size // num_heads
2093
2321
  )
2094
2322
 
2095
- self.set_decoder(spec.decoder, model.layers)
2096
- self.set_linear(spec.decoder.projection, model.layers[-1].linear)
2097
- return spec
2098
-
2099
- def get_vocabulary(self, model, tokenizer):
2100
- tokens = super().get_vocabulary(model, tokenizer)
2101
-
2102
- extra_ids = model.config.vocab_size - len(tokens)
2103
- for i in range(extra_ids):
2104
- tokens.append("<extra_id_%d>" % i)
2105
-
2106
- return tokens
2107
-
2108
- def set_vocabulary(self, spec, tokens):
2109
- spec.register_vocabulary(tokens)
2110
-
2111
- def set_config(self, config, model, tokenizer):
2112
- config.bos_token = tokenizer.bos_token
2113
- config.eos_token = tokenizer.eos_token
2114
- config.unk_token = tokenizer.unk_token
2323
+ if num_heads_kv == num_heads:
2324
+ num_heads_kv = None
2115
2325
 
2116
- def set_decoder(self, spec, module):
2117
- spec.scale_embeddings = False
2118
- self.set_embeddings(spec.embeddings, module[0].wte)
2119
- self.set_layer_norm(spec.layer_norm, module[-1].ln)
2326
+ rope_scaling = getattr(model.config, "rope_scaling", None)
2327
+ if rope_scaling:
2328
+ rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
2329
+ rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
2330
+ rotary_scaling_factor = rope_scaling["factor"]
2331
+ if rotary_scaling_type is None:
2332
+ raise NotImplementedError(
2333
+ "RoPE scaling type '%s' is not yet implemented. "
2334
+ "The following RoPE scaling types are currently supported: %s"
2335
+ % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
2336
+ )
2337
+ else:
2338
+ rotary_scaling_type = None
2339
+ rotary_scaling_factor = 1
2120
2340
 
2121
- for layer_spec, layer in zip(spec.layer, module[1:-1]):
2122
- self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln)
2123
- self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv)
2341
+ # Check for AWQ quantization config
2342
+ quantization_config = getattr(model.config, "quantization_config", None)
2343
+ if quantization_config:
2344
+ quant_type = None
2345
+ if quantization_config.quant_method == "awq":
2346
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2347
+ if quant_type is None:
2348
+ raise NotImplementedError(
2349
+ "Quantization type '%s' is not yet implemented. "
2350
+ "The following Quantization types are currently supported: %s"
2351
+ % (
2352
+ quantization_config.quant_method,
2353
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2354
+ )
2355
+ )
2356
+ quant_group_size = quantization_config.group_size
2357
+ quant_bits = quantization_config.bits
2358
+ else:
2359
+ quant_type = common_spec.Quantization.CT2
2360
+ quant_group_size = None
2361
+ quant_bits = None
2362
+
2363
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2364
+ num_layers,
2365
+ num_heads,
2366
+ activation=common_spec.Activation.SWISH,
2367
+ pre_norm=True,
2368
+ ffn_glu=True,
2369
+ rms_norm=True,
2370
+ rotary_dim=model.config.head_dim,
2371
+ rotary_interleave=False,
2372
+ rotary_scaling_type=rotary_scaling_type,
2373
+ rotary_scaling_factor=rotary_scaling_factor,
2374
+ rotary_base=getattr(model.config, "rope_theta", 10000),
2375
+ num_heads_kv=num_heads_kv,
2376
+ head_dim=head_dim,
2377
+ qk_norm=True,
2378
+ quant_type=quant_type,
2379
+ quant_group_size=quant_group_size,
2380
+ quant_bits=quant_bits,
2381
+ )
2382
+
2383
+ self.set_decoder(spec.decoder, model.model, quant_type)
2384
+ self.set_linear(spec.decoder.projection, model.lm_head)
2385
+ return spec
2386
+
2387
+ def get_vocabulary(self, model, tokenizer):
2388
+ tokens = super().get_vocabulary(model, tokenizer)
2389
+ extra_ids = model.config.vocab_size - len(tokens)
2390
+ for i in range(extra_ids):
2391
+ tokens.append("<extra_id_%d>" % i)
2392
+ return tokens
2393
+
2394
+ def set_vocabulary(self, spec, tokens):
2395
+ spec.register_vocabulary(tokens)
2396
+
2397
+ def set_config(self, config, model, tokenizer):
2398
+ config.bos_token = (
2399
+ tokenizer.bos_token
2400
+ if tokenizer.bos_token is not None
2401
+ else tokenizer.pad_token
2402
+ )
2403
+ config.eos_token = tokenizer.eos_token
2404
+ config.unk_token = (
2405
+ tokenizer.unk_token if tokenizer.unk_token is not None else ""
2406
+ )
2407
+ config.layer_norm_epsilon = model.config.rms_norm_eps
2408
+
2409
+ def set_layer_norm(self, spec, layer_norm):
2410
+ spec.gamma = layer_norm.weight
2411
+
2412
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2413
+ spec.scale_embeddings = False
2414
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
2415
+ self.set_layer_norm(spec.layer_norm, module.norm)
2416
+
2417
+ for layer_idx, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
2418
+ self.set_layer_norm(
2419
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
2420
+ )
2421
+ self.set_layer_norm(
2422
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2423
+ )
2424
+
2425
+ self.set_layer_norm(
2426
+ layer_spec.self_attention.q_norm, layer.self_attn.q_norm
2427
+ )
2428
+ self.set_layer_norm(
2429
+ layer_spec.self_attention.k_norm, layer.self_attn.k_norm
2430
+ )
2431
+
2432
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
2433
+ self.set_linear(
2434
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2435
+ )
2436
+ self.set_linear(
2437
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2438
+ )
2439
+ self.set_linear(
2440
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2441
+ )
2442
+
2443
+ if quant_type == common_spec.Quantization.CT2:
2444
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2445
+ else:
2446
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2447
+ utils.fuse_linear_prequant(
2448
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2449
+ )
2450
+
2451
+ self.set_linear(
2452
+ layer_spec.self_attention.linear[1],
2453
+ layer.self_attn.o_proj,
2454
+ quant_type=quant_type,
2455
+ )
2456
+
2457
+ self.set_linear(
2458
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2459
+ )
2460
+ self.set_linear(
2461
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2462
+ )
2463
+ self.set_linear(
2464
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2465
+ )
2466
+
2467
+ delattr(layer, "self_attn")
2468
+ delattr(layer, "mlp")
2469
+ gc.collect()
2470
+
2471
+
2472
+ @register_loader("MixFormerSequentialConfig")
2473
+ class MixFormerSequentialLoader(ModelLoader):
2474
+ @property
2475
+ def architecture_name(self):
2476
+ return "AutoModelForCausalLM"
2477
+
2478
+ def get_model_spec(self, model):
2479
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2480
+ num_layers=model.config.n_layer,
2481
+ num_heads=model.config.n_head,
2482
+ pre_norm=True,
2483
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
2484
+ rotary_dim=model.config.rotary_dim,
2485
+ rotary_interleave=False,
2486
+ parallel_residual=True,
2487
+ shared_layer_norm=True,
2488
+ )
2489
+
2490
+ self.set_decoder(spec.decoder, model.layers)
2491
+ self.set_linear(spec.decoder.projection, model.layers[-1].linear)
2492
+ return spec
2493
+
2494
+ def get_vocabulary(self, model, tokenizer):
2495
+ tokens = super().get_vocabulary(model, tokenizer)
2496
+
2497
+ extra_ids = model.config.vocab_size - len(tokens)
2498
+ for i in range(extra_ids):
2499
+ tokens.append("<extra_id_%d>" % i)
2500
+
2501
+ return tokens
2502
+
2503
+ def set_vocabulary(self, spec, tokens):
2504
+ spec.register_vocabulary(tokens)
2505
+
2506
+ def set_config(self, config, model, tokenizer):
2507
+ config.bos_token = tokenizer.bos_token
2508
+ config.eos_token = tokenizer.eos_token
2509
+ config.unk_token = tokenizer.unk_token
2510
+
2511
+ def set_decoder(self, spec, module):
2512
+ spec.scale_embeddings = False
2513
+ self.set_embeddings(spec.embeddings, module[0].wte)
2514
+ self.set_layer_norm(spec.layer_norm, module[-1].ln)
2515
+
2516
+ for layer_spec, layer in zip(spec.layer, module[1:-1]):
2517
+ self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln)
2518
+ self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv)
2124
2519
  self.set_linear(layer_spec.self_attention.linear[1], layer.mixer.out_proj)
2125
2520
  self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc1)
2126
2521
  self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
@@ -2211,6 +2606,28 @@ class Phi3Loader(ModelLoader):
2211
2606
  rotary_scaling_type = None
2212
2607
  rotary_scaling_factor = 1
2213
2608
 
2609
+ # Check for AWQ quantization config
2610
+ quantization_config = getattr(model.config, "quantization_config", None)
2611
+ if quantization_config:
2612
+ quant_type = None
2613
+ if quantization_config.quant_method == "awq":
2614
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2615
+ if quant_type is None:
2616
+ raise NotImplementedError(
2617
+ "Quantization type '%s' is not yet implemented. "
2618
+ "The following Quantization types are currently supported: %s"
2619
+ % (
2620
+ quantization_config.quant_method,
2621
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2622
+ )
2623
+ )
2624
+ quant_group_size = quantization_config.group_size
2625
+ quant_bits = quantization_config.bits
2626
+ else:
2627
+ quant_type = common_spec.Quantization.CT2
2628
+ quant_group_size = None
2629
+ quant_bits = None
2630
+
2214
2631
  spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2215
2632
  num_layers,
2216
2633
  num_heads,
@@ -2226,9 +2643,12 @@ class Phi3Loader(ModelLoader):
2226
2643
  original_max_position_embeddings=original_max_position_embeddings,
2227
2644
  max_position_embeddings=max_position_embeddings,
2228
2645
  num_heads_kv=num_heads_kv,
2646
+ quant_type=quant_type,
2647
+ quant_group_size=quant_group_size,
2648
+ quant_bits=quant_bits,
2229
2649
  )
2230
2650
 
2231
- self.set_decoder(spec.decoder, model.model)
2651
+ self.set_decoder(spec.decoder, model.model, quant_type)
2232
2652
  self.set_linear(spec.decoder.projection, model.lm_head)
2233
2653
  return spec
2234
2654
 
@@ -2262,7 +2682,7 @@ class Phi3Loader(ModelLoader):
2262
2682
  rotary_scaling_short_factor, dtype=torch.float32
2263
2683
  )
2264
2684
 
2265
- def set_decoder(self, spec, module):
2685
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2266
2686
  spec.scale_embeddings = False
2267
2687
  self.set_embeddings(spec.embeddings, module.embed_tokens)
2268
2688
  self.set_layer_norm(spec.layer_norm, module.norm)
@@ -2276,9 +2696,15 @@ class Phi3Loader(ModelLoader):
2276
2696
  )
2277
2697
 
2278
2698
  self.set_linear(
2279
- layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj
2699
+ layer_spec.self_attention.linear[0],
2700
+ layer.self_attn.qkv_proj,
2701
+ quant_type=quant_type,
2702
+ )
2703
+ self.set_linear(
2704
+ layer_spec.self_attention.linear[1],
2705
+ layer.self_attn.o_proj,
2706
+ quant_type=quant_type,
2280
2707
  )
2281
- self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)
2282
2708
  if (
2283
2709
  layer.self_attn.rotary_emb.long_factor is not None
2284
2710
  and layer.self_attn.rotary_emb.short_factor is not None
@@ -2289,10 +2715,30 @@ class Phi3Loader(ModelLoader):
2289
2715
  layer.self_attn.rotary_emb.short_factor,
2290
2716
  )
2291
2717
 
2292
- gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
2293
- layer_spec.ffn.linear_0.weight = gate_proj
2294
- layer_spec.ffn.linear_0_noact.weight = up_proj
2295
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2718
+ # Handle gate_up_proj differently for AWQ vs regular models
2719
+ if quant_type == common_spec.Quantization.CT2:
2720
+ gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
2721
+ layer_spec.ffn.linear_0.weight = gate_proj
2722
+ layer_spec.ffn.linear_0_noact.weight = up_proj
2723
+ else:
2724
+ # AWQ: chunk qweight, scales, and qzeros
2725
+ gate_qweight, up_qweight = layer.mlp.gate_up_proj.qweight.chunk(
2726
+ 2, dim=1
2727
+ )
2728
+ gate_scales, up_scales = layer.mlp.gate_up_proj.scales.chunk(2, dim=1)
2729
+ gate_qzeros, up_qzeros = layer.mlp.gate_up_proj.qzeros.chunk(2, dim=1)
2730
+
2731
+ layer_spec.ffn.linear_0.weight = gate_qweight
2732
+ layer_spec.ffn.linear_0.weight_scale = gate_scales
2733
+ layer_spec.ffn.linear_0.weight_zero = gate_qzeros
2734
+
2735
+ layer_spec.ffn.linear_0_noact.weight = up_qweight
2736
+ layer_spec.ffn.linear_0_noact.weight_scale = up_scales
2737
+ layer_spec.ffn.linear_0_noact.weight_zero = up_qzeros
2738
+
2739
+ self.set_linear(
2740
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2741
+ )
2296
2742
 
2297
2743
  delattr(layer, "self_attn")
2298
2744
  delattr(layer, "mlp")
@@ -3022,3 +3468,266 @@ _WHISPER_ALIGNMENT_HEADS = {
3022
3468
  (25, 6),
3023
3469
  ],
3024
3470
  }
3471
+
3472
+
3473
+ # Paper: https://arxiv.org/pdf/2504.06225
3474
+ @register_loader("T5GemmaConfig")
3475
+ class T5GemmaLoader(ModelLoader):
3476
+ @property
3477
+ def architecture_name(self):
3478
+ return "T5GemmaForConditionalGeneration"
3479
+
3480
+ def set_layer_norm(self, spec, layer_norm):
3481
+ spec.gamma = layer_norm.weight.data + 1.0
3482
+
3483
+ def get_model_spec(self, model):
3484
+ encoder_config = model.config.encoder
3485
+ decoder_config = model.config.decoder
3486
+ sliding_window = getattr(model.config, "sliding_window", 4096)
3487
+
3488
+ encoder_num_heads = encoder_config.num_attention_heads
3489
+ encoder_num_heads_kv = getattr(
3490
+ encoder_config, "num_key_value_heads", encoder_num_heads
3491
+ )
3492
+ if encoder_num_heads_kv == encoder_num_heads:
3493
+ encoder_num_heads_kv = None
3494
+
3495
+ encoder = transformer_spec.TransformerEncoderSpec(
3496
+ encoder_config.num_hidden_layers,
3497
+ encoder_config.num_attention_heads,
3498
+ pre_norm=True,
3499
+ activation=_SUPPORTED_ACTIVATIONS[encoder_config.hidden_activation],
3500
+ ffn_glu=True,
3501
+ rms_norm=True,
3502
+ rotary_dim=encoder_config.head_dim,
3503
+ rotary_interleave=False,
3504
+ rotary_base=getattr(encoder_config, "rope_theta", 10000),
3505
+ sliding_window=sliding_window,
3506
+ pre_post_layer_norm=True,
3507
+ num_heads_kv=encoder_num_heads_kv,
3508
+ head_dim=encoder_config.head_dim,
3509
+ )
3510
+
3511
+ decoder_num_heads = decoder_config.num_attention_heads
3512
+ decoder_num_heads_kv = getattr(
3513
+ decoder_config, "num_key_value_heads", decoder_num_heads
3514
+ )
3515
+ if decoder_num_heads_kv == decoder_num_heads:
3516
+ decoder_num_heads_kv = None
3517
+
3518
+ decoder = transformer_spec.TransformerDecoderSpec(
3519
+ decoder_config.num_hidden_layers,
3520
+ decoder_config.num_attention_heads,
3521
+ pre_norm=True,
3522
+ activation=_SUPPORTED_ACTIVATIONS[decoder_config.hidden_activation],
3523
+ ffn_glu=True,
3524
+ rms_norm=True,
3525
+ with_encoder_attention=True,
3526
+ rotary_dim=decoder_config.head_dim,
3527
+ rotary_interleave=False,
3528
+ rotary_base=getattr(decoder_config, "rope_theta", 10000),
3529
+ sliding_window=sliding_window,
3530
+ pre_post_layer_norm=True,
3531
+ external_pre_post_encoder_layers=True,
3532
+ num_heads_kv=decoder_num_heads_kv,
3533
+ head_dim=decoder_config.head_dim,
3534
+ )
3535
+
3536
+ spec = transformer_spec.TransformerSpec(encoder, decoder)
3537
+
3538
+ self.set_encoder(spec.encoder, model.model.encoder, encoder_config)
3539
+
3540
+ self.set_decoder(
3541
+ spec.decoder,
3542
+ model.model.decoder,
3543
+ decoder_config,
3544
+ common_spec.Quantization.CT2,
3545
+ )
3546
+
3547
+ # Tie_word_embeddings
3548
+ self.set_linear(spec.decoder.projection, model.model.decoder.embed_tokens)
3549
+ return spec
3550
+
3551
+ def set_vocabulary(self, spec, tokens):
3552
+ spec.register_source_vocabulary(tokens)
3553
+ spec.register_target_vocabulary(tokens)
3554
+
3555
+ def set_config(self, config, model, tokenizer):
3556
+ config.bos_token = tokenizer.bos_token
3557
+ config.eos_token = tokenizer.eos_token
3558
+ config.unk_token = tokenizer.unk_token
3559
+
3560
+ if hasattr(model.config, "encoder"):
3561
+ config.layer_norm_epsilon = model.config.encoder.rms_norm_eps
3562
+ elif hasattr(model.config, "rms_norm_eps"):
3563
+ config.layer_norm_epsilon = model.config.rms_norm_eps
3564
+ else:
3565
+ config.layer_norm_epsilon = 1e-6
3566
+
3567
+ config.decoder_start_token = tokenizer.bos_token
3568
+
3569
+ def set_encoder(
3570
+ self, spec, encoder, encoder_config, quant_type=common_spec.Quantization.CT2
3571
+ ):
3572
+ spec.scale_embeddings = True
3573
+
3574
+ encoder_emb_spec = (
3575
+ spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
3576
+ )
3577
+
3578
+ self.set_embeddings(encoder_emb_spec, encoder.embed_tokens)
3579
+ encoder_emb_spec.multiply_by_sqrt_depth = encoder_config.hidden_size**0.5
3580
+ self.set_layer_norm(spec.layer_norm, encoder.norm)
3581
+
3582
+ module = encoder
3583
+ for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
3584
+ self.set_layer_norm(
3585
+ layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
3586
+ )
3587
+ self.set_layer_norm(
3588
+ layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
3589
+ )
3590
+
3591
+ # T5GemmaSelfAttention
3592
+ qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
3593
+ self.set_linear(
3594
+ qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
3595
+ )
3596
+ self.set_linear(
3597
+ qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
3598
+ )
3599
+ self.set_linear(
3600
+ qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
3601
+ )
3602
+ utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
3603
+ self.set_linear(
3604
+ layer_spec.self_attention.linear[1],
3605
+ layer.self_attn.o_proj,
3606
+ quant_type=quant_type,
3607
+ )
3608
+
3609
+ # T5GemmaRMSNorm
3610
+ self.set_layer_norm(
3611
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
3612
+ )
3613
+ # T5GemmaRMSNorm
3614
+ self.set_layer_norm(
3615
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
3616
+ )
3617
+
3618
+ # T5GemmaMLP
3619
+ self.set_linear(
3620
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
3621
+ )
3622
+ self.set_linear(
3623
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
3624
+ )
3625
+ self.set_linear(
3626
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
3627
+ )
3628
+
3629
+ # Clean up
3630
+ delattr(layer, "self_attn")
3631
+ delattr(layer, "mlp")
3632
+ gc.collect()
3633
+
3634
+ def set_decoder(
3635
+ self, spec, module, decoder_config, quant_type=common_spec.Quantization.CT2
3636
+ ):
3637
+ spec.scale_embeddings = True
3638
+ spec.start_from_zero_embedding = False
3639
+
3640
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
3641
+ spec.embeddings.multiply_by_sqrt_depth = decoder_config.hidden_size**0.5
3642
+ self.set_layer_norm(spec.layer_norm, module.norm)
3643
+
3644
+ for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
3645
+ # Self-attention block
3646
+ self.set_layer_norm(
3647
+ layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
3648
+ )
3649
+ self.set_layer_norm(
3650
+ layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
3651
+ )
3652
+
3653
+ # T5GemmaSelfAttention - QKV projections
3654
+ qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
3655
+ self.set_linear(
3656
+ qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
3657
+ )
3658
+ self.set_linear(
3659
+ qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
3660
+ )
3661
+ self.set_linear(
3662
+ qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
3663
+ )
3664
+ utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
3665
+ self.set_linear(
3666
+ layer_spec.self_attention.linear[1],
3667
+ layer.self_attn.o_proj,
3668
+ quant_type=quant_type,
3669
+ )
3670
+
3671
+ # Pre and post cross-attention layer norm
3672
+ self.set_layer_norm(
3673
+ layer_spec.external_pre_encoder_attention_layer_norm,
3674
+ layer.pre_cross_attn_layernorm,
3675
+ )
3676
+
3677
+ self.set_layer_norm(
3678
+ layer_spec.external_post_encoder_attention_layer_norm,
3679
+ layer.post_cross_attn_layernorm,
3680
+ )
3681
+
3682
+ # Cross-attention Q projection
3683
+ self.set_linear(
3684
+ layer_spec.attention.linear[0],
3685
+ layer.cross_attn.q_proj,
3686
+ quant_type=quant_type,
3687
+ )
3688
+
3689
+ # Cross-attention K+V fused
3690
+ kv_split_layers = [common_spec.LinearSpec() for _ in range(2)]
3691
+ self.set_linear(
3692
+ kv_split_layers[0],
3693
+ layer.cross_attn.k_proj,
3694
+ quant_type=quant_type,
3695
+ )
3696
+ self.set_linear(
3697
+ kv_split_layers[1],
3698
+ layer.cross_attn.v_proj,
3699
+ quant_type=quant_type,
3700
+ )
3701
+ utils.fuse_linear(layer_spec.attention.linear[1], kv_split_layers)
3702
+
3703
+ # Cross-attention output projection
3704
+ self.set_linear(
3705
+ layer_spec.attention.linear[2],
3706
+ layer.cross_attn.o_proj,
3707
+ quant_type=quant_type,
3708
+ )
3709
+
3710
+ # Feed-forward block
3711
+ self.set_layer_norm(
3712
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
3713
+ )
3714
+ self.set_layer_norm(
3715
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
3716
+ )
3717
+
3718
+ # T5GemmaMLP
3719
+ self.set_linear(
3720
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
3721
+ )
3722
+ self.set_linear(
3723
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
3724
+ )
3725
+ self.set_linear(
3726
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
3727
+ )
3728
+
3729
+ # Clean up
3730
+ delattr(layer, "self_attn")
3731
+ delattr(layer, "cross_attn")
3732
+ delattr(layer, "mlp")
3733
+ gc.collect()