xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
import lightning as L
|
|
5
|
+
import rootutils
|
|
6
|
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
7
|
+
from lightning.pytorch.loggers import Logger
|
|
8
|
+
from omegaconf import DictConfig
|
|
9
|
+
|
|
10
|
+
from matcha import utils
|
|
11
|
+
|
|
12
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
13
|
+
# ------------------------------------------------------------------------------------ #
|
|
14
|
+
# the setup_root above is equivalent to:
|
|
15
|
+
# - adding project root dir to PYTHONPATH
|
|
16
|
+
# (so you don't need to force user to install project as a package)
|
|
17
|
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
|
18
|
+
# - setting up PROJECT_ROOT environment variable
|
|
19
|
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
|
20
|
+
# (this way all filepaths are the same no matter where you run the code)
|
|
21
|
+
# - loading environment variables from ".env" in root dir
|
|
22
|
+
#
|
|
23
|
+
# you can remove it if you:
|
|
24
|
+
# 1. either install project as a package or move entry files to project root dir
|
|
25
|
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
|
26
|
+
#
|
|
27
|
+
# more info: https://github.com/ashleve/rootutils
|
|
28
|
+
# ------------------------------------------------------------------------------------ #
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
log = utils.get_pylogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@utils.task_wrapper
|
|
35
|
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
36
|
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
|
37
|
+
training.
|
|
38
|
+
|
|
39
|
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
|
40
|
+
failure. Useful for multiruns, saving info about the crash, etc.
|
|
41
|
+
|
|
42
|
+
:param cfg: A DictConfig configuration composed by Hydra.
|
|
43
|
+
:return: A tuple with metrics and dict with all instantiated objects.
|
|
44
|
+
"""
|
|
45
|
+
# set seed for random number generators in pytorch, numpy and python.random
|
|
46
|
+
if cfg.get("seed"):
|
|
47
|
+
L.seed_everything(cfg.seed, workers=True)
|
|
48
|
+
|
|
49
|
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access
|
|
50
|
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
51
|
+
|
|
52
|
+
log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access
|
|
53
|
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
|
54
|
+
|
|
55
|
+
log.info("Instantiating callbacks...")
|
|
56
|
+
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
|
|
57
|
+
|
|
58
|
+
log.info("Instantiating loggers...")
|
|
59
|
+
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
|
|
60
|
+
|
|
61
|
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access
|
|
62
|
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
|
|
63
|
+
|
|
64
|
+
object_dict = {
|
|
65
|
+
"cfg": cfg,
|
|
66
|
+
"datamodule": datamodule,
|
|
67
|
+
"model": model,
|
|
68
|
+
"callbacks": callbacks,
|
|
69
|
+
"logger": logger,
|
|
70
|
+
"trainer": trainer,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if logger:
|
|
74
|
+
log.info("Logging hyperparameters!")
|
|
75
|
+
utils.log_hyperparameters(object_dict)
|
|
76
|
+
|
|
77
|
+
if cfg.get("train"):
|
|
78
|
+
log.info("Starting training!")
|
|
79
|
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
|
80
|
+
|
|
81
|
+
train_metrics = trainer.callback_metrics
|
|
82
|
+
|
|
83
|
+
if cfg.get("test"):
|
|
84
|
+
log.info("Starting testing!")
|
|
85
|
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
86
|
+
if ckpt_path == "":
|
|
87
|
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
|
88
|
+
ckpt_path = None
|
|
89
|
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
|
90
|
+
log.info(f"Best ckpt path: {ckpt_path}")
|
|
91
|
+
|
|
92
|
+
test_metrics = trainer.callback_metrics
|
|
93
|
+
|
|
94
|
+
# merge train and test metrics
|
|
95
|
+
metric_dict = {**train_metrics, **test_metrics}
|
|
96
|
+
|
|
97
|
+
return metric_dict, object_dict
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
|
|
101
|
+
def main(cfg: DictConfig) -> Optional[float]:
|
|
102
|
+
"""Main entry point for training.
|
|
103
|
+
|
|
104
|
+
:param cfg: DictConfig configuration composed by Hydra.
|
|
105
|
+
:return: Optional[float] with optimized metric value.
|
|
106
|
+
"""
|
|
107
|
+
# apply extra utilities
|
|
108
|
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
|
109
|
+
utils.extras(cfg)
|
|
110
|
+
|
|
111
|
+
# train the model
|
|
112
|
+
metric_dict, _ = train(cfg)
|
|
113
|
+
|
|
114
|
+
# safely retrieve metric value for hydra-based hyperparameter optimization
|
|
115
|
+
metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
|
|
116
|
+
|
|
117
|
+
# return optimized metric
|
|
118
|
+
return metric_value
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == "__main__":
|
|
122
|
+
main() # pylint: disable=no-value-for-parameter
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
|
2
|
+
from matcha.utils.logging_utils import log_hyperparameters
|
|
3
|
+
from matcha.utils.pylogger import get_pylogger
|
|
4
|
+
from matcha.utils.rich_utils import enforce_tags, print_config_tree
|
|
5
|
+
from matcha.utils.utils import extras, get_metric_value, task_wrapper
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.utils.data
|
|
4
|
+
from librosa.filters import mel as librosa_mel_fn
|
|
5
|
+
from scipy.io.wavfile import read
|
|
6
|
+
|
|
7
|
+
MAX_WAV_VALUE = 32768.0
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_wav(full_path):
|
|
11
|
+
sampling_rate, data = read(full_path)
|
|
12
|
+
return data, sampling_rate
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
|
16
|
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def dynamic_range_decompression(x, C=1):
|
|
20
|
+
return np.exp(x) / C
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
|
24
|
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def dynamic_range_decompression_torch(x, C=1):
|
|
28
|
+
return torch.exp(x) / C
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def spectral_normalize_torch(magnitudes):
|
|
32
|
+
output = dynamic_range_compression_torch(magnitudes)
|
|
33
|
+
return output
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def spectral_de_normalize_torch(magnitudes):
|
|
37
|
+
output = dynamic_range_decompression_torch(magnitudes)
|
|
38
|
+
return output
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
mel_basis = {}
|
|
42
|
+
hann_window = {}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
|
46
|
+
if torch.min(y) < -1.0:
|
|
47
|
+
print("min value is ", torch.min(y))
|
|
48
|
+
if torch.max(y) > 1.0:
|
|
49
|
+
print("max value is ", torch.max(y))
|
|
50
|
+
|
|
51
|
+
global mel_basis, hann_window # pylint: disable=global-statement
|
|
52
|
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
|
53
|
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
54
|
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
|
55
|
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
|
56
|
+
|
|
57
|
+
y = torch.nn.functional.pad(
|
|
58
|
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
|
59
|
+
)
|
|
60
|
+
y = y.squeeze(1)
|
|
61
|
+
|
|
62
|
+
spec = torch.view_as_real(
|
|
63
|
+
torch.stft(
|
|
64
|
+
y,
|
|
65
|
+
n_fft,
|
|
66
|
+
hop_length=hop_size,
|
|
67
|
+
win_length=win_size,
|
|
68
|
+
window=hann_window[str(y.device)],
|
|
69
|
+
center=center,
|
|
70
|
+
pad_mode="reflect",
|
|
71
|
+
normalized=False,
|
|
72
|
+
onesided=True,
|
|
73
|
+
return_complex=True,
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
|
78
|
+
|
|
79
|
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
|
80
|
+
spec = spectral_normalize_torch(spec)
|
|
81
|
+
|
|
82
|
+
return spec
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
|
3
|
+
when needed.
|
|
4
|
+
|
|
5
|
+
Parameters from hparam.py will be used
|
|
6
|
+
"""
|
|
7
|
+
import argparse
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import rootutils
|
|
14
|
+
import torch
|
|
15
|
+
from hydra import compose, initialize
|
|
16
|
+
from omegaconf import open_dict
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
|
+
|
|
19
|
+
from matcha.data.text_mel_datamodule import TextMelDataModule
|
|
20
|
+
from matcha.utils.logging_utils import pylogger
|
|
21
|
+
|
|
22
|
+
log = pylogger.get_pylogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
|
|
26
|
+
"""Generate data mean and standard deviation helpful in data normalisation
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data_loader (torch.utils.data.Dataloader): _description_
|
|
30
|
+
out_channels (int): mel spectrogram channels
|
|
31
|
+
"""
|
|
32
|
+
total_mel_sum = 0
|
|
33
|
+
total_mel_sq_sum = 0
|
|
34
|
+
total_mel_len = 0
|
|
35
|
+
|
|
36
|
+
for batch in tqdm(data_loader, leave=False):
|
|
37
|
+
mels = batch["y"]
|
|
38
|
+
mel_lengths = batch["y_lengths"]
|
|
39
|
+
|
|
40
|
+
total_mel_len += torch.sum(mel_lengths)
|
|
41
|
+
total_mel_sum += torch.sum(mels)
|
|
42
|
+
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
|
43
|
+
|
|
44
|
+
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
|
45
|
+
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
|
|
46
|
+
|
|
47
|
+
return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def main():
|
|
51
|
+
parser = argparse.ArgumentParser()
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"-i",
|
|
55
|
+
"--input-config",
|
|
56
|
+
type=str,
|
|
57
|
+
default="vctk.yaml",
|
|
58
|
+
help="The name of the yaml config file under configs/data",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"-b",
|
|
63
|
+
"--batch-size",
|
|
64
|
+
type=int,
|
|
65
|
+
default="256",
|
|
66
|
+
help="Can have increased batch size for faster computation",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"-f",
|
|
71
|
+
"--force",
|
|
72
|
+
action="store_true",
|
|
73
|
+
default=False,
|
|
74
|
+
required=False,
|
|
75
|
+
help="force overwrite the file",
|
|
76
|
+
)
|
|
77
|
+
args = parser.parse_args()
|
|
78
|
+
output_file = Path(args.input_config).with_suffix(".json")
|
|
79
|
+
|
|
80
|
+
if os.path.exists(output_file) and not args.force:
|
|
81
|
+
print("File already exists. Use -f to force overwrite")
|
|
82
|
+
sys.exit(1)
|
|
83
|
+
|
|
84
|
+
with initialize(version_base="1.3", config_path="../../configs/data"):
|
|
85
|
+
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
|
86
|
+
|
|
87
|
+
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
|
88
|
+
|
|
89
|
+
with open_dict(cfg):
|
|
90
|
+
del cfg["hydra"]
|
|
91
|
+
del cfg["_target_"]
|
|
92
|
+
cfg["data_statistics"] = None
|
|
93
|
+
cfg["seed"] = 1234
|
|
94
|
+
cfg["batch_size"] = args.batch_size
|
|
95
|
+
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
|
96
|
+
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
|
97
|
+
cfg["load_durations"] = False
|
|
98
|
+
|
|
99
|
+
text_mel_datamodule = TextMelDataModule(**cfg)
|
|
100
|
+
text_mel_datamodule.setup()
|
|
101
|
+
data_loader = text_mel_datamodule.train_dataloader()
|
|
102
|
+
log.info("Dataloader loaded! Now computing stats...")
|
|
103
|
+
params = compute_data_statistics(data_loader, cfg["n_feats"])
|
|
104
|
+
print(params)
|
|
105
|
+
json.dump(
|
|
106
|
+
params,
|
|
107
|
+
open(output_file, "w"),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
if __name__ == "__main__":
|
|
112
|
+
main()
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
|
|
3
|
+
when needed.
|
|
4
|
+
|
|
5
|
+
Parameters from hparam.py will be used
|
|
6
|
+
"""
|
|
7
|
+
import argparse
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import lightning
|
|
14
|
+
import numpy as np
|
|
15
|
+
import rootutils
|
|
16
|
+
import torch
|
|
17
|
+
from hydra import compose, initialize
|
|
18
|
+
from omegaconf import open_dict
|
|
19
|
+
from torch import nn
|
|
20
|
+
from tqdm.auto import tqdm
|
|
21
|
+
|
|
22
|
+
from matcha.cli import get_device
|
|
23
|
+
from matcha.data.text_mel_datamodule import TextMelDataModule
|
|
24
|
+
from matcha.models.matcha_tts import MatchaTTS
|
|
25
|
+
from matcha.utils.logging_utils import pylogger
|
|
26
|
+
from matcha.utils.utils import get_phoneme_durations
|
|
27
|
+
|
|
28
|
+
log = pylogger.get_pylogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def save_durations_to_folder(
|
|
32
|
+
attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str
|
|
33
|
+
):
|
|
34
|
+
durations = attn.squeeze().sum(1)[:x_length].numpy()
|
|
35
|
+
durations_json = get_phoneme_durations(durations, text)
|
|
36
|
+
output = output_folder / Path(filepath).name.replace(".wav", ".npy")
|
|
37
|
+
with open(output.with_suffix(".json"), "w", encoding="utf-8") as f:
|
|
38
|
+
json.dump(durations_json, f, indent=4, ensure_ascii=False)
|
|
39
|
+
|
|
40
|
+
np.save(output, durations)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@torch.inference_mode()
|
|
44
|
+
def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder):
|
|
45
|
+
"""Generate durations from the model for each datapoint and save it in a folder
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
data_loader (torch.utils.data.DataLoader): Dataloader
|
|
49
|
+
model (nn.Module): MatchaTTS model
|
|
50
|
+
device (torch.device): GPU or CPU
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"):
|
|
54
|
+
x, x_lengths = batch["x"], batch["x_lengths"]
|
|
55
|
+
y, y_lengths = batch["y"], batch["y_lengths"]
|
|
56
|
+
spks = batch["spks"]
|
|
57
|
+
x = x.to(device)
|
|
58
|
+
y = y.to(device)
|
|
59
|
+
x_lengths = x_lengths.to(device)
|
|
60
|
+
y_lengths = y_lengths.to(device)
|
|
61
|
+
spks = spks.to(device) if spks is not None else None
|
|
62
|
+
|
|
63
|
+
_, _, _, attn = model(
|
|
64
|
+
x=x,
|
|
65
|
+
x_lengths=x_lengths,
|
|
66
|
+
y=y,
|
|
67
|
+
y_lengths=y_lengths,
|
|
68
|
+
spks=spks,
|
|
69
|
+
)
|
|
70
|
+
attn = attn.cpu()
|
|
71
|
+
for i in range(attn.shape[0]):
|
|
72
|
+
save_durations_to_folder(
|
|
73
|
+
attn[i],
|
|
74
|
+
x_lengths[i].item(),
|
|
75
|
+
y_lengths[i].item(),
|
|
76
|
+
batch["filepaths"][i],
|
|
77
|
+
output_folder,
|
|
78
|
+
batch["x_texts"][i],
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def main():
|
|
83
|
+
parser = argparse.ArgumentParser()
|
|
84
|
+
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"-i",
|
|
87
|
+
"--input-config",
|
|
88
|
+
type=str,
|
|
89
|
+
default="ljspeech.yaml",
|
|
90
|
+
help="The name of the yaml config file under configs/data",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"-b",
|
|
95
|
+
"--batch-size",
|
|
96
|
+
type=int,
|
|
97
|
+
default="32",
|
|
98
|
+
help="Can have increased batch size for faster computation",
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"-f",
|
|
103
|
+
"--force",
|
|
104
|
+
action="store_true",
|
|
105
|
+
default=False,
|
|
106
|
+
required=False,
|
|
107
|
+
help="force overwrite the file",
|
|
108
|
+
)
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"-c",
|
|
111
|
+
"--checkpoint_path",
|
|
112
|
+
type=str,
|
|
113
|
+
required=True,
|
|
114
|
+
help="Path to the checkpoint file to load the model from",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
parser.add_argument(
|
|
118
|
+
"-o",
|
|
119
|
+
"--output-folder",
|
|
120
|
+
type=str,
|
|
121
|
+
default=None,
|
|
122
|
+
help="Output folder to save the data statistics",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
parser.add_argument(
|
|
126
|
+
"--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
args = parser.parse_args()
|
|
130
|
+
|
|
131
|
+
with initialize(version_base="1.3", config_path="../../configs/data"):
|
|
132
|
+
cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
|
|
133
|
+
|
|
134
|
+
root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
|
|
135
|
+
|
|
136
|
+
with open_dict(cfg):
|
|
137
|
+
del cfg["hydra"]
|
|
138
|
+
del cfg["_target_"]
|
|
139
|
+
cfg["seed"] = 1234
|
|
140
|
+
cfg["batch_size"] = args.batch_size
|
|
141
|
+
cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
|
|
142
|
+
cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
|
|
143
|
+
cfg["load_durations"] = False
|
|
144
|
+
|
|
145
|
+
if args.output_folder is not None:
|
|
146
|
+
output_folder = Path(args.output_folder)
|
|
147
|
+
else:
|
|
148
|
+
output_folder = Path(cfg["train_filelist_path"]).parent / "durations"
|
|
149
|
+
|
|
150
|
+
print(f"Output folder set to: {output_folder}")
|
|
151
|
+
|
|
152
|
+
if os.path.exists(output_folder) and not args.force:
|
|
153
|
+
print("Folder already exists. Use -f to force overwrite")
|
|
154
|
+
sys.exit(1)
|
|
155
|
+
|
|
156
|
+
output_folder.mkdir(parents=True, exist_ok=True)
|
|
157
|
+
|
|
158
|
+
print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}")
|
|
159
|
+
print("Loading model...")
|
|
160
|
+
device = get_device(args)
|
|
161
|
+
model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device)
|
|
162
|
+
|
|
163
|
+
text_mel_datamodule = TextMelDataModule(**cfg)
|
|
164
|
+
text_mel_datamodule.setup()
|
|
165
|
+
try:
|
|
166
|
+
print("Computing stats for training set if exists...")
|
|
167
|
+
train_dataloader = text_mel_datamodule.train_dataloader()
|
|
168
|
+
compute_durations(train_dataloader, model, device, output_folder)
|
|
169
|
+
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
|
170
|
+
print("No training set found")
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
print("Computing stats for validation set if exists...")
|
|
174
|
+
val_dataloader = text_mel_datamodule.val_dataloader()
|
|
175
|
+
compute_durations(val_dataloader, model, device, output_folder)
|
|
176
|
+
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
|
177
|
+
print("No validation set found")
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
print("Computing stats for test set if exists...")
|
|
181
|
+
test_dataloader = text_mel_datamodule.test_dataloader()
|
|
182
|
+
compute_durations(test_dataloader, model, device, output_folder)
|
|
183
|
+
except lightning.fabric.utilities.exceptions.MisconfigurationException:
|
|
184
|
+
print("No test set found")
|
|
185
|
+
|
|
186
|
+
print(f"[+] Done! Data statistics saved to: {output_folder}")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
if __name__ == "__main__":
|
|
190
|
+
# Helps with generating durations for the dataset to train other architectures
|
|
191
|
+
# that cannot learn to align due to limited size of dataset
|
|
192
|
+
# Example usage:
|
|
193
|
+
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
|
|
194
|
+
# This will create a folder in data/processed_data/durations/ljspeech with the durations
|
|
195
|
+
main()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
from lightning import Callback
|
|
5
|
+
from lightning.pytorch.loggers import Logger
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
|
|
8
|
+
from matcha.utils import pylogger
|
|
9
|
+
|
|
10
|
+
log = pylogger.get_pylogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
|
14
|
+
"""Instantiates callbacks from config.
|
|
15
|
+
|
|
16
|
+
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
|
17
|
+
:return: A list of instantiated callbacks.
|
|
18
|
+
"""
|
|
19
|
+
callbacks: List[Callback] = []
|
|
20
|
+
|
|
21
|
+
if not callbacks_cfg:
|
|
22
|
+
log.warning("No callback configs found! Skipping..")
|
|
23
|
+
return callbacks
|
|
24
|
+
|
|
25
|
+
if not isinstance(callbacks_cfg, DictConfig):
|
|
26
|
+
raise TypeError("Callbacks config must be a DictConfig!")
|
|
27
|
+
|
|
28
|
+
for _, cb_conf in callbacks_cfg.items():
|
|
29
|
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
|
30
|
+
log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
|
|
31
|
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
32
|
+
|
|
33
|
+
return callbacks
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
|
37
|
+
"""Instantiates loggers from config.
|
|
38
|
+
|
|
39
|
+
:param logger_cfg: A DictConfig object containing logger configurations.
|
|
40
|
+
:return: A list of instantiated loggers.
|
|
41
|
+
"""
|
|
42
|
+
logger: List[Logger] = []
|
|
43
|
+
|
|
44
|
+
if not logger_cfg:
|
|
45
|
+
log.warning("No logger configs found! Skipping...")
|
|
46
|
+
return logger
|
|
47
|
+
|
|
48
|
+
if not isinstance(logger_cfg, DictConfig):
|
|
49
|
+
raise TypeError("Logger config must be a DictConfig!")
|
|
50
|
+
|
|
51
|
+
for _, lg_conf in logger_cfg.items():
|
|
52
|
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
|
53
|
+
log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
|
|
54
|
+
logger.append(hydra.utils.instantiate(lg_conf))
|
|
55
|
+
|
|
56
|
+
return logger
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
4
|
+
from omegaconf import OmegaConf
|
|
5
|
+
|
|
6
|
+
from matcha.utils import pylogger
|
|
7
|
+
|
|
8
|
+
log = pylogger.get_pylogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@rank_zero_only
|
|
12
|
+
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
|
13
|
+
"""Controls which config parts are saved by Lightning loggers.
|
|
14
|
+
|
|
15
|
+
Additionally saves:
|
|
16
|
+
- Number of model parameters
|
|
17
|
+
|
|
18
|
+
:param object_dict: A dictionary containing the following objects:
|
|
19
|
+
- `"cfg"`: A DictConfig object containing the main config.
|
|
20
|
+
- `"model"`: The Lightning model.
|
|
21
|
+
- `"trainer"`: The Lightning trainer.
|
|
22
|
+
"""
|
|
23
|
+
hparams = {}
|
|
24
|
+
|
|
25
|
+
cfg = OmegaConf.to_container(object_dict["cfg"])
|
|
26
|
+
model = object_dict["model"]
|
|
27
|
+
trainer = object_dict["trainer"]
|
|
28
|
+
|
|
29
|
+
if not trainer.logger:
|
|
30
|
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
hparams["model"] = cfg["model"]
|
|
34
|
+
|
|
35
|
+
# save number of model parameters
|
|
36
|
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
|
37
|
+
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
38
|
+
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
39
|
+
|
|
40
|
+
hparams["data"] = cfg["data"]
|
|
41
|
+
hparams["trainer"] = cfg["trainer"]
|
|
42
|
+
|
|
43
|
+
hparams["callbacks"] = cfg.get("callbacks")
|
|
44
|
+
hparams["extras"] = cfg.get("extras")
|
|
45
|
+
|
|
46
|
+
hparams["task_name"] = cfg.get("task_name")
|
|
47
|
+
hparams["tags"] = cfg.get("tags")
|
|
48
|
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
|
49
|
+
hparams["seed"] = cfg.get("seed")
|
|
50
|
+
|
|
51
|
+
# send hparams to all loggers
|
|
52
|
+
for logger in trainer.loggers:
|
|
53
|
+
logger.log_hyperparams(hparams)
|