xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""FRACTION类
|
|
3
|
+
分数 <=> 中文字符串 方法
|
|
4
|
+
中文字符串 <=> 分数 方法
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
|
|
8
|
+
__data__ = "2019-05-03"
|
|
9
|
+
|
|
10
|
+
from fish_speech.text.chn_text_norm.basic_util import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Fraction:
|
|
14
|
+
"""
|
|
15
|
+
FRACTION类
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, fraction=None, chntext=None):
|
|
19
|
+
self.fraction = fraction
|
|
20
|
+
self.chntext = chntext
|
|
21
|
+
|
|
22
|
+
def chntext2fraction(self):
|
|
23
|
+
denominator, numerator = self.chntext.split("分之")
|
|
24
|
+
return chn2num(numerator) + "/" + chn2num(denominator)
|
|
25
|
+
|
|
26
|
+
def fraction2chntext(self):
|
|
27
|
+
numerator, denominator = self.fraction.split("/")
|
|
28
|
+
return num2chn(denominator) + "分之" + num2chn(numerator)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
|
|
33
|
+
# 测试程序
|
|
34
|
+
print(Fraction(fraction="2135/7230").fraction2chntext())
|
|
35
|
+
print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""MONEY类
|
|
3
|
+
金钱 <=> 中文字符串 方法
|
|
4
|
+
中文字符串 <=> 金钱 方法
|
|
5
|
+
"""
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
|
|
9
|
+
__data__ = "2019-05-08"
|
|
10
|
+
|
|
11
|
+
from fish_speech.text.chn_text_norm.cardinal import Cardinal
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Money:
|
|
15
|
+
"""
|
|
16
|
+
MONEY类
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, money=None, chntext=None):
|
|
20
|
+
self.money = money
|
|
21
|
+
self.chntext = chntext
|
|
22
|
+
|
|
23
|
+
# def chntext2money(self):
|
|
24
|
+
# return self.money
|
|
25
|
+
|
|
26
|
+
def money2chntext(self):
|
|
27
|
+
money = self.money
|
|
28
|
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
|
29
|
+
matchers = pattern.findall(money)
|
|
30
|
+
if matchers:
|
|
31
|
+
for matcher in matchers:
|
|
32
|
+
money = money.replace(
|
|
33
|
+
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
|
|
34
|
+
)
|
|
35
|
+
self.chntext = money
|
|
36
|
+
return self.chntext
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if __name__ == "__main__":
|
|
40
|
+
|
|
41
|
+
# 测试
|
|
42
|
+
print(Money(money="21.5万元").money2chntext())
|
|
43
|
+
print(Money(money="230块5毛").money2chntext())
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""PERCENTAGE类
|
|
3
|
+
百分数 <=> 中文字符串 方法
|
|
4
|
+
中文字符串 <=> 百分数 方法
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
|
|
8
|
+
__data__ = "2019-05-06"
|
|
9
|
+
|
|
10
|
+
from fish_speech.text.chn_text_norm.basic_util import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Percentage:
|
|
14
|
+
"""
|
|
15
|
+
PERCENTAGE类
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, percentage=None, chntext=None):
|
|
19
|
+
self.percentage = percentage
|
|
20
|
+
self.chntext = chntext
|
|
21
|
+
|
|
22
|
+
def chntext2percentage(self):
|
|
23
|
+
return chn2num(self.chntext.strip().strip("百分之")) + "%"
|
|
24
|
+
|
|
25
|
+
def percentage2chntext(self):
|
|
26
|
+
return "百分之" + num2chn(self.percentage.strip().strip("%"))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if __name__ == "__main__":
|
|
30
|
+
|
|
31
|
+
# 测试程序
|
|
32
|
+
print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
|
|
33
|
+
print(Percentage(percentage="65.3%").percentage2chntext())
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""TELEPHONE类
|
|
3
|
+
电话号码 <=> 中文字符串 方法
|
|
4
|
+
中文字符串 <=> 电话号码 方法
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
|
|
8
|
+
__data__ = "2019-05-03"
|
|
9
|
+
|
|
10
|
+
from fish_speech.text.chn_text_norm.basic_util import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TelePhone:
|
|
14
|
+
"""
|
|
15
|
+
TELEPHONE类
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
|
|
19
|
+
self.telephone = telephone
|
|
20
|
+
self.raw_chntext = raw_chntext
|
|
21
|
+
self.chntext = chntext
|
|
22
|
+
|
|
23
|
+
# def chntext2telephone(self):
|
|
24
|
+
# sil_parts = self.raw_chntext.split('<SIL>')
|
|
25
|
+
# self.telephone = '-'.join([
|
|
26
|
+
# str(chn2num(p)) for p in sil_parts
|
|
27
|
+
# ])
|
|
28
|
+
# return self.telephone
|
|
29
|
+
|
|
30
|
+
def telephone2chntext(self, fixed=False):
|
|
31
|
+
|
|
32
|
+
if fixed:
|
|
33
|
+
sil_parts = self.telephone.split("-")
|
|
34
|
+
self.raw_chntext = "<SIL>".join(
|
|
35
|
+
[num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
|
|
36
|
+
)
|
|
37
|
+
self.chntext = self.raw_chntext.replace("<SIL>", "")
|
|
38
|
+
else:
|
|
39
|
+
sp_parts = self.telephone.strip("+").split()
|
|
40
|
+
self.raw_chntext = "<SP>".join(
|
|
41
|
+
[num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
|
|
42
|
+
)
|
|
43
|
+
self.chntext = self.raw_chntext.replace("<SP>", "")
|
|
44
|
+
return self.chntext
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
|
|
49
|
+
# 测试程序
|
|
50
|
+
print(TelePhone(telephone="0595-23980880").telephone2chntext())
|
|
51
|
+
# print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
TEXT类
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
|
|
7
|
+
__data__ = "2019-05-03"
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
from fish_speech.text.chn_text_norm.cardinal import Cardinal
|
|
12
|
+
from fish_speech.text.chn_text_norm.date import Date
|
|
13
|
+
from fish_speech.text.chn_text_norm.digit import Digit
|
|
14
|
+
from fish_speech.text.chn_text_norm.fraction import Fraction
|
|
15
|
+
from fish_speech.text.chn_text_norm.money import Money
|
|
16
|
+
from fish_speech.text.chn_text_norm.percentage import Percentage
|
|
17
|
+
from fish_speech.text.chn_text_norm.telephone import TelePhone
|
|
18
|
+
|
|
19
|
+
CURRENCY_NAMES = (
|
|
20
|
+
"(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
|
|
21
|
+
"里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
|
|
22
|
+
)
|
|
23
|
+
CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
|
24
|
+
COM_QUANTIFIERS = (
|
|
25
|
+
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
|
|
26
|
+
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
|
|
27
|
+
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
|
|
28
|
+
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
|
|
29
|
+
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
|
|
30
|
+
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Text:
|
|
35
|
+
"""
|
|
36
|
+
Text类
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, raw_text, norm_text=None):
|
|
40
|
+
self.raw_text = "^" + raw_text + "$"
|
|
41
|
+
self.norm_text = norm_text
|
|
42
|
+
|
|
43
|
+
def _particular(self):
|
|
44
|
+
text = self.norm_text
|
|
45
|
+
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
|
|
46
|
+
matchers = pattern.findall(text)
|
|
47
|
+
if matchers:
|
|
48
|
+
# print('particular')
|
|
49
|
+
for matcher in matchers:
|
|
50
|
+
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
|
|
51
|
+
self.norm_text = text
|
|
52
|
+
return self.norm_text
|
|
53
|
+
|
|
54
|
+
def normalize(self):
|
|
55
|
+
text = self.raw_text
|
|
56
|
+
|
|
57
|
+
# 规范化日期
|
|
58
|
+
pattern = re.compile(
|
|
59
|
+
r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
|
|
60
|
+
)
|
|
61
|
+
matchers = pattern.findall(text)
|
|
62
|
+
if matchers:
|
|
63
|
+
# print('date')
|
|
64
|
+
for matcher in matchers:
|
|
65
|
+
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
|
|
66
|
+
|
|
67
|
+
# 规范化金钱
|
|
68
|
+
pattern = re.compile(
|
|
69
|
+
r"\D+((\d+(\.\d+)?)[多余几]?"
|
|
70
|
+
+ CURRENCY_UNITS
|
|
71
|
+
+ "(\d"
|
|
72
|
+
+ CURRENCY_UNITS
|
|
73
|
+
+ "?)?)"
|
|
74
|
+
)
|
|
75
|
+
matchers = pattern.findall(text)
|
|
76
|
+
if matchers:
|
|
77
|
+
# print('money')
|
|
78
|
+
for matcher in matchers:
|
|
79
|
+
text = text.replace(
|
|
80
|
+
matcher[0], Money(money=matcher[0]).money2chntext(), 1
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# 规范化固话/手机号码
|
|
84
|
+
# 手机
|
|
85
|
+
# http://www.jihaoba.com/news/show/13680
|
|
86
|
+
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
|
87
|
+
# 联通:130、131、132、156、155、186、185、176
|
|
88
|
+
# 电信:133、153、189、180、181、177
|
|
89
|
+
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
|
|
90
|
+
matchers = pattern.findall(text)
|
|
91
|
+
if matchers:
|
|
92
|
+
# print('telephone')
|
|
93
|
+
for matcher in matchers:
|
|
94
|
+
text = text.replace(
|
|
95
|
+
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
|
|
96
|
+
)
|
|
97
|
+
# 固话
|
|
98
|
+
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
|
|
99
|
+
matchers = pattern.findall(text)
|
|
100
|
+
if matchers:
|
|
101
|
+
# print('fixed telephone')
|
|
102
|
+
for matcher in matchers:
|
|
103
|
+
text = text.replace(
|
|
104
|
+
matcher[0],
|
|
105
|
+
TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
|
|
106
|
+
1,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# 规范化分数
|
|
110
|
+
pattern = re.compile(r"(\d+/\d+)")
|
|
111
|
+
matchers = pattern.findall(text)
|
|
112
|
+
if matchers:
|
|
113
|
+
# print('fraction')
|
|
114
|
+
for matcher in matchers:
|
|
115
|
+
text = text.replace(
|
|
116
|
+
matcher, Fraction(fraction=matcher).fraction2chntext(), 1
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# 规范化百分数
|
|
120
|
+
text = text.replace("%", "%")
|
|
121
|
+
pattern = re.compile(r"(\d+(\.\d+)?%)")
|
|
122
|
+
matchers = pattern.findall(text)
|
|
123
|
+
if matchers:
|
|
124
|
+
# print('percentage')
|
|
125
|
+
for matcher in matchers:
|
|
126
|
+
text = text.replace(
|
|
127
|
+
matcher[0],
|
|
128
|
+
Percentage(percentage=matcher[0]).percentage2chntext(),
|
|
129
|
+
1,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# 规范化纯数+量词
|
|
133
|
+
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
|
|
134
|
+
matchers = pattern.findall(text)
|
|
135
|
+
if matchers:
|
|
136
|
+
# print('cardinal+quantifier')
|
|
137
|
+
for matcher in matchers:
|
|
138
|
+
text = text.replace(
|
|
139
|
+
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# 规范化数字编号
|
|
143
|
+
pattern = re.compile(r"(\d{4,32})")
|
|
144
|
+
matchers = pattern.findall(text)
|
|
145
|
+
if matchers:
|
|
146
|
+
# print('digit')
|
|
147
|
+
for matcher in matchers:
|
|
148
|
+
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
|
|
149
|
+
|
|
150
|
+
# 规范化纯数
|
|
151
|
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
|
152
|
+
matchers = pattern.findall(text)
|
|
153
|
+
if matchers:
|
|
154
|
+
# print('cardinal')
|
|
155
|
+
for matcher in matchers:
|
|
156
|
+
text = text.replace(
|
|
157
|
+
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self.norm_text = text
|
|
161
|
+
self._particular()
|
|
162
|
+
|
|
163
|
+
return self.norm_text.lstrip("^").rstrip("$")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
if __name__ == "__main__":
|
|
167
|
+
|
|
168
|
+
# 测试程序
|
|
169
|
+
print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
|
|
170
|
+
print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
|
|
171
|
+
print(Text(raw_text="分数:32477/76391。").normalize())
|
|
172
|
+
print(Text(raw_text="百分数:80.03%。").normalize())
|
|
173
|
+
print(Text(raw_text="编号:31520181154418。").normalize())
|
|
174
|
+
print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
|
|
175
|
+
print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
|
|
176
|
+
print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
|
|
177
|
+
print(Text(raw_text="特殊:O2O或B2C。").normalize())
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
LANGUAGE_UNICODE_RANGE_MAP = {
|
|
5
|
+
"ZH": [(0x4E00, 0x9FFF)],
|
|
6
|
+
"JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)],
|
|
7
|
+
"EN": [(0x0000, 0x007F)],
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
SYMBOLS_MAPPING = {
|
|
11
|
+
":": ",",
|
|
12
|
+
";": ",",
|
|
13
|
+
",": ",",
|
|
14
|
+
"。": ".",
|
|
15
|
+
"!": "!",
|
|
16
|
+
"?": "?",
|
|
17
|
+
"\n": ".",
|
|
18
|
+
"·": ",",
|
|
19
|
+
"、": ",",
|
|
20
|
+
"...": "…",
|
|
21
|
+
"“": "'",
|
|
22
|
+
"”": "'",
|
|
23
|
+
"‘": "'",
|
|
24
|
+
"’": "'",
|
|
25
|
+
"(": "'",
|
|
26
|
+
")": "'",
|
|
27
|
+
"(": "'",
|
|
28
|
+
")": "'",
|
|
29
|
+
"《": "'",
|
|
30
|
+
"》": "'",
|
|
31
|
+
"【": "'",
|
|
32
|
+
"】": "'",
|
|
33
|
+
"[": "'",
|
|
34
|
+
"]": "'",
|
|
35
|
+
"—": "-",
|
|
36
|
+
"~": "-",
|
|
37
|
+
"~": "-",
|
|
38
|
+
"・": "-",
|
|
39
|
+
"「": "'",
|
|
40
|
+
"」": "'",
|
|
41
|
+
";": ",",
|
|
42
|
+
":": ",",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
REPLACE_SYMBOL_REGEX = re.compile(
|
|
46
|
+
"|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
|
|
47
|
+
)
|
|
48
|
+
ALL_KNOWN_UTF8_RANGE = list(
|
|
49
|
+
itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values())
|
|
50
|
+
)
|
|
51
|
+
REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
|
|
52
|
+
"[^"
|
|
53
|
+
+ "".join(
|
|
54
|
+
f"{re.escape(chr(start))}-{re.escape(chr(end))}"
|
|
55
|
+
for start, end in ALL_KNOWN_UTF8_RANGE
|
|
56
|
+
)
|
|
57
|
+
+ "]"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def clean_text(text):
|
|
62
|
+
# Clean the text
|
|
63
|
+
text = text.strip()
|
|
64
|
+
|
|
65
|
+
# Replace all chinese symbols with their english counterparts
|
|
66
|
+
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
|
67
|
+
text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
|
|
68
|
+
|
|
69
|
+
return text
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import string
|
|
3
|
+
|
|
4
|
+
from fish_speech.text.clean import clean_text
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def utf_8_len(text):
|
|
8
|
+
return len(text.encode("utf-8"))
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def break_text(texts, length, splits: set):
|
|
12
|
+
for text in texts:
|
|
13
|
+
if utf_8_len(text) <= length:
|
|
14
|
+
yield text
|
|
15
|
+
continue
|
|
16
|
+
|
|
17
|
+
curr = ""
|
|
18
|
+
for char in text:
|
|
19
|
+
curr += char
|
|
20
|
+
|
|
21
|
+
if char in splits:
|
|
22
|
+
yield curr
|
|
23
|
+
curr = ""
|
|
24
|
+
|
|
25
|
+
if curr:
|
|
26
|
+
yield curr
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def break_text_by_length(texts, length):
|
|
30
|
+
for text in texts:
|
|
31
|
+
if utf_8_len(text) <= length:
|
|
32
|
+
yield text
|
|
33
|
+
continue
|
|
34
|
+
|
|
35
|
+
curr = ""
|
|
36
|
+
for char in text:
|
|
37
|
+
curr += char
|
|
38
|
+
|
|
39
|
+
if utf_8_len(curr) >= length:
|
|
40
|
+
yield curr
|
|
41
|
+
curr = ""
|
|
42
|
+
|
|
43
|
+
if curr:
|
|
44
|
+
yield curr
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def add_cleaned(curr, segments):
|
|
48
|
+
curr = curr.strip()
|
|
49
|
+
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
|
50
|
+
segments.append(curr)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def protect_float(text):
|
|
54
|
+
# Turns 3.14 into <3_f_14> to prevent splitting
|
|
55
|
+
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def unprotect_float(text):
|
|
59
|
+
# Turns <3_f_14> into 3.14
|
|
60
|
+
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def split_text(text, length):
|
|
64
|
+
text = clean_text(text)
|
|
65
|
+
|
|
66
|
+
# Break the text into pieces with following rules:
|
|
67
|
+
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
|
68
|
+
# 2. If the text is longer than length, split at ","
|
|
69
|
+
# 3. If the text is still longer than length, split at " "
|
|
70
|
+
# 4. If the text is still longer than length, split at any character to length
|
|
71
|
+
|
|
72
|
+
texts = [text]
|
|
73
|
+
texts = map(protect_float, texts)
|
|
74
|
+
texts = break_text(texts, length, {".", "!", "?"})
|
|
75
|
+
texts = map(unprotect_float, texts)
|
|
76
|
+
texts = break_text(texts, length, {","})
|
|
77
|
+
texts = break_text(texts, length, {" "})
|
|
78
|
+
texts = list(break_text_by_length(texts, length))
|
|
79
|
+
|
|
80
|
+
# Then, merge the texts into segments with length <= length
|
|
81
|
+
segments = []
|
|
82
|
+
curr = ""
|
|
83
|
+
|
|
84
|
+
for text in texts:
|
|
85
|
+
if utf_8_len(curr) + utf_8_len(text) <= length:
|
|
86
|
+
curr += text
|
|
87
|
+
else:
|
|
88
|
+
add_cleaned(curr, segments)
|
|
89
|
+
curr = text
|
|
90
|
+
|
|
91
|
+
if curr:
|
|
92
|
+
add_cleaned(curr, segments)
|
|
93
|
+
|
|
94
|
+
return segments
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
if __name__ == "__main__":
|
|
98
|
+
# Test the split_text function
|
|
99
|
+
|
|
100
|
+
text = "This is a test sentence. This is another test sentence. And a third one."
|
|
101
|
+
|
|
102
|
+
assert split_text(text, 50) == [
|
|
103
|
+
"This is a test sentence.",
|
|
104
|
+
"This is another test sentence. And a third one.",
|
|
105
|
+
]
|
|
106
|
+
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
|
|
107
|
+
assert split_text(" ", 10) == []
|
|
108
|
+
assert split_text("a", 10) == ["a"]
|
|
109
|
+
|
|
110
|
+
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
|
|
111
|
+
assert split_text(text, 50) == [
|
|
112
|
+
"This is a test sentence with only commas,",
|
|
113
|
+
"and no dots, and no exclamation marks,",
|
|
114
|
+
"and no question marks, and no newlines.",
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
|
|
118
|
+
# First half split at " ", second half split at ","
|
|
119
|
+
assert split_text(text, 50) == [
|
|
120
|
+
"This is a test sentence This is a test sentence",
|
|
121
|
+
"This is a test sentence. This is a test sentence,",
|
|
122
|
+
"This is a test sentence, This is a test sentence.",
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
|
|
126
|
+
assert split_text(text, 50) == [
|
|
127
|
+
"这是一段很长的中文文本,",
|
|
128
|
+
"而且没有句号,也没有感叹号,",
|
|
129
|
+
"也没有问号,也没有换行符.",
|
|
130
|
+
]
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import hydra
|
|
6
|
+
import lightning as L
|
|
7
|
+
# import pyrootutils
|
|
8
|
+
import torch
|
|
9
|
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
10
|
+
from lightning.pytorch.loggers import Logger
|
|
11
|
+
from lightning.pytorch.strategies import DDPStrategy
|
|
12
|
+
from omegaconf import DictConfig, OmegaConf
|
|
13
|
+
|
|
14
|
+
os.environ.pop("SLURM_NTASKS", None)
|
|
15
|
+
os.environ.pop("SLURM_JOB_NAME", None)
|
|
16
|
+
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
|
|
17
|
+
|
|
18
|
+
# register eval resolver and root
|
|
19
|
+
# pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
20
|
+
|
|
21
|
+
# Allow TF32 on Ampere GPUs
|
|
22
|
+
torch.set_float32_matmul_precision("high")
|
|
23
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
24
|
+
|
|
25
|
+
# register eval resolver
|
|
26
|
+
OmegaConf.register_new_resolver("eval", eval)
|
|
27
|
+
|
|
28
|
+
import fish_speech.utils as utils
|
|
29
|
+
|
|
30
|
+
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@utils.task_wrapper
|
|
34
|
+
def train(cfg: DictConfig) -> tuple[dict, dict]:
|
|
35
|
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
|
36
|
+
training.
|
|
37
|
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
|
38
|
+
failure. Useful for multiruns, saving info about the crash, etc.
|
|
39
|
+
Args:
|
|
40
|
+
cfg (DictConfig): Configuration composed by Hydra.
|
|
41
|
+
Returns:
|
|
42
|
+
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
|
43
|
+
""" # noqa: E501
|
|
44
|
+
|
|
45
|
+
# set seed for random number generators in pytorch, numpy and python.random
|
|
46
|
+
if cfg.get("seed"):
|
|
47
|
+
L.seed_everything(cfg.seed, workers=False)
|
|
48
|
+
|
|
49
|
+
if cfg.get("deterministic"):
|
|
50
|
+
torch.use_deterministic_algorithms(True)
|
|
51
|
+
|
|
52
|
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
|
53
|
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
54
|
+
|
|
55
|
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
|
56
|
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
|
57
|
+
|
|
58
|
+
log.info("Instantiating callbacks...")
|
|
59
|
+
callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
|
|
60
|
+
|
|
61
|
+
log.info("Instantiating loggers...")
|
|
62
|
+
logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
|
|
63
|
+
|
|
64
|
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
|
65
|
+
trainer: Trainer = hydra.utils.instantiate(
|
|
66
|
+
cfg.trainer,
|
|
67
|
+
callbacks=callbacks,
|
|
68
|
+
logger=logger,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
object_dict = {
|
|
72
|
+
"cfg": cfg,
|
|
73
|
+
"datamodule": datamodule,
|
|
74
|
+
"model": model,
|
|
75
|
+
"callbacks": callbacks,
|
|
76
|
+
"logger": logger,
|
|
77
|
+
"trainer": trainer,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
if logger:
|
|
81
|
+
log.info("Logging hyperparameters!")
|
|
82
|
+
utils.log_hyperparameters(object_dict)
|
|
83
|
+
|
|
84
|
+
if cfg.get("train"):
|
|
85
|
+
log.info("Starting training!")
|
|
86
|
+
|
|
87
|
+
ckpt_path = cfg.get("ckpt_path")
|
|
88
|
+
auto_resume = False
|
|
89
|
+
|
|
90
|
+
resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
|
|
91
|
+
if resume_ckpt_path is not None:
|
|
92
|
+
ckpt_path = resume_ckpt_path
|
|
93
|
+
auto_resume = True
|
|
94
|
+
|
|
95
|
+
if ckpt_path is not None:
|
|
96
|
+
log.info(f"Resuming from checkpoint: {ckpt_path}")
|
|
97
|
+
|
|
98
|
+
# resume weights only is disabled for auto-resume
|
|
99
|
+
if cfg.get("resume_weights_only") and auto_resume is False:
|
|
100
|
+
log.info("Resuming weights only!")
|
|
101
|
+
ckpt = torch.load(ckpt_path, map_location=model.device)
|
|
102
|
+
if "state_dict" in ckpt:
|
|
103
|
+
ckpt = ckpt["state_dict"]
|
|
104
|
+
err = model.load_state_dict(ckpt, strict=False)
|
|
105
|
+
log.info(f"Error loading state dict: {err}")
|
|
106
|
+
ckpt_path = None
|
|
107
|
+
|
|
108
|
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
|
109
|
+
|
|
110
|
+
train_metrics = trainer.callback_metrics
|
|
111
|
+
|
|
112
|
+
if cfg.get("test"):
|
|
113
|
+
log.info("Starting testing!")
|
|
114
|
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
115
|
+
if ckpt_path == "":
|
|
116
|
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
|
117
|
+
ckpt_path = cfg.get("ckpt_path")
|
|
118
|
+
|
|
119
|
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
|
120
|
+
log.info(f"Best ckpt path: {ckpt_path}")
|
|
121
|
+
|
|
122
|
+
test_metrics = trainer.callback_metrics
|
|
123
|
+
|
|
124
|
+
# merge train and test metrics
|
|
125
|
+
metric_dict = {**train_metrics, **test_metrics}
|
|
126
|
+
|
|
127
|
+
return metric_dict, object_dict
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@hydra.main(
|
|
131
|
+
version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
|
|
132
|
+
)
|
|
133
|
+
def main(cfg: DictConfig) -> Optional[float]:
|
|
134
|
+
# train the model
|
|
135
|
+
train(cfg)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
main()
|