xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +48 -41
  6. xinference/model/audio/chattts.py +24 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/fish_speech.py +228 -0
  9. xinference/model/audio/model_spec.json +8 -0
  10. xinference/model/embedding/core.py +23 -1
  11. xinference/model/image/model_spec.json +2 -1
  12. xinference/model/image/model_spec_modelscope.json +2 -1
  13. xinference/model/image/stable_diffusion/core.py +49 -1
  14. xinference/model/llm/__init__.py +6 -0
  15. xinference/model/llm/llm_family.json +54 -9
  16. xinference/model/llm/llm_family.py +2 -0
  17. xinference/model/llm/llm_family_modelscope.json +56 -10
  18. xinference/model/llm/lmdeploy/__init__.py +0 -0
  19. xinference/model/llm/lmdeploy/core.py +557 -0
  20. xinference/model/llm/transformers/cogvlm2.py +4 -45
  21. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/glm4v.py +2 -23
  24. xinference/model/llm/transformers/intern_vl.py +94 -11
  25. xinference/model/llm/transformers/minicpmv25.py +2 -23
  26. xinference/model/llm/transformers/minicpmv26.py +2 -22
  27. xinference/model/llm/transformers/yi_vl.py +2 -24
  28. xinference/model/llm/utils.py +10 -1
  29. xinference/model/llm/vllm/core.py +1 -1
  30. xinference/thirdparty/fish_speech/__init__.py +0 -0
  31. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  32. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  33. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  34. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  35. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  36. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  37. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  38. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  39. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  40. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  41. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  42. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  44. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  46. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  48. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  49. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  50. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  52. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  56. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  57. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  58. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  59. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  60. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  63. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  64. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  67. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  68. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  69. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  70. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  71. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  72. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  73. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  74. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  75. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  76. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  77. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  78. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  79. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  83. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  84. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  85. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  86. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  87. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  88. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  89. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  90. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  91. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  92. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  93. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  94. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  95. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  96. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  99. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  100. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  101. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  102. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  103. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  104. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  105. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  106. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  107. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  108. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  109. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  110. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  111. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  112. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  113. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  114. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  115. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  116. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  117. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  118. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  119. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  120. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  121. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  122. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  123. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  124. xinference/web/ui/build/asset-manifest.json +3 -3
  125. xinference/web/ui/build/index.html +1 -1
  126. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  127. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  129. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
  130. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
  131. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  133. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  134. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  135. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  136. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  137. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,95 @@
