xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -1,25 +1,26 @@
|
|
|
1
|
-
# A inference only version of the FireflyGAN model
|
|
2
|
-
|
|
3
1
|
import math
|
|
4
2
|
from functools import partial
|
|
5
3
|
from math import prod
|
|
6
4
|
from typing import Callable
|
|
7
5
|
|
|
8
|
-
import numpy as np
|
|
9
6
|
import torch
|
|
10
7
|
import torch.nn.functional as F
|
|
11
8
|
from torch import nn
|
|
12
|
-
from torch.nn import Conv1d
|
|
13
9
|
from torch.nn.utils.parametrizations import weight_norm
|
|
14
10
|
from torch.nn.utils.parametrize import remove_parametrizations
|
|
15
11
|
from torch.utils.checkpoint import checkpoint
|
|
16
12
|
|
|
17
|
-
|
|
13
|
+
|
|
14
|
+
def sequence_mask(length, max_length=None):
|
|
15
|
+
if max_length is None:
|
|
16
|
+
max_length = length.max()
|
|
17
|
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
18
|
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def init_weights(m, mean=0.0, std=0.01):
|
|
21
22
|
classname = m.__class__.__name__
|
|
22
|
-
if classname.find("
|
|
23
|
+
if classname.find("Conv1D") != -1:
|
|
23
24
|
m.weight.data.normal_(mean, std)
|
|
24
25
|
|
|
25
26
|
|
|
@@ -27,78 +28,141 @@ def get_padding(kernel_size, dilation=1):
|
|
|
27
28
|
return (kernel_size * dilation - dilation) // 2
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
|
|
32
|
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
|
33
|
+
padding_left, padding_right = paddings
|
|
34
|
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
35
|
+
assert (padding_left + padding_right) <= x.shape[-1]
|
|
36
|
+
end = x.shape[-1] - padding_right
|
|
37
|
+
return x[..., padding_left:end]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_extra_padding_for_conv1d(
|
|
41
|
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
42
|
+
) -> int:
|
|
43
|
+
"""See `pad_for_conv1d`."""
|
|
44
|
+
length = x.shape[-1]
|
|
45
|
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
|
46
|
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
|
47
|
+
return ideal_length - length
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def pad1d(
|
|
51
|
+
x: torch.Tensor,
|
|
52
|
+
paddings: tuple[int, int],
|
|
53
|
+
mode: str = "zeros",
|
|
54
|
+
value: float = 0.0,
|
|
55
|
+
):
|
|
56
|
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
57
|
+
If this is the case, we insert extra 0 padding to the right
|
|
58
|
+
before the reflection happen.
|
|
59
|
+
"""
|
|
60
|
+
length = x.shape[-1]
|
|
61
|
+
padding_left, padding_right = paddings
|
|
62
|
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
63
|
+
if mode == "reflect":
|
|
64
|
+
max_pad = max(padding_left, padding_right)
|
|
65
|
+
extra_pad = 0
|
|
66
|
+
if length <= max_pad:
|
|
67
|
+
extra_pad = max_pad - length + 1
|
|
68
|
+
x = F.pad(x, (0, extra_pad))
|
|
69
|
+
padded = F.pad(x, paddings, mode, value)
|
|
70
|
+
end = padded.shape[-1] - extra_pad
|
|
71
|
+
return padded[..., :end]
|
|
72
|
+
else:
|
|
73
|
+
return F.pad(x, paddings, mode, value)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class FishConvNet(nn.Module):
|
|
77
|
+
def __init__(
|
|
78
|
+
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
|
|
79
|
+
):
|
|
80
|
+
super(FishConvNet, self).__init__()
|
|
81
|
+
self.conv = nn.Conv1d(
|
|
82
|
+
in_channels,
|
|
83
|
+
out_channels,
|
|
84
|
+
kernel_size,
|
|
85
|
+
stride=stride,
|
|
86
|
+
dilation=dilation,
|
|
87
|
+
groups=groups,
|
|
88
|
+
)
|
|
89
|
+
self.stride = stride
|
|
90
|
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
|
91
|
+
self.dilation = dilation
|
|
92
|
+
|
|
93
|
+
def forward(self, x):
|
|
94
|
+
pad = self.kernel_size - self.stride
|
|
95
|
+
extra_padding = get_extra_padding_for_conv1d(
|
|
96
|
+
x, self.kernel_size, self.stride, pad
|
|
97
|
+
)
|
|
98
|
+
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
|
99
|
+
return self.conv(x).contiguous()
|
|
100
|
+
|
|
101
|
+
def weight_norm(self, name="weight", dim=0):
|
|
102
|
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
103
|
+
return self
|
|
104
|
+
|
|
105
|
+
def remove_weight_norm(self):
|
|
106
|
+
self.conv = remove_parametrizations(self.conv)
|
|
107
|
+
return self
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class FishTransConvNet(nn.Module):
|
|
111
|
+
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
|
|
112
|
+
super(FishTransConvNet, self).__init__()
|
|
113
|
+
self.conv = nn.ConvTranspose1d(
|
|
114
|
+
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
|
115
|
+
)
|
|
116
|
+
self.stride = stride
|
|
117
|
+
self.kernel_size = kernel_size
|
|
118
|
+
|
|
119
|
+
def forward(self, x):
|
|
120
|
+
x = self.conv(x)
|
|
121
|
+
pad = self.kernel_size - self.stride
|
|
122
|
+
padding_right = math.ceil(pad)
|
|
123
|
+
padding_left = pad - padding_right
|
|
124
|
+
x = unpad1d(x, (padding_left, padding_right))
|
|
125
|
+
return x.contiguous()
|
|
126
|
+
|
|
127
|
+
def weight_norm(self, name="weight", dim=0):
|
|
128
|
+
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
def remove_weight_norm(self):
|
|
132
|
+
self.conv = remove_parametrizations(self.conv)
|
|
133
|
+
return self
|
|
134
|
+
|
|
135
|
+
|
|
30
136
|
class ResBlock1(torch.nn.Module):
|
|
31
137
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
|
32
138
|
super().__init__()
|
|
33
139
|
|
|
34
140
|
self.convs1 = nn.ModuleList(
|
|
35
141
|
[
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
),
|
|
46
|
-
weight_norm(
|
|
47
|
-
Conv1d(
|
|
48
|
-
channels,
|
|
49
|
-
channels,
|
|
50
|
-
kernel_size,
|
|
51
|
-
1,
|
|
52
|
-
dilation=dilation[1],
|
|
53
|
-
padding=get_padding(kernel_size, dilation[1]),
|
|
54
|
-
)
|
|
55
|
-
),
|
|
56
|
-
weight_norm(
|
|
57
|
-
Conv1d(
|
|
58
|
-
channels,
|
|
59
|
-
channels,
|
|
60
|
-
kernel_size,
|
|
61
|
-
1,
|
|
62
|
-
dilation=dilation[2],
|
|
63
|
-
padding=get_padding(kernel_size, dilation[2]),
|
|
64
|
-
)
|
|
65
|
-
),
|
|
142
|
+
FishConvNet(
|
|
143
|
+
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
|
144
|
+
).weight_norm(),
|
|
145
|
+
FishConvNet(
|
|
146
|
+
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
|
147
|
+
).weight_norm(),
|
|
148
|
+
FishConvNet(
|
|
149
|
+
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
|
150
|
+
).weight_norm(),
|
|
66
151
|
]
|
|
67
152
|
)
|
|
68
153
|
self.convs1.apply(init_weights)
|
|
69
154
|
|
|
70
155
|
self.convs2 = nn.ModuleList(
|
|
71
156
|
[
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
),
|
|
82
|
-
weight_norm(
|
|
83
|
-
Conv1d(
|
|
84
|
-
channels,
|
|
85
|
-
channels,
|
|
86
|
-
kernel_size,
|
|
87
|
-
1,
|
|
88
|
-
dilation=1,
|
|
89
|
-
padding=get_padding(kernel_size, 1),
|
|
90
|
-
)
|
|
91
|
-
),
|
|
92
|
-
weight_norm(
|
|
93
|
-
Conv1d(
|
|
94
|
-
channels,
|
|
95
|
-
channels,
|
|
96
|
-
kernel_size,
|
|
97
|
-
1,
|
|
98
|
-
dilation=1,
|
|
99
|
-
padding=get_padding(kernel_size, 1),
|
|
100
|
-
)
|
|
101
|
-
),
|
|
157
|
+
FishConvNet(
|
|
158
|
+
channels, channels, kernel_size, stride=1, dilation=dilation[0]
|
|
159
|
+
).weight_norm(),
|
|
160
|
+
FishConvNet(
|
|
161
|
+
channels, channels, kernel_size, stride=1, dilation=dilation[1]
|
|
162
|
+
).weight_norm(),
|
|
163
|
+
FishConvNet(
|
|
164
|
+
channels, channels, kernel_size, stride=1, dilation=dilation[2]
|
|
165
|
+
).weight_norm(),
|
|
102
166
|
]
|
|
103
167
|
)
|
|
104
168
|
self.convs2.apply(init_weights)
|
|
@@ -119,7 +183,7 @@ class ResBlock1(torch.nn.Module):
|
|
|
119
183
|
remove_parametrizations(conv, tensor_name="weight")
|
|
120
184
|
|
|
121
185
|
|
|
122
|
-
class
|
|
186
|
+
class ParallelBlock(nn.Module):
|
|
123
187
|
def __init__(
|
|
124
188
|
self,
|
|
125
189
|
channels: int,
|
|
@@ -153,7 +217,6 @@ class HiFiGANGenerator(nn.Module):
|
|
|
153
217
|
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
|
154
218
|
num_mels: int = 128,
|
|
155
219
|
upsample_initial_channel: int = 512,
|
|
156
|
-
use_template: bool = True,
|
|
157
220
|
pre_conv_kernel_size: int = 7,
|
|
158
221
|
post_conv_kernel_size: int = 7,
|
|
159
222
|
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
|
@@ -164,85 +227,51 @@ class HiFiGANGenerator(nn.Module):
|
|
|
164
227
|
prod(upsample_rates) == hop_length
|
|
165
228
|
), f"hop_length must be {prod(upsample_rates)}"
|
|
166
229
|
|
|
167
|
-
self.conv_pre =
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
padding=get_padding(pre_conv_kernel_size),
|
|
174
|
-
)
|
|
175
|
-
)
|
|
230
|
+
self.conv_pre = FishConvNet(
|
|
231
|
+
num_mels,
|
|
232
|
+
upsample_initial_channel,
|
|
233
|
+
pre_conv_kernel_size,
|
|
234
|
+
stride=1,
|
|
235
|
+
).weight_norm()
|
|
176
236
|
|
|
177
237
|
self.num_upsamples = len(upsample_rates)
|
|
178
238
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
179
239
|
|
|
180
240
|
self.noise_convs = nn.ModuleList()
|
|
181
|
-
self.use_template = use_template
|
|
182
241
|
self.ups = nn.ModuleList()
|
|
183
242
|
|
|
184
243
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
185
|
-
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
|
186
244
|
self.ups.append(
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
padding=(k - u) // 2,
|
|
194
|
-
)
|
|
195
|
-
)
|
|
245
|
+
FishTransConvNet(
|
|
246
|
+
upsample_initial_channel // (2**i),
|
|
247
|
+
upsample_initial_channel // (2 ** (i + 1)),
|
|
248
|
+
k,
|
|
249
|
+
stride=u,
|
|
250
|
+
).weight_norm()
|
|
196
251
|
)
|
|
197
252
|
|
|
198
|
-
if not use_template:
|
|
199
|
-
continue
|
|
200
|
-
|
|
201
|
-
if i + 1 < len(upsample_rates):
|
|
202
|
-
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
|
203
|
-
self.noise_convs.append(
|
|
204
|
-
Conv1d(
|
|
205
|
-
1,
|
|
206
|
-
c_cur,
|
|
207
|
-
kernel_size=stride_f0 * 2,
|
|
208
|
-
stride=stride_f0,
|
|
209
|
-
padding=stride_f0 // 2,
|
|
210
|
-
)
|
|
211
|
-
)
|
|
212
|
-
else:
|
|
213
|
-
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
|
214
|
-
|
|
215
253
|
self.resblocks = nn.ModuleList()
|
|
216
254
|
for i in range(len(self.ups)):
|
|
217
255
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
|
218
256
|
self.resblocks.append(
|
|
219
|
-
|
|
257
|
+
ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
|
|
220
258
|
)
|
|
221
259
|
|
|
222
260
|
self.activation_post = post_activation()
|
|
223
|
-
self.conv_post =
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
1,
|
|
227
|
-
post_conv_kernel_size,
|
|
228
|
-
1,
|
|
229
|
-
padding=get_padding(post_conv_kernel_size),
|
|
230
|
-
)
|
|
231
|
-
)
|
|
261
|
+
self.conv_post = FishConvNet(
|
|
262
|
+
ch, 1, post_conv_kernel_size, stride=1
|
|
263
|
+
).weight_norm()
|
|
232
264
|
self.ups.apply(init_weights)
|
|
233
265
|
self.conv_post.apply(init_weights)
|
|
234
266
|
|
|
235
|
-
def forward(self, x
|
|
267
|
+
def forward(self, x):
|
|
236
268
|
x = self.conv_pre(x)
|
|
237
269
|
|
|
238
270
|
for i in range(self.num_upsamples):
|
|
239
271
|
x = F.silu(x, inplace=True)
|
|
240
272
|
x = self.ups[i](x)
|
|
241
273
|
|
|
242
|
-
if self.
|
|
243
|
-
x = x + self.noise_convs[i](template)
|
|
244
|
-
|
|
245
|
-
if self.training:
|
|
274
|
+
if self.training and self.checkpointing:
|
|
246
275
|
x = checkpoint(
|
|
247
276
|
self.resblocks[i],
|
|
248
277
|
x,
|
|
@@ -364,11 +393,11 @@ class ConvNeXtBlock(nn.Module):
|
|
|
364
393
|
):
|
|
365
394
|
super().__init__()
|
|
366
395
|
|
|
367
|
-
self.dwconv =
|
|
396
|
+
self.dwconv = FishConvNet(
|
|
368
397
|
dim,
|
|
369
398
|
dim,
|
|
370
399
|
kernel_size=kernel_size,
|
|
371
|
-
padding=int(dilation * (kernel_size - 1) / 2),
|
|
400
|
+
# padding=int(dilation * (kernel_size - 1) / 2),
|
|
372
401
|
groups=dim,
|
|
373
402
|
) # depthwise conv
|
|
374
403
|
self.norm = LayerNorm(dim, eps=1e-6)
|
|
@@ -421,12 +450,13 @@ class ConvNeXtEncoder(nn.Module):
|
|
|
421
450
|
|
|
422
451
|
self.downsample_layers = nn.ModuleList()
|
|
423
452
|
stem = nn.Sequential(
|
|
424
|
-
|
|
453
|
+
FishConvNet(
|
|
425
454
|
input_channels,
|
|
426
455
|
dims[0],
|
|
427
|
-
kernel_size=
|
|
428
|
-
padding=
|
|
429
|
-
padding_mode="
|
|
456
|
+
kernel_size=7,
|
|
457
|
+
# padding=3,
|
|
458
|
+
# padding_mode="replicate",
|
|
459
|
+
# padding_mode="zeros",
|
|
430
460
|
),
|
|
431
461
|
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
|
432
462
|
)
|
|
@@ -491,6 +521,7 @@ class FireflyArchitecture(nn.Module):
|
|
|
491
521
|
self.head = head
|
|
492
522
|
self.quantizer = quantizer
|
|
493
523
|
self.spec_transform = spec_transform
|
|
524
|
+
self.downsample_factor = math.prod(self.quantizer.downsample_factor)
|
|
494
525
|
|
|
495
526
|
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
|
|
496
527
|
if self.spec_transform is not None:
|
|
@@ -512,7 +543,7 @@ class FireflyArchitecture(nn.Module):
|
|
|
512
543
|
if x.ndim == 2:
|
|
513
544
|
x = x[:, None, :]
|
|
514
545
|
|
|
515
|
-
if self.
|
|
546
|
+
if self.vq is not None:
|
|
516
547
|
return x, vq_result
|
|
517
548
|
|
|
518
549
|
return x
|
|
@@ -528,25 +559,30 @@ class FireflyArchitecture(nn.Module):
|
|
|
528
559
|
|
|
529
560
|
# Encode
|
|
530
561
|
encoded_features = self.backbone(mels) * mel_masks_float_conv
|
|
531
|
-
feature_lengths = mel_lengths //
|
|
562
|
+
feature_lengths = mel_lengths // self.downsample_factor
|
|
532
563
|
|
|
533
564
|
return self.quantizer.encode(encoded_features), feature_lengths
|
|
534
565
|
|
|
535
566
|
def decode(self, indices, feature_lengths) -> torch.Tensor:
|
|
536
|
-
|
|
537
|
-
|
|
567
|
+
mel_masks = sequence_mask(
|
|
568
|
+
feature_lengths * self.downsample_factor,
|
|
569
|
+
indices.shape[2] * self.downsample_factor,
|
|
570
|
+
)
|
|
538
571
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
572
|
+
audio_lengths = (
|
|
573
|
+
feature_lengths * self.downsample_factor * self.spec_transform.hop_length
|
|
574
|
+
)
|
|
539
575
|
|
|
540
576
|
audio_masks = sequence_mask(
|
|
541
|
-
|
|
542
|
-
indices.shape[2] *
|
|
577
|
+
audio_lengths,
|
|
578
|
+
indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
|
|
543
579
|
)
|
|
544
580
|
audio_masks_float_conv = audio_masks[:, None, :].float()
|
|
545
581
|
|
|
546
582
|
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
|
547
583
|
x = self.head(z) * audio_masks_float_conv
|
|
548
584
|
|
|
549
|
-
return x
|
|
585
|
+
return x, audio_lengths
|
|
550
586
|
|
|
551
587
|
def remove_parametrizations(self):
|
|
552
588
|
if hasattr(self.backbone, "remove_parametrizations"):
|
|
@@ -558,68 +594,3 @@ class FireflyArchitecture(nn.Module):
|
|
|
558
594
|
@property
|
|
559
595
|
def device(self):
|
|
560
596
|
return next(self.parameters()).device
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
class FireflyBase(nn.Module):
|
|
564
|
-
def __init__(self, ckpt_path: str = None, pretrained: bool = True):
|
|
565
|
-
super().__init__()
|
|
566
|
-
|
|
567
|
-
self.backbone = ConvNeXtEncoder(
|
|
568
|
-
input_channels=128,
|
|
569
|
-
depths=[3, 3, 9, 3],
|
|
570
|
-
dims=[128, 256, 384, 512],
|
|
571
|
-
drop_path_rate=0.2,
|
|
572
|
-
kernel_size=7,
|
|
573
|
-
)
|
|
574
|
-
|
|
575
|
-
self.head = HiFiGANGenerator(
|
|
576
|
-
hop_length=512,
|
|
577
|
-
upsample_rates=[8, 8, 2, 2, 2],
|
|
578
|
-
upsample_kernel_sizes=[16, 16, 4, 4, 4],
|
|
579
|
-
resblock_kernel_sizes=[3, 7, 11],
|
|
580
|
-
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
581
|
-
num_mels=512,
|
|
582
|
-
upsample_initial_channel=512,
|
|
583
|
-
use_template=False,
|
|
584
|
-
pre_conv_kernel_size=13,
|
|
585
|
-
post_conv_kernel_size=13,
|
|
586
|
-
)
|
|
587
|
-
|
|
588
|
-
if ckpt_path is not None:
|
|
589
|
-
state_dict = torch.load(ckpt_path, map_location="cpu")
|
|
590
|
-
elif pretrained:
|
|
591
|
-
state_dict = torch.hub.load_state_dict_from_url(
|
|
592
|
-
"https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
|
|
593
|
-
map_location="cpu",
|
|
594
|
-
model_dir="checkpoints",
|
|
595
|
-
)
|
|
596
|
-
|
|
597
|
-
if "state_dict" in state_dict:
|
|
598
|
-
state_dict = state_dict["state_dict"]
|
|
599
|
-
|
|
600
|
-
if any("generator." in k for k in state_dict):
|
|
601
|
-
state_dict = {
|
|
602
|
-
k.replace("generator.", ""): v
|
|
603
|
-
for k, v in state_dict.items()
|
|
604
|
-
if "generator." in k
|
|
605
|
-
}
|
|
606
|
-
|
|
607
|
-
self.load_state_dict(state_dict, strict=True)
|
|
608
|
-
self.head.remove_parametrizations()
|
|
609
|
-
|
|
610
|
-
@torch.no_grad()
|
|
611
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
612
|
-
x = self.backbone(x)
|
|
613
|
-
x = self.head(x)
|
|
614
|
-
if x.ndim == 2:
|
|
615
|
-
x = x[:, None, :]
|
|
616
|
-
return x
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
if __name__ == "__main__":
|
|
620
|
-
model = FireflyBase()
|
|
621
|
-
model.eval()
|
|
622
|
-
x = torch.randn(1, 128, 128)
|
|
623
|
-
with torch.no_grad():
|
|
624
|
-
y = model(x)
|
|
625
|
-
print(y.shape)
|
|
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
from vector_quantize_pytorch import GroupedResidualFSQ
|
|
8
8
|
|
|
9
|
-
from .firefly import ConvNeXtBlock
|
|
9
|
+
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@dataclass
|
|
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
22
|
input_dim: int = 512,
|
|
23
|
-
n_codebooks: int =
|
|
23
|
+
n_codebooks: int = 9,
|
|
24
24
|
n_groups: int = 1,
|
|
25
25
|
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
|
26
26
|
downsample_factor: tuple[int] = (2, 2),
|
|
@@ -46,7 +46,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
|
46
46
|
self.downsample = nn.Sequential(
|
|
47
47
|
*[
|
|
48
48
|
nn.Sequential(
|
|
49
|
-
|
|
49
|
+
FishConvNet(
|
|
50
50
|
all_dims[idx],
|
|
51
51
|
all_dims[idx + 1],
|
|
52
52
|
kernel_size=factor,
|
|
@@ -61,7 +61,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
|
61
61
|
self.upsample = nn.Sequential(
|
|
62
62
|
*[
|
|
63
63
|
nn.Sequential(
|
|
64
|
-
|
|
64
|
+
FishTransConvNet(
|
|
65
65
|
all_dims[idx + 1],
|
|
66
66
|
all_dims[idx],
|
|
67
67
|
kernel_size=factor,
|
|
@@ -114,26 +114,3 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
|
114
114
|
z_q = self.residual_fsq.get_output_from_indices(indices)
|
|
115
115
|
z_q = self.upsample(z_q.mT)
|
|
116
116
|
return z_q
|
|
117
|
-
|
|
118
|
-
# def from_latents(self, latents: torch.Tensor):
|
|
119
|
-
# z_q, z_p, codes = super().from_latents(latents)
|
|
120
|
-
# z_q = self.upsample(z_q)
|
|
121
|
-
# return z_q, z_p, codes
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if __name__ == "__main__":
|
|
125
|
-
rvq = DownsampleFiniteScalarQuantize(
|
|
126
|
-
n_codebooks=1,
|
|
127
|
-
downsample_factor=(2, 2),
|
|
128
|
-
)
|
|
129
|
-
x = torch.randn(16, 512, 80)
|
|
130
|
-
|
|
131
|
-
result = rvq(x)
|
|
132
|
-
print(rvq)
|
|
133
|
-
print(result.latents.shape, result.codes.shape, result.z.shape)
|
|
134
|
-
|
|
135
|
-
# y = rvq.from_codes(result.codes)
|
|
136
|
-
# print(y[0].shape)
|
|
137
|
-
|
|
138
|
-
# y = rvq.from_latents(result.latents)
|
|
139
|
-
# print(y[0].shape)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
*.egg-info/
|
|
24
|
+
.installed.cfg
|
|
25
|
+
*.egg
|
|
26
|
+
MANIFEST
|
|
27
|
+
|
|
28
|
+
# PyInstaller
|
|
29
|
+
# Usually these files are written by a python script from a template
|
|
30
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
31
|
+
*.manifest
|
|
32
|
+
*.spec
|
|
33
|
+
|
|
34
|
+
# Installer logs
|
|
35
|
+
pip-log.txt
|
|
36
|
+
pip-delete-this-directory.txt
|
|
37
|
+
|
|
38
|
+
# Unit test / coverage reports
|
|
39
|
+
htmlcov/
|
|
40
|
+
.tox/
|
|
41
|
+
.coverage
|
|
42
|
+
.coverage.*
|
|
43
|
+
.cache
|
|
44
|
+
nosetests.xml
|
|
45
|
+
coverage.xml
|
|
46
|
+
*.cover
|
|
47
|
+
.hypothesis/
|
|
48
|
+
.pytest_cache/
|
|
49
|
+
|
|
50
|
+
# Translations
|
|
51
|
+
*.mo
|
|
52
|
+
*.pot
|
|
53
|
+
|
|
54
|
+
# Django stuff:
|
|
55
|
+
*.log
|
|
56
|
+
local_settings.py
|
|
57
|
+
db.sqlite3
|
|
58
|
+
|
|
59
|
+
# Flask stuff:
|
|
60
|
+
instance/
|
|
61
|
+
.webassets-cache
|
|
62
|
+
|
|
63
|
+
# Scrapy stuff:
|
|
64
|
+
.scrapy
|
|
65
|
+
|
|
66
|
+
# Sphinx documentation
|
|
67
|
+
docs/_build/
|
|
68
|
+
|
|
69
|
+
# PyBuilder
|
|
70
|
+
target/
|
|
71
|
+
|
|
72
|
+
# Jupyter Notebook
|
|
73
|
+
.ipynb_checkpoints
|
|
74
|
+
|
|
75
|
+
# pyenv
|
|
76
|
+
.python-version
|
|
77
|
+
|
|
78
|
+
# celery beat schedule file
|
|
79
|
+
celerybeat-schedule
|
|
80
|
+
|
|
81
|
+
# SageMath parsed files
|
|
82
|
+
*.sage.py
|
|
83
|
+
|
|
84
|
+
# Environments
|
|
85
|
+
.env
|
|
86
|
+
.venv
|
|
87
|
+
env/
|
|
88
|
+
venv/
|
|
89
|
+
ENV/
|
|
90
|
+
env.bak/
|
|
91
|
+
venv.bak/
|
|
92
|
+
|
|
93
|
+
# Spyder project settings
|
|
94
|
+
.spyderproject
|
|
95
|
+
.spyproject
|
|
96
|
+
|
|
97
|
+
# Rope project settings
|
|
98
|
+
.ropeproject
|
|
99
|
+
|
|
100
|
+
# mkdocs documentation
|
|
101
|
+
/site
|
|
102
|
+
|
|
103
|
+
# mypy
|
|
104
|
+
.mypy_cache/
|
|
105
|
+
|
|
106
|
+
# JetBrains PyCharm
|
|
107
|
+
.idea
|
|
108
|
+
|
|
109
|
+
# Customize
|
|
110
|
+
references
|
|
111
|
+
url.txt
|
|
112
|
+
|
|
113
|
+
# Git
|
|
114
|
+
.git
|