xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,598 @@
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ from . import commons
10
+ from .commons import init_weights, get_padding
11
+ from .transforms import piecewise_rational_quadratic_transform
12
+ from .attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dialted and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert n_layers == 3, n_layers
534
+ assert channels % 2 == 0, "channels should be divisible by 2"
535
+ super().__init__()
536
+ self.channels = channels
537
+ self.hidden_channels = hidden_channels
538
+ self.kernel_size = kernel_size
539
+ self.n_layers = n_layers
540
+ self.half_channels = channels // 2
541
+ self.mean_only = mean_only
542
+
543
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
+ self.enc = (
545
+ Encoder(
546
+ hidden_channels,
547
+ filter_channels,
548
+ n_heads,
549
+ n_layers,
550
+ kernel_size,
551
+ p_dropout,
552
+ isflow=True,
553
+ gin_channels=gin_channels,
554
+ )
555
+ if wn_sharing_parameter is None
556
+ else wn_sharing_parameter
557
+ )
558
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
+ self.post.weight.data.zero_()
560
+ self.post.bias.data.zero_()
561
+
562
+ def forward(self, x, x_mask, g=None, reverse=False):
563
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
+ h = self.pre(x0) * x_mask
565
+ h = self.enc(h, x_mask, g=g)
566
+ stats = self.post(h) * x_mask
567
+ if not self.mean_only:
568
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
+ else:
570
+ m = stats
571
+ logs = torch.zeros_like(m)
572
+
573
+ if not reverse:
574
+ x1 = m + x1 * torch.exp(logs) * x_mask
575
+ x = torch.cat([x0, x1], 1)
576
+ logdet = torch.sum(logs, [1, 2])
577
+ return x, logdet
578
+ else:
579
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
+ x = torch.cat([x0, x1], 1)
581
+ return x
582
+
583
+ x1, logabsdet = piecewise_rational_quadratic_transform(
584
+ x1,
585
+ unnormalized_widths,
586
+ unnormalized_heights,
587
+ unnormalized_derivatives,
588
+ inverse=reverse,
589
+ tails="linear",
590
+ tail_bound=self.tail_bound,
591
+ )
592
+
593
+ x = torch.cat([x0, x1], 1) * x_mask
594
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
595
+ if not reverse:
596
+ return x, logdet
597
+ else:
598
+ return x
@@ -0,0 +1,16 @@
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
@@ -0,0 +1,46 @@
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1