minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
cosyvoice/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ """
2
+ CosyVoice: Text-to-Speech with Large Language Model
3
+ """
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ # Lazy import to avoid requiring all dependencies at package import time
8
+ def __getattr__(name):
9
+ if name in ('CosyVoice', 'CosyVoice2'):
10
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
11
+ if name == 'CosyVoice':
12
+ return CosyVoice
13
+ elif name == 'CosyVoice2':
14
+ return CosyVoice2
15
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
16
+
17
+ __all__ = ['CosyVoice', 'CosyVoice2']
@@ -0,0 +1,93 @@
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in ['step', 'epoch']:
79
+ if k not in avg.keys():
80
+ avg[k] = states[k].clone()
81
+ else:
82
+ avg[k] += states[k]
83
+ # average
84
+ for k in avg.keys():
85
+ if avg[k] is not None:
86
+ # pytorch 1.6 use true_divide instead of /=
87
+ avg[k] = torch.true_divide(avg[k], num)
88
+ print('Saving to {}'.format(args.dst_model))
89
+ torch.save(avg, args.dst_model)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
@@ -0,0 +1,103 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
27
+ from cosyvoice.utils.file_utils import logging
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='export your model for deployment')
32
+ parser.add_argument('--model_dir',
33
+ type=str,
34
+ default='pretrained_models/CosyVoice-300M',
35
+ help='local path')
36
+ args = parser.parse_args()
37
+ print(args)
38
+ return args
39
+
40
+
41
+ def get_optimized_script(model, preserved_attrs=[]):
42
+ script = torch.jit.script(model)
43
+ if preserved_attrs != []:
44
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
45
+ else:
46
+ script = torch.jit.freeze(script)
47
+ script = torch.jit.optimize_for_inference(script)
48
+ return script
49
+
50
+
51
+ def main():
52
+ args = get_args()
53
+ logging.basicConfig(level=logging.DEBUG,
54
+ format='%(asctime)s %(levelname)s %(message)s')
55
+
56
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
57
+ torch._C._jit_set_profiling_mode(False)
58
+ torch._C._jit_set_profiling_executor(False)
59
+
60
+ try:
61
+ model = CosyVoice(args.model_dir)
62
+ except Exception:
63
+ try:
64
+ model = CosyVoice2(args.model_dir)
65
+ except Exception:
66
+ raise TypeError('no valid model_type!')
67
+
68
+ if not isinstance(model, CosyVoice2):
69
+ # 1. export llm text_encoder
70
+ llm_text_encoder = model.model.llm.text_encoder
71
+ script = get_optimized_script(llm_text_encoder)
72
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
73
+ script = get_optimized_script(llm_text_encoder.half())
74
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
75
+ logging.info('successfully export llm_text_encoder')
76
+
77
+ # 2. export llm llm
78
+ llm_llm = model.model.llm.llm
79
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
80
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
81
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
82
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
83
+ logging.info('successfully export llm_llm')
84
+
85
+ # 3. export flow encoder
86
+ flow_encoder = model.model.flow.encoder
87
+ script = get_optimized_script(flow_encoder)
88
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
89
+ script = get_optimized_script(flow_encoder.half())
90
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
91
+ logging.info('successfully export flow_encoder')
92
+ else:
93
+ # 3. export flow encoder
94
+ flow_encoder = model.model.flow.encoder
95
+ script = get_optimized_script(flow_encoder)
96
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
97
+ script = get_optimized_script(flow_encoder.half())
98
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
99
+ logging.info('successfully export flow_encoder')
100
+
101
+
102
+ if __name__ == '__main__':
103
+ main()
@@ -0,0 +1,120 @@
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
31
+ from cosyvoice.utils.file_utils import logging
32
+
33
+
34
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
35
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
36
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
37
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
38
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
39
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
40
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
41
+ return x, mask, mu, t, spks, cond
42
+
43
+
44
+ def get_args():
45
+ parser = argparse.ArgumentParser(description='export your model for deployment')
46
+ parser.add_argument('--model_dir',
47
+ type=str,
48
+ default='pretrained_models/CosyVoice-300M',
49
+ help='local path')
50
+ args = parser.parse_args()
51
+ print(args)
52
+ return args
53
+
54
+
55
+ @torch.no_grad()
56
+ def main():
57
+ args = get_args()
58
+ logging.basicConfig(level=logging.DEBUG,
59
+ format='%(asctime)s %(levelname)s %(message)s')
60
+
61
+ try:
62
+ model = CosyVoice(args.model_dir)
63
+ except Exception:
64
+ try:
65
+ model = CosyVoice2(args.model_dir)
66
+ except Exception:
67
+ raise TypeError('no valid model_type!')
68
+
69
+ # 1. export flow decoder estimator
70
+ estimator = model.model.flow.decoder.estimator
71
+ estimator.eval()
72
+
73
+ device = model.model.device
74
+ batch_size, seq_len = 2, 256
75
+ out_channels = model.model.flow.decoder.estimator.out_channels
76
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
77
+ torch.onnx.export(
78
+ estimator,
79
+ (x, mask, mu, t, spks, cond),
80
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
81
+ export_params=True,
82
+ opset_version=18,
83
+ do_constant_folding=True,
84
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
85
+ output_names=['estimator_out'],
86
+ dynamic_axes={
87
+ 'x': {2: 'seq_len'},
88
+ 'mask': {2: 'seq_len'},
89
+ 'mu': {2: 'seq_len'},
90
+ 'cond': {2: 'seq_len'},
91
+ 'estimator_out': {2: 'seq_len'},
92
+ }
93
+ )
94
+
95
+ # 2. test computation consistency
96
+ option = onnxruntime.SessionOptions()
97
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
98
+ option.intra_op_num_threads = 1
99
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
100
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
101
+ sess_options=option, providers=providers)
102
+
103
+ for _ in tqdm(range(10)):
104
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
105
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
106
+ ort_inputs = {
107
+ 'x': x.cpu().numpy(),
108
+ 'mask': mask.cpu().numpy(),
109
+ 'mu': mu.cpu().numpy(),
110
+ 't': t.cpu().numpy(),
111
+ 'spks': spks.cpu().numpy(),
112
+ 'cond': cond.cpu().numpy()
113
+ }
114
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
115
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
116
+ logging.info('successfully export estimator')
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ try:
64
+ with open(args.config, 'r') as f:
65
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
66
+ model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
67
+ except Exception:
68
+ try:
69
+ with open(args.config, 'r') as f:
70
+ configs = load_hyperpyyaml(f)
71
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
72
+ except Exception:
73
+ raise TypeError('no valid model_type!')
74
+
75
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
76
+
77
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
78
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
79
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
80
+
81
+ sample_rate = configs['sample_rate']
82
+ del configs
83
+ os.makedirs(args.result_dir, exist_ok=True)
84
+ fn = os.path.join(args.result_dir, 'wav.scp')
85
+ f = open(fn, 'w')
86
+ with torch.no_grad():
87
+ for _, batch in tqdm(enumerate(test_data_loader)):
88
+ utts = batch["utts"]
89
+ assert len(utts) == 1, "inference mode only support batchsize 1"
90
+ text_token = batch["text_token"].to(device)
91
+ text_token_len = batch["text_token_len"].to(device)
92
+ tts_index = batch["tts_index"]
93
+ tts_text_token = batch["tts_text_token"].to(device)
94
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
95
+ speech_token = batch["speech_token"].to(device)
96
+ speech_token_len = batch["speech_token_len"].to(device)
97
+ speech_feat = batch["speech_feat"].to(device)
98
+ speech_feat_len = batch["speech_feat_len"].to(device)
99
+ utt_embedding = batch["utt_embedding"].to(device)
100
+ spk_embedding = batch["spk_embedding"].to(device)
101
+ if args.mode == 'sft':
102
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
103
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
104
+ else:
105
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
106
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
107
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
108
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
109
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
110
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
111
+ tts_speeches = []
112
+ for model_output in model.tts(**model_input):
113
+ tts_speeches.append(model_output['tts_speech'])
114
+ tts_speeches = torch.concat(tts_speeches, dim=1)
115
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
116
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
117
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
118
+ f.write('{} {}\n'.format(tts_key, tts_fn))
119
+ f.flush()
120
+ f.close()
121
+ logging.info('Result wav.scp saved in {}'.format(fn))
122
+
123
+
124
+ if __name__ == '__main__':
125
+ logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
126
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,195 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.losses import DPOLoss
31
+ from cosyvoice.utils.executor import Executor
32
+ from cosyvoice.utils.train_utils import (
33
+ init_distributed,
34
+ init_dataset_and_dataloader,
35
+ init_optimizer_and_scheduler,
36
+ init_summarywriter, save_model,
37
+ wrap_cuda_model, check_modify_and_save_config)
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser(description='training your network')
42
+ parser.add_argument('--train_engine',
43
+ default='torch_ddp',
44
+ choices=['torch_ddp', 'deepspeed'],
45
+ help='Engine for paralleled training')
46
+ parser.add_argument('--model', required=True, help='model which will be trained')
47
+ parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
48
+ parser.add_argument('--config', required=True, help='config file')
49
+ parser.add_argument('--train_data', required=True, help='train data file')
50
+ parser.add_argument('--cv_data', required=True, help='cv data file')
51
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
52
+ parser.add_argument('--checkpoint', help='checkpoint model')
53
+ parser.add_argument('--model_dir', required=True, help='save model dir')
54
+ parser.add_argument('--tensorboard_dir',
55
+ default='tensorboard',
56
+ help='tensorboard log dir')
57
+ parser.add_argument('--ddp.dist_backend',
58
+ dest='dist_backend',
59
+ default='nccl',
60
+ choices=['nccl', 'gloo'],
61
+ help='distributed backend')
62
+ parser.add_argument('--num_workers',
63
+ default=0,
64
+ type=int,
65
+ help='num of subprocess workers for reading')
66
+ parser.add_argument('--prefetch',
67
+ default=100,
68
+ type=int,
69
+ help='prefetch number')
70
+ parser.add_argument('--pin_memory',
71
+ action='store_true',
72
+ default=False,
73
+ help='Use pinned memory buffers used for reading')
74
+ parser.add_argument('--use_amp',
75
+ action='store_true',
76
+ default=False,
77
+ help='Use automatic mixed precision training')
78
+ parser.add_argument('--dpo',
79
+ action='store_true',
80
+ default=False,
81
+ help='Use Direct Preference Optimization')
82
+ parser.add_argument('--deepspeed.save_states',
83
+ dest='save_states',
84
+ default='model_only',
85
+ choices=['model_only', 'model+optimizer'],
86
+ help='save model/optimizer states')
87
+ parser.add_argument('--timeout',
88
+ default=60,
89
+ type=int,
90
+ help='timeout (in seconds) of cosyvoice_join.')
91
+ parser = deepspeed.add_config_arguments(parser)
92
+ args = parser.parse_args()
93
+ return args
94
+
95
+
96
+ @record
97
+ def main():
98
+ args = get_args()
99
+ logging.basicConfig(level=logging.DEBUG,
100
+ format='%(asctime)s %(levelname)s %(message)s')
101
+ # gan train has some special initialization logic
102
+ gan = True if args.model == 'hifigan' else False
103
+
104
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
105
+ if gan is True:
106
+ override_dict.pop('hift')
107
+ try:
108
+ with open(args.config, 'r') as f:
109
+ configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
110
+ except Exception:
111
+ with open(args.config, 'r') as f:
112
+ configs = load_hyperpyyaml(f, overrides=override_dict)
113
+ if gan is True:
114
+ configs['train_conf'] = configs['train_conf_gan']
115
+ configs['train_conf'].update(vars(args))
116
+
117
+ # Init env for ddp
118
+ init_distributed(args)
119
+
120
+ # Get dataset & dataloader
121
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
122
+ init_dataset_and_dataloader(args, configs, gan, args.dpo)
123
+
124
+ # Do some sanity checks and save config to arsg.model_dir
125
+ configs = check_modify_and_save_config(args, configs)
126
+
127
+ # Tensorboard summary
128
+ writer = init_summarywriter(args)
129
+
130
+ # load checkpoint
131
+ if args.dpo is True:
132
+ configs[args.model].forward = configs[args.model].forward_dpo
133
+ model = configs[args.model]
134
+ start_step, start_epoch = 0, -1
135
+ if args.checkpoint is not None:
136
+ if os.path.exists(args.checkpoint):
137
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
138
+ model.load_state_dict(state_dict, strict=False)
139
+ if 'step' in state_dict:
140
+ start_step = state_dict['step']
141
+ if 'epoch' in state_dict:
142
+ start_epoch = state_dict['epoch']
143
+ else:
144
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
145
+
146
+ # Dispatch model from cpu to gpu
147
+ model = wrap_cuda_model(args, model)
148
+
149
+ # Get optimizer & scheduler
150
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
151
+ scheduler.set_step(start_step)
152
+ if scheduler_d is not None:
153
+ scheduler_d.set_step(start_step)
154
+
155
+ # Save init checkpoints
156
+ info_dict = deepcopy(configs['train_conf'])
157
+ info_dict['step'] = start_step
158
+ info_dict['epoch'] = start_epoch
159
+ save_model(model, 'init', info_dict)
160
+
161
+ # DPO related
162
+ if args.dpo is True:
163
+ ref_model = deepcopy(configs[args.model])
164
+ state_dict = torch.load(args.ref_model, map_location='cpu')
165
+ ref_model.load_state_dict(state_dict, strict=False)
166
+ dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
167
+ # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
168
+ ref_model = wrap_cuda_model(args, ref_model)
169
+ else:
170
+ ref_model, dpo_loss = None, None
171
+
172
+ # Get executor
173
+ executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
174
+ executor.step = start_step
175
+
176
+ # Init scaler, used for pytorch amp mixed precision training
177
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
178
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
179
+
180
+ # Start training loop
181
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
182
+ executor.epoch = epoch
183
+ train_dataset.set_epoch(epoch)
184
+ dist.barrier()
185
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
186
+ if gan is True:
187
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
188
+ writer, info_dict, scaler, group_join)
189
+ else:
190
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
191
+ dist.destroy_process_group(group_join)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ main()
File without changes