minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
|
2
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
16
|
+
"""Unility functions for Transformer."""
|
|
17
|
+
|
|
18
|
+
import random
|
|
19
|
+
from typing import List
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
IGNORE_ID = -1
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
|
28
|
+
"""Perform padding for the list of tensors.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
|
32
|
+
pad_value (float): Value for padding.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Tensor: Padded tensor (B, Tmax, `*`).
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
|
39
|
+
>>> x
|
|
40
|
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
|
41
|
+
>>> pad_list(x, 0)
|
|
42
|
+
tensor([[1., 1., 1., 1.],
|
|
43
|
+
[1., 1., 0., 0.],
|
|
44
|
+
[1., 0., 0., 0.]])
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
max_len = max([len(item) for item in xs])
|
|
48
|
+
batchs = len(xs)
|
|
49
|
+
ndim = xs[0].ndim
|
|
50
|
+
if ndim == 1:
|
|
51
|
+
pad_res = torch.zeros(batchs,
|
|
52
|
+
max_len,
|
|
53
|
+
dtype=xs[0].dtype,
|
|
54
|
+
device=xs[0].device)
|
|
55
|
+
elif ndim == 2:
|
|
56
|
+
pad_res = torch.zeros(batchs,
|
|
57
|
+
max_len,
|
|
58
|
+
xs[0].shape[1],
|
|
59
|
+
dtype=xs[0].dtype,
|
|
60
|
+
device=xs[0].device)
|
|
61
|
+
elif ndim == 3:
|
|
62
|
+
pad_res = torch.zeros(batchs,
|
|
63
|
+
max_len,
|
|
64
|
+
xs[0].shape[1],
|
|
65
|
+
xs[0].shape[2],
|
|
66
|
+
dtype=xs[0].dtype,
|
|
67
|
+
device=xs[0].device)
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Unsupported ndim: {ndim}")
|
|
70
|
+
pad_res.fill_(pad_value)
|
|
71
|
+
for i in range(batchs):
|
|
72
|
+
pad_res[i, :len(xs[i])] = xs[i]
|
|
73
|
+
return pad_res
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_padding(kernel_size, dilation=1):
|
|
77
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def init_weights(m, mean=0.0, std=0.01):
|
|
81
|
+
classname = m.__class__.__name__
|
|
82
|
+
if classname.find("Conv") != -1:
|
|
83
|
+
m.weight.data.normal_(mean, std)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
|
87
|
+
device = fade_in_mel.device
|
|
88
|
+
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
|
89
|
+
mel_overlap_len = int(window.shape[0] / 2)
|
|
90
|
+
if fade_in_mel.device == torch.device('cpu'):
|
|
91
|
+
fade_in_mel = fade_in_mel.clone()
|
|
92
|
+
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
|
93
|
+
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
|
94
|
+
return fade_in_mel.to(device)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def set_all_random_seed(seed):
|
|
98
|
+
random.seed(seed)
|
|
99
|
+
np.random.seed(seed)
|
|
100
|
+
torch.manual_seed(seed)
|
|
101
|
+
torch.cuda.manual_seed_all(seed)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
import torch
|
|
19
|
+
from typing import List
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
23
|
+
"""Make mask tensor containing indices of padded part.
|
|
24
|
+
|
|
25
|
+
See description of make_non_pad_mask.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
lengths (torch.Tensor): Batch of lengths (B,).
|
|
29
|
+
Returns:
|
|
30
|
+
torch.Tensor: Mask tensor containing indices of padded part.
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
>>> lengths = [5, 3, 2]
|
|
34
|
+
>>> make_pad_mask(lengths)
|
|
35
|
+
masks = [[0, 0, 0, 0 ,0],
|
|
36
|
+
[0, 0, 0, 1, 1],
|
|
37
|
+
[0, 0, 1, 1, 1]]
|
|
38
|
+
"""
|
|
39
|
+
batch_size = lengths.size(0)
|
|
40
|
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
41
|
+
seq_range = torch.arange(0,
|
|
42
|
+
max_len,
|
|
43
|
+
dtype=torch.int64,
|
|
44
|
+
device=lengths.device)
|
|
45
|
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
46
|
+
seq_length_expand = lengths.unsqueeze(-1)
|
|
47
|
+
mask = seq_range_expand >= seq_length_expand
|
|
48
|
+
return mask
|
|
49
|
+
|
|
File without changes
|
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
""" Example Usage: see README.md
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import random
|
|
21
|
+
import sys
|
|
22
|
+
import time
|
|
23
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
24
|
+
from datetime import datetime
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import onnxruntime
|
|
28
|
+
import s3tokenizer
|
|
29
|
+
import torch
|
|
30
|
+
import torch.distributed as dist
|
|
31
|
+
import torchaudio
|
|
32
|
+
import torchaudio.compliance.kaldi as kaldi
|
|
33
|
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
|
34
|
+
from tqdm import tqdm
|
|
35
|
+
|
|
36
|
+
from stepaudio2.flashcosyvoice.config import Config, CosyVoice2LLMConfig, SamplingParams
|
|
37
|
+
from stepaudio2.flashcosyvoice.cosyvoice2 import CosyVoice2
|
|
38
|
+
from stepaudio2.flashcosyvoice.utils.audio import mel_spectrogram
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def set_all_random_seed(seed):
|
|
42
|
+
random.seed(seed)
|
|
43
|
+
np.random.seed(seed)
|
|
44
|
+
torch.manual_seed(seed)
|
|
45
|
+
torch.cuda.manual_seed_all(seed)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def save_file_async(
|
|
49
|
+
wav, prompt_speech_tokens, generated_speech_tokens,
|
|
50
|
+
info, timing_stats
|
|
51
|
+
):
|
|
52
|
+
"""Save audio asynchronously."""
|
|
53
|
+
try:
|
|
54
|
+
os.makedirs(os.path.dirname(info['wav']), exist_ok=True)
|
|
55
|
+
if wav is not None:
|
|
56
|
+
wav = wav.cpu()
|
|
57
|
+
torchaudio.save(info['wav'], wav, 24000)
|
|
58
|
+
duration = wav.shape[-1] / 24000.0
|
|
59
|
+
rtf = ((timing_stats['dataloader_time'] + timing_stats['model_inference_time']) / timing_stats['batch_size']) / duration
|
|
60
|
+
timing_stats['rtf'] = rtf
|
|
61
|
+
else:
|
|
62
|
+
duration = 0.0
|
|
63
|
+
info['timing_stats'] = timing_stats
|
|
64
|
+
info['prompt_speech_tokens'] = prompt_speech_tokens
|
|
65
|
+
info['generated_speech_tokens'] = generated_speech_tokens
|
|
66
|
+
with open(f"{info['wav'].replace('.wav', '.json')}", "w") as f:
|
|
67
|
+
json.dump(info, f, ensure_ascii=False, indent=4)
|
|
68
|
+
return duration
|
|
69
|
+
except Exception as e:
|
|
70
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
71
|
+
tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}")
|
|
72
|
+
return 0.0
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AudioDataset(Dataset):
|
|
76
|
+
|
|
77
|
+
def __init__(self, text_norm, text_tokenizer, data_list, model_config: Config):
|
|
78
|
+
self.datas = []
|
|
79
|
+
self.text_norm = text_norm
|
|
80
|
+
self.model_config = model_config
|
|
81
|
+
|
|
82
|
+
"""Example data_list:
|
|
83
|
+
```
|
|
84
|
+
{"key": "uttid_1", "prompt_text": "你好,我是小明。", "text": "你好,我是小红。", "prompt_wav": "/mnt/data/audio/00000000.wav", "wav": "/mnt/data/audio_synthetic/uttid_1.wav"}
|
|
85
|
+
{"key": "uttid_2", "prompt_text": "你好,我是小红。", "text": "你好,我是小明。", "prompt_wav": "/mnt/data/audio/00000001.wav", "wav": "/mnt/data/audio_synthetic/uttid_2.wav"}
|
|
86
|
+
```
|
|
87
|
+
Note:
|
|
88
|
+
- `key` is the key of this sample.
|
|
89
|
+
- `prompt_text` is the text used for prompt.
|
|
90
|
+
- `text` is the text used for generating real audio.
|
|
91
|
+
- `prompt_wav` is the audio used for prompt.
|
|
92
|
+
- `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script).
|
|
93
|
+
"""
|
|
94
|
+
missing = 0
|
|
95
|
+
with open(data_list, 'r', encoding='utf-8') as f:
|
|
96
|
+
lines = f.readlines()
|
|
97
|
+
total_lines = len(lines)
|
|
98
|
+
if torch.distributed.get_node_local_rank() == 0:
|
|
99
|
+
iterator = tqdm(lines, desc='Loading data')
|
|
100
|
+
else:
|
|
101
|
+
iterator = lines
|
|
102
|
+
for line in iterator:
|
|
103
|
+
data = json.loads(line.strip())
|
|
104
|
+
valid = True
|
|
105
|
+
for k in ['key', 'prompt_text', 'text', 'prompt_wav']:
|
|
106
|
+
if k not in data:
|
|
107
|
+
valid = False
|
|
108
|
+
break
|
|
109
|
+
if data[k] is None:
|
|
110
|
+
valid = False
|
|
111
|
+
break
|
|
112
|
+
if not os.path.exists(data['prompt_wav']):
|
|
113
|
+
valid = False
|
|
114
|
+
if valid:
|
|
115
|
+
self.datas.append(data)
|
|
116
|
+
else:
|
|
117
|
+
missing += 1
|
|
118
|
+
if torch.distributed.get_node_local_rank() == 0:
|
|
119
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
120
|
+
tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.')
|
|
121
|
+
|
|
122
|
+
self.text_tokenizer = text_tokenizer
|
|
123
|
+
|
|
124
|
+
option = onnxruntime.SessionOptions()
|
|
125
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
126
|
+
option.intra_op_num_threads = 1
|
|
127
|
+
self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option,
|
|
128
|
+
providers=["CPUExecutionProvider"])
|
|
129
|
+
|
|
130
|
+
def __len__(self):
|
|
131
|
+
return len(self.datas)
|
|
132
|
+
|
|
133
|
+
def __getitem__(self, idx):
|
|
134
|
+
data = self.datas[idx]
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
# 1. feature for s3tokenizer
|
|
138
|
+
audio = s3tokenizer.load_audio(data['prompt_wav'], sr=16000) # [T]
|
|
139
|
+
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
|
140
|
+
|
|
141
|
+
# 2. feature for speaker embedding
|
|
142
|
+
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
|
143
|
+
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
|
144
|
+
spk_emb = self.spk_model.run(
|
|
145
|
+
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
|
146
|
+
)[0].flatten().tolist()
|
|
147
|
+
|
|
148
|
+
# 3. feature for flow
|
|
149
|
+
audio, sample_rate = torchaudio.load(data['prompt_wav'], backend='soundfile')
|
|
150
|
+
audio = audio.mean(dim=0, keepdim=True) # [1, T]
|
|
151
|
+
if sample_rate != 24000:
|
|
152
|
+
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
|
153
|
+
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
|
154
|
+
mel_len = mel.shape[0]
|
|
155
|
+
|
|
156
|
+
# 4. feature for llm
|
|
157
|
+
if self.text_norm is not None:
|
|
158
|
+
prompt_texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['prompt_text'].strip()))["sentences"]]
|
|
159
|
+
prompt_text = ''.join(prompt_texts)
|
|
160
|
+
texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['text'].strip()))["sentences"]]
|
|
161
|
+
text = ''.join(texts)
|
|
162
|
+
else:
|
|
163
|
+
prompt_text = data['prompt_text']
|
|
164
|
+
text = data['text']
|
|
165
|
+
prompt_text_ids = self.text_tokenizer.encode(prompt_text)
|
|
166
|
+
prompt_text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in prompt_text_ids]
|
|
167
|
+
text_ids = self.text_tokenizer.encode(text)
|
|
168
|
+
text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in text_ids]
|
|
169
|
+
item = {
|
|
170
|
+
"prompt_text_tokens": prompt_text_ids, "text_tokens": text_ids,
|
|
171
|
+
"spk_emb": spk_emb, "mel": mel, "mel_len": mel_len, "log_mel": log_mel, "info": data,
|
|
172
|
+
"min_tokens": len(text_ids) * self.model_config.min_token_text_ratio,
|
|
173
|
+
"max_tokens": len(text_ids) * self.model_config.max_token_text_ratio,
|
|
174
|
+
}
|
|
175
|
+
except Exception as e:
|
|
176
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
177
|
+
tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}")
|
|
178
|
+
return None
|
|
179
|
+
return item
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def collate_fn(batch):
|
|
183
|
+
prompt_mels_for_llm = [item["log_mel"] for item in batch if item is not None]
|
|
184
|
+
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_mels_for_llm) # [B, num_mels=128, T]
|
|
185
|
+
prompt_text_tokens_for_llm = [item["prompt_text_tokens"] for item in batch if item is not None]
|
|
186
|
+
text_tokens_for_llm = [item["text_tokens"] for item in batch if item is not None]
|
|
187
|
+
prompt_mels_for_flow = [item["mel"] for item in batch if item is not None]
|
|
188
|
+
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
|
189
|
+
prompt_mels_lens_for_flow = [item["mel_len"] for item in batch if item is not None]
|
|
190
|
+
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
|
191
|
+
spk_emb_for_flow = [item["spk_emb"] for item in batch if item is not None]
|
|
192
|
+
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
|
193
|
+
sampling_params = [SamplingParams(min_tokens=item["min_tokens"], max_tokens=item["max_tokens"], use_ras=True) for item in batch if item is not None]
|
|
194
|
+
infos = [item["info"] for item in batch if item is not None]
|
|
195
|
+
return {
|
|
196
|
+
"prompt_mels_for_llm": prompt_mels_for_llm,
|
|
197
|
+
"prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
|
|
198
|
+
"prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
|
|
199
|
+
"text_tokens_for_llm": text_tokens_for_llm,
|
|
200
|
+
"prompt_mels_for_flow": prompt_mels_for_flow,
|
|
201
|
+
"prompt_mels_lens_for_flow": prompt_mels_lens_for_flow,
|
|
202
|
+
"spk_emb_for_flow": spk_emb_for_flow,
|
|
203
|
+
"sampling_params": sampling_params,
|
|
204
|
+
"infos": infos,
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def init_distributed():
|
|
209
|
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
210
|
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
|
211
|
+
rank = int(os.environ.get('RANK', 0))
|
|
212
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
213
|
+
tqdm.write(f'[{timestamp}] - [INFO] - Inference on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}')
|
|
214
|
+
torch.cuda.set_device(local_rank)
|
|
215
|
+
dist.init_process_group("nccl")
|
|
216
|
+
return world_size, local_rank, rank
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def get_args():
|
|
220
|
+
parser = argparse.ArgumentParser(description='FlashCosyVoice')
|
|
221
|
+
parser.add_argument('--model_path',
|
|
222
|
+
required=True,
|
|
223
|
+
type=str,
|
|
224
|
+
help='model path')
|
|
225
|
+
parser.add_argument('--data_list',
|
|
226
|
+
required=True,
|
|
227
|
+
type=str,
|
|
228
|
+
help='data list')
|
|
229
|
+
parser.add_argument('--batch_size_dataloader',
|
|
230
|
+
required=True,
|
|
231
|
+
type=int,
|
|
232
|
+
help='batch size (per-device) for dataloading')
|
|
233
|
+
parser.add_argument('--batch_size_flow',
|
|
234
|
+
required=True,
|
|
235
|
+
type=int,
|
|
236
|
+
help='batch size (per-device) for flow-matching')
|
|
237
|
+
parser.add_argument('--num_workers',
|
|
238
|
+
type=int,
|
|
239
|
+
default=4,
|
|
240
|
+
help='workers for dataloader')
|
|
241
|
+
parser.add_argument('--prefetch',
|
|
242
|
+
type=int,
|
|
243
|
+
default=5,
|
|
244
|
+
help='prefetch for dataloader')
|
|
245
|
+
parser.add_argument('--enable_tn',
|
|
246
|
+
action='store_true',
|
|
247
|
+
help='enable text normalization')
|
|
248
|
+
parser.add_argument('--only_llm',
|
|
249
|
+
action='store_true',
|
|
250
|
+
help='only generate speech tokens from llm')
|
|
251
|
+
parser.add_argument('--fp16_flow',
|
|
252
|
+
action='store_true',
|
|
253
|
+
help='enable fp16 flow')
|
|
254
|
+
parser.add_argument('--seed',
|
|
255
|
+
type=int,
|
|
256
|
+
default=1986,
|
|
257
|
+
help='random seed for generation')
|
|
258
|
+
args = parser.parse_args()
|
|
259
|
+
return args
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def main():
|
|
263
|
+
args = get_args()
|
|
264
|
+
|
|
265
|
+
if args.enable_tn:
|
|
266
|
+
# Check python version, if == 3.10, use ttsfrd
|
|
267
|
+
if sys.version_info.major == 3 and sys.version_info.minor == 10:
|
|
268
|
+
# Check if ttsfrd is installed
|
|
269
|
+
try:
|
|
270
|
+
import ttsfrd
|
|
271
|
+
from cosyvoice_ttsfrd import get_resource_path
|
|
272
|
+
except ImportError as e:
|
|
273
|
+
raise ImportError("ttsfrd is not installed, please install it first, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for installation guide.") from e
|
|
274
|
+
text_norm = ttsfrd.TtsFrontendEngine()
|
|
275
|
+
text_norm.initialize(get_resource_path())
|
|
276
|
+
text_norm.set_lang_type('pinyinvg')
|
|
277
|
+
else:
|
|
278
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
279
|
+
tqdm.write(f"[{timestamp}] - [WARNING] - Only python 3.10 is supported for ttsfrd, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for more info. Setting enable_tn to False...")
|
|
280
|
+
# TODO: maybe we should use wetext if python version is not 3.10?
|
|
281
|
+
args.enable_tn = False
|
|
282
|
+
text_norm = None
|
|
283
|
+
else:
|
|
284
|
+
text_norm = None
|
|
285
|
+
|
|
286
|
+
assert (torch.cuda.is_available())
|
|
287
|
+
world_size, local_rank, rank = init_distributed()
|
|
288
|
+
config = Config(model=args.model_path, enforce_eager=True, tensor_parallel_size=1,
|
|
289
|
+
max_num_seqs=args.batch_size_dataloader,
|
|
290
|
+
hf_config=CosyVoice2LLMConfig(fp16_flow=args.fp16_flow), rank=local_rank)
|
|
291
|
+
model = CosyVoice2(config)
|
|
292
|
+
|
|
293
|
+
set_all_random_seed(args.seed)
|
|
294
|
+
|
|
295
|
+
dataset = AudioDataset(text_norm, model.llm.tokenizer, args.data_list, config)
|
|
296
|
+
sampler = DistributedSampler(dataset,
|
|
297
|
+
num_replicas=world_size,
|
|
298
|
+
rank=rank)
|
|
299
|
+
dataloader = DataLoader(dataset, batch_size=args.batch_size_dataloader, num_workers=args.num_workers, pin_memory=True,
|
|
300
|
+
sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn)
|
|
301
|
+
total_steps = len(dataset)
|
|
302
|
+
|
|
303
|
+
if local_rank == 0:
|
|
304
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
305
|
+
tqdm.write(f"[{timestamp}] - [INFO] - {args}")
|
|
306
|
+
progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav",
|
|
307
|
+
position=0, leave=True, dynamic_ncols=True)
|
|
308
|
+
|
|
309
|
+
cpu_counts = os.cpu_count()
|
|
310
|
+
executor = ThreadPoolExecutor(max_workers=min(args.batch_size_dataloader, cpu_counts // 8))
|
|
311
|
+
pending_futures = []
|
|
312
|
+
dataloader_iter = iter(dataloader)
|
|
313
|
+
succeed_duration = 0.01 # avoid division by zero
|
|
314
|
+
start_time = time.time()
|
|
315
|
+
estimated_total_wavs = 0
|
|
316
|
+
succeed_wavs = 0
|
|
317
|
+
failed_wavs = 0
|
|
318
|
+
last_print_time = start_time
|
|
319
|
+
|
|
320
|
+
while True:
|
|
321
|
+
try:
|
|
322
|
+
dataloader_start = time.time()
|
|
323
|
+
batch = next(dataloader_iter)
|
|
324
|
+
dataloader_time = time.time() - dataloader_start
|
|
325
|
+
|
|
326
|
+
if len(batch['infos']) == 0:
|
|
327
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
328
|
+
tqdm.write(f"[{timestamp}] - [WARNING] - rank {rank} of {world_size}: No valid batch found, skipping this batch...")
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
model_start = time.time()
|
|
332
|
+
results_dict, timing_stats = model(**batch, batch_size_flow=args.batch_size_flow,
|
|
333
|
+
only_llm=args.only_llm)
|
|
334
|
+
model_time = time.time() - model_start
|
|
335
|
+
|
|
336
|
+
estimated_total_wavs += len(results_dict['generated_wavs'])
|
|
337
|
+
|
|
338
|
+
timing_stats['dataloader_time'] = dataloader_time
|
|
339
|
+
timing_stats['model_inference_time'] = model_time
|
|
340
|
+
|
|
341
|
+
if args.only_llm:
|
|
342
|
+
results_dict['generated_wavs'] = [None] * len(results_dict['prompt_speech_tokens'])
|
|
343
|
+
|
|
344
|
+
for i in range(len(results_dict['generated_wavs'])):
|
|
345
|
+
future = executor.submit(
|
|
346
|
+
save_file_async, results_dict['generated_wavs'][i],
|
|
347
|
+
results_dict['prompt_speech_tokens'][i],
|
|
348
|
+
results_dict['generated_speech_tokens'][i],
|
|
349
|
+
batch['infos'][i].copy(), timing_stats.copy()
|
|
350
|
+
)
|
|
351
|
+
pending_futures.append(future)
|
|
352
|
+
|
|
353
|
+
completed_futures = []
|
|
354
|
+
for future in pending_futures:
|
|
355
|
+
if future.done():
|
|
356
|
+
try:
|
|
357
|
+
duration = future.result()
|
|
358
|
+
succeed_duration += duration
|
|
359
|
+
succeed_wavs += 1
|
|
360
|
+
except Exception as e:
|
|
361
|
+
failed_wavs += 1
|
|
362
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
363
|
+
tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in async save task: {e}")
|
|
364
|
+
completed_futures.append(future)
|
|
365
|
+
|
|
366
|
+
for future in completed_futures:
|
|
367
|
+
pending_futures.remove(future)
|
|
368
|
+
|
|
369
|
+
if local_rank == 0:
|
|
370
|
+
update_n = world_size * len(batch["prompt_text_tokens_for_llm"])
|
|
371
|
+
if progress_bar.n + update_n > progress_bar.total:
|
|
372
|
+
progress_bar.update(progress_bar.total - progress_bar.n)
|
|
373
|
+
else:
|
|
374
|
+
progress_bar.update(update_n)
|
|
375
|
+
|
|
376
|
+
current_time = time.time()
|
|
377
|
+
if current_time - last_print_time >= 120 and not args.only_llm:
|
|
378
|
+
elapsed_time = current_time - start_time
|
|
379
|
+
avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0
|
|
380
|
+
estimated_total_duration = avg_duration * estimated_total_wavs
|
|
381
|
+
current_rtf = elapsed_time / estimated_total_duration if estimated_total_duration > 0.01 else 0
|
|
382
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
383
|
+
tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Estimated RTF: {current_rtf:.5f}, Elapsed time: {elapsed_time:.2f}s") # noqa
|
|
384
|
+
last_print_time = current_time
|
|
385
|
+
except StopIteration:
|
|
386
|
+
break
|
|
387
|
+
except Exception as e:
|
|
388
|
+
failed_wavs += 1
|
|
389
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
390
|
+
tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in main loop: {e}")
|
|
391
|
+
continue
|
|
392
|
+
|
|
393
|
+
total_time = time.time() - start_time
|
|
394
|
+
|
|
395
|
+
if local_rank == 0:
|
|
396
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
397
|
+
tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...")
|
|
398
|
+
|
|
399
|
+
for future in pending_futures:
|
|
400
|
+
try:
|
|
401
|
+
duration = future.result(timeout=60)
|
|
402
|
+
succeed_duration += duration
|
|
403
|
+
succeed_wavs += 1
|
|
404
|
+
except Exception as e:
|
|
405
|
+
failed_wavs += 1
|
|
406
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
407
|
+
tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in final async save task: {e}")
|
|
408
|
+
executor.shutdown(wait=True)
|
|
409
|
+
|
|
410
|
+
if local_rank == 0:
|
|
411
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
412
|
+
tqdm.write(f"[{timestamp}] - [INFO] - All async save tasks completed.")
|
|
413
|
+
progress_bar.close()
|
|
414
|
+
|
|
415
|
+
if not args.only_llm:
|
|
416
|
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
|
417
|
+
tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h), RTF: {total_time / succeed_duration:.5f}") # noqa
|
|
418
|
+
|
|
419
|
+
dist.barrier()
|
|
420
|
+
dist.destroy_process_group()
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
if __name__ == "__main__":
|
|
424
|
+
main()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from transformers import AutoConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CosyVoice2LLMConfig:
|
|
10
|
+
architectures: list[str] = field(default_factory=lambda: ["Qwen2ForCausalLM"])
|
|
11
|
+
attention_dropout: float = 0.0
|
|
12
|
+
bos_token_id: int = 151643
|
|
13
|
+
eos_token_id: int = 6561 # speech eos
|
|
14
|
+
hidden_act: str = "silu"
|
|
15
|
+
hidden_size: int = 896
|
|
16
|
+
initializer_range: float = 0.02
|
|
17
|
+
intermediate_size: int = 4864
|
|
18
|
+
max_position_embeddings: int = 32768
|
|
19
|
+
max_window_layers: int = 24
|
|
20
|
+
model_type: str = "qwen2"
|
|
21
|
+
num_attention_heads: int = 14
|
|
22
|
+
num_hidden_layers: int = 24
|
|
23
|
+
num_key_value_heads: int = 2
|
|
24
|
+
head_dim: int = 64
|
|
25
|
+
rms_norm_eps: float = 1e-06
|
|
26
|
+
rope_scaling: dict | None = None
|
|
27
|
+
rope_theta: float = 1000000.0
|
|
28
|
+
sliding_window: int = 32768
|
|
29
|
+
tie_word_embeddings: bool = False
|
|
30
|
+
torch_dtype: torch.dtype = torch.bfloat16
|
|
31
|
+
transformers_version: str = "4.52.0.dev0"
|
|
32
|
+
use_cache: bool = True
|
|
33
|
+
use_sliding_window: bool = False
|
|
34
|
+
vocab_size: int = 158500 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
|
|
35
|
+
text_vocab_size: int = 151936
|
|
36
|
+
speech_vocab_size: int = 6562 # actually 6564, we only care about non-streaming inference, so cut off tokens (6562, 6563) that are only used for streaming TTS
|
|
37
|
+
lm_head_bias: bool = True
|
|
38
|
+
qkv_bias: bool = True
|
|
39
|
+
fp16_flow: bool = True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class SamplingParams:
|
|
44
|
+
temperature: float = 1.0
|
|
45
|
+
min_tokens: int = 2
|
|
46
|
+
max_tokens: int = 64
|
|
47
|
+
ignore_eos: bool = False
|
|
48
|
+
top_k: int = 25
|
|
49
|
+
# RasSampler parameters
|
|
50
|
+
use_ras: bool = False
|
|
51
|
+
win_size: int = 10
|
|
52
|
+
tau_r: float = 0.1
|
|
53
|
+
top_p: float = 0.8
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class Config:
|
|
58
|
+
model: str
|
|
59
|
+
max_num_batched_tokens: int = 1572864
|
|
60
|
+
max_num_seqs: int = 1024
|
|
61
|
+
max_model_len: int = 1536 # 15s prompt + 30s generated audio for 25hz audio tokenizer
|
|
62
|
+
gpu_memory_utilization: float = 0.9
|
|
63
|
+
tensor_parallel_size: int = 1
|
|
64
|
+
enforce_eager: bool = False
|
|
65
|
+
hf_config: CosyVoice2LLMConfig | AutoConfig = field(default_factory=CosyVoice2LLMConfig)
|
|
66
|
+
eos: int = -1
|
|
67
|
+
kvcache_block_size: int = 256
|
|
68
|
+
num_kvcache_blocks: int = -1
|
|
69
|
+
min_token_text_ratio: int = 2
|
|
70
|
+
max_token_text_ratio: int = 20
|
|
71
|
+
rank: int = 0
|
|
72
|
+
|
|
73
|
+
def __post_init__(self):
|
|
74
|
+
assert os.path.isdir(self.model)
|
|
75
|
+
assert self.kvcache_block_size % 256 == 0
|
|
76
|
+
assert 1 <= self.tensor_parallel_size <= 8
|
|
77
|
+
|
|
78
|
+
max_pos = getattr(self.hf_config, "max_position_embeddings", 4096)
|
|
79
|
+
self.max_model_len = min(self.max_model_len, max_pos)
|
|
80
|
+
assert self.max_num_batched_tokens >= self.max_model_len
|