1
+ import shutil
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import hydra
7
+ import torch
8
+ from hydra import compose, initialize
9
+ from hydra.utils import instantiate
10
+ from loguru import logger
11
+
12
+ from fish_speech.models.text2semantic.llama import BaseTransformer
13
+ from fish_speech.models.text2semantic.lora import get_merged_state_dict
14
+
15
+
16
+ @click.command()
17
+ @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
+ @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
19
+ @click.option("--lora-weight", type=str, required=True)
20
+ @click.option("--output", type=str, required=True)
21
+ def merge(lora_config, base_weight, lora_weight, output):
22
+ output = Path(output)
23
+ logger.info(
24
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
25
+ )
26
+
27
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
28
+ cfg = compose(config_name=lora_config)
29
+
30
+ lora_config = instantiate(cfg)
31
+ logger.info(f"Loaded lora model with config {lora_config}")
32
+
33
+ llama_model = BaseTransformer.from_pretrained(
34
+ path=base_weight,
35
+ load_weights=True,
36
+ lora_config=lora_config,
37
+ )
38
+ logger.info(f"Loaded llama model")
39
+
40
+ llama_state_dict = llama_model.state_dict()
41
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
42
+ llama_state_dict_copy = deepcopy(llama_state_dict)
43
+ lora_state_dict = torch.load(lora_weight, map_location="cpu")
44
+
45
+ if "state_dict" in llama_state_dict:
46
+ llama_state_dict = llama_state_dict["state_dict"]
47
+
48
+ if "state_dict" in lora_state_dict:
49
+ lora_state_dict = lora_state_dict["state_dict"]
50
+
51
+ # remove prefix model.
52
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
53
+ llama_state_dict = {
54
+ k.replace("model.", ""): v
55
+ for k, v in llama_state_dict.items()
56
+ if k.startswith("model.")
57
+ }
58
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
59
+ lora_state_dict = {
60
+ k.replace("model.", ""): v
61
+ for k, v in lora_state_dict.items()
62
+ if k.startswith("model.")
63
+ }
64
+
65
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
66
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
67
+
68
+ merged_state_dict = llama_state_dict | lora_state_dict
69
+ llama_model.load_state_dict(merged_state_dict, strict=True)
70
+ logger.info(f"Merged model loaded")
71
+
72
+ # Trigger eval mode to merge lora
73
+ llama_model.eval()
74
+ llama_model.save_pretrained(output, drop_lora=True)
75
+ logger.info(f"Saved merged model to {output}, validating")
76
+
77
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
78
+ original_keys = set(llama_state_dict_copy.keys())
79
+ merged_keys = set(new_state_dict.keys())
80
+
81
+ assert original_keys == merged_keys, "Keys should be same"
82
+
83
+ for key in original_keys:
84
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
85
+ if diff_l1 != 0:
86
+ break
87
+ else:
88
+ logger.error("Merged model is same as the original model")
89
+ exit(1)
90
+
91
+ logger.info("Merged model is different from the original model, check passed")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ merge()
@@ -0,0 +1,497 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ import datetime
4
+ import shutil
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import click
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from fish_speech.models.text2semantic.llama import find_multiple
17
+ from tools.llama.generate import load_model
18
+
19
+ ##### Quantization Primitives ######
20
+
21
+
22
+ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
23
+ # assumes symmetric quantization
24
+ # assumes axis == 0
25
+ # assumes dense memory format
26
+ # TODO(future): relax ^ as needed
27
+
28
+ # default setup for affine quantization of activations
29
+ eps = torch.finfo(torch.float32).eps
30
+
31
+ # get min and max
32
+ min_val, max_val = torch.aminmax(x, dim=1)
33
+
34
+ # calculate scales and zero_points based on min and max
35
+ # reference: https://fburl.com/code/srbiybme
36
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
37
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
38
+ device = min_val_neg.device
39
+
40
+ # reference: https://fburl.com/code/4wll53rk
41
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
42
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
43
+ # ensure scales is the same dtype as the original tensor
44
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
45
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
46
+
47
+ # quantize based on qmin/qmax/scales/zp
48
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
49
+ x_div = x / scales.unsqueeze(-1)
50
+ x_round = torch.round(x_div)
51
+ x_zp = x_round + zero_points.unsqueeze(-1)
52
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
53
+
54
+ return quant, scales, zero_points
55
+
56
+
57
+ def get_group_qparams(w, n_bit=4, groupsize=128):
58
+ # needed for GPTQ with padding
59
+ if groupsize > w.shape[-1]:
60
+ groupsize = w.shape[-1]
61
+ assert groupsize > 1
62
+ assert w.shape[-1] % groupsize == 0
63
+ assert w.dim() == 2
64
+
65
+ to_quant = w.reshape(-1, groupsize)
66
+ assert torch.isnan(to_quant).sum() == 0
67
+
68
+ max_val = to_quant.amax(dim=1, keepdim=True)
69
+ min_val = to_quant.amin(dim=1, keepdim=True)
70
+ max_int = 2**n_bit - 1
71
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
72
+ zeros = min_val + scales * (2 ** (n_bit - 1))
73
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
74
+ torch.bfloat16
75
+ ).reshape(w.shape[0], -1)
76
+
77
+
78
+ def pack_scales_and_zeros(scales, zeros):
79
+ assert scales.shape == zeros.shape
80
+ assert scales.dtype == torch.bfloat16
81
+ assert zeros.dtype == torch.bfloat16
82
+ return (
83
+ torch.cat(
84
+ [
85
+ scales.reshape(scales.size(0), scales.size(1), 1),
86
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
87
+ ],
88
+ 2,
89
+ )
90
+ .transpose(0, 1)
91
+ .contiguous()
92
+ )
93
+
94
+
95
+ def unpack_scales_and_zeros(scales_and_zeros):
96
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
97
+ assert scales_and_zeros.dtype == torch.float
98
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
99
+
100
+
101
+ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
102
+ assert groupsize > 1
103
+ # needed for GPTQ single column quantize
104
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
105
+ groupsize = w.shape[-1]
106
+
107
+ assert w.shape[-1] % groupsize == 0
108
+ assert w.dim() == 2
109
+
110
+ to_quant = w.reshape(-1, groupsize)
111
+ assert torch.isnan(to_quant).sum() == 0
112
+
113
+ scales = scales.reshape(-1, 1)
114
+ zeros = zeros.reshape(-1, 1)
115
+ min_val = zeros - scales * (2 ** (n_bit - 1))
116
+ max_int = 2**n_bit - 1
117
+ min_int = 0
118
+ w_int32 = (
119
+ to_quant.sub(min_val)
120
+ .div(scales)
121
+ .round()
122
+ .clamp_(min_int, max_int)
123
+ .to(torch.int32)
124
+ .reshape_as(w)
125
+ )
126
+
127
+ return w_int32
128
+
129
+
130
+ def group_quantize_tensor(w, n_bit=4, groupsize=128):
131
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
132
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
133
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
134
+ return w_int32, scales_and_zeros
135
+
136
+
137
+ def group_dequantize_tensor_from_qparams(
138
+ w_int32, scales, zeros, n_bit=4, groupsize=128
139
+ ):
140
+ assert groupsize > 1
141
+ # needed for GPTQ single column dequantize
142
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
143
+ groupsize = w_int32.shape[-1]
144
+ assert w_int32.shape[-1] % groupsize == 0
145
+ assert w_int32.dim() == 2
146
+
147
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
148
+ scales = scales.reshape(-1, 1)
149
+ zeros = zeros.reshape(-1, 1)
150
+
151
+ w_dq = (
152
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
153
+ )
154
+ return w_dq
155
+
156
+
157
+ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
158
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
159
+ return group_dequantize_tensor_from_qparams(
160
+ w_int32, scales, zeros, n_bit, groupsize
161
+ )
162
+
163
+
164
+ class QuantHandler:
165
+ def __init__(self, mod):
166
+ self.mod = mod
167
+
168
+ def create_quantized_state_dict(self) -> "StateDict":
169
+ pass
170
+
171
+ def convert_for_runtime(self) -> "nn.Module":
172
+ pass
173
+
174
+
175
+ ##### Weight-only int8 per-channel quantized code ######
176
+
177
+
178
+ def replace_linear_weight_only_int8_per_channel(module):
179
+ for name, child in module.named_children():
180
+ if isinstance(child, nn.Linear):
181
+ setattr(
182
+ module,
183
+ name,
184
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
185
+ )
186
+ else:
187
+ replace_linear_weight_only_int8_per_channel(child)
188
+
189
+
190
+ class WeightOnlyInt8QuantHandler:
191
+ def __init__(self, mod):
192
+ self.mod = mod
193
+
194
+ @torch.no_grad()
195
+ def create_quantized_state_dict(self):
196
+ cur_state_dict = self.mod.state_dict()
197
+ for fqn, mod in self.mod.named_modules():
198
+ if isinstance(mod, torch.nn.Linear):
199
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
200
+ mod.weight.float(), -128, 127, torch.int8
201
+ )
202
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
203
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
204
+
205
+ return cur_state_dict
206
+
207
+ def convert_for_runtime(self):
208
+ replace_linear_weight_only_int8_per_channel(self.mod)
209
+ return self.mod
210
+
211
+
212
+ class WeightOnlyInt8Linear(torch.nn.Module):
213
+ __constants__ = ["in_features", "out_features"]
214
+ in_features: int
215
+ out_features: int
216
+ weight: torch.Tensor
217
+
218
+ def __init__(
219
+ self,
220
+ in_features: int,
221
+ out_features: int,
222
+ bias: bool = True,
223
+ device=None,
224
+ dtype=None,
225
+ ) -> None:
226
+ factory_kwargs = {"device": device, "dtype": dtype}
227
+ super().__init__()
228
+ self.in_features = in_features
229
+ self.out_features = out_features
230
+ self.register_buffer(
231
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
232
+ )
233
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
234
+
235
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
236
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
237
+
238
+
239
+ ##### weight only int4 per channel groupwise quantized code ######
240
+
241
+
242
+ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
243
+ weight_int32, scales_and_zeros = group_quantize_tensor(
244
+ weight_bf16, n_bit=4, groupsize=groupsize
245
+ )
246
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
247
+ weight_int32, inner_k_tiles
248
+ )
249
+ return weight_int4pack, scales_and_zeros
250
+
251
+
252
+ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
253
+ origin_x_size = x.size()
254
+ x = x.reshape(-1, origin_x_size[-1])
255
+ c = torch.ops.aten._weight_int4pack_mm(
256
+ x, weight_int4pack, groupsize, scales_and_zeros
257
+ )
258
+ new_shape = origin_x_size[:-1] + (out_features,)
259
+ c = c.reshape(new_shape)
260
+ return c
261
+
262
+
263
+ def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
264
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
265
+
266
+
267
+ def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
268
+ for name, child in module.named_children():
269
+ if isinstance(child, nn.Linear):
270
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
271
+ setattr(
272
+ module,
273
+ name,
274
+ WeightOnlyInt4Linear(
275
+ child.in_features,
276
+ child.out_features,
277
+ bias=False,
278
+ groupsize=groupsize,
279
+ inner_k_tiles=inner_k_tiles,
280
+ padding=False,
281
+ ),
282
+ )
283
+ elif padding:
284
+ setattr(
285
+ module,
286
+ name,
287
+ WeightOnlyInt4Linear(
288
+ child.in_features,
289
+ child.out_features,
290
+ bias=False,
291
+ groupsize=groupsize,
292
+ inner_k_tiles=inner_k_tiles,
293
+ padding=True,
294
+ ),
295
+ )
296
+ else:
297
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
298
+
299
+
300
+ class WeightOnlyInt4QuantHandler:
301
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
302
+ self.mod = mod
303
+ self.groupsize = groupsize
304
+ self.inner_k_tiles = inner_k_tiles
305
+ self.padding = padding
306
+ assert groupsize in [32, 64, 128, 256]
307
+ assert inner_k_tiles in [2, 4, 8]
308
+
309
+ @torch.no_grad()
310
+ def create_quantized_state_dict(self):
311
+ cur_state_dict = self.mod.state_dict()
312
+ for fqn, mod in self.mod.named_modules():
313
+ if isinstance(mod, torch.nn.Linear):
314
+ assert not mod.bias
315
+ out_features = mod.out_features
316
+ in_features = mod.in_features
317
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
318
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
319
+
320
+ weight = mod.weight.data
321
+ if not _check_linear_int4_k(
322
+ in_features, self.groupsize, self.inner_k_tiles
323
+ ):
324
+ if self.padding:
325
+ import torch.nn.functional as F
326
+
327
+ print(
328
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
329
+ )
330
+ padded_in_features = find_multiple(in_features, 1024)
331
+ weight = F.pad(
332
+ weight, pad=(0, padded_in_features - in_features)
333
+ )
334
+ else:
335
+ print(
336
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
337
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
338
+ )
339
+ continue
340
+ (
341
+ weight_int4pack,
342
+ scales_and_zeros,
343
+ ) = prepare_int4_weight_and_scales_and_zeros(
344
+ weight.to(torch.bfloat16).to("cuda"),
345
+ self.groupsize,
346
+ self.inner_k_tiles,
347
+ )
348
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
349
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
350
+
351
+ return cur_state_dict
352
+
353
+ def convert_for_runtime(self):
354
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
355
+ return self.mod
356
+
357
+
358
+ class WeightOnlyInt4Linear(torch.nn.Module):
359
+ __constants__ = ["in_features", "out_features"]
360
+ in_features: int
361
+ out_features: int
362
+ weight: torch.Tensor
363
+
364
+ def __init__(
365
+ self,
366
+ in_features: int,
367
+ out_features: int,
368
+ bias=True,
369
+ device=None,
370
+ dtype=None,
371
+ groupsize: int = 128,
372
+ inner_k_tiles: int = 8,
373
+ padding: bool = True,
374
+ ) -> None:
375
+ super().__init__()
376
+ self.padding = padding
377
+ if padding:
378
+ self.origin_in_features = in_features
379
+ in_features = find_multiple(in_features, 1024)
380
+
381
+ self.in_features = in_features
382
+ self.out_features = out_features
383
+ assert not bias, "require bias=False"
384
+ self.groupsize = groupsize
385
+ self.inner_k_tiles = inner_k_tiles
386
+
387
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
388
+ assert (
389
+ in_features % (inner_k_tiles * 16) == 0
390
+ ), "require in_features % (innerKTiles * 16) == 0"
391
+ self.register_buffer(
392
+ "weight",
393
+ torch.empty(
394
+ (
395
+ out_features // 8,
396
+ in_features // (inner_k_tiles * 16),
397
+ 32,
398
+ inner_k_tiles // 2,
399
+ ),
400
+ dtype=torch.int32,
401
+ ),
402
+ )
403
+ self.register_buffer(
404
+ "scales_and_zeros",
405
+ torch.empty(
406
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
407
+ ),
408
+ )
409
+
410
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
411
+ input = input.to(torch.bfloat16)
412
+ if self.padding:
413
+ import torch.nn.functional as F
414
+
415
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
416
+ return linear_forward_int4(
417
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
418
+ )
419
+
420
+
421
+ def generate_folder_name():
422
+ now = datetime.datetime.now()
423
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
424
+ return folder_name
425
+
426
+
427
+ @click.command()
428
+ @click.option(
429
+ "--checkpoint-path",
430
+ type=click.Path(path_type=Path, exists=True),
431
+ default="checkpoints/fish-speech-1.2-sft",
432
+ )
433
+ @click.option(
434
+ "--mode", type=str, default="int8", help="type of quantization to perform"
435
+ )
436
+ @click.option(
437
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
438
+ )
439
+ @click.option("--timestamp", type=str, default="None", help="When to do quantization")
440
+ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
441
+
442
+ device = "cpu"
443
+ precision = torch.bfloat16
444
+
445
+ print("Loading model ...")
446
+ t0 = time.time()
447
+
448
+ model, _ = load_model(
449
+ checkpoint_path=checkpoint_path,
450
+ device=device,
451
+ precision=precision,
452
+ compile=False,
453
+ )
454
+ vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
455
+ now = timestamp if timestamp != "None" else generate_folder_name()
456
+
457
+ if mode == "int8":
458
+ print(
459
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
460
+ )
461
+ quant_handler = WeightOnlyInt8QuantHandler(model)
462
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
463
+
464
+ dir_name = checkpoint_path
465
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
466
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
467
+ if (dst_name / vq_model).exists():
468
+ (dst_name / vq_model).unlink()
469
+ quantize_path = dst_name / "model.pth"
470
+
471
+ elif mode == "int4":
472
+ print(
473
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
474
+ )
475
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
476
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
477
+
478
+ dir_name = checkpoint_path
479
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
480
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
481
+ if (dst_name / vq_model).exists():
482
+ (dst_name / vq_model).unlink()
483
+ quantize_path = dst_name / "model.pth"
484
+
485
+ else:
486
+ raise ValueError(
487
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
488
+ )
489
+
490
+ print(f"Writing quantized weights to {quantize_path}")
491
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
492
+ torch.save(quantized_state_dict, quantize_path)
493
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
494
+
495
+
496
+ if __name__ == "__main__":
497
+ quantize()
@@ -0,0 +1,57 @@
1
+ from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
2
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
+
4
+ # Initialize a tokenizer
5
+ tokenizer = Tokenizer(models.BPE())
6
+
7
+ # Customize pre-tokenization and decoding
8
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
9
+ tokenizer.decoder = decoders.ByteLevel()
10
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
11
+
12
+ # Don't train the tokenizer
13
+ trainer = trainers.BpeTrainer(
14
+ vocab_size=0,
15
+ min_frequency=2,
16
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
17
+ special_tokens=[
18
+ "<|begin_of_sequence|>",
19
+ "<|end_of_sequence|>",
20
+ "<|im_start|>",
21
+ "<|im_sep|>", # system, user, assistant, etc.
22
+ "<|im_end|>",
23
+ "<|semantic|>", # audio features
24
+ "<|pad|>",
25
+ ],
26
+ )
27
+
28
+ # <|im_start|>user<|im_sep|>...<|im_end|>
29
+ # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
30
+ tokenizer.train_from_iterator([], trainer=trainer)
31
+
32
+ print(len(tokenizer.get_vocab()))
33
+ x = tokenizer.encode(
34
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
35
+ ).ids
36
+ print(x, len(x))
37
+ print(tokenizer.decode(x, skip_special_tokens=True))
38
+
39
+
40
+ tokenizer = PreTrainedTokenizerFast(
41
+ tokenizer_object=tokenizer,
42
+ pad_token="<|pad|>",
43
+ bos_token="<|begin_of_sequence|>",
44
+ eos_token="<|end_of_sequence|>",
45
+ )
46
+
47
+ # Try tokenizing a new sequence
48
+ sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
49
+ encoded = tokenizer(sequence).input_ids
50
+
51
+ print("Test encoding....")
52
+ print(f"\tSentence: {sequence}")
53
+ print(f"\tEncoded: {encoded}")
54
+ print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
55
+ print(f"\tDecoded: {tokenizer.decode(encoded)}")
56
+
57
+ tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
@@ -0,0 +1,55 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from pydub import AudioSegment
5
+ from tqdm import tqdm
6
+
7
+ from tools.file import AUDIO_EXTENSIONS, list_files
8
+
9
+
10
+ def merge_and_delete_files(save_dir, original_files):
11
+ save_path = Path(save_dir)
12
+ audio_slice_files = list_files(
13
+ path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
14
+ )
15
+ audio_files = {}
16
+ label_files = {}
17
+ for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
18
+ rel_path = Path(file_path).relative_to(save_path)
19
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
20
+ if file_path.suffix == ".wav":
21
+ prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
22
+ if prefix == rel_path.parent / file_path.stem:
23
+ continue
24
+ audio = AudioSegment.from_wav(file_path)
25
+ if prefix in audio_files.keys():
26
+ audio_files[prefix] = audio_files[prefix] + audio
27
+ else:
28
+ audio_files[prefix] = audio
29
+
30
+ elif file_path.suffix == ".lab":
31
+ prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
32
+ if prefix == rel_path.parent / file_path.stem:
33
+ continue
34
+ with open(file_path, "r", encoding="utf-8") as f:
35
+ label = f.read()
36
+ if prefix in label_files.keys():
37
+ label_files[prefix] = label_files[prefix] + ", " + label
38
+ else:
39
+ label_files[prefix] = label
40
+
41
+ for prefix, audio in audio_files.items():
42
+ output_audio_path = save_path / f"{prefix}.wav"
43
+ audio.export(output_audio_path, format="wav")
44
+
45
+ for prefix, label in label_files.items():
46
+ output_label_path = save_path / f"{prefix}.lab"
47
+ with open(output_label_path, "w", encoding="utf-8") as f:
48
+ f.write(label)
49
+
50
+ for file_path in original_files:
51
+ os.remove(file_path)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ merge_and_delete_files("/made/by/spicysama/laziman", [__file__])