minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
matcha/train.py ADDED
@@ -0,0 +1,122 @@
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import hydra
4
+ import lightning as L
5
+ import rootutils
6
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
7
+ from lightning.pytorch.loggers import Logger
8
+ from omegaconf import DictConfig
9
+
10
+ from matcha import utils
11
+
12
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
13
+ # ------------------------------------------------------------------------------------ #
14
+ # the setup_root above is equivalent to:
15
+ # - adding project root dir to PYTHONPATH
16
+ # (so you don't need to force user to install project as a package)
17
+ # (necessary before importing any local modules e.g. `from src import utils`)
18
+ # - setting up PROJECT_ROOT environment variable
19
+ # (which is used as a base for paths in "configs/paths/default.yaml")
20
+ # (this way all filepaths are the same no matter where you run the code)
21
+ # - loading environment variables from ".env" in root dir
22
+ #
23
+ # you can remove it if you:
24
+ # 1. either install project as a package or move entry files to project root dir
25
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
26
+ #
27
+ # more info: https://github.com/ashleve/rootutils
28
+ # ------------------------------------------------------------------------------------ #
29
+
30
+
31
+ log = utils.get_pylogger(__name__)
32
+
33
+
34
+ @utils.task_wrapper
35
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
36
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
37
+ training.
38
+
39
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
40
+ failure. Useful for multiruns, saving info about the crash, etc.
41
+
42
+ :param cfg: A DictConfig configuration composed by Hydra.
43
+ :return: A tuple with metrics and dict with all instantiated objects.
44
+ """
45
+ # set seed for random number generators in pytorch, numpy and python.random
46
+ if cfg.get("seed"):
47
+ L.seed_everything(cfg.seed, workers=True)
48
+
49
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access
50
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
51
+
52
+ log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access
53
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
54
+
55
+ log.info("Instantiating callbacks...")
56
+ callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
57
+
58
+ log.info("Instantiating loggers...")
59
+ logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
60
+
61
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access
62
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
63
+
64
+ object_dict = {
65
+ "cfg": cfg,
66
+ "datamodule": datamodule,
67
+ "model": model,
68
+ "callbacks": callbacks,
69
+ "logger": logger,
70
+ "trainer": trainer,
71
+ }
72
+
73
+ if logger:
74
+ log.info("Logging hyperparameters!")
75
+ utils.log_hyperparameters(object_dict)
76
+
77
+ if cfg.get("train"):
78
+ log.info("Starting training!")
79
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
80
+
81
+ train_metrics = trainer.callback_metrics
82
+
83
+ if cfg.get("test"):
84
+ log.info("Starting testing!")
85
+ ckpt_path = trainer.checkpoint_callback.best_model_path
86
+ if ckpt_path == "":
87
+ log.warning("Best ckpt not found! Using current weights for testing...")
88
+ ckpt_path = None
89
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
90
+ log.info(f"Best ckpt path: {ckpt_path}")
91
+
92
+ test_metrics = trainer.callback_metrics
93
+
94
+ # merge train and test metrics
95
+ metric_dict = {**train_metrics, **test_metrics}
96
+
97
+ return metric_dict, object_dict
98
+
99
+
100
+ @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
101
+ def main(cfg: DictConfig) -> Optional[float]:
102
+ """Main entry point for training.
103
+
104
+ :param cfg: DictConfig configuration composed by Hydra.
105
+ :return: Optional[float] with optimized metric value.
106
+ """
107
+ # apply extra utilities
108
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
109
+ utils.extras(cfg)
110
+
111
+ # train the model
112
+ metric_dict, _ = train(cfg)
113
+
114
+ # safely retrieve metric value for hydra-based hyperparameter optimization
115
+ metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
116
+
117
+ # return optimized metric
118
+ return metric_value
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main() # pylint: disable=no-value-for-parameter
@@ -0,0 +1,5 @@
1
+ from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
2
+ from matcha.utils.logging_utils import log_hyperparameters
3
+ from matcha.utils.pylogger import get_pylogger
4
+ from matcha.utils.rich_utils import enforce_tags, print_config_tree
5
+ from matcha.utils.utils import extras, get_metric_value, task_wrapper
matcha/utils/audio.py ADDED
@@ -0,0 +1,82 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
@@ -0,0 +1,111 @@
1
+ r"""
2
+ The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
3
+ when needed.
4
+
5
+ Parameters from hparam.py will be used
6
+ """
7
+ import argparse
8
+ import json
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import rootutils
14
+ import torch
15
+ from hydra import compose, initialize
16
+ from omegaconf import open_dict
17
+ from tqdm.auto import tqdm
18
+
19
+ from matcha.data.text_mel_datamodule import TextMelDataModule
20
+ from matcha.utils.logging_utils import pylogger
21
+
22
+ log = pylogger.get_pylogger(__name__)
23
+
24
+
25
+ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
26
+ """Generate data mean and standard deviation helpful in data normalisation
27
+
28
+ Args:
29
+ data_loader (torch.utils.data.Dataloader): _description_
30
+ out_channels (int): mel spectrogram channels
31
+ """
32
+ total_mel_sum = 0
33
+ total_mel_sq_sum = 0
34
+ total_mel_len = 0
35
+
36
+ for batch in tqdm(data_loader, leave=False):
37
+ mels = batch["y"]
38
+ mel_lengths = batch["y_lengths"]
39
+
40
+ total_mel_len += torch.sum(mel_lengths)
41
+ total_mel_sum += torch.sum(mels)
42
+ total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
43
+
44
+ data_mean = total_mel_sum / (total_mel_len * out_channels)
45
+ data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
46
+
47
+ return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
48
+
49
+
50
+ def main():
51
+ parser = argparse.ArgumentParser()
52
+
53
+ parser.add_argument(
54
+ "-i",
55
+ "--input-config",
56
+ type=str,
57
+ default="vctk.yaml",
58
+ help="The name of the yaml config file under configs/data",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "-b",
63
+ "--batch-size",
64
+ type=int,
65
+ default="256",
66
+ help="Can have increased batch size for faster computation",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "-f",
71
+ "--force",
72
+ action="store_true",
73
+ default=False,
74
+ required=False,
75
+ help="force overwrite the file",
76
+ )
77
+ args = parser.parse_args()
78
+ output_file = Path(args.input_config).with_suffix(".json")
79
+
80
+ if os.path.exists(output_file) and not args.force:
81
+ print("File already exists. Use -f to force overwrite")
82
+ sys.exit(1)
83
+
84
+ with initialize(version_base="1.3", config_path="../../configs/data"):
85
+ cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
86
+
87
+ root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
88
+
89
+ with open_dict(cfg):
90
+ del cfg["hydra"]
91
+ del cfg["_target_"]
92
+ cfg["data_statistics"] = None
93
+ cfg["seed"] = 1234
94
+ cfg["batch_size"] = args.batch_size
95
+ cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
96
+ cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
97
+
98
+ text_mel_datamodule = TextMelDataModule(**cfg)
99
+ text_mel_datamodule.setup()
100
+ data_loader = text_mel_datamodule.train_dataloader()
101
+ log.info("Dataloader loaded! Now computing stats...")
102
+ params = compute_data_statistics(data_loader, cfg["n_feats"])
103
+ print(params)
104
+ json.dump(
105
+ params,
106
+ open(output_file, "w"),
107
+ )
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
@@ -0,0 +1,56 @@
1
+ from typing import List
2
+
3
+ import hydra
4
+ from lightning import Callback
5
+ from lightning.pytorch.loggers import Logger
6
+ from omegaconf import DictConfig
7
+
8
+ from matcha.utils import pylogger
9
+
10
+ log = pylogger.get_pylogger(__name__)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config.
15
+
16
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
17
+ :return: A list of instantiated callbacks.
18
+ """
19
+ callbacks: List[Callback] = []
20
+
21
+ if not callbacks_cfg:
22
+ log.warning("No callback configs found! Skipping..")
23
+ return callbacks
24
+
25
+ if not isinstance(callbacks_cfg, DictConfig):
26
+ raise TypeError("Callbacks config must be a DictConfig!")
27
+
28
+ for _, cb_conf in callbacks_cfg.items():
29
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
30
+ log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
31
+ callbacks.append(hydra.utils.instantiate(cb_conf))
32
+
33
+ return callbacks
34
+
35
+
36
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
37
+ """Instantiates loggers from config.
38
+
39
+ :param logger_cfg: A DictConfig object containing logger configurations.
40
+ :return: A list of instantiated loggers.
41
+ """
42
+ logger: List[Logger] = []
43
+
44
+ if not logger_cfg:
45
+ log.warning("No logger configs found! Skipping...")
46
+ return logger
47
+
48
+ if not isinstance(logger_cfg, DictConfig):
49
+ raise TypeError("Logger config must be a DictConfig!")
50
+
51
+ for _, lg_conf in logger_cfg.items():
52
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
53
+ log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
54
+ logger.append(hydra.utils.instantiate(lg_conf))
55
+
56
+ return logger
@@ -0,0 +1,53 @@
1
+ from typing import Any, Dict
2
+
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+ from omegaconf import OmegaConf
5
+
6
+ from matcha.utils import pylogger
7
+
8
+ log = pylogger.get_pylogger(__name__)
9
+
10
+
11
+ @rank_zero_only
12
+ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
13
+ """Controls which config parts are saved by Lightning loggers.
14
+
15
+ Additionally saves:
16
+ - Number of model parameters
17
+
18
+ :param object_dict: A dictionary containing the following objects:
19
+ - `"cfg"`: A DictConfig object containing the main config.
20
+ - `"model"`: The Lightning model.
21
+ - `"trainer"`: The Lightning trainer.
22
+ """
23
+ hparams = {}
24
+
25
+ cfg = OmegaConf.to_container(object_dict["cfg"])
26
+ model = object_dict["model"]
27
+ trainer = object_dict["trainer"]
28
+
29
+ if not trainer.logger:
30
+ log.warning("Logger not found! Skipping hyperparameter logging...")
31
+ return
32
+
33
+ hparams["model"] = cfg["model"]
34
+
35
+ # save number of model parameters
36
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37
+ hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
39
+
40
+ hparams["data"] = cfg["data"]
41
+ hparams["trainer"] = cfg["trainer"]
42
+
43
+ hparams["callbacks"] = cfg.get("callbacks")
44
+ hparams["extras"] = cfg.get("extras")
45
+
46
+ hparams["task_name"] = cfg.get("task_name")
47
+ hparams["tags"] = cfg.get("tags")
48
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
49
+ hparams["seed"] = cfg.get("seed")
50
+
51
+ # send hparams to all loggers
52
+ for logger in trainer.loggers:
53
+ logger.log_hyperparams(hparams)
matcha/utils/model.py ADDED
@@ -0,0 +1,90 @@
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def sequence_mask(length, max_length=None):
8
+ if max_length is None:
9
+ max_length = length.max()
10
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
+ return x.unsqueeze(0) < length.unsqueeze(1)
12
+
13
+
14
+ def fix_len_compatibility(length, num_downsamplings_in_unet=2):
15
+ factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
16
+ length = (length / factor).ceil() * factor
17
+ if not torch.onnx.is_in_onnx_export():
18
+ return length.int().item()
19
+ else:
20
+ return length
21
+
22
+
23
+ def convert_pad_shape(pad_shape):
24
+ inverted_shape = pad_shape[::-1]
25
+ pad_shape = [item for sublist in inverted_shape for item in sublist]
26
+ return pad_shape
27
+
28
+
29
+ def generate_path(duration, mask):
30
+ device = duration.device
31
+
32
+ b, t_x, t_y = mask.shape
33
+ cum_duration = torch.cumsum(duration, 1)
34
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
35
+
36
+ cum_duration_flat = cum_duration.view(b * t_x)
37
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
38
+ path = path.view(b, t_x, t_y)
39
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
40
+ path = path * mask
41
+ return path
42
+
43
+
44
+ def duration_loss(logw, logw_, lengths):
45
+ loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
46
+ return loss
47
+
48
+
49
+ def normalize(data, mu, std):
50
+ if not isinstance(mu, (float, int)):
51
+ if isinstance(mu, list):
52
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
53
+ elif isinstance(mu, torch.Tensor):
54
+ mu = mu.to(data.device)
55
+ elif isinstance(mu, np.ndarray):
56
+ mu = torch.from_numpy(mu).to(data.device)
57
+ mu = mu.unsqueeze(-1)
58
+
59
+ if not isinstance(std, (float, int)):
60
+ if isinstance(std, list):
61
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
62
+ elif isinstance(std, torch.Tensor):
63
+ std = std.to(data.device)
64
+ elif isinstance(std, np.ndarray):
65
+ std = torch.from_numpy(std).to(data.device)
66
+ std = std.unsqueeze(-1)
67
+
68
+ return (data - mu) / std
69
+
70
+
71
+ def denormalize(data, mu, std):
72
+ if not isinstance(mu, float):
73
+ if isinstance(mu, list):
74
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
75
+ elif isinstance(mu, torch.Tensor):
76
+ mu = mu.to(data.device)
77
+ elif isinstance(mu, np.ndarray):
78
+ mu = torch.from_numpy(mu).to(data.device)
79
+ mu = mu.unsqueeze(-1)
80
+
81
+ if not isinstance(std, float):
82
+ if isinstance(std, list):
83
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
84
+ elif isinstance(std, torch.Tensor):
85
+ std = std.to(data.device)
86
+ elif isinstance(std, np.ndarray):
87
+ std = torch.from_numpy(std).to(data.device)
88
+ std = std.unsqueeze(-1)
89
+
90
+ return data * std + mu
@@ -0,0 +1,22 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from matcha.utils.monotonic_align.core import maximum_path_c
5
+
6
+
7
+ def maximum_path(value, mask):
8
+ """Cython optimised version.
9
+ value: [b, t_x, t_y]
10
+ mask: [b, t_x, t_y]
11
+ """
12
+ value = value * mask
13
+ device = value.device
14
+ dtype = value.dtype
15
+ value = value.data.cpu().numpy().astype(np.float32)
16
+ path = np.zeros_like(value).astype(np.int32)
17
+ mask = mask.data.cpu().numpy()
18
+
19
+ t_x_max = mask.sum(1)[:, 0].astype(np.int32)
20
+ t_y_max = mask.sum(2)[:, 0].astype(np.int32)
21
+ maximum_path_c(path, value, t_x_max, t_y_max)
22
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
@@ -0,0 +1,7 @@
1
+ # from distutils.core import setup
2
+ # from Cython.Build import cythonize
3
+ # import numpy
4
+
5
+ # setup(name='monotonic_align',
6
+ # ext_modules=cythonize("core.pyx"),
7
+ # include_dirs=[numpy.get_include()])
@@ -0,0 +1,21 @@
1
+ import logging
2
+
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+
5
+
6
+ def get_pylogger(name: str = __name__) -> logging.Logger:
7
+ """Initializes a multi-GPU-friendly python command line logger.
8
+
9
+ :param name: The name of the logger, defaults to ``__name__``.
10
+
11
+ :return: A logger object.
12
+ """
13
+ logger = logging.getLogger(name)
14
+
15
+ # this ensures all logging levels get marked with the rank zero decorator
16
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
17
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
18
+ for level in logging_levels:
19
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
20
+
21
+ return logger
@@ -0,0 +1,101 @@
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from matcha.utils import pylogger
13
+
14
+ log = pylogger.get_pylogger(__name__)
15
+
16
+
17
+ @rank_zero_only
18
+ def print_config_tree(
19
+ cfg: DictConfig,
20
+ print_order: Sequence[str] = (
21
+ "data",
22
+ "model",
23
+ "callbacks",
24
+ "logger",
25
+ "trainer",
26
+ "paths",
27
+ "extras",
28
+ ),
29
+ resolve: bool = False,
30
+ save_to_file: bool = False,
31
+ ) -> None:
32
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
33
+
34
+ :param cfg: A DictConfig composed by Hydra.
35
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
36
+ "callbacks", "logger", "trainer", "paths", "extras")``.
37
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
38
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
39
+ """
40
+ style = "dim"
41
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
42
+
43
+ queue = []
44
+
45
+ # add fields from `print_order` to queue
46
+ for field in print_order:
47
+ _ = (
48
+ queue.append(field)
49
+ if field in cfg
50
+ else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
51
+ )
52
+
53
+ # add all the other fields to queue (not specified in `print_order`)
54
+ for field in cfg:
55
+ if field not in queue:
56
+ queue.append(field)
57
+
58
+ # generate config tree from queue
59
+ for field in queue:
60
+ branch = tree.add(field, style=style, guide_style=style)
61
+
62
+ config_group = cfg[field]
63
+ if isinstance(config_group, DictConfig):
64
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
65
+ else:
66
+ branch_content = str(config_group)
67
+
68
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
69
+
70
+ # print config tree
71
+ rich.print(tree)
72
+
73
+ # save config tree to file
74
+ if save_to_file:
75
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
76
+ rich.print(tree, file=file)
77
+
78
+
79
+ @rank_zero_only
80
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
81
+ """Prompts user to input tags from command line if no tags are provided in config.
82
+
83
+ :param cfg: A DictConfig composed by Hydra.
84
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
85
+ """
86
+ if not cfg.get("tags"):
87
+ if "id" in HydraConfig().cfg.hydra.job:
88
+ raise ValueError("Specify tags before launching a multirun!")
89
+
90
+ log.warning("No tags provided in config. Prompting user to input tags...")
91
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
92
+ tags = [t.strip() for t in tags.split(",") if t != ""]
93
+
94
+ with open_dict(cfg):
95
+ cfg.tags = tags
96
+
97
+ log.info(f"Tags: {cfg.tags}")
98
+
99
+ if save_to_file:
100
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
101
+ rich.print(cfg.tags, file=file)