xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,28 @@
1
+ v1 = {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0004,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+ "upsample_rates": [8, 8, 2, 2],
11
+ "upsample_kernel_sizes": [16, 16, 4, 4],
12
+ "upsample_initial_channel": 512,
13
+ "resblock_kernel_sizes": [3, 7, 11],
14
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
15
+ "resblock_initial_channel": 256,
16
+ "segment_size": 8192,
17
+ "num_mels": 80,
18
+ "num_freq": 1025,
19
+ "n_fft": 1024,
20
+ "hop_size": 256,
21
+ "win_size": 1024,
22
+ "sampling_rate": 22050,
23
+ "fmin": 0,
24
+ "fmax": 8000,
25
+ "fmax_loss": None,
26
+ "num_workers": 4,
27
+ "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
28
+ }
@@ -0,0 +1,64 @@
1
+ # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
2
+
3
+ """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
4
+ import torch
5
+
6
+
7
+ class Denoiser(torch.nn.Module):
8
+ """Removes model bias from audio produced with waveglow"""
9
+
10
+ def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
11
+ super().__init__()
12
+ self.filter_length = filter_length
13
+ self.hop_length = int(filter_length / n_overlap)
14
+ self.win_length = win_length
15
+
16
+ dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
17
+ self.device = device
18
+ if mode == "zeros":
19
+ mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
20
+ elif mode == "normal":
21
+ mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
22
+ else:
23
+ raise Exception(f"Mode {mode} if not supported")
24
+
25
+ def stft_fn(audio, n_fft, hop_length, win_length, window):
26
+ spec = torch.stft(
27
+ audio,
28
+ n_fft=n_fft,
29
+ hop_length=hop_length,
30
+ win_length=win_length,
31
+ window=window,
32
+ return_complex=True,
33
+ )
34
+ spec = torch.view_as_real(spec)
35
+ return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
36
+
37
+ self.stft = lambda x: stft_fn(
38
+ audio=x,
39
+ n_fft=self.filter_length,
40
+ hop_length=self.hop_length,
41
+ win_length=self.win_length,
42
+ window=torch.hann_window(self.win_length, device=device),
43
+ )
44
+ self.istft = lambda x, y: torch.istft(
45
+ torch.complex(x * torch.cos(y), x * torch.sin(y)),
46
+ n_fft=self.filter_length,
47
+ hop_length=self.hop_length,
48
+ win_length=self.win_length,
49
+ window=torch.hann_window(self.win_length, device=device),
50
+ )
51
+
52
+ with torch.no_grad():
53
+ bias_audio = vocoder(mel_input).float().squeeze(0)
54
+ bias_spec, _ = self.stft(bias_audio)
55
+
56
+ self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
57
+
58
+ @torch.inference_mode()
59
+ def forward(self, audio, strength=0.0005):
60
+ audio_spec, audio_angles = self.stft(audio)
61
+ audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
62
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
63
+ audio_denoised = self.istft(audio_spec_denoised, audio_angles)
64
+ return audio_denoised
@@ -0,0 +1,17 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import os
4
+ import shutil
5
+
6
+
7
+ class AttrDict(dict):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.__dict__ = self
11
+
12
+
13
+ def build_env(config, config_name, path):
14
+ t_path = os.path.join(path, config_name)
15
+ if config != t_path:
16
+ os.makedirs(path, exist_ok=True)
17
+ shutil.copyfile(config, os.path.join(path, config_name))
@@ -0,0 +1,217 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import math
4
+ import os
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from librosa.util import normalize
12
+ from scipy.io.wavfile import read
13
+
14
+ MAX_WAV_VALUE = 32768.0
15
+
16
+
17
+ def load_wav(full_path):
18
+ sampling_rate, data = read(full_path)
19
+ return data, sampling_rate
20
+
21
+
22
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
23
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
24
+
25
+
26
+ def dynamic_range_decompression(x, C=1):
27
+ return np.exp(x) / C
28
+
29
+
30
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
31
+ return torch.log(torch.clamp(x, min=clip_val) * C)
32
+
33
+
34
+ def dynamic_range_decompression_torch(x, C=1):
35
+ return torch.exp(x) / C
36
+
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+
43
+ def spectral_de_normalize_torch(magnitudes):
44
+ output = dynamic_range_decompression_torch(magnitudes)
45
+ return output
46
+
47
+
48
+ mel_basis = {}
49
+ hann_window = {}
50
+
51
+
52
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
53
+ if torch.min(y) < -1.0:
54
+ print("min value is ", torch.min(y))
55
+ if torch.max(y) > 1.0:
56
+ print("max value is ", torch.max(y))
57
+
58
+ global mel_basis, hann_window # pylint: disable=global-statement
59
+ if fmax not in mel_basis:
60
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
61
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
62
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
63
+
64
+ y = torch.nn.functional.pad(
65
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
66
+ )
67
+ y = y.squeeze(1)
68
+
69
+ spec = torch.view_as_real(
70
+ torch.stft(
71
+ y,
72
+ n_fft,
73
+ hop_length=hop_size,
74
+ win_length=win_size,
75
+ window=hann_window[str(y.device)],
76
+ center=center,
77
+ pad_mode="reflect",
78
+ normalized=False,
79
+ onesided=True,
80
+ return_complex=True,
81
+ )
82
+ )
83
+
84
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
85
+
86
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
87
+ spec = spectral_normalize_torch(spec)
88
+
89
+ return spec
90
+
91
+
92
+ def get_dataset_filelist(a):
93
+ with open(a.input_training_file, encoding="utf-8") as fi:
94
+ training_files = [
95
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
96
+ ]
97
+
98
+ with open(a.input_validation_file, encoding="utf-8") as fi:
99
+ validation_files = [
100
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
101
+ ]
102
+ return training_files, validation_files
103
+
104
+
105
+ class MelDataset(torch.utils.data.Dataset):
106
+ def __init__(
107
+ self,
108
+ training_files,
109
+ segment_size,
110
+ n_fft,
111
+ num_mels,
112
+ hop_size,
113
+ win_size,
114
+ sampling_rate,
115
+ fmin,
116
+ fmax,
117
+ split=True,
118
+ shuffle=True,
119
+ n_cache_reuse=1,
120
+ device=None,
121
+ fmax_loss=None,
122
+ fine_tuning=False,
123
+ base_mels_path=None,
124
+ ):
125
+ self.audio_files = training_files
126
+ random.seed(1234)
127
+ if shuffle:
128
+ random.shuffle(self.audio_files)
129
+ self.segment_size = segment_size
130
+ self.sampling_rate = sampling_rate
131
+ self.split = split
132
+ self.n_fft = n_fft
133
+ self.num_mels = num_mels
134
+ self.hop_size = hop_size
135
+ self.win_size = win_size
136
+ self.fmin = fmin
137
+ self.fmax = fmax
138
+ self.fmax_loss = fmax_loss
139
+ self.cached_wav = None
140
+ self.n_cache_reuse = n_cache_reuse
141
+ self._cache_ref_count = 0
142
+ self.device = device
143
+ self.fine_tuning = fine_tuning
144
+ self.base_mels_path = base_mels_path
145
+
146
+ def __getitem__(self, index):
147
+ filename = self.audio_files[index]
148
+ if self._cache_ref_count == 0:
149
+ audio, sampling_rate = load_wav(filename)
150
+ audio = audio / MAX_WAV_VALUE
151
+ if not self.fine_tuning:
152
+ audio = normalize(audio) * 0.95
153
+ self.cached_wav = audio
154
+ if sampling_rate != self.sampling_rate:
155
+ raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
156
+ self._cache_ref_count = self.n_cache_reuse
157
+ else:
158
+ audio = self.cached_wav
159
+ self._cache_ref_count -= 1
160
+
161
+ audio = torch.FloatTensor(audio)
162
+ audio = audio.unsqueeze(0)
163
+
164
+ if not self.fine_tuning:
165
+ if self.split:
166
+ if audio.size(1) >= self.segment_size:
167
+ max_audio_start = audio.size(1) - self.segment_size
168
+ audio_start = random.randint(0, max_audio_start)
169
+ audio = audio[:, audio_start : audio_start + self.segment_size]
170
+ else:
171
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
172
+
173
+ mel = mel_spectrogram(
174
+ audio,
175
+ self.n_fft,
176
+ self.num_mels,
177
+ self.sampling_rate,
178
+ self.hop_size,
179
+ self.win_size,
180
+ self.fmin,
181
+ self.fmax,
182
+ center=False,
183
+ )
184
+ else:
185
+ mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
186
+ mel = torch.from_numpy(mel)
187
+
188
+ if len(mel.shape) < 3:
189
+ mel = mel.unsqueeze(0)
190
+
191
+ if self.split:
192
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
193
+
194
+ if audio.size(1) >= self.segment_size:
195
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
196
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
197
+ audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
198
+ else:
199
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
200
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
201
+
202
+ mel_loss = mel_spectrogram(
203
+ audio,
204
+ self.n_fft,
205
+ self.num_mels,
206
+ self.sampling_rate,
207
+ self.hop_size,
208
+ self.win_size,
209
+ self.fmin,
210
+ self.fmax_loss,
211
+ center=False,
212
+ )
213
+
214
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
215
+
216
+ def __len__(self):
217
+ return len(self.audio_files)
@@ -0,0 +1,368 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+
9
+ from .xutils import get_padding, init_weights
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ class ResBlock1(torch.nn.Module):
15
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
16
+ super().__init__()
17
+ self.h = h
18
+ self.convs1 = nn.ModuleList(
19
+ [
20
+ weight_norm(
21
+ Conv1d(
22
+ channels,
23
+ channels,
24
+ kernel_size,
25
+ 1,
26
+ dilation=dilation[0],
27
+ padding=get_padding(kernel_size, dilation[0]),
28
+ )
29
+ ),
30
+ weight_norm(
31
+ Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ 1,
36
+ dilation=dilation[1],
37
+ padding=get_padding(kernel_size, dilation[1]),
38
+ )
39
+ ),
40
+ weight_norm(
41
+ Conv1d(
42
+ channels,
43
+ channels,
44
+ kernel_size,
45
+ 1,
46
+ dilation=dilation[2],
47
+ padding=get_padding(kernel_size, dilation[2]),
48
+ )
49
+ ),
50
+ ]
51
+ )
52
+ self.convs1.apply(init_weights)
53
+
54
+ self.convs2 = nn.ModuleList(
55
+ [
56
+ weight_norm(
57
+ Conv1d(
58
+ channels,
59
+ channels,
60
+ kernel_size,
61
+ 1,
62
+ dilation=1,
63
+ padding=get_padding(kernel_size, 1),
64
+ )
65
+ ),
66
+ weight_norm(
67
+ Conv1d(
68
+ channels,
69
+ channels,
70
+ kernel_size,
71
+ 1,
72
+ dilation=1,
73
+ padding=get_padding(kernel_size, 1),
74
+ )
75
+ ),
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ 1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ ),
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ def forward(self, x):
91
+ for c1, c2 in zip(self.convs1, self.convs2):
92
+ xt = F.leaky_relu(x, LRELU_SLOPE)
93
+ xt = c1(xt)
94
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
95
+ xt = c2(xt)
96
+ x = xt + x
97
+ return x
98
+
99
+ def remove_weight_norm(self):
100
+ for l in self.convs1:
101
+ remove_weight_norm(l)
102
+ for l in self.convs2:
103
+ remove_weight_norm(l)
104
+
105
+
106
+ class ResBlock2(torch.nn.Module):
107
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
108
+ super().__init__()
109
+ self.h = h
110
+ self.convs = nn.ModuleList(
111
+ [
112
+ weight_norm(
113
+ Conv1d(
114
+ channels,
115
+ channels,
116
+ kernel_size,
117
+ 1,
118
+ dilation=dilation[0],
119
+ padding=get_padding(kernel_size, dilation[0]),
120
+ )
121
+ ),
122
+ weight_norm(
123
+ Conv1d(
124
+ channels,
125
+ channels,
126
+ kernel_size,
127
+ 1,
128
+ dilation=dilation[1],
129
+ padding=get_padding(kernel_size, dilation[1]),
130
+ )
131
+ ),
132
+ ]
133
+ )
134
+ self.convs.apply(init_weights)
135
+
136
+ def forward(self, x):
137
+ for c in self.convs:
138
+ xt = F.leaky_relu(x, LRELU_SLOPE)
139
+ xt = c(xt)
140
+ x = xt + x
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs:
145
+ remove_weight_norm(l)
146
+
147
+
148
+ class Generator(torch.nn.Module):
149
+ def __init__(self, h):
150
+ super().__init__()
151
+ self.h = h
152
+ self.num_kernels = len(h.resblock_kernel_sizes)
153
+ self.num_upsamples = len(h.upsample_rates)
154
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
155
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
156
+
157
+ self.ups = nn.ModuleList()
158
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
159
+ self.ups.append(
160
+ weight_norm(
161
+ ConvTranspose1d(
162
+ h.upsample_initial_channel // (2**i),
163
+ h.upsample_initial_channel // (2 ** (i + 1)),
164
+ k,
165
+ u,
166
+ padding=(k - u) // 2,
167
+ )
168
+ )
169
+ )
170
+
171
+ self.resblocks = nn.ModuleList()
172
+ for i in range(len(self.ups)):
173
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
174
+ for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
175
+ self.resblocks.append(resblock(h, ch, k, d))
176
+
177
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
178
+ self.ups.apply(init_weights)
179
+ self.conv_post.apply(init_weights)
180
+
181
+ def forward(self, x):
182
+ x = self.conv_pre(x)
183
+ for i in range(self.num_upsamples):
184
+ x = F.leaky_relu(x, LRELU_SLOPE)
185
+ x = self.ups[i](x)
186
+ xs = None
187
+ for j in range(self.num_kernels):
188
+ if xs is None:
189
+ xs = self.resblocks[i * self.num_kernels + j](x)
190
+ else:
191
+ xs += self.resblocks[i * self.num_kernels + j](x)
192
+ x = xs / self.num_kernels
193
+ x = F.leaky_relu(x)
194
+ x = self.conv_post(x)
195
+ x = torch.tanh(x)
196
+
197
+ return x
198
+
199
+ def remove_weight_norm(self):
200
+ print("Removing weight norm...")
201
+ for l in self.ups:
202
+ remove_weight_norm(l)
203
+ for l in self.resblocks:
204
+ l.remove_weight_norm()
205
+ remove_weight_norm(self.conv_pre)
206
+ remove_weight_norm(self.conv_post)
207
+
208
+
209
+ class DiscriminatorP(torch.nn.Module):
210
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
211
+ super().__init__()
212
+ self.period = period
213
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
214
+ self.convs = nn.ModuleList(
215
+ [
216
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
217
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
218
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
219
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
220
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
221
+ ]
222
+ )
223
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
224
+
225
+ def forward(self, x):
226
+ fmap = []
227
+
228
+ # 1d to 2d
229
+ b, c, t = x.shape
230
+ if t % self.period != 0: # pad first
231
+ n_pad = self.period - (t % self.period)
232
+ x = F.pad(x, (0, n_pad), "reflect")
233
+ t = t + n_pad
234
+ x = x.view(b, c, t // self.period, self.period)
235
+
236
+ for l in self.convs:
237
+ x = l(x)
238
+ x = F.leaky_relu(x, LRELU_SLOPE)
239
+ fmap.append(x)
240
+ x = self.conv_post(x)
241
+ fmap.append(x)
242
+ x = torch.flatten(x, 1, -1)
243
+
244
+ return x, fmap
245
+
246
+
247
+ class MultiPeriodDiscriminator(torch.nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+ self.discriminators = nn.ModuleList(
251
+ [
252
+ DiscriminatorP(2),
253
+ DiscriminatorP(3),
254
+ DiscriminatorP(5),
255
+ DiscriminatorP(7),
256
+ DiscriminatorP(11),
257
+ ]
258
+ )
259
+
260
+ def forward(self, y, y_hat):
261
+ y_d_rs = []
262
+ y_d_gs = []
263
+ fmap_rs = []
264
+ fmap_gs = []
265
+ for _, d in enumerate(self.discriminators):
266
+ y_d_r, fmap_r = d(y)
267
+ y_d_g, fmap_g = d(y_hat)
268
+ y_d_rs.append(y_d_r)
269
+ fmap_rs.append(fmap_r)
270
+ y_d_gs.append(y_d_g)
271
+ fmap_gs.append(fmap_g)
272
+
273
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
+
275
+
276
+ class DiscriminatorS(torch.nn.Module):
277
+ def __init__(self, use_spectral_norm=False):
278
+ super().__init__()
279
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
280
+ self.convs = nn.ModuleList(
281
+ [
282
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
283
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
284
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
285
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
286
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
287
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
288
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
289
+ ]
290
+ )
291
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
292
+
293
+ def forward(self, x):
294
+ fmap = []
295
+ for l in self.convs:
296
+ x = l(x)
297
+ x = F.leaky_relu(x, LRELU_SLOPE)
298
+ fmap.append(x)
299
+ x = self.conv_post(x)
300
+ fmap.append(x)
301
+ x = torch.flatten(x, 1, -1)
302
+
303
+ return x, fmap
304
+
305
+
306
+ class MultiScaleDiscriminator(torch.nn.Module):
307
+ def __init__(self):
308
+ super().__init__()
309
+ self.discriminators = nn.ModuleList(
310
+ [
311
+ DiscriminatorS(use_spectral_norm=True),
312
+ DiscriminatorS(),
313
+ DiscriminatorS(),
314
+ ]
315
+ )
316
+ self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
317
+
318
+ def forward(self, y, y_hat):
319
+ y_d_rs = []
320
+ y_d_gs = []
321
+ fmap_rs = []
322
+ fmap_gs = []
323
+ for i, d in enumerate(self.discriminators):
324
+ if i != 0:
325
+ y = self.meanpools[i - 1](y)
326
+ y_hat = self.meanpools[i - 1](y_hat)
327
+ y_d_r, fmap_r = d(y)
328
+ y_d_g, fmap_g = d(y_hat)
329
+ y_d_rs.append(y_d_r)
330
+ fmap_rs.append(fmap_r)
331
+ y_d_gs.append(y_d_g)
332
+ fmap_gs.append(fmap_g)
333
+
334
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
335
+
336
+
337
+ def feature_loss(fmap_r, fmap_g):
338
+ loss = 0
339
+ for dr, dg in zip(fmap_r, fmap_g):
340
+ for rl, gl in zip(dr, dg):
341
+ loss += torch.mean(torch.abs(rl - gl))
342
+
343
+ return loss * 2
344
+
345
+
346
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
347
+ loss = 0
348
+ r_losses = []
349
+ g_losses = []
350
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
351
+ r_loss = torch.mean((1 - dr) ** 2)
352
+ g_loss = torch.mean(dg**2)
353
+ loss += r_loss + g_loss
354
+ r_losses.append(r_loss.item())
355
+ g_losses.append(g_loss.item())
356
+
357
+ return loss, r_losses, g_losses
358
+
359
+
360
+ def generator_loss(disc_outputs):
361
+ loss = 0
362
+ gen_losses = []
363
+ for dg in disc_outputs:
364
+ l = torch.mean((1 - dg) ** 2)
365
+ gen_losses.append(l)
366
+ loss += l
367
+
368
+ return loss, gen_losses