xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
""" from https://github.com/jik876/hifi-gan """
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import matplotlib
|
|
7
|
+
import torch
|
|
8
|
+
from torch.nn.utils import weight_norm
|
|
9
|
+
|
|
10
|
+
matplotlib.use("Agg")
|
|
11
|
+
import matplotlib.pylab as plt
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def plot_spectrogram(spectrogram):
|
|
15
|
+
fig, ax = plt.subplots(figsize=(10, 2))
|
|
16
|
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
|
17
|
+
plt.colorbar(im, ax=ax)
|
|
18
|
+
|
|
19
|
+
fig.canvas.draw()
|
|
20
|
+
plt.close()
|
|
21
|
+
|
|
22
|
+
return fig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def init_weights(m, mean=0.0, std=0.01):
|
|
26
|
+
classname = m.__class__.__name__
|
|
27
|
+
if classname.find("Conv") != -1:
|
|
28
|
+
m.weight.data.normal_(mean, std)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def apply_weight_norm(m):
|
|
32
|
+
classname = m.__class__.__name__
|
|
33
|
+
if classname.find("Conv") != -1:
|
|
34
|
+
weight_norm(m)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_padding(kernel_size, dilation=1):
|
|
38
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_checkpoint(filepath, device):
|
|
42
|
+
assert os.path.isfile(filepath)
|
|
43
|
+
print(f"Loading '{filepath}'")
|
|
44
|
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
|
45
|
+
print("Complete.")
|
|
46
|
+
return checkpoint_dict
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def save_checkpoint(filepath, obj):
|
|
50
|
+
print(f"Saving checkpoint to {filepath}")
|
|
51
|
+
torch.save(obj, filepath)
|
|
52
|
+
print("Complete.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def scan_checkpoint(cp_dir, prefix):
|
|
56
|
+
pattern = os.path.join(cp_dir, prefix + "????????")
|
|
57
|
+
cp_list = glob.glob(pattern)
|
|
58
|
+
if len(cp_list) == 0:
|
|
59
|
+
return None
|
|
60
|
+
return sorted(cp_list)[-1]
|
|
File without changes
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is a base lightning module that can be used to train a model.
|
|
3
|
+
The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
|
|
4
|
+
"""
|
|
5
|
+
import inspect
|
|
6
|
+
from abc import ABC
|
|
7
|
+
from typing import Any, Dict
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from lightning import LightningModule
|
|
11
|
+
from lightning.pytorch.utilities import grad_norm
|
|
12
|
+
|
|
13
|
+
from matcha import utils
|
|
14
|
+
from matcha.utils.utils import plot_tensor
|
|
15
|
+
|
|
16
|
+
log = utils.get_pylogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseLightningClass(LightningModule, ABC):
|
|
20
|
+
def update_data_statistics(self, data_statistics):
|
|
21
|
+
if data_statistics is None:
|
|
22
|
+
data_statistics = {
|
|
23
|
+
"mel_mean": 0.0,
|
|
24
|
+
"mel_std": 1.0,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
|
|
28
|
+
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
|
|
29
|
+
|
|
30
|
+
def configure_optimizers(self) -> Any:
|
|
31
|
+
optimizer = self.hparams.optimizer(params=self.parameters())
|
|
32
|
+
if self.hparams.scheduler not in (None, {}):
|
|
33
|
+
scheduler_args = {}
|
|
34
|
+
# Manage last epoch for exponential schedulers
|
|
35
|
+
if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
|
|
36
|
+
if hasattr(self, "ckpt_loaded_epoch"):
|
|
37
|
+
current_epoch = self.ckpt_loaded_epoch - 1
|
|
38
|
+
else:
|
|
39
|
+
current_epoch = -1
|
|
40
|
+
|
|
41
|
+
scheduler_args.update({"optimizer": optimizer})
|
|
42
|
+
scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
|
|
43
|
+
scheduler.last_epoch = current_epoch
|
|
44
|
+
return {
|
|
45
|
+
"optimizer": optimizer,
|
|
46
|
+
"lr_scheduler": {
|
|
47
|
+
"scheduler": scheduler,
|
|
48
|
+
"interval": self.hparams.scheduler.lightning_args.interval,
|
|
49
|
+
"frequency": self.hparams.scheduler.lightning_args.frequency,
|
|
50
|
+
"name": "learning_rate",
|
|
51
|
+
},
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
return {"optimizer": optimizer}
|
|
55
|
+
|
|
56
|
+
def get_losses(self, batch):
|
|
57
|
+
x, x_lengths = batch["x"], batch["x_lengths"]
|
|
58
|
+
y, y_lengths = batch["y"], batch["y_lengths"]
|
|
59
|
+
spks = batch["spks"]
|
|
60
|
+
|
|
61
|
+
dur_loss, prior_loss, diff_loss, *_ = self(
|
|
62
|
+
x=x,
|
|
63
|
+
x_lengths=x_lengths,
|
|
64
|
+
y=y,
|
|
65
|
+
y_lengths=y_lengths,
|
|
66
|
+
spks=spks,
|
|
67
|
+
out_size=self.out_size,
|
|
68
|
+
durations=batch["durations"],
|
|
69
|
+
)
|
|
70
|
+
return {
|
|
71
|
+
"dur_loss": dur_loss,
|
|
72
|
+
"prior_loss": prior_loss,
|
|
73
|
+
"diff_loss": diff_loss,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
|
77
|
+
self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
|
|
78
|
+
|
|
79
|
+
def training_step(self, batch: Any, batch_idx: int):
|
|
80
|
+
loss_dict = self.get_losses(batch)
|
|
81
|
+
self.log(
|
|
82
|
+
"step",
|
|
83
|
+
float(self.global_step),
|
|
84
|
+
on_step=True,
|
|
85
|
+
prog_bar=True,
|
|
86
|
+
logger=True,
|
|
87
|
+
sync_dist=True,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self.log(
|
|
91
|
+
"sub_loss/train_dur_loss",
|
|
92
|
+
loss_dict["dur_loss"],
|
|
93
|
+
on_step=True,
|
|
94
|
+
on_epoch=True,
|
|
95
|
+
logger=True,
|
|
96
|
+
sync_dist=True,
|
|
97
|
+
)
|
|
98
|
+
self.log(
|
|
99
|
+
"sub_loss/train_prior_loss",
|
|
100
|
+
loss_dict["prior_loss"],
|
|
101
|
+
on_step=True,
|
|
102
|
+
on_epoch=True,
|
|
103
|
+
logger=True,
|
|
104
|
+
sync_dist=True,
|
|
105
|
+
)
|
|
106
|
+
self.log(
|
|
107
|
+
"sub_loss/train_diff_loss",
|
|
108
|
+
loss_dict["diff_loss"],
|
|
109
|
+
on_step=True,
|
|
110
|
+
on_epoch=True,
|
|
111
|
+
logger=True,
|
|
112
|
+
sync_dist=True,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
total_loss = sum(loss_dict.values())
|
|
116
|
+
self.log(
|
|
117
|
+
"loss/train",
|
|
118
|
+
total_loss,
|
|
119
|
+
on_step=True,
|
|
120
|
+
on_epoch=True,
|
|
121
|
+
logger=True,
|
|
122
|
+
prog_bar=True,
|
|
123
|
+
sync_dist=True,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return {"loss": total_loss, "log": loss_dict}
|
|
127
|
+
|
|
128
|
+
def validation_step(self, batch: Any, batch_idx: int):
|
|
129
|
+
loss_dict = self.get_losses(batch)
|
|
130
|
+
self.log(
|
|
131
|
+
"sub_loss/val_dur_loss",
|
|
132
|
+
loss_dict["dur_loss"],
|
|
133
|
+
on_step=True,
|
|
134
|
+
on_epoch=True,
|
|
135
|
+
logger=True,
|
|
136
|
+
sync_dist=True,
|
|
137
|
+
)
|
|
138
|
+
self.log(
|
|
139
|
+
"sub_loss/val_prior_loss",
|
|
140
|
+
loss_dict["prior_loss"],
|
|
141
|
+
on_step=True,
|
|
142
|
+
on_epoch=True,
|
|
143
|
+
logger=True,
|
|
144
|
+
sync_dist=True,
|
|
145
|
+
)
|
|
146
|
+
self.log(
|
|
147
|
+
"sub_loss/val_diff_loss",
|
|
148
|
+
loss_dict["diff_loss"],
|
|
149
|
+
on_step=True,
|
|
150
|
+
on_epoch=True,
|
|
151
|
+
logger=True,
|
|
152
|
+
sync_dist=True,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
total_loss = sum(loss_dict.values())
|
|
156
|
+
self.log(
|
|
157
|
+
"loss/val",
|
|
158
|
+
total_loss,
|
|
159
|
+
on_step=True,
|
|
160
|
+
on_epoch=True,
|
|
161
|
+
logger=True,
|
|
162
|
+
prog_bar=True,
|
|
163
|
+
sync_dist=True,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return total_loss
|
|
167
|
+
|
|
168
|
+
def on_validation_end(self) -> None:
|
|
169
|
+
if self.trainer.is_global_zero:
|
|
170
|
+
one_batch = next(iter(self.trainer.val_dataloaders))
|
|
171
|
+
if self.current_epoch == 0:
|
|
172
|
+
log.debug("Plotting original samples")
|
|
173
|
+
for i in range(2):
|
|
174
|
+
y = one_batch["y"][i].unsqueeze(0).to(self.device)
|
|
175
|
+
self.logger.experiment.add_image(
|
|
176
|
+
f"original/{i}",
|
|
177
|
+
plot_tensor(y.squeeze().cpu()),
|
|
178
|
+
self.current_epoch,
|
|
179
|
+
dataformats="HWC",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
log.debug("Synthesising...")
|
|
183
|
+
for i in range(2):
|
|
184
|
+
x = one_batch["x"][i].unsqueeze(0).to(self.device)
|
|
185
|
+
x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
|
|
186
|
+
spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
|
|
187
|
+
output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
|
|
188
|
+
y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
|
|
189
|
+
attn = output["attn"]
|
|
190
|
+
self.logger.experiment.add_image(
|
|
191
|
+
f"generated_enc/{i}",
|
|
192
|
+
plot_tensor(y_enc.squeeze().cpu()),
|
|
193
|
+
self.current_epoch,
|
|
194
|
+
dataformats="HWC",
|
|
195
|
+
)
|
|
196
|
+
self.logger.experiment.add_image(
|
|
197
|
+
f"generated_dec/{i}",
|
|
198
|
+
plot_tensor(y_dec.squeeze().cpu()),
|
|
199
|
+
self.current_epoch,
|
|
200
|
+
dataformats="HWC",
|
|
201
|
+
)
|
|
202
|
+
self.logger.experiment.add_image(
|
|
203
|
+
f"alignment/{i}",
|
|
204
|
+
plot_tensor(attn.squeeze().cpu()),
|
|
205
|
+
self.current_epoch,
|
|
206
|
+
dataformats="HWC",
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def on_before_optimizer_step(self, optimizer):
|
|
210
|
+
self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
|
|
File without changes
|