xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -1,884 +0,0 @@
1
- #!/usr/bin/env python3
2
- # copied from llama.cpp
3
- from __future__ import annotations
4
-
5
- import json
6
- import os
7
- import shutil
8
- import struct
9
- import sys
10
- import tempfile
11
- from enum import IntEnum, auto
12
- from io import BufferedWriter
13
- from pathlib import Path
14
- from typing import Any, BinaryIO, Callable, Sequence
15
-
16
- import numpy as np
17
-
18
- #
19
- # constants
20
- #
21
-
22
- GGUF_MAGIC = 0x46554747
23
- GGUF_VERSION = 2
24
- GGUF_DEFAULT_ALIGNMENT = 32
25
-
26
- # general
27
- KEY_GENERAL_ARCHITECTURE = "general.architecture"
28
- KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version"
29
- KEY_GENERAL_ALIGNMENT = "general.alignment"
30
- KEY_GENERAL_NAME = "general.name"
31
- KEY_GENERAL_AUTHOR = "general.author"
32
- KEY_GENERAL_URL = "general.url"
33
- KEY_GENERAL_DESCRIPTION = "general.description"
34
- KEY_GENERAL_LICENSE = "general.license"
35
- KEY_GENERAL_SOURCE_URL = "general.source.url"
36
- KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
37
- KEY_GENERAL_FILE_TYPE = "general.file_type"
38
-
39
- # LLM
40
- KEY_CONTEXT_LENGTH = "{arch}.context_length"
41
- KEY_EMBEDDING_LENGTH = "{arch}.embedding_length"
42
- KEY_BLOCK_COUNT = "{arch}.block_count"
43
- KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
44
- KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
45
- KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
46
-
47
- # attention
48
- KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
49
- KEY_ATTENTION_HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
50
- KEY_ATTENTION_MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
51
- KEY_ATTENTION_CLAMP_KQV = "{arch}.attention.clamp_kqv"
52
- KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
53
- KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
54
-
55
- # RoPE
56
- KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
57
- KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
58
- KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
59
-
60
- # tokenization
61
- KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
62
- KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens"
63
- KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type"
64
- KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores"
65
- KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges"
66
- KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id"
67
- KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id"
68
- KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id"
69
- KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id"
70
- KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id"
71
- KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json"
72
- KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
73
-
74
-
75
- #
76
- # recommended mapping of model tensor names for storage in gguf
77
- #
78
-
79
-
80
- class MODEL_ARCH(IntEnum):
81
- LLAMA: int = auto()
82
- FALCON: int = auto()
83
- GPT2: int = auto()
84
- GPTJ: int = auto()
85
- GPTNEOX: int = auto()
86
- MPT: int = auto()
87
-
88
-
89
- class MODEL_TENSOR(IntEnum):
90
- TOKEN_EMBD: int = auto()
91
- POS_EMBD: int = auto()
92
- OUTPUT: int = auto()
93
- OUTPUT_NORM: int = auto()
94
- ROPE_FREQS: int = auto()
95
- ATTN_Q: int = auto()
96
- ATTN_K: int = auto()
97
- ATTN_V: int = auto()
98
- ATTN_QKV: int = auto()
99
- ATTN_OUT: int = auto()
100
- ATTN_NORM: int = auto()
101
- ATTN_NORM_2: int = auto()
102
- ATTN_ROT_EMBD: int = auto()
103
- FFN_GATE: int = auto()
104
- FFN_DOWN: int = auto()
105
- FFN_UP: int = auto()
106
- FFN_NORM: int = auto()
107
-
108
-
109
- MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
110
- MODEL_ARCH.LLAMA: "llama",
111
- MODEL_ARCH.FALCON: "falcon",
112
- MODEL_ARCH.GPT2: "gpt2",
113
- MODEL_ARCH.GPTJ: "gptj",
114
- MODEL_ARCH.GPTNEOX: "gptneox",
115
- MODEL_ARCH.MPT: "mpt",
116
- }
117
-
118
- MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
119
- MODEL_ARCH.LLAMA: {
120
- MODEL_TENSOR.TOKEN_EMBD: "token_embd",
121
- MODEL_TENSOR.OUTPUT_NORM: "output_norm",
122
- MODEL_TENSOR.OUTPUT: "output",
123
- MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
124
- MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
125
- MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
126
- MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
127
- MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
128
- MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
129
- MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
130
- MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
131
- MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
132
- MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
133
- MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
134
- },
135
- MODEL_ARCH.GPTNEOX: {
136
- MODEL_TENSOR.TOKEN_EMBD: "token_embd",
137
- MODEL_TENSOR.OUTPUT_NORM: "output_norm",
138
- MODEL_TENSOR.OUTPUT: "output",
139
- MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
140
- MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
141
- MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
142
- MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
143
- MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
144
- MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
145
- },
146
- MODEL_ARCH.FALCON: {
147
- MODEL_TENSOR.TOKEN_EMBD: "token_embd",
148
- MODEL_TENSOR.OUTPUT_NORM: "output_norm",
149
- MODEL_TENSOR.OUTPUT: "output",
150
- MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
151
- MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
152
- MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
153
- MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
154
- MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
155
- MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
156
- },
157
- MODEL_ARCH.GPT2: {
158
- # TODO
159
- },
160
- # TODO
161
- }
162
-
163
- # tensors that will not be serialized
164
- MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
165
- MODEL_ARCH.LLAMA: [
166
- MODEL_TENSOR.ROPE_FREQS,
167
- MODEL_TENSOR.ATTN_ROT_EMBD,
168
- ],
169
- }
170
-
171
-
172
- class TensorNameMap:
173
- mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
174
- # Token embeddings
175
- MODEL_TENSOR.TOKEN_EMBD: (
176
- "gpt_neox.embed_in", # gptneox
177
- "transformer.wte", # gpt2 mpt
178
- "transformer.word_embeddings", # falcon
179
- "model.embed_tokens", # llama-hf
180
- "tok_embeddings", # llama-pth
181
- ),
182
- # Position embeddings
183
- MODEL_TENSOR.POS_EMBD: ("transformer.wpe",), # gpt2
184
- # Output
185
- MODEL_TENSOR.OUTPUT: (
186
- "embed_out", # gptneox
187
- "lm_head", # gpt2 mpt falcon llama-hf
188
- "output", # llama-pth
189
- ),
190
- # Output norm
191
- MODEL_TENSOR.OUTPUT_NORM: (
192
- "gpt_neox.final_layer_norm", # gptneox
193
- "transformer.ln_f", # gpt2 falcon
194
- "model.norm", # llama-hf
195
- "norm", # llama-pth
196
- ),
197
- # Rope frequencies
198
- MODEL_TENSOR.ROPE_FREQS: ("rope.freqs",), # llama-pth
199
- }
200
-
201
- block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
202
- # Attention norm
203
- MODEL_TENSOR.ATTN_NORM: (
204
- "gpt_neox.layers.{bid}.input_layernorm", # gptneox
205
- "transformer.h.{bid}.ln_1", # gpt2
206
- "transformer.blocks.{bid}.norm_1", # mpt
207
- "transformer.h.{bid}.input_layernorm", # falcon7b
208
- "transformer.h.{bid}.ln_mlp", # falcon40b
209
- "model.layers.{bid}.input_layernorm", # llama-hf
210
- "layers.{bid}.attention_norm", # llama-pth
211
- ),
212
- # Attention norm 2
213
- MODEL_TENSOR.ATTN_NORM_2: ("transformer.h.{bid}.ln_attn",), # falcon40b
214
- # Attention query-key-value
215
- MODEL_TENSOR.ATTN_QKV: (
216
- "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
217
- "transformer.h.{bid}.attn.c_attn", # gpt2
218
- "transformer.blocks.{bid}.attn.Wqkv", # mpt
219
- "transformer.h.{bid}.self_attention.query_key_value", # falcon
220
- ),
221
- # Attention query
222
- MODEL_TENSOR.ATTN_Q: (
223
- "model.layers.{bid}.self_attn.q_proj", # llama-hf
224
- "layers.{bid}.attention.wq", # llama-pth
225
- ),
226
- # Attention key
227
- MODEL_TENSOR.ATTN_K: (
228
- "model.layers.{bid}.self_attn.k_proj", # llama-hf
229
- "layers.{bid}.attention.wk", # llama-pth
230
- ),
231
- # Attention value
232
- MODEL_TENSOR.ATTN_V: (
233
- "model.layers.{bid}.self_attn.v_proj", # llama-hf
234
- "layers.{bid}.attention.wv", # llama-pth
235
- ),
236
- # Attention output
237
- MODEL_TENSOR.ATTN_OUT: (
238
- "gpt_neox.layers.{bid}.attention.dense", # gptneox
239
- "transformer.h.{bid}.attn.c_proj", # gpt2
240
- "transformer.blocks.{bid}.attn.out_proj", # mpt
241
- "transformer.h.{bid}.self_attention.dense", # falcon
242
- "model.layers.{bid}.self_attn.o_proj", # llama-hf
243
- "layers.{bid}.attention.wo", # llama-pth
244
- ),
245
- # Rotary embeddings
246
- MODEL_TENSOR.ATTN_ROT_EMBD: (
247
- "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
248
- "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
249
- ),
250
- # Feed-forward norm
251
- MODEL_TENSOR.FFN_NORM: (
252
- "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
253
- "transformer.h.{bid}.ln_2", # gpt2
254
- "transformer.blocks.{bid}.norm_2", # mpt
255
- "model.layers.{bid}.post_attention_layernorm", # llama-hf
256
- "layers.{bid}.ffn_norm", # llama-pth
257
- ),
258
- # Feed-forward up
259
- MODEL_TENSOR.FFN_UP: (
260
- "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
261
- "transformer.h.{bid}.mlp.c_fc", # gpt2
262
- "transformer.blocks.{bid}.ffn.up_proj", # mpt
263
- "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
264
- "model.layers.{bid}.mlp.up_proj", # llama-hf
265
- "layers.{bid}.feed_forward.w3", # llama-pth
266
- ),
267
- # Feed-forward gate
268
- MODEL_TENSOR.FFN_GATE: (
269
- "model.layers.{bid}.mlp.gate_proj", # llama-hf
270
- "layers.{bid}.feed_forward.w1", # llama-pth
271
- ),
272
- # Feed-forward down
273
- MODEL_TENSOR.FFN_DOWN: (
274
- "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
275
- "transformer.h.{bid}.mlp.c_proj", # gpt2
276
- "transformer.blocks.{bid}.ffn.down_proj", # mpt
277
- "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
278
- "model.layers.{bid}.mlp.down_proj", # llama-hf
279
- "layers.{bid}.feed_forward.w2", # llama-pth
280
- ),
281
- }
282
-
283
- mapping: dict[str, tuple[MODEL_TENSOR, str]]
284
-
285
- tensor_names: dict[MODEL_TENSOR, str]
286
-
287
- def __init__(self, arch: MODEL_ARCH, n_blocks: int):
288
- mapping = self.mapping = {}
289
- tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch]
290
- for tensor, keys in self.mappings_cfg.items():
291
- tensor_name = tensor_names.get(tensor)
292
- if tensor_name is None:
293
- continue
294
- for key in keys:
295
- mapping[key] = (tensor, tensor_name)
296
- for bid in range(n_blocks):
297
- for tensor, keys in self.block_mappings_cfg.items():
298
- tensor_name = tensor_names.get(tensor)
299
- if tensor_name is None:
300
- continue
301
- tensor_name = tensor_name.format(bid=bid)
302
- for key in keys:
303
- key = key.format(bid=bid)
304
- mapping[key] = (tensor, tensor_name)
305
-
306
- def get_type_and_name(
307
- self, key: str, try_suffixes: Sequence[str]
308
- ) -> tuple[MODEL_TENSOR, str] | None:
309
- result = self.mapping.get(key)
310
- if result is not None:
311
- return result
312
- for suffix in try_suffixes:
313
- if key.endswith(suffix):
314
- result = self.mapping.get(key[: -len(suffix)])
315
- if result is not None:
316
- return (result[0], result[1] + suffix)
317
- return None
318
-
319
- def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
320
- result = self.get_type_and_name(key, try_suffixes=try_suffixes)
321
- if result is None:
322
- return None
323
- return result[1]
324
-
325
- def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
326
- result = self.get_type_and_name(key, try_suffixes=try_suffixes)
327
- if result is None:
328
- return None
329
- return result[0]
330
-
331
- def __getitem__(self, key: str) -> str:
332
- try:
333
- return self.mapping[key][1]
334
- except KeyError:
335
- raise KeyError(key)
336
-
337
- def __contains__(self, key: str) -> bool:
338
- return key in self.mapping
339
-
340
- def __repr__(self) -> str:
341
- return repr(self.mapping)
342
-
343
-
344
- def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
345
- return TensorNameMap(arch, n_blocks)
346
-
347
-
348
- class TokenType(IntEnum):
349
- NORMAL = 1
350
- UNKNOWN = 2
351
- CONTROL = 3
352
- USER_DEFINED = 4
353
- UNUSED = 5
354
- BYTE = 6
355
-
356
-
357
- #
358
- # implementation
359
- #
360
-
361
-
362
- class GGMLQuantizationType(IntEnum):
363
- F32 = 0
364
- F16 = 1
365
- Q4_0 = 2
366
- Q4_1 = 3
367
- Q5_0 = 6
368
- Q5_1 = 7
369
- Q8_0 = 8
370
- Q8_1 = 9
371
- Q2_K = 10
372
- Q3_K = 11
373
- Q4_K = 12
374
- Q5_K = 13
375
- Q6_K = 14
376
- Q8_K = 15
377
-
378
-
379
- class GGUFValueType(IntEnum):
380
- UINT8 = 0
381
- INT8 = 1
382
- UINT16 = 2
383
- INT16 = 3
384
- UINT32 = 4
385
- INT32 = 5
386
- FLOAT32 = 6
387
- BOOL = 7
388
- STRING = 8
389
- ARRAY = 9
390
- UINT64 = 10
391
- INT64 = 11
392
- FLOAT64 = 12
393
-
394
- @staticmethod
395
- def get_type(val):
396
- if isinstance(val, str) or isinstance(val, bytes) or isinstance(val, bytearray):
397
- return GGUFValueType.STRING
398
- elif isinstance(val, list):
399
- return GGUFValueType.ARRAY
400
- elif isinstance(val, float):
401
- return GGUFValueType.FLOAT32
402
- elif isinstance(val, bool):
403
- return GGUFValueType.BOOL
404
- elif isinstance(val, int):
405
- return GGUFValueType.INT32
406
- # TODO: need help with 64-bit types in Python
407
- else:
408
- print("Unknown type: " + str(type(val)))
409
- sys.exit()
410
-
411
-
412
- class GGUFWriter:
413
- fout: BufferedWriter
414
- arch: str
415
- offset_tensor = 0
416
- data_alignment = GGUF_DEFAULT_ALIGNMENT
417
- kv_data = b""
418
- kv_data_count = 0
419
- ti_data = b""
420
- ti_data_count = 0
421
- use_temp_file: bool
422
- temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
423
- tensors: list[tuple[np.ndarray[Any, Any], int]]
424
-
425
- def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file=True):
426
- self.fout = open(path, "wb")
427
- self.arch = arch
428
- self.add_architecture()
429
- self.use_temp_file = use_temp_file
430
- self.tensors = []
431
-
432
- def write_header_to_file(self):
433
- self.fout.write(struct.pack("<I", GGUF_MAGIC))
434
- self.fout.write(struct.pack("<I", GGUF_VERSION))
435
- self.fout.write(struct.pack("<Q", self.ti_data_count))
436
- self.fout.write(struct.pack("<Q", self.kv_data_count))
437
- self.flush()
438
-
439
- # print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
440
-
441
- def write_kv_data_to_file(self):
442
- self.fout.write(self.kv_data)
443
- self.flush()
444
-
445
- def write_ti_data_to_file(self):
446
- self.fout.write(self.ti_data)
447
- self.flush()
448
-
449
- def add_key(self, key: str):
450
- self.add_val(key, GGUFValueType.STRING, add_vtype=False)
451
-
452
- def add_uint8(self, key: str, val: int):
453
- self.add_key(key)
454
- self.add_val(val, GGUFValueType.UINT8)
455
-
456
- def add_int8(self, key: str, val: int):
457
- self.add_key(key)
458
- self.add_val(val, GGUFValueType.INT8)
459
-
460
- def add_uint16(self, key: str, val: int):
461
- self.add_key(key)
462
- self.add_val(val, GGUFValueType.UINT16)
463
-
464
- def add_int16(self, key: str, val: int):
465
- self.add_key(key)
466
- self.add_val(val, GGUFValueType.INT16)
467
-
468
- def add_uint32(self, key: str, val: int):
469
- self.add_key(key)
470
- self.add_val(val, GGUFValueType.UINT32)
471
-
472
- def add_int32(self, key: str, val: int):
473
- self.add_key(key)
474
- self.add_val(val, GGUFValueType.INT32)
475
-
476
- def add_float32(self, key: str, val: float):
477
- self.add_key(key)
478
- self.add_val(val, GGUFValueType.FLOAT32)
479
-
480
- def add_uint64(self, key: str, val: int):
481
- self.add_key(key)
482
- self.add_val(val, GGUFValueType.UINT64)
483
-
484
- def add_int64(self, key: str, val: int):
485
- self.add_key(key)
486
- self.add_val(val, GGUFValueType.INT64)
487
-
488
- def add_float64(self, key: str, val: float):
489
- self.add_key(key)
490
- self.add_val(val, GGUFValueType.FLOAT64)
491
-
492
- def add_bool(self, key: str, val: bool):
493
- self.add_key(key)
494
- self.add_val(val, GGUFValueType.BOOL)
495
-
496
- def add_string(self, key: str, val: str):
497
- if len(val) == 0:
498
- return
499
- self.add_key(key)
500
- self.add_val(val, GGUFValueType.STRING)
501
-
502
- def add_array(self, key: str, val: Sequence[Any]):
503
- if not isinstance(val, Sequence):
504
- raise ValueError("Value must be a sequence for array type")
505
-
506
- self.add_key(key)
507
- self.add_val(val, GGUFValueType.ARRAY)
508
-
509
- _simple_value_packing = {
510
- GGUFValueType.UINT8: "<B",
511
- GGUFValueType.INT8: "<b",
512
- GGUFValueType.UINT16: "<H",
513
- GGUFValueType.INT16: "<h",
514
- GGUFValueType.UINT32: "<I",
515
- GGUFValueType.INT32: "<i",
516
- GGUFValueType.FLOAT32: "<f",
517
- GGUFValueType.UINT64: "<Q",
518
- GGUFValueType.INT64: "<q",
519
- GGUFValueType.FLOAT64: "<d",
520
- GGUFValueType.BOOL: "?",
521
- }
522
-
523
- def add_val(
524
- self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True
525
- ):
526
- if vtype is None:
527
- vtype = GGUFValueType.get_type(val)
528
-
529
- if add_vtype:
530
- self.kv_data += struct.pack("<I", vtype)
531
- self.kv_data_count += 1
532
-
533
- pack_fmt = self._simple_value_packing.get(vtype)
534
- if pack_fmt is not None:
535
- self.kv_data += struct.pack(pack_fmt, val)
536
- elif vtype == GGUFValueType.STRING:
537
- encoded_val = val.encode("utf8") if isinstance(val, str) else val
538
- self.kv_data += struct.pack("<Q", len(encoded_val))
539
- self.kv_data += encoded_val
540
- elif (
541
- vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0
542
- ):
543
- ltype = GGUFValueType.get_type(val[0])
544
- if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
545
- raise ValueError("All items in a GGUF array should be of the same type")
546
- self.kv_data += struct.pack("<I", ltype)
547
- self.kv_data += struct.pack("<Q", len(val))
548
- for item in val:
549
- self.add_val(item, add_vtype=False)
550
- else:
551
- raise ValueError("Invalid GGUF metadata value type or value")
552
-
553
- @staticmethod
554
- def ggml_pad(x: int, n: int) -> int:
555
- return ((x + n - 1) // n) * n
556
-
557
- def add_tensor_info(
558
- self,
559
- name: str,
560
- tensor_shape: Sequence[int],
561
- tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
562
- tensor_nbytes: int,
563
- raw_dtype: GGMLQuantizationType | None = None,
564
- ):
565
- assert raw_dtype is not None or tensor_dtype in (
566
- np.float32,
567
- np.float16,
568
- ), "Only F32 and F16 tensors are supported for now"
569
-
570
- encoded_name = name.encode("utf8")
571
- self.ti_data += struct.pack("<Q", len(encoded_name))
572
- self.ti_data += encoded_name
573
- n_dims = len(tensor_shape)
574
- self.ti_data += struct.pack("<I", n_dims)
575
- for i in range(n_dims):
576
- self.ti_data += struct.pack("<Q", tensor_shape[n_dims - 1 - i])
577
- if raw_dtype is None:
578
- dtype = (
579
- GGMLQuantizationType.F32
580
- if tensor_dtype == np.float32
581
- else GGMLQuantizationType.F16
582
- )
583
- else:
584
- dtype = raw_dtype
585
- self.ti_data += struct.pack("<I", dtype)
586
- self.ti_data += struct.pack("<Q", self.offset_tensor)
587
- self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
588
- self.ti_data_count += 1
589
-
590
- def add_tensor(
591
- self,
592
- name: str,
593
- tensor: np.ndarray[Any, Any],
594
- raw_shape: Sequence[int] | None = None,
595
- raw_dtype: GGMLQuantizationType | None = None,
596
- ):
597
- if self.use_temp_file and self.temp_file is None:
598
- fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
599
- fp.seek(0)
600
- self.temp_file = fp
601
-
602
- shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
603
- self.add_tensor_info(
604
- name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype
605
- )
606
-
607
- pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
608
-
609
- if self.temp_file is None:
610
- self.tensors.append((tensor, pad))
611
- return
612
-
613
- tensor.tofile(self.temp_file)
614
-
615
- if pad != 0:
616
- self.temp_file.write(bytes([0] * pad))
617
-
618
- def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
619
- pad = (
620
- GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment)
621
- - n
622
- )
623
- if pad != 0:
624
- fp.write(bytes([0] * pad))
625
-
626
- def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
627
- self.write_padding(self.fout, self.fout.tell())
628
- tensor.tofile(self.fout)
629
- self.write_padding(self.fout, tensor.nbytes)
630
-
631
- def write_tensors_to_file(self):
632
- self.write_ti_data_to_file()
633
-
634
- self.write_padding(self.fout, self.fout.tell())
635
-
636
- if self.temp_file is None:
637
- for currtensor, currpad in self.tensors:
638
- currtensor.tofile(self.fout)
639
- if currpad != 0:
640
- self.fout.write(bytes([0] * currpad))
641
- return
642
-
643
- self.temp_file.seek(0)
644
-
645
- shutil.copyfileobj(self.temp_file, self.fout)
646
- self.flush()
647
- self.temp_file.close()
648
-
649
- def flush(self):
650
- self.fout.flush()
651
-
652
- def close(self):
653
- self.fout.close()
654
-
655
- def add_architecture(self):
656
- self.add_string(KEY_GENERAL_ARCHITECTURE, self.arch)
657
-
658
- def add_author(self, author: str):
659
- self.add_string(KEY_GENERAL_AUTHOR, author)
660
-
661
- def add_tensor_data_layout(self, layout: str):
662
- self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
663
-
664
- def add_url(self, url: str):
665
- self.add_string(KEY_GENERAL_URL, url)
666
-
667
- def add_description(self, description: str):
668
- self.add_string(KEY_GENERAL_DESCRIPTION, description)
669
-
670
- def add_source_url(self, url: str):
671
- self.add_string(KEY_GENERAL_SOURCE_URL, url)
672
-
673
- def add_source_hf_repo(self, repo: str):
674
- self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
675
-
676
- def add_file_type(self, ftype: int):
677
- self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)
678
-
679
- def add_name(self, name: str):
680
- self.add_string(KEY_GENERAL_NAME, name)
681
-
682
- def add_quantization_version(self, quantization_version: GGMLQuantizationType):
683
- self.add_uint32(KEY_GENERAL_QUANTIZATION_VERSION, quantization_version)
684
-
685
- def add_custom_alignment(self, alignment: int):
686
- self.data_alignment = alignment
687
- self.add_uint32(KEY_GENERAL_ALIGNMENT, alignment)
688
-
689
- def add_context_length(self, length: int):
690
- self.add_uint32(KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
691
-
692
- def add_embedding_length(self, length: int):
693
- self.add_uint32(KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
694
-
695
- def add_block_count(self, length: int):
696
- self.add_uint32(KEY_BLOCK_COUNT.format(arch=self.arch), length)
697
-
698
- def add_feed_forward_length(self, length: int):
699
- self.add_uint32(KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
700
-
701
- def add_parallel_residual(self, use: bool):
702
- self.add_bool(KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
703
-
704
- def add_head_count(self, count: int):
705
- self.add_uint32(KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
706
-
707
- def add_head_count_kv(self, count: int):
708
- self.add_uint32(KEY_ATTENTION_HEAD_COUNT_KV.format(arch=self.arch), count)
709
-
710
- def add_max_alibi_bias(self, bias: float):
711
- self.add_float32(KEY_ATTENTION_MAX_ALIBI_BIAS.format(arch=self.arch), bias)
712
-
713
- def add_clamp_kqv(self, value: float):
714
- self.add_float32(KEY_ATTENTION_CLAMP_KQV.format(arch=self.arch), value)
715
-
716
- def add_layer_norm_eps(self, value: float):
717
- self.add_float32(KEY_ATTENTION_LAYERNORM_EPS.format(arch=self.arch), value)
718
-
719
- def add_layer_norm_rms_eps(self, value: float):
720
- self.add_float32(KEY_ATTENTION_LAYERNORM_RMS_EPS.format(arch=self.arch), value)
721
-
722
- def add_rope_dimension_count(self, count: int):
723
- self.add_uint32(KEY_ROPE_DIMENSION_COUNT.format(arch=self.arch), count)
724
-
725
- def add_rope_freq_base(self, value: float):
726
- self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
727
-
728
- def add_rope_scale_linear(self, value: float):
729
- self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
730
-
731
- def add_tokenizer_model(self, model: str):
732
- self.add_string(KEY_TOKENIZER_MODEL, model)
733
-
734
- def add_token_list(
735
- self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]
736
- ):
737
- self.add_array(KEY_TOKENIZER_LIST, tokens)
738
-
739
- def add_token_merges(
740
- self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]
741
- ):
742
- self.add_array(KEY_TOKENIZER_MERGES, merges)
743
-
744
- def add_token_types(self, types: Sequence[TokenType] | Sequence[int]):
745
- self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
746
-
747
- def add_token_scores(self, scores: Sequence[float]):
748
- self.add_array(KEY_TOKENIZER_SCORES, scores)
749
-
750
- def add_bos_token_id(self, id: int):
751
- self.add_uint32(KEY_TOKENIZER_BOS_ID, id)
752
-
753
- def add_eos_token_id(self, id: int):
754
- self.add_uint32(KEY_TOKENIZER_EOS_ID, id)
755
-
756
- def add_unk_token_id(self, id: int):
757
- self.add_uint32(KEY_TOKENIZER_UNK_ID, id)
758
-
759
- def add_sep_token_id(self, id: int):
760
- self.add_uint32(KEY_TOKENIZER_SEP_ID, id)
761
-
762
- def add_pad_token_id(self, id: int):
763
- self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
764
-
765
-
766
- class SpecialVocab:
767
- load_merges: bool = False
768
- merges: list[str] = []
769
- special_token_types: tuple[str, ...] = ("bos", "eos", "unk", "sep", "pad")
770
- special_token_ids: dict[str, int] = {}
771
-
772
- def __init__(
773
- self,
774
- path: Path,
775
- load_merges: bool = False,
776
- special_token_types: tuple[str, ...] | None = None,
777
- ):
778
- self.special_token_ids = {}
779
- self.load_merges = load_merges
780
- if special_token_types is not None:
781
- self.special_token_types = special_token_types
782
- self.load(path)
783
-
784
- def load(self, path: Path):
785
- if not self.try_load_from_tokenizer_json(path):
786
- self.try_load_from_config_json(path)
787
-
788
- def try_load_from_tokenizer_json(self, path: Path) -> bool:
789
- tokenizer_file = path / "tokenizer.json"
790
- if not tokenizer_file.is_file():
791
- return False
792
- with open(tokenizer_file, "r", encoding="utf-8") as f:
793
- tokenizer = json.load(f)
794
- if self.load_merges:
795
- merges = tokenizer.get("model", {}).get("merges")
796
- if (
797
- isinstance(merges, list)
798
- and len(merges) > 0
799
- and isinstance(merges[0], str)
800
- ):
801
- self.merges = merges
802
- tokenizer_config_file = path / "tokenizer_config.json"
803
- added_tokens = tokenizer.get("added_tokens")
804
- if added_tokens is None or not tokenizer_config_file.is_file():
805
- return True
806
- with open(tokenizer_config_file, "r", encoding="utf-8") as f:
807
- tokenizer_config = json.load(f)
808
- for typ in self.special_token_types:
809
- entry = tokenizer_config.get(f"{typ}_token")
810
- if isinstance(entry, str):
811
- tc_content = entry
812
- elif isinstance(entry, dict):
813
- entry_content = entry.get("content")
814
- if not isinstance(entry_content, str):
815
- continue
816
- tc_content = entry_content
817
- else:
818
- continue
819
- for maybe_token_id in (
820
- atok.get("id")
821
- for atok in added_tokens
822
- if atok.get("content") == tc_content
823
- ):
824
- if isinstance(maybe_token_id, int) and maybe_token_id >= 0:
825
- self.special_token_ids[typ] = maybe_token_id
826
- break
827
- return True
828
-
829
- def try_load_from_config_json(self, path: Path) -> bool:
830
- config_file = path / "config.json"
831
- if not config_file.is_file():
832
- return False
833
- with open(config_file, "r", encoding="utf-8") as f:
834
- config = json.load(f)
835
- for typ in self.special_token_types:
836
- maybe_token_id = config.get(f"{typ}_token_id")
837
- if isinstance(maybe_token_id, int) and maybe_token_id >= 0:
838
- self.special_token_ids[typ] = maybe_token_id
839
- return True
840
-
841
- def add_to_gguf(self, gw: GGUFWriter):
842
- if len(self.merges) > 0:
843
- print(f"gguf: Adding {len(self.merges)} merge(s).")
844
- gw.add_token_merges(self.merges)
845
- for typ, tokid in self.special_token_ids.items():
846
- handler: Callable[[int], None] | None = getattr(
847
- gw, f"add_{typ}_token_id", None
848
- )
849
- if handler is None:
850
- print(
851
- f"gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping"
852
- )
853
- continue
854
- print(f"gguf: Setting special token type {typ} to {tokid}")
855
- handler(tokid)
856
-
857
- def __repr__(self):
858
- return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>'
859
-
860
-
861
- # Example usage:
862
- if __name__ == "__main__":
863
- # Example usage with a file
864
- gguf_writer = GGUFWriter("example.gguf", "llama")
865
-
866
- gguf_writer.add_architecture()
867
- gguf_writer.add_block_count(12)
868
- gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
869
- gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
870
- gguf_writer.add_custom_alignment(64)
871
-
872
- tensor1 = np.ones((32,), dtype=np.float32) * 100.0
873
- tensor2 = np.ones((64,), dtype=np.float32) * 101.0
874
- tensor3 = np.ones((96,), dtype=np.float32) * 102.0
875
-
876
- gguf_writer.add_tensor("tensor1", tensor1)
877
- gguf_writer.add_tensor("tensor2", tensor2)
878
- gguf_writer.add_tensor("tensor3", tensor3)
879
-
880
- gguf_writer.write_header_to_file()
881
- gguf_writer.write_kv_data_to_file()
882
- gguf_writer.write_tensors_to_file()
883
-
884
- gguf_writer.close()