bigdl-core-cpp 2.1.0b20230202__py3-none-manylinux2010_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. bigdl/cpp/__init__.py +0 -0
  2. bigdl/cpp/cli/init-llama-cpp +14 -0
  3. bigdl/cpp/cli/init-ollama +8 -0
  4. bigdl/cpp/convert-hf-to-gguf.py +2858 -0
  5. bigdl/cpp/convert.py +1714 -0
  6. bigdl/cpp/gguf-py/__init__.py +0 -0
  7. bigdl/cpp/gguf-py/gguf/__init__.py +7 -0
  8. bigdl/cpp/gguf-py/gguf/constants.py +1033 -0
  9. bigdl/cpp/gguf-py/gguf/gguf.py +15 -0
  10. bigdl/cpp/gguf-py/gguf/gguf_reader.py +296 -0
  11. bigdl/cpp/gguf-py/gguf/gguf_writer.py +554 -0
  12. bigdl/cpp/gguf-py/gguf/lazy.py +236 -0
  13. bigdl/cpp/gguf-py/gguf/py.typed +0 -0
  14. bigdl/cpp/gguf-py/gguf/quants.py +123 -0
  15. bigdl/cpp/gguf-py/gguf/tensor_mapping.py +463 -0
  16. bigdl/cpp/gguf-py/gguf/vocab.py +165 -0
  17. bigdl/cpp/libs/baby-llama +0 -0
  18. bigdl/cpp/libs/batched +0 -0
  19. bigdl/cpp/libs/batched-bench +0 -0
  20. bigdl/cpp/libs/benchmark +0 -0
  21. bigdl/cpp/libs/embedding +0 -0
  22. bigdl/cpp/libs/export-lora +0 -0
  23. bigdl/cpp/libs/finetune +0 -0
  24. bigdl/cpp/libs/gguf +0 -0
  25. bigdl/cpp/libs/gritlm +0 -0
  26. bigdl/cpp/libs/imatrix +0 -0
  27. bigdl/cpp/libs/infill +0 -0
  28. bigdl/cpp/libs/llama-bench +0 -0
  29. bigdl/cpp/libs/llava-cli +0 -0
  30. bigdl/cpp/libs/lookahead +0 -0
  31. bigdl/cpp/libs/lookup +0 -0
  32. bigdl/cpp/libs/ls-sycl-device +0 -0
  33. bigdl/cpp/libs/main +0 -0
  34. bigdl/cpp/libs/ollama +0 -0
  35. bigdl/cpp/libs/parallel +0 -0
  36. bigdl/cpp/libs/perplexity +0 -0
  37. bigdl/cpp/libs/quantize +0 -0
  38. bigdl/cpp/libs/quantize-stats +0 -0
  39. bigdl/cpp/libs/save-load-state +0 -0
  40. bigdl/cpp/libs/server +0 -0
  41. bigdl/cpp/libs/simple +0 -0
  42. bigdl/cpp/libs/speculative +0 -0
  43. bigdl/cpp/libs/tokenize +0 -0
  44. bigdl/cpp/libs/train-text-from-scratch +0 -0
  45. bigdl_core_cpp-2.1.0b20230202.data/scripts/init-llama-cpp +14 -0
  46. bigdl_core_cpp-2.1.0b20230202.data/scripts/init-ollama +8 -0
  47. bigdl_core_cpp-2.1.0b20230202.dist-info/METADATA +18 -0
  48. bigdl_core_cpp-2.1.0b20230202.dist-info/RECORD +50 -0
  49. bigdl_core_cpp-2.1.0b20230202.dist-info/WHEEL +5 -0
  50. bigdl_core_cpp-2.1.0b20230202.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1033 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum, IntEnum, auto
