xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,122 @@
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import hydra
4
+ import lightning as L
5
+ import rootutils
6
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
7
+ from lightning.pytorch.loggers import Logger
8
+ from omegaconf import DictConfig
9
+
10
+ from matcha import utils
11
+
12
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
13
+ # ------------------------------------------------------------------------------------ #
14
+ # the setup_root above is equivalent to:
15
+ # - adding project root dir to PYTHONPATH
16
+ # (so you don't need to force user to install project as a package)
17
+ # (necessary before importing any local modules e.g. `from src import utils`)
18
+ # - setting up PROJECT_ROOT environment variable
19
+ # (which is used as a base for paths in "configs/paths/default.yaml")
20
+ # (this way all filepaths are the same no matter where you run the code)
21
+ # - loading environment variables from ".env" in root dir
22
+ #
23
+ # you can remove it if you:
24
+ # 1. either install project as a package or move entry files to project root dir
25
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
26
+ #
27
+ # more info: https://github.com/ashleve/rootutils
28
+ # ------------------------------------------------------------------------------------ #
29
+
30
+
31
+ log = utils.get_pylogger(__name__)
32
+
33
+
34
+ @utils.task_wrapper
35
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
36
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
37
+ training.
38
+
39
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
40
+ failure. Useful for multiruns, saving info about the crash, etc.
41
+
42
+ :param cfg: A DictConfig configuration composed by Hydra.
43
+ :return: A tuple with metrics and dict with all instantiated objects.
44
+ """
45
+ # set seed for random number generators in pytorch, numpy and python.random
46
+ if cfg.get("seed"):
47
+ L.seed_everything(cfg.seed, workers=True)
48
+
49
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access
50
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
51
+
52
+ log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access
53
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
54
+
55
+ log.info("Instantiating callbacks...")
56
+ callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
57
+
58
+ log.info("Instantiating loggers...")
59
+ logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
60
+
61
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access
62
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
63
+
64
+ object_dict = {
65
+ "cfg": cfg,
66
+ "datamodule": datamodule,
67
+ "model": model,
68
+ "callbacks": callbacks,
69
+ "logger": logger,
70
+ "trainer": trainer,
71
+ }
72
+
73
+ if logger:
74
+ log.info("Logging hyperparameters!")
75
+ utils.log_hyperparameters(object_dict)
76
+
77
+ if cfg.get("train"):
78
+ log.info("Starting training!")
79
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
80
+
81
+ train_metrics = trainer.callback_metrics
82
+
83
+ if cfg.get("test"):
84
+ log.info("Starting testing!")
85
+ ckpt_path = trainer.checkpoint_callback.best_model_path
86
+ if ckpt_path == "":
87
+ log.warning("Best ckpt not found! Using current weights for testing...")
88
+ ckpt_path = None
89
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
90
+ log.info(f"Best ckpt path: {ckpt_path}")
91
+
92
+ test_metrics = trainer.callback_metrics
93
+
94
+ # merge train and test metrics
95
+ metric_dict = {**train_metrics, **test_metrics}
96
+
97
+ return metric_dict, object_dict
98
+
99
+
100
+ @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
101
+ def main(cfg: DictConfig) -> Optional[float]:
102
+ """Main entry point for training.
103
+
104
+ :param cfg: DictConfig configuration composed by Hydra.
105
+ :return: Optional[float] with optimized metric value.
106
+ """
107
+ # apply extra utilities
108
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
109
+ utils.extras(cfg)
110
+
111
+ # train the model
112
+ metric_dict, _ = train(cfg)
113
+
114
+ # safely retrieve metric value for hydra-based hyperparameter optimization
115
+ metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
116
+
117
+ # return optimized metric
118
+ return metric_value
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main() # pylint: disable=no-value-for-parameter
@@ -0,0 +1,5 @@
1
+ from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
2
+ from matcha.utils.logging_utils import log_hyperparameters
3
+ from matcha.utils.pylogger import get_pylogger
4
+ from matcha.utils.rich_utils import enforce_tags, print_config_tree
5
+ from matcha.utils.utils import extras, get_metric_value, task_wrapper
@@ -0,0 +1,82 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
@@ -0,0 +1,112 @@
1
+ r"""
2
+ The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
3
+ when needed.
4
+
5
+ Parameters from hparam.py will be used
6
+ """
7
+ import argparse
8
+ import json
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import rootutils
14
+ import torch
15
+ from hydra import compose, initialize
16
+ from omegaconf import open_dict
17
+ from tqdm.auto import tqdm
18
+
19
+ from matcha.data.text_mel_datamodule import TextMelDataModule
20
+ from matcha.utils.logging_utils import pylogger
21
+
22
+ log = pylogger.get_pylogger(__name__)
23
+
24
+
25
+ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
26
+ """Generate data mean and standard deviation helpful in data normalisation
27
+
28
+ Args:
29
+ data_loader (torch.utils.data.Dataloader): _description_
30
+ out_channels (int): mel spectrogram channels
31
+ """
32
+ total_mel_sum = 0
33
+ total_mel_sq_sum = 0
34
+ total_mel_len = 0
35
+
36
+ for batch in tqdm(data_loader, leave=False):
37
+ mels = batch["y"]
38
+ mel_lengths = batch["y_lengths"]
39
+
40
+ total_mel_len += torch.sum(mel_lengths)
41
+ total_mel_sum += torch.sum(mels)
42
+ total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
43
+
44
+ data_mean = total_mel_sum / (total_mel_len * out_channels)
45
+ data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
46
+
47
+ return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
48
+
49
+
50
+ def main():
51
+ parser = argparse.ArgumentParser()
52
+
53
+ parser.add_argument(
54
+ "-i",
55
+ "--input-config",
56
+ type=str,
57
+ default="vctk.yaml",
58
+ help="The name of the yaml config file under configs/data",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "-b",
63
+ "--batch-size",
64
+ type=int,
65
+ default="256",
66
+ help="Can have increased batch size for faster computation",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "-f",
71
+ "--force",
72
+ action="store_true",
73
+ default=False,
74
+ required=False,
75
+ help="force overwrite the file",
76
+ )
77
+ args = parser.parse_args()
78
+ output_file = Path(args.input_config).with_suffix(".json")
79
+
80
+ if os.path.exists(output_file) and not args.force:
81
+ print("File already exists. Use -f to force overwrite")
82
+ sys.exit(1)
83
+
84
+ with initialize(version_base="1.3", config_path="../../configs/data"):
85
+ cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
86
+
87
+ root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
88
+
89
+ with open_dict(cfg):
90
+ del cfg["hydra"]
91
+ del cfg["_target_"]
92
+ cfg["data_statistics"] = None
93
+ cfg["seed"] = 1234
94
+ cfg["batch_size"] = args.batch_size
95
+ cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
96
+ cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
97
+ cfg["load_durations"] = False
98
+
99
+ text_mel_datamodule = TextMelDataModule(**cfg)
100
+ text_mel_datamodule.setup()
101
+ data_loader = text_mel_datamodule.train_dataloader()
102
+ log.info("Dataloader loaded! Now computing stats...")
103
+ params = compute_data_statistics(data_loader, cfg["n_feats"])
104
+ print(params)
105
+ json.dump(
106
+ params,
107
+ open(output_file, "w"),
108
+ )
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
@@ -0,0 +1,195 @@
1
+ r"""
2
+ The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
3
+ when needed.
4
+
5
+ Parameters from hparam.py will be used
6
+ """
7
+ import argparse
8
+ import json
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import lightning
14
+ import numpy as np
15
+ import rootutils
16
+ import torch
17
+ from hydra import compose, initialize
18
+ from omegaconf import open_dict
19
+ from torch import nn
20
+ from tqdm.auto import tqdm
21
+
22
+ from matcha.cli import get_device
23
+ from matcha.data.text_mel_datamodule import TextMelDataModule
24
+ from matcha.models.matcha_tts import MatchaTTS
25
+ from matcha.utils.logging_utils import pylogger
26
+ from matcha.utils.utils import get_phoneme_durations
27
+
28
+ log = pylogger.get_pylogger(__name__)
29
+
30
+
31
+ def save_durations_to_folder(
32
+ attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
33
+ ):
34
+ durations = attn.squeeze().sum(1)[:x_length].numpy()
35
+ durations_json = get_phoneme_durations(durations, text)
36
+ output = output_folder / Path(filepath).name.replace(".wav", ".npy")
37
+ with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
38
+ json.dump(durations_json, f, indent=4, ensure_ascii=False)
39
+
40
+ np.save(output, durations)
41
+
42
+
43
+ @torch.inference_mode()
44
+ def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
45
+ """Generate durations from the model for each datapoint and save it in a folder
46
+
47
+ Args:
48
+ data_loader (torch.utils.data.DataLoader): Dataloader
49
+ model (nn.Module): MatchaTTS model
50
+ device (torch.device): GPU or CPU
51
+ """
52
+
53
+ for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
54
+ x, x_lengths = batch["x"], batch["x_lengths"]
55
+ y, y_lengths = batch["y"], batch["y_lengths"]
56
+ spks = batch["spks"]
57
+ x = x.to(device)
58
+ y = y.to(device)
59
+ x_lengths = x_lengths.to(device)
60
+ y_lengths = y_lengths.to(device)
61
+ spks = spks.to(device) if spks is not None else None
62
+
63
+ _, _, _, attn = model(
64
+ x=x,
65
+ x_lengths=x_lengths,
66
+ y=y,
67
+ y_lengths=y_lengths,
68
+ spks=spks,
69
+ )
70
+ attn = attn.cpu()
71
+ for i in range(attn.shape[0]):
72
+ save_durations_to_folder(
73
+ attn[i],
74
+ x_lengths[i].item(),
75
+ y_lengths[i].item(),
76
+ batch["filepaths"][i],
77
+ output_folder,
78
+ batch["x_texts"][i],
79
+ )
80
+
81
+
82
+ def main():
83
+ parser = argparse.ArgumentParser()
84
+
85
+ parser.add_argument(
86
+ "-i",
87
+ "--input-config",
88
+ type=str,
89
+ default="ljspeech.yaml",
90
+ help="The name of the yaml config file under configs/data",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "-b",
95
+ "--batch-size",
96
+ type=int,
97
+ default="32",
98
+ help="Can have increased batch size for faster computation",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "-f",
103
+ "--force",
104
+ action="store_true",
105
+ default=False,
106
+ required=False,
107
+ help="force overwrite the file",
108
+ )
109
+ parser.add_argument(
110
+ "-c",
111
+ "--checkpoint_path",
112
+ type=str,
113
+ required=True,
114
+ help="Path to the checkpoint file to load the model from",
115
+ )
116
+
117
+ parser.add_argument(
118
+ "-o",
119
+ "--output-folder",
120
+ type=str,
121
+ default=None,
122
+ help="Output folder to save the data statistics",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
127
+ )
128
+
129
+ args = parser.parse_args()
130
+
131
+ with initialize(version_base="1.3", config_path="../../configs/data"):
132
+ cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
133
+
134
+ root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
135
+
136
+ with open_dict(cfg):
137
+ del cfg["hydra"]
138
+ del cfg["_target_"]
139
+ cfg["seed"] = 1234
140
+ cfg["batch_size"] = args.batch_size
141
+ cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
142
+ cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
143
+ cfg["load_durations"] = False
144
+
145
+ if args.output_folder is not None:
146
+ output_folder = Path(args.output_folder)
147
+ else:
148
+ output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
149
+
150
+ print(f"Output folder set to: {output_folder}")
151
+
152
+ if os.path.exists(output_folder) and not args.force:
153
+ print("Folder already exists. Use -f to force overwrite")
154
+ sys.exit(1)
155
+
156
+ output_folder.mkdir(parents=True, exist_ok=True)
157
+
158
+ print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
159
+ print("Loading model...")
160
+ device = get_device(args)
161
+ model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
162
+
163
+ text_mel_datamodule = TextMelDataModule(**cfg)
164
+ text_mel_datamodule.setup()
165
+ try:
166
+ print("Computing stats for training set if exists...")
167
+ train_dataloader = text_mel_datamodule.train_dataloader()
168
+ compute_durations(train_dataloader, model, device, output_folder)
169
+ except lightning.fabric.utilities.exceptions.MisconfigurationException:
170
+ print("No training set found")
171
+
172
+ try:
173
+ print("Computing stats for validation set if exists...")
174
+ val_dataloader = text_mel_datamodule.val_dataloader()
175
+ compute_durations(val_dataloader, model, device, output_folder)
176
+ except lightning.fabric.utilities.exceptions.MisconfigurationException:
177
+ print("No validation set found")
178
+
179
+ try:
180
+ print("Computing stats for test set if exists...")
181
+ test_dataloader = text_mel_datamodule.test_dataloader()
182
+ compute_durations(test_dataloader, model, device, output_folder)
183
+ except lightning.fabric.utilities.exceptions.MisconfigurationException:
184
+ print("No test set found")
185
+
186
+ print(f"[+] Done! Data statistics saved to: {output_folder}")
187
+
188
+
189
+ if __name__ == "__main__":
190
+ # Helps with generating durations for the dataset to train other architectures
191
+ # that cannot learn to align due to limited size of dataset
192
+ # Example usage:
193
+ # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
194
+ # This will create a folder in data/processed_data/durations/ljspeech with the durations
195
+ main()
@@ -0,0 +1,56 @@
1
+ from typing import List
2
+
3
+ import hydra
4
+ from lightning import Callback
5
+ from lightning.pytorch.loggers import Logger
6
+ from omegaconf import DictConfig
7
+
8
+ from matcha.utils import pylogger
9
+
10
+ log = pylogger.get_pylogger(__name__)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config.
15
+
16
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
17
+ :return: A list of instantiated callbacks.
18
+ """
19
+ callbacks: List[Callback] = []
20
+
21
+ if not callbacks_cfg:
22
+ log.warning("No callback configs found! Skipping..")
23
+ return callbacks
24
+
25
+ if not isinstance(callbacks_cfg, DictConfig):
26
+ raise TypeError("Callbacks config must be a DictConfig!")
27
+
28
+ for _, cb_conf in callbacks_cfg.items():
29
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
30
+ log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
31
+ callbacks.append(hydra.utils.instantiate(cb_conf))
32
+
33
+ return callbacks
34
+
35
+
36
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
37
+ """Instantiates loggers from config.
38
+
39
+ :param logger_cfg: A DictConfig object containing logger configurations.
40
+ :return: A list of instantiated loggers.
41
+ """
42
+ logger: List[Logger] = []
43
+
44
+ if not logger_cfg:
45
+ log.warning("No logger configs found! Skipping...")
46
+ return logger
47
+
48
+ if not isinstance(logger_cfg, DictConfig):
49
+ raise TypeError("Logger config must be a DictConfig!")
50
+
51
+ for _, lg_conf in logger_cfg.items():
52
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
53
+ log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
54
+ logger.append(hydra.utils.instantiate(lg_conf))
55
+
56
+ return logger
@@ -0,0 +1,53 @@
1
+ from typing import Any, Dict
2
+
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+ from omegaconf import OmegaConf
5
+
6
+ from matcha.utils import pylogger
7
+
8
+ log = pylogger.get_pylogger(__name__)
9
+
10
+
11
+ @rank_zero_only
12
+ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
13
+ """Controls which config parts are saved by Lightning loggers.
14
+
15
+ Additionally saves:
16
+ - Number of model parameters
17
+
18
+ :param object_dict: A dictionary containing the following objects:
19
+ - `"cfg"`: A DictConfig object containing the main config.
20
+ - `"model"`: The Lightning model.
21
+ - `"trainer"`: The Lightning trainer.
22
+ """
23
+ hparams = {}
24
+
25
+ cfg = OmegaConf.to_container(object_dict["cfg"])
26
+ model = object_dict["model"]
27
+ trainer = object_dict["trainer"]
28
+
29
+ if not trainer.logger:
30
+ log.warning("Logger not found! Skipping hyperparameter logging...")
31
+ return
32
+
33
+ hparams["model"] = cfg["model"]
34
+
35
+ # save number of model parameters
36
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37
+ hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
39
+
40
+ hparams["data"] = cfg["data"]
41
+ hparams["trainer"] = cfg["trainer"]
42
+
43
+ hparams["callbacks"] = cfg.get("callbacks")
44
+ hparams["extras"] = cfg.get("extras")
45
+
46
+ hparams["task_name"] = cfg.get("task_name")
47
+ hparams["tags"] = cfg.get("tags")
48
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
49
+ hparams["seed"] = cfg.get("seed")
50
+
51
+ # send hparams to all loggers
52
+ for logger in trainer.loggers:
53
+ logger.log_hyperparams(hparams)