xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -7
- xinference/client/handlers.py +3 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +2 -0
- xinference/core/scheduler.py +4 -7
- xinference/core/supervisor.py +114 -23
- xinference/core/worker.py +70 -4
- xinference/deploy/local.py +2 -1
- xinference/model/audio/core.py +11 -0
- xinference/model/audio/cosyvoice.py +16 -5
- xinference/model/audio/kokoro.py +139 -0
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +80 -0
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/llm/llama_cpp/core.py +21 -14
- xinference/model/llm/llm_family.json +527 -1
- xinference/model/llm/llm_family.py +4 -1
- xinference/model/llm/llm_family_modelscope.json +495 -3
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +24 -6
- xinference/model/llm/transformers/core.py +9 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -3
- xinference/model/llm/transformers/utils.py +22 -11
- xinference/model/llm/utils.py +115 -1
- xinference/model/llm/vllm/core.py +14 -4
- xinference/model/llm/vllm/xavier/block.py +3 -4
- xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/executor.py +18 -16
- xinference/model/llm/vllm/xavier/scheduler.py +79 -63
- xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
- xinference/model/llm/vllm/xavier/transfer.py +53 -32
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/melo/__init__.py +0 -0
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
- /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn import functional as F
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
|
8
|
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
|
9
|
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def piecewise_rational_quadratic_transform(
|
|
13
|
+
inputs,
|
|
14
|
+
unnormalized_widths,
|
|
15
|
+
unnormalized_heights,
|
|
16
|
+
unnormalized_derivatives,
|
|
17
|
+
inverse=False,
|
|
18
|
+
tails=None,
|
|
19
|
+
tail_bound=1.0,
|
|
20
|
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
|
21
|
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
|
22
|
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
|
23
|
+
):
|
|
24
|
+
if tails is None:
|
|
25
|
+
spline_fn = rational_quadratic_spline
|
|
26
|
+
spline_kwargs = {}
|
|
27
|
+
else:
|
|
28
|
+
spline_fn = unconstrained_rational_quadratic_spline
|
|
29
|
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
|
30
|
+
|
|
31
|
+
outputs, logabsdet = spline_fn(
|
|
32
|
+
inputs=inputs,
|
|
33
|
+
unnormalized_widths=unnormalized_widths,
|
|
34
|
+
unnormalized_heights=unnormalized_heights,
|
|
35
|
+
unnormalized_derivatives=unnormalized_derivatives,
|
|
36
|
+
inverse=inverse,
|
|
37
|
+
min_bin_width=min_bin_width,
|
|
38
|
+
min_bin_height=min_bin_height,
|
|
39
|
+
min_derivative=min_derivative,
|
|
40
|
+
**spline_kwargs
|
|
41
|
+
)
|
|
42
|
+
return outputs, logabsdet
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
|
46
|
+
bin_locations[..., -1] += eps
|
|
47
|
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def unconstrained_rational_quadratic_spline(
|
|
51
|
+
inputs,
|
|
52
|
+
unnormalized_widths,
|
|
53
|
+
unnormalized_heights,
|
|
54
|
+
unnormalized_derivatives,
|
|
55
|
+
inverse=False,
|
|
56
|
+
tails="linear",
|
|
57
|
+
tail_bound=1.0,
|
|
58
|
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
|
59
|
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
|
60
|
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
|
61
|
+
):
|
|
62
|
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
|
63
|
+
outside_interval_mask = ~inside_interval_mask
|
|
64
|
+
|
|
65
|
+
outputs = torch.zeros_like(inputs)
|
|
66
|
+
logabsdet = torch.zeros_like(inputs)
|
|
67
|
+
|
|
68
|
+
if tails == "linear":
|
|
69
|
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
|
70
|
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
|
71
|
+
unnormalized_derivatives[..., 0] = constant
|
|
72
|
+
unnormalized_derivatives[..., -1] = constant
|
|
73
|
+
|
|
74
|
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
|
75
|
+
logabsdet[outside_interval_mask] = 0
|
|
76
|
+
else:
|
|
77
|
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
|
78
|
+
|
|
79
|
+
(
|
|
80
|
+
outputs[inside_interval_mask],
|
|
81
|
+
logabsdet[inside_interval_mask],
|
|
82
|
+
) = rational_quadratic_spline(
|
|
83
|
+
inputs=inputs[inside_interval_mask],
|
|
84
|
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
|
85
|
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
|
86
|
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
|
87
|
+
inverse=inverse,
|
|
88
|
+
left=-tail_bound,
|
|
89
|
+
right=tail_bound,
|
|
90
|
+
bottom=-tail_bound,
|
|
91
|
+
top=tail_bound,
|
|
92
|
+
min_bin_width=min_bin_width,
|
|
93
|
+
min_bin_height=min_bin_height,
|
|
94
|
+
min_derivative=min_derivative,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return outputs, logabsdet
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def rational_quadratic_spline(
|
|
101
|
+
inputs,
|
|
102
|
+
unnormalized_widths,
|
|
103
|
+
unnormalized_heights,
|
|
104
|
+
unnormalized_derivatives,
|
|
105
|
+
inverse=False,
|
|
106
|
+
left=0.0,
|
|
107
|
+
right=1.0,
|
|
108
|
+
bottom=0.0,
|
|
109
|
+
top=1.0,
|
|
110
|
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
|
111
|
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
|
112
|
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
|
113
|
+
):
|
|
114
|
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
|
115
|
+
raise ValueError("Input to a transform is not within its domain")
|
|
116
|
+
|
|
117
|
+
num_bins = unnormalized_widths.shape[-1]
|
|
118
|
+
|
|
119
|
+
if min_bin_width * num_bins > 1.0:
|
|
120
|
+
raise ValueError("Minimal bin width too large for the number of bins")
|
|
121
|
+
if min_bin_height * num_bins > 1.0:
|
|
122
|
+
raise ValueError("Minimal bin height too large for the number of bins")
|
|
123
|
+
|
|
124
|
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
|
125
|
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
|
126
|
+
cumwidths = torch.cumsum(widths, dim=-1)
|
|
127
|
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
|
128
|
+
cumwidths = (right - left) * cumwidths + left
|
|
129
|
+
cumwidths[..., 0] = left
|
|
130
|
+
cumwidths[..., -1] = right
|
|
131
|
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
|
132
|
+
|
|
133
|
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
|
134
|
+
|
|
135
|
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
|
136
|
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
|
137
|
+
cumheights = torch.cumsum(heights, dim=-1)
|
|
138
|
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
|
139
|
+
cumheights = (top - bottom) * cumheights + bottom
|
|
140
|
+
cumheights[..., 0] = bottom
|
|
141
|
+
cumheights[..., -1] = top
|
|
142
|
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
|
143
|
+
|
|
144
|
+
if inverse:
|
|
145
|
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
|
146
|
+
else:
|
|
147
|
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
|
148
|
+
|
|
149
|
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
|
150
|
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
|
151
|
+
|
|
152
|
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
|
153
|
+
delta = heights / widths
|
|
154
|
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
|
155
|
+
|
|
156
|
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
|
157
|
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
|
158
|
+
|
|
159
|
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
|
160
|
+
|
|
161
|
+
if inverse:
|
|
162
|
+
a = (inputs - input_cumheights) * (
|
|
163
|
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
|
164
|
+
) + input_heights * (input_delta - input_derivatives)
|
|
165
|
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
|
166
|
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
|
167
|
+
)
|
|
168
|
+
c = -input_delta * (inputs - input_cumheights)
|
|
169
|
+
|
|
170
|
+
discriminant = b.pow(2) - 4 * a * c
|
|
171
|
+
assert (discriminant >= 0).all()
|
|
172
|
+
|
|
173
|
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
|
174
|
+
outputs = root * input_bin_widths + input_cumwidths
|
|
175
|
+
|
|
176
|
+
theta_one_minus_theta = root * (1 - root)
|
|
177
|
+
denominator = input_delta + (
|
|
178
|
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
|
179
|
+
* theta_one_minus_theta
|
|
180
|
+
)
|
|
181
|
+
derivative_numerator = input_delta.pow(2) * (
|
|
182
|
+
input_derivatives_plus_one * root.pow(2)
|
|
183
|
+
+ 2 * input_delta * theta_one_minus_theta
|
|
184
|
+
+ input_derivatives * (1 - root).pow(2)
|
|
185
|
+
)
|
|
186
|
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
|
187
|
+
|
|
188
|
+
return outputs, -logabsdet
|
|
189
|
+
else:
|
|
190
|
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
|
191
|
+
theta_one_minus_theta = theta * (1 - theta)
|
|
192
|
+
|
|
193
|
+
numerator = input_heights * (
|
|
194
|
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
|
195
|
+
)
|
|
196
|
+
denominator = input_delta + (
|
|
197
|
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
|
198
|
+
* theta_one_minus_theta
|
|
199
|
+
)
|
|
200
|
+
outputs = input_cumheights + numerator / denominator
|
|
201
|
+
|
|
202
|
+
derivative_numerator = input_delta.pow(2) * (
|
|
203
|
+
input_derivatives_plus_one * theta.pow(2)
|
|
204
|
+
+ 2 * input_delta * theta_one_minus_theta
|
|
205
|
+
+ input_derivatives * (1 - theta).pow(2)
|
|
206
|
+
)
|
|
207
|
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
|
208
|
+
|
|
209
|
+
return outputs, logabsdet
|
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import json
|
|
6
|
+
import subprocess
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy.io.wavfile import read
|
|
9
|
+
import torch
|
|
10
|
+
import torchaudio
|
|
11
|
+
import librosa
|
|
12
|
+
from melo.text import cleaned_text_to_sequence, get_bert
|
|
13
|
+
from melo.text.cleaner import clean_text
|
|
14
|
+
from melo import commons
|
|
15
|
+
|
|
16
|
+
MATPLOTLIB_FLAG = False
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
|
|
23
|
+
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
|
24
|
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
|
|
25
|
+
|
|
26
|
+
if hps.data.add_blank:
|
|
27
|
+
phone = commons.intersperse(phone, 0)
|
|
28
|
+
tone = commons.intersperse(tone, 0)
|
|
29
|
+
language = commons.intersperse(language, 0)
|
|
30
|
+
for i in range(len(word2ph)):
|
|
31
|
+
word2ph[i] = word2ph[i] * 2
|
|
32
|
+
word2ph[0] += 1
|
|
33
|
+
|
|
34
|
+
if getattr(hps.data, "disable_bert", False):
|
|
35
|
+
bert = torch.zeros(1024, len(phone))
|
|
36
|
+
ja_bert = torch.zeros(768, len(phone))
|
|
37
|
+
else:
|
|
38
|
+
bert = get_bert(norm_text, word2ph, language_str, device)
|
|
39
|
+
del word2ph
|
|
40
|
+
assert bert.shape[-1] == len(phone), phone
|
|
41
|
+
|
|
42
|
+
if language_str == "ZH":
|
|
43
|
+
bert = bert
|
|
44
|
+
ja_bert = torch.zeros(768, len(phone))
|
|
45
|
+
elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
|
|
46
|
+
ja_bert = bert
|
|
47
|
+
bert = torch.zeros(1024, len(phone))
|
|
48
|
+
else:
|
|
49
|
+
raise NotImplementedError()
|
|
50
|
+
|
|
51
|
+
assert bert.shape[-1] == len(
|
|
52
|
+
phone
|
|
53
|
+
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
|
54
|
+
|
|
55
|
+
phone = torch.LongTensor(phone)
|
|
56
|
+
tone = torch.LongTensor(tone)
|
|
57
|
+
language = torch.LongTensor(language)
|
|
58
|
+
return bert, ja_bert, phone, tone, language
|
|
59
|
+
|
|
60
|
+
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
|
61
|
+
assert os.path.isfile(checkpoint_path)
|
|
62
|
+
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
|
63
|
+
iteration = checkpoint_dict.get("iteration", 0)
|
|
64
|
+
learning_rate = checkpoint_dict.get("learning_rate", 0.)
|
|
65
|
+
if (
|
|
66
|
+
optimizer is not None
|
|
67
|
+
and not skip_optimizer
|
|
68
|
+
and checkpoint_dict["optimizer"] is not None
|
|
69
|
+
):
|
|
70
|
+
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
|
71
|
+
elif optimizer is None and not skip_optimizer:
|
|
72
|
+
# else: Disable this line if Infer and resume checkpoint,then enable the line upper
|
|
73
|
+
new_opt_dict = optimizer.state_dict()
|
|
74
|
+
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
|
|
75
|
+
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
|
|
76
|
+
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
|
|
77
|
+
optimizer.load_state_dict(new_opt_dict)
|
|
78
|
+
|
|
79
|
+
saved_state_dict = checkpoint_dict["model"]
|
|
80
|
+
if hasattr(model, "module"):
|
|
81
|
+
state_dict = model.module.state_dict()
|
|
82
|
+
else:
|
|
83
|
+
state_dict = model.state_dict()
|
|
84
|
+
|
|
85
|
+
new_state_dict = {}
|
|
86
|
+
for k, v in state_dict.items():
|
|
87
|
+
try:
|
|
88
|
+
# assert "emb_g" not in k
|
|
89
|
+
new_state_dict[k] = saved_state_dict[k]
|
|
90
|
+
assert saved_state_dict[k].shape == v.shape, (
|
|
91
|
+
saved_state_dict[k].shape,
|
|
92
|
+
v.shape,
|
|
93
|
+
)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
print(e)
|
|
96
|
+
# For upgrading from the old version
|
|
97
|
+
if "ja_bert_proj" in k:
|
|
98
|
+
v = torch.zeros_like(v)
|
|
99
|
+
logger.warn(
|
|
100
|
+
f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
logger.error(f"{k} is not in the checkpoint")
|
|
104
|
+
|
|
105
|
+
new_state_dict[k] = v
|
|
106
|
+
|
|
107
|
+
if hasattr(model, "module"):
|
|
108
|
+
model.module.load_state_dict(new_state_dict, strict=False)
|
|
109
|
+
else:
|
|
110
|
+
model.load_state_dict(new_state_dict, strict=False)
|
|
111
|
+
|
|
112
|
+
logger.info(
|
|
113
|
+
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return model, optimizer, learning_rate, iteration
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
|
120
|
+
logger.info(
|
|
121
|
+
"Saving model and optimizer state at iteration {} to {}".format(
|
|
122
|
+
iteration, checkpoint_path
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
if hasattr(model, "module"):
|
|
126
|
+
state_dict = model.module.state_dict()
|
|
127
|
+
else:
|
|
128
|
+
state_dict = model.state_dict()
|
|
129
|
+
torch.save(
|
|
130
|
+
{
|
|
131
|
+
"model": state_dict,
|
|
132
|
+
"iteration": iteration,
|
|
133
|
+
"optimizer": optimizer.state_dict(),
|
|
134
|
+
"learning_rate": learning_rate,
|
|
135
|
+
},
|
|
136
|
+
checkpoint_path,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def summarize(
|
|
141
|
+
writer,
|
|
142
|
+
global_step,
|
|
143
|
+
scalars={},
|
|
144
|
+
histograms={},
|
|
145
|
+
images={},
|
|
146
|
+
audios={},
|
|
147
|
+
audio_sampling_rate=22050,
|
|
148
|
+
):
|
|
149
|
+
for k, v in scalars.items():
|
|
150
|
+
writer.add_scalar(k, v, global_step)
|
|
151
|
+
for k, v in histograms.items():
|
|
152
|
+
writer.add_histogram(k, v, global_step)
|
|
153
|
+
for k, v in images.items():
|
|
154
|
+
writer.add_image(k, v, global_step, dataformats="HWC")
|
|
155
|
+
for k, v in audios.items():
|
|
156
|
+
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
|
160
|
+
f_list = glob.glob(os.path.join(dir_path, regex))
|
|
161
|
+
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
|
|
162
|
+
x = f_list[-1]
|
|
163
|
+
return x
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def plot_spectrogram_to_numpy(spectrogram):
|
|
167
|
+
global MATPLOTLIB_FLAG
|
|
168
|
+
if not MATPLOTLIB_FLAG:
|
|
169
|
+
import matplotlib
|
|
170
|
+
|
|
171
|
+
matplotlib.use("Agg")
|
|
172
|
+
MATPLOTLIB_FLAG = True
|
|
173
|
+
mpl_logger = logging.getLogger("matplotlib")
|
|
174
|
+
mpl_logger.setLevel(logging.WARNING)
|
|
175
|
+
import matplotlib.pylab as plt
|
|
176
|
+
import numpy as np
|
|
177
|
+
|
|
178
|
+
fig, ax = plt.subplots(figsize=(10, 2))
|
|
179
|
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
|
180
|
+
plt.colorbar(im, ax=ax)
|
|
181
|
+
plt.xlabel("Frames")
|
|
182
|
+
plt.ylabel("Channels")
|
|
183
|
+
plt.tight_layout()
|
|
184
|
+
|
|
185
|
+
fig.canvas.draw()
|
|
186
|
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
|
187
|
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
188
|
+
plt.close()
|
|
189
|
+
return data
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def plot_alignment_to_numpy(alignment, info=None):
|
|
193
|
+
global MATPLOTLIB_FLAG
|
|
194
|
+
if not MATPLOTLIB_FLAG:
|
|
195
|
+
import matplotlib
|
|
196
|
+
|
|
197
|
+
matplotlib.use("Agg")
|
|
198
|
+
MATPLOTLIB_FLAG = True
|
|
199
|
+
mpl_logger = logging.getLogger("matplotlib")
|
|
200
|
+
mpl_logger.setLevel(logging.WARNING)
|
|
201
|
+
import matplotlib.pylab as plt
|
|
202
|
+
import numpy as np
|
|
203
|
+
|
|
204
|
+
fig, ax = plt.subplots(figsize=(6, 4))
|
|
205
|
+
im = ax.imshow(
|
|
206
|
+
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
|
207
|
+
)
|
|
208
|
+
fig.colorbar(im, ax=ax)
|
|
209
|
+
xlabel = "Decoder timestep"
|
|
210
|
+
if info is not None:
|
|
211
|
+
xlabel += "\n\n" + info
|
|
212
|
+
plt.xlabel(xlabel)
|
|
213
|
+
plt.ylabel("Encoder timestep")
|
|
214
|
+
plt.tight_layout()
|
|
215
|
+
|
|
216
|
+
fig.canvas.draw()
|
|
217
|
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
|
218
|
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
219
|
+
plt.close()
|
|
220
|
+
return data
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def load_wav_to_torch(full_path):
|
|
224
|
+
sampling_rate, data = read(full_path)
|
|
225
|
+
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def load_wav_to_torch_new(full_path):
|
|
229
|
+
audio_norm, sampling_rate = torchaudio.load(full_path, frame_offset=0, num_frames=-1, normalize=True, channels_first=True)
|
|
230
|
+
audio_norm = audio_norm.mean(dim=0)
|
|
231
|
+
return audio_norm, sampling_rate
|
|
232
|
+
|
|
233
|
+
def load_wav_to_torch_librosa(full_path, sr):
|
|
234
|
+
audio_norm, sampling_rate = librosa.load(full_path, sr=sr, mono=True)
|
|
235
|
+
return torch.FloatTensor(audio_norm.astype(np.float32)), sampling_rate
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def load_filepaths_and_text(filename, split="|"):
|
|
239
|
+
with open(filename, encoding="utf-8") as f:
|
|
240
|
+
filepaths_and_text = [line.strip().split(split) for line in f]
|
|
241
|
+
return filepaths_and_text
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def get_hparams(init=True):
|
|
245
|
+
parser = argparse.ArgumentParser()
|
|
246
|
+
parser.add_argument(
|
|
247
|
+
"-c",
|
|
248
|
+
"--config",
|
|
249
|
+
type=str,
|
|
250
|
+
default="./configs/base.json",
|
|
251
|
+
help="JSON file for configuration",
|
|
252
|
+
)
|
|
253
|
+
parser.add_argument('--local_rank', type=int, default=0)
|
|
254
|
+
parser.add_argument('--world-size', type=int, default=1)
|
|
255
|
+
parser.add_argument('--port', type=int, default=10000)
|
|
256
|
+
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
|
|
257
|
+
parser.add_argument('--pretrain_G', type=str, default=None,
|
|
258
|
+
help='pretrain model')
|
|
259
|
+
parser.add_argument('--pretrain_D', type=str, default=None,
|
|
260
|
+
help='pretrain model D')
|
|
261
|
+
parser.add_argument('--pretrain_dur', type=str, default=None,
|
|
262
|
+
help='pretrain model duration')
|
|
263
|
+
|
|
264
|
+
args = parser.parse_args()
|
|
265
|
+
model_dir = os.path.join("./logs", args.model)
|
|
266
|
+
|
|
267
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
268
|
+
|
|
269
|
+
config_path = args.config
|
|
270
|
+
config_save_path = os.path.join(model_dir, "config.json")
|
|
271
|
+
if init:
|
|
272
|
+
with open(config_path, "r") as f:
|
|
273
|
+
data = f.read()
|
|
274
|
+
with open(config_save_path, "w") as f:
|
|
275
|
+
f.write(data)
|
|
276
|
+
else:
|
|
277
|
+
with open(config_save_path, "r") as f:
|
|
278
|
+
data = f.read()
|
|
279
|
+
config = json.loads(data)
|
|
280
|
+
|
|
281
|
+
hparams = HParams(**config)
|
|
282
|
+
hparams.model_dir = model_dir
|
|
283
|
+
hparams.pretrain_G = args.pretrain_G
|
|
284
|
+
hparams.pretrain_D = args.pretrain_D
|
|
285
|
+
hparams.pretrain_dur = args.pretrain_dur
|
|
286
|
+
hparams.port = args.port
|
|
287
|
+
return hparams
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
|
|
291
|
+
"""Freeing up space by deleting saved ckpts
|
|
292
|
+
|
|
293
|
+
Arguments:
|
|
294
|
+
path_to_models -- Path to the model directory
|
|
295
|
+
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
|
296
|
+
sort_by_time -- True -> chronologically delete ckpts
|
|
297
|
+
False -> lexicographically delete ckpts
|
|
298
|
+
"""
|
|
299
|
+
import re
|
|
300
|
+
|
|
301
|
+
ckpts_files = [
|
|
302
|
+
f
|
|
303
|
+
for f in os.listdir(path_to_models)
|
|
304
|
+
if os.path.isfile(os.path.join(path_to_models, f))
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
def name_key(_f):
|
|
308
|
+
return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
|
|
309
|
+
|
|
310
|
+
def time_key(_f):
|
|
311
|
+
return os.path.getmtime(os.path.join(path_to_models, _f))
|
|
312
|
+
|
|
313
|
+
sort_key = time_key if sort_by_time else name_key
|
|
314
|
+
|
|
315
|
+
def x_sorted(_x):
|
|
316
|
+
return sorted(
|
|
317
|
+
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
|
|
318
|
+
key=sort_key,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
to_del = [
|
|
322
|
+
os.path.join(path_to_models, fn)
|
|
323
|
+
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
|
|
324
|
+
]
|
|
325
|
+
|
|
326
|
+
def del_info(fn):
|
|
327
|
+
return logger.info(f".. Free up space by deleting ckpt {fn}")
|
|
328
|
+
|
|
329
|
+
def del_routine(x):
|
|
330
|
+
return [os.remove(x), del_info(x)]
|
|
331
|
+
|
|
332
|
+
[del_routine(fn) for fn in to_del]
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def get_hparams_from_dir(model_dir):
|
|
336
|
+
config_save_path = os.path.join(model_dir, "config.json")
|
|
337
|
+
with open(config_save_path, "r", encoding="utf-8") as f:
|
|
338
|
+
data = f.read()
|
|
339
|
+
config = json.loads(data)
|
|
340
|
+
|
|
341
|
+
hparams = HParams(**config)
|
|
342
|
+
hparams.model_dir = model_dir
|
|
343
|
+
return hparams
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def get_hparams_from_file(config_path):
|
|
347
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
348
|
+
data = f.read()
|
|
349
|
+
config = json.loads(data)
|
|
350
|
+
|
|
351
|
+
hparams = HParams(**config)
|
|
352
|
+
return hparams
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def check_git_hash(model_dir):
|
|
356
|
+
source_dir = os.path.dirname(os.path.realpath(__file__))
|
|
357
|
+
if not os.path.exists(os.path.join(source_dir, ".git")):
|
|
358
|
+
logger.warn(
|
|
359
|
+
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
|
360
|
+
source_dir
|
|
361
|
+
)
|
|
362
|
+
)
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
|
366
|
+
|
|
367
|
+
path = os.path.join(model_dir, "githash")
|
|
368
|
+
if os.path.exists(path):
|
|
369
|
+
saved_hash = open(path).read()
|
|
370
|
+
if saved_hash != cur_hash:
|
|
371
|
+
logger.warn(
|
|
372
|
+
"git hash values are different. {}(saved) != {}(current)".format(
|
|
373
|
+
saved_hash[:8], cur_hash[:8]
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
open(path, "w").write(cur_hash)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def get_logger(model_dir, filename="train.log"):
|
|
381
|
+
global logger
|
|
382
|
+
logger = logging.getLogger(os.path.basename(model_dir))
|
|
383
|
+
logger.setLevel(logging.DEBUG)
|
|
384
|
+
|
|
385
|
+
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
|
386
|
+
if not os.path.exists(model_dir):
|
|
387
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
388
|
+
h = logging.FileHandler(os.path.join(model_dir, filename))
|
|
389
|
+
h.setLevel(logging.DEBUG)
|
|
390
|
+
h.setFormatter(formatter)
|
|
391
|
+
logger.addHandler(h)
|
|
392
|
+
return logger
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class HParams:
|
|
396
|
+
def __init__(self, **kwargs):
|
|
397
|
+
for k, v in kwargs.items():
|
|
398
|
+
if type(v) == dict:
|
|
399
|
+
v = HParams(**v)
|
|
400
|
+
self[k] = v
|
|
401
|
+
|
|
402
|
+
def keys(self):
|
|
403
|
+
return self.__dict__.keys()
|
|
404
|
+
|
|
405
|
+
def items(self):
|
|
406
|
+
return self.__dict__.items()
|
|
407
|
+
|
|
408
|
+
def values(self):
|
|
409
|
+
return self.__dict__.values()
|
|
410
|
+
|
|
411
|
+
def __len__(self):
|
|
412
|
+
return len(self.__dict__)
|
|
413
|
+
|
|
414
|
+
def __getitem__(self, key):
|
|
415
|
+
return getattr(self, key)
|
|
416
|
+
|
|
417
|
+
def __setitem__(self, key, value):
|
|
418
|
+
return setattr(self, key, value)
|
|
419
|
+
|
|
420
|
+
def __contains__(self, key):
|
|
421
|
+
return key in self.__dict__
|
|
422
|
+
|
|
423
|
+
def __repr__(self):
|
|
424
|
+
return self.__dict__.__repr__()
|
xinference/types.py
CHANGED
|
@@ -335,8 +335,10 @@ def get_pydantic_model_from_method(
|
|
|
335
335
|
exclude_fields: Optional[Iterable[str]] = None,
|
|
336
336
|
include_fields: Optional[Dict[str, Any]] = None,
|
|
337
337
|
) -> BaseModel:
|
|
338
|
+
# The validate_arguments set Config.extra = "forbid" by default.
|
|
338
339
|
f = validate_arguments(meth, config={"arbitrary_types_allowed": True})
|
|
339
340
|
model = f.model
|
|
341
|
+
model.Config.extra = "ignore"
|
|
340
342
|
model.__fields__.pop("self", None)
|
|
341
343
|
model.__fields__.pop("args", None)
|
|
342
344
|
model.__fields__.pop("kwargs", None)
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
{
|
|
2
2
|
"files": {
|
|
3
3
|
"main.css": "./static/css/main.51a587ff.css",
|
|
4
|
-
"main.js": "./static/js/main.
|
|
4
|
+
"main.js": "./static/js/main.b0936c54.js",
|
|
5
5
|
"static/media/icon.webp": "./static/media/icon.4603d52c63041e5dfbfd.webp",
|
|
6
6
|
"index.html": "./index.html",
|
|
7
7
|
"main.51a587ff.css.map": "./static/css/main.51a587ff.css.map",
|
|
8
|
-
"main.
|
|
8
|
+
"main.b0936c54.js.map": "./static/js/main.b0936c54.js.map"
|
|
9
9
|
},
|
|
10
10
|
"entrypoints": [
|
|
11
11
|
"static/css/main.51a587ff.css",
|
|
12
|
-
"static/js/main.
|
|
12
|
+
"static/js/main.b0936c54.js"
|
|
13
13
|
]
|
|
14
14
|
}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
<!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.
|
|
1
|
+
<!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.b0936c54.js"></script><link href="./static/css/main.51a587ff.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|