4
+ from typing import Any
5
+
6
+ #
7
+ # constants
8
+ #
9
+
10
+ GGUF_MAGIC = 0x46554747 # "GGUF"
11
+ GGUF_VERSION = 3
12
+ GGUF_DEFAULT_ALIGNMENT = 32
13
+ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
14
+
15
+ #
16
+ # metadata keys
17
+ #
18
+
19
+
20
+ class Keys:
21
+ class General:
22
+ ARCHITECTURE = "general.architecture"
23
+ QUANTIZATION_VERSION = "general.quantization_version"
24
+ ALIGNMENT = "general.alignment"
25
+ NAME = "general.name"
26
+ AUTHOR = "general.author"
27
+ VERSION = "general.version"
28
+ URL = "general.url"
29
+ DESCRIPTION = "general.description"
30
+ LICENSE = "general.license"
31
+ SOURCE_URL = "general.source.url"
32
+ SOURCE_HF_REPO = "general.source.huggingface.repository"
33
+ FILE_TYPE = "general.file_type"
34
+
35
+ class LLM:
36
+ VOCAB_SIZE = "{arch}.vocab_size"
37
+ CONTEXT_LENGTH = "{arch}.context_length"
38
+ EMBEDDING_LENGTH = "{arch}.embedding_length"
39
+ BLOCK_COUNT = "{arch}.block_count"
40
+ FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
41
+ USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
42
+ TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
43
+ EXPERT_COUNT = "{arch}.expert_count"
44
+ EXPERT_USED_COUNT = "{arch}.expert_used_count"
45
+ POOLING_TYPE = "{arch}.pooling_type"
46
+ LOGIT_SCALE = "{arch}.logit_scale"
47
+
48
+ class Attention:
49
+ HEAD_COUNT = "{arch}.attention.head_count"
50
+ HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
51
+ MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
52
+ CLAMP_KQV = "{arch}.attention.clamp_kqv"
53
+ KEY_LENGTH = "{arch}.attention.key_length"
54
+ VALUE_LENGTH = "{arch}.attention.value_length"
55
+ LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
56
+ LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
57
+ CAUSAL = "{arch}.attention.causal"
58
+
59
+ class Rope:
60
+ DIMENSION_COUNT = "{arch}.rope.dimension_count"
61
+ FREQ_BASE = "{arch}.rope.freq_base"
62
+ SCALING_TYPE = "{arch}.rope.scaling.type"
63
+ SCALING_FACTOR = "{arch}.rope.scaling.factor"
64
+ SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
65
+ SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
66
+ SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
67
+
68
+ class SSM:
69
+ CONV_KERNEL = "{arch}.ssm.conv_kernel"
70
+ INNER_SIZE = "{arch}.ssm.inner_size"
71
+ STATE_SIZE = "{arch}.ssm.state_size"
72
+ TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
73
+
74
+ class Tokenizer:
75
+ MODEL = "tokenizer.ggml.model"
76
+ PRE = "tokenizer.ggml.pre"
77
+ LIST = "tokenizer.ggml.tokens"
78
+ TOKEN_TYPE = "tokenizer.ggml.token_type"
79
+ TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
80
+ SCORES = "tokenizer.ggml.scores"
81
+ MERGES = "tokenizer.ggml.merges"
82
+ BOS_ID = "tokenizer.ggml.bos_token_id"
83
+ EOS_ID = "tokenizer.ggml.eos_token_id"
84
+ UNK_ID = "tokenizer.ggml.unknown_token_id"
85
+ SEP_ID = "tokenizer.ggml.seperator_token_id"
86
+ PAD_ID = "tokenizer.ggml.padding_token_id"
87
+ CLS_ID = "tokenizer.ggml.cls_token_id"
88
+ MASK_ID = "tokenizer.ggml.mask_token_id"
89
+ ADD_BOS = "tokenizer.ggml.add_bos_token"
90
+ ADD_EOS = "tokenizer.ggml.add_eos_token"
91
+ ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
92
+ HF_JSON = "tokenizer.huggingface.json"
93
+ RWKV = "tokenizer.rwkv.world"
94
+ CHAT_TEMPLATE = "tokenizer.chat_template"
95
+ CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
96
+ CHAT_TEMPLATES = "tokenizer.chat_templates"
97
+ # FIM/Infill special tokens constants
98
+ PREFIX_ID = "tokenizer.ggml.prefix_token_id"
99
+ SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
100
+ MIDDLE_ID = "tokenizer.ggml.middle_token_id"
101
+ EOT_ID = "tokenizer.ggml.eot_token_id"
102
+
103
+
104
+ #
105
+ # recommended mapping of model tensor names for storage in gguf
106
+ #
107
+
108
+
109
+ class MODEL_ARCH(IntEnum):
110
+ LLAMA = auto()
111
+ FALCON = auto()
112
+ BAICHUAN = auto()
113
+ GROK = auto()
114
+ GPT2 = auto()
115
+ GPTJ = auto()
116
+ GPTNEOX = auto()
117
+ MPT = auto()
118
+ STARCODER = auto()
119
+ REFACT = auto()
120
+ BERT = auto()
121
+ NOMIC_BERT = auto()
122
+ JINA_BERT_V2 = auto()
123
+ BLOOM = auto()
124
+ STABLELM = auto()
125
+ QWEN = auto()
126
+ QWEN2 = auto()
127
+ QWEN2MOE = auto()
128
+ PHI2 = auto()
129
+ PHI3 = auto()
130
+ PLAMO = auto()
131
+ CODESHELL = auto()
132
+ ORION = auto()
133
+ INTERNLM2 = auto()
134
+ MINICPM = auto()
135
+ GEMMA = auto()
136
+ STARCODER2 = auto()
137
+ MAMBA = auto()
138
+ XVERSE = auto()
139
+ COMMAND_R = auto()
140
+ DBRX = auto()
141
+ OLMO = auto()
142
+ ARCTIC = auto()
143
+
144
+
145
+ class MODEL_TENSOR(IntEnum):
146
+ TOKEN_EMBD = auto()
147
+ TOKEN_EMBD_NORM = auto()
148
+ TOKEN_TYPES = auto()
149
+ POS_EMBD = auto()
150
+ OUTPUT = auto()
151
+ OUTPUT_NORM = auto()
152
+ ROPE_FREQS = auto()
153
+ ROPE_FACTORS_LONG = auto()
154
+ ROPE_FACTORS_SHORT = auto()
155
+ ATTN_Q = auto()
156
+ ATTN_K = auto()
157
+ ATTN_V = auto()
158
+ ATTN_QKV = auto()
159
+ ATTN_OUT = auto()
160
+ ATTN_NORM = auto()
161
+ ATTN_NORM_2 = auto()
162
+ ATTN_OUT_NORM = auto()
163
+ ATTN_ROT_EMBD = auto()
164
+ FFN_GATE_INP = auto()
165
+ FFN_GATE_INP_SHEXP = auto()
166
+ FFN_NORM = auto()
167
+ FFN_GATE = auto()
168
+ FFN_DOWN = auto()
169
+ FFN_UP = auto()
170
+ FFN_ACT = auto()
171
+ FFN_NORM_EXP = auto()
172
+ FFN_GATE_EXP = auto()
173
+ FFN_DOWN_EXP = auto()
174
+ FFN_UP_EXP = auto()
175
+ FFN_GATE_SHEXP = auto()
176
+ FFN_DOWN_SHEXP = auto()
177
+ FFN_UP_SHEXP = auto()
178
+ ATTN_Q_NORM = auto()
179
+ ATTN_K_NORM = auto()
180
+ LAYER_OUT_NORM = auto()
181
+ SSM_IN = auto()
182
+ SSM_CONV1D = auto()
183
+ SSM_X = auto()
184
+ SSM_DT = auto()
185
+ SSM_A = auto()
186
+ SSM_D = auto()
187
+ SSM_OUT = auto()
188
+
189
+
190
+ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
191
+ MODEL_ARCH.LLAMA: "llama",
192
+ MODEL_ARCH.FALCON: "falcon",
193
+ MODEL_ARCH.BAICHUAN: "baichuan",
194
+ MODEL_ARCH.GROK: "grok",
195
+ MODEL_ARCH.GPT2: "gpt2",
196
+ MODEL_ARCH.GPTJ: "gptj",
197
+ MODEL_ARCH.GPTNEOX: "gptneox",
198
+ MODEL_ARCH.MPT: "mpt",
199
+ MODEL_ARCH.STARCODER: "starcoder",
200
+ MODEL_ARCH.REFACT: "refact",
201
+ MODEL_ARCH.BERT: "bert",
202
+ MODEL_ARCH.NOMIC_BERT: "nomic-bert",
203
+ MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
204
+ MODEL_ARCH.BLOOM: "bloom",
205
+ MODEL_ARCH.STABLELM: "stablelm",
206
+ MODEL_ARCH.QWEN: "qwen",
207
+ MODEL_ARCH.QWEN2: "qwen2",
208
+ MODEL_ARCH.QWEN2MOE: "qwen2moe",
209
+ MODEL_ARCH.PHI2: "phi2",
210
+ MODEL_ARCH.PHI3: "phi3",
211
+ MODEL_ARCH.PLAMO: "plamo",
212
+ MODEL_ARCH.CODESHELL: "codeshell",
213
+ MODEL_ARCH.ORION: "orion",
214
+ MODEL_ARCH.INTERNLM2: "internlm2",
215
+ MODEL_ARCH.MINICPM: "minicpm",
216
+ MODEL_ARCH.GEMMA: "gemma",
217
+ MODEL_ARCH.STARCODER2: "starcoder2",
218
+ MODEL_ARCH.MAMBA: "mamba",
219
+ MODEL_ARCH.XVERSE: "xverse",
220
+ MODEL_ARCH.COMMAND_R: "command-r",
221
+ MODEL_ARCH.DBRX: "dbrx",
222
+ MODEL_ARCH.OLMO: "olmo",
223
+ MODEL_ARCH.ARCTIC: "arctic",
224
+ }
225
+
226
+ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
227
+ MODEL_TENSOR.TOKEN_EMBD: "token_embd",
228
+ MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
229
+ MODEL_TENSOR.TOKEN_TYPES: "token_types",
230
+ MODEL_TENSOR.POS_EMBD: "position_embd",
231
+ MODEL_TENSOR.OUTPUT_NORM: "output_norm",
232
+ MODEL_TENSOR.OUTPUT: "output",
233
+ MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
234
+ MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
235
+ MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
236
+ MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
237
+ MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
238
+ MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
239
+ MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
240
+ MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
241
+ MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
242
+ MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
243
+ MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
244
+ MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
245
+ MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
246
+ MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
247
+ MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
248
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
249
+ MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
250
+ MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
251
+ MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
252
+ MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
253
+ MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
254
+ MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
255
+ MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
256
+ MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
257
+ MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
258
+ MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
259
+ MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
260
+ MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
261
+ MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
262
+ MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
263
+ MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
264
+ MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
265
+ MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
266
+ MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
267
+ MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
268
+ MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
269
+ }
270
+
271
+ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
272
+ MODEL_ARCH.LLAMA: [
273
+ MODEL_TENSOR.TOKEN_EMBD,
274
+ MODEL_TENSOR.OUTPUT_NORM,
275
+ MODEL_TENSOR.OUTPUT,
276
+ MODEL_TENSOR.ROPE_FREQS,
277
+ MODEL_TENSOR.ATTN_NORM,
278
+ MODEL_TENSOR.ATTN_Q,
279
+ MODEL_TENSOR.ATTN_K,
280
+ MODEL_TENSOR.ATTN_V,
281
+ MODEL_TENSOR.ATTN_OUT,
282
+ MODEL_TENSOR.ATTN_ROT_EMBD,
283
+ MODEL_TENSOR.FFN_GATE_INP,
284
+ MODEL_TENSOR.FFN_NORM,
285
+ MODEL_TENSOR.FFN_GATE,
286
+ MODEL_TENSOR.FFN_DOWN,
287
+ MODEL_TENSOR.FFN_UP,
288
+ MODEL_TENSOR.FFN_GATE_EXP,
289
+ MODEL_TENSOR.FFN_DOWN_EXP,
290
+ MODEL_TENSOR.FFN_UP_EXP,
291
+ ],
292
+ MODEL_ARCH.GROK: [
293
+ MODEL_TENSOR.TOKEN_EMBD,
294
+ MODEL_TENSOR.OUTPUT_NORM,
295
+ MODEL_TENSOR.OUTPUT,
296
+ MODEL_TENSOR.ROPE_FREQS,
297
+ MODEL_TENSOR.ATTN_NORM,
298
+ MODEL_TENSOR.ATTN_Q,
299
+ MODEL_TENSOR.ATTN_K,
300
+ MODEL_TENSOR.ATTN_V,
301
+ MODEL_TENSOR.ATTN_OUT,
302
+ MODEL_TENSOR.ATTN_ROT_EMBD,
303
+ MODEL_TENSOR.ATTN_OUT_NORM,
304
+ MODEL_TENSOR.FFN_GATE_INP,
305
+ MODEL_TENSOR.FFN_NORM,
306
+ MODEL_TENSOR.FFN_GATE,
307
+ MODEL_TENSOR.FFN_DOWN,
308
+ MODEL_TENSOR.FFN_UP,
309
+ MODEL_TENSOR.FFN_GATE_EXP,
310
+ MODEL_TENSOR.FFN_DOWN_EXP,
311
+ MODEL_TENSOR.FFN_UP_EXP,
312
+ MODEL_TENSOR.LAYER_OUT_NORM,
313
+ ],
314
+ MODEL_ARCH.GPTNEOX: [
315
+ MODEL_TENSOR.TOKEN_EMBD,
316
+ MODEL_TENSOR.OUTPUT_NORM,
317
+ MODEL_TENSOR.OUTPUT,
318
+ MODEL_TENSOR.ATTN_NORM,
319
+ MODEL_TENSOR.ATTN_QKV,
320
+ MODEL_TENSOR.ATTN_OUT,
321
+ MODEL_TENSOR.FFN_NORM,
322
+ MODEL_TENSOR.FFN_DOWN,
323
+ MODEL_TENSOR.FFN_UP,
324
+ ],
325
+ MODEL_ARCH.FALCON: [
326
+ MODEL_TENSOR.TOKEN_EMBD,
327
+ MODEL_TENSOR.OUTPUT_NORM,
328
+ MODEL_TENSOR.OUTPUT,
329
+ MODEL_TENSOR.ATTN_NORM,
330
+ MODEL_TENSOR.ATTN_NORM_2,
331
+ MODEL_TENSOR.ATTN_QKV,
332
+ MODEL_TENSOR.ATTN_OUT,
333
+ MODEL_TENSOR.FFN_DOWN,
334
+ MODEL_TENSOR.FFN_UP,
335
+ ],
336
+ MODEL_ARCH.BAICHUAN: [
337
+ MODEL_TENSOR.TOKEN_EMBD,
338
+ MODEL_TENSOR.OUTPUT_NORM,
339
+ MODEL_TENSOR.OUTPUT,
340
+ MODEL_TENSOR.ROPE_FREQS,
341
+ MODEL_TENSOR.ATTN_NORM,
342
+ MODEL_TENSOR.ATTN_Q,
343
+ MODEL_TENSOR.ATTN_K,
344
+ MODEL_TENSOR.ATTN_V,
345
+ MODEL_TENSOR.ATTN_OUT,
346
+ MODEL_TENSOR.ATTN_ROT_EMBD,
347
+ MODEL_TENSOR.FFN_NORM,
348
+ MODEL_TENSOR.FFN_GATE,
349
+ MODEL_TENSOR.FFN_DOWN,
350
+ MODEL_TENSOR.FFN_UP,
351
+ ],
352
+ MODEL_ARCH.STARCODER: [
353
+ MODEL_TENSOR.TOKEN_EMBD,
354
+ MODEL_TENSOR.POS_EMBD,
355
+ MODEL_TENSOR.OUTPUT_NORM,
356
+ MODEL_TENSOR.OUTPUT,
357
+ MODEL_TENSOR.ATTN_NORM,
358
+ MODEL_TENSOR.ATTN_QKV,
359
+ MODEL_TENSOR.ATTN_OUT,
360
+ MODEL_TENSOR.FFN_NORM,
361
+ MODEL_TENSOR.FFN_DOWN,
362
+ MODEL_TENSOR.FFN_UP,
363
+ ],
364
+ MODEL_ARCH.BERT: [
365
+ MODEL_TENSOR.TOKEN_EMBD,
366
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
367
+ MODEL_TENSOR.TOKEN_TYPES,
368
+ MODEL_TENSOR.POS_EMBD,
369
+ MODEL_TENSOR.OUTPUT_NORM,
370
+ MODEL_TENSOR.ATTN_OUT_NORM,
371
+ MODEL_TENSOR.ATTN_Q,
372
+ MODEL_TENSOR.ATTN_K,
373
+ MODEL_TENSOR.ATTN_V,
374
+ MODEL_TENSOR.ATTN_OUT,
375
+ MODEL_TENSOR.FFN_DOWN,
376
+ MODEL_TENSOR.FFN_UP,
377
+ MODEL_TENSOR.LAYER_OUT_NORM,
378
+ ],
379
+ MODEL_ARCH.NOMIC_BERT: [
380
+ MODEL_TENSOR.TOKEN_EMBD,
381
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
382
+ MODEL_TENSOR.TOKEN_TYPES,
383
+ MODEL_TENSOR.POS_EMBD,
384
+ MODEL_TENSOR.OUTPUT_NORM,
385
+ MODEL_TENSOR.ATTN_OUT_NORM,
386
+ MODEL_TENSOR.ATTN_QKV,
387
+ MODEL_TENSOR.ATTN_OUT,
388
+ MODEL_TENSOR.FFN_GATE,
389
+ MODEL_TENSOR.FFN_DOWN,
390
+ MODEL_TENSOR.FFN_UP,
391
+ MODEL_TENSOR.LAYER_OUT_NORM,
392
+ ],
393
+ MODEL_ARCH.JINA_BERT_V2: [
394
+ MODEL_TENSOR.TOKEN_EMBD,
395
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
396
+ MODEL_TENSOR.TOKEN_TYPES,
397
+ MODEL_TENSOR.ATTN_OUT_NORM,
398
+ MODEL_TENSOR.ATTN_Q,
399
+ MODEL_TENSOR.ATTN_Q_NORM,
400
+ MODEL_TENSOR.ATTN_K,
401
+ MODEL_TENSOR.ATTN_K_NORM,
402
+ MODEL_TENSOR.ATTN_V,
403
+ MODEL_TENSOR.ATTN_OUT,
404
+ MODEL_TENSOR.FFN_UP,
405
+ MODEL_TENSOR.FFN_GATE,
406
+ MODEL_TENSOR.FFN_DOWN,
407
+ MODEL_TENSOR.LAYER_OUT_NORM,
408
+ ],
409
+ MODEL_ARCH.MPT: [
410
+ MODEL_TENSOR.TOKEN_EMBD,
411
+ MODEL_TENSOR.OUTPUT_NORM,
412
+ MODEL_TENSOR.OUTPUT,
413
+ MODEL_TENSOR.ATTN_NORM,
414
+ MODEL_TENSOR.ATTN_QKV,
415
+ MODEL_TENSOR.ATTN_OUT,
416
+ MODEL_TENSOR.FFN_NORM,
417
+ MODEL_TENSOR.FFN_DOWN,
418
+ MODEL_TENSOR.FFN_UP,
419
+ MODEL_TENSOR.FFN_ACT,
420
+ MODEL_TENSOR.ATTN_Q_NORM,
421
+ MODEL_TENSOR.ATTN_K_NORM,
422
+ MODEL_TENSOR.POS_EMBD,
423
+ ],
424
+ MODEL_ARCH.GPTJ: [
425
+ MODEL_TENSOR.TOKEN_EMBD,
426
+ MODEL_TENSOR.OUTPUT_NORM,
427
+ MODEL_TENSOR.OUTPUT,
428
+ MODEL_TENSOR.ATTN_NORM,
429
+ MODEL_TENSOR.ATTN_Q,
430
+ MODEL_TENSOR.ATTN_K,
431
+ MODEL_TENSOR.ATTN_V,
432
+ MODEL_TENSOR.ATTN_OUT,
433
+ MODEL_TENSOR.FFN_DOWN,
434
+ MODEL_TENSOR.FFN_UP,
435
+ ],
436
+ MODEL_ARCH.REFACT: [
437
+ MODEL_TENSOR.TOKEN_EMBD,
438
+ MODEL_TENSOR.OUTPUT_NORM,
439
+ MODEL_TENSOR.OUTPUT,
440
+ MODEL_TENSOR.ATTN_NORM,
441
+ MODEL_TENSOR.ATTN_Q,
442
+ MODEL_TENSOR.ATTN_K,
443
+ MODEL_TENSOR.ATTN_V,
444
+ MODEL_TENSOR.ATTN_OUT,
445
+ MODEL_TENSOR.FFN_NORM,
446
+ MODEL_TENSOR.FFN_GATE,
447
+ MODEL_TENSOR.FFN_DOWN,
448
+ MODEL_TENSOR.FFN_UP,
449
+ ],
450
+ MODEL_ARCH.BLOOM: [
451
+ MODEL_TENSOR.TOKEN_EMBD,
452
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
453
+ MODEL_TENSOR.OUTPUT_NORM,
454
+ MODEL_TENSOR.OUTPUT,
455
+ MODEL_TENSOR.ATTN_NORM,
456
+ MODEL_TENSOR.ATTN_QKV,
457
+ MODEL_TENSOR.ATTN_OUT,
458
+ MODEL_TENSOR.FFN_NORM,
459
+ MODEL_TENSOR.FFN_DOWN,
460
+ MODEL_TENSOR.FFN_UP,
461
+ ],
462
+ MODEL_ARCH.STABLELM: [
463
+ MODEL_TENSOR.TOKEN_EMBD,
464
+ MODEL_TENSOR.OUTPUT_NORM,
465
+ MODEL_TENSOR.OUTPUT,
466
+ MODEL_TENSOR.ROPE_FREQS,
467
+ MODEL_TENSOR.ATTN_NORM,
468
+ MODEL_TENSOR.ATTN_Q,
469
+ MODEL_TENSOR.ATTN_K,
470
+ MODEL_TENSOR.ATTN_V,
471
+ MODEL_TENSOR.ATTN_OUT,
472
+ MODEL_TENSOR.FFN_NORM,
473
+ MODEL_TENSOR.FFN_GATE,
474
+ MODEL_TENSOR.FFN_DOWN,
475
+ MODEL_TENSOR.FFN_UP,
476
+ MODEL_TENSOR.ATTN_Q_NORM,
477
+ MODEL_TENSOR.ATTN_K_NORM,
478
+ ],
479
+ MODEL_ARCH.QWEN: [
480
+ MODEL_TENSOR.TOKEN_EMBD,
481
+ MODEL_TENSOR.OUTPUT_NORM,
482
+ MODEL_TENSOR.OUTPUT,
483
+ MODEL_TENSOR.ROPE_FREQS,
484
+ MODEL_TENSOR.ATTN_NORM,
485
+ MODEL_TENSOR.ATTN_QKV,
486
+ MODEL_TENSOR.ATTN_OUT,
487
+ MODEL_TENSOR.ATTN_ROT_EMBD,
488
+ MODEL_TENSOR.FFN_NORM,
489
+ MODEL_TENSOR.FFN_GATE,
490
+ MODEL_TENSOR.FFN_DOWN,
491
+ MODEL_TENSOR.FFN_UP,
492
+ ],
493
+ MODEL_ARCH.QWEN2: [
494
+ MODEL_TENSOR.TOKEN_EMBD,
495
+ MODEL_TENSOR.OUTPUT_NORM,
496
+ MODEL_TENSOR.OUTPUT,
497
+ MODEL_TENSOR.ATTN_NORM,
498
+ MODEL_TENSOR.ATTN_Q,
499
+ MODEL_TENSOR.ATTN_K,
500
+ MODEL_TENSOR.ATTN_V,
501
+ MODEL_TENSOR.ATTN_OUT,
502
+ MODEL_TENSOR.FFN_NORM,
503
+ MODEL_TENSOR.FFN_GATE,
504
+ MODEL_TENSOR.FFN_DOWN,
505
+ MODEL_TENSOR.FFN_UP,
506
+ ],
507
+ MODEL_ARCH.QWEN2MOE: [
508
+ MODEL_TENSOR.TOKEN_EMBD,
509
+ MODEL_TENSOR.OUTPUT_NORM,
510
+ MODEL_TENSOR.OUTPUT,
511
+ MODEL_TENSOR.ATTN_NORM,
512
+ MODEL_TENSOR.ATTN_Q,
513
+ MODEL_TENSOR.ATTN_K,
514
+ MODEL_TENSOR.ATTN_V,
515
+ MODEL_TENSOR.ATTN_OUT,
516
+ MODEL_TENSOR.FFN_NORM,
517
+ MODEL_TENSOR.FFN_GATE_INP,
518
+ MODEL_TENSOR.FFN_GATE_EXP,
519
+ MODEL_TENSOR.FFN_DOWN_EXP,
520
+ MODEL_TENSOR.FFN_UP_EXP,
521
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP,
522
+ MODEL_TENSOR.FFN_GATE_SHEXP,
523
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
524
+ MODEL_TENSOR.FFN_UP_SHEXP,
525
+ ],
526
+ MODEL_ARCH.PLAMO: [
527
+ MODEL_TENSOR.TOKEN_EMBD,
528
+ MODEL_TENSOR.OUTPUT_NORM,
529
+ MODEL_TENSOR.OUTPUT,
530
+ MODEL_TENSOR.ROPE_FREQS,
531
+ MODEL_TENSOR.ATTN_NORM,
532
+ MODEL_TENSOR.ATTN_Q,
533
+ MODEL_TENSOR.ATTN_K,
534
+ MODEL_TENSOR.ATTN_V,
535
+ MODEL_TENSOR.ATTN_OUT,
536
+ MODEL_TENSOR.ATTN_ROT_EMBD,
537
+ MODEL_TENSOR.FFN_GATE,
538
+ MODEL_TENSOR.FFN_DOWN,
539
+ MODEL_TENSOR.FFN_UP,
540
+ ],
541
+ MODEL_ARCH.GPT2: [
542
+ MODEL_TENSOR.TOKEN_EMBD,
543
+ MODEL_TENSOR.POS_EMBD,
544
+ MODEL_TENSOR.OUTPUT_NORM,
545
+ MODEL_TENSOR.OUTPUT,
546
+ MODEL_TENSOR.ATTN_NORM,
547
+ MODEL_TENSOR.ATTN_QKV,
548
+ MODEL_TENSOR.ATTN_OUT,
549
+ MODEL_TENSOR.FFN_NORM,
550
+ MODEL_TENSOR.FFN_DOWN,
551
+ MODEL_TENSOR.FFN_UP,
552
+ ],
553
+ MODEL_ARCH.PHI2: [
554
+ MODEL_TENSOR.TOKEN_EMBD,
555
+ MODEL_TENSOR.OUTPUT_NORM,
556
+ MODEL_TENSOR.OUTPUT,
557
+ MODEL_TENSOR.ATTN_NORM,
558
+ MODEL_TENSOR.ATTN_QKV,
559
+ MODEL_TENSOR.ATTN_Q,
560
+ MODEL_TENSOR.ATTN_K,
561
+ MODEL_TENSOR.ATTN_V,
562
+ MODEL_TENSOR.ATTN_OUT,
563
+ MODEL_TENSOR.FFN_NORM,
564
+ MODEL_TENSOR.FFN_DOWN,
565
+ MODEL_TENSOR.FFN_UP,
566
+ ],
567
+ MODEL_ARCH.PHI3: [
568
+ MODEL_TENSOR.TOKEN_EMBD,
569
+ MODEL_TENSOR.OUTPUT_NORM,
570
+ MODEL_TENSOR.OUTPUT,
571
+ MODEL_TENSOR.ATTN_NORM,
572
+ MODEL_TENSOR.ATTN_QKV,
573
+ MODEL_TENSOR.ATTN_Q,
574
+ MODEL_TENSOR.ATTN_K,
575
+ MODEL_TENSOR.ATTN_V,
576
+ MODEL_TENSOR.ATTN_OUT,
577
+ MODEL_TENSOR.FFN_NORM,
578
+ MODEL_TENSOR.FFN_DOWN,
579
+ MODEL_TENSOR.FFN_UP,
580
+ ],
581
+ MODEL_ARCH.CODESHELL: [
582
+ MODEL_TENSOR.TOKEN_EMBD,
583
+ MODEL_TENSOR.POS_EMBD,
584
+ MODEL_TENSOR.OUTPUT_NORM,
585
+ MODEL_TENSOR.OUTPUT,
586
+ MODEL_TENSOR.ATTN_NORM,
587
+ MODEL_TENSOR.ATTN_QKV,
588
+ MODEL_TENSOR.ATTN_OUT,
589
+ MODEL_TENSOR.ATTN_ROT_EMBD,
590
+ MODEL_TENSOR.FFN_NORM,
591
+ MODEL_TENSOR.FFN_DOWN,
592
+ MODEL_TENSOR.FFN_UP,
593
+ ],
594
+ MODEL_ARCH.ORION: [
595
+ MODEL_TENSOR.TOKEN_EMBD,
596
+ MODEL_TENSOR.OUTPUT_NORM,
597
+ MODEL_TENSOR.OUTPUT,
598
+ MODEL_TENSOR.ROPE_FREQS,
599
+ MODEL_TENSOR.ATTN_NORM,
600
+ MODEL_TENSOR.ATTN_Q,
601
+ MODEL_TENSOR.ATTN_K,
602
+ MODEL_TENSOR.ATTN_V,
603
+ MODEL_TENSOR.ATTN_OUT,
604
+ MODEL_TENSOR.ATTN_ROT_EMBD,
605
+ MODEL_TENSOR.FFN_NORM,
606
+ MODEL_TENSOR.FFN_GATE,
607
+ MODEL_TENSOR.FFN_DOWN,
608
+ MODEL_TENSOR.FFN_UP,
609
+ ],
610
+ MODEL_ARCH.INTERNLM2: [
611
+ MODEL_TENSOR.TOKEN_EMBD,
612
+ MODEL_TENSOR.OUTPUT_NORM,
613
+ MODEL_TENSOR.OUTPUT,
614
+ MODEL_TENSOR.ATTN_NORM,
615
+ MODEL_TENSOR.ATTN_Q,
616
+ MODEL_TENSOR.ATTN_K,
617
+ MODEL_TENSOR.ATTN_V,
618
+ MODEL_TENSOR.ATTN_OUT,
619
+ MODEL_TENSOR.ATTN_ROT_EMBD,
620
+ MODEL_TENSOR.FFN_NORM,
621
+ MODEL_TENSOR.FFN_GATE,
622
+ MODEL_TENSOR.FFN_DOWN,
623
+ MODEL_TENSOR.FFN_UP,
624
+ ],
625
+ MODEL_ARCH.MINICPM: [
626
+ MODEL_TENSOR.TOKEN_EMBD,
627
+ MODEL_TENSOR.OUTPUT_NORM,
628
+ MODEL_TENSOR.ROPE_FREQS,
629
+ MODEL_TENSOR.ATTN_NORM,
630
+ MODEL_TENSOR.ATTN_Q,
631
+ MODEL_TENSOR.ATTN_K,
632
+ MODEL_TENSOR.ATTN_V,
633
+ MODEL_TENSOR.ATTN_OUT,
634
+ MODEL_TENSOR.ATTN_ROT_EMBD,
635
+ MODEL_TENSOR.FFN_GATE_INP,
636
+ MODEL_TENSOR.FFN_NORM,
637
+ MODEL_TENSOR.FFN_GATE,
638
+ MODEL_TENSOR.FFN_DOWN,
639
+ MODEL_TENSOR.FFN_UP,
640
+ MODEL_TENSOR.FFN_GATE_EXP,
641
+ MODEL_TENSOR.FFN_DOWN_EXP,
642
+ MODEL_TENSOR.FFN_UP_EXP,
643
+ ],
644
+ MODEL_ARCH.GEMMA: [
645
+ MODEL_TENSOR.TOKEN_EMBD,
646
+ MODEL_TENSOR.OUTPUT_NORM,
647
+ MODEL_TENSOR.ATTN_NORM,
648
+ MODEL_TENSOR.ATTN_Q,
649
+ MODEL_TENSOR.ATTN_K,
650
+ MODEL_TENSOR.ATTN_V,
651
+ MODEL_TENSOR.ATTN_OUT,
652
+ MODEL_TENSOR.FFN_GATE,
653
+ MODEL_TENSOR.FFN_DOWN,
654
+ MODEL_TENSOR.FFN_UP,
655
+ MODEL_TENSOR.FFN_NORM,
656
+ ],
657
+ MODEL_ARCH.STARCODER2: [
658
+ MODEL_TENSOR.TOKEN_EMBD,
659
+ MODEL_TENSOR.OUTPUT_NORM,
660
+ MODEL_TENSOR.OUTPUT,
661
+ MODEL_TENSOR.ROPE_FREQS,
662
+ MODEL_TENSOR.ATTN_NORM,
663
+ MODEL_TENSOR.ATTN_Q,
664
+ MODEL_TENSOR.ATTN_K,
665
+ MODEL_TENSOR.ATTN_V,
666
+ MODEL_TENSOR.ATTN_OUT,
667
+ MODEL_TENSOR.ATTN_ROT_EMBD,
668
+ MODEL_TENSOR.FFN_NORM,
669
+ MODEL_TENSOR.FFN_DOWN,
670
+ MODEL_TENSOR.FFN_UP,
671
+ ],
672
+ MODEL_ARCH.MAMBA: [
673
+ MODEL_TENSOR.TOKEN_EMBD,
674
+ MODEL_TENSOR.OUTPUT_NORM,
675
+ MODEL_TENSOR.OUTPUT,
676
+ MODEL_TENSOR.ATTN_NORM,
677
+ MODEL_TENSOR.SSM_IN,
678
+ MODEL_TENSOR.SSM_CONV1D,
679
+ MODEL_TENSOR.SSM_X,
680
+ MODEL_TENSOR.SSM_DT,
681
+ MODEL_TENSOR.SSM_A,
682
+ MODEL_TENSOR.SSM_D,
683
+ MODEL_TENSOR.SSM_OUT,
684
+ ],
685
+ MODEL_ARCH.XVERSE: [
686
+ MODEL_TENSOR.TOKEN_EMBD,
687
+ MODEL_TENSOR.OUTPUT_NORM,
688
+ MODEL_TENSOR.OUTPUT,
689
+ MODEL_TENSOR.ROPE_FREQS,
690
+ MODEL_TENSOR.ATTN_NORM,
691
+ MODEL_TENSOR.ATTN_Q,
692
+ MODEL_TENSOR.ATTN_K,
693
+ MODEL_TENSOR.ATTN_V,
694
+ MODEL_TENSOR.ATTN_OUT,
695
+ MODEL_TENSOR.ATTN_ROT_EMBD,
696
+ MODEL_TENSOR.FFN_NORM,
697
+ MODEL_TENSOR.FFN_GATE,
698
+ MODEL_TENSOR.FFN_DOWN,
699
+ MODEL_TENSOR.FFN_UP,
700
+ ],
701
+ MODEL_ARCH.COMMAND_R: [
702
+ MODEL_TENSOR.TOKEN_EMBD,
703
+ MODEL_TENSOR.OUTPUT_NORM,
704
+ MODEL_TENSOR.ATTN_NORM,
705
+ MODEL_TENSOR.ATTN_Q,
706
+ MODEL_TENSOR.ATTN_K,
707
+ MODEL_TENSOR.ATTN_V,
708
+ MODEL_TENSOR.ATTN_OUT,
709
+ MODEL_TENSOR.FFN_GATE,
710
+ MODEL_TENSOR.FFN_DOWN,
711
+ MODEL_TENSOR.FFN_UP,
712
+ MODEL_TENSOR.ATTN_K_NORM,
713
+ MODEL_TENSOR.ATTN_Q_NORM,
714
+ ],
715
+ MODEL_ARCH.DBRX: [
716
+ MODEL_TENSOR.TOKEN_EMBD,
717
+ MODEL_TENSOR.OUTPUT_NORM,
718
+ MODEL_TENSOR.OUTPUT,
719
+ MODEL_TENSOR.ATTN_NORM,
720
+ MODEL_TENSOR.ATTN_QKV,
721
+ MODEL_TENSOR.ATTN_OUT,
722
+ MODEL_TENSOR.ATTN_OUT_NORM,
723
+ MODEL_TENSOR.FFN_GATE_INP,
724
+ MODEL_TENSOR.FFN_GATE_EXP,
725
+ MODEL_TENSOR.FFN_DOWN_EXP,
726
+ MODEL_TENSOR.FFN_UP_EXP,
727
+ ],
728
+ MODEL_ARCH.OLMO: [
729
+ MODEL_TENSOR.TOKEN_EMBD,
730
+ MODEL_TENSOR.OUTPUT,
731
+ MODEL_TENSOR.ATTN_Q,
732
+ MODEL_TENSOR.ATTN_K,
733
+ MODEL_TENSOR.ATTN_V,
734
+ MODEL_TENSOR.ATTN_OUT,
735
+ MODEL_TENSOR.FFN_GATE,
736
+ MODEL_TENSOR.FFN_DOWN,
737
+ MODEL_TENSOR.FFN_UP,
738
+ ],
739
+ MODEL_ARCH.ARCTIC: [
740
+ MODEL_TENSOR.TOKEN_EMBD,
741
+ MODEL_TENSOR.OUTPUT_NORM,
742
+ MODEL_TENSOR.OUTPUT,
743
+ MODEL_TENSOR.ROPE_FREQS,
744
+ MODEL_TENSOR.ATTN_NORM,
745
+ MODEL_TENSOR.ATTN_Q,
746
+ MODEL_TENSOR.ATTN_K,
747
+ MODEL_TENSOR.ATTN_V,
748
+ MODEL_TENSOR.ATTN_OUT,
749
+ MODEL_TENSOR.ATTN_ROT_EMBD,
750
+ MODEL_TENSOR.FFN_GATE_INP,
751
+ MODEL_TENSOR.FFN_NORM,
752
+ MODEL_TENSOR.FFN_GATE,
753
+ MODEL_TENSOR.FFN_DOWN,
754
+ MODEL_TENSOR.FFN_UP,
755
+ MODEL_TENSOR.FFN_NORM_EXP,
756
+ MODEL_TENSOR.FFN_GATE_EXP,
757
+ MODEL_TENSOR.FFN_DOWN_EXP,
758
+ MODEL_TENSOR.FFN_UP_EXP,
759
+ ],
760
+ # TODO
761
+ }
762
+
763
+ # tensors that will not be serialized
764
+ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
765
+ MODEL_ARCH.LLAMA: [
766
+ MODEL_TENSOR.ROPE_FREQS,
767
+ MODEL_TENSOR.ATTN_ROT_EMBD,
768
+ ],
769
+ MODEL_ARCH.BAICHUAN: [
770
+ MODEL_TENSOR.ROPE_FREQS,
771
+ MODEL_TENSOR.ATTN_ROT_EMBD,
772
+ ],
773
+ MODEL_ARCH.QWEN: [
774
+ MODEL_TENSOR.ROPE_FREQS,
775
+ MODEL_TENSOR.ATTN_ROT_EMBD,
776
+ ],
777
+ MODEL_ARCH.CODESHELL: [
778
+ MODEL_TENSOR.ROPE_FREQS,
779
+ MODEL_TENSOR.ATTN_ROT_EMBD,
780
+ ],
781
+ MODEL_ARCH.ORION: [
782
+ MODEL_TENSOR.ROPE_FREQS,
783
+ MODEL_TENSOR.ATTN_ROT_EMBD,
784
+ ],
785
+ MODEL_ARCH.STARCODER2: [
786
+ MODEL_TENSOR.ROPE_FREQS,
787
+ MODEL_TENSOR.ATTN_ROT_EMBD,
788
+ ],
789
+ MODEL_ARCH.XVERSE: [
790
+ MODEL_TENSOR.ROPE_FREQS,
791
+ MODEL_TENSOR.ATTN_ROT_EMBD,
792
+ ],
793
+ }
794
+
795
+ #
796
+ # types
797
+ #
798
+
799
+
800
+ class TokenType(IntEnum):
801
+ NORMAL = 1
802
+ UNKNOWN = 2
803
+ CONTROL = 3
804
+ USER_DEFINED = 4
805
+ UNUSED = 5
806
+ BYTE = 6
807
+
808
+
809
+ class RopeScalingType(Enum):
810
+ NONE = 'none'
811
+ LINEAR = 'linear'
812
+ YARN = 'yarn'
813
+
814
+
815
+ class PoolingType(IntEnum):
816
+ NONE = 0
817
+ MEAN = 1
818
+ CLS = 2
819
+
820
+
821
+ class GGMLQuantizationType(IntEnum):
822
+ F32 = 0
823
+ F16 = 1
824
+ Q4_0 = 2
825
+ Q4_1 = 3
826
+ Q5_0 = 6
827
+ Q5_1 = 7
828
+ Q8_0 = 8
829
+ Q8_1 = 9
830
+ Q2_K = 10
831
+ Q3_K = 11
832
+ Q4_K = 12
833
+ Q5_K = 13
834
+ Q6_K = 14
835
+ Q8_K = 15
836
+ IQ2_XXS = 16
837
+ IQ2_XS = 17
838
+ IQ3_XXS = 18
839
+ IQ1_S = 19
840
+ IQ4_NL = 20
841
+ IQ3_S = 21
842
+ IQ2_S = 22
843
+ IQ4_XS = 23
844
+ I8 = 24
845
+ I16 = 25
846
+ I32 = 26
847
+ I64 = 27
848
+ F64 = 28
849
+ IQ1_M = 29
850
+ BF16 = 30
851
+
852
+
853
+ # TODO: add GGMLFileType from ggml_ftype in ggml.h
854
+
855
+
856
+ # from llama_ftype in llama.h
857
+ # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
858
+ class LlamaFileType(IntEnum):
859
+ ALL_F32 = 0
860
+ MOSTLY_F16 = 1 # except 1d tensors
861
+ MOSTLY_Q4_0 = 2 # except 1d tensors
862
+ MOSTLY_Q4_1 = 3 # except 1d tensors
863
+ MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
864
+ # MOSTLY_Q4_2 = 5 # support has been removed
865
+ # MOSTLY_Q4_3 = 6 # support has been removed
866
+ MOSTLY_Q8_0 = 7 # except 1d tensors
867
+ MOSTLY_Q5_0 = 8 # except 1d tensors
868
+ MOSTLY_Q5_1 = 9 # except 1d tensors
869
+ MOSTLY_Q2_K = 10 # except 1d tensors
870
+ MOSTLY_Q3_K_S = 11 # except 1d tensors
871
+ MOSTLY_Q3_K_M = 12 # except 1d tensors
872
+ MOSTLY_Q3_K_L = 13 # except 1d tensors
873
+ MOSTLY_Q4_K_S = 14 # except 1d tensors
874
+ MOSTLY_Q4_K_M = 15 # except 1d tensors
875
+ MOSTLY_Q5_K_S = 16 # except 1d tensors
876
+ MOSTLY_Q5_K_M = 17 # except 1d tensors
877
+ MOSTLY_Q6_K = 18 # except 1d tensors
878
+ MOSTLY_IQ2_XXS = 19 # except 1d tensors
879
+ MOSTLY_IQ2_XS = 20 # except 1d tensors
880
+ MOSTLY_Q2_K_S = 21 # except 1d tensors
881
+ MOSTLY_IQ3_XS = 22 # except 1d tensors
882
+ MOSTLY_IQ3_XXS = 23 # except 1d tensors
883
+ MOSTLY_IQ1_S = 24 # except 1d tensors
884
+ MOSTLY_IQ4_NL = 25 # except 1d tensors
885
+ MOSTLY_IQ3_S = 26 # except 1d tensors
886
+ MOSTLY_IQ3_M = 27 # except 1d tensors
887
+ MOSTLY_IQ2_S = 28 # except 1d tensors
888
+ MOSTLY_IQ2_M = 29 # except 1d tensors
889
+ MOSTLY_IQ4_XS = 30 # except 1d tensors
890
+ MOSTLY_IQ1_M = 31 # except 1d tensors
891
+ MOSTLY_BF16 = 32 # except 1d tensors
892
+
893
+ GUESSED = 1024 # not specified in the model file
894
+
895
+
896
+ class GGUFEndian(IntEnum):
897
+ LITTLE = 0
898
+ BIG = 1
899
+
900
+
901
+ class GGUFValueType(IntEnum):
902
+ UINT8 = 0
903
+ INT8 = 1
904
+ UINT16 = 2
905
+ INT16 = 3
906
+ UINT32 = 4
907
+ INT32 = 5
908
+ FLOAT32 = 6
909
+ BOOL = 7
910
+ STRING = 8
911
+ ARRAY = 9
912
+ UINT64 = 10
913
+ INT64 = 11
914
+ FLOAT64 = 12
915
+
916
+ @staticmethod
917
+ def get_type(val: Any) -> GGUFValueType:
918
+ if isinstance(val, (str, bytes, bytearray)):
919
+ return GGUFValueType.STRING
920
+ elif isinstance(val, list):
921
+ return GGUFValueType.ARRAY
922
+ elif isinstance(val, float):
923
+ return GGUFValueType.FLOAT32
924
+ elif isinstance(val, bool):
925
+ return GGUFValueType.BOOL
926
+ elif isinstance(val, int):
927
+ return GGUFValueType.INT32
928
+ # TODO: need help with 64-bit types in Python
929
+ else:
930
+ raise ValueError(f"Unknown type: {type(val)}")
931
+
932
+
933
+ # Items here are (block size, type size)
934
+ QK_K = 256
935
+ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
936
+ GGMLQuantizationType.F32: (1, 4),
937
+ GGMLQuantizationType.F16: (1, 2),
938
+ GGMLQuantizationType.Q4_0: (32, 2 + 16),
939
+ GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
940
+ GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
941
+ GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
942
+ GGMLQuantizationType.Q8_0: (32, 2 + 32),
943
+ GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
944
+ GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
945
+ GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
946
+ GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
947
+ GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
948
+ GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
949
+ GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
950
+ GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),
951
+ GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32),
952
+ GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),
953
+ GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16),
954
+ GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
955
+ GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
956
+ GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
957
+ GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
958
+ GGMLQuantizationType.I8: (1, 1),
959
+ GGMLQuantizationType.I16: (1, 2),
960
+ GGMLQuantizationType.I32: (1, 4),
961
+ GGMLQuantizationType.I64: (1, 8),
962
+ GGMLQuantizationType.F64: (1, 8),
963
+ GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
964
+ GGMLQuantizationType.BF16: (1, 2),
965
+ }
966
+
967
+
968
+ # Aliases for backward compatibility.
969
+
970
+ # general
971
+ KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE
972
+ KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION
973
+ KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT
974
+ KEY_GENERAL_NAME = Keys.General.NAME
975
+ KEY_GENERAL_AUTHOR = Keys.General.AUTHOR
976
+ KEY_GENERAL_URL = Keys.General.URL
977
+ KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
978
+ KEY_GENERAL_LICENSE = Keys.General.LICENSE
979
+ KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
980
+ KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
981
+ KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
982
+
983
+ # LLM
984
+ KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE
985
+ KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH
986
+ KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH
987
+ KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT
988
+ KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH
989
+ KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL
990
+ KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT
991
+
992
+ # attention
993
+ KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT
994
+ KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV
995
+ KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS
996
+ KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV
997
+ KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS
998
+ KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
999
+
1000
+ # RoPE
1001
+ KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
1002
+ KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
1003
+ KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
1004
+ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
1005
+ KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
1006
+ KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
1007
+
1008
+ # SSM
1009
+ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
1010
+ KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
1011
+ KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
1012
+ KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
1013
+
1014
+ # tokenization
1015
+ KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
1016
+ KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
1017
+ KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
1018
+ KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
1019
+ KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
1020
+ KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
1021
+ KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
1022
+ KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
1023
+ KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
1024
+ KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
1025
+ KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID
1026
+ KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID
1027
+ KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
1028
+ KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
1029
+ KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
1030
+ KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
1031
+ KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
1032
+ KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
1033
+ KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID