xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
|
8
|
+
|
|
9
|
+
from .firefly import ConvNeXtBlock
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class FSQResult:
|
|
14
|
+
z: torch.Tensor
|
|
15
|
+
codes: torch.Tensor
|
|
16
|
+
latents: torch.Tensor
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DownsampleFiniteScalarQuantize(nn.Module):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
input_dim: int = 512,
|
|
23
|
+
n_codebooks: int = 1,
|
|
24
|
+
n_groups: int = 1,
|
|
25
|
+
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
|
26
|
+
downsample_factor: tuple[int] = (2, 2),
|
|
27
|
+
downsample_dims: tuple[int] | None = None,
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
if downsample_dims is None:
|
|
32
|
+
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
|
33
|
+
|
|
34
|
+
all_dims = (input_dim,) + tuple(downsample_dims)
|
|
35
|
+
|
|
36
|
+
self.residual_fsq = GroupedResidualFSQ(
|
|
37
|
+
dim=all_dims[-1],
|
|
38
|
+
levels=levels,
|
|
39
|
+
num_quantizers=n_codebooks,
|
|
40
|
+
groups=n_groups,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.downsample_factor = downsample_factor
|
|
44
|
+
self.downsample_dims = downsample_dims
|
|
45
|
+
|
|
46
|
+
self.downsample = nn.Sequential(
|
|
47
|
+
*[
|
|
48
|
+
nn.Sequential(
|
|
49
|
+
nn.Conv1d(
|
|
50
|
+
all_dims[idx],
|
|
51
|
+
all_dims[idx + 1],
|
|
52
|
+
kernel_size=factor,
|
|
53
|
+
stride=factor,
|
|
54
|
+
),
|
|
55
|
+
ConvNeXtBlock(dim=all_dims[idx + 1]),
|
|
56
|
+
)
|
|
57
|
+
for idx, factor in enumerate(downsample_factor)
|
|
58
|
+
]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.upsample = nn.Sequential(
|
|
62
|
+
*[
|
|
63
|
+
nn.Sequential(
|
|
64
|
+
nn.ConvTranspose1d(
|
|
65
|
+
all_dims[idx + 1],
|
|
66
|
+
all_dims[idx],
|
|
67
|
+
kernel_size=factor,
|
|
68
|
+
stride=factor,
|
|
69
|
+
),
|
|
70
|
+
ConvNeXtBlock(dim=all_dims[idx]),
|
|
71
|
+
)
|
|
72
|
+
for idx, factor in reversed(list(enumerate(downsample_factor)))
|
|
73
|
+
]
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.apply(self._init_weights)
|
|
77
|
+
|
|
78
|
+
def _init_weights(self, m):
|
|
79
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
80
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
81
|
+
nn.init.constant_(m.bias, 0)
|
|
82
|
+
|
|
83
|
+
def forward(self, z) -> FSQResult:
|
|
84
|
+
original_shape = z.shape
|
|
85
|
+
z = self.downsample(z)
|
|
86
|
+
quantized, indices = self.residual_fsq(z.mT)
|
|
87
|
+
result = FSQResult(
|
|
88
|
+
z=quantized.mT,
|
|
89
|
+
codes=indices.mT,
|
|
90
|
+
latents=z,
|
|
91
|
+
)
|
|
92
|
+
result.z = self.upsample(result.z)
|
|
93
|
+
|
|
94
|
+
# Pad or crop z to match original shape
|
|
95
|
+
diff = original_shape[-1] - result.z.shape[-1]
|
|
96
|
+
left = diff // 2
|
|
97
|
+
right = diff - left
|
|
98
|
+
|
|
99
|
+
if diff > 0:
|
|
100
|
+
result.z = F.pad(result.z, (left, right))
|
|
101
|
+
elif diff < 0:
|
|
102
|
+
result.z = result.z[..., left:-right]
|
|
103
|
+
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
def encode(self, z):
|
|
107
|
+
z = self.downsample(z)
|
|
108
|
+
_, indices = self.residual_fsq(z.mT)
|
|
109
|
+
indices = rearrange(indices, "g b l r -> b (g r) l")
|
|
110
|
+
return indices
|
|
111
|
+
|
|
112
|
+
def decode(self, indices: torch.Tensor):
|
|
113
|
+
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
|
|
114
|
+
z_q = self.residual_fsq.get_output_from_indices(indices)
|
|
115
|
+
z_q = self.upsample(z_q.mT)
|
|
116
|
+
return z_q
|
|
117
|
+
|
|
118
|
+
# def from_latents(self, latents: torch.Tensor):
|
|
119
|
+
# z_q, z_p, codes = super().from_latents(latents)
|
|
120
|
+
# z_q = self.upsample(z_q)
|
|
121
|
+
# return z_q, z_p, codes
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
if __name__ == "__main__":
|
|
125
|
+
rvq = DownsampleFiniteScalarQuantize(
|
|
126
|
+
n_codebooks=1,
|
|
127
|
+
downsample_factor=(2, 2),
|
|
128
|
+
)
|
|
129
|
+
x = torch.randn(16, 512, 80)
|
|
130
|
+
|
|
131
|
+
result = rvq(x)
|
|
132
|
+
print(rvq)
|
|
133
|
+
print(result.latents.shape, result.codes.shape, result.z.shape)
|
|
134
|
+
|
|
135
|
+
# y = rvq.from_codes(result.codes)
|
|
136
|
+
# print(y[0].shape)
|
|
137
|
+
|
|
138
|
+
# y = rvq.from_latents(result.latents)
|
|
139
|
+
# print(y[0].shape)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from fish_speech.utils import autocast_exclude_mps
|
|
8
|
+
|
|
9
|
+
from .wavenet import WaveNet
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ReferenceEncoder(WaveNet):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
input_channels: Optional[int] = None,
|
|
16
|
+
output_channels: Optional[int] = None,
|
|
17
|
+
residual_channels: int = 512,
|
|
18
|
+
residual_layers: int = 20,
|
|
19
|
+
dilation_cycle: Optional[int] = 4,
|
|
20
|
+
num_heads: int = 8,
|
|
21
|
+
latent_len: int = 4,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(
|
|
24
|
+
input_channels=input_channels,
|
|
25
|
+
residual_channels=residual_channels,
|
|
26
|
+
residual_layers=residual_layers,
|
|
27
|
+
dilation_cycle=dilation_cycle,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
self.head_dim = residual_channels // num_heads
|
|
31
|
+
self.num_heads = num_heads
|
|
32
|
+
|
|
33
|
+
self.latent_len = latent_len
|
|
34
|
+
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
|
|
35
|
+
|
|
36
|
+
self.q = nn.Linear(residual_channels, residual_channels, bias=True)
|
|
37
|
+
self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
|
|
38
|
+
self.q_norm = nn.LayerNorm(self.head_dim)
|
|
39
|
+
self.k_norm = nn.LayerNorm(self.head_dim)
|
|
40
|
+
self.proj = nn.Linear(residual_channels, residual_channels)
|
|
41
|
+
self.proj_drop = nn.Dropout(0.1)
|
|
42
|
+
|
|
43
|
+
self.norm = nn.LayerNorm(residual_channels)
|
|
44
|
+
self.mlp = nn.Sequential(
|
|
45
|
+
nn.Linear(residual_channels, residual_channels * 4),
|
|
46
|
+
nn.SiLU(),
|
|
47
|
+
nn.Linear(residual_channels * 4, residual_channels),
|
|
48
|
+
)
|
|
49
|
+
self.output_projection_attn = nn.Linear(residual_channels, output_channels)
|
|
50
|
+
|
|
51
|
+
torch.nn.init.trunc_normal_(self.latent, std=0.02)
|
|
52
|
+
self.apply(self.init_weights)
|
|
53
|
+
|
|
54
|
+
def init_weights(self, m):
|
|
55
|
+
if isinstance(m, nn.Linear):
|
|
56
|
+
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
|
57
|
+
if m.bias is not None:
|
|
58
|
+
torch.nn.init.constant_(m.bias, 0)
|
|
59
|
+
|
|
60
|
+
def forward(self, x, attn_mask=None):
|
|
61
|
+
x = super().forward(x).mT
|
|
62
|
+
B, N, C = x.shape
|
|
63
|
+
|
|
64
|
+
# Calculate mask
|
|
65
|
+
if attn_mask is not None:
|
|
66
|
+
assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
|
|
67
|
+
|
|
68
|
+
attn_mask = attn_mask[:, None, None, :].expand(
|
|
69
|
+
B, self.num_heads, self.latent_len, N
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
q_latent = self.latent.expand(B, -1, -1)
|
|
73
|
+
q = (
|
|
74
|
+
self.q(q_latent)
|
|
75
|
+
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
|
|
76
|
+
.transpose(1, 2)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
kv = (
|
|
80
|
+
self.kv(x)
|
|
81
|
+
.reshape(B, N, 2, self.num_heads, self.head_dim)
|
|
82
|
+
.permute(2, 0, 3, 1, 4)
|
|
83
|
+
)
|
|
84
|
+
k, v = kv.unbind(0)
|
|
85
|
+
|
|
86
|
+
q, k = self.q_norm(q), self.k_norm(k)
|
|
87
|
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
88
|
+
|
|
89
|
+
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
|
90
|
+
x = self.proj(x)
|
|
91
|
+
x = self.proj_drop(x)
|
|
92
|
+
|
|
93
|
+
x = x + self.mlp(self.norm(x))
|
|
94
|
+
x = self.output_projection_attn(x)
|
|
95
|
+
x = x.mean(1)
|
|
96
|
+
|
|
97
|
+
return x
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
if __name__ == "__main__":
|
|
101
|
+
with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
|
|
102
|
+
model = ReferenceEncoder(
|
|
103
|
+
input_channels=128,
|
|
104
|
+
output_channels=64,
|
|
105
|
+
residual_channels=384,
|
|
106
|
+
residual_layers=20,
|
|
107
|
+
dilation_cycle=4,
|
|
108
|
+
num_heads=8,
|
|
109
|
+
)
|
|
110
|
+
x = torch.randn(4, 128, 64)
|
|
111
|
+
mask = torch.ones(4, 64, dtype=torch.bool)
|
|
112
|
+
y = model(x, mask)
|
|
113
|
+
print(y.shape)
|
|
114
|
+
loss = F.mse_loss(y, torch.randn(4, 64))
|
|
115
|
+
loss.backward()
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Mish(nn.Module):
|
|
10
|
+
def forward(self, x):
|
|
11
|
+
return x * torch.tanh(F.softplus(x))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DiffusionEmbedding(nn.Module):
|
|
15
|
+
"""Diffusion Step Embedding"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, d_denoiser):
|
|
18
|
+
super(DiffusionEmbedding, self).__init__()
|
|
19
|
+
self.dim = d_denoiser
|
|
20
|
+
|
|
21
|
+
def forward(self, x):
|
|
22
|
+
device = x.device
|
|
23
|
+
half_dim = self.dim // 2
|
|
24
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
25
|
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
|
26
|
+
emb = x[:, None] * emb[None, :]
|
|
27
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
28
|
+
return emb
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LinearNorm(nn.Module):
|
|
32
|
+
"""LinearNorm Projection"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, in_features, out_features, bias=False):
|
|
35
|
+
super(LinearNorm, self).__init__()
|
|
36
|
+
self.linear = nn.Linear(in_features, out_features, bias)
|
|
37
|
+
|
|
38
|
+
nn.init.xavier_uniform_(self.linear.weight)
|
|
39
|
+
if bias:
|
|
40
|
+
nn.init.constant_(self.linear.bias, 0.0)
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
x = self.linear(x)
|
|
44
|
+
return x
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ConvNorm(nn.Module):
|
|
48
|
+
"""1D Convolution"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
in_channels,
|
|
53
|
+
out_channels,
|
|
54
|
+
kernel_size=1,
|
|
55
|
+
stride=1,
|
|
56
|
+
padding=None,
|
|
57
|
+
dilation=1,
|
|
58
|
+
bias=True,
|
|
59
|
+
w_init_gain="linear",
|
|
60
|
+
):
|
|
61
|
+
super(ConvNorm, self).__init__()
|
|
62
|
+
|
|
63
|
+
if padding is None:
|
|
64
|
+
assert kernel_size % 2 == 1
|
|
65
|
+
padding = int(dilation * (kernel_size - 1) / 2)
|
|
66
|
+
|
|
67
|
+
self.conv = nn.Conv1d(
|
|
68
|
+
in_channels,
|
|
69
|
+
out_channels,
|
|
70
|
+
kernel_size=kernel_size,
|
|
71
|
+
stride=stride,
|
|
72
|
+
padding=padding,
|
|
73
|
+
dilation=dilation,
|
|
74
|
+
bias=bias,
|
|
75
|
+
)
|
|
76
|
+
nn.init.kaiming_normal_(self.conv.weight)
|
|
77
|
+
|
|
78
|
+
def forward(self, signal):
|
|
79
|
+
conv_signal = self.conv(signal)
|
|
80
|
+
|
|
81
|
+
return conv_signal
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ResidualBlock(nn.Module):
|
|
85
|
+
"""Residual Block"""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
residual_channels,
|
|
90
|
+
use_linear_bias=False,
|
|
91
|
+
dilation=1,
|
|
92
|
+
condition_channels=None,
|
|
93
|
+
):
|
|
94
|
+
super(ResidualBlock, self).__init__()
|
|
95
|
+
self.conv_layer = ConvNorm(
|
|
96
|
+
residual_channels,
|
|
97
|
+
2 * residual_channels,
|
|
98
|
+
kernel_size=3,
|
|
99
|
+
stride=1,
|
|
100
|
+
padding=dilation,
|
|
101
|
+
dilation=dilation,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if condition_channels is not None:
|
|
105
|
+
self.diffusion_projection = LinearNorm(
|
|
106
|
+
residual_channels, residual_channels, use_linear_bias
|
|
107
|
+
)
|
|
108
|
+
self.condition_projection = ConvNorm(
|
|
109
|
+
condition_channels, 2 * residual_channels, kernel_size=1
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.output_projection = ConvNorm(
|
|
113
|
+
residual_channels, 2 * residual_channels, kernel_size=1
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def forward(self, x, condition=None, diffusion_step=None):
|
|
117
|
+
y = x
|
|
118
|
+
|
|
119
|
+
if diffusion_step is not None:
|
|
120
|
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
|
121
|
+
y = y + diffusion_step
|
|
122
|
+
|
|
123
|
+
y = self.conv_layer(y)
|
|
124
|
+
|
|
125
|
+
if condition is not None:
|
|
126
|
+
condition = self.condition_projection(condition)
|
|
127
|
+
y = y + condition
|
|
128
|
+
|
|
129
|
+
gate, filter = torch.chunk(y, 2, dim=1)
|
|
130
|
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
|
131
|
+
|
|
132
|
+
y = self.output_projection(y)
|
|
133
|
+
residual, skip = torch.chunk(y, 2, dim=1)
|
|
134
|
+
|
|
135
|
+
return (x + residual) / math.sqrt(2.0), skip
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class WaveNet(nn.Module):
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
input_channels: Optional[int] = None,
|
|
142
|
+
output_channels: Optional[int] = None,
|
|
143
|
+
residual_channels: int = 512,
|
|
144
|
+
residual_layers: int = 20,
|
|
145
|
+
dilation_cycle: Optional[int] = 4,
|
|
146
|
+
is_diffusion: bool = False,
|
|
147
|
+
condition_channels: Optional[int] = None,
|
|
148
|
+
):
|
|
149
|
+
super().__init__()
|
|
150
|
+
|
|
151
|
+
# Input projection
|
|
152
|
+
self.input_projection = None
|
|
153
|
+
if input_channels is not None and input_channels != residual_channels:
|
|
154
|
+
self.input_projection = ConvNorm(
|
|
155
|
+
input_channels, residual_channels, kernel_size=1
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if input_channels is None:
|
|
159
|
+
input_channels = residual_channels
|
|
160
|
+
|
|
161
|
+
self.input_channels = input_channels
|
|
162
|
+
|
|
163
|
+
# Residual layers
|
|
164
|
+
self.residual_layers = nn.ModuleList(
|
|
165
|
+
[
|
|
166
|
+
ResidualBlock(
|
|
167
|
+
residual_channels=residual_channels,
|
|
168
|
+
use_linear_bias=False,
|
|
169
|
+
dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
|
|
170
|
+
condition_channels=condition_channels,
|
|
171
|
+
)
|
|
172
|
+
for i in range(residual_layers)
|
|
173
|
+
]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Skip projection
|
|
177
|
+
self.skip_projection = ConvNorm(
|
|
178
|
+
residual_channels, residual_channels, kernel_size=1
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Output projection
|
|
182
|
+
self.output_projection = None
|
|
183
|
+
if output_channels is not None and output_channels != residual_channels:
|
|
184
|
+
self.output_projection = ConvNorm(
|
|
185
|
+
residual_channels, output_channels, kernel_size=1
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if is_diffusion:
|
|
189
|
+
self.diffusion_embedding = DiffusionEmbedding(residual_channels)
|
|
190
|
+
self.mlp = nn.Sequential(
|
|
191
|
+
LinearNorm(residual_channels, residual_channels * 4, False),
|
|
192
|
+
Mish(),
|
|
193
|
+
LinearNorm(residual_channels * 4, residual_channels, False),
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.apply(self._init_weights)
|
|
197
|
+
|
|
198
|
+
def _init_weights(self, m):
|
|
199
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
200
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
201
|
+
if getattr(m, "bias", None) is not None:
|
|
202
|
+
nn.init.constant_(m.bias, 0)
|
|
203
|
+
|
|
204
|
+
def forward(self, x, t=None, condition=None):
|
|
205
|
+
if self.input_projection is not None:
|
|
206
|
+
x = self.input_projection(x)
|
|
207
|
+
x = F.silu(x)
|
|
208
|
+
|
|
209
|
+
if t is not None:
|
|
210
|
+
t = self.diffusion_embedding(t)
|
|
211
|
+
t = self.mlp(t)
|
|
212
|
+
|
|
213
|
+
skip = []
|
|
214
|
+
for layer in self.residual_layers:
|
|
215
|
+
x, skip_connection = layer(x, condition, t)
|
|
216
|
+
skip.append(skip_connection)
|
|
217
|
+
|
|
218
|
+
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
|
|
219
|
+
x = self.skip_projection(x)
|
|
220
|
+
|
|
221
|
+
if self.output_projection is not None:
|
|
222
|
+
x = F.silu(x)
|
|
223
|
+
x = self.output_projection(x)
|
|
224
|
+
|
|
225
|
+
return x
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import matplotlib
|
|
2
|
+
import torch
|
|
3
|
+
from matplotlib import pyplot as plt
|
|
4
|
+
|
|
5
|
+
matplotlib.use("Agg")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_pad_shape(pad_shape):
|
|
9
|
+
l = pad_shape[::-1]
|
|
10
|
+
pad_shape = [item for sublist in l for item in sublist]
|
|
11
|
+
return pad_shape
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sequence_mask(length, max_length=None):
|
|
15
|
+
if max_length is None:
|
|
16
|
+
max_length = length.max()
|
|
17
|
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
18
|
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def init_weights(m, mean=0.0, std=0.01):
|
|
22
|
+
classname = m.__class__.__name__
|
|
23
|
+
if classname.find("Conv") != -1:
|
|
24
|
+
m.weight.data.normal_(mean, std)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_padding(kernel_size, dilation=1):
|
|
28
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def plot_mel(data, titles=None):
|
|
32
|
+
fig, axes = plt.subplots(len(data), 1, squeeze=False)
|
|
33
|
+
|
|
34
|
+
if titles is None:
|
|
35
|
+
titles = [None for i in range(len(data))]
|
|
36
|
+
|
|
37
|
+
plt.tight_layout()
|
|
38
|
+
|
|
39
|
+
for i in range(len(data)):
|
|
40
|
+
mel = data[i]
|
|
41
|
+
|
|
42
|
+
if isinstance(mel, torch.Tensor):
|
|
43
|
+
mel = mel.float().detach().cpu().numpy()
|
|
44
|
+
|
|
45
|
+
axes[i][0].imshow(mel, origin="lower")
|
|
46
|
+
axes[i][0].set_aspect(2.5, adjustable="box")
|
|
47
|
+
axes[i][0].set_ylim(0, mel.shape[0])
|
|
48
|
+
axes[i][0].set_title(titles[i], fontsize="medium")
|
|
49
|
+
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
|
|
50
|
+
axes[i][0].set_anchor("W")
|
|
51
|
+
|
|
52
|
+
return fig
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def slice_segments(x, ids_str, segment_size=4):
|
|
56
|
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
|
57
|
+
for i in range(x.size(0)):
|
|
58
|
+
idx_str = ids_str[i]
|
|
59
|
+
idx_end = idx_str + segment_size
|
|
60
|
+
ret[i] = x[i, :, idx_str:idx_end]
|
|
61
|
+
|
|
62
|
+
return ret
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
|
66
|
+
b, d, t = x.size()
|
|
67
|
+
if x_lengths is None:
|
|
68
|
+
x_lengths = t
|
|
69
|
+
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
|
|
70
|
+
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
|
71
|
+
ret = slice_segments(x, ids_str, segment_size)
|
|
72
|
+
return ret, ids_str
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@torch.jit.script
|
|
76
|
+
def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
|
|
77
|
+
n_channels_int = n_channels[0]
|
|
78
|
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
|
79
|
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
|
80
|
+
acts = t_act * s_act
|
|
81
|
+
|
|
82
|
+
return acts
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def avg_with_mask(x, mask):
|
|
86
|
+
assert mask.dtype == torch.float, "Mask should be float"
|
|
87
|
+
|
|
88
|
+
if mask.ndim == 2:
|
|
89
|
+
mask = mask.unsqueeze(1)
|
|
90
|
+
|
|
91
|
+
if mask.shape[1] == 1:
|
|
92
|
+
mask = mask.expand_as(x)
|
|
93
|
+
|
|
94
|
+
return (x * mask).sum() / mask.sum()
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_cosine_schedule_with_warmup_lr_lambda(
|
|
5
|
+
current_step: int,
|
|
6
|
+
*,
|
|
7
|
+
num_warmup_steps: int | float,
|
|
8
|
+
num_training_steps: int,
|
|
9
|
+
num_cycles: float = 0.5,
|
|
10
|
+
final_lr_ratio: float = 0.0,
|
|
11
|
+
):
|
|
12
|
+
if 0 < num_warmup_steps < 1: # float mode
|
|
13
|
+
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
|
14
|
+
|
|
15
|
+
if current_step < num_warmup_steps:
|
|
16
|
+
return float(current_step) / float(max(1, num_warmup_steps))
|
|
17
|
+
|
|
18
|
+
progress = float(current_step - num_warmup_steps) / float(
|
|
19
|
+
max(1, num_training_steps - num_warmup_steps)
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
return max(
|
|
23
|
+
final_lr_ratio,
|
|
24
|
+
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_constant_schedule_with_warmup_lr_lambda(
|
|
29
|
+
current_step: int,
|
|
30
|
+
*,
|
|
31
|
+
num_warmup_steps: int | float,
|
|
32
|
+
num_training_steps: int | None = None,
|
|
33
|
+
):
|
|
34
|
+
if 0 < num_warmup_steps < 1: # float mode
|
|
35
|
+
num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
|
36
|
+
|
|
37
|
+
if current_step < num_warmup_steps:
|
|
38
|
+
return float(current_step) / float(max(1, num_warmup_steps))
|
|
39
|
+
|
|
40
|
+
return 1.0
|
|
File without changes
|