xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,90 @@
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def sequence_mask(length, max_length=None):
8
+ if max_length is None:
9
+ max_length = length.max()
10
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
+ return x.unsqueeze(0) < length.unsqueeze(1)
12
+
13
+
14
+ def fix_len_compatibility(length, num_downsamplings_in_unet=2):
15
+ factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
16
+ length = (length / factor).ceil() * factor
17
+ if not torch.onnx.is_in_onnx_export():
18
+ return length.int().item()
19
+ else:
20
+ return length
21
+
22
+
23
+ def convert_pad_shape(pad_shape):
24
+ inverted_shape = pad_shape[::-1]
25
+ pad_shape = [item for sublist in inverted_shape for item in sublist]
26
+ return pad_shape
27
+
28
+
29
+ def generate_path(duration, mask):
30
+ device = duration.device
31
+
32
+ b, t_x, t_y = mask.shape
33
+ cum_duration = torch.cumsum(duration, 1)
34
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
35
+
36
+ cum_duration_flat = cum_duration.view(b * t_x)
37
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
38
+ path = path.view(b, t_x, t_y)
39
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
40
+ path = path * mask
41
+ return path
42
+
43
+
44
+ def duration_loss(logw, logw_, lengths):
45
+ loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
46
+ return loss
47
+
48
+
49
+ def normalize(data, mu, std):
50
+ if not isinstance(mu, (float, int)):
51
+ if isinstance(mu, list):
52
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
53
+ elif isinstance(mu, torch.Tensor):
54
+ mu = mu.to(data.device)
55
+ elif isinstance(mu, np.ndarray):
56
+ mu = torch.from_numpy(mu).to(data.device)
57
+ mu = mu.unsqueeze(-1)
58
+
59
+ if not isinstance(std, (float, int)):
60
+ if isinstance(std, list):
61
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
62
+ elif isinstance(std, torch.Tensor):
63
+ std = std.to(data.device)
64
+ elif isinstance(std, np.ndarray):
65
+ std = torch.from_numpy(std).to(data.device)
66
+ std = std.unsqueeze(-1)
67
+
68
+ return (data - mu) / std
69
+
70
+
71
+ def denormalize(data, mu, std):
72
+ if not isinstance(mu, float):
73
+ if isinstance(mu, list):
74
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
75
+ elif isinstance(mu, torch.Tensor):
76
+ mu = mu.to(data.device)
77
+ elif isinstance(mu, np.ndarray):
78
+ mu = torch.from_numpy(mu).to(data.device)
79
+ mu = mu.unsqueeze(-1)
80
+
81
+ if not isinstance(std, float):
82
+ if isinstance(std, list):
83
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
84
+ elif isinstance(std, torch.Tensor):
85
+ std = std.to(data.device)
86
+ elif isinstance(std, np.ndarray):
87
+ std = torch.from_numpy(std).to(data.device)
88
+ std = std.unsqueeze(-1)
89
+
90
+ return data * std + mu
@@ -0,0 +1,22 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from matcha.utils.monotonic_align.core import maximum_path_c
5
+
6
+
7
+ def maximum_path(value, mask):
8
+ """Cython optimised version.
9
+ value: [b, t_x, t_y]
10
+ mask: [b, t_x, t_y]
11
+ """
12
+ value = value * mask
13
+ device = value.device
14
+ dtype = value.dtype
15
+ value = value.data.cpu().numpy().astype(np.float32)
16
+ path = np.zeros_like(value).astype(np.int32)
17
+ mask = mask.data.cpu().numpy()
18
+
19
+ t_x_max = mask.sum(1)[:, 0].astype(np.int32)
20
+ t_y_max = mask.sum(2)[:, 0].astype(np.int32)
21
+ maximum_path_c(path, value, t_x_max, t_y_max)
22
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
@@ -0,0 +1,47 @@
1
+ import numpy as np
2
+
3
+ cimport cython
4
+ cimport numpy as np
5
+
6
+ from cython.parallel import prange
7
+
8
+
9
+ @cython.boundscheck(False)
10
+ @cython.wraparound(False)
11
+ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
12
+ cdef int x
13
+ cdef int y
14
+ cdef float v_prev
15
+ cdef float v_cur
16
+ cdef float tmp
17
+ cdef int index = t_x - 1
18
+
19
+ for y in range(t_y):
20
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
21
+ if x == y:
22
+ v_cur = max_neg_val
23
+ else:
24
+ v_cur = value[x, y-1]
25
+ if x == 0:
26
+ if y == 0:
27
+ v_prev = 0.
28
+ else:
29
+ v_prev = max_neg_val
30
+ else:
31
+ v_prev = value[x-1, y-1]
32
+ value[x, y] = max(v_cur, v_prev) + value[x, y]
33
+
34
+ for y in range(t_y - 1, -1, -1):
35
+ path[index, y] = 1
36
+ if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
37
+ index = index - 1
38
+
39
+
40
+ @cython.boundscheck(False)
41
+ @cython.wraparound(False)
42
+ cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
43
+ cdef int b = values.shape[0]
44
+
45
+ cdef int i
46
+ for i in prange(b, nogil=True):
47
+ maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
@@ -0,0 +1,7 @@
1
+ # from distutils.core import setup
2
+ # from Cython.Build import cythonize
3
+ # import numpy
4
+
5
+ # setup(name='monotonic_align',
6
+ # ext_modules=cythonize("core.pyx"),
7
+ # include_dirs=[numpy.get_include()])
@@ -0,0 +1,21 @@
1
+ import logging
2
+
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+
5
+
6
+ def get_pylogger(name: str = __name__) -> logging.Logger:
7
+ """Initializes a multi-GPU-friendly python command line logger.
8
+
9
+ :param name: The name of the logger, defaults to ``__name__``.
10
+
11
+ :return: A logger object.
12
+ """
13
+ logger = logging.getLogger(name)
14
+
15
+ # this ensures all logging levels get marked with the rank zero decorator
16
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
17
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
18
+ for level in logging_levels:
19
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
20
+
21
+ return logger
@@ -0,0 +1,101 @@
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from matcha.utils import pylogger
13
+
14
+ log = pylogger.get_pylogger(__name__)
15
+
16
+
17
+ @rank_zero_only
18
+ def print_config_tree(
19
+ cfg: DictConfig,
20
+ print_order: Sequence[str] = (
21
+ "data",
22
+ "model",
23
+ "callbacks",
24
+ "logger",
25
+ "trainer",
26
+ "paths",
27
+ "extras",
28
+ ),
29
+ resolve: bool = False,
30
+ save_to_file: bool = False,
31
+ ) -> None:
32
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
33
+
34
+ :param cfg: A DictConfig composed by Hydra.
35
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
36
+ "callbacks", "logger", "trainer", "paths", "extras")``.
37
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
38
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
39
+ """
40
+ style = "dim"
41
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
42
+
43
+ queue = []
44
+
45
+ # add fields from `print_order` to queue
46
+ for field in print_order:
47
+ _ = (
48
+ queue.append(field)
49
+ if field in cfg
50
+ else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
51
+ )
52
+
53
+ # add all the other fields to queue (not specified in `print_order`)
54
+ for field in cfg:
55
+ if field not in queue:
56
+ queue.append(field)
57
+
58
+ # generate config tree from queue
59
+ for field in queue:
60
+ branch = tree.add(field, style=style, guide_style=style)
61
+
62
+ config_group = cfg[field]
63
+ if isinstance(config_group, DictConfig):
64
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
65
+ else:
66
+ branch_content = str(config_group)
67
+
68
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
69
+
70
+ # print config tree
71
+ rich.print(tree)
72
+
73
+ # save config tree to file
74
+ if save_to_file:
75
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
76
+ rich.print(tree, file=file)
77
+
78
+
79
+ @rank_zero_only
80
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
81
+ """Prompts user to input tags from command line if no tags are provided in config.
82
+
83
+ :param cfg: A DictConfig composed by Hydra.
84
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
85
+ """
86
+ if not cfg.get("tags"):
87
+ if "id" in HydraConfig().cfg.hydra.job:
88
+ raise ValueError("Specify tags before launching a multirun!")
89
+
90
+ log.warning("No tags provided in config. Prompting user to input tags...")
91
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
92
+ tags = [t.strip() for t in tags.split(",") if t != ""]
93
+
94
+ with open_dict(cfg):
95
+ cfg.tags = tags
96
+
97
+ log.info(f"Tags: {cfg.tags}")
98
+
99
+ if save_to_file:
100
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
101
+ rich.print(cfg.tags, file=file)
@@ -0,0 +1,259 @@
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from importlib.util import find_spec
5
+ from math import ceil
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Dict, Tuple
8
+
9
+ import gdown
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ import wget
14
+ from omegaconf import DictConfig
15
+
16
+ from matcha.utils import pylogger, rich_utils
17
+
18
+ log = pylogger.get_pylogger(__name__)
19
+
20
+
21
+ def extras(cfg: DictConfig) -> None:
22
+ """Applies optional utilities before the task is started.
23
+
24
+ Utilities:
25
+ - Ignoring python warnings
26
+ - Setting tags from command line
27
+ - Rich config printing
28
+
29
+ :param cfg: A DictConfig object containing the config tree.
30
+ """
31
+ # return if no `extras` config
32
+ if not cfg.get("extras"):
33
+ log.warning("Extras config not found! <cfg.extras=null>")
34
+ return
35
+
36
+ # disable python warnings
37
+ if cfg.extras.get("ignore_warnings"):
38
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
39
+ warnings.filterwarnings("ignore")
40
+
41
+ # prompt user to input tags from command line if none are provided in the config
42
+ if cfg.extras.get("enforce_tags"):
43
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
44
+ rich_utils.enforce_tags(cfg, save_to_file=True)
45
+
46
+ # pretty print config tree using Rich library
47
+ if cfg.extras.get("print_config"):
48
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
49
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
50
+
51
+
52
+ def task_wrapper(task_func: Callable) -> Callable:
53
+ """Optional decorator that controls the failure behavior when executing the task function.
54
+
55
+ This wrapper can be used to:
56
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
57
+ - save the exception to a `.log` file
58
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
59
+ - etc. (adjust depending on your needs)
60
+
61
+ Example:
62
+ ```
63
+ @utils.task_wrapper
64
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
65
+ ...
66
+ return metric_dict, object_dict
67
+ ```
68
+
69
+ :param task_func: The task function to be wrapped.
70
+
71
+ :return: The wrapped task function.
72
+ """
73
+
74
+ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
75
+ # execute the task
76
+ try:
77
+ metric_dict, object_dict = task_func(cfg=cfg)
78
+
79
+ # things to do if exception occurs
80
+ except Exception as ex:
81
+ # save exception to `.log` file
82
+ log.exception("")
83
+
84
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
85
+ # so when using hparam search plugins like Optuna, you might want to disable
86
+ # raising the below exception to avoid multirun failure
87
+ raise ex
88
+
89
+ # things to always do after either success or exception
90
+ finally:
91
+ # display output dir path in terminal
92
+ log.info(f"Output dir: {cfg.paths.output_dir}")
93
+
94
+ # always close wandb run (even if exception occurs so multirun won't fail)
95
+ if find_spec("wandb"): # check if wandb is installed
96
+ import wandb
97
+
98
+ if wandb.run:
99
+ log.info("Closing wandb!")
100
+ wandb.finish()
101
+
102
+ return metric_dict, object_dict
103
+
104
+ return wrap
105
+
106
+
107
+ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
108
+ """Safely retrieves value of the metric logged in LightningModule.
109
+
110
+ :param metric_dict: A dict containing metric values.
111
+ :param metric_name: The name of the metric to retrieve.
112
+ :return: The value of the metric.
113
+ """
114
+ if not metric_name:
115
+ log.info("Metric name is None! Skipping metric value retrieval...")
116
+ return None
117
+
118
+ if metric_name not in metric_dict:
119
+ raise ValueError(
120
+ f"Metric value not found! <metric_name={metric_name}>\n"
121
+ "Make sure metric name logged in LightningModule is correct!\n"
122
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
123
+ )
124
+
125
+ metric_value = metric_dict[metric_name].item()
126
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
127
+
128
+ return metric_value
129
+
130
+
131
+ def intersperse(lst, item):
132
+ # Adds blank symbol
133
+ result = [item] * (len(lst) * 2 + 1)
134
+ result[1::2] = lst
135
+ return result
136
+
137
+
138
+ def save_figure_to_numpy(fig):
139
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
140
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
141
+ return data
142
+
143
+
144
+ def plot_tensor(tensor):
145
+ plt.style.use("default")
146
+ fig, ax = plt.subplots(figsize=(12, 3))
147
+ im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
148
+ plt.colorbar(im, ax=ax)
149
+ plt.tight_layout()
150
+ fig.canvas.draw()
151
+ data = save_figure_to_numpy(fig)
152
+ plt.close()
153
+ return data
154
+
155
+
156
+ def save_plot(tensor, savepath):
157
+ plt.style.use("default")
158
+ fig, ax = plt.subplots(figsize=(12, 3))
159
+ im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
160
+ plt.colorbar(im, ax=ax)
161
+ plt.tight_layout()
162
+ fig.canvas.draw()
163
+ plt.savefig(savepath)
164
+ plt.close()
165
+
166
+
167
+ def to_numpy(tensor):
168
+ if isinstance(tensor, np.ndarray):
169
+ return tensor
170
+ elif isinstance(tensor, torch.Tensor):
171
+ return tensor.detach().cpu().numpy()
172
+ elif isinstance(tensor, list):
173
+ return np.array(tensor)
174
+ else:
175
+ raise TypeError("Unsupported type for conversion to numpy array")
176
+
177
+
178
+ def get_user_data_dir(appname="matcha_tts"):
179
+ """
180
+ Args:
181
+ appname (str): Name of application
182
+
183
+ Returns:
184
+ Path: path to user data directory
185
+ """
186
+
187
+ MATCHA_HOME = os.environ.get("MATCHA_HOME")
188
+ if MATCHA_HOME is not None:
189
+ ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
190
+ elif sys.platform == "win32":
191
+ import winreg # pylint: disable=import-outside-toplevel
192
+
193
+ key = winreg.OpenKey(
194
+ winreg.HKEY_CURRENT_USER,
195
+ r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
196
+ )
197
+ dir_, _ = winreg.QueryValueEx(key, "Local AppData")
198
+ ans = Path(dir_).resolve(strict=False)
199
+ elif sys.platform == "darwin":
200
+ ans = Path("~/Library/Application Support/").expanduser()
201
+ else:
202
+ ans = Path.home().joinpath(".local/share")
203
+
204
+ final_path = ans.joinpath(appname)
205
+ final_path.mkdir(parents=True, exist_ok=True)
206
+ return final_path
207
+
208
+
209
+ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
210
+ if Path(checkpoint_path).exists():
211
+ log.debug(f"[+] Model already present at {checkpoint_path}!")
212
+ print(f"[+] Model already present at {checkpoint_path}!")
213
+ return
214
+ log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
215
+ print(f"[-] Model not found at {checkpoint_path}! Will download it")
216
+ checkpoint_path = str(checkpoint_path)
217
+ if not use_wget:
218
+ gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
219
+ else:
220
+ wget.download(url=url, out=checkpoint_path)
221
+
222
+
223
+ def get_phoneme_durations(durations, phones):
224
+ prev = durations[0]
225
+ merged_durations = []
226
+ # Convolve with stride 2
227
+ for i in range(1, len(durations), 2):
228
+ if i == len(durations) - 2:
229
+ # if it is last take full value
230
+ next_half = durations[i + 1]
231
+ else:
232
+ next_half = ceil(durations[i + 1] / 2)
233
+
234
+ curr = prev + durations[i] + next_half
235
+ prev = durations[i + 1] - next_half
236
+ merged_durations.append(curr)
237
+
238
+ assert len(phones) == len(merged_durations)
239
+ assert len(merged_durations) == (len(durations) - 1) // 2
240
+
241
+ merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long)
242
+ start = torch.tensor(0)
243
+ duration_json = []
244
+ for i, duration in enumerate(merged_durations):
245
+ duration_json.append(
246
+ {
247
+ phones[i]: {
248
+ "starttime": start.item(),
249
+ "endtime": duration.item(),
250
+ "duration": duration.item() - start.item(),
251
+ }
252
+ }
253
+ )
254
+ start = duration
255
+
256
+ assert list(duration_json[-1].values())[0]["endtime"] == sum(
257
+ durations
258
+ ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}"
259
+ return duration_json
@@ -1,14 +1,14 @@
1
1
  {
2
2
  "files": {
3
3
  "main.css": "./static/css/main.4bafd904.css",
4
- "main.js": "./static/js/main.ffc26121.js",
4
+ "main.js": "./static/js/main.661c7b0a.js",
5
5
  "static/media/icon.webp": "./static/media/icon.4603d52c63041e5dfbfd.webp",
6
6
  "index.html": "./index.html",
7
7
  "main.4bafd904.css.map": "./static/css/main.4bafd904.css.map",
8
- "main.ffc26121.js.map": "./static/js/main.ffc26121.js.map"
8
+ "main.661c7b0a.js.map": "./static/js/main.661c7b0a.js.map"
9
9
  },
10
10
  "entrypoints": [
11
11
  "static/css/main.4bafd904.css",
12
- "static/js/main.ffc26121.js"
12
+ "static/js/main.661c7b0a.js"
13
13
  ]
14
14
  }
@@ -1 +1 @@
1
- <!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.ffc26121.js"></script><link href="./static/css/main.4bafd904.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
1
+ <!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.661c7b0a.js"></script><link href="./static/css/main.4bafd904.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>