ctranslate2 4.7.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3721 @@
1
+ import abc
2
+ import argparse
3
+ import gc
4
+ import itertools
5
+ import os
6
+
7
+ from typing import List, Optional
8
+
9
+ import numpy as np
10
+
11
+ try:
12
+ import huggingface_hub
13
+ import torch
14
+ import transformers
15
+ except ImportError:
16
+ pass
17
+
18
+ from ctranslate2.converters import utils
19
+ from ctranslate2.converters.converter import Converter
20
+ from ctranslate2.specs import (
21
+ attention_spec,
22
+ common_spec,
23
+ model_spec,
24
+ transformer_spec,
25
+ wav2vec2_spec,
26
+ wav2vec2bert_spec,
27
+ whisper_spec,
28
+ )
29
+
30
+ _SUPPORTED_ACTIVATIONS = {
31
+ "gelu": common_spec.Activation.GELU,
32
+ "gelu_fast": common_spec.Activation.GELUTanh,
33
+ "gelu_new": common_spec.Activation.GELUTanh,
34
+ "gelu_python": common_spec.Activation.GELU,
35
+ "gelu_pytorch_tanh": common_spec.Activation.GELUTanh,
36
+ "quick_gelu": common_spec.Activation.GELUSigmoid,
37
+ "relu": common_spec.Activation.RELU,
38
+ "silu": common_spec.Activation.SWISH,
39
+ "swish": common_spec.Activation.SWISH,
40
+ }
41
+
42
+ _SUPPORTED_ROPE_SCALING = {
43
+ "linear": attention_spec.RotaryScalingType.Linear,
44
+ "su": attention_spec.RotaryScalingType.Su,
45
+ "llama3": attention_spec.RotaryScalingType.Llama3,
46
+ "longrope": attention_spec.RotaryScalingType.Su,
47
+ }
48
+
49
+ _SUPPORTED_QUANTIZATION = {
50
+ "gemm": common_spec.Quantization.AWQ_GEMM,
51
+ "gemv": common_spec.Quantization.AWQ_GEMV,
52
+ }
53
+
54
+ _MODEL_LOADERS = {}
55
+
56
+
57
+ def register_loader(config_name):
58
+ """Registers a model loader for this configuration name."""
59
+
60
+ def decorator(cls):
61
+ _MODEL_LOADERS[config_name] = cls()
62
+ return cls
63
+
64
+ return decorator
65
+
66
+
67
+ class TransformersConverter(Converter):
68
+ """Converts models from Hugging Face Transformers."""
69
+
70
+ def __init__(
71
+ self,
72
+ model_name_or_path: str,
73
+ activation_scales: Optional[str] = None,
74
+ copy_files: Optional[List[str]] = None,
75
+ load_as_float16: bool = False,
76
+ revision: Optional[str] = None,
77
+ low_cpu_mem_usage: bool = False,
78
+ trust_remote_code: bool = False,
79
+ ):
80
+ """Initializes the converter.
81
+
82
+ Arguments:
83
+ model_name_or_path: Name of the pretrained model to download, or path to the
84
+ directory containing the pretrained model.
85
+ activation_scales: Path to the pre-computed activation scales. Models may
86
+ use them to rescale some weights to smooth the intermediate activations
87
+ and improve the quantization accuracy. See
88
+ https://github.com/mit-han-lab/smoothquant.
89
+ copy_files: List of filenames to copy from the Hugging Face model to the
90
+ converted model directory.
91
+ load_as_float16: Load the model weights as float16. More precisely, the model
92
+ will be loaded with ``from_pretrained(..., dtype=torch.float16)``.
93
+ revision: Revision of the model to download from the Hugging Face Hub.
94
+ low_cpu_mem_usage: Enable the flag ``low_cpu_mem_usage`` when loading the model
95
+ with ``from_pretrained``.
96
+ trust_remote_code: Allow converting models using custom code.
97
+ """
98
+ self._model_name_or_path = model_name_or_path
99
+ self._activation_scales = activation_scales
100
+ self._copy_files = copy_files
101
+ self._load_as_float16 = load_as_float16
102
+ self._revision = revision
103
+ self._low_cpu_mem_usage = low_cpu_mem_usage
104
+ self._trust_remote_code = trust_remote_code
105
+
106
+ def _load(self):
107
+ with torch.no_grad():
108
+ config = transformers.AutoConfig.from_pretrained(
109
+ self._model_name_or_path, trust_remote_code=self._trust_remote_code
110
+ )
111
+
112
+ config_name = config.__class__.__name__
113
+ loader = _MODEL_LOADERS.get(config_name)
114
+
115
+ if loader is None:
116
+ raise ValueError(
117
+ "No conversion is registered for the model configuration %s "
118
+ "(supported configurations are: %s)"
119
+ % (config_name, ", ".join(sorted(_MODEL_LOADERS.keys())))
120
+ )
121
+
122
+ model_class = getattr(transformers, loader.architecture_name)
123
+ tokenizer_class = transformers.AutoTokenizer
124
+
125
+ kwargs = {
126
+ "dtype": (
127
+ torch.float16
128
+ if self._load_as_float16
129
+ else getattr(config, "dtype", None)
130
+ or getattr(config, "torch_dtype", None)
131
+ )
132
+ }
133
+
134
+ if self._revision:
135
+ kwargs["revision"] = self._revision
136
+ if self._low_cpu_mem_usage:
137
+ kwargs["low_cpu_mem_usage"] = self._low_cpu_mem_usage
138
+ if self._trust_remote_code:
139
+ kwargs["trust_remote_code"] = self._trust_remote_code
140
+
141
+ model = self.load_model(model_class, self._model_name_or_path, **kwargs)
142
+
143
+ tokenizer_kwargs = {}
144
+ if self._trust_remote_code:
145
+ tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code
146
+
147
+ tokenizer = self.load_tokenizer(
148
+ tokenizer_class, self._model_name_or_path, **tokenizer_kwargs
149
+ )
150
+
151
+ spec = loader(model, tokenizer)
152
+
153
+ if self._activation_scales:
154
+ activation_scales = torch.load(
155
+ self._activation_scales, map_location="cpu"
156
+ )
157
+ loader.smooth_activation(spec, activation_scales)
158
+
159
+ if self._copy_files:
160
+ for filename in self._copy_files:
161
+ spec.register_file(self.get_model_file(filename))
162
+
163
+ return spec
164
+
165
+ def load_model(self, model_class, model_name_or_path, **kwargs):
166
+ return model_class.from_pretrained(model_name_or_path, **kwargs)
167
+
168
+ def load_tokenizer(self, tokenizer_class, model_name_or_path, **kwargs):
169
+ return tokenizer_class.from_pretrained(model_name_or_path, **kwargs)
170
+
171
+ def get_model_file(self, filename):
172
+ if os.path.isdir(self._model_name_or_path):
173
+ path = os.path.join(self._model_name_or_path, filename)
174
+ else:
175
+ try:
176
+ path = huggingface_hub.hf_hub_download(
177
+ repo_id=self._model_name_or_path, filename=filename
178
+ )
179
+ except huggingface_hub.utils.EntryNotFoundError:
180
+ path = None
181
+
182
+ if path is None or not os.path.isfile(path):
183
+ raise ValueError(
184
+ "File %s does not exist in model %s"
185
+ % (filename, self._model_name_or_path)
186
+ )
187
+
188
+ return path
189
+
190
+
191
+ class ModelLoader(abc.ABC):
192
+ """Base class for loading Transformers models into a CTranslate2 model specification."""
193
+
194
+ @property
195
+ def architecture_name(self):
196
+ return None
197
+
198
+ @abc.abstractmethod
199
+ def get_model_spec(self, model):
200
+ raise NotImplementedError()
201
+
202
+ def __call__(self, model, tokenizer):
203
+ spec = self.get_model_spec(model)
204
+ self.set_config(spec.config, model, tokenizer)
205
+
206
+ tokens = self.get_vocabulary(model, tokenizer)
207
+ self.set_vocabulary(spec, tokens)
208
+
209
+ return spec
210
+
211
+ def get_vocabulary(self, model, tokenizer):
212
+ return [
213
+ token
214
+ for token, _ in sorted(
215
+ tokenizer.get_vocab().items(), key=lambda item: item[1]
216
+ )
217
+ ]
218
+
219
+ def set_vocabulary(self, spec, tokens):
220
+ pass
221
+
222
+ def set_config(self, config, model, tokenizer):
223
+ pass
224
+
225
+ def set_layer_norm(self, spec, module):
226
+ spec.gamma = module.weight
227
+ spec.beta = module.bias
228
+
229
+ def set_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
230
+ if quant_type == common_spec.Quantization.CT2:
231
+ spec.weight = module.weight
232
+ else:
233
+ spec.weight = module.qweight
234
+ spec.weight_scale = module.scales
235
+ spec.weight_zero = module.qzeros
236
+
237
+ if isinstance(module, transformers.Conv1D):
238
+ spec.weight = spec.weight.transpose(0, 1)
239
+ if hasattr(module, "bias") and module.bias is not None:
240
+ spec.bias = module.bias
241
+
242
+ def set_embeddings(self, spec, module):
243
+ spec.weight = module.weight
244
+
245
+ def set_position_encodings(self, spec, module):
246
+ spec.encodings = module.weight
247
+ offset = getattr(module, "offset", 0)
248
+ if offset > 0:
249
+ spec.encodings = spec.encodings[offset:]
250
+
251
+ def smooth_activation(self, spec, activation_scales):
252
+ raise NotImplementedError(
253
+ "No activation smoothing logic is defined for this model"
254
+ )
255
+
256
+ def get_rotary_params(self, config, default_rope_theta):
257
+ rope_scaling = getattr(config, "rope_scaling", None)
258
+ if rope_scaling:
259
+ rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type")
260
+
261
+ if rope_type == "default":
262
+ rotary_scaling_type = None
263
+ else:
264
+ rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
265
+ if rotary_scaling_type is None:
266
+ raise NotImplementedError(
267
+ "RoPE scaling type '%s' is not yet implemented. "
268
+ "The following RoPE scaling types are currently supported: %s"
269
+ % (rope_type, ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
270
+ )
271
+ rotary_scaling_factor = rope_scaling.get("factor", 1)
272
+ rope_theta = rope_scaling.get("rope_theta", default_rope_theta)
273
+ else:
274
+ rotary_scaling_type = None
275
+ rotary_scaling_factor = 1
276
+ rope_theta = getattr(config, "rope_theta", default_rope_theta)
277
+
278
+ return rotary_scaling_type, rotary_scaling_factor, rope_theta
279
+
280
+
281
+ @register_loader("BartConfig")
282
+ class BartLoader(ModelLoader):
283
+ @property
284
+ def architecture_name(self):
285
+ return "BartForConditionalGeneration"
286
+
287
+ def get_model_spec(self, model):
288
+ spec = transformer_spec.TransformerSpec.from_config(
289
+ (model.config.encoder_layers, model.config.decoder_layers),
290
+ model.config.encoder_attention_heads,
291
+ pre_norm=model.config.normalize_before,
292
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
293
+ layernorm_embedding=getattr(model.config, "normalize_embedding", True),
294
+ )
295
+
296
+ self.set_encoder(spec.encoder, model.model.encoder)
297
+ self.set_decoder(spec.decoder, model.model.decoder)
298
+ self.set_linear(spec.decoder.projection, model.lm_head)
299
+
300
+ final_logits_bias = getattr(model, "final_logits_bias", None)
301
+ if final_logits_bias is not None and final_logits_bias.nonzero().numel() != 0:
302
+ spec.decoder.projection.bias = final_logits_bias.squeeze()
303
+
304
+ return spec
305
+
306
+ def get_vocabulary(self, model, tokenizer):
307
+ tokens = super().get_vocabulary(model, tokenizer)
308
+ if model.config.vocab_size < len(tokens):
309
+ tokens = tokens[: model.config.vocab_size]
310
+ return tokens
311
+
312
+ def set_vocabulary(self, spec, tokens):
313
+ spec.register_source_vocabulary(tokens)
314
+ spec.register_target_vocabulary(tokens)
315
+
316
+ def set_config(self, config, model, tokenizer):
317
+ config.bos_token = tokenizer.bos_token
318
+ config.eos_token = tokenizer.eos_token
319
+ config.unk_token = tokenizer.unk_token
320
+ config.decoder_start_token = tokenizer.convert_ids_to_tokens(
321
+ model.config.decoder_start_token_id
322
+ )
323
+
324
+ def set_encoder(self, spec, encoder):
325
+ self.set_common_layers(spec, encoder)
326
+
327
+ for layer_spec, layer in zip(spec.layer, encoder.layers):
328
+ self.set_attention(
329
+ layer_spec.self_attention,
330
+ layer.self_attn,
331
+ self_attention=True,
332
+ )
333
+ self.set_layer_norm(
334
+ layer_spec.self_attention.layer_norm,
335
+ layer.self_attn_layer_norm,
336
+ )
337
+
338
+ self.set_linear(layer_spec.ffn.linear_0, layer.fc1)
339
+ self.set_linear(layer_spec.ffn.linear_1, layer.fc2)
340
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm)
341
+
342
+ def set_decoder(self, spec, decoder):
343
+ self.set_common_layers(spec, decoder)
344
+
345
+ for layer_spec, layer in zip(spec.layer, decoder.layers):
346
+ self.set_attention(
347
+ layer_spec.self_attention,
348
+ layer.self_attn,
349
+ self_attention=True,
350
+ )
351
+ self.set_layer_norm(
352
+ layer_spec.self_attention.layer_norm,
353
+ layer.self_attn_layer_norm,
354
+ )
355
+
356
+ if hasattr(layer, "encoder_attn"):
357
+ self.set_attention(
358
+ layer_spec.attention,
359
+ layer.encoder_attn,
360
+ self_attention=False,
361
+ )
362
+ self.set_layer_norm(
363
+ layer_spec.attention.layer_norm,
364
+ layer.encoder_attn_layer_norm,
365
+ )
366
+
367
+ self.set_linear(layer_spec.ffn.linear_0, layer.fc1)
368
+ self.set_linear(layer_spec.ffn.linear_1, layer.fc2)
369
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm)
370
+
371
+ def set_attention(self, spec, attention, self_attention=False):
372
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
373
+ self.set_linear(split_layers[0], attention.q_proj)
374
+ self.set_linear(split_layers[1], attention.k_proj)
375
+ self.set_linear(split_layers[2], attention.v_proj)
376
+
377
+ if self_attention:
378
+ utils.fuse_linear(spec.linear[0], split_layers)
379
+ else:
380
+ utils.fuse_linear(spec.linear[0], split_layers[:1])
381
+ utils.fuse_linear(spec.linear[1], split_layers[1:])
382
+
383
+ self.set_linear(spec.linear[-1], attention.out_proj)
384
+
385
+ def set_common_layers(self, spec, module):
386
+ import math
387
+
388
+ if not hasattr(module, "embed_scale"):
389
+ embed_scale = (
390
+ math.sqrt(module.config.d_model)
391
+ if module.config.scale_embedding
392
+ else 1.0
393
+ )
394
+ else:
395
+ embed_scale = module.embed_scale
396
+ spec.scale_embeddings = embed_scale
397
+ self.set_position_encodings(spec.position_encodings, module.embed_positions)
398
+ self.set_embeddings(
399
+ (
400
+ spec.embeddings[0]
401
+ if isinstance(spec.embeddings, list)
402
+ else spec.embeddings
403
+ ),
404
+ module.embed_tokens,
405
+ )
406
+
407
+ if hasattr(module, "layer_norm"):
408
+ self.set_layer_norm(spec.layer_norm, module.layer_norm)
409
+ if hasattr(module, "layernorm_embedding"):
410
+ self.set_layer_norm(spec.layernorm_embedding, module.layernorm_embedding)
411
+
412
+
413
+ @register_loader("MarianConfig")
414
+ class MarianMTLoader(BartLoader):
415
+ @property
416
+ def architecture_name(self):
417
+ return "MarianMTModel"
418
+
419
+ def get_model_spec(self, model):
420
+ model.config.normalize_before = False
421
+ model.config.normalize_embedding = False
422
+ spec = super().get_model_spec(model)
423
+ self._remove_pad_weights(spec)
424
+ return spec
425
+
426
+ def set_config(self, config, model, tokenizer):
427
+ config.eos_token = tokenizer.eos_token
428
+ config.unk_token = tokenizer.unk_token
429
+
430
+ # The decoder start token can be any token because the decoder always starts
431
+ # from a zero embedding.
432
+ config.decoder_start_token = tokenizer.eos_token
433
+
434
+ def set_decoder(self, spec, decoder):
435
+ spec.start_from_zero_embedding = True
436
+ super().set_decoder(spec, decoder)
437
+
438
+ def get_vocabulary(self, model, tokenizer):
439
+ # The <pad> token is added by Transformers to start the decoder from a zero embedding,
440
+ # but we already have a dedicated option "start_from_zero_embedding". We remove this token
441
+ # to match the original Marian vocabulary and prevent this token from being generated.
442
+ tokens = super().get_vocabulary(model, tokenizer)
443
+ if tokens[-1] == "<pad>":
444
+ tokens.pop()
445
+ return tokens
446
+
447
+ def _remove_pad_weights(self, spec):
448
+ vocab_specs = [
449
+ spec.encoder.embeddings[0],
450
+ spec.decoder.embeddings,
451
+ spec.decoder.projection,
452
+ ]
453
+
454
+ # Weights may be shared so we check against the expected size to prevent
455
+ # updating the same weight multiple times.
456
+ new_vocab_size = vocab_specs[0].weight.shape[0] - 1
457
+
458
+ for vocab_spec in vocab_specs:
459
+ if vocab_spec.weight.shape[0] == new_vocab_size + 1:
460
+ vocab_spec.weight = vocab_spec.weight[:-1]
461
+ if (
462
+ isinstance(vocab_spec, common_spec.LinearSpec)
463
+ and vocab_spec.has_bias()
464
+ and vocab_spec.bias.shape[0] == new_vocab_size + 1
465
+ ):
466
+ vocab_spec.bias = vocab_spec.bias[:-1]
467
+
468
+
469
+ @register_loader("M2M100Config")
470
+ class M2M100Loader(BartLoader):
471
+ @property
472
+ def architecture_name(self):
473
+ return "M2M100ForConditionalGeneration"
474
+
475
+ def get_model_spec(self, model):
476
+ model.config.normalize_before = True
477
+ model.config.normalize_embedding = False
478
+ return super().get_model_spec(model)
479
+
480
+ def set_position_encodings(self, spec, module):
481
+ spec.encodings = module.weights[module.offset :]
482
+
483
+ def get_vocabulary(self, model, tokenizer):
484
+ tokens = super().get_vocabulary(model, tokenizer)
485
+
486
+ # Workaround for issue https://github.com/OpenNMT/CTranslate2/issues/1039.
487
+ if tokens[-1] == tokenizer.unk_token:
488
+ tokens.insert(tokenizer.unk_token_id, tokens.pop())
489
+
490
+ for token in tokenizer.special_tokens_map.get("additional_special_tokens", []):
491
+ if token not in tokens:
492
+ tokens.append(token)
493
+
494
+ num_madeup_words = getattr(
495
+ tokenizer, "num_madeup_words", model.config.vocab_size - len(tokens)
496
+ )
497
+ if num_madeup_words > 0:
498
+ tokens += ["madeupword%d" % i for i in range(num_madeup_words)]
499
+
500
+ return tokens
501
+
502
+
503
+ @register_loader("MBartConfig")
504
+ class MBartLoader(BartLoader):
505
+ @property
506
+ def architecture_name(self):
507
+ return "MBartForConditionalGeneration"
508
+
509
+ def set_config(self, config, model, tokenizer):
510
+ config.bos_token = tokenizer.bos_token
511
+ config.eos_token = tokenizer.eos_token
512
+ config.unk_token = tokenizer.unk_token
513
+
514
+ # MBart-25 passes the language code as the decoder start token.
515
+ if getattr(model.config, "tokenizer_class", None) in ("MBartTokenizer", None):
516
+ config.decoder_start_token = None
517
+ else:
518
+ config.decoder_start_token = tokenizer.eos_token
519
+
520
+
521
+ @register_loader("PegasusConfig")
522
+ class PegasusLoader(BartLoader):
523
+ @property
524
+ def architecture_name(self):
525
+ return "PegasusForConditionalGeneration"
526
+
527
+ def set_config(self, config, model, tokenizer):
528
+ config.bos_token = tokenizer.pad_token
529
+ config.eos_token = tokenizer.eos_token
530
+ config.unk_token = tokenizer.unk_token
531
+ config.decoder_start_token = tokenizer.pad_token
532
+
533
+
534
+ @register_loader("OPTConfig")
535
+ class OPTLoader(BartLoader):
536
+ @property
537
+ def architecture_name(self):
538
+ return "OPTForCausalLM"
539
+
540
+ def get_model_spec(self, model):
541
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
542
+ model.config.num_hidden_layers,
543
+ model.config.num_attention_heads,
544
+ pre_norm=model.config.do_layer_norm_before,
545
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
546
+ project_in_out=model.config.word_embed_proj_dim != model.config.hidden_size,
547
+ )
548
+
549
+ self.set_decoder(spec.decoder, model.model.decoder)
550
+ self.set_linear(spec.decoder.projection, model.lm_head)
551
+ return spec
552
+
553
+ def smooth_activation(self, spec, activation_scales):
554
+ for i, layer in enumerate(spec.decoder.layer):
555
+ layer_scope = "model.decoder.layers.%d" % i
556
+
557
+ utils.smooth_activation(
558
+ layer.self_attention.layer_norm,
559
+ layer.self_attention.linear[0],
560
+ activation_scales["%s.self_attn.q_proj" % layer_scope],
561
+ )
562
+
563
+ utils.smooth_activation(
564
+ layer.ffn.layer_norm,
565
+ layer.ffn.linear_0,
566
+ activation_scales["%s.fc1" % layer_scope],
567
+ )
568
+
569
+ def set_vocabulary(self, spec, tokens):
570
+ spec.register_vocabulary(tokens)
571
+
572
+ def set_config(self, config, model, tokenizer):
573
+ config.bos_token = tokenizer.bos_token
574
+ config.eos_token = tokenizer.eos_token
575
+ config.unk_token = tokenizer.unk_token
576
+
577
+ def set_decoder(self, spec, decoder):
578
+ super().set_decoder(spec, decoder)
579
+
580
+ if decoder.project_in is not None:
581
+ self.set_linear(spec.project_in, decoder.project_in)
582
+ if decoder.project_out is not None:
583
+ self.set_linear(spec.project_out, decoder.project_out)
584
+ if decoder.final_layer_norm is not None:
585
+ self.set_layer_norm(spec.layer_norm, decoder.final_layer_norm)
586
+
587
+ def set_common_layers(self, spec, module):
588
+ spec.scale_embeddings = False
589
+ self.set_position_encodings(spec.position_encodings, module.embed_positions)
590
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
591
+
592
+ def get_vocabulary(self, model, tokenizer):
593
+ tokens = super().get_vocabulary(model, tokenizer)
594
+
595
+ i = 0
596
+ while len(tokens) % 8 != 0:
597
+ symbol = "madeupword{:04d}".format(i)
598
+ if symbol not in tokens:
599
+ tokens.append(symbol)
600
+ i += 1
601
+
602
+ return tokens
603
+
604
+
605
+ @register_loader("GPTBigCodeConfig")
606
+ class GPTBigCodeMHALoader(ModelLoader):
607
+ @property
608
+ def architecture_name(self):
609
+ return "GPTBigCodeForCausalLM"
610
+
611
+ def get_model_spec(self, model):
612
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
613
+ model.config.n_layer,
614
+ model.config.n_head,
615
+ pre_norm=True,
616
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
617
+ multi_query_attention=True,
618
+ )
619
+
620
+ self.set_decoder(spec.decoder, model.transformer)
621
+ self.set_linear(spec.decoder.projection, model.lm_head)
622
+ return spec
623
+
624
+ def set_vocabulary(self, spec, tokens):
625
+ spec.register_vocabulary(tokens)
626
+
627
+ def get_vocabulary(self, model, tokenizer):
628
+ tokens = super().get_vocabulary(model, tokenizer)
629
+
630
+ extra_ids = model.config.vocab_size - len(tokens)
631
+ for i in range(extra_ids):
632
+ tokens.append("<extra_id_%d>" % i)
633
+
634
+ return tokens
635
+
636
+ def set_config(self, config, model, tokenizer):
637
+ config.bos_token = tokenizer.bos_token
638
+ config.eos_token = tokenizer.eos_token
639
+ config.unk_token = tokenizer.unk_token
640
+
641
+ def set_decoder(self, spec, module):
642
+ spec.scale_embeddings = False
643
+ self.set_embeddings(spec.embeddings, module.wte)
644
+ self.set_position_encodings(spec.position_encodings, module.wpe)
645
+ self.set_layer_norm(spec.layer_norm, module.ln_f)
646
+
647
+ for layer_spec, layer in zip(spec.layer, module.h):
648
+ self.set_layer_norm(layer_spec.self_attention.layer_norm, layer.ln_1)
649
+ self.set_linear(layer_spec.self_attention.linear[0], layer.attn.c_attn)
650
+ self.set_linear(layer_spec.self_attention.linear[1], layer.attn.c_proj)
651
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.ln_2)
652
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.c_fc)
653
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.c_proj)
654
+
655
+
656
+ @register_loader("GPT2Config")
657
+ class GPT2Loader(ModelLoader):
658
+ @property
659
+ def architecture_name(self):
660
+ return "GPT2LMHeadModel"
661
+
662
+ def get_model_spec(self, model):
663
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
664
+ model.config.n_layer,
665
+ model.config.n_head,
666
+ pre_norm=True,
667
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
668
+ )
669
+
670
+ self.set_decoder(spec.decoder, model.transformer)
671
+ self.set_linear(spec.decoder.projection, model.lm_head)
672
+ return spec
673
+
674
+ def set_vocabulary(self, spec, tokens):
675
+ spec.register_vocabulary(tokens)
676
+
677
+ def set_config(self, config, model, tokenizer):
678
+ config.bos_token = tokenizer.bos_token
679
+ config.eos_token = tokenizer.eos_token
680
+ config.unk_token = tokenizer.unk_token
681
+
682
+ def set_decoder(self, spec, module):
683
+ spec.scale_embeddings = False
684
+ self.set_embeddings(spec.embeddings, module.wte)
685
+ self.set_position_encodings(spec.position_encodings, module.wpe)
686
+ self.set_layer_norm(spec.layer_norm, module.ln_f)
687
+
688
+ for layer_spec, layer in zip(spec.layer, module.h):
689
+ self.set_layer_norm(layer_spec.self_attention.layer_norm, layer.ln_1)
690
+ self.set_linear(layer_spec.self_attention.linear[0], layer.attn.c_attn)
691
+ self.set_linear(layer_spec.self_attention.linear[1], layer.attn.c_proj)
692
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.ln_2)
693
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.c_fc)
694
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.c_proj)
695
+
696
+
697
+ @register_loader("GPTJConfig")
698
+ class GPTJLoader(ModelLoader):
699
+ @property
700
+ def architecture_name(self):
701
+ return "GPTJForCausalLM"
702
+
703
+ def get_model_spec(self, model):
704
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
705
+ model.config.n_layer,
706
+ model.config.n_head,
707
+ pre_norm=True,
708
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
709
+ rotary_dim=model.config.rotary_dim,
710
+ rotary_interleave=False,
711
+ parallel_residual=True,
712
+ shared_layer_norm=True,
713
+ )
714
+
715
+ self.set_decoder(
716
+ spec.decoder,
717
+ model.transformer,
718
+ model.config.rotary_dim,
719
+ model.config.n_head,
720
+ )
721
+ self.set_linear(spec.decoder.projection, model.lm_head)
722
+ return spec
723
+
724
+ def set_vocabulary(self, spec, tokens):
725
+ spec.register_vocabulary(tokens)
726
+
727
+ def set_config(self, config, model, tokenizer):
728
+ config.bos_token = tokenizer.bos_token
729
+ config.eos_token = tokenizer.eos_token
730
+ config.unk_token = tokenizer.unk_token
731
+
732
+ def set_decoder(self, spec, module, rotary_dim, num_heads):
733
+ spec.scale_embeddings = False
734
+ self.set_embeddings(spec.embeddings, module.wte)
735
+ self.set_layer_norm(spec.layer_norm, module.ln_f)
736
+
737
+ for layer_spec, layer in zip(spec.layer, module.h):
738
+ self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln_1)
739
+
740
+ qw = layer.attn.q_proj.weight
741
+ kw = layer.attn.k_proj.weight
742
+ vw = layer.attn.v_proj.weight
743
+
744
+ qw = utils.permute_for_sliced_rotary(qw, num_heads, rotary_dim)
745
+ kw = utils.permute_for_sliced_rotary(kw, num_heads, rotary_dim)
746
+
747
+ layer_spec.self_attention.linear[0].weight = torch.cat((qw, kw, vw))
748
+ self.set_linear(layer_spec.self_attention.linear[1], layer.attn.out_proj)
749
+
750
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc_in)
751
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc_out)
752
+
753
+
754
+ @register_loader("CodeGenConfig")
755
+ class CodeGenLoader(ModelLoader):
756
+ @property
757
+ def architecture_name(self):
758
+ return "CodeGenForCausalLM"
759
+
760
+ def get_model_spec(self, model):
761
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
762
+ model.config.n_layer,
763
+ model.config.n_head,
764
+ pre_norm=True,
765
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
766
+ rotary_dim=model.config.rotary_dim,
767
+ rotary_interleave=False,
768
+ parallel_residual=True,
769
+ shared_layer_norm=True,
770
+ )
771
+
772
+ mp_num = 4
773
+ if hasattr(model.config, "head_dim") and model.config.head_dim in [128, 256]:
774
+ # models forked from "Salesforce/codegen2-1B" and "Salesforce/codegen2-3_7B"
775
+ # use a special setting of mp_num=8, all other using 4
776
+ # these model.config's use a special setting of head_dim
777
+ mp_num = 8
778
+
779
+ self.set_decoder(
780
+ spec.decoder,
781
+ model.transformer,
782
+ model.config.rotary_dim,
783
+ model.config.n_head,
784
+ model.config.n_embd,
785
+ mp_num=mp_num,
786
+ )
787
+ self.set_linear(spec.decoder.projection, model.lm_head)
788
+ return spec
789
+
790
+ def get_vocabulary(self, model, tokenizer):
791
+ tokens = super().get_vocabulary(model, tokenizer)
792
+
793
+ extra_ids = model.config.vocab_size - len(tokens)
794
+ for i in range(extra_ids):
795
+ # fix for additional vocab, see GPTNeoX Converter
796
+ tokens.append("<extra_id_%d>" % i)
797
+
798
+ return tokens
799
+
800
+ def set_vocabulary(self, spec, tokens):
801
+ spec.register_vocabulary(tokens)
802
+
803
+ def set_config(self, config, model, tokenizer):
804
+ config.bos_token = tokenizer.bos_token
805
+ config.eos_token = tokenizer.eos_token
806
+ config.unk_token = tokenizer.unk_token
807
+
808
+ def set_decoder(self, spec, module, rotary_dim, num_heads, embed_dim, mp_num):
809
+ spec.scale_embeddings = False
810
+ self.set_embeddings(spec.embeddings, module.wte)
811
+ self.set_layer_norm(spec.layer_norm, module.ln_f)
812
+
813
+ base_permutation = np.arange(0, mp_num * 3).reshape(-1, 3).T.flatten().tolist()
814
+ local_dim = embed_dim // mp_num
815
+ permutation = torch.cat(
816
+ [torch.arange(i * local_dim, (i + 1) * local_dim) for i in base_permutation]
817
+ )
818
+
819
+ for layer_spec, layer in zip(spec.layer, module.h):
820
+ self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln_1)
821
+ # [start convert CodeGen to GPT-J format]
822
+ # see https://github.com/fauxpilot/fauxpilot/blob/fb4073a9078dd001ebeb7dfefb8cb2ecc8a88f4b/converter/codegen_gptj_convert.py # noqa
823
+ qkv_proj = layer.attn.qkv_proj.weight
824
+
825
+ # GPT-J and CodeGen slice up the qkv projection slightly differently.
826
+ # the following permutation brings Codegen 'qkv_proj'
827
+ # in GPT-J order of qw, vw, kw
828
+ # we permute the *rows* here because the computation is xA.T
829
+ new_qkv_proj = qkv_proj[permutation, :]
830
+ # the name QKV is misleading here; they are actually stored in QVK
831
+ qw, vw, kw = new_qkv_proj.chunk(3, dim=0)
832
+ # [end convert CodeGen to GPT-J.]
833
+
834
+ qw = utils.permute_for_sliced_rotary(qw, num_heads, rotary_dim)
835
+ kw = utils.permute_for_sliced_rotary(kw, num_heads, rotary_dim)
836
+
837
+ layer_spec.self_attention.linear[0].weight = torch.cat((qw, kw, vw))
838
+ self.set_linear(layer_spec.self_attention.linear[1], layer.attn.out_proj)
839
+
840
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc_in)
841
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc_out)
842
+
843
+
844
+ @register_loader("GPTNeoXConfig")
845
+ class GPTNeoXLoader(ModelLoader):
846
+ @property
847
+ def architecture_name(self):
848
+ return "GPTNeoXForCausalLM"
849
+
850
+ def get_model_spec(self, model):
851
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
852
+ model.config.num_hidden_layers,
853
+ model.config.num_attention_heads,
854
+ pre_norm=True,
855
+ activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
856
+ rotary_dim=int(
857
+ model.config.rotary_pct
858
+ * (model.config.hidden_size // model.config.num_attention_heads)
859
+ ),
860
+ rotary_interleave=False,
861
+ parallel_residual=model.config.use_parallel_residual,
862
+ shared_layer_norm=False,
863
+ )
864
+
865
+ self.set_decoder(spec.decoder, model.gpt_neox, model.config.num_attention_heads)
866
+ self.set_linear(spec.decoder.projection, model.embed_out)
867
+ return spec
868
+
869
+ def get_vocabulary(self, model, tokenizer):
870
+ tokens = super().get_vocabulary(model, tokenizer)
871
+
872
+ extra_ids = model.config.vocab_size - len(tokens)
873
+ for i in range(extra_ids):
874
+ tokens.append("<extra_id_%d>" % i)
875
+
876
+ return tokens
877
+
878
+ def set_vocabulary(self, spec, tokens):
879
+ spec.register_vocabulary(tokens)
880
+
881
+ def set_config(self, config, model, tokenizer):
882
+ config.bos_token = tokenizer.bos_token
883
+ config.eos_token = tokenizer.eos_token
884
+ config.unk_token = tokenizer.unk_token
885
+
886
+ def set_decoder(self, spec, module, num_heads):
887
+ spec.scale_embeddings = False
888
+ self.set_embeddings(spec.embeddings, module.embed_in)
889
+ self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
890
+
891
+ for layer_spec, layer in zip(spec.layer, module.layers):
892
+ if hasattr(layer_spec, "input_layer_norm"): # Use parallel residual.
893
+ self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)
894
+ self.set_layer_norm(
895
+ layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
896
+ )
897
+ else:
898
+ self.set_layer_norm(
899
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
900
+ )
901
+ self.set_layer_norm(
902
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
903
+ )
904
+
905
+ qkv_w = layer.attention.query_key_value.weight
906
+ qkv_b = layer.attention.query_key_value.bias
907
+
908
+ qkv_w = (
909
+ qkv_w.reshape(num_heads, 3, -1, qkv_w.shape[-1])
910
+ .swapaxes(0, 1)
911
+ .reshape(-1, qkv_w.shape[-1])
912
+ )
913
+ qkv_b = qkv_b.reshape(num_heads, 3, -1).swapaxes(0, 1).reshape(-1)
914
+
915
+ layer_spec.self_attention.linear[0].weight = qkv_w
916
+ layer_spec.self_attention.linear[0].bias = qkv_b
917
+
918
+ self.set_linear(layer_spec.self_attention.linear[1], layer.attention.dense)
919
+
920
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.dense_h_to_4h)
921
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.dense_4h_to_h)
922
+
923
+
924
+ @register_loader("WhisperConfig")
925
+ class WhisperLoader(BartLoader):
926
+ @property
927
+ def architecture_name(self):
928
+ return "WhisperForConditionalGeneration"
929
+
930
+ def get_model_spec(self, model):
931
+ spec = whisper_spec.WhisperSpec(
932
+ model.config.encoder_layers,
933
+ model.config.encoder_attention_heads,
934
+ model.config.decoder_layers,
935
+ model.config.decoder_attention_heads,
936
+ )
937
+
938
+ self.set_encoder(spec.encoder, model.model.encoder)
939
+ self.set_decoder(spec.decoder, model.model.decoder)
940
+ self.set_linear(spec.decoder.projection, model.proj_out)
941
+
942
+ return spec
943
+
944
+ def _get_lang_ids_from_tokenizer(self, tokenizer):
945
+ non_lang_special_tokens = [
946
+ "<|endoftext|>",
947
+ "<|startoftranscript|>",
948
+ "<|translate|>",
949
+ "<|transcribe|>",
950
+ "<|startoflm|>",
951
+ "<|startofprev|>",
952
+ "<|nocaptions|>",
953
+ "<|notimestamps|>",
954
+ ]
955
+
956
+ additional_tokens = getattr(tokenizer, "additional_special_tokens", [])
957
+ if not additional_tokens:
958
+ return []
959
+
960
+ return [
961
+ tokenizer.convert_tokens_to_ids(token)
962
+ for token in additional_tokens
963
+ if token not in non_lang_special_tokens
964
+ ]
965
+
966
+ def set_config(self, config, model, tokenizer):
967
+ gen_config = getattr(model, "generation_config", None)
968
+
969
+ if gen_config is not None:
970
+ config.suppress_ids = gen_config.suppress_tokens
971
+ config.suppress_ids_begin = gen_config.begin_suppress_tokens
972
+ if hasattr(gen_config, "alignment_heads"):
973
+ config.alignment_heads = gen_config.alignment_heads
974
+ if hasattr(gen_config, "lang_to_id"):
975
+ config.lang_ids = sorted(gen_config.lang_to_id.values())
976
+ else:
977
+ config.suppress_ids = model.config.suppress_tokens
978
+ config.suppress_ids_begin = model.config.begin_suppress_tokens
979
+ config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path)
980
+
981
+ if getattr(config, "lang_ids", None) is None:
982
+ config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer)
983
+
984
+ if config.alignment_heads is None:
985
+ # Use the last half layers for alignment by default.
986
+ num_layers = model.config.decoder_layers
987
+ num_heads = model.config.decoder_attention_heads
988
+ config.alignment_heads = list(
989
+ itertools.product(
990
+ range(num_layers // 2, num_layers),
991
+ range(num_heads),
992
+ )
993
+ )
994
+
995
+ def get_vocabulary(self, model, tokenizer):
996
+ tokens = super().get_vocabulary(model, tokenizer)
997
+
998
+ # Add timestamp tokens.
999
+ tokens.extend(
1000
+ "<|%.2f|>" % (i * 0.02)
1001
+ for i in range(model.config.vocab_size - len(tokens))
1002
+ )
1003
+
1004
+ return tokens
1005
+
1006
+ def set_vocabulary(self, spec, tokens):
1007
+ spec.register_vocabulary(tokens)
1008
+
1009
+ def set_encoder(self, spec, encoder):
1010
+ self.set_conv1d(spec.conv1, encoder.conv1)
1011
+ self.set_conv1d(spec.conv2, encoder.conv2)
1012
+ super().set_encoder(spec, encoder)
1013
+
1014
+ def set_decoder(self, spec, decoder):
1015
+ self.set_embeddings(spec.embeddings, decoder.embed_tokens)
1016
+ super().set_decoder(spec, decoder)
1017
+
1018
+ def set_common_layers(self, spec, module):
1019
+ self.set_position_encodings(spec.position_encodings, module.embed_positions)
1020
+ self.set_layer_norm(spec.layer_norm, module.layer_norm)
1021
+
1022
+ def set_conv1d(self, spec, module):
1023
+ spec.weight = module.weight
1024
+ spec.bias = module.bias
1025
+
1026
+
1027
+ @register_loader("Wav2Vec2Config")
1028
+ class Wav2Vec2Loader(BartLoader):
1029
+ @property
1030
+ def architecture_name(self):
1031
+ return "Wav2Vec2ForCTC"
1032
+
1033
+ def get_model_spec(self, model):
1034
+ return_hidden = getattr(model.wav2vec2.config, "return_hidden", False)
1035
+ spec = wav2vec2_spec.Wav2Vec2Spec(
1036
+ model.wav2vec2.config.num_feat_extract_layers,
1037
+ model.wav2vec2.encoder.config.num_hidden_layers,
1038
+ model.wav2vec2.encoder.config.num_attention_heads,
1039
+ model.lm_head.weight.shape[0],
1040
+ return_hidden,
1041
+ )
1042
+
1043
+ # layer component name matching (no duplications saving)
1044
+ for layer in model.wav2vec2.encoder.layers:
1045
+ layer.self_attn = layer.attention
1046
+ layer.self_attn_layer_norm = layer.layer_norm
1047
+ layer.activation_fn = layer.feed_forward.intermediate_act_fn
1048
+ layer.fc1 = layer.feed_forward.intermediate_dense
1049
+ layer.fc2 = layer.feed_forward.output_dense
1050
+
1051
+ self.set_encoder(spec.encoder, model, model.wav2vec2.config)
1052
+ return spec
1053
+
1054
+ def set_config(self, config, model, tokenizer):
1055
+ return
1056
+
1057
+ def get_vocabulary(self, model, tokenizer):
1058
+ return tokenizer.get_vocab()
1059
+
1060
+ def set_vocabulary(self, spec, tokens):
1061
+ spec.register_vocabulary(tokens)
1062
+
1063
+ def set_feature_extractor(self, spec, feature_extractor):
1064
+ spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight
1065
+ spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias
1066
+ self.set_layer_norm(
1067
+ spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm
1068
+ )
1069
+ for spec_layer, module_layer in zip(
1070
+ spec.feat_layer, feature_extractor.conv_layers[1:]
1071
+ ):
1072
+ spec_layer.conv.weight = module_layer.conv.weight
1073
+ spec_layer.conv.bias = module_layer.conv.bias
1074
+ self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm)
1075
+
1076
+ def set_feature_projection(self, spec, feature_projection):
1077
+ self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
1078
+ self.set_linear(spec.fp_projection, feature_projection.projection)
1079
+
1080
+ def set_pos_conv_embed(self, spec, encoder, config):
1081
+ # forcing parameters to be set because some transformers version initializes garbage numbers
1082
+ # conv parameters are float16 so force float32 for the loading
1083
+ encoder.pos_conv_embed.conv.weight.data = (
1084
+ encoder.pos_conv_embed.conv.weight.data.float()
1085
+ )
1086
+ encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float()
1087
+ for param in encoder.pos_conv_embed.parameters():
1088
+ param.data = param.data.float()
1089
+ encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size)))
1090
+ spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight
1091
+ spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias
1092
+
1093
+ def set_encoder(self, spec, model, config):
1094
+ self.set_feature_extractor(spec, model.wav2vec2.feature_extractor)
1095
+ self.set_feature_projection(spec, model.wav2vec2.feature_projection)
1096
+ self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config)
1097
+ super().set_encoder(spec, model.wav2vec2.encoder)
1098
+ return_hidden = getattr(model.wav2vec2.config, "return_hidden", False)
1099
+ if not return_hidden:
1100
+ self.set_linear(spec.lm_head, model.lm_head)
1101
+
1102
+ def set_common_layers(self, spec, module):
1103
+ self.set_layer_norm(spec.layer_norm, module.layer_norm)
1104
+
1105
+
1106
+ @register_loader("Wav2Vec2BertConfig")
1107
+ class Wav2Vec2BertLoader(BartLoader):
1108
+ @property
1109
+ def architecture_name(self):
1110
+ return "Wav2Vec2BertForCTC"
1111
+
1112
+ def get_model_spec(self, model):
1113
+ return_hidden = getattr(model.wav2vec2_bert.config, "return_hidden", False)
1114
+ spec = wav2vec2bert_spec.Wav2Vec2BertSpec(
1115
+ model.wav2vec2_bert.config.num_adapter_layers,
1116
+ model.wav2vec2_bert.config.num_hidden_layers,
1117
+ model.lm_head.weight.shape[0],
1118
+ return_hidden,
1119
+ )
1120
+ self.set_encoder(spec.encoder, model)
1121
+ return spec
1122
+
1123
+ def set_config(self, config, model, tokenizer):
1124
+ return
1125
+
1126
+ def get_vocabulary(self, model, tokenizer):
1127
+ return tokenizer.get_vocab()
1128
+
1129
+ def set_vocabulary(self, spec, tokens):
1130
+ spec.register_vocabulary(tokens)
1131
+
1132
+ def set_feature_projection(self, spec, feature_projection):
1133
+ self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
1134
+ self.set_linear(spec.fp_projection, feature_projection.projection)
1135
+
1136
+ def set_attention(
1137
+ self, spec, attention, left_max_position=None, right_max_position=None
1138
+ ):
1139
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
1140
+ self.set_linear(split_layers[0], attention.linear_q)
1141
+ self.set_linear(split_layers[1], attention.linear_k)
1142
+ self.set_linear(split_layers[2], attention.linear_v)
1143
+ utils.fuse_linear(spec.linear[0], split_layers)
1144
+ self.set_linear(spec.linear[-1], attention.linear_out)
1145
+ if left_max_position or right_max_position:
1146
+ spec.relative_asymmetric_position_keys = attention.distance_embedding.weight
1147
+ spec.relative_left_max_position = np.dtype("int32").type(left_max_position)
1148
+ spec.relative_right_max_position = np.dtype("int32").type(
1149
+ right_max_position
1150
+ )
1151
+
1152
+ def set_wav2vec2bert_encoder(
1153
+ self, spec_layers, layers, left_max_position, right_max_position
1154
+ ):
1155
+ for slayer, layer in zip(spec_layers, layers):
1156
+ self.set_layer_norm(slayer.enc_ffn1_layer_norm, layer.ffn1_layer_norm)
1157
+ self.set_linear(slayer.enc_ffn1.linear_0, layer.ffn1.intermediate_dense)
1158
+ self.set_linear(slayer.enc_ffn1.linear_1, layer.ffn1.output_dense)
1159
+ self.set_attention(
1160
+ slayer.enc_attn, layer.self_attn, left_max_position, right_max_position
1161
+ )
1162
+ self.set_layer_norm(slayer.enc_attn_layer_norm, layer.self_attn_layer_norm)
1163
+ self.set_layer_norm(
1164
+ slayer.enc_conv_layer_norm, layer.conv_module.layer_norm
1165
+ )
1166
+ self.set_conv1d(
1167
+ slayer.enc_conv_pointwise_conv1, layer.conv_module.pointwise_conv1
1168
+ )
1169
+ self.set_conv1d(
1170
+ slayer.enc_conv_depthwise_conv, layer.conv_module.depthwise_conv
1171
+ )
1172
+ self.set_layer_norm(
1173
+ slayer.enc_conv_depthwise_layer_norm,
1174
+ layer.conv_module.depthwise_layer_norm,
1175
+ )
1176
+ self.set_conv1d(
1177
+ slayer.enc_conv_pointwise_conv2, layer.conv_module.pointwise_conv2
1178
+ )
1179
+ self.set_layer_norm(slayer.enc_ffn2_layer_norm, layer.ffn2_layer_norm)
1180
+ self.set_linear(slayer.enc_ffn2.linear_0, layer.ffn2.intermediate_dense)
1181
+ self.set_linear(slayer.enc_ffn2.linear_1, layer.ffn2.output_dense)
1182
+ self.set_layer_norm(slayer.enc_final_layer_norm, layer.final_layer_norm)
1183
+
1184
+ def set_wav2vec2bert_adapter(self, spec_layers, layers):
1185
+ for slayer, layer in zip(spec_layers, layers):
1186
+ self.set_layer_norm(
1187
+ slayer.adpt_residual_layer_norm, layer.residual_layer_norm
1188
+ )
1189
+ self.set_conv1d(slayer.adpt_residual_conv, layer.residual_conv)
1190
+ self.set_layer_norm(slayer.adpt_attn_layer_norm, layer.self_attn_layer_norm)
1191
+ self.set_conv1d(slayer.adpt_attn_conv, layer.self_attn_conv)
1192
+ self.set_attention(slayer.adpt_attn_layer, layer.self_attn)
1193
+ self.set_layer_norm(slayer.adpt_ffn_layer_norm, layer.ffn_layer_norm)
1194
+ self.set_linear(slayer.adpt_ffn.linear_0, layer.ffn.intermediate_dense)
1195
+ self.set_linear(slayer.adpt_ffn.linear_1, layer.ffn.output_dense)
1196
+
1197
+ def set_encoder(self, spec, model):
1198
+ self.set_feature_projection(spec, model.wav2vec2_bert.feature_projection)
1199
+ self.set_wav2vec2bert_encoder(
1200
+ spec.encoder_layers,
1201
+ model.wav2vec2_bert.encoder.layers,
1202
+ model.wav2vec2_bert.config.left_max_position_embeddings,
1203
+ model.wav2vec2_bert.config.right_max_position_embeddings,
1204
+ )
1205
+ self.set_wav2vec2bert_adapter(
1206
+ spec.adapter_layers, model.wav2vec2_bert.adapter.layers
1207
+ )
1208
+ return_hidden = getattr(model.wav2vec2_bert.config, "return_hidden", False)
1209
+ if not return_hidden:
1210
+ self.set_linear(spec.lm_head, model.lm_head)
1211
+
1212
+ def set_conv1d(self, spec, module):
1213
+ spec.weight = module.weight
1214
+ if module.bias is not None:
1215
+ spec.bias = module.bias
1216
+
1217
+ def set_layer_norm(self, spec, module):
1218
+ spec.gamma = module.weight
1219
+ if module.bias is not None:
1220
+ spec.beta = module.bias
1221
+
1222
+
1223
+ @register_loader("T5Config")
1224
+ class T5Loader(ModelLoader):
1225
+ @property
1226
+ def architecture_name(self):
1227
+ return "T5ForConditionalGeneration"
1228
+
1229
+ def get_model_spec(self, model):
1230
+ spec = transformer_spec.TransformerSpec.from_config(
1231
+ (model.config.num_layers, model.config.num_decoder_layers),
1232
+ model.config.num_heads,
1233
+ pre_norm=True,
1234
+ activation=_SUPPORTED_ACTIVATIONS[model.config.dense_act_fn],
1235
+ ffn_glu=model.config.is_gated_act,
1236
+ relative_attention_bias=True,
1237
+ rms_norm=True,
1238
+ )
1239
+
1240
+ self.set_stack(spec.encoder, model.encoder)
1241
+ self.set_stack(spec.decoder, model.decoder, is_decoder=True)
1242
+ self.set_linear(spec.decoder.projection, model.lm_head)
1243
+
1244
+ if model.config.tie_word_embeddings:
1245
+ spec.decoder.scale_outputs = model.config.d_model**-0.5
1246
+
1247
+ return spec
1248
+
1249
+ def get_vocabulary(self, model, tokenizer):
1250
+ tokens = super().get_vocabulary(model, tokenizer)
1251
+
1252
+ extra_ids = model.config.vocab_size - len(tokens)
1253
+ for i in range(extra_ids):
1254
+ tokens.append("<extra_id_%d>" % i)
1255
+
1256
+ return tokens
1257
+
1258
+ def set_vocabulary(self, spec, tokens):
1259
+ spec.register_source_vocabulary(tokens)
1260
+ spec.register_target_vocabulary(tokens)
1261
+
1262
+ def set_config(self, config, model, tokenizer):
1263
+ config.bos_token = tokenizer.pad_token
1264
+ config.eos_token = tokenizer.eos_token
1265
+ config.unk_token = tokenizer.unk_token
1266
+ if hasattr(model.config, "decoder_start_token_id"):
1267
+ config.decoder_start_token = tokenizer.convert_ids_to_tokens(
1268
+ model.config.decoder_start_token_id
1269
+ )
1270
+ else:
1271
+ config.decoder_start_token = tokenizer.pad_token
1272
+
1273
+ def set_stack(self, spec, module, is_decoder=False):
1274
+ self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
1275
+ self.set_embeddings(
1276
+ (
1277
+ spec.embeddings[0]
1278
+ if isinstance(spec.embeddings, list)
1279
+ else spec.embeddings
1280
+ ),
1281
+ module.embed_tokens,
1282
+ )
1283
+
1284
+ spec.scale_embeddings = False
1285
+
1286
+ for i, (layer_spec, block) in enumerate(zip(spec.layer, module.block)):
1287
+ self.set_self_attention(layer_spec.self_attention, block.layer[0])
1288
+
1289
+ if i > 0:
1290
+ # Reuse relative attention bias from the first layer.
1291
+ first_self_attention = spec.layer[0].self_attention
1292
+ layer_spec.self_attention.relative_attention_bias = (
1293
+ first_self_attention.relative_attention_bias
1294
+ )
1295
+ layer_spec.self_attention.relative_attention_max_distance = (
1296
+ first_self_attention.relative_attention_max_distance
1297
+ )
1298
+
1299
+ if is_decoder:
1300
+ self.set_cross_attention(layer_spec.attention, block.layer[1])
1301
+
1302
+ self.set_ffn(layer_spec.ffn, block.layer[-1])
1303
+
1304
+ def set_ffn(self, spec, module):
1305
+ if hasattr(spec, "linear_0_noact"):
1306
+ self.set_linear(spec.linear_0, module.DenseReluDense.wi_0)
1307
+ self.set_linear(spec.linear_0_noact, module.DenseReluDense.wi_1)
1308
+ else:
1309
+ self.set_linear(spec.linear_0, module.DenseReluDense.wi)
1310
+
1311
+ self.set_linear(spec.linear_1, module.DenseReluDense.wo)
1312
+ self.set_layer_norm(spec.layer_norm, module.layer_norm)
1313
+
1314
+ def set_self_attention(self, spec, module):
1315
+ self.set_attention(spec, module.SelfAttention, self_attention=True)
1316
+ self.set_layer_norm(spec.layer_norm, module.layer_norm)
1317
+
1318
+ def set_cross_attention(self, spec, module):
1319
+ self.set_attention(spec, module.EncDecAttention)
1320
+ self.set_layer_norm(spec.layer_norm, module.layer_norm)
1321
+
1322
+ def set_attention(self, spec, attention, self_attention=False):
1323
+ spec.queries_scale = 1.0
1324
+
1325
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
1326
+ self.set_linear(split_layers[0], attention.q)
1327
+ self.set_linear(split_layers[1], attention.k)
1328
+ self.set_linear(split_layers[2], attention.v)
1329
+
1330
+ if self_attention:
1331
+ utils.fuse_linear(spec.linear[0], split_layers)
1332
+ else:
1333
+ utils.fuse_linear(spec.linear[0], split_layers[:1])
1334
+ utils.fuse_linear(spec.linear[1], split_layers[1:])
1335
+
1336
+ self.set_linear(spec.linear[-1], attention.o)
1337
+
1338
+ if attention.has_relative_attention_bias:
1339
+ spec.relative_attention_bias = attention.relative_attention_bias.weight
1340
+ spec.relative_attention_max_distance = np.dtype("int32").type(
1341
+ attention.relative_attention_max_distance
1342
+ )
1343
+
1344
+ def set_layer_norm(self, spec, layer_norm):
1345
+ spec.gamma = layer_norm.weight
1346
+
1347
+
1348
+ @register_loader("MT5Config")
1349
+ class MT5Loader(T5Loader):
1350
+ @property
1351
+ def architecture_name(self):
1352
+ return "MT5ForConditionalGeneration"
1353
+
1354
+
1355
+ @register_loader("BloomConfig")
1356
+ class BloomLoader(ModelLoader):
1357
+ @property
1358
+ def architecture_name(self):
1359
+ return "BloomForCausalLM"
1360
+
1361
+ def get_model_spec(self, model):
1362
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1363
+ model.config.n_layer,
1364
+ model.config.n_head,
1365
+ pre_norm=True,
1366
+ activation=common_spec.Activation.GELUTanh,
1367
+ layernorm_embedding=True,
1368
+ alibi=True,
1369
+ alibi_use_positive_positions=True,
1370
+ )
1371
+
1372
+ self.set_decoder(spec.decoder, model.transformer)
1373
+ self.set_linear(spec.decoder.projection, model.lm_head)
1374
+ return spec
1375
+
1376
+ def get_vocabulary(self, model, tokenizer):
1377
+ tokens = super().get_vocabulary(model, tokenizer)
1378
+
1379
+ extra_ids = model.config.vocab_size - len(tokens)
1380
+ for i in range(extra_ids):
1381
+ tokens.append("<extra_id_%d>" % i)
1382
+
1383
+ return tokens
1384
+
1385
+ def set_vocabulary(self, spec, tokens):
1386
+ spec.register_vocabulary(tokens)
1387
+
1388
+ def set_config(self, config, model, tokenizer):
1389
+ config.bos_token = tokenizer.bos_token
1390
+ config.eos_token = tokenizer.eos_token
1391
+ config.unk_token = tokenizer.unk_token
1392
+
1393
+ def set_decoder(self, spec, module):
1394
+ spec.scale_embeddings = False
1395
+ self.set_embeddings(spec.embeddings, module.word_embeddings)
1396
+ self.set_layer_norm(spec.layernorm_embedding, module.word_embeddings_layernorm)
1397
+ self.set_layer_norm(spec.layer_norm, module.ln_f)
1398
+
1399
+ for layer_spec, layer in zip(spec.layer, module.h):
1400
+ self.set_layer_norm(
1401
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
1402
+ )
1403
+ self.set_qkv_linear(
1404
+ layer_spec.self_attention.linear[0],
1405
+ layer.self_attention.query_key_value,
1406
+ layer.self_attention.num_heads,
1407
+ )
1408
+ self.set_linear(
1409
+ layer_spec.self_attention.linear[1], layer.self_attention.dense
1410
+ )
1411
+
1412
+ self.set_layer_norm(
1413
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
1414
+ )
1415
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.dense_h_to_4h)
1416
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.dense_4h_to_h)
1417
+
1418
+ def set_qkv_linear(self, spec, module, num_heads):
1419
+ weight = module.weight
1420
+ weight = weight.reshape(num_heads, 3, -1, weight.shape[-1])
1421
+ weight = weight.transpose(0, 1)
1422
+ weight = weight.reshape(-1, weight.shape[-1])
1423
+
1424
+ bias = module.bias
1425
+ bias = bias.reshape(num_heads, 3, -1)
1426
+ bias = bias.transpose(0, 1)
1427
+ bias = bias.reshape(-1)
1428
+
1429
+ spec.weight = weight
1430
+ spec.bias = bias
1431
+
1432
+
1433
+ @register_loader("MPTConfig")
1434
+ class MPTLoader(ModelLoader):
1435
+ @property
1436
+ def architecture_name(self):
1437
+ return "AutoModelForCausalLM"
1438
+
1439
+ def get_model_spec(self, model):
1440
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1441
+ model.config.n_layers,
1442
+ model.config.n_heads,
1443
+ pre_norm=True,
1444
+ activation=common_spec.Activation.GELU,
1445
+ alibi=True,
1446
+ )
1447
+
1448
+ self.set_decoder(spec.decoder, model.transformer)
1449
+ return spec
1450
+
1451
+ def get_vocabulary(self, model, tokenizer):
1452
+ tokens = super().get_vocabulary(model, tokenizer)
1453
+
1454
+ extra_ids = model.config.vocab_size - len(tokens)
1455
+ for i in range(extra_ids):
1456
+ tokens.append("<extra_id_%d>" % i)
1457
+
1458
+ return tokens
1459
+
1460
+ def set_vocabulary(self, spec, tokens):
1461
+ spec.register_vocabulary(tokens)
1462
+
1463
+ def set_config(self, config, model, tokenizer):
1464
+ config.bos_token = tokenizer.bos_token
1465
+ config.eos_token = tokenizer.eos_token
1466
+ config.unk_token = tokenizer.unk_token
1467
+
1468
+ def set_decoder(self, spec, module):
1469
+ self.set_embeddings(spec.embeddings, module.wte)
1470
+ self.set_layer_norm(spec.layer_norm, module.norm_f)
1471
+
1472
+ spec.scale_embeddings = False
1473
+ spec.projection.weight = spec.embeddings.weight
1474
+
1475
+ for layer_spec, layer in zip(spec.layer, module.blocks):
1476
+ self.set_layer_norm(layer_spec.self_attention.layer_norm, layer.norm_1)
1477
+ self.set_linear(layer_spec.self_attention.linear[0], layer.attn.Wqkv)
1478
+ self.set_linear(layer_spec.self_attention.linear[1], layer.attn.out_proj)
1479
+
1480
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.norm_2)
1481
+ self.set_linear(layer_spec.ffn.linear_0, layer.ffn.up_proj)
1482
+ self.set_linear(layer_spec.ffn.linear_1, layer.ffn.down_proj)
1483
+
1484
+ def set_layer_norm(self, spec, module):
1485
+ spec.gamma = module.weight
1486
+ spec.beta = torch.zeros_like(spec.gamma)
1487
+
1488
+
1489
+ @register_loader("GemmaConfig")
1490
+ class GemmaLoader(ModelLoader):
1491
+ @property
1492
+ def architecture_name(self):
1493
+ return "GemmaForCausalLM"
1494
+
1495
+ def get_model_spec(self, model):
1496
+ num_layers = model.config.num_hidden_layers
1497
+
1498
+ num_heads = model.config.num_attention_heads
1499
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1500
+ if num_heads_kv == num_heads:
1501
+ num_heads_kv = None
1502
+
1503
+ activation_config = getattr(
1504
+ model.config, "hidden_activation", "gelu_pytorch_tanh"
1505
+ )
1506
+
1507
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1508
+ num_layers,
1509
+ num_heads,
1510
+ activation=(
1511
+ common_spec.Activation.GELU
1512
+ if activation_config == "gelu"
1513
+ else common_spec.Activation.GELUTanh
1514
+ ),
1515
+ pre_norm=True,
1516
+ ffn_glu=True,
1517
+ rms_norm=True,
1518
+ rotary_dim=0,
1519
+ rotary_interleave=False,
1520
+ rotary_base=getattr(model.config, "rope_theta", 10000),
1521
+ num_heads_kv=num_heads_kv,
1522
+ head_dim=model.config.head_dim,
1523
+ )
1524
+
1525
+ self.set_decoder(spec.decoder, model.model)
1526
+ self.set_linear(spec.decoder.projection, model.lm_head)
1527
+ spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
1528
+ return spec
1529
+
1530
+ def get_vocabulary(self, model, tokenizer):
1531
+ tokens = super().get_vocabulary(model, tokenizer)
1532
+
1533
+ extra_ids = model.config.vocab_size - len(tokens)
1534
+ for i in range(extra_ids):
1535
+ tokens.append("<extra_id_%d>" % i)
1536
+ if model.config.vocab_size < len(tokens):
1537
+ tokens = tokens[: model.config.vocab_size]
1538
+
1539
+ return tokens
1540
+
1541
+ def set_vocabulary(self, spec, tokens):
1542
+ spec.register_vocabulary(tokens)
1543
+
1544
+ def set_config(self, config, model, tokenizer):
1545
+ config.bos_token = tokenizer.bos_token
1546
+ config.eos_token = tokenizer.eos_token
1547
+ config.unk_token = tokenizer.unk_token
1548
+ config.layer_norm_epsilon = model.config.rms_norm_eps
1549
+
1550
+ def set_layer_norm(self, spec, layer_norm):
1551
+ spec.gamma = layer_norm.weight
1552
+ spec.layer_norm_use_residual = True
1553
+
1554
+ def set_decoder(self, spec, module):
1555
+ spec.scale_embeddings = True
1556
+ spec.start_from_zero_embedding = False
1557
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
1558
+ self.set_layer_norm(spec.layer_norm, module.norm)
1559
+
1560
+ for layer_spec, layer in zip(spec.layer, module.layers):
1561
+ self.set_layer_norm(
1562
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
1563
+ )
1564
+ self.set_layer_norm(
1565
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
1566
+ )
1567
+
1568
+ wq = layer.self_attn.q_proj.weight
1569
+ wk = layer.self_attn.k_proj.weight
1570
+ wv = layer.self_attn.v_proj.weight
1571
+ wo = layer.self_attn.o_proj.weight
1572
+
1573
+ layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
1574
+ layer_spec.self_attention.linear[1].weight = wo
1575
+
1576
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
1577
+ self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
1578
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
1579
+
1580
+ delattr(layer, "self_attn")
1581
+ delattr(layer, "mlp")
1582
+ gc.collect()
1583
+
1584
+
1585
+ @register_loader("Gemma2Config")
1586
+ class Gemma2Loader(ModelLoader):
1587
+ @property
1588
+ def architecture_name(self):
1589
+ return "Gemma2ForCausalLM"
1590
+
1591
+ def get_model_spec(self, model):
1592
+ num_layers = model.config.num_hidden_layers
1593
+
1594
+ num_heads = model.config.num_attention_heads
1595
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1596
+ if num_heads_kv == num_heads:
1597
+ num_heads_kv = None
1598
+
1599
+ activation_config = getattr(
1600
+ model.config, "hidden_activation", "gelu_pytorch_tanh"
1601
+ )
1602
+
1603
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1604
+ num_layers,
1605
+ num_heads,
1606
+ activation=(
1607
+ common_spec.Activation.GELU
1608
+ if activation_config == "gelu"
1609
+ else common_spec.Activation.GELUTanh
1610
+ ),
1611
+ pre_norm=True,
1612
+ ffn_glu=True,
1613
+ rms_norm=True,
1614
+ rotary_dim=0,
1615
+ rotary_interleave=False,
1616
+ rotary_base=getattr(model.config, "rope_theta", 10000),
1617
+ num_heads_kv=num_heads_kv,
1618
+ head_dim=model.config.head_dim,
1619
+ pre_post_layer_norm=True,
1620
+ )
1621
+
1622
+ self.set_decoder(spec.decoder, model.model)
1623
+ self.set_linear(spec.decoder.projection, model.lm_head)
1624
+ spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
1625
+ return spec
1626
+
1627
+ def get_vocabulary(self, model, tokenizer):
1628
+ tokens = super().get_vocabulary(model, tokenizer)
1629
+
1630
+ extra_ids = model.config.vocab_size - len(tokens)
1631
+ for i in range(extra_ids):
1632
+ tokens.append("<extra_id_%d>" % i)
1633
+ if model.config.vocab_size < len(tokens):
1634
+ tokens = tokens[: model.config.vocab_size]
1635
+
1636
+ return tokens
1637
+
1638
+ def set_vocabulary(self, spec, tokens):
1639
+ spec.register_vocabulary(tokens)
1640
+
1641
+ def set_config(self, config, model, tokenizer):
1642
+ config.bos_token = tokenizer.bos_token
1643
+ config.eos_token = tokenizer.eos_token
1644
+ config.unk_token = tokenizer.unk_token
1645
+ config.layer_norm_epsilon = model.config.rms_norm_eps
1646
+
1647
+ def set_layer_norm(self, spec, layer_norm):
1648
+ spec.gamma = layer_norm.weight
1649
+ spec.layer_norm_use_residual = True
1650
+
1651
+ def set_decoder(self, spec, module):
1652
+ spec.scale_embeddings = True
1653
+ spec.start_from_zero_embedding = False
1654
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
1655
+ self.set_layer_norm(spec.layer_norm, module.norm)
1656
+
1657
+ for layer_spec, layer in zip(spec.layer, module.layers):
1658
+ self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)
1659
+
1660
+ self.set_layer_norm(
1661
+ layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
1662
+ )
1663
+
1664
+ self.set_layer_norm(
1665
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
1666
+ )
1667
+
1668
+ self.set_layer_norm(
1669
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
1670
+ )
1671
+
1672
+ wq = layer.self_attn.q_proj.weight
1673
+ wk = layer.self_attn.k_proj.weight
1674
+ wv = layer.self_attn.v_proj.weight
1675
+ wo = layer.self_attn.o_proj.weight
1676
+
1677
+ layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
1678
+ layer_spec.self_attention.linear[1].weight = wo
1679
+
1680
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
1681
+ self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
1682
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
1683
+
1684
+ delattr(layer, "self_attn")
1685
+ delattr(layer, "mlp")
1686
+ gc.collect()
1687
+
1688
+
1689
+ @register_loader("LlamaConfig")
1690
+ class LlamaLoader(ModelLoader):
1691
+ @property
1692
+ def architecture_name(self):
1693
+ return "LlamaForCausalLM"
1694
+
1695
+ def get_model_spec(self, model):
1696
+ num_layers = model.config.num_hidden_layers
1697
+
1698
+ num_heads = model.config.num_attention_heads
1699
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1700
+ if num_heads_kv == num_heads:
1701
+ num_heads_kv = None
1702
+
1703
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
1704
+ model.config, 10_000
1705
+ )
1706
+
1707
+ quantization_config = getattr(model.config, "quantization_config", None)
1708
+ if quantization_config:
1709
+ quant_type = None
1710
+ if quantization_config.quant_method == "awq":
1711
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
1712
+ if quant_type is None:
1713
+ raise NotImplementedError(
1714
+ "Quantization type '%s' is not yet implemented. "
1715
+ "The following Quantization types are currently supported: %s"
1716
+ % (
1717
+ quantization_config.quant_method,
1718
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
1719
+ )
1720
+ )
1721
+ quant_group_size = quantization_config.group_size
1722
+ quant_bits = quantization_config.bits
1723
+ else:
1724
+ quant_type = common_spec.Quantization.CT2
1725
+ quant_group_size = None
1726
+ quant_bits = None
1727
+
1728
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1729
+ num_layers,
1730
+ num_heads,
1731
+ activation=common_spec.Activation.SWISH,
1732
+ pre_norm=True,
1733
+ ffn_glu=True,
1734
+ rms_norm=True,
1735
+ rotary_dim=0,
1736
+ rotary_interleave=False,
1737
+ rotary_scaling_type=rotary_scaling_type,
1738
+ rotary_scaling_factor=rotary_scaling_factor,
1739
+ rotary_base=rope_theta,
1740
+ num_heads_kv=num_heads_kv,
1741
+ quant_type=quant_type,
1742
+ quant_group_size=quant_group_size,
1743
+ quant_bits=quant_bits,
1744
+ )
1745
+
1746
+ self.set_decoder(spec.decoder, model.model, quant_type)
1747
+ self.set_linear(spec.decoder.projection, model.lm_head)
1748
+
1749
+ # set extra RoPE parameters for Llama-3.1
1750
+ rope_scaling = getattr(model.config, "rope_scaling", None)
1751
+ if rotary_scaling_type == attention_spec.RotaryScalingType.Llama3:
1752
+ for layer in spec.decoder.layer:
1753
+ layer.self_attention.rotary_low_freq_factor = rope_scaling[
1754
+ "low_freq_factor"
1755
+ ]
1756
+ layer.self_attention.rotary_high_freq_factor = rope_scaling[
1757
+ "high_freq_factor"
1758
+ ]
1759
+ return spec
1760
+
1761
+ def get_vocabulary(self, model, tokenizer):
1762
+ tokens = super().get_vocabulary(model, tokenizer)
1763
+
1764
+ extra_ids = model.config.vocab_size - len(tokens)
1765
+ for i in range(extra_ids):
1766
+ tokens.append("<extra_id_%d>" % i)
1767
+ if model.config.vocab_size < len(tokens):
1768
+ tokens = tokens[: model.config.vocab_size]
1769
+
1770
+ return tokens
1771
+
1772
+ def set_vocabulary(self, spec, tokens):
1773
+ spec.register_vocabulary(tokens)
1774
+
1775
+ def set_config(self, config, model, tokenizer):
1776
+ config.bos_token = tokenizer.bos_token
1777
+ config.eos_token = tokenizer.eos_token
1778
+ config.unk_token = (
1779
+ tokenizer.unk_token if tokenizer.unk_token is not None else ""
1780
+ )
1781
+ config.layer_norm_epsilon = model.config.rms_norm_eps
1782
+
1783
+ def set_layer_norm(self, spec, layer_norm):
1784
+ spec.gamma = layer_norm.weight
1785
+
1786
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
1787
+ spec.scale_embeddings = False
1788
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
1789
+ self.set_layer_norm(spec.layer_norm, module.norm)
1790
+
1791
+ for layer_spec, layer in zip(spec.layer, module.layers):
1792
+ self.set_layer_norm(
1793
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
1794
+ )
1795
+ self.set_layer_norm(
1796
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
1797
+ )
1798
+
1799
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
1800
+ self.set_linear(
1801
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
1802
+ )
1803
+ self.set_linear(
1804
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
1805
+ )
1806
+ self.set_linear(
1807
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
1808
+ )
1809
+
1810
+ if quant_type == common_spec.Quantization.CT2:
1811
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
1812
+ else:
1813
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
1814
+ utils.fuse_linear_prequant(
1815
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
1816
+ )
1817
+ self.set_linear(
1818
+ layer_spec.self_attention.linear[1],
1819
+ layer.self_attn.o_proj,
1820
+ quant_type=quant_type,
1821
+ )
1822
+
1823
+ self.set_linear(
1824
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
1825
+ )
1826
+ self.set_linear(
1827
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
1828
+ )
1829
+ self.set_linear(
1830
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
1831
+ )
1832
+
1833
+ delattr(layer, "self_attn")
1834
+ delattr(layer, "mlp")
1835
+ gc.collect()
1836
+
1837
+
1838
+ @register_loader("Gemma3TextConfig")
1839
+ @register_loader("Gemma3Config")
1840
+ class Gemma3Loader(ModelLoader):
1841
+ @property
1842
+ def architecture_name(self):
1843
+ return "Gemma3ForCausalLM"
1844
+
1845
+ def get_model_spec(self, model):
1846
+ num_layers = model.config.num_hidden_layers
1847
+ num_heads = model.config.num_attention_heads
1848
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1849
+ if num_heads_kv == num_heads:
1850
+ num_heads_kv = None
1851
+
1852
+ head_dim = model.config.head_dim
1853
+
1854
+ activation_config = getattr(
1855
+ model.config, "hidden_activation", "gelu_pytorch_tanh"
1856
+ )
1857
+
1858
+ # Get RoPE parameters
1859
+ rope_theta = getattr(model.config, "rope_theta", 1_000_000) # Global: 1M
1860
+ rope_local_base_freq = getattr(
1861
+ model.config, "rope_local_base_freq", 10_000
1862
+ ) # Local: 10k
1863
+
1864
+ # Get sliding window configuration
1865
+ sliding_window = getattr(model.config, "sliding_window", 1024)
1866
+ layer_types = getattr(model.config, "layer_types", None)
1867
+
1868
+ quantization_config = getattr(model.config, "quantization_config", None)
1869
+ if quantization_config:
1870
+ if quantization_config.quant_method == "awq":
1871
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
1872
+ if quant_type is None:
1873
+ raise NotImplementedError(
1874
+ "Quantization type '%s' is not yet implemented."
1875
+ % quantization_config.quant_method
1876
+ )
1877
+ quant_group_size = quantization_config.group_size
1878
+ quant_bits = quantization_config.bits
1879
+ else:
1880
+ quant_type = common_spec.Quantization.CT2
1881
+ quant_group_size = None
1882
+ quant_bits = None
1883
+
1884
+ # Create base spec using from_config
1885
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1886
+ num_layers,
1887
+ num_heads,
1888
+ activation=(
1889
+ common_spec.Activation.GELU
1890
+ if activation_config == "gelu"
1891
+ else common_spec.Activation.GELUTanh
1892
+ ),
1893
+ pre_norm=True,
1894
+ ffn_glu=True,
1895
+ rms_norm=True,
1896
+ rotary_dim=head_dim,
1897
+ rotary_interleave=False,
1898
+ rotary_base=rope_local_base_freq, # Default to local base freq
1899
+ num_heads_kv=num_heads_kv,
1900
+ head_dim=head_dim,
1901
+ sliding_window=sliding_window, # Default to local sliding window
1902
+ pre_post_layer_norm=True,
1903
+ quant_type=quant_type,
1904
+ quant_group_size=quant_group_size,
1905
+ quant_bits=quant_bits,
1906
+ qk_norm=True,
1907
+ )
1908
+
1909
+ # Store layer_types for use in set_decoder
1910
+ self._layer_types = layer_types
1911
+
1912
+ # Override per-layer settings for global vs local attention
1913
+ for i, layer_type in enumerate(layer_types):
1914
+ layer = spec.decoder.layer[i]
1915
+ if layer_type == "full_attention":
1916
+ layer.self_attention.rotary_base = np.dtype("float32").type(rope_theta)
1917
+ layer.self_attention.sliding_window = np.dtype("int32").type(0)
1918
+ elif layer_type == "sliding_attention":
1919
+ layer.self_attention.rotary_base = np.dtype("float32").type(
1920
+ rope_local_base_freq
1921
+ )
1922
+ layer.self_attention.sliding_window = np.dtype("int32").type(
1923
+ sliding_window
1924
+ )
1925
+
1926
+ self.set_decoder(spec.decoder, model.model, quant_type)
1927
+ self.set_linear(spec.decoder.projection, model.lm_head)
1928
+ return spec
1929
+
1930
+ def get_vocabulary(self, model, tokenizer):
1931
+ tokens = super().get_vocabulary(model, tokenizer)
1932
+
1933
+ extra_ids = model.config.vocab_size - len(tokens)
1934
+ for i in range(extra_ids):
1935
+ tokens.append("<extra_id_%d>" % i)
1936
+ if model.config.vocab_size < len(tokens):
1937
+ tokens = tokens[: model.config.vocab_size]
1938
+
1939
+ return tokens
1940
+
1941
+ def set_vocabulary(self, spec, tokens):
1942
+ spec.register_vocabulary(tokens)
1943
+
1944
+ def set_config(self, config, model, tokenizer):
1945
+ config.bos_token = tokenizer.bos_token
1946
+ config.unk_token = tokenizer.unk_token
1947
+
1948
+ if (
1949
+ hasattr(tokenizer, "chat_template")
1950
+ and isinstance(tokenizer.chat_template, str)
1951
+ and tokenizer.chat_template.strip()
1952
+ ):
1953
+ config.eos_token = "<end_of_turn>"
1954
+ else:
1955
+ config.eos_token = tokenizer.eos_token
1956
+
1957
+ def set_layer_norm(self, spec, layer_norm):
1958
+ spec.gamma = layer_norm.weight
1959
+ spec.layer_norm_use_residual = True
1960
+
1961
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
1962
+ spec.scale_embeddings = True
1963
+ spec.start_from_zero_embedding = False
1964
+ self.set_embeddings(spec.embeddings, module.embed_tokens) # Input
1965
+ self.set_layer_norm(spec.layer_norm, module.norm) # Output
1966
+
1967
+ for layer_spec, layer in zip(spec.layer, module.layers):
1968
+ self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)
1969
+
1970
+ self.set_layer_norm(
1971
+ layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
1972
+ )
1973
+
1974
+ self.set_layer_norm(
1975
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
1976
+ )
1977
+
1978
+ self.set_layer_norm(
1979
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
1980
+ )
1981
+
1982
+ # Set QK-norm weights (Gemma 3 uses this instead of soft-capping)
1983
+ self.set_layer_norm(
1984
+ layer_spec.self_attention.q_norm, layer.self_attn.q_norm
1985
+ )
1986
+ self.set_layer_norm(
1987
+ layer_spec.self_attention.k_norm, layer.self_attn.k_norm
1988
+ )
1989
+
1990
+ # Set attention projections
1991
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
1992
+ self.set_linear(
1993
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
1994
+ )
1995
+ self.set_linear(
1996
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
1997
+ )
1998
+ self.set_linear(
1999
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2000
+ )
2001
+
2002
+ if quant_type == common_spec.Quantization.CT2:
2003
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2004
+ else:
2005
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2006
+ utils.fuse_linear_prequant(
2007
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2008
+ )
2009
+
2010
+ self.set_linear(
2011
+ layer_spec.self_attention.linear[1],
2012
+ layer.self_attn.o_proj,
2013
+ quant_type=quant_type,
2014
+ )
2015
+
2016
+ # Set FFN weights
2017
+ self.set_linear(
2018
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2019
+ )
2020
+ self.set_linear(
2021
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2022
+ )
2023
+ self.set_linear(
2024
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2025
+ )
2026
+
2027
+ delattr(layer, "self_attn")
2028
+ delattr(layer, "mlp")
2029
+ gc.collect()
2030
+
2031
+
2032
+ @register_loader("MistralConfig")
2033
+ class MistralLoader(ModelLoader):
2034
+ @property
2035
+ def architecture_name(self):
2036
+ return "MistralForCausalLM"
2037
+
2038
+ def get_model_spec(self, model):
2039
+ num_layers = model.config.num_hidden_layers
2040
+
2041
+ num_heads = model.config.num_attention_heads
2042
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
2043
+ if num_heads_kv == num_heads:
2044
+ num_heads_kv = None
2045
+
2046
+ sliding_window = getattr(model.config, "sliding_window", 0)
2047
+
2048
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
2049
+ model.config, 10_000
2050
+ )
2051
+
2052
+ quantization_config = getattr(model.config, "quantization_config", None)
2053
+ if quantization_config:
2054
+ if quantization_config.quant_method == "awq":
2055
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2056
+ if quant_type is None:
2057
+ raise NotImplementedError(
2058
+ "Quantization type '%s' is not yet implemented. "
2059
+ "The following Quantization types are currently supported: %s"
2060
+ % (
2061
+ quantization_config.quant_method,
2062
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2063
+ )
2064
+ )
2065
+ quant_group_size = quantization_config.group_size
2066
+ quant_bits = quantization_config.bits
2067
+ else:
2068
+ quant_type = common_spec.Quantization.CT2
2069
+ quant_group_size = None
2070
+ quant_bits = None
2071
+
2072
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2073
+ num_layers,
2074
+ num_heads,
2075
+ activation=common_spec.Activation.SWISH,
2076
+ pre_norm=True,
2077
+ ffn_glu=True,
2078
+ rms_norm=True,
2079
+ rotary_dim=0,
2080
+ rotary_interleave=False,
2081
+ rotary_scaling_type=rotary_scaling_type,
2082
+ rotary_scaling_factor=rotary_scaling_factor,
2083
+ rotary_base=rope_theta,
2084
+ num_heads_kv=num_heads_kv,
2085
+ sliding_window=sliding_window,
2086
+ quant_type=quant_type,
2087
+ quant_group_size=quant_group_size,
2088
+ quant_bits=quant_bits,
2089
+ head_dim=model.config.head_dim,
2090
+ )
2091
+
2092
+ self.set_decoder(spec.decoder, model.model, quant_type=quant_type)
2093
+ self.set_linear(spec.decoder.projection, model.lm_head)
2094
+ return spec
2095
+
2096
+ def get_vocabulary(self, model, tokenizer):
2097
+ tokens = super().get_vocabulary(model, tokenizer)
2098
+
2099
+ extra_ids = model.config.vocab_size - len(tokens)
2100
+ for i in range(extra_ids):
2101
+ tokens.append("<extra_id_%d>" % i)
2102
+
2103
+ return tokens
2104
+
2105
+ def set_vocabulary(self, spec, tokens):
2106
+ spec.register_vocabulary(tokens)
2107
+
2108
+ def set_config(self, config, model, tokenizer):
2109
+ config.bos_token = tokenizer.bos_token
2110
+ config.eos_token = tokenizer.eos_token
2111
+ config.unk_token = tokenizer.unk_token
2112
+ config.layer_norm_epsilon = model.config.rms_norm_eps
2113
+
2114
+ def set_layer_norm(self, spec, layer_norm):
2115
+ spec.gamma = layer_norm.weight
2116
+
2117
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2118
+ spec.scale_embeddings = False
2119
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
2120
+ self.set_layer_norm(spec.layer_norm, module.norm)
2121
+
2122
+ for layer_spec, layer in zip(spec.layer, module.layers):
2123
+ self.set_layer_norm(
2124
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
2125
+ )
2126
+ self.set_layer_norm(
2127
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2128
+ )
2129
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
2130
+ self.set_linear(
2131
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2132
+ )
2133
+ self.set_linear(
2134
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2135
+ )
2136
+ self.set_linear(
2137
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2138
+ )
2139
+
2140
+ if quant_type == common_spec.Quantization.CT2:
2141
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2142
+ else:
2143
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2144
+ utils.fuse_linear_prequant(
2145
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2146
+ )
2147
+ self.set_linear(
2148
+ layer_spec.self_attention.linear[1],
2149
+ layer.self_attn.o_proj,
2150
+ quant_type=quant_type,
2151
+ )
2152
+
2153
+ self.set_linear(
2154
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2155
+ )
2156
+ self.set_linear(
2157
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2158
+ )
2159
+ self.set_linear(
2160
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2161
+ )
2162
+
2163
+ delattr(layer, "self_attn")
2164
+ delattr(layer, "mlp")
2165
+ gc.collect()
2166
+
2167
+
2168
+ @register_loader("Qwen2Config")
2169
+ class Qwen2Loader(ModelLoader):
2170
+ @property
2171
+ def architecture_name(self):
2172
+ return "Qwen2ForCausalLM"
2173
+
2174
+ def get_model_spec(self, model):
2175
+ num_layers = model.config.num_hidden_layers
2176
+
2177
+ num_heads = model.config.num_attention_heads
2178
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
2179
+ if num_heads_kv == num_heads:
2180
+ num_heads_kv = None
2181
+
2182
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
2183
+ model.config, 10_000
2184
+ )
2185
+
2186
+ # Check for AWQ quantization config
2187
+ quantization_config = getattr(model.config, "quantization_config", None)
2188
+ if quantization_config:
2189
+ quant_type = None
2190
+ if quantization_config.quant_method == "awq":
2191
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2192
+ if quant_type is None:
2193
+ raise NotImplementedError(
2194
+ "Quantization type '%s' is not yet implemented. "
2195
+ "The following Quantization types are currently supported: %s"
2196
+ % (
2197
+ quantization_config.quant_method,
2198
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2199
+ )
2200
+ )
2201
+ quant_group_size = quantization_config.group_size
2202
+ quant_bits = quantization_config.bits
2203
+ else:
2204
+ quant_type = common_spec.Quantization.CT2
2205
+ quant_group_size = None
2206
+ quant_bits = None
2207
+
2208
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2209
+ num_layers,
2210
+ num_heads,
2211
+ activation=common_spec.Activation.SWISH,
2212
+ pre_norm=True,
2213
+ ffn_glu=True,
2214
+ rms_norm=True,
2215
+ rotary_dim=0,
2216
+ rotary_interleave=False,
2217
+ rotary_scaling_type=rotary_scaling_type,
2218
+ rotary_scaling_factor=rotary_scaling_factor,
2219
+ rotary_base=rope_theta,
2220
+ num_heads_kv=num_heads_kv,
2221
+ quant_type=quant_type,
2222
+ quant_group_size=quant_group_size,
2223
+ quant_bits=quant_bits,
2224
+ )
2225
+
2226
+ self.set_decoder(spec.decoder, model.model, quant_type)
2227
+ self.set_linear(spec.decoder.projection, model.lm_head)
2228
+ return spec
2229
+
2230
+ def get_vocabulary(self, model, tokenizer):
2231
+ tokens = super().get_vocabulary(model, tokenizer)
2232
+
2233
+ extra_ids = model.config.vocab_size - len(tokens)
2234
+ for i in range(extra_ids):
2235
+ tokens.append("<extra_id_%d>" % i)
2236
+ return tokens
2237
+
2238
+ def set_vocabulary(self, spec, tokens):
2239
+ spec.register_vocabulary(tokens)
2240
+
2241
+ def set_config(self, config, model, tokenizer):
2242
+ config.bos_token = (
2243
+ tokenizer.bos_token
2244
+ if tokenizer.bos_token is not None
2245
+ else tokenizer.pad_token
2246
+ )
2247
+ config.eos_token = tokenizer.eos_token
2248
+ config.unk_token = (
2249
+ tokenizer.unk_token if tokenizer.unk_token is not None else ""
2250
+ )
2251
+ config.layer_norm_epsilon = model.config.rms_norm_eps
2252
+
2253
+ def set_layer_norm(self, spec, layer_norm):
2254
+ spec.gamma = layer_norm.weight
2255
+
2256
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2257
+ spec.scale_embeddings = False
2258
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
2259
+ self.set_layer_norm(spec.layer_norm, module.norm)
2260
+
2261
+ for layer_spec, layer in zip(spec.layer, module.layers):
2262
+ self.set_layer_norm(
2263
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
2264
+ )
2265
+ self.set_layer_norm(
2266
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2267
+ )
2268
+
2269
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
2270
+ self.set_linear(
2271
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2272
+ )
2273
+ self.set_linear(
2274
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2275
+ )
2276
+ self.set_linear(
2277
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2278
+ )
2279
+
2280
+ if quant_type == common_spec.Quantization.CT2:
2281
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2282
+ else:
2283
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2284
+ utils.fuse_linear_prequant(
2285
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2286
+ )
2287
+
2288
+ self.set_linear(
2289
+ layer_spec.self_attention.linear[1],
2290
+ layer.self_attn.o_proj,
2291
+ quant_type=quant_type,
2292
+ )
2293
+
2294
+ self.set_linear(
2295
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2296
+ )
2297
+ self.set_linear(
2298
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2299
+ )
2300
+ self.set_linear(
2301
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2302
+ )
2303
+
2304
+ delattr(layer, "self_attn")
2305
+ delattr(layer, "mlp")
2306
+ gc.collect()
2307
+
2308
+
2309
+ @register_loader("Qwen3Config")
2310
+ class Qwen3Loader(ModelLoader):
2311
+ @property
2312
+ def architecture_name(self):
2313
+ return "Qwen3ForCausalLM"
2314
+
2315
+ def get_model_spec(self, model):
2316
+ num_layers = model.config.num_hidden_layers
2317
+ num_heads = model.config.num_attention_heads
2318
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
2319
+ head_dim = getattr(
2320
+ model.config, "head_dim", model.config.hidden_size // num_heads
2321
+ )
2322
+
2323
+ if num_heads_kv == num_heads:
2324
+ num_heads_kv = None
2325
+
2326
+ rotary_scaling_type, rotary_scaling_factor, rope_theta = self.get_rotary_params(
2327
+ model.config, 1_000_000
2328
+ )
2329
+ # Check for AWQ quantization config
2330
+ quantization_config = getattr(model.config, "quantization_config", None)
2331
+ if quantization_config:
2332
+ quant_type = None
2333
+ if quantization_config.quant_method == "awq":
2334
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2335
+ if quant_type is None:
2336
+ raise NotImplementedError(
2337
+ "Quantization type '%s' is not yet implemented. "
2338
+ "The following Quantization types are currently supported: %s"
2339
+ % (
2340
+ quantization_config.quant_method,
2341
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2342
+ )
2343
+ )
2344
+ quant_group_size = quantization_config.group_size
2345
+ quant_bits = quantization_config.bits
2346
+ else:
2347
+ quant_type = common_spec.Quantization.CT2
2348
+ quant_group_size = None
2349
+ quant_bits = None
2350
+
2351
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2352
+ num_layers,
2353
+ num_heads,
2354
+ activation=common_spec.Activation.SWISH,
2355
+ pre_norm=True,
2356
+ ffn_glu=True,
2357
+ rms_norm=True,
2358
+ rotary_dim=model.config.head_dim,
2359
+ rotary_interleave=False,
2360
+ rotary_scaling_type=rotary_scaling_type,
2361
+ rotary_scaling_factor=rotary_scaling_factor,
2362
+ rotary_base=rope_theta,
2363
+ num_heads_kv=num_heads_kv,
2364
+ head_dim=head_dim,
2365
+ qk_norm=True,
2366
+ quant_type=quant_type,
2367
+ quant_group_size=quant_group_size,
2368
+ quant_bits=quant_bits,
2369
+ )
2370
+
2371
+ self.set_decoder(spec.decoder, model.model, quant_type)
2372
+ self.set_linear(spec.decoder.projection, model.lm_head)
2373
+ return spec
2374
+
2375
+ def get_vocabulary(self, model, tokenizer):
2376
+ tokens = super().get_vocabulary(model, tokenizer)
2377
+ extra_ids = model.config.vocab_size - len(tokens)
2378
+ for i in range(extra_ids):
2379
+ tokens.append("<extra_id_%d>" % i)
2380
+ return tokens
2381
+
2382
+ def set_vocabulary(self, spec, tokens):
2383
+ spec.register_vocabulary(tokens)
2384
+
2385
+ def set_config(self, config, model, tokenizer):
2386
+ config.bos_token = (
2387
+ tokenizer.bos_token
2388
+ if tokenizer.bos_token is not None
2389
+ else tokenizer.pad_token
2390
+ )
2391
+ config.eos_token = tokenizer.eos_token
2392
+ config.unk_token = (
2393
+ tokenizer.unk_token if tokenizer.unk_token is not None else ""
2394
+ )
2395
+ config.layer_norm_epsilon = model.config.rms_norm_eps
2396
+
2397
+ def set_layer_norm(self, spec, layer_norm):
2398
+ spec.gamma = layer_norm.weight
2399
+
2400
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2401
+ spec.scale_embeddings = False
2402
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
2403
+ self.set_layer_norm(spec.layer_norm, module.norm)
2404
+
2405
+ for layer_idx, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
2406
+ self.set_layer_norm(
2407
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
2408
+ )
2409
+ self.set_layer_norm(
2410
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2411
+ )
2412
+
2413
+ self.set_layer_norm(
2414
+ layer_spec.self_attention.q_norm, layer.self_attn.q_norm
2415
+ )
2416
+ self.set_layer_norm(
2417
+ layer_spec.self_attention.k_norm, layer.self_attn.k_norm
2418
+ )
2419
+
2420
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
2421
+ self.set_linear(
2422
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
2423
+ )
2424
+ self.set_linear(
2425
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
2426
+ )
2427
+ self.set_linear(
2428
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
2429
+ )
2430
+
2431
+ if quant_type == common_spec.Quantization.CT2:
2432
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2433
+ else:
2434
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
2435
+ utils.fuse_linear_prequant(
2436
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
2437
+ )
2438
+
2439
+ self.set_linear(
2440
+ layer_spec.self_attention.linear[1],
2441
+ layer.self_attn.o_proj,
2442
+ quant_type=quant_type,
2443
+ )
2444
+
2445
+ self.set_linear(
2446
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
2447
+ )
2448
+ self.set_linear(
2449
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
2450
+ )
2451
+ self.set_linear(
2452
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2453
+ )
2454
+
2455
+ delattr(layer, "self_attn")
2456
+ delattr(layer, "mlp")
2457
+ gc.collect()
2458
+
2459
+
2460
+ @register_loader("MixFormerSequentialConfig")
2461
+ class MixFormerSequentialLoader(ModelLoader):
2462
+ @property
2463
+ def architecture_name(self):
2464
+ return "AutoModelForCausalLM"
2465
+
2466
+ def get_model_spec(self, model):
2467
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2468
+ num_layers=model.config.n_layer,
2469
+ num_heads=model.config.n_head,
2470
+ pre_norm=True,
2471
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
2472
+ rotary_dim=model.config.rotary_dim,
2473
+ rotary_interleave=False,
2474
+ parallel_residual=True,
2475
+ shared_layer_norm=True,
2476
+ )
2477
+
2478
+ self.set_decoder(spec.decoder, model.layers)
2479
+ self.set_linear(spec.decoder.projection, model.layers[-1].linear)
2480
+ return spec
2481
+
2482
+ def get_vocabulary(self, model, tokenizer):
2483
+ tokens = super().get_vocabulary(model, tokenizer)
2484
+
2485
+ extra_ids = model.config.vocab_size - len(tokens)
2486
+ for i in range(extra_ids):
2487
+ tokens.append("<extra_id_%d>" % i)
2488
+
2489
+ return tokens
2490
+
2491
+ def set_vocabulary(self, spec, tokens):
2492
+ spec.register_vocabulary(tokens)
2493
+
2494
+ def set_config(self, config, model, tokenizer):
2495
+ config.bos_token = tokenizer.bos_token
2496
+ config.eos_token = tokenizer.eos_token
2497
+ config.unk_token = tokenizer.unk_token
2498
+
2499
+ def set_decoder(self, spec, module):
2500
+ spec.scale_embeddings = False
2501
+ self.set_embeddings(spec.embeddings, module[0].wte)
2502
+ self.set_layer_norm(spec.layer_norm, module[-1].ln)
2503
+
2504
+ for layer_spec, layer in zip(spec.layer, module[1:-1]):
2505
+ self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln)
2506
+ self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv)
2507
+ self.set_linear(layer_spec.self_attention.linear[1], layer.mixer.out_proj)
2508
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc1)
2509
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
2510
+
2511
+
2512
+ @register_loader("PhiConfig")
2513
+ class PhiLoader(ModelLoader):
2514
+ @property
2515
+ def architecture_name(self):
2516
+ return "AutoModelForCausalLM"
2517
+
2518
+ def get_model_spec(self, model):
2519
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2520
+ num_layers=model.config.n_layer,
2521
+ num_heads=model.config.n_head,
2522
+ pre_norm=True,
2523
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
2524
+ rotary_dim=model.config.rotary_dim,
2525
+ rotary_interleave=False,
2526
+ parallel_residual=True,
2527
+ shared_layer_norm=True,
2528
+ )
2529
+
2530
+ self.set_decoder(spec.decoder, model.transformer)
2531
+ self.set_linear(spec.decoder.projection, model.lm_head.linear)
2532
+ self.set_layer_norm(spec.decoder.layer_norm, model.lm_head.ln)
2533
+ return spec
2534
+
2535
+ def get_vocabulary(self, model, tokenizer):
2536
+ tokens = super().get_vocabulary(model, tokenizer)
2537
+
2538
+ extra_ids = model.config.vocab_size - len(tokens)
2539
+ for i in range(extra_ids):
2540
+ tokens.append("<extra_id_%d>" % i)
2541
+
2542
+ return tokens
2543
+
2544
+ def set_vocabulary(self, spec, tokens):
2545
+ spec.register_vocabulary(tokens)
2546
+
2547
+ def set_config(self, config, model, tokenizer):
2548
+ config.bos_token = tokenizer.bos_token
2549
+ config.eos_token = tokenizer.eos_token
2550
+ config.unk_token = tokenizer.unk_token
2551
+
2552
+ def set_decoder(self, spec, module):
2553
+ spec.scale_embeddings = False
2554
+ self.set_embeddings(spec.embeddings, module.embd.wte)
2555
+
2556
+ for layer_spec, layer in zip(spec.layer, module.h):
2557
+ self.set_layer_norm(layer_spec.shared_layer_norm, layer.ln)
2558
+ self.set_linear(layer_spec.self_attention.linear[0], layer.mixer.Wqkv)
2559
+ self.set_linear(layer_spec.self_attention.linear[1], layer.mixer.out_proj)
2560
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.fc1)
2561
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.fc2)
2562
+
2563
+
2564
+ @register_loader("Phi3Config")
2565
+ class Phi3Loader(ModelLoader):
2566
+ @property
2567
+ def architecture_name(self):
2568
+ return "AutoModelForCausalLM"
2569
+
2570
+ def get_model_spec(self, model):
2571
+ num_layers = model.config.num_hidden_layers
2572
+
2573
+ num_heads = model.config.num_attention_heads
2574
+ num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
2575
+ if num_heads_kv == num_heads:
2576
+ num_heads_kv = None
2577
+
2578
+ original_max_position_embeddings = getattr(
2579
+ model.config, "original_max_position_embeddings", 0
2580
+ )
2581
+ max_position_embeddings = getattr(model.config, "max_position_embeddings", 0)
2582
+ rope_scaling = getattr(model.config, "rope_scaling", None)
2583
+ if rope_scaling:
2584
+ rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
2585
+ rotary_scaling_factor = rope_scaling.get("factor", 1)
2586
+
2587
+ if rotary_scaling_type is None:
2588
+ raise NotImplementedError(
2589
+ "RoPE scaling type '%s' is not yet implemented. "
2590
+ "The following RoPE scaling types are currently supported: %s"
2591
+ % (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
2592
+ )
2593
+ else:
2594
+ rotary_scaling_type = None
2595
+ rotary_scaling_factor = 1
2596
+
2597
+ # Check for AWQ quantization config
2598
+ quantization_config = getattr(model.config, "quantization_config", None)
2599
+ if quantization_config:
2600
+ quant_type = None
2601
+ if quantization_config.quant_method == "awq":
2602
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
2603
+ if quant_type is None:
2604
+ raise NotImplementedError(
2605
+ "Quantization type '%s' is not yet implemented. "
2606
+ "The following Quantization types are currently supported: %s"
2607
+ % (
2608
+ quantization_config.quant_method,
2609
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
2610
+ )
2611
+ )
2612
+ quant_group_size = quantization_config.group_size
2613
+ quant_bits = quantization_config.bits
2614
+ else:
2615
+ quant_type = common_spec.Quantization.CT2
2616
+ quant_group_size = None
2617
+ quant_bits = None
2618
+
2619
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2620
+ num_layers,
2621
+ num_heads,
2622
+ activation=common_spec.Activation.SWISH,
2623
+ pre_norm=True,
2624
+ ffn_glu=True,
2625
+ rms_norm=True,
2626
+ rotary_dim=0,
2627
+ rotary_interleave=False,
2628
+ rotary_scaling_type=rotary_scaling_type,
2629
+ rotary_scaling_factor=rotary_scaling_factor,
2630
+ rotary_base=getattr(model.config, "rope_theta", 10000),
2631
+ original_max_position_embeddings=original_max_position_embeddings,
2632
+ max_position_embeddings=max_position_embeddings,
2633
+ num_heads_kv=num_heads_kv,
2634
+ quant_type=quant_type,
2635
+ quant_group_size=quant_group_size,
2636
+ quant_bits=quant_bits,
2637
+ )
2638
+
2639
+ self.set_decoder(spec.decoder, model.model, quant_type)
2640
+ self.set_linear(spec.decoder.projection, model.lm_head)
2641
+ return spec
2642
+
2643
+ def get_vocabulary(self, model, tokenizer):
2644
+ tokens = super().get_vocabulary(model, tokenizer)
2645
+
2646
+ extra_ids = model.config.vocab_size - len(tokens)
2647
+ for i in range(extra_ids):
2648
+ tokens.append("<extra_id_%d>" % i)
2649
+
2650
+ return tokens
2651
+
2652
+ def set_vocabulary(self, spec, tokens):
2653
+ spec.register_vocabulary(tokens)
2654
+
2655
+ def set_config(self, config, model, tokenizer):
2656
+ config.bos_token = tokenizer.bos_token
2657
+ config.eos_token = tokenizer.eos_token
2658
+ config.unk_token = tokenizer.unk_token
2659
+
2660
+ def set_layer_norm(self, spec, layer_norm):
2661
+ spec.gamma = layer_norm.weight
2662
+
2663
+ def set_rotary_embeddings(
2664
+ self, spec, rotary_scaling_long_factor, rotary_scaling_short_factor
2665
+ ):
2666
+ spec.rotary_scaling_long_factor = torch.tensor(
2667
+ rotary_scaling_long_factor, dtype=torch.float32
2668
+ )
2669
+ spec.rotary_scaling_short_factor = torch.tensor(
2670
+ rotary_scaling_short_factor, dtype=torch.float32
2671
+ )
2672
+
2673
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
2674
+ spec.scale_embeddings = False
2675
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
2676
+ self.set_layer_norm(spec.layer_norm, module.norm)
2677
+
2678
+ for layer_spec, layer in zip(spec.layer, module.layers):
2679
+ self.set_layer_norm(
2680
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
2681
+ )
2682
+ self.set_layer_norm(
2683
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2684
+ )
2685
+
2686
+ self.set_linear(
2687
+ layer_spec.self_attention.linear[0],
2688
+ layer.self_attn.qkv_proj,
2689
+ quant_type=quant_type,
2690
+ )
2691
+ self.set_linear(
2692
+ layer_spec.self_attention.linear[1],
2693
+ layer.self_attn.o_proj,
2694
+ quant_type=quant_type,
2695
+ )
2696
+ if (
2697
+ layer.self_attn.rotary_emb.long_factor is not None
2698
+ and layer.self_attn.rotary_emb.short_factor is not None
2699
+ ):
2700
+ self.set_rotary_embeddings(
2701
+ layer_spec.self_attention,
2702
+ layer.self_attn.rotary_emb.long_factor,
2703
+ layer.self_attn.rotary_emb.short_factor,
2704
+ )
2705
+
2706
+ # Handle gate_up_proj differently for AWQ vs regular models
2707
+ if quant_type == common_spec.Quantization.CT2:
2708
+ gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
2709
+ layer_spec.ffn.linear_0.weight = gate_proj
2710
+ layer_spec.ffn.linear_0_noact.weight = up_proj
2711
+ else:
2712
+ # AWQ: chunk qweight, scales, and qzeros
2713
+ gate_qweight, up_qweight = layer.mlp.gate_up_proj.qweight.chunk(
2714
+ 2, dim=1
2715
+ )
2716
+ gate_scales, up_scales = layer.mlp.gate_up_proj.scales.chunk(2, dim=1)
2717
+ gate_qzeros, up_qzeros = layer.mlp.gate_up_proj.qzeros.chunk(2, dim=1)
2718
+
2719
+ layer_spec.ffn.linear_0.weight = gate_qweight
2720
+ layer_spec.ffn.linear_0.weight_scale = gate_scales
2721
+ layer_spec.ffn.linear_0.weight_zero = gate_qzeros
2722
+
2723
+ layer_spec.ffn.linear_0_noact.weight = up_qweight
2724
+ layer_spec.ffn.linear_0_noact.weight_scale = up_scales
2725
+ layer_spec.ffn.linear_0_noact.weight_zero = up_qzeros
2726
+
2727
+ self.set_linear(
2728
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
2729
+ )
2730
+
2731
+ delattr(layer, "self_attn")
2732
+ delattr(layer, "mlp")
2733
+ gc.collect()
2734
+
2735
+
2736
+ @register_loader("RWConfig")
2737
+ class RWLoader(ModelLoader):
2738
+ @property
2739
+ def architecture_name(self):
2740
+ return "AutoModelForCausalLM"
2741
+
2742
+ def get_falcon_spec(self, model):
2743
+ self._num_layers = model.config.n_layer
2744
+ self._num_heads = model.config.n_head
2745
+ self._num_heads_kv = getattr(model.config, "n_head_kv", None)
2746
+ self._num_kv_attr = "num_kv"
2747
+
2748
+ def get_model_spec(self, model):
2749
+ self.get_falcon_spec(model)
2750
+
2751
+ if getattr(model.config, "multi_query", False):
2752
+ num_heads_kv = 1
2753
+ else:
2754
+ num_heads_kv = self._num_heads_kv
2755
+
2756
+ spec = transformer_spec.TransformerDecoderModelSpec.from_config(
2757
+ self._num_layers,
2758
+ self._num_heads,
2759
+ pre_norm=True,
2760
+ activation=common_spec.Activation.GELU,
2761
+ alibi=model.config.alibi,
2762
+ alibi_use_positive_positions=True,
2763
+ scale_alibi=True,
2764
+ rotary_dim=0 if model.config.rotary else None,
2765
+ rotary_interleave=False,
2766
+ parallel_residual=model.config.parallel_attn,
2767
+ shared_layer_norm=num_heads_kv == 1,
2768
+ num_heads_kv=num_heads_kv,
2769
+ )
2770
+
2771
+ self.set_decoder(spec.decoder, model.transformer)
2772
+ self.set_linear(spec.decoder.projection, model.lm_head)
2773
+ return spec
2774
+
2775
+ def get_vocabulary(self, model, tokenizer):
2776
+ tokens = super().get_vocabulary(model, tokenizer)
2777
+
2778
+ extra_ids = model.config.vocab_size - len(tokens)
2779
+ for i in range(extra_ids):
2780
+ tokens.append("<extra_id_%d>" % i)
2781
+
2782
+ return tokens
2783
+
2784
+ def set_vocabulary(self, spec, tokens):
2785
+ spec.register_vocabulary(tokens)
2786
+
2787
+ def set_config(self, config, model, tokenizer):
2788
+ config.bos_token = tokenizer.eos_token
2789
+ config.eos_token = tokenizer.eos_token
2790
+ config.unk_token = tokenizer.eos_token
2791
+
2792
+ def set_decoder(self, spec, module):
2793
+ spec.scale_embeddings = False
2794
+ self.set_embeddings(spec.embeddings, module.word_embeddings)
2795
+ self.set_layer_norm(spec.layer_norm, module.ln_f)
2796
+
2797
+ for layer_spec, layer in zip(spec.layer, module.h):
2798
+ if hasattr(layer, "ln_attn"):
2799
+ self.set_layer_norm(layer_spec.input_layer_norm, layer.ln_attn)
2800
+ self.set_layer_norm(layer_spec.post_attention_layer_norm, layer.ln_mlp)
2801
+ elif hasattr(layer_spec, "shared_layer_norm"):
2802
+ self.set_layer_norm(layer_spec.shared_layer_norm, layer.input_layernorm)
2803
+ else:
2804
+ self.set_layer_norm(
2805
+ layer_spec.self_attention.layer_norm, layer.input_layernorm
2806
+ )
2807
+ self.set_layer_norm(
2808
+ layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2809
+ )
2810
+
2811
+ num_kv = getattr(layer.self_attention, self._num_kv_attr)
2812
+ if num_kv == 1:
2813
+ self.set_linear(
2814
+ layer_spec.self_attention.linear[0],
2815
+ layer.self_attention.query_key_value,
2816
+ )
2817
+ else:
2818
+ self.set_qkv_linear(
2819
+ layer_spec.self_attention.linear[0],
2820
+ layer.self_attention.query_key_value,
2821
+ layer.self_attention.num_heads,
2822
+ num_kv if num_kv < layer.self_attention.num_heads else None,
2823
+ )
2824
+
2825
+ self.set_linear(
2826
+ layer_spec.self_attention.linear[1], layer.self_attention.dense
2827
+ )
2828
+
2829
+ self.set_linear(layer_spec.ffn.linear_0, layer.mlp.dense_h_to_4h)
2830
+ self.set_linear(layer_spec.ffn.linear_1, layer.mlp.dense_4h_to_h)
2831
+
2832
+ def set_qkv_linear(self, spec, module, num_heads, num_kv=None):
2833
+ weight = module.weight
2834
+
2835
+ if num_kv is None:
2836
+ weight = weight.reshape(num_heads, 3, -1, weight.shape[-1])
2837
+ weight = weight.transpose(0, 1)
2838
+ weight = weight.reshape(-1, weight.shape[-1])
2839
+ else:
2840
+ head_dim = weight.shape[0] // (num_heads + num_kv * 2)
2841
+ weight = weight.reshape(
2842
+ -1, num_heads // num_kv + 2, head_dim, weight.shape[-1]
2843
+ )
2844
+ q, k, v = weight.split([num_heads // num_kv, 1, 1], dim=1)
2845
+ weight = torch.cat(
2846
+ [
2847
+ q.reshape(num_heads * head_dim, -1),
2848
+ k.reshape(num_kv * head_dim, -1),
2849
+ v.reshape(num_kv * head_dim, -1),
2850
+ ]
2851
+ )
2852
+
2853
+ spec.weight = weight
2854
+
2855
+ if module.bias is not None:
2856
+ bias = module.bias
2857
+
2858
+ if num_kv is None:
2859
+ bias = bias.reshape(num_heads, 3, -1)
2860
+ bias = bias.transpose(0, 1)
2861
+ bias = bias.reshape(-1)
2862
+ else:
2863
+ bias = bias.reshape(-1, num_heads // num_kv + 2, head_dim)
2864
+ q, k, v = bias.split([num_heads // num_kv, 1, 1], dim=1)
2865
+ bias = torch.cat(
2866
+ [
2867
+ q.reshape(num_heads * head_dim),
2868
+ k.reshape(num_kv * head_dim),
2869
+ v.reshape(num_kv * head_dim),
2870
+ ]
2871
+ )
2872
+
2873
+ spec.bias = bias
2874
+
2875
+
2876
+ @register_loader("FalconConfig")
2877
+ class FalconLoader(RWLoader):
2878
+ def get_falcon_spec(self, model):
2879
+ self._num_layers = model.config.num_hidden_layers
2880
+ self._num_heads = model.config.num_attention_heads
2881
+ self._num_heads_kv = getattr(model.config, "num_kv_heads", None)
2882
+ self._num_kv_attr = "num_kv_heads"
2883
+
2884
+
2885
+ @register_loader("DistilBertConfig")
2886
+ class DistilBertLoader(ModelLoader):
2887
+ @property
2888
+ def architecture_name(self):
2889
+ return "DistilBertModel"
2890
+
2891
+ def get_model_spec(self, model):
2892
+ encoder_spec = transformer_spec.TransformerEncoderSpec(
2893
+ model.config.n_layers,
2894
+ model.config.n_heads,
2895
+ pre_norm=False,
2896
+ activation=_SUPPORTED_ACTIVATIONS[model.config.activation],
2897
+ layernorm_embedding=True,
2898
+ )
2899
+ spec = transformer_spec.TransformerEncoderModelSpec(
2900
+ encoder_spec,
2901
+ )
2902
+
2903
+ spec.encoder.scale_embeddings = False
2904
+
2905
+ self.set_embeddings(
2906
+ spec.encoder.embeddings[0], model.embeddings.word_embeddings
2907
+ )
2908
+ self.set_position_encodings(
2909
+ spec.encoder.position_encodings, model.embeddings.position_embeddings
2910
+ )
2911
+ self.set_layer_norm(
2912
+ spec.encoder.layernorm_embedding, model.embeddings.LayerNorm
2913
+ )
2914
+
2915
+ for layer_spec, layer in zip(spec.encoder.layer, model.transformer.layer):
2916
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
2917
+ self.set_linear(split_layers[0], layer.attention.q_lin)
2918
+ self.set_linear(split_layers[1], layer.attention.k_lin)
2919
+ self.set_linear(split_layers[2], layer.attention.v_lin)
2920
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2921
+
2922
+ self.set_linear(
2923
+ layer_spec.self_attention.linear[1], layer.attention.out_lin
2924
+ )
2925
+ self.set_layer_norm(
2926
+ layer_spec.self_attention.layer_norm, layer.sa_layer_norm
2927
+ )
2928
+
2929
+ self.set_linear(layer_spec.ffn.linear_0, layer.ffn.lin1)
2930
+ self.set_linear(layer_spec.ffn.linear_1, layer.ffn.lin2)
2931
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output_layer_norm)
2932
+
2933
+ return spec
2934
+
2935
+ def set_vocabulary(self, spec, tokens):
2936
+ spec.register_vocabulary(tokens)
2937
+
2938
+ def set_config(self, config, model, tokenizer):
2939
+ config.unk_token = tokenizer.unk_token
2940
+ config.layer_norm_epsilon = 1e-12
2941
+
2942
+
2943
+ @register_loader("BertConfig")
2944
+ class BertLoader(ModelLoader):
2945
+ @property
2946
+ def architecture_name(self):
2947
+ return "BertModel"
2948
+
2949
+ def get_model_spec(self, model):
2950
+ assert model.config.position_embedding_type == "absolute"
2951
+
2952
+ encoder_spec = transformer_spec.TransformerEncoderSpec(
2953
+ model.config.num_hidden_layers,
2954
+ model.config.num_attention_heads,
2955
+ pre_norm=False,
2956
+ activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
2957
+ layernorm_embedding=True,
2958
+ num_source_embeddings=2,
2959
+ embeddings_merge=common_spec.EmbeddingsMerge.ADD,
2960
+ )
2961
+
2962
+ spec = transformer_spec.TransformerEncoderModelSpec(
2963
+ encoder_spec,
2964
+ pooling_layer=True,
2965
+ pooling_activation=common_spec.Activation.Tanh,
2966
+ )
2967
+
2968
+ spec.encoder.scale_embeddings = False
2969
+
2970
+ self.set_embeddings(
2971
+ spec.encoder.embeddings[0], model.embeddings.word_embeddings
2972
+ )
2973
+ self.set_embeddings(
2974
+ spec.encoder.embeddings[1], model.embeddings.token_type_embeddings
2975
+ )
2976
+ self.set_position_encodings(
2977
+ spec.encoder.position_encodings, model.embeddings.position_embeddings
2978
+ )
2979
+ self.set_layer_norm(
2980
+ spec.encoder.layernorm_embedding, model.embeddings.LayerNorm
2981
+ )
2982
+
2983
+ self.set_linear(spec.pooler_dense, model.pooler.dense)
2984
+
2985
+ for layer_spec, layer in zip(spec.encoder.layer, model.encoder.layer):
2986
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
2987
+ self.set_linear(split_layers[0], layer.attention.self.query)
2988
+ self.set_linear(split_layers[1], layer.attention.self.key)
2989
+ self.set_linear(split_layers[2], layer.attention.self.value)
2990
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2991
+
2992
+ self.set_linear(
2993
+ layer_spec.self_attention.linear[1], layer.attention.output.dense
2994
+ )
2995
+ self.set_layer_norm(
2996
+ layer_spec.self_attention.layer_norm, layer.attention.output.LayerNorm
2997
+ )
2998
+
2999
+ self.set_linear(layer_spec.ffn.linear_0, layer.intermediate.dense)
3000
+ self.set_linear(layer_spec.ffn.linear_1, layer.output.dense)
3001
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output.LayerNorm)
3002
+
3003
+ return spec
3004
+
3005
+ def get_vocabulary(self, model, tokenizer):
3006
+ tokens = super().get_vocabulary(model, tokenizer)
3007
+
3008
+ extra_ids = model.config.vocab_size - len(tokens)
3009
+ for i in range(extra_ids):
3010
+ tokens.append("<extra_id_%d>" % i)
3011
+
3012
+ return tokens
3013
+
3014
+ def set_vocabulary(self, spec, tokens):
3015
+ spec.register_vocabulary(tokens)
3016
+
3017
+ def set_config(self, config, model, tokenizer):
3018
+ config.unk_token = tokenizer.unk_token
3019
+ config.layer_norm_epsilon = model.config.layer_norm_eps
3020
+
3021
+
3022
+ @register_loader("XLMRobertaConfig")
3023
+ class XLMRobertaLoader(ModelLoader):
3024
+ @property
3025
+ def architecture_name(self):
3026
+ return "XLMRobertaForSequenceClassification"
3027
+
3028
+ def get_model_spec(self, model):
3029
+ assert model.config.position_embedding_type == "absolute"
3030
+
3031
+ encoder_spec = transformer_spec.TransformerEncoderSpec(
3032
+ model.config.num_hidden_layers,
3033
+ model.config.num_attention_heads,
3034
+ pre_norm=False,
3035
+ activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
3036
+ layernorm_embedding=True,
3037
+ num_source_embeddings=2,
3038
+ embeddings_merge=common_spec.EmbeddingsMerge.ADD,
3039
+ )
3040
+
3041
+ if model.roberta.pooler is None:
3042
+ pooling_layer = False
3043
+ else:
3044
+ pooling_layer = True
3045
+
3046
+ spec = transformer_spec.TransformerEncoderModelSpec(
3047
+ encoder_spec,
3048
+ pooling_layer=pooling_layer,
3049
+ pooling_activation=common_spec.Activation.Tanh,
3050
+ )
3051
+
3052
+ spec.encoder.scale_embeddings = False
3053
+
3054
+ self.set_embeddings(
3055
+ spec.encoder.embeddings[0], model.roberta.embeddings.word_embeddings
3056
+ )
3057
+ self.set_embeddings(
3058
+ spec.encoder.embeddings[1], model.roberta.embeddings.token_type_embeddings
3059
+ )
3060
+ self.set_position_encodings(
3061
+ spec.encoder.position_encodings,
3062
+ model.roberta.embeddings.position_embeddings,
3063
+ )
3064
+ self.set_layer_norm(
3065
+ spec.encoder.layernorm_embedding, model.roberta.embeddings.LayerNorm
3066
+ )
3067
+ if pooling_layer:
3068
+ self.set_linear(spec.pooler_dense, model.roberta.pooler.dense)
3069
+
3070
+ for layer_spec, layer in zip(spec.encoder.layer, model.roberta.encoder.layer):
3071
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
3072
+ self.set_linear(split_layers[0], layer.attention.self.query)
3073
+ self.set_linear(split_layers[1], layer.attention.self.key)
3074
+ self.set_linear(split_layers[2], layer.attention.self.value)
3075
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
3076
+
3077
+ self.set_linear(
3078
+ layer_spec.self_attention.linear[1], layer.attention.output.dense
3079
+ )
3080
+ self.set_layer_norm(
3081
+ layer_spec.self_attention.layer_norm, layer.attention.output.LayerNorm
3082
+ )
3083
+
3084
+ self.set_linear(layer_spec.ffn.linear_0, layer.intermediate.dense)
3085
+ self.set_linear(layer_spec.ffn.linear_1, layer.output.dense)
3086
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output.LayerNorm)
3087
+
3088
+ return spec
3089
+
3090
+ def set_vocabulary(self, spec, tokens):
3091
+ spec.register_vocabulary(tokens)
3092
+
3093
+ def set_config(self, config, model, tokenizer):
3094
+ config.unk_token = tokenizer.unk_token
3095
+ config.layer_norm_epsilon = model.config.layer_norm_eps
3096
+
3097
+ def set_position_encodings(self, spec, module):
3098
+ spec.encodings = module.weight
3099
+ offset = getattr(module, "padding_idx", 0)
3100
+ if offset > 0:
3101
+ spec.encodings = spec.encodings[offset + 1 :]
3102
+
3103
+
3104
+ @register_loader("RobertaConfig")
3105
+ class RobertaLoader(ModelLoader):
3106
+ @property
3107
+ def architecture_name(self):
3108
+ return "RobertaModel"
3109
+
3110
+ def get_model_spec(self, model):
3111
+ assert model.config.position_embedding_type == "absolute"
3112
+
3113
+ encoder_spec = transformer_spec.TransformerEncoderSpec(
3114
+ model.config.num_hidden_layers,
3115
+ model.config.num_attention_heads,
3116
+ pre_norm=False,
3117
+ activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
3118
+ layernorm_embedding=True,
3119
+ num_source_embeddings=2,
3120
+ embeddings_merge=common_spec.EmbeddingsMerge.ADD,
3121
+ )
3122
+
3123
+ if model.pooler is None:
3124
+ pooling_layer = False
3125
+ else:
3126
+ pooling_layer = True
3127
+
3128
+ spec = transformer_spec.TransformerEncoderModelSpec(
3129
+ encoder_spec,
3130
+ pooling_layer=pooling_layer,
3131
+ pooling_activation=common_spec.Activation.Tanh,
3132
+ )
3133
+
3134
+ spec.encoder.scale_embeddings = False
3135
+
3136
+ self.set_embeddings(
3137
+ spec.encoder.embeddings[0], model.embeddings.word_embeddings
3138
+ )
3139
+ self.set_embeddings(
3140
+ spec.encoder.embeddings[1], model.embeddings.token_type_embeddings
3141
+ )
3142
+ self.set_position_encodings(
3143
+ spec.encoder.position_encodings,
3144
+ model.embeddings.position_embeddings,
3145
+ )
3146
+ self.set_layer_norm(
3147
+ spec.encoder.layernorm_embedding, model.embeddings.LayerNorm
3148
+ )
3149
+ if pooling_layer:
3150
+ self.set_linear(spec.pooler_dense, model.pooler.dense)
3151
+
3152
+ for layer_spec, layer in zip(spec.encoder.layer, model.encoder.layer):
3153
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
3154
+ self.set_linear(split_layers[0], layer.attention.self.query)
3155
+ self.set_linear(split_layers[1], layer.attention.self.key)
3156
+ self.set_linear(split_layers[2], layer.attention.self.value)
3157
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
3158
+
3159
+ self.set_linear(
3160
+ layer_spec.self_attention.linear[1], layer.attention.output.dense
3161
+ )
3162
+ self.set_layer_norm(
3163
+ layer_spec.self_attention.layer_norm, layer.attention.output.LayerNorm
3164
+ )
3165
+
3166
+ self.set_linear(layer_spec.ffn.linear_0, layer.intermediate.dense)
3167
+ self.set_linear(layer_spec.ffn.linear_1, layer.output.dense)
3168
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output.LayerNorm)
3169
+
3170
+ return spec
3171
+
3172
+ def set_vocabulary(self, spec, tokens):
3173
+ spec.register_vocabulary(tokens)
3174
+
3175
+ def set_config(self, config, model, tokenizer):
3176
+ config.unk_token = tokenizer.unk_token
3177
+ config.layer_norm_epsilon = model.config.layer_norm_eps
3178
+
3179
+ def set_position_encodings(self, spec, module):
3180
+ spec.encodings = module.weight
3181
+ offset = getattr(module, "padding_idx", 0)
3182
+ if offset > 0:
3183
+ spec.encodings = spec.encodings[offset + 1 :]
3184
+
3185
+
3186
+ @register_loader("CamembertConfig")
3187
+ class CamembertLoader(ModelLoader):
3188
+ @property
3189
+ def architecture_name(self):
3190
+ return "CamembertModel"
3191
+
3192
+ def get_model_spec(self, model):
3193
+ assert model.config.position_embedding_type == "absolute"
3194
+
3195
+ encoder_spec = transformer_spec.TransformerEncoderSpec(
3196
+ model.config.num_hidden_layers,
3197
+ model.config.num_attention_heads,
3198
+ pre_norm=False,
3199
+ activation=_SUPPORTED_ACTIVATIONS[model.config.hidden_act],
3200
+ layernorm_embedding=True,
3201
+ num_source_embeddings=2,
3202
+ embeddings_merge=common_spec.EmbeddingsMerge.ADD,
3203
+ )
3204
+
3205
+ if model.pooler is None:
3206
+ pooling_layer = False
3207
+ else:
3208
+ pooling_layer = True
3209
+
3210
+ spec = transformer_spec.TransformerEncoderModelSpec(
3211
+ encoder_spec,
3212
+ pooling_layer=pooling_layer,
3213
+ pooling_activation=common_spec.Activation.Tanh,
3214
+ )
3215
+
3216
+ spec.encoder.scale_embeddings = False
3217
+
3218
+ self.set_embeddings(
3219
+ spec.encoder.embeddings[0], model.embeddings.word_embeddings
3220
+ )
3221
+ self.set_embeddings(
3222
+ spec.encoder.embeddings[1], model.embeddings.token_type_embeddings
3223
+ )
3224
+ self.set_position_encodings(
3225
+ spec.encoder.position_encodings,
3226
+ model.embeddings.position_embeddings,
3227
+ )
3228
+ self.set_layer_norm(
3229
+ spec.encoder.layernorm_embedding, model.embeddings.LayerNorm
3230
+ )
3231
+ if pooling_layer:
3232
+ self.set_linear(spec.pooler_dense, model.pooler.dense)
3233
+
3234
+ for layer_spec, layer in zip(spec.encoder.layer, model.encoder.layer):
3235
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
3236
+ self.set_linear(split_layers[0], layer.attention.self.query)
3237
+ self.set_linear(split_layers[1], layer.attention.self.key)
3238
+ self.set_linear(split_layers[2], layer.attention.self.value)
3239
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
3240
+
3241
+ self.set_linear(
3242
+ layer_spec.self_attention.linear[1], layer.attention.output.dense
3243
+ )
3244
+ self.set_layer_norm(
3245
+ layer_spec.self_attention.layer_norm, layer.attention.output.LayerNorm
3246
+ )
3247
+
3248
+ self.set_linear(layer_spec.ffn.linear_0, layer.intermediate.dense)
3249
+ self.set_linear(layer_spec.ffn.linear_1, layer.output.dense)
3250
+ self.set_layer_norm(layer_spec.ffn.layer_norm, layer.output.LayerNorm)
3251
+
3252
+ return spec
3253
+
3254
+ def set_vocabulary(self, spec, tokens):
3255
+ spec.register_vocabulary(tokens)
3256
+
3257
+ def set_config(self, config, model, tokenizer):
3258
+ config.unk_token = tokenizer.unk_token
3259
+ config.layer_norm_epsilon = model.config.layer_norm_eps
3260
+
3261
+ def set_position_encodings(self, spec, module):
3262
+ spec.encodings = module.weight
3263
+ offset = getattr(module, "padding_idx", 0)
3264
+ if offset > 0:
3265
+ spec.encodings = spec.encodings[offset + 1 :]
3266
+
3267
+
3268
+ def main():
3269
+ parser = argparse.ArgumentParser(
3270
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
3271
+ )
3272
+ parser.add_argument(
3273
+ "--model",
3274
+ required=True,
3275
+ help=(
3276
+ "Name of the pretrained model to download, "
3277
+ "or path to a directory containing the pretrained model."
3278
+ ),
3279
+ )
3280
+ parser.add_argument(
3281
+ "--activation_scales",
3282
+ help=(
3283
+ "Path to the pre-computed activation scales. Models may "
3284
+ "use them to rescale some weights to smooth the intermediate activations "
3285
+ "and improve the quantization accuracy. See "
3286
+ "https://github.com/mit-han-lab/smoothquant."
3287
+ ),
3288
+ )
3289
+ parser.add_argument(
3290
+ "--copy_files",
3291
+ nargs="+",
3292
+ help=(
3293
+ "List of filenames to copy from the Hugging Face model to the converted "
3294
+ "model directory."
3295
+ ),
3296
+ )
3297
+ parser.add_argument(
3298
+ "--revision",
3299
+ help="Revision of the model to download from the Hugging Face Hub.",
3300
+ )
3301
+ parser.add_argument(
3302
+ "--low_cpu_mem_usage",
3303
+ action="store_true",
3304
+ help="Enable the flag low_cpu_mem_usage when loading the model with from_pretrained.",
3305
+ )
3306
+ parser.add_argument(
3307
+ "--trust_remote_code",
3308
+ action="store_true",
3309
+ help="Allow converting models using custom code.",
3310
+ )
3311
+
3312
+ Converter.declare_arguments(parser)
3313
+ args = parser.parse_args()
3314
+ converter = TransformersConverter(
3315
+ args.model,
3316
+ activation_scales=args.activation_scales,
3317
+ copy_files=args.copy_files,
3318
+ load_as_float16=args.quantization in ("float16", "int8_float16"),
3319
+ revision=args.revision,
3320
+ low_cpu_mem_usage=args.low_cpu_mem_usage,
3321
+ trust_remote_code=args.trust_remote_code,
3322
+ )
3323
+ converter.convert_from_args(args)
3324
+
3325
+
3326
+ if __name__ == "__main__":
3327
+ main()
3328
+
3329
+
3330
+ # Cross-attention heads that are highly correlated to the word-level timing,
3331
+ # i.e. the alignment between audio and text tokens.
3332
+ # Obtained from https://github.com/openai/whisper/blob/v20231106/whisper/__init__.py#L32-L47
3333
+ _WHISPER_ALIGNMENT_HEADS = {
3334
+ "openai/whisper-tiny.en": [
3335
+ (1, 0),
3336
+ (2, 0),
3337
+ (2, 5),
3338
+ (3, 0),
3339
+ (3, 1),
3340
+ (3, 2),
3341
+ (3, 3),
3342
+ (3, 4),
3343
+ ],
3344
+ "openai/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)],
3345
+ "openai/whisper-base.en": [(3, 3), (4, 7), (5, 1), (5, 5), (5, 7)],
3346
+ "openai/whisper-base": [
3347
+ (3, 1),
3348
+ (4, 2),
3349
+ (4, 3),
3350
+ (4, 7),
3351
+ (5, 1),
3352
+ (5, 2),
3353
+ (5, 4),
3354
+ (5, 6),
3355
+ ],
3356
+ "openai/whisper-small.en": [
3357
+ (6, 6),
3358
+ (7, 0),
3359
+ (7, 3),
3360
+ (7, 8),
3361
+ (8, 2),
3362
+ (8, 5),
3363
+ (8, 7),
3364
+ (9, 0),
3365
+ (9, 4),
3366
+ (9, 8),
3367
+ (9, 10),
3368
+ (10, 0),
3369
+ (10, 1),
3370
+ (10, 2),
3371
+ (10, 3),
3372
+ (10, 6),
3373
+ (10, 11),
3374
+ (11, 2),
3375
+ (11, 4),
3376
+ ],
3377
+ "openai/whisper-small": [
3378
+ (5, 3),
3379
+ (5, 9),
3380
+ (8, 0),
3381
+ (8, 4),
3382
+ (8, 7),
3383
+ (8, 8),
3384
+ (9, 0),
3385
+ (9, 7),
3386
+ (9, 9),
3387
+ (10, 5),
3388
+ ],
3389
+ "openai/whisper-medium.en": [
3390
+ (11, 4),
3391
+ (14, 1),
3392
+ (14, 12),
3393
+ (14, 14),
3394
+ (15, 4),
3395
+ (16, 0),
3396
+ (16, 4),
3397
+ (16, 9),
3398
+ (17, 12),
3399
+ (17, 14),
3400
+ (18, 7),
3401
+ (18, 10),
3402
+ (18, 15),
3403
+ (20, 0),
3404
+ (20, 3),
3405
+ (20, 9),
3406
+ (20, 14),
3407
+ (21, 12),
3408
+ ],
3409
+ "openai/whisper-medium": [(13, 15), (15, 4), (15, 15), (16, 1), (20, 0), (23, 4)],
3410
+ "openai/whisper-large": [
3411
+ (9, 19),
3412
+ (11, 2),
3413
+ (11, 4),
3414
+ (11, 17),
3415
+ (22, 7),
3416
+ (22, 11),
3417
+ (22, 17),
3418
+ (23, 2),
3419
+ (23, 15),
3420
+ ],
3421
+ "openai/whisper-large-v2": [
3422
+ (10, 12),
3423
+ (13, 17),
3424
+ (16, 11),
3425
+ (16, 12),
3426
+ (16, 13),
3427
+ (17, 15),
3428
+ (17, 16),
3429
+ (18, 4),
3430
+ (18, 11),
3431
+ (18, 19),
3432
+ (19, 11),
3433
+ (21, 2),
3434
+ (21, 3),
3435
+ (22, 3),
3436
+ (22, 9),
3437
+ (22, 12),
3438
+ (23, 5),
3439
+ (23, 7),
3440
+ (23, 13),
3441
+ (25, 5),
3442
+ (26, 1),
3443
+ (26, 12),
3444
+ (27, 15),
3445
+ ],
3446
+ "openai/whisper-large-v3": [
3447
+ (7, 0),
3448
+ (10, 17),
3449
+ (12, 18),
3450
+ (13, 12),
3451
+ (16, 1),
3452
+ (17, 14),
3453
+ (19, 11),
3454
+ (21, 4),
3455
+ (24, 1),
3456
+ (25, 6),
3457
+ ],
3458
+ }
3459
+
3460
+
3461
+ # Paper: https://arxiv.org/pdf/2504.06225
3462
+ @register_loader("T5GemmaConfig")
3463
+ class T5GemmaLoader(ModelLoader):
3464
+ @property
3465
+ def architecture_name(self):
3466
+ return "T5GemmaForConditionalGeneration"
3467
+
3468
+ def set_layer_norm(self, spec, layer_norm):
3469
+ spec.gamma = layer_norm.weight.data + 1.0
3470
+
3471
+ def get_model_spec(self, model):
3472
+ encoder_config = model.config.encoder
3473
+ decoder_config = model.config.decoder
3474
+ sliding_window = getattr(model.config, "sliding_window", 4096)
3475
+
3476
+ encoder_num_heads = encoder_config.num_attention_heads
3477
+ encoder_num_heads_kv = getattr(
3478
+ encoder_config, "num_key_value_heads", encoder_num_heads
3479
+ )
3480
+ if encoder_num_heads_kv == encoder_num_heads:
3481
+ encoder_num_heads_kv = None
3482
+
3483
+ encoder = transformer_spec.TransformerEncoderSpec(
3484
+ encoder_config.num_hidden_layers,
3485
+ encoder_config.num_attention_heads,
3486
+ pre_norm=True,
3487
+ activation=_SUPPORTED_ACTIVATIONS[encoder_config.hidden_activation],
3488
+ ffn_glu=True,
3489
+ rms_norm=True,
3490
+ rotary_dim=encoder_config.head_dim,
3491
+ rotary_interleave=False,
3492
+ rotary_base=getattr(encoder_config, "rope_theta", 10000),
3493
+ sliding_window=sliding_window,
3494
+ pre_post_layer_norm=True,
3495
+ num_heads_kv=encoder_num_heads_kv,
3496
+ head_dim=encoder_config.head_dim,
3497
+ )
3498
+
3499
+ decoder_num_heads = decoder_config.num_attention_heads
3500
+ decoder_num_heads_kv = getattr(
3501
+ decoder_config, "num_key_value_heads", decoder_num_heads
3502
+ )
3503
+ if decoder_num_heads_kv == decoder_num_heads:
3504
+ decoder_num_heads_kv = None
3505
+
3506
+ decoder = transformer_spec.TransformerDecoderSpec(
3507
+ decoder_config.num_hidden_layers,
3508
+ decoder_config.num_attention_heads,
3509
+ pre_norm=True,
3510
+ activation=_SUPPORTED_ACTIVATIONS[decoder_config.hidden_activation],
3511
+ ffn_glu=True,
3512
+ rms_norm=True,
3513
+ with_encoder_attention=True,
3514
+ rotary_dim=decoder_config.head_dim,
3515
+ rotary_interleave=False,
3516
+ rotary_base=getattr(decoder_config, "rope_theta", 10000),
3517
+ sliding_window=sliding_window,
3518
+ pre_post_layer_norm=True,
3519
+ external_pre_post_encoder_layers=True,
3520
+ num_heads_kv=decoder_num_heads_kv,
3521
+ head_dim=decoder_config.head_dim,
3522
+ )
3523
+
3524
+ spec = transformer_spec.TransformerSpec(encoder, decoder)
3525
+
3526
+ self.set_encoder(spec.encoder, model.model.encoder, encoder_config)
3527
+
3528
+ self.set_decoder(
3529
+ spec.decoder,
3530
+ model.model.decoder,
3531
+ decoder_config,
3532
+ common_spec.Quantization.CT2,
3533
+ )
3534
+
3535
+ # Tie_word_embeddings
3536
+ self.set_linear(spec.decoder.projection, model.model.decoder.embed_tokens)
3537
+ return spec
3538
+
3539
+ def set_vocabulary(self, spec, tokens):
3540
+ spec.register_source_vocabulary(tokens)
3541
+ spec.register_target_vocabulary(tokens)
3542
+
3543
+ def set_config(self, config, model, tokenizer):
3544
+ config.bos_token = tokenizer.bos_token
3545
+ config.eos_token = tokenizer.eos_token
3546
+ config.unk_token = tokenizer.unk_token
3547
+
3548
+ if hasattr(model.config, "encoder"):
3549
+ config.layer_norm_epsilon = model.config.encoder.rms_norm_eps
3550
+ elif hasattr(model.config, "rms_norm_eps"):
3551
+ config.layer_norm_epsilon = model.config.rms_norm_eps
3552
+ else:
3553
+ config.layer_norm_epsilon = 1e-6
3554
+
3555
+ config.decoder_start_token = tokenizer.bos_token
3556
+
3557
+ def set_encoder(
3558
+ self, spec, encoder, encoder_config, quant_type=common_spec.Quantization.CT2
3559
+ ):
3560
+ spec.scale_embeddings = True
3561
+
3562
+ encoder_emb_spec = (
3563
+ spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
3564
+ )
3565
+
3566
+ self.set_embeddings(encoder_emb_spec, encoder.embed_tokens)
3567
+ encoder_emb_spec.multiply_by_sqrt_depth = encoder_config.hidden_size**0.5
3568
+ self.set_layer_norm(spec.layer_norm, encoder.norm)
3569
+
3570
+ module = encoder
3571
+ for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
3572
+ self.set_layer_norm(
3573
+ layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
3574
+ )
3575
+ self.set_layer_norm(
3576
+ layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
3577
+ )
3578
+
3579
+ # T5GemmaSelfAttention
3580
+ qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
3581
+ self.set_linear(
3582
+ qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
3583
+ )
3584
+ self.set_linear(
3585
+ qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
3586
+ )
3587
+ self.set_linear(
3588
+ qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
3589
+ )
3590
+ utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
3591
+ self.set_linear(
3592
+ layer_spec.self_attention.linear[1],
3593
+ layer.self_attn.o_proj,
3594
+ quant_type=quant_type,
3595
+ )
3596
+
3597
+ # T5GemmaRMSNorm
3598
+ self.set_layer_norm(
3599
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
3600
+ )
3601
+ # T5GemmaRMSNorm
3602
+ self.set_layer_norm(
3603
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
3604
+ )
3605
+
3606
+ # T5GemmaMLP
3607
+ self.set_linear(
3608
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
3609
+ )
3610
+ self.set_linear(
3611
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
3612
+ )
3613
+ self.set_linear(
3614
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
3615
+ )
3616
+
3617
+ # Clean up
3618
+ delattr(layer, "self_attn")
3619
+ delattr(layer, "mlp")
3620
+ gc.collect()
3621
+
3622
+ def set_decoder(
3623
+ self, spec, module, decoder_config, quant_type=common_spec.Quantization.CT2
3624
+ ):
3625
+ spec.scale_embeddings = True
3626
+ spec.start_from_zero_embedding = False
3627
+
3628
+ self.set_embeddings(spec.embeddings, module.embed_tokens)
3629
+ spec.embeddings.multiply_by_sqrt_depth = decoder_config.hidden_size**0.5
3630
+ self.set_layer_norm(spec.layer_norm, module.norm)
3631
+
3632
+ for i, (layer_spec, layer) in enumerate(zip(spec.layer, module.layers)):
3633
+ # Self-attention block
3634
+ self.set_layer_norm(
3635
+ layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
3636
+ )
3637
+ self.set_layer_norm(
3638
+ layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
3639
+ )
3640
+
3641
+ # T5GemmaSelfAttention - QKV projections
3642
+ qkv_split_layers = [common_spec.LinearSpec() for _ in range(3)]
3643
+ self.set_linear(
3644
+ qkv_split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
3645
+ )
3646
+ self.set_linear(
3647
+ qkv_split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
3648
+ )
3649
+ self.set_linear(
3650
+ qkv_split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
3651
+ )
3652
+ utils.fuse_linear(layer_spec.self_attention.linear[0], qkv_split_layers)
3653
+ self.set_linear(
3654
+ layer_spec.self_attention.linear[1],
3655
+ layer.self_attn.o_proj,
3656
+ quant_type=quant_type,
3657
+ )
3658
+
3659
+ # Pre and post cross-attention layer norm
3660
+ self.set_layer_norm(
3661
+ layer_spec.external_pre_encoder_attention_layer_norm,
3662
+ layer.pre_cross_attn_layernorm,
3663
+ )
3664
+
3665
+ self.set_layer_norm(
3666
+ layer_spec.external_post_encoder_attention_layer_norm,
3667
+ layer.post_cross_attn_layernorm,
3668
+ )
3669
+
3670
+ # Cross-attention Q projection
3671
+ self.set_linear(
3672
+ layer_spec.attention.linear[0],
3673
+ layer.cross_attn.q_proj,
3674
+ quant_type=quant_type,
3675
+ )
3676
+
3677
+ # Cross-attention K+V fused
3678
+ kv_split_layers = [common_spec.LinearSpec() for _ in range(2)]
3679
+ self.set_linear(
3680
+ kv_split_layers[0],
3681
+ layer.cross_attn.k_proj,
3682
+ quant_type=quant_type,
3683
+ )
3684
+ self.set_linear(
3685
+ kv_split_layers[1],
3686
+ layer.cross_attn.v_proj,
3687
+ quant_type=quant_type,
3688
+ )
3689
+ utils.fuse_linear(layer_spec.attention.linear[1], kv_split_layers)
3690
+
3691
+ # Cross-attention output projection
3692
+ self.set_linear(
3693
+ layer_spec.attention.linear[2],
3694
+ layer.cross_attn.o_proj,
3695
+ quant_type=quant_type,
3696
+ )
3697
+
3698
+ # Feed-forward block
3699
+ self.set_layer_norm(
3700
+ layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
3701
+ )
3702
+ self.set_layer_norm(
3703
+ layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
3704
+ )
3705
+
3706
+ # T5GemmaMLP
3707
+ self.set_linear(
3708
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
3709
+ )
3710
+ self.set_linear(
3711
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
3712
+ )
3713
+ self.set_linear(
3714
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
3715
+ )
3716
+
3717
+ # Clean up
3718
+ delattr(layer, "self_attn")
3719
+ delattr(layer, "cross_attn")
3720
+ delattr(layer, "mlp")
3721
+ gc.collect()