xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,35 @@
1
+ # -*- coding: utf-8 -*-
2
+ """FRACTION类
3
+ 分数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 分数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Fraction:
14
+ """
15
+ FRACTION类
16
+ """
17
+
18
+ def __init__(self, fraction=None, chntext=None):
19
+ self.fraction = fraction
20
+ self.chntext = chntext
21
+
22
+ def chntext2fraction(self):
23
+ denominator, numerator = self.chntext.split("分之")
24
+ return chn2num(numerator) + "/" + chn2num(denominator)
25
+
26
+ def fraction2chntext(self):
27
+ numerator, denominator = self.fraction.split("/")
28
+ return num2chn(denominator) + "分之" + num2chn(numerator)
29
+
30
+
31
+ if __name__ == "__main__":
32
+
33
+ # 测试程序
34
+ print(Fraction(fraction="2135/7230").fraction2chntext())
35
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
@@ -0,0 +1,43 @@
1
+ # -*- coding: utf-8 -*-
2
+ """MONEY类
3
+ 金钱 <=> 中文字符串 方法
4
+ 中文字符串 <=> 金钱 方法
5
+ """
6
+ import re
7
+
8
+ __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
9
+ __data__ = "2019-05-08"
10
+
11
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
12
+
13
+
14
+ class Money:
15
+ """
16
+ MONEY类
17
+ """
18
+
19
+ def __init__(self, money=None, chntext=None):
20
+ self.money = money
21
+ self.chntext = chntext
22
+
23
+ # def chntext2money(self):
24
+ # return self.money
25
+
26
+ def money2chntext(self):
27
+ money = self.money
28
+ pattern = re.compile(r"(\d+(\.\d+)?)")
29
+ matchers = pattern.findall(money)
30
+ if matchers:
31
+ for matcher in matchers:
32
+ money = money.replace(
33
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
34
+ )
35
+ self.chntext = money
36
+ return self.chntext
37
+
38
+
39
+ if __name__ == "__main__":
40
+
41
+ # 测试
42
+ print(Money(money="21.5万元").money2chntext())
43
+ print(Money(money="230块5毛").money2chntext())
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ """PERCENTAGE类
3
+ 百分数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 百分数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
8
+ __data__ = "2019-05-06"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Percentage:
14
+ """
15
+ PERCENTAGE类
16
+ """
17
+
18
+ def __init__(self, percentage=None, chntext=None):
19
+ self.percentage = percentage
20
+ self.chntext = chntext
21
+
22
+ def chntext2percentage(self):
23
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
24
+
25
+ def percentage2chntext(self):
26
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
33
+ print(Percentage(percentage="65.3%").percentage2chntext())
@@ -0,0 +1,51 @@
1
+ # -*- coding: utf-8 -*-
2
+ """TELEPHONE类
3
+ 电话号码 <=> 中文字符串 方法
4
+ 中文字符串 <=> 电话号码 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class TelePhone:
14
+ """
15
+ TELEPHONE类
16
+ """
17
+
18
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
19
+ self.telephone = telephone
20
+ self.raw_chntext = raw_chntext
21
+ self.chntext = chntext
22
+
23
+ # def chntext2telephone(self):
24
+ # sil_parts = self.raw_chntext.split('<SIL>')
25
+ # self.telephone = '-'.join([
26
+ # str(chn2num(p)) for p in sil_parts
27
+ # ])
28
+ # return self.telephone
29
+
30
+ def telephone2chntext(self, fixed=False):
31
+
32
+ if fixed:
33
+ sil_parts = self.telephone.split("-")
34
+ self.raw_chntext = "<SIL>".join(
35
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
36
+ )
37
+ self.chntext = self.raw_chntext.replace("<SIL>", "")
38
+ else:
39
+ sp_parts = self.telephone.strip("+").split()
40
+ self.raw_chntext = "<SP>".join(
41
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
42
+ )
43
+ self.chntext = self.raw_chntext.replace("<SP>", "")
44
+ return self.chntext
45
+
46
+
47
+ if __name__ == "__main__":
48
+
49
+ # 测试程序
50
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
51
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
@@ -0,0 +1,177 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ TEXT类
4
+ """
5
+
6
+ __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
7
+ __data__ = "2019-05-03"
8
+
9
+ import re
10
+
11
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
12
+ from fish_speech.text.chn_text_norm.date import Date
13
+ from fish_speech.text.chn_text_norm.digit import Digit
14
+ from fish_speech.text.chn_text_norm.fraction import Fraction
15
+ from fish_speech.text.chn_text_norm.money import Money
16
+ from fish_speech.text.chn_text_norm.percentage import Percentage
17
+ from fish_speech.text.chn_text_norm.telephone import TelePhone
18
+
19
+ CURRENCY_NAMES = (
20
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
21
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
22
+ )
23
+ CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
24
+ COM_QUANTIFIERS = (
25
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
26
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
27
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
28
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
29
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
30
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
31
+ )
32
+
33
+
34
+ class Text:
35
+ """
36
+ Text类
37
+ """
38
+
39
+ def __init__(self, raw_text, norm_text=None):
40
+ self.raw_text = "^" + raw_text + "$"
41
+ self.norm_text = norm_text
42
+
43
+ def _particular(self):
44
+ text = self.norm_text
45
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
46
+ matchers = pattern.findall(text)
47
+ if matchers:
48
+ # print('particular')
49
+ for matcher in matchers:
50
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
51
+ self.norm_text = text
52
+ return self.norm_text
53
+
54
+ def normalize(self):
55
+ text = self.raw_text
56
+
57
+ # 规范化日期
58
+ pattern = re.compile(
59
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
60
+ )
61
+ matchers = pattern.findall(text)
62
+ if matchers:
63
+ # print('date')
64
+ for matcher in matchers:
65
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
66
+
67
+ # 规范化金钱
68
+ pattern = re.compile(
69
+ r"\D+((\d+(\.\d+)?)[多余几]?"
70
+ + CURRENCY_UNITS
71
+ + "(\d"
72
+ + CURRENCY_UNITS
73
+ + "?)?)"
74
+ )
75
+ matchers = pattern.findall(text)
76
+ if matchers:
77
+ # print('money')
78
+ for matcher in matchers:
79
+ text = text.replace(
80
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
81
+ )
82
+
83
+ # 规范化固话/手机号码
84
+ # 手机
85
+ # http://www.jihaoba.com/news/show/13680
86
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
87
+ # 联通:130、131、132、156、155、186、185、176
88
+ # 电信:133、153、189、180、181、177
89
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
90
+ matchers = pattern.findall(text)
91
+ if matchers:
92
+ # print('telephone')
93
+ for matcher in matchers:
94
+ text = text.replace(
95
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
96
+ )
97
+ # 固话
98
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
99
+ matchers = pattern.findall(text)
100
+ if matchers:
101
+ # print('fixed telephone')
102
+ for matcher in matchers:
103
+ text = text.replace(
104
+ matcher[0],
105
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
106
+ 1,
107
+ )
108
+
109
+ # 规范化分数
110
+ pattern = re.compile(r"(\d+/\d+)")
111
+ matchers = pattern.findall(text)
112
+ if matchers:
113
+ # print('fraction')
114
+ for matcher in matchers:
115
+ text = text.replace(
116
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
117
+ )
118
+
119
+ # 规范化百分数
120
+ text = text.replace("%", "%")
121
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
122
+ matchers = pattern.findall(text)
123
+ if matchers:
124
+ # print('percentage')
125
+ for matcher in matchers:
126
+ text = text.replace(
127
+ matcher[0],
128
+ Percentage(percentage=matcher[0]).percentage2chntext(),
129
+ 1,
130
+ )
131
+
132
+ # 规范化纯数+量词
133
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
134
+ matchers = pattern.findall(text)
135
+ if matchers:
136
+ # print('cardinal+quantifier')
137
+ for matcher in matchers:
138
+ text = text.replace(
139
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
140
+ )
141
+
142
+ # 规范化数字编号
143
+ pattern = re.compile(r"(\d{4,32})")
144
+ matchers = pattern.findall(text)
145
+ if matchers:
146
+ # print('digit')
147
+ for matcher in matchers:
148
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
149
+
150
+ # 规范化纯数
151
+ pattern = re.compile(r"(\d+(\.\d+)?)")
152
+ matchers = pattern.findall(text)
153
+ if matchers:
154
+ # print('cardinal')
155
+ for matcher in matchers:
156
+ text = text.replace(
157
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
158
+ )
159
+
160
+ self.norm_text = text
161
+ self._particular()
162
+
163
+ return self.norm_text.lstrip("^").rstrip("$")
164
+
165
+
166
+ if __name__ == "__main__":
167
+
168
+ # 测试程序
169
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
170
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
171
+ print(Text(raw_text="分数:32477/76391。").normalize())
172
+ print(Text(raw_text="百分数:80.03%。").normalize())
173
+ print(Text(raw_text="编号:31520181154418。").normalize())
174
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
175
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
176
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
177
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
@@ -0,0 +1,69 @@
1
+ import itertools
2
+ import re
3
+
4
+ LANGUAGE_UNICODE_RANGE_MAP = {
5
+ "ZH": [(0x4E00, 0x9FFF)],
6
+ "JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)],
7
+ "EN": [(0x0000, 0x007F)],
8
+ }
9
+
10
+ SYMBOLS_MAPPING = {
11
+ ":": ",",
12
+ ";": ",",
13
+ ",": ",",
14
+ "。": ".",
15
+ "!": "!",
16
+ "?": "?",
17
+ "\n": ".",
18
+ "·": ",",
19
+ "、": ",",
20
+ "...": "…",
21
+ "“": "'",
22
+ "”": "'",
23
+ "‘": "'",
24
+ "’": "'",
25
+ "(": "'",
26
+ ")": "'",
27
+ "(": "'",
28
+ ")": "'",
29
+ "《": "'",
30
+ "》": "'",
31
+ "【": "'",
32
+ "】": "'",
33
+ "[": "'",
34
+ "]": "'",
35
+ "—": "-",
36
+ "~": "-",
37
+ "~": "-",
38
+ "・": "-",
39
+ "「": "'",
40
+ "」": "'",
41
+ ";": ",",
42
+ ":": ",",
43
+ }
44
+
45
+ REPLACE_SYMBOL_REGEX = re.compile(
46
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
47
+ )
48
+ ALL_KNOWN_UTF8_RANGE = list(
49
+ itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values())
50
+ )
51
+ REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
52
+ "[^"
53
+ + "".join(
54
+ f"{re.escape(chr(start))}-{re.escape(chr(end))}"
55
+ for start, end in ALL_KNOWN_UTF8_RANGE
56
+ )
57
+ + "]"
58
+ )
59
+
60
+
61
+ def clean_text(text):
62
+ # Clean the text
63
+ text = text.strip()
64
+
65
+ # Replace all chinese symbols with their english counterparts
66
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
67
+ text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
68
+
69
+ return text
@@ -0,0 +1,130 @@
1
+ import re
2
+ import string
3
+
4
+ from fish_speech.text.clean import clean_text
5
+
6
+
7
+ def utf_8_len(text):
8
+ return len(text.encode("utf-8"))
9
+
10
+
11
+ def break_text(texts, length, splits: set):
12
+ for text in texts:
13
+ if utf_8_len(text) <= length:
14
+ yield text
15
+ continue
16
+
17
+ curr = ""
18
+ for char in text:
19
+ curr += char
20
+
21
+ if char in splits:
22
+ yield curr
23
+ curr = ""
24
+
25
+ if curr:
26
+ yield curr
27
+
28
+
29
+ def break_text_by_length(texts, length):
30
+ for text in texts:
31
+ if utf_8_len(text) <= length:
32
+ yield text
33
+ continue
34
+
35
+ curr = ""
36
+ for char in text:
37
+ curr += char
38
+
39
+ if utf_8_len(curr) >= length:
40
+ yield curr
41
+ curr = ""
42
+
43
+ if curr:
44
+ yield curr
45
+
46
+
47
+ def add_cleaned(curr, segments):
48
+ curr = curr.strip()
49
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
50
+ segments.append(curr)
51
+
52
+
53
+ def protect_float(text):
54
+ # Turns 3.14 into <3_f_14> to prevent splitting
55
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
56
+
57
+
58
+ def unprotect_float(text):
59
+ # Turns <3_f_14> into 3.14
60
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
61
+
62
+
63
+ def split_text(text, length):
64
+ text = clean_text(text)
65
+
66
+ # Break the text into pieces with following rules:
67
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
68
+ # 2. If the text is longer than length, split at ","
69
+ # 3. If the text is still longer than length, split at " "
70
+ # 4. If the text is still longer than length, split at any character to length
71
+
72
+ texts = [text]
73
+ texts = map(protect_float, texts)
74
+ texts = break_text(texts, length, {".", "!", "?"})
75
+ texts = map(unprotect_float, texts)
76
+ texts = break_text(texts, length, {","})
77
+ texts = break_text(texts, length, {" "})
78
+ texts = list(break_text_by_length(texts, length))
79
+
80
+ # Then, merge the texts into segments with length <= length
81
+ segments = []
82
+ curr = ""
83
+
84
+ for text in texts:
85
+ if utf_8_len(curr) + utf_8_len(text) <= length:
86
+ curr += text
87
+ else:
88
+ add_cleaned(curr, segments)
89
+ curr = text
90
+
91
+ if curr:
92
+ add_cleaned(curr, segments)
93
+
94
+ return segments
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Test the split_text function
99
+
100
+ text = "This is a test sentence. This is another test sentence. And a third one."
101
+
102
+ assert split_text(text, 50) == [
103
+ "This is a test sentence.",
104
+ "This is another test sentence. And a third one.",
105
+ ]
106
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
107
+ assert split_text(" ", 10) == []
108
+ assert split_text("a", 10) == ["a"]
109
+
110
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
111
+ assert split_text(text, 50) == [
112
+ "This is a test sentence with only commas,",
113
+ "and no dots, and no exclamation marks,",
114
+ "and no question marks, and no newlines.",
115
+ ]
116
+
117
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
118
+ # First half split at " ", second half split at ","
119
+ assert split_text(text, 50) == [
120
+ "This is a test sentence This is a test sentence",
121
+ "This is a test sentence. This is a test sentence,",
122
+ "This is a test sentence, This is a test sentence.",
123
+ ]
124
+
125
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
126
+ assert split_text(text, 50) == [
127
+ "这是一段很长的中文文本,",
128
+ "而且没有句号,也没有感叹号,",
129
+ "也没有问号,也没有换行符.",
130
+ ]
@@ -0,0 +1,139 @@
1
+ import os
2
+ import sys
3
+ from typing import Optional
4
+
5
+ import hydra
6
+ import lightning as L
7
+ # import pyrootutils
8
+ import torch
9
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
10
+ from lightning.pytorch.loggers import Logger
11
+ from lightning.pytorch.strategies import DDPStrategy
12
+ from omegaconf import DictConfig, OmegaConf
13
+
14
+ os.environ.pop("SLURM_NTASKS", None)
15
+ os.environ.pop("SLURM_JOB_NAME", None)
16
+ os.environ.pop("SLURM_NTASKS_PER_NODE", None)
17
+
18
+ # register eval resolver and root
19
+ # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+ # Allow TF32 on Ampere GPUs
22
+ torch.set_float32_matmul_precision("high")
23
+ torch.backends.cudnn.allow_tf32 = True
24
+
25
+ # register eval resolver
26
+ OmegaConf.register_new_resolver("eval", eval)
27
+
28
+ import fish_speech.utils as utils
29
+
30
+ log = utils.RankedLogger(__name__, rank_zero_only=True)
31
+
32
+
33
+ @utils.task_wrapper
34
+ def train(cfg: DictConfig) -> tuple[dict, dict]:
35
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
36
+ training.
37
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
38
+ failure. Useful for multiruns, saving info about the crash, etc.
39
+ Args:
40
+ cfg (DictConfig): Configuration composed by Hydra.
41
+ Returns:
42
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
43
+ """ # noqa: E501
44
+
45
+ # set seed for random number generators in pytorch, numpy and python.random
46
+ if cfg.get("seed"):
47
+ L.seed_everything(cfg.seed, workers=False)
48
+
49
+ if cfg.get("deterministic"):
50
+ torch.use_deterministic_algorithms(True)
51
+
52
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
53
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
54
+
55
+ log.info(f"Instantiating model <{cfg.model._target_}>")
56
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
57
+
58
+ log.info("Instantiating callbacks...")
59
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
60
+
61
+ log.info("Instantiating loggers...")
62
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
63
+
64
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
65
+ trainer: Trainer = hydra.utils.instantiate(
66
+ cfg.trainer,
67
+ callbacks=callbacks,
68
+ logger=logger,
69
+ )
70
+
71
+ object_dict = {
72
+ "cfg": cfg,
73
+ "datamodule": datamodule,
74
+ "model": model,
75
+ "callbacks": callbacks,
76
+ "logger": logger,
77
+ "trainer": trainer,
78
+ }
79
+
80
+ if logger:
81
+ log.info("Logging hyperparameters!")
82
+ utils.log_hyperparameters(object_dict)
83
+
84
+ if cfg.get("train"):
85
+ log.info("Starting training!")
86
+
87
+ ckpt_path = cfg.get("ckpt_path")
88
+ auto_resume = False
89
+
90
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
91
+ if resume_ckpt_path is not None:
92
+ ckpt_path = resume_ckpt_path
93
+ auto_resume = True
94
+
95
+ if ckpt_path is not None:
96
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
97
+
98
+ # resume weights only is disabled for auto-resume
99
+ if cfg.get("resume_weights_only") and auto_resume is False:
100
+ log.info("Resuming weights only!")
101
+ ckpt = torch.load(ckpt_path, map_location=model.device)
102
+ if "state_dict" in ckpt:
103
+ ckpt = ckpt["state_dict"]
104
+ err = model.load_state_dict(ckpt, strict=False)
105
+ log.info(f"Error loading state dict: {err}")
106
+ ckpt_path = None
107
+
108
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
109
+
110
+ train_metrics = trainer.callback_metrics
111
+
112
+ if cfg.get("test"):
113
+ log.info("Starting testing!")
114
+ ckpt_path = trainer.checkpoint_callback.best_model_path
115
+ if ckpt_path == "":
116
+ log.warning("Best ckpt not found! Using current weights for testing...")
117
+ ckpt_path = cfg.get("ckpt_path")
118
+
119
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
120
+ log.info(f"Best ckpt path: {ckpt_path}")
121
+
122
+ test_metrics = trainer.callback_metrics
123
+
124
+ # merge train and test metrics
125
+ metric_dict = {**train_metrics, **test_metrics}
126
+
127
+ return metric_dict, object_dict
128
+
129
+
130
+ @hydra.main(
131
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
132
+ )
133
+ def main(cfg: DictConfig) -> Optional[float]:
134
+ # train the model
135
+ train(cfg)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()