xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +48 -41
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +2 -0
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +10 -1
- xinference/model/llm/vllm/core.py +1 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import click
|
|
6
|
+
import hydra
|
|
7
|
+
import torch
|
|
8
|
+
from hydra import compose, initialize
|
|
9
|
+
from hydra.utils import instantiate
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from fish_speech.models.text2semantic.llama import BaseTransformer
|
|
13
|
+
from fish_speech.models.text2semantic.lora import get_merged_state_dict
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.command()
|
|
17
|
+
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
|
18
|
+
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
|
|
19
|
+
@click.option("--lora-weight", type=str, required=True)
|
|
20
|
+
@click.option("--output", type=str, required=True)
|
|
21
|
+
def merge(lora_config, base_weight, lora_weight, output):
|
|
22
|
+
output = Path(output)
|
|
23
|
+
logger.info(
|
|
24
|
+
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
|
|
28
|
+
cfg = compose(config_name=lora_config)
|
|
29
|
+
|
|
30
|
+
lora_config = instantiate(cfg)
|
|
31
|
+
logger.info(f"Loaded lora model with config {lora_config}")
|
|
32
|
+
|
|
33
|
+
llama_model = BaseTransformer.from_pretrained(
|
|
34
|
+
path=base_weight,
|
|
35
|
+
load_weights=True,
|
|
36
|
+
lora_config=lora_config,
|
|
37
|
+
)
|
|
38
|
+
logger.info(f"Loaded llama model")
|
|
39
|
+
|
|
40
|
+
llama_state_dict = llama_model.state_dict()
|
|
41
|
+
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
|
|
42
|
+
llama_state_dict_copy = deepcopy(llama_state_dict)
|
|
43
|
+
lora_state_dict = torch.load(lora_weight, map_location="cpu")
|
|
44
|
+
|
|
45
|
+
if "state_dict" in llama_state_dict:
|
|
46
|
+
llama_state_dict = llama_state_dict["state_dict"]
|
|
47
|
+
|
|
48
|
+
if "state_dict" in lora_state_dict:
|
|
49
|
+
lora_state_dict = lora_state_dict["state_dict"]
|
|
50
|
+
|
|
51
|
+
# remove prefix model.
|
|
52
|
+
if any(k.startswith("model.") for k in llama_state_dict.keys()):
|
|
53
|
+
llama_state_dict = {
|
|
54
|
+
k.replace("model.", ""): v
|
|
55
|
+
for k, v in llama_state_dict.items()
|
|
56
|
+
if k.startswith("model.")
|
|
57
|
+
}
|
|
58
|
+
if any(k.startswith("model.") for k in lora_state_dict.keys()):
|
|
59
|
+
lora_state_dict = {
|
|
60
|
+
k.replace("model.", ""): v
|
|
61
|
+
for k, v in lora_state_dict.items()
|
|
62
|
+
if k.startswith("model.")
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
|
|
66
|
+
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
|
|
67
|
+
|
|
68
|
+
merged_state_dict = llama_state_dict | lora_state_dict
|
|
69
|
+
llama_model.load_state_dict(merged_state_dict, strict=True)
|
|
70
|
+
logger.info(f"Merged model loaded")
|
|
71
|
+
|
|
72
|
+
# Trigger eval mode to merge lora
|
|
73
|
+
llama_model.eval()
|
|
74
|
+
llama_model.save_pretrained(output, drop_lora=True)
|
|
75
|
+
logger.info(f"Saved merged model to {output}, validating")
|
|
76
|
+
|
|
77
|
+
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
|
|
78
|
+
original_keys = set(llama_state_dict_copy.keys())
|
|
79
|
+
merged_keys = set(new_state_dict.keys())
|
|
80
|
+
|
|
81
|
+
assert original_keys == merged_keys, "Keys should be same"
|
|
82
|
+
|
|
83
|
+
for key in original_keys:
|
|
84
|
+
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
|
|
85
|
+
if diff_l1 != 0:
|
|
86
|
+
break
|
|
87
|
+
else:
|
|
88
|
+
logger.error("Merged model is same as the original model")
|
|
89
|
+
exit(1)
|
|
90
|
+
|
|
91
|
+
logger.info("Merged model is different from the original model, check passed")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
if __name__ == "__main__":
|
|
95
|
+
merge()
|
|
@@ -0,0 +1,497 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
import datetime
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
# This source code is licensed under the license found in the
|
|
7
|
+
# LICENSE file in the root directory of this source tree.
|
|
8
|
+
import time
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import click
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
from fish_speech.models.text2semantic.llama import find_multiple
|
|
17
|
+
from tools.llama.generate import load_model
|
|
18
|
+
|
|
19
|
+
##### Quantization Primitives ######
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
|
23
|
+
# assumes symmetric quantization
|
|
24
|
+
# assumes axis == 0
|
|
25
|
+
# assumes dense memory format
|
|
26
|
+
# TODO(future): relax ^ as needed
|
|
27
|
+
|
|
28
|
+
# default setup for affine quantization of activations
|
|
29
|
+
eps = torch.finfo(torch.float32).eps
|
|
30
|
+
|
|
31
|
+
# get min and max
|
|
32
|
+
min_val, max_val = torch.aminmax(x, dim=1)
|
|
33
|
+
|
|
34
|
+
# calculate scales and zero_points based on min and max
|
|
35
|
+
# reference: https://fburl.com/code/srbiybme
|
|
36
|
+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
|
37
|
+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
|
38
|
+
device = min_val_neg.device
|
|
39
|
+
|
|
40
|
+
# reference: https://fburl.com/code/4wll53rk
|
|
41
|
+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
|
42
|
+
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
|
43
|
+
# ensure scales is the same dtype as the original tensor
|
|
44
|
+
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
|
45
|
+
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
|
46
|
+
|
|
47
|
+
# quantize based on qmin/qmax/scales/zp
|
|
48
|
+
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
|
49
|
+
x_div = x / scales.unsqueeze(-1)
|
|
50
|
+
x_round = torch.round(x_div)
|
|
51
|
+
x_zp = x_round + zero_points.unsqueeze(-1)
|
|
52
|
+
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
|
53
|
+
|
|
54
|
+
return quant, scales, zero_points
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_group_qparams(w, n_bit=4, groupsize=128):
|
|
58
|
+
# needed for GPTQ with padding
|
|
59
|
+
if groupsize > w.shape[-1]:
|
|
60
|
+
groupsize = w.shape[-1]
|
|
61
|
+
assert groupsize > 1
|
|
62
|
+
assert w.shape[-1] % groupsize == 0
|
|
63
|
+
assert w.dim() == 2
|
|
64
|
+
|
|
65
|
+
to_quant = w.reshape(-1, groupsize)
|
|
66
|
+
assert torch.isnan(to_quant).sum() == 0
|
|
67
|
+
|
|
68
|
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
69
|
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
70
|
+
max_int = 2**n_bit - 1
|
|
71
|
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
72
|
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
73
|
+
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
|
74
|
+
torch.bfloat16
|
|
75
|
+
).reshape(w.shape[0], -1)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def pack_scales_and_zeros(scales, zeros):
|
|
79
|
+
assert scales.shape == zeros.shape
|
|
80
|
+
assert scales.dtype == torch.bfloat16
|
|
81
|
+
assert zeros.dtype == torch.bfloat16
|
|
82
|
+
return (
|
|
83
|
+
torch.cat(
|
|
84
|
+
[
|
|
85
|
+
scales.reshape(scales.size(0), scales.size(1), 1),
|
|
86
|
+
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
|
87
|
+
],
|
|
88
|
+
2,
|
|
89
|
+
)
|
|
90
|
+
.transpose(0, 1)
|
|
91
|
+
.contiguous()
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def unpack_scales_and_zeros(scales_and_zeros):
|
|
96
|
+
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
|
97
|
+
assert scales_and_zeros.dtype == torch.float
|
|
98
|
+
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
|
102
|
+
assert groupsize > 1
|
|
103
|
+
# needed for GPTQ single column quantize
|
|
104
|
+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
|
105
|
+
groupsize = w.shape[-1]
|
|
106
|
+
|
|
107
|
+
assert w.shape[-1] % groupsize == 0
|
|
108
|
+
assert w.dim() == 2
|
|
109
|
+
|
|
110
|
+
to_quant = w.reshape(-1, groupsize)
|
|
111
|
+
assert torch.isnan(to_quant).sum() == 0
|
|
112
|
+
|
|
113
|
+
scales = scales.reshape(-1, 1)
|
|
114
|
+
zeros = zeros.reshape(-1, 1)
|
|
115
|
+
min_val = zeros - scales * (2 ** (n_bit - 1))
|
|
116
|
+
max_int = 2**n_bit - 1
|
|
117
|
+
min_int = 0
|
|
118
|
+
w_int32 = (
|
|
119
|
+
to_quant.sub(min_val)
|
|
120
|
+
.div(scales)
|
|
121
|
+
.round()
|
|
122
|
+
.clamp_(min_int, max_int)
|
|
123
|
+
.to(torch.int32)
|
|
124
|
+
.reshape_as(w)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return w_int32
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
|
131
|
+
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
|
132
|
+
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
|
133
|
+
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
|
134
|
+
return w_int32, scales_and_zeros
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def group_dequantize_tensor_from_qparams(
|
|
138
|
+
w_int32, scales, zeros, n_bit=4, groupsize=128
|
|
139
|
+
):
|
|
140
|
+
assert groupsize > 1
|
|
141
|
+
# needed for GPTQ single column dequantize
|
|
142
|
+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
|
143
|
+
groupsize = w_int32.shape[-1]
|
|
144
|
+
assert w_int32.shape[-1] % groupsize == 0
|
|
145
|
+
assert w_int32.dim() == 2
|
|
146
|
+
|
|
147
|
+
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
|
148
|
+
scales = scales.reshape(-1, 1)
|
|
149
|
+
zeros = zeros.reshape(-1, 1)
|
|
150
|
+
|
|
151
|
+
w_dq = (
|
|
152
|
+
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
|
153
|
+
)
|
|
154
|
+
return w_dq
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
|
158
|
+
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
|
159
|
+
return group_dequantize_tensor_from_qparams(
|
|
160
|
+
w_int32, scales, zeros, n_bit, groupsize
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class QuantHandler:
|
|
165
|
+
def __init__(self, mod):
|
|
166
|
+
self.mod = mod
|
|
167
|
+
|
|
168
|
+
def create_quantized_state_dict(self) -> "StateDict":
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def convert_for_runtime(self) -> "nn.Module":
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
##### Weight-only int8 per-channel quantized code ######
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def replace_linear_weight_only_int8_per_channel(module):
|
|
179
|
+
for name, child in module.named_children():
|
|
180
|
+
if isinstance(child, nn.Linear):
|
|
181
|
+
setattr(
|
|
182
|
+
module,
|
|
183
|
+
name,
|
|
184
|
+
WeightOnlyInt8Linear(child.in_features, child.out_features),
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
replace_linear_weight_only_int8_per_channel(child)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class WeightOnlyInt8QuantHandler:
|
|
191
|
+
def __init__(self, mod):
|
|
192
|
+
self.mod = mod
|
|
193
|
+
|
|
194
|
+
@torch.no_grad()
|
|
195
|
+
def create_quantized_state_dict(self):
|
|
196
|
+
cur_state_dict = self.mod.state_dict()
|
|
197
|
+
for fqn, mod in self.mod.named_modules():
|
|
198
|
+
if isinstance(mod, torch.nn.Linear):
|
|
199
|
+
int8_weight, scales, _ = dynamically_quantize_per_channel(
|
|
200
|
+
mod.weight.float(), -128, 127, torch.int8
|
|
201
|
+
)
|
|
202
|
+
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
|
203
|
+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
|
204
|
+
|
|
205
|
+
return cur_state_dict
|
|
206
|
+
|
|
207
|
+
def convert_for_runtime(self):
|
|
208
|
+
replace_linear_weight_only_int8_per_channel(self.mod)
|
|
209
|
+
return self.mod
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class WeightOnlyInt8Linear(torch.nn.Module):
|
|
213
|
+
__constants__ = ["in_features", "out_features"]
|
|
214
|
+
in_features: int
|
|
215
|
+
out_features: int
|
|
216
|
+
weight: torch.Tensor
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
in_features: int,
|
|
221
|
+
out_features: int,
|
|
222
|
+
bias: bool = True,
|
|
223
|
+
device=None,
|
|
224
|
+
dtype=None,
|
|
225
|
+
) -> None:
|
|
226
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
227
|
+
super().__init__()
|
|
228
|
+
self.in_features = in_features
|
|
229
|
+
self.out_features = out_features
|
|
230
|
+
self.register_buffer(
|
|
231
|
+
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
|
|
232
|
+
)
|
|
233
|
+
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
|
234
|
+
|
|
235
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
236
|
+
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
##### weight only int4 per channel groupwise quantized code ######
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
|
243
|
+
weight_int32, scales_and_zeros = group_quantize_tensor(
|
|
244
|
+
weight_bf16, n_bit=4, groupsize=groupsize
|
|
245
|
+
)
|
|
246
|
+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
|
|
247
|
+
weight_int32, inner_k_tiles
|
|
248
|
+
)
|
|
249
|
+
return weight_int4pack, scales_and_zeros
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
|
253
|
+
origin_x_size = x.size()
|
|
254
|
+
x = x.reshape(-1, origin_x_size[-1])
|
|
255
|
+
c = torch.ops.aten._weight_int4pack_mm(
|
|
256
|
+
x, weight_int4pack, groupsize, scales_and_zeros
|
|
257
|
+
)
|
|
258
|
+
new_shape = origin_x_size[:-1] + (out_features,)
|
|
259
|
+
c = c.reshape(new_shape)
|
|
260
|
+
return c
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
|
|
264
|
+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
|
268
|
+
for name, child in module.named_children():
|
|
269
|
+
if isinstance(child, nn.Linear):
|
|
270
|
+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
|
271
|
+
setattr(
|
|
272
|
+
module,
|
|
273
|
+
name,
|
|
274
|
+
WeightOnlyInt4Linear(
|
|
275
|
+
child.in_features,
|
|
276
|
+
child.out_features,
|
|
277
|
+
bias=False,
|
|
278
|
+
groupsize=groupsize,
|
|
279
|
+
inner_k_tiles=inner_k_tiles,
|
|
280
|
+
padding=False,
|
|
281
|
+
),
|
|
282
|
+
)
|
|
283
|
+
elif padding:
|
|
284
|
+
setattr(
|
|
285
|
+
module,
|
|
286
|
+
name,
|
|
287
|
+
WeightOnlyInt4Linear(
|
|
288
|
+
child.in_features,
|
|
289
|
+
child.out_features,
|
|
290
|
+
bias=False,
|
|
291
|
+
groupsize=groupsize,
|
|
292
|
+
inner_k_tiles=inner_k_tiles,
|
|
293
|
+
padding=True,
|
|
294
|
+
),
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class WeightOnlyInt4QuantHandler:
|
|
301
|
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
|
302
|
+
self.mod = mod
|
|
303
|
+
self.groupsize = groupsize
|
|
304
|
+
self.inner_k_tiles = inner_k_tiles
|
|
305
|
+
self.padding = padding
|
|
306
|
+
assert groupsize in [32, 64, 128, 256]
|
|
307
|
+
assert inner_k_tiles in [2, 4, 8]
|
|
308
|
+
|
|
309
|
+
@torch.no_grad()
|
|
310
|
+
def create_quantized_state_dict(self):
|
|
311
|
+
cur_state_dict = self.mod.state_dict()
|
|
312
|
+
for fqn, mod in self.mod.named_modules():
|
|
313
|
+
if isinstance(mod, torch.nn.Linear):
|
|
314
|
+
assert not mod.bias
|
|
315
|
+
out_features = mod.out_features
|
|
316
|
+
in_features = mod.in_features
|
|
317
|
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
318
|
+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
|
319
|
+
|
|
320
|
+
weight = mod.weight.data
|
|
321
|
+
if not _check_linear_int4_k(
|
|
322
|
+
in_features, self.groupsize, self.inner_k_tiles
|
|
323
|
+
):
|
|
324
|
+
if self.padding:
|
|
325
|
+
import torch.nn.functional as F
|
|
326
|
+
|
|
327
|
+
print(
|
|
328
|
+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
|
|
329
|
+
)
|
|
330
|
+
padded_in_features = find_multiple(in_features, 1024)
|
|
331
|
+
weight = F.pad(
|
|
332
|
+
weight, pad=(0, padded_in_features - in_features)
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
print(
|
|
336
|
+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
|
|
337
|
+
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
|
|
338
|
+
)
|
|
339
|
+
continue
|
|
340
|
+
(
|
|
341
|
+
weight_int4pack,
|
|
342
|
+
scales_and_zeros,
|
|
343
|
+
) = prepare_int4_weight_and_scales_and_zeros(
|
|
344
|
+
weight.to(torch.bfloat16).to("cuda"),
|
|
345
|
+
self.groupsize,
|
|
346
|
+
self.inner_k_tiles,
|
|
347
|
+
)
|
|
348
|
+
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
|
|
349
|
+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
|
|
350
|
+
|
|
351
|
+
return cur_state_dict
|
|
352
|
+
|
|
353
|
+
def convert_for_runtime(self):
|
|
354
|
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
|
355
|
+
return self.mod
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class WeightOnlyInt4Linear(torch.nn.Module):
|
|
359
|
+
__constants__ = ["in_features", "out_features"]
|
|
360
|
+
in_features: int
|
|
361
|
+
out_features: int
|
|
362
|
+
weight: torch.Tensor
|
|
363
|
+
|
|
364
|
+
def __init__(
|
|
365
|
+
self,
|
|
366
|
+
in_features: int,
|
|
367
|
+
out_features: int,
|
|
368
|
+
bias=True,
|
|
369
|
+
device=None,
|
|
370
|
+
dtype=None,
|
|
371
|
+
groupsize: int = 128,
|
|
372
|
+
inner_k_tiles: int = 8,
|
|
373
|
+
padding: bool = True,
|
|
374
|
+
) -> None:
|
|
375
|
+
super().__init__()
|
|
376
|
+
self.padding = padding
|
|
377
|
+
if padding:
|
|
378
|
+
self.origin_in_features = in_features
|
|
379
|
+
in_features = find_multiple(in_features, 1024)
|
|
380
|
+
|
|
381
|
+
self.in_features = in_features
|
|
382
|
+
self.out_features = out_features
|
|
383
|
+
assert not bias, "require bias=False"
|
|
384
|
+
self.groupsize = groupsize
|
|
385
|
+
self.inner_k_tiles = inner_k_tiles
|
|
386
|
+
|
|
387
|
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
388
|
+
assert (
|
|
389
|
+
in_features % (inner_k_tiles * 16) == 0
|
|
390
|
+
), "require in_features % (innerKTiles * 16) == 0"
|
|
391
|
+
self.register_buffer(
|
|
392
|
+
"weight",
|
|
393
|
+
torch.empty(
|
|
394
|
+
(
|
|
395
|
+
out_features // 8,
|
|
396
|
+
in_features // (inner_k_tiles * 16),
|
|
397
|
+
32,
|
|
398
|
+
inner_k_tiles // 2,
|
|
399
|
+
),
|
|
400
|
+
dtype=torch.int32,
|
|
401
|
+
),
|
|
402
|
+
)
|
|
403
|
+
self.register_buffer(
|
|
404
|
+
"scales_and_zeros",
|
|
405
|
+
torch.empty(
|
|
406
|
+
(in_features // groupsize, out_features, 2), dtype=torch.bfloat16
|
|
407
|
+
),
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
411
|
+
input = input.to(torch.bfloat16)
|
|
412
|
+
if self.padding:
|
|
413
|
+
import torch.nn.functional as F
|
|
414
|
+
|
|
415
|
+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
|
416
|
+
return linear_forward_int4(
|
|
417
|
+
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def generate_folder_name():
|
|
422
|
+
now = datetime.datetime.now()
|
|
423
|
+
folder_name = now.strftime("%Y%m%d_%H%M%S")
|
|
424
|
+
return folder_name
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@click.command()
|
|
428
|
+
@click.option(
|
|
429
|
+
"--checkpoint-path",
|
|
430
|
+
type=click.Path(path_type=Path, exists=True),
|
|
431
|
+
default="checkpoints/fish-speech-1.2-sft",
|
|
432
|
+
)
|
|
433
|
+
@click.option(
|
|
434
|
+
"--mode", type=str, default="int8", help="type of quantization to perform"
|
|
435
|
+
)
|
|
436
|
+
@click.option(
|
|
437
|
+
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
|
|
438
|
+
)
|
|
439
|
+
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
|
|
440
|
+
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
|
|
441
|
+
|
|
442
|
+
device = "cpu"
|
|
443
|
+
precision = torch.bfloat16
|
|
444
|
+
|
|
445
|
+
print("Loading model ...")
|
|
446
|
+
t0 = time.time()
|
|
447
|
+
|
|
448
|
+
model, _ = load_model(
|
|
449
|
+
checkpoint_path=checkpoint_path,
|
|
450
|
+
device=device,
|
|
451
|
+
precision=precision,
|
|
452
|
+
compile=False,
|
|
453
|
+
)
|
|
454
|
+
vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
|
|
455
|
+
now = timestamp if timestamp != "None" else generate_folder_name()
|
|
456
|
+
|
|
457
|
+
if mode == "int8":
|
|
458
|
+
print(
|
|
459
|
+
"Quantizing model weights for int8 weight-only symmetric per-channel quantization"
|
|
460
|
+
)
|
|
461
|
+
quant_handler = WeightOnlyInt8QuantHandler(model)
|
|
462
|
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
|
463
|
+
|
|
464
|
+
dir_name = checkpoint_path
|
|
465
|
+
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
|
|
466
|
+
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
|
467
|
+
if (dst_name / vq_model).exists():
|
|
468
|
+
(dst_name / vq_model).unlink()
|
|
469
|
+
quantize_path = dst_name / "model.pth"
|
|
470
|
+
|
|
471
|
+
elif mode == "int4":
|
|
472
|
+
print(
|
|
473
|
+
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
|
|
474
|
+
)
|
|
475
|
+
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
|
476
|
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
|
477
|
+
|
|
478
|
+
dir_name = checkpoint_path
|
|
479
|
+
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
|
|
480
|
+
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
|
481
|
+
if (dst_name / vq_model).exists():
|
|
482
|
+
(dst_name / vq_model).unlink()
|
|
483
|
+
quantize_path = dst_name / "model.pth"
|
|
484
|
+
|
|
485
|
+
else:
|
|
486
|
+
raise ValueError(
|
|
487
|
+
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
print(f"Writing quantized weights to {quantize_path}")
|
|
491
|
+
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
|
492
|
+
torch.save(quantized_state_dict, quantize_path)
|
|
493
|
+
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
if __name__ == "__main__":
|
|
497
|
+
quantize()
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
|
|
2
|
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
3
|
+
|
|
4
|
+
# Initialize a tokenizer
|
|
5
|
+
tokenizer = Tokenizer(models.BPE())
|
|
6
|
+
|
|
7
|
+
# Customize pre-tokenization and decoding
|
|
8
|
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
|
9
|
+
tokenizer.decoder = decoders.ByteLevel()
|
|
10
|
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
|
11
|
+
|
|
12
|
+
# Don't train the tokenizer
|
|
13
|
+
trainer = trainers.BpeTrainer(
|
|
14
|
+
vocab_size=0,
|
|
15
|
+
min_frequency=2,
|
|
16
|
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
|
17
|
+
special_tokens=[
|
|
18
|
+
"<|begin_of_sequence|>",
|
|
19
|
+
"<|end_of_sequence|>",
|
|
20
|
+
"<|im_start|>",
|
|
21
|
+
"<|im_sep|>", # system, user, assistant, etc.
|
|
22
|
+
"<|im_end|>",
|
|
23
|
+
"<|semantic|>", # audio features
|
|
24
|
+
"<|pad|>",
|
|
25
|
+
],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# <|im_start|>user<|im_sep|>...<|im_end|>
|
|
29
|
+
# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
|
|
30
|
+
tokenizer.train_from_iterator([], trainer=trainer)
|
|
31
|
+
|
|
32
|
+
print(len(tokenizer.get_vocab()))
|
|
33
|
+
x = tokenizer.encode(
|
|
34
|
+
"Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
|
|
35
|
+
).ids
|
|
36
|
+
print(x, len(x))
|
|
37
|
+
print(tokenizer.decode(x, skip_special_tokens=True))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
tokenizer = PreTrainedTokenizerFast(
|
|
41
|
+
tokenizer_object=tokenizer,
|
|
42
|
+
pad_token="<|pad|>",
|
|
43
|
+
bos_token="<|begin_of_sequence|>",
|
|
44
|
+
eos_token="<|end_of_sequence|>",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Try tokenizing a new sequence
|
|
48
|
+
sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
|
|
49
|
+
encoded = tokenizer(sequence).input_ids
|
|
50
|
+
|
|
51
|
+
print("Test encoding....")
|
|
52
|
+
print(f"\tSentence: {sequence}")
|
|
53
|
+
print(f"\tEncoded: {encoded}")
|
|
54
|
+
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
|
|
55
|
+
print(f"\tDecoded: {tokenizer.decode(encoded)}")
|
|
56
|
+
|
|
57
|
+
tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from pydub import AudioSegment
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def merge_and_delete_files(save_dir, original_files):
|
|
11
|
+
save_path = Path(save_dir)
|
|
12
|
+
audio_slice_files = list_files(
|
|
13
|
+
path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
|
|
14
|
+
)
|
|
15
|
+
audio_files = {}
|
|
16
|
+
label_files = {}
|
|
17
|
+
for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
|
|
18
|
+
rel_path = Path(file_path).relative_to(save_path)
|
|
19
|
+
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
|
20
|
+
if file_path.suffix == ".wav":
|
|
21
|
+
prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
|
|
22
|
+
if prefix == rel_path.parent / file_path.stem:
|
|
23
|
+
continue
|
|
24
|
+
audio = AudioSegment.from_wav(file_path)
|
|
25
|
+
if prefix in audio_files.keys():
|
|
26
|
+
audio_files[prefix] = audio_files[prefix] + audio
|
|
27
|
+
else:
|
|
28
|
+
audio_files[prefix] = audio
|
|
29
|
+
|
|
30
|
+
elif file_path.suffix == ".lab":
|
|
31
|
+
prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
|
|
32
|
+
if prefix == rel_path.parent / file_path.stem:
|
|
33
|
+
continue
|
|
34
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
35
|
+
label = f.read()
|
|
36
|
+
if prefix in label_files.keys():
|
|
37
|
+
label_files[prefix] = label_files[prefix] + ", " + label
|
|
38
|
+
else:
|
|
39
|
+
label_files[prefix] = label
|
|
40
|
+
|
|
41
|
+
for prefix, audio in audio_files.items():
|
|
42
|
+
output_audio_path = save_path / f"{prefix}.wav"
|
|
43
|
+
audio.export(output_audio_path, format="wav")
|
|
44
|
+
|
|
45
|
+
for prefix, label in label_files.items():
|
|
46
|
+
output_label_path = save_path / f"{prefix}.lab"
|
|
47
|
+
with open(output_label_path, "w", encoding="utf-8") as f:
|
|
48
|
+
f.write(label)
|
|
49
|
+
|
|
50
|
+
for file_path in original_files:
|
|
51
|
+
os.remove(file_path)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
if __name__ == "__main__":
|
|
55
|
+
merge_and_delete_files("/made/by/spicysama/laziman", [__file__])
|