minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
matcha/train.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
import lightning as L
|
|
5
|
+
import rootutils
|
|
6
|
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
7
|
+
from lightning.pytorch.loggers import Logger
|
|
8
|
+
from omegaconf import DictConfig
|
|
9
|
+
|
|
10
|
+
from matcha import utils
|
|
11
|
+
|
|
12
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
13
|
+
# ------------------------------------------------------------------------------------ #
|
|
14
|
+
# the setup_root above is equivalent to:
|
|
15
|
+
# - adding project root dir to PYTHONPATH
|
|
16
|
+
# (so you don't need to force user to install project as a package)
|
|
17
|
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
|
18
|
+
# - setting up PROJECT_ROOT environment variable
|
|
19
|
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
|
20
|
+
# (this way all filepaths are the same no matter where you run the code)
|
|
21
|
+
# - loading environment variables from ".env" in root dir
|
|
22
|
+
#
|
|
23
|
+
# you can remove it if you:
|
|
24
|
+
# 1. either install project as a package or move entry files to project root dir
|
|
25
|
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
|
26
|
+
#
|
|
27
|
+
# more info: https://github.com/ashleve/rootutils
|
|
28
|
+
# ------------------------------------------------------------------------------------ #
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
log = utils.get_pylogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@utils.task_wrapper
|
|
35
|
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
36
|
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
|
37
|
+
training.
|
|
38
|
+
|
|
39
|
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
|
40
|
+
failure. Useful for multiruns, saving info about the crash, etc.
|
|
41
|
+
|
|
42
|
+
:param cfg: A DictConfig configuration composed by Hydra.
|
|
43
|
+
:return: A tuple with metrics and dict with all instantiated objects.
|
|
44
|
+
"""
|
|
45
|
+
# set seed for random number generators in pytorch, numpy and python.random
|
|
46
|
+
if cfg.get("seed"):
|
|
47
|
+
L.seed_everything(cfg.seed, workers=True)
|
|
48
|
+
|
|
49
|
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access
|
|
50
|
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
51
|
+
|
|
52
|
+
log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access
|
|
53
|
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
|
54
|
+
|
|
55
|
+
log.info("Instantiating callbacks...")
|
|
56
|
+
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
|
|
57
|
+
|
|
58
|
+
log.info("Instantiating loggers...")
|
|
59
|
+
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
|
|
60
|
+
|
|
61
|
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access
|
|
62
|
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
|
|
63
|
+
|
|
64
|
+
object_dict = {
|
|
65
|
+
"cfg": cfg,
|
|
66
|
+
"datamodule": datamodule,
|
|
67
|
+
"model": model,
|
|
68
|
+
"callbacks": callbacks,
|
|
69
|
+
"logger": logger,
|
|
70
|
+
"trainer": trainer,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if logger:
|
|
74
|
+
log.info("Logging hyperparameters!")
|
|
75
|
+
utils.log_hyperparameters(object_dict)
|
|
76
|
+
|
|
77
|
+
if cfg.get("train"):
|
|
78
|
+
log.info("Starting training!")
|
|
79
|
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
|
80
|
+
|
|
81
|
+
train_metrics = trainer.callback_metrics
|
|
82
|
+
|
|
83
|
+
if cfg.get("test"):
|
|
84
|
+
log.info("Starting testing!")
|
|
85
|
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
86
|
+
if ckpt_path == "":
|
|
87
|
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
|
88
|
+
ckpt_path = None
|
|
89
|
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
|
90
|
+
log.info(f"Best ckpt path: {ckpt_path}")
|
|
91
|
+
|
|
92
|
+
test_metrics = trainer.callback_metrics
|
|
93
|
+
|
|
94
|
+
# merge train and test metrics
|
|
95
|
+
metric_dict = {**train_metrics, **test_metrics}
|
|
96
|
+
|
|
97
|
+
return metric_dict, object_dict
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
|
|
101
|
+
def main(cfg: DictConfig) -> Optional[float]:
|
|
102
|
+
"""Main entry point for training.
|
|
103
|
+
|
|
104
|
+
:param cfg: DictConfig configuration composed by Hydra.
|
|
105
|
+
:return: Optional[float] with optimized metric value.
|
|
106
|
+
"""
|
|
107
|
+
# apply extra utilities
|
|
108
|
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
|
109
|
+
utils.extras(cfg)
|
|
110
|
+
|
|
111
|
+
# train the model
|
|
112
|
+
metric_dict, _ = train(cfg)
|
|
113
|
+
|
|
114
|
+
# safely retrieve metric value for hydra-based hyperparameter optimization
|
|
115
|
+
metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
|
|
116
|
+
|
|
117
|
+
# return optimized metric
|
|
118
|
+
return metric_value
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == "__main__":
|
|
122
|
+
main() # pylint: disable=no-value-for-parameter
|
matcha/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
|
2
|
+
from matcha.utils.logging_utils import log_hyperparameters
|
|
3
|
+
from matcha.utils.pylogger import get_pylogger
|
|
4
|
+
from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
|
5
|
+
from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
matcha/utils/audio.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.utils.data
|
|
4
|
+
from librosa.filters import mel as librosa_mel_fn
|
|
5
|
+
from scipy.io.wavfile import read
|
|
6
|
+
|
|
7
|
+
MAX_WAV_VALUE = 32768.0
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_wav(full_path):
|
|
11
|
+
sampling_rate, data = read(full_path)
|
|
12
|
+
return data, sampling_rate
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
|
16
|
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def dynamic_range_decompression(x, C=1):
|
|
20
|
+
return np.exp(x) / C
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
|
24
|
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def dynamic_range_decompression_torch(x, C=1):
|
|
28
|
+
return torch.exp(x) / C
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def spectral_normalize_torch(magnitudes):
|
|
32
|
+
output = dynamic_range_compression_torch(magnitudes)
|
|
33
|
+
return output
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def spectral_de_normalize_torch(magnitudes):
|
|
37
|
+
output = dynamic_range_decompression_torch(magnitudes)
|
|
38
|
+
return output
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
mel_basis = {}
|
|
42
|
+
hann_window = {}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
|
46
|
+
if torch.min(y) < -1.0:
|
|
47
|
+
print("min value is ", torch.min(y))
|
|
48
|
+
if torch.max(y) > 1.0:
|
|
49
|
+
print("max value is ", torch.max(y))
|
|
50
|
+
|
|
51
|
+
global mel_basis, hann_window # pylint: disable=global-statement
|
|
52
|
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
|
53
|
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
54
|
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
|
55
|
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
|
56
|
+
|
|
57
|
+
y = torch.nn.functional.pad(
|
|
58
|
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
|
59
|
+
)
|
|
60
|
+
y = y.squeeze(1)
|
|
61
|
+
|
|
62
|
+
spec = torch.view_as_real(
|
|
63
|
+
torch.stft(
|
|
64
|
+
y,
|
|
65
|
+
n_fft,
|
|
66
|
+
hop_length=hop_size,
|
|
67
|
+
win_length=win_size,
|
|
68
|
+
window=hann_window[str(y.device)],
|
|
69
|
+
center=center,
|
|
70
|
+
pad_mode="reflect",
|
|
71
|
+
normalized=False,
|
|
72
|
+
onesided=True,
|
|
73
|
+
return_complex=True,
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
|
78
|
+
|
|
79
|
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
|
80
|
+
spec = spectral_normalize_torch(spec)
|
|
81
|
+
|
|
82
|
+
return spec
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
|
3
|
+
when needed.
|
|
4
|
+
|
|
5
|
+
Parameters from hparam.py will be used
|
|
6
|
+
"""
|
|
7
|
+
import argparse
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import rootutils
|
|
14
|
+
import torch
|
|
15
|
+
from hydra import compose, initialize
|
|
16
|
+
from omegaconf import open_dict
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
|
+
|
|
19
|
+
from matcha.data.text_mel_datamodule import TextMelDataModule
|
|
20
|
+
from matcha.utils.logging_utils import pylogger
|
|
21
|
+
|
|
22
|
+
log = pylogger.get_pylogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
|
|
26
|
+
"""Generate data mean and standard deviation helpful in data normalisation
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data_loader (torch.utils.data.Dataloader): _description_
|
|
30
|
+
out_channels (int): mel spectrogram channels
|
|
31
|
+
"""
|
|
32
|
+
total_mel_sum = 0
|
|
33
|
+
total_mel_sq_sum = 0
|
|
34
|
+
total_mel_len = 0
|
|
35
|
+
|
|
36
|
+
for batch in tqdm(data_loader, leave=False):
|
|
37
|
+
mels = batch["y"]
|
|
38
|
+
mel_lengths = batch["y_lengths"]
|
|
39
|
+
|
|
40
|
+
total_mel_len += torch.sum(mel_lengths)
|
|
41
|
+
total_mel_sum += torch.sum(mels)
|
|
42
|
+
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
|
43
|
+
|
|
44
|
+
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
|
45
|
+
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
|
|
46
|
+
|
|
47
|
+
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def main():
|
|
51
|
+
parser = argparse.ArgumentParser()
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"-i",
|
|
55
|
+
"--input-config",
|
|
56
|
+
type=str,
|
|
57
|
+
default="vctk.yaml",
|
|
58
|
+
help="The name of the yaml config file under configs/data",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"-b",
|
|
63
|
+
"--batch-size",
|
|
64
|
+
type=int,
|
|
65
|
+
default="256",
|
|
66
|
+
help="Can have increased batch size for faster computation",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"-f",
|
|
71
|
+
"--force",
|
|
72
|
+
action="store_true",
|
|
73
|
+
default=False,
|
|
74
|
+
required=False,
|
|
75
|
+
help="force overwrite the file",
|
|
76
|
+
)
|
|
77
|
+
args = parser.parse_args()
|
|
78
|
+
output_file = Path(args.input_config).with_suffix(".json")
|
|
79
|
+
|
|
80
|
+
if os.path.exists(output_file) and not args.force:
|
|
81
|
+
print("File already exists. Use -f to force overwrite")
|
|
82
|
+
sys.exit(1)
|
|
83
|
+
|
|
84
|
+
with initialize(version_base="1.3", config_path="../../configs/data"):
|
|
85
|
+
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
|
86
|
+
|
|
87
|
+
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
|
88
|
+
|
|
89
|
+
with open_dict(cfg):
|
|
90
|
+
del cfg["hydra"]
|
|
91
|
+
del cfg["_target_"]
|
|
92
|
+
cfg["data_statistics"] = None
|
|
93
|
+
cfg["seed"] = 1234
|
|
94
|
+
cfg["batch_size"] = args.batch_size
|
|
95
|
+
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
|
96
|
+
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
|
97
|
+
|
|
98
|
+
text_mel_datamodule = TextMelDataModule(**cfg)
|
|
99
|
+
text_mel_datamodule.setup()
|
|
100
|
+
data_loader = text_mel_datamodule.train_dataloader()
|
|
101
|
+
log.info("Dataloader loaded! Now computing stats...")
|
|
102
|
+
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
|
103
|
+
print(params)
|
|
104
|
+
json.dump(
|
|
105
|
+
params,
|
|
106
|
+
open(output_file, "w"),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
if __name__ == "__main__":
|
|
111
|
+
main()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
from lightning import Callback
|
|
5
|
+
from lightning.pytorch.loggers import Logger
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
|
|
8
|
+
from matcha.utils import pylogger
|
|
9
|
+
|
|
10
|
+
log = pylogger.get_pylogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
|
14
|
+
"""Instantiates callbacks from config.
|
|
15
|
+
|
|
16
|
+
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
|
17
|
+
:return: A list of instantiated callbacks.
|
|
18
|
+
"""
|
|
19
|
+
callbacks: List[Callback] = []
|
|
20
|
+
|
|
21
|
+
if not callbacks_cfg:
|
|
22
|
+
log.warning("No callback configs found! Skipping..")
|
|
23
|
+
return callbacks
|
|
24
|
+
|
|
25
|
+
if not isinstance(callbacks_cfg, DictConfig):
|
|
26
|
+
raise TypeError("Callbacks config must be a DictConfig!")
|
|
27
|
+
|
|
28
|
+
for _, cb_conf in callbacks_cfg.items():
|
|
29
|
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
|
30
|
+
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
|
|
31
|
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
32
|
+
|
|
33
|
+
return callbacks
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
|
37
|
+
"""Instantiates loggers from config.
|
|
38
|
+
|
|
39
|
+
:param logger_cfg: A DictConfig object containing logger configurations.
|
|
40
|
+
:return: A list of instantiated loggers.
|
|
41
|
+
"""
|
|
42
|
+
logger: List[Logger] = []
|
|
43
|
+
|
|
44
|
+
if not logger_cfg:
|
|
45
|
+
log.warning("No logger configs found! Skipping...")
|
|
46
|
+
return logger
|
|
47
|
+
|
|
48
|
+
if not isinstance(logger_cfg, DictConfig):
|
|
49
|
+
raise TypeError("Logger config must be a DictConfig!")
|
|
50
|
+
|
|
51
|
+
for _, lg_conf in logger_cfg.items():
|
|
52
|
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
|
53
|
+
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
|
|
54
|
+
logger.append(hydra.utils.instantiate(lg_conf))
|
|
55
|
+
|
|
56
|
+
return logger
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
4
|
+
from omegaconf import OmegaConf
|
|
5
|
+
|
|
6
|
+
from matcha.utils import pylogger
|
|
7
|
+
|
|
8
|
+
log = pylogger.get_pylogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@rank_zero_only
|
|
12
|
+
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
|
13
|
+
"""Controls which config parts are saved by Lightning loggers.
|
|
14
|
+
|
|
15
|
+
Additionally saves:
|
|
16
|
+
- Number of model parameters
|
|
17
|
+
|
|
18
|
+
:param object_dict: A dictionary containing the following objects:
|
|
19
|
+
- `"cfg"`: A DictConfig object containing the main config.
|
|
20
|
+
- `"model"`: The Lightning model.
|
|
21
|
+
- `"trainer"`: The Lightning trainer.
|
|
22
|
+
"""
|
|
23
|
+
hparams = {}
|
|
24
|
+
|
|
25
|
+
cfg = OmegaConf.to_container(object_dict["cfg"])
|
|
26
|
+
model = object_dict["model"]
|
|
27
|
+
trainer = object_dict["trainer"]
|
|
28
|
+
|
|
29
|
+
if not trainer.logger:
|
|
30
|
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
hparams["model"] = cfg["model"]
|
|
34
|
+
|
|
35
|
+
# save number of model parameters
|
|
36
|
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
|
37
|
+
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
38
|
+
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
39
|
+
|
|
40
|
+
hparams["data"] = cfg["data"]
|
|
41
|
+
hparams["trainer"] = cfg["trainer"]
|
|
42
|
+
|
|
43
|
+
hparams["callbacks"] = cfg.get("callbacks")
|
|
44
|
+
hparams["extras"] = cfg.get("extras")
|
|
45
|
+
|
|
46
|
+
hparams["task_name"] = cfg.get("task_name")
|
|
47
|
+
hparams["tags"] = cfg.get("tags")
|
|
48
|
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
|
49
|
+
hparams["seed"] = cfg.get("seed")
|
|
50
|
+
|
|
51
|
+
# send hparams to all loggers
|
|
52
|
+
for logger in trainer.loggers:
|
|
53
|
+
logger.log_hyperparams(hparams)
|
matcha/utils/model.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
""" from https://github.com/jaywalnut310/glow-tts """
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def sequence_mask(length, max_length=None):
|
|
8
|
+
if max_length is None:
|
|
9
|
+
max_length = length.max()
|
|
10
|
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
11
|
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
|
15
|
+
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
|
|
16
|
+
length = (length / factor).ceil() * factor
|
|
17
|
+
if not torch.onnx.is_in_onnx_export():
|
|
18
|
+
return length.int().item()
|
|
19
|
+
else:
|
|
20
|
+
return length
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def convert_pad_shape(pad_shape):
|
|
24
|
+
inverted_shape = pad_shape[::-1]
|
|
25
|
+
pad_shape = [item for sublist in inverted_shape for item in sublist]
|
|
26
|
+
return pad_shape
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def generate_path(duration, mask):
|
|
30
|
+
device = duration.device
|
|
31
|
+
|
|
32
|
+
b, t_x, t_y = mask.shape
|
|
33
|
+
cum_duration = torch.cumsum(duration, 1)
|
|
34
|
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
|
35
|
+
|
|
36
|
+
cum_duration_flat = cum_duration.view(b * t_x)
|
|
37
|
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
|
38
|
+
path = path.view(b, t_x, t_y)
|
|
39
|
+
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
|
40
|
+
path = path * mask
|
|
41
|
+
return path
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def duration_loss(logw, logw_, lengths):
|
|
45
|
+
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
|
46
|
+
return loss
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def normalize(data, mu, std):
|
|
50
|
+
if not isinstance(mu, (float, int)):
|
|
51
|
+
if isinstance(mu, list):
|
|
52
|
+
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
|
|
53
|
+
elif isinstance(mu, torch.Tensor):
|
|
54
|
+
mu = mu.to(data.device)
|
|
55
|
+
elif isinstance(mu, np.ndarray):
|
|
56
|
+
mu = torch.from_numpy(mu).to(data.device)
|
|
57
|
+
mu = mu.unsqueeze(-1)
|
|
58
|
+
|
|
59
|
+
if not isinstance(std, (float, int)):
|
|
60
|
+
if isinstance(std, list):
|
|
61
|
+
std = torch.tensor(std, dtype=data.dtype, device=data.device)
|
|
62
|
+
elif isinstance(std, torch.Tensor):
|
|
63
|
+
std = std.to(data.device)
|
|
64
|
+
elif isinstance(std, np.ndarray):
|
|
65
|
+
std = torch.from_numpy(std).to(data.device)
|
|
66
|
+
std = std.unsqueeze(-1)
|
|
67
|
+
|
|
68
|
+
return (data - mu) / std
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def denormalize(data, mu, std):
|
|
72
|
+
if not isinstance(mu, float):
|
|
73
|
+
if isinstance(mu, list):
|
|
74
|
+
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
|
|
75
|
+
elif isinstance(mu, torch.Tensor):
|
|
76
|
+
mu = mu.to(data.device)
|
|
77
|
+
elif isinstance(mu, np.ndarray):
|
|
78
|
+
mu = torch.from_numpy(mu).to(data.device)
|
|
79
|
+
mu = mu.unsqueeze(-1)
|
|
80
|
+
|
|
81
|
+
if not isinstance(std, float):
|
|
82
|
+
if isinstance(std, list):
|
|
83
|
+
std = torch.tensor(std, dtype=data.dtype, device=data.device)
|
|
84
|
+
elif isinstance(std, torch.Tensor):
|
|
85
|
+
std = std.to(data.device)
|
|
86
|
+
elif isinstance(std, np.ndarray):
|
|
87
|
+
std = torch.from_numpy(std).to(data.device)
|
|
88
|
+
std = std.unsqueeze(-1)
|
|
89
|
+
|
|
90
|
+
return data * std + mu
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from matcha.utils.monotonic_align.core import maximum_path_c
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def maximum_path(value, mask):
|
|
8
|
+
"""Cython optimised version.
|
|
9
|
+
value: [b, t_x, t_y]
|
|
10
|
+
mask: [b, t_x, t_y]
|
|
11
|
+
"""
|
|
12
|
+
value = value * mask
|
|
13
|
+
device = value.device
|
|
14
|
+
dtype = value.dtype
|
|
15
|
+
value = value.data.cpu().numpy().astype(np.float32)
|
|
16
|
+
path = np.zeros_like(value).astype(np.int32)
|
|
17
|
+
mask = mask.data.cpu().numpy()
|
|
18
|
+
|
|
19
|
+
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
|
20
|
+
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
|
21
|
+
maximum_path_c(path, value, t_x_max, t_y_max)
|
|
22
|
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
matcha/utils/pylogger.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_pylogger(name: str = __name__) -> logging.Logger:
|
|
7
|
+
"""Initializes a multi-GPU-friendly python command line logger.
|
|
8
|
+
|
|
9
|
+
:param name: The name of the logger, defaults to ``__name__``.
|
|
10
|
+
|
|
11
|
+
:return: A logger object.
|
|
12
|
+
"""
|
|
13
|
+
logger = logging.getLogger(name)
|
|
14
|
+
|
|
15
|
+
# this ensures all logging levels get marked with the rank zero decorator
|
|
16
|
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
|
17
|
+
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
|
18
|
+
for level in logging_levels:
|
|
19
|
+
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
|
20
|
+
|
|
21
|
+
return logger
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Sequence
|
|
3
|
+
|
|
4
|
+
import rich
|
|
5
|
+
import rich.syntax
|
|
6
|
+
import rich.tree
|
|
7
|
+
from hydra.core.hydra_config import HydraConfig
|
|
8
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
|
10
|
+
from rich.prompt import Prompt
|
|
11
|
+
|
|
12
|
+
from matcha.utils import pylogger
|
|
13
|
+
|
|
14
|
+
log = pylogger.get_pylogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@rank_zero_only
|
|
18
|
+
def print_config_tree(
|
|
19
|
+
cfg: DictConfig,
|
|
20
|
+
print_order: Sequence[str] = (
|
|
21
|
+
"data",
|
|
22
|
+
"model",
|
|
23
|
+
"callbacks",
|
|
24
|
+
"logger",
|
|
25
|
+
"trainer",
|
|
26
|
+
"paths",
|
|
27
|
+
"extras",
|
|
28
|
+
),
|
|
29
|
+
resolve: bool = False,
|
|
30
|
+
save_to_file: bool = False,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
|
33
|
+
|
|
34
|
+
:param cfg: A DictConfig composed by Hydra.
|
|
35
|
+
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
|
36
|
+
"callbacks", "logger", "trainer", "paths", "extras")``.
|
|
37
|
+
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
|
38
|
+
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
|
39
|
+
"""
|
|
40
|
+
style = "dim"
|
|
41
|
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
|
42
|
+
|
|
43
|
+
queue = []
|
|
44
|
+
|
|
45
|
+
# add fields from `print_order` to queue
|
|
46
|
+
for field in print_order:
|
|
47
|
+
_ = (
|
|
48
|
+
queue.append(field)
|
|
49
|
+
if field in cfg
|
|
50
|
+
else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# add all the other fields to queue (not specified in `print_order`)
|
|
54
|
+
for field in cfg:
|
|
55
|
+
if field not in queue:
|
|
56
|
+
queue.append(field)
|
|
57
|
+
|
|
58
|
+
# generate config tree from queue
|
|
59
|
+
for field in queue:
|
|
60
|
+
branch = tree.add(field, style=style, guide_style=style)
|
|
61
|
+
|
|
62
|
+
config_group = cfg[field]
|
|
63
|
+
if isinstance(config_group, DictConfig):
|
|
64
|
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
|
65
|
+
else:
|
|
66
|
+
branch_content = str(config_group)
|
|
67
|
+
|
|
68
|
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
69
|
+
|
|
70
|
+
# print config tree
|
|
71
|
+
rich.print(tree)
|
|
72
|
+
|
|
73
|
+
# save config tree to file
|
|
74
|
+
if save_to_file:
|
|
75
|
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
|
76
|
+
rich.print(tree, file=file)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@rank_zero_only
|
|
80
|
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
|
81
|
+
"""Prompts user to input tags from command line if no tags are provided in config.
|
|
82
|
+
|
|
83
|
+
:param cfg: A DictConfig composed by Hydra.
|
|
84
|
+
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
|
85
|
+
"""
|
|
86
|
+
if not cfg.get("tags"):
|
|
87
|
+
if "id" in HydraConfig().cfg.hydra.job:
|
|
88
|
+
raise ValueError("Specify tags before launching a multirun!")
|
|
89
|
+
|
|
90
|
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
|
91
|
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
|
92
|
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
|
93
|
+
|
|
94
|
+
with open_dict(cfg):
|
|
95
|
+
cfg.tags = tags
|
|
96
|
+
|
|
97
|
+
log.info(f"Tags: {cfg.tags}")
|
|
98
|
+
|
|
99
|
+
if save_to_file:
|
|
100
|
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
|
101
|
+
rich.print(cfg.tags, file=file)
|