xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ from .braceexpand import braceexpand
2
+ from .context import autocast_exclude_mps
3
+ from .file import get_latest_checkpoint
4
+ from .instantiators import instantiate_callbacks, instantiate_loggers
5
+ from .logger import RankedLogger
6
+ from .logging_utils import log_hyperparameters
7
+ from .rich_utils import enforce_tags, print_config_tree
8
+ from .utils import extras, get_metric_value, task_wrapper
9
+
10
+ __all__ = [
11
+ "enforce_tags",
12
+ "extras",
13
+ "get_metric_value",
14
+ "RankedLogger",
15
+ "instantiate_callbacks",
16
+ "instantiate_loggers",
17
+ "log_hyperparameters",
18
+ "print_config_tree",
19
+ "task_wrapper",
20
+ "braceexpand",
21
+ "get_latest_checkpoint",
22
+ "autocast_exclude_mps",
23
+ ]
@@ -0,0 +1,217 @@
1
+ """
2
+ Bash-style brace expansion
3
+ Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
4
+ License: MIT
5
+ """
6
+
7
+ import re
8
+ import string
9
+ from itertools import chain, product
10
+ from typing import Iterable, Iterator, Optional
11
+
12
+ __all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
13
+
14
+
15
+ class UnbalancedBracesError(ValueError):
16
+ pass
17
+
18
+
19
+ alphabet = string.ascii_uppercase + string.ascii_lowercase
20
+
21
+ int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
22
+ char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
23
+ escape_re = re.compile(r"\\(.)")
24
+
25
+
26
+ def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
27
+ """braceexpand(pattern) -> iterator over generated strings
28
+
29
+ Returns an iterator over the strings resulting from brace expansion
30
+ of pattern. This function implements Brace Expansion as described in
31
+ bash(1), with the following limitations:
32
+
33
+ * A pattern containing unbalanced braces will raise an
34
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
35
+ be partly expanded or ignored.
36
+
37
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
38
+ include the characters '[]^_`' between 'Z' and 'a'.
39
+
40
+ When escape is True (the default), characters in pattern can be
41
+ prefixed with a backslash to cause them not to be interpreted as
42
+ special characters for brace expansion (such as '{', '}', ',').
43
+ To pass through a a literal backslash, double it ('\\\\').
44
+
45
+ When escape is False, backslashes in pattern have no special
46
+ meaning and will be preserved in the output.
47
+
48
+ Examples:
49
+
50
+ >>> from braceexpand import braceexpand
51
+
52
+ # Integer range
53
+ >>> list(braceexpand('item{1..3}'))
54
+ ['item1', 'item2', 'item3']
55
+
56
+ # Character range
57
+ >>> list(braceexpand('{a..c}'))
58
+ ['a', 'b', 'c']
59
+
60
+ # Sequence
61
+ >>> list(braceexpand('index.html{,.backup}'))
62
+ ['index.html', 'index.html.backup']
63
+
64
+ # Nested patterns
65
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
66
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
67
+
68
+ # Prefixing an integer with zero causes all numbers to be padded to
69
+ # the same width.
70
+ >>> list(braceexpand('{07..10}'))
71
+ ['07', '08', '09', '10']
72
+
73
+ # An optional increment can be specified for ranges.
74
+ >>> list(braceexpand('{a..g..2}'))
75
+ ['a', 'c', 'e', 'g']
76
+
77
+ # Ranges can go in both directions.
78
+ >>> list(braceexpand('{4..1}'))
79
+ ['4', '3', '2', '1']
80
+
81
+ # Numbers can be negative
82
+ >>> list(braceexpand('{2..-1}'))
83
+ ['2', '1', '0', '-1']
84
+
85
+ # Unbalanced braces raise an exception.
86
+ >>> list(braceexpand('{1{2,3}'))
87
+ Traceback (most recent call last):
88
+ ...
89
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
90
+
91
+ # By default, the backslash is the escape character.
92
+ >>> list(braceexpand(r'{1\\{2,3}'))
93
+ ['1{2', '3']
94
+
95
+ # Setting 'escape' to False disables backslash escaping.
96
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
97
+ ['\\\\1', '\\\\2']
98
+
99
+ """
100
+ return (
101
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
102
+ )
103
+
104
+
105
+ def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
106
+ start = 0
107
+ pos = 0
108
+ bracketdepth = 0
109
+ items: list[Iterable[str]] = []
110
+
111
+ # print 'pattern:', pattern
112
+ while pos < len(pattern):
113
+ if escape and pattern[pos] == "\\":
114
+ pos += 2
115
+ continue
116
+ elif pattern[pos] == "{":
117
+ if bracketdepth == 0 and pos > start:
118
+ # print 'literal:', pattern[start:pos]
119
+ items.append([pattern[start:pos]])
120
+ start = pos
121
+ bracketdepth += 1
122
+ elif pattern[pos] == "}":
123
+ bracketdepth -= 1
124
+ if bracketdepth == 0:
125
+ # print 'expression:', pattern[start+1:pos]
126
+ expr = pattern[start + 1 : pos]
127
+ item = parse_expression(expr, escape)
128
+ if item is None: # not a range or sequence
129
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
130
+ else:
131
+ items.append(item)
132
+ start = pos + 1 # skip the closing brace
133
+ pos += 1
134
+
135
+ if bracketdepth != 0: # unbalanced braces
136
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
137
+
138
+ if start < pos:
139
+ items.append([pattern[start:]])
140
+
141
+ return ("".join(item) for item in product(*items))
142
+
143
+
144
+ def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
145
+ int_range_match = int_range_re.match(expr)
146
+ if int_range_match:
147
+ return make_int_range(*int_range_match.groups())
148
+
149
+ char_range_match = char_range_re.match(expr)
150
+ if char_range_match:
151
+ return make_char_range(*char_range_match.groups())
152
+
153
+ return parse_sequence(expr, escape)
154
+
155
+
156
+ def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
157
+ # sequence -> chain(*sequence_items)
158
+ start = 0
159
+ pos = 0
160
+ bracketdepth = 0
161
+ items: list[Iterable[str]] = []
162
+
163
+ # print 'sequence:', seq
164
+ while pos < len(seq):
165
+ if escape and seq[pos] == "\\":
166
+ pos += 2
167
+ continue
168
+ elif seq[pos] == "{":
169
+ bracketdepth += 1
170
+ elif seq[pos] == "}":
171
+ bracketdepth -= 1
172
+ elif seq[pos] == "," and bracketdepth == 0:
173
+ items.append(parse_pattern(seq[start:pos], escape))
174
+ start = pos + 1 # skip the comma
175
+ pos += 1
176
+
177
+ if bracketdepth != 0:
178
+ raise UnbalancedBracesError
179
+ if not items:
180
+ return None
181
+
182
+ # part after the last comma (may be the empty string)
183
+ items.append(parse_pattern(seq[start:], escape))
184
+ return chain(*items)
185
+
186
+
187
+ def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
188
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
189
+ padding = max(len(left), len(right))
190
+ else:
191
+ padding = 0
192
+ step = (int(incr) or 1) if incr else 1
193
+ start = int(left)
194
+ end = int(right)
195
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
196
+ fmt = "%0{}d".format(padding)
197
+ return (fmt % i for i in r)
198
+
199
+
200
+ def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
201
+ step = (int(incr) or 1) if incr else 1
202
+ start = alphabet.index(left)
203
+ end = alphabet.index(right)
204
+ if start < end:
205
+ return alphabet[start : end + 1 : step]
206
+ else:
207
+ end = end or -len(alphabet)
208
+ return alphabet[start : end - 1 : -step]
209
+
210
+
211
+ if __name__ == "__main__":
212
+ import doctest
213
+ import sys
214
+
215
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
216
+ if failed:
217
+ sys.exit(1)
@@ -0,0 +1,13 @@
1
+ from contextlib import nullcontext
2
+
3
+ import torch
4
+
5
+
6
+ def autocast_exclude_mps(
7
+ device_type: str, dtype: torch.dtype
8
+ ) -> nullcontext | torch.autocast:
9
+ return (
10
+ nullcontext()
11
+ if torch.backends.mps.is_available()
12
+ else torch.autocast(device_type, dtype)
13
+ )
@@ -0,0 +1,16 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+
5
+ def get_latest_checkpoint(path: Path | str) -> Path | None:
6
+ # Find the latest checkpoint
7
+ ckpt_dir = Path(path)
8
+
9
+ if ckpt_dir.exists() is False:
10
+ return None
11
+
12
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
13
+ if len(ckpts) == 0:
14
+ return None
15
+
16
+ return ckpts[-1]
@@ -0,0 +1,50 @@
1
+ from typing import List
2
+
3
+ import hydra
4
+ from omegaconf import DictConfig
5
+ from pytorch_lightning import Callback
6
+ from pytorch_lightning.loggers import Logger
7
+
8
+ from .logger import RankedLogger
9
+
10
+ log = RankedLogger(__name__, rank_zero_only=True)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config."""
15
+
16
+ callbacks: List[Callback] = []
17
+
18
+ if not callbacks_cfg:
19
+ log.warning("No callback configs found! Skipping..")
20
+ return callbacks
21
+
22
+ if not isinstance(callbacks_cfg, DictConfig):
23
+ raise TypeError("Callbacks config must be a DictConfig!")
24
+
25
+ for _, cb_conf in callbacks_cfg.items():
26
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
27
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
28
+ callbacks.append(hydra.utils.instantiate(cb_conf))
29
+
30
+ return callbacks
31
+
32
+
33
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
34
+ """Instantiates loggers from config."""
35
+
36
+ logger: List[Logger] = []
37
+
38
+ if not logger_cfg:
39
+ log.warning("No logger configs found! Skipping...")
40
+ return logger
41
+
42
+ if not isinstance(logger_cfg, DictConfig):
43
+ raise TypeError("Logger config must be a DictConfig!")
44
+
45
+ for _, lg_conf in logger_cfg.items():
46
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
47
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
48
+ logger.append(hydra.utils.instantiate(lg_conf))
49
+
50
+ return logger
@@ -0,0 +1,55 @@
1
+ import logging
2
+ from typing import Mapping, Optional
3
+
4
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5
+
6
+
7
+ class RankedLogger(logging.LoggerAdapter):
8
+ """A multi-GPU-friendly python command line logger."""
9
+
10
+ def __init__(
11
+ self,
12
+ name: str = __name__,
13
+ rank_zero_only: bool = True,
14
+ extra: Optional[Mapping[str, object]] = None,
15
+ ) -> None:
16
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17
+ with their rank prefixed in the log message.
18
+
19
+ :param name: The name of the logger. Default is ``__name__``.
20
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22
+ """
23
+ logger = logging.getLogger(name)
24
+ super().__init__(logger=logger, extra=extra)
25
+ self.rank_zero_only = rank_zero_only
26
+
27
+ def log(
28
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
29
+ ) -> None:
30
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
31
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
32
+ occur on that rank/process.
33
+
34
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
35
+ :param msg: The message to log.
36
+ :param rank: The rank to log at.
37
+ :param args: Additional args to pass to the underlying logging function.
38
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
39
+ """
40
+ if self.isEnabledFor(level):
41
+ msg, kwargs = self.process(msg, kwargs)
42
+ current_rank = getattr(rank_zero_only, "rank", None)
43
+ if current_rank is None:
44
+ raise RuntimeError(
45
+ "The `rank_zero_only.rank` needs to be set before use"
46
+ )
47
+ msg = rank_prefixed_message(msg, current_rank)
48
+ if self.rank_zero_only:
49
+ if current_rank == 0:
50
+ self.logger.log(level, msg, *args, **kwargs)
51
+ else:
52
+ if rank is None:
53
+ self.logger.log(level, msg, *args, **kwargs)
54
+ elif current_rank == rank:
55
+ self.logger.log(level, msg, *args, **kwargs)
@@ -0,0 +1,48 @@
1
+ from lightning.pytorch.utilities import rank_zero_only
2
+
3
+ from fish_speech.utils import logger as log
4
+
5
+
6
+ @rank_zero_only
7
+ def log_hyperparameters(object_dict: dict) -> None:
8
+ """Controls which config parts are saved by lightning loggers.
9
+
10
+ Additionally saves:
11
+ - Number of model parameters
12
+ """
13
+
14
+ hparams = {}
15
+
16
+ cfg = object_dict["cfg"]
17
+ model = object_dict["model"]
18
+ trainer = object_dict["trainer"]
19
+
20
+ if not trainer.logger:
21
+ log.warning("Logger not found! Skipping hyperparameter logging...")
22
+ return
23
+
24
+ hparams["model"] = cfg["model"]
25
+
26
+ # save number of model parameters
27
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
28
+ hparams["model/params/trainable"] = sum(
29
+ p.numel() for p in model.parameters() if p.requires_grad
30
+ )
31
+ hparams["model/params/non_trainable"] = sum(
32
+ p.numel() for p in model.parameters() if not p.requires_grad
33
+ )
34
+
35
+ hparams["data"] = cfg["data"]
36
+ hparams["trainer"] = cfg["trainer"]
37
+
38
+ hparams["callbacks"] = cfg.get("callbacks")
39
+ hparams["extras"] = cfg.get("extras")
40
+
41
+ hparams["task_name"] = cfg.get("task_name")
42
+ hparams["tags"] = cfg.get("tags")
43
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
44
+ hparams["seed"] = cfg.get("seed")
45
+
46
+ # send hparams to all loggers
47
+ for logger in trainer.loggers:
48
+ logger.log_hyperparams(hparams)
@@ -0,0 +1,100 @@
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from fish_speech.utils import logger as log
13
+
14
+
15
+ @rank_zero_only
16
+ def print_config_tree(
17
+ cfg: DictConfig,
18
+ print_order: Sequence[str] = (
19
+ "data",
20
+ "model",
21
+ "callbacks",
22
+ "logger",
23
+ "trainer",
24
+ "paths",
25
+ "extras",
26
+ ),
27
+ resolve: bool = False,
28
+ save_to_file: bool = False,
29
+ ) -> None:
30
+ """Prints content of DictConfig using Rich library and its tree structure.
31
+
32
+ Args:
33
+ cfg (DictConfig): Configuration composed by Hydra.
34
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
35
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
36
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
37
+ """ # noqa: E501
38
+
39
+ style = "dim"
40
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
41
+
42
+ queue = []
43
+
44
+ # add fields from `print_order` to queue
45
+ for field in print_order:
46
+ (
47
+ queue.append(field)
48
+ if field in cfg
49
+ else log.warning(
50
+ f"Field '{field}' not found in config. "
51
+ + f"Skipping '{field}' config printing..."
52
+ )
53
+ )
54
+
55
+ # add all the other fields to queue (not specified in `print_order`)
56
+ for field in cfg:
57
+ if field not in queue:
58
+ queue.append(field)
59
+
60
+ # generate config tree from queue
61
+ for field in queue:
62
+ branch = tree.add(field, style=style, guide_style=style)
63
+
64
+ config_group = cfg[field]
65
+ if isinstance(config_group, DictConfig):
66
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
67
+ else:
68
+ branch_content = str(config_group)
69
+
70
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
71
+
72
+ # print config tree
73
+ rich.print(tree)
74
+
75
+ # save config tree to file
76
+ if save_to_file:
77
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
78
+ rich.print(tree, file=file)
79
+
80
+
81
+ @rank_zero_only
82
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
83
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
84
+
85
+ if not cfg.get("tags"):
86
+ if "id" in HydraConfig().cfg.hydra.job:
87
+ raise ValueError("Specify tags before launching a multirun!")
88
+
89
+ log.warning("No tags provided in config. Prompting user to input tags...")
90
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
91
+ tags = [t.strip() for t in tags.split(",") if t != ""]
92
+
93
+ with open_dict(cfg):
94
+ cfg.tags = tags
95
+
96
+ log.info(f"Tags: {cfg.tags}")
97
+
98
+ if save_to_file:
99
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
100
+ rich.print(cfg.tags, file=file)
@@ -0,0 +1,122 @@
1
+ import torch
2
+ import torchaudio.functional as F
3
+ from torch import Tensor, nn
4
+ from torchaudio.transforms import MelScale
5
+
6
+
7
+ class LinearSpectrogram(nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_fft=2048,
11
+ win_length=2048,
12
+ hop_length=512,
13
+ center=False,
14
+ mode="pow2_sqrt",
15
+ ):
16
+ super().__init__()
17
+
18
+ self.n_fft = n_fft
19
+ self.win_length = win_length
20
+ self.hop_length = hop_length
21
+ self.center = center
22
+ self.mode = mode
23
+
24
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
25
+
26
+ def forward(self, y: Tensor) -> Tensor:
27
+ if y.ndim == 3:
28
+ y = y.squeeze(1)
29
+
30
+ y = torch.nn.functional.pad(
31
+ y.unsqueeze(1),
32
+ (
33
+ (self.win_length - self.hop_length) // 2,
34
+ (self.win_length - self.hop_length + 1) // 2,
35
+ ),
36
+ mode="reflect",
37
+ ).squeeze(1)
38
+
39
+ spec = torch.stft(
40
+ y,
41
+ self.n_fft,
42
+ hop_length=self.hop_length,
43
+ win_length=self.win_length,
44
+ window=self.window,
45
+ center=self.center,
46
+ pad_mode="reflect",
47
+ normalized=False,
48
+ onesided=True,
49
+ return_complex=True,
50
+ )
51
+
52
+ spec = torch.view_as_real(spec)
53
+
54
+ if self.mode == "pow2_sqrt":
55
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
56
+
57
+ return spec
58
+
59
+
60
+ class LogMelSpectrogram(nn.Module):
61
+ def __init__(
62
+ self,
63
+ sample_rate=44100,
64
+ n_fft=2048,
65
+ win_length=2048,
66
+ hop_length=512,
67
+ n_mels=128,
68
+ center=False,
69
+ f_min=0.0,
70
+ f_max=None,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.sample_rate = sample_rate
75
+ self.n_fft = n_fft
76
+ self.win_length = win_length
77
+ self.hop_length = hop_length
78
+ self.center = center
79
+ self.n_mels = n_mels
80
+ self.f_min = f_min
81
+ self.f_max = f_max or float(sample_rate // 2)
82
+
83
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
84
+
85
+ fb = F.melscale_fbanks(
86
+ n_freqs=self.n_fft // 2 + 1,
87
+ f_min=self.f_min,
88
+ f_max=self.f_max,
89
+ n_mels=self.n_mels,
90
+ sample_rate=self.sample_rate,
91
+ norm="slaney",
92
+ mel_scale="slaney",
93
+ )
94
+ self.register_buffer(
95
+ "fb",
96
+ fb,
97
+ persistent=False,
98
+ )
99
+
100
+ def compress(self, x: Tensor) -> Tensor:
101
+ return torch.log(torch.clamp(x, min=1e-5))
102
+
103
+ def decompress(self, x: Tensor) -> Tensor:
104
+ return torch.exp(x)
105
+
106
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
107
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
108
+
109
+ def forward(
110
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
111
+ ) -> Tensor:
112
+ if sample_rate is not None and sample_rate != self.sample_rate:
113
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
114
+
115
+ linear = self.spectrogram(x)
116
+ x = self.apply_mel_scale(linear)
117
+ x = self.compress(x)
118
+
119
+ if return_linear:
120
+ return x, self.compress(linear)
121
+
122
+ return x