minicpmo-utils 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- minicpmo_utils-0.0.5/PKG-INFO +116 -0
- minicpmo_utils-0.0.5/README.md +53 -0
- minicpmo_utils-0.0.5/pyproject.toml +108 -0
- minicpmo_utils-0.0.5/setup.cfg +4 -0
- minicpmo_utils-0.0.5/src/cosyvoice/__init__.py +17 -0
- minicpmo_utils-0.0.5/src/cosyvoice/bin/average_model.py +93 -0
- minicpmo_utils-0.0.5/src/cosyvoice/bin/export_jit.py +103 -0
- minicpmo_utils-0.0.5/src/cosyvoice/bin/export_onnx.py +120 -0
- minicpmo_utils-0.0.5/src/cosyvoice/bin/inference_deprecated.py +126 -0
- minicpmo_utils-0.0.5/src/cosyvoice/bin/train.py +195 -0
- minicpmo_utils-0.0.5/src/cosyvoice/cli/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/cosyvoice/cli/cosyvoice.py +204 -0
- minicpmo_utils-0.0.5/src/cosyvoice/cli/frontend.py +238 -0
- minicpmo_utils-0.0.5/src/cosyvoice/cli/model.py +386 -0
- minicpmo_utils-0.0.5/src/cosyvoice/dataset/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/cosyvoice/dataset/dataset.py +151 -0
- minicpmo_utils-0.0.5/src/cosyvoice/dataset/processor.py +434 -0
- minicpmo_utils-0.0.5/src/cosyvoice/flow/decoder.py +494 -0
- minicpmo_utils-0.0.5/src/cosyvoice/flow/flow.py +281 -0
- minicpmo_utils-0.0.5/src/cosyvoice/flow/flow_matching.py +227 -0
- minicpmo_utils-0.0.5/src/cosyvoice/flow/length_regulator.py +70 -0
- minicpmo_utils-0.0.5/src/cosyvoice/hifigan/discriminator.py +230 -0
- minicpmo_utils-0.0.5/src/cosyvoice/hifigan/f0_predictor.py +58 -0
- minicpmo_utils-0.0.5/src/cosyvoice/hifigan/generator.py +582 -0
- minicpmo_utils-0.0.5/src/cosyvoice/hifigan/hifigan.py +67 -0
- minicpmo_utils-0.0.5/src/cosyvoice/llm/llm.py +610 -0
- minicpmo_utils-0.0.5/src/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- minicpmo_utils-0.0.5/src/cosyvoice/tokenizer/tokenizer.py +279 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/activation.py +84 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/attention.py +330 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/convolution.py +145 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/decoder.py +396 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/decoder_layer.py +132 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/embedding.py +302 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/encoder.py +474 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/encoder_layer.py +236 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/subsampling.py +383 -0
- minicpmo_utils-0.0.5/src/cosyvoice/transformer/upsample_encoder.py +320 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/class_utils.py +83 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/common.py +186 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/executor.py +176 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/file_utils.py +130 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/frontend_utils.py +136 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/losses.py +57 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/mask.py +265 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/scheduler.py +738 -0
- minicpmo_utils-0.0.5/src/cosyvoice/utils/train_utils.py +367 -0
- minicpmo_utils-0.0.5/src/cosyvoice/vllm/cosyvoice2.py +103 -0
- minicpmo_utils-0.0.5/src/matcha/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/matcha/app.py +357 -0
- minicpmo_utils-0.0.5/src/matcha/cli.py +418 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/config.py +28 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/denoiser.py +64 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/env.py +17 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/meldataset.py +217 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/models.py +368 -0
- minicpmo_utils-0.0.5/src/matcha/hifigan/xutils.py +60 -0
- minicpmo_utils-0.0.5/src/matcha/models/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/matcha/models/baselightningmodule.py +209 -0
- minicpmo_utils-0.0.5/src/matcha/models/components/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/matcha/models/components/decoder.py +443 -0
- minicpmo_utils-0.0.5/src/matcha/models/components/flow_matching.py +132 -0
- minicpmo_utils-0.0.5/src/matcha/models/components/text_encoder.py +410 -0
- minicpmo_utils-0.0.5/src/matcha/models/components/transformer.py +316 -0
- minicpmo_utils-0.0.5/src/matcha/models/matcha_tts.py +239 -0
- minicpmo_utils-0.0.5/src/matcha/onnx/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/matcha/onnx/export.py +181 -0
- minicpmo_utils-0.0.5/src/matcha/onnx/infer.py +168 -0
- minicpmo_utils-0.0.5/src/matcha/text/__init__.py +53 -0
- minicpmo_utils-0.0.5/src/matcha/text/cleaners.py +116 -0
- minicpmo_utils-0.0.5/src/matcha/text/numbers.py +71 -0
- minicpmo_utils-0.0.5/src/matcha/text/symbols.py +17 -0
- minicpmo_utils-0.0.5/src/matcha/train.py +122 -0
- minicpmo_utils-0.0.5/src/matcha/utils/__init__.py +5 -0
- minicpmo_utils-0.0.5/src/matcha/utils/audio.py +82 -0
- minicpmo_utils-0.0.5/src/matcha/utils/generate_data_statistics.py +111 -0
- minicpmo_utils-0.0.5/src/matcha/utils/instantiators.py +56 -0
- minicpmo_utils-0.0.5/src/matcha/utils/logging_utils.py +53 -0
- minicpmo_utils-0.0.5/src/matcha/utils/model.py +90 -0
- minicpmo_utils-0.0.5/src/matcha/utils/monotonic_align/__init__.py +22 -0
- minicpmo_utils-0.0.5/src/matcha/utils/monotonic_align/setup.py +7 -0
- minicpmo_utils-0.0.5/src/matcha/utils/pylogger.py +21 -0
- minicpmo_utils-0.0.5/src/matcha/utils/rich_utils.py +101 -0
- minicpmo_utils-0.0.5/src/matcha/utils/utils.py +219 -0
- minicpmo_utils-0.0.5/src/minicpmo/__init__.py +14 -0
- minicpmo_utils-0.0.5/src/minicpmo/utils.py +723 -0
- minicpmo_utils-0.0.5/src/minicpmo/version.py +2 -0
- minicpmo_utils-0.0.5/src/minicpmo_utils.egg-info/PKG-INFO +116 -0
- minicpmo_utils-0.0.5/src/minicpmo_utils.egg-info/SOURCES.txt +151 -0
- minicpmo_utils-0.0.5/src/minicpmo_utils.egg-info/dependency_links.txt +1 -0
- minicpmo_utils-0.0.5/src/minicpmo_utils.egg-info/requires.txt +55 -0
- minicpmo_utils-0.0.5/src/minicpmo_utils.egg-info/top_level.txt +5 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/__init__.py +153 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/assets/mel_filters.npz +0 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/cli.py +183 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/model.py +546 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/model_v2.py +605 -0
- minicpmo_utils-0.0.5/src/s3tokenizer/utils.py +390 -0
- minicpmo_utils-0.0.5/src/stepaudio2/__init__.py +40 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/__init__.py +1 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/flow/flow.py +230 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/utils/common.py +101 -0
- minicpmo_utils-0.0.5/src/stepaudio2/cosyvoice2/utils/mask.py +49 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/cli.py +424 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/config.py +80 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/utils/context.py +28 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- minicpmo_utils-0.0.5/src/stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- minicpmo_utils-0.0.5/src/stepaudio2/stepaudio2.py +204 -0
- minicpmo_utils-0.0.5/src/stepaudio2/token2wav.py +247 -0
- minicpmo_utils-0.0.5/src/stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: minicpmo-utils
|
|
3
|
+
Version: 0.0.5
|
|
4
|
+
Summary: Unified utilities package for MiniCPM-o: includes cosyvoice + stepaudio2 and extensible utils.
|
|
5
|
+
Author: MiniCPM-o Utils Maintainers
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Keywords: minicpmo,audio,tts,utils,cosyvoice,stepaudio2
|
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
Requires-Dist: numpy
|
|
18
|
+
Requires-Dist: pillow==10.4.0
|
|
19
|
+
Requires-Dist: librosa==0.9.0
|
|
20
|
+
Requires-Dist: decord==0.6.0
|
|
21
|
+
Requires-Dist: moviepy==2.1.2
|
|
22
|
+
Requires-Dist: numba==0.61.2
|
|
23
|
+
Provides-Extra: tts
|
|
24
|
+
Requires-Dist: torch>=2.3.0; extra == "tts"
|
|
25
|
+
Requires-Dist: torchaudio>=2.3.0; extra == "tts"
|
|
26
|
+
Requires-Dist: transformers<4.53.0,>=4.51.0; extra == "tts"
|
|
27
|
+
Requires-Dist: onnxruntime<=1.21.0,>=1.18.0; extra == "tts"
|
|
28
|
+
Requires-Dist: onnx; extra == "tts"
|
|
29
|
+
Requires-Dist: hyperpyyaml; extra == "tts"
|
|
30
|
+
Requires-Dist: openai-whisper==20231117; extra == "tts"
|
|
31
|
+
Requires-Dist: tqdm; extra == "tts"
|
|
32
|
+
Requires-Dist: tiktoken; extra == "tts"
|
|
33
|
+
Requires-Dist: inflect; extra == "tts"
|
|
34
|
+
Requires-Dist: omegaconf>=2.0.6; extra == "tts"
|
|
35
|
+
Requires-Dist: conformer==0.3.2; extra == "tts"
|
|
36
|
+
Requires-Dist: einops==0.8.1; extra == "tts"
|
|
37
|
+
Requires-Dist: hydra-core; extra == "tts"
|
|
38
|
+
Requires-Dist: lightning==2.2.4; extra == "tts"
|
|
39
|
+
Requires-Dist: rich; extra == "tts"
|
|
40
|
+
Requires-Dist: gdown==5.2.0; extra == "tts"
|
|
41
|
+
Requires-Dist: matplotlib; extra == "tts"
|
|
42
|
+
Requires-Dist: wget; extra == "tts"
|
|
43
|
+
Requires-Dist: pyarrow; extra == "tts"
|
|
44
|
+
Requires-Dist: pyworld; extra == "tts"
|
|
45
|
+
Requires-Dist: scipy; extra == "tts"
|
|
46
|
+
Requires-Dist: pyyaml; extra == "tts"
|
|
47
|
+
Requires-Dist: regex; extra == "tts"
|
|
48
|
+
Requires-Dist: soundfile==0.12.1; extra == "tts"
|
|
49
|
+
Requires-Dist: diffusers==0.29.0; extra == "tts"
|
|
50
|
+
Provides-Extra: streaming
|
|
51
|
+
Requires-Dist: minicpmo-utils[tts]; extra == "streaming"
|
|
52
|
+
Provides-Extra: streaming-flash
|
|
53
|
+
Requires-Dist: minicpmo-utils[streaming]; extra == "streaming-flash"
|
|
54
|
+
Requires-Dist: flash-attn>=2.6.0; sys_platform == "linux" and extra == "streaming-flash"
|
|
55
|
+
Requires-Dist: triton>=2.3.0; sys_platform == "linux" and extra == "streaming-flash"
|
|
56
|
+
Requires-Dist: safetensors; extra == "streaming-flash"
|
|
57
|
+
Requires-Dist: pynvml; extra == "streaming-flash"
|
|
58
|
+
Requires-Dist: xxhash; extra == "streaming-flash"
|
|
59
|
+
Provides-Extra: gpu
|
|
60
|
+
Requires-Dist: onnxruntime-gpu<=1.23.2,>=1.18.0; sys_platform == "linux" and extra == "gpu"
|
|
61
|
+
Provides-Extra: all
|
|
62
|
+
Requires-Dist: minicpmo-utils[gpu,streaming,tts]; extra == "all"
|
|
63
|
+
|
|
64
|
+
## minicpmo-utils
|
|
65
|
+
|
|
66
|
+
一个统一安装的工具包(一个 PyPI 分发包),把仓库里的 `cosyvoice` 与 `stepaudio2` 一起打进同一个 wheel,并预留 `minicpmo` 作为后续扩展 utils 的统一入口。
|
|
67
|
+
|
|
68
|
+
### 安装方式
|
|
69
|
+
|
|
70
|
+
- 从源码本地安装(开发态,可编辑,默认只装公共依赖):
|
|
71
|
+
```bash
|
|
72
|
+
cd minicpmo-utils
|
|
73
|
+
pip install -e .
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
- 如果只想安装 cosyvoice 相关依赖(TTS):
|
|
77
|
+
```bash
|
|
78
|
+
pip install -e .[tts]
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
- 如果只想安装 stepaudio2 / streaming 相关依赖:
|
|
82
|
+
```bash
|
|
83
|
+
pip install -e .[streaming]
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
- 同时安装 cosyvoice + stepaudio2 相关依赖:
|
|
87
|
+
```bash
|
|
88
|
+
pip install -e .[tts,streaming]
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
- 构建并安装 wheel(推荐分发):
|
|
92
|
+
```bash
|
|
93
|
+
cd minicpmo-utils
|
|
94
|
+
python -m build # 生成 dist/*.whl
|
|
95
|
+
pip install \"dist/minicpmo_utils-0.1.0-py3-none-any.whl[tts,streaming]\"
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
### 导入方式
|
|
99
|
+
|
|
100
|
+
包会暴露以下顶层模块,安装后可直接使用:
|
|
101
|
+
- `import cosyvoice`
|
|
102
|
+
- `import stepaudio2`
|
|
103
|
+
- `import matcha`
|
|
104
|
+
- `import minicpmo`
|
|
105
|
+
|
|
106
|
+
也支持通过统一入口导入子包:
|
|
107
|
+
```python
|
|
108
|
+
from minicpmo import cosyvoice, stepaudio2, matcha
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
以及通过统一的 utils 入口使用通用工具函数,例如:
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
from minicpmo.utils import get_video_frame_audio_segments
|
|
115
|
+
```
|
|
116
|
+
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
## minicpmo-utils
|
|
2
|
+
|
|
3
|
+
一个统一安装的工具包(一个 PyPI 分发包),把仓库里的 `cosyvoice` 与 `stepaudio2` 一起打进同一个 wheel,并预留 `minicpmo` 作为后续扩展 utils 的统一入口。
|
|
4
|
+
|
|
5
|
+
### 安装方式
|
|
6
|
+
|
|
7
|
+
- 从源码本地安装(开发态,可编辑,默认只装公共依赖):
|
|
8
|
+
```bash
|
|
9
|
+
cd minicpmo-utils
|
|
10
|
+
pip install -e .
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
- 如果只想安装 cosyvoice 相关依赖(TTS):
|
|
14
|
+
```bash
|
|
15
|
+
pip install -e .[tts]
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
- 如果只想安装 stepaudio2 / streaming 相关依赖:
|
|
19
|
+
```bash
|
|
20
|
+
pip install -e .[streaming]
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
- 同时安装 cosyvoice + stepaudio2 相关依赖:
|
|
24
|
+
```bash
|
|
25
|
+
pip install -e .[tts,streaming]
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
- 构建并安装 wheel(推荐分发):
|
|
29
|
+
```bash
|
|
30
|
+
cd minicpmo-utils
|
|
31
|
+
python -m build # 生成 dist/*.whl
|
|
32
|
+
pip install \"dist/minicpmo_utils-0.1.0-py3-none-any.whl[tts,streaming]\"
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
### 导入方式
|
|
36
|
+
|
|
37
|
+
包会暴露以下顶层模块,安装后可直接使用:
|
|
38
|
+
- `import cosyvoice`
|
|
39
|
+
- `import stepaudio2`
|
|
40
|
+
- `import matcha`
|
|
41
|
+
- `import minicpmo`
|
|
42
|
+
|
|
43
|
+
也支持通过统一入口导入子包:
|
|
44
|
+
```python
|
|
45
|
+
from minicpmo import cosyvoice, stepaudio2, matcha
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
以及通过统一的 utils 入口使用通用工具函数,例如:
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from minicpmo.utils import get_video_frame_audio_segments
|
|
52
|
+
```
|
|
53
|
+
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "minicpmo-utils"
|
|
7
|
+
version = "0.0.5"
|
|
8
|
+
description = "Unified utilities package for MiniCPM-o: includes cosyvoice + stepaudio2 and extensible utils."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = {text = "Apache-2.0"}
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "MiniCPM-o Utils Maintainers"},
|
|
14
|
+
]
|
|
15
|
+
keywords = ["minicpmo", "audio", "tts", "utils", "cosyvoice", "stepaudio2"]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 4 - Beta",
|
|
18
|
+
"Intended Audience :: Developers",
|
|
19
|
+
"License :: OSI Approved :: Apache Software License",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Programming Language :: Python :: 3.12",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
# NOTE:
|
|
27
|
+
# - 这是“一个分发包(minicpmo-utils)”同时提供多个顶层 import 包:
|
|
28
|
+
# - cosyvoice (来自 ../cosyvoice/src)
|
|
29
|
+
# - matcha (来自 ../cosyvoice/src)
|
|
30
|
+
# - stepaudio2(来自 ../stepaudio2/src)
|
|
31
|
+
# - s3tokenizer (来自 S3Tokenizer-main)
|
|
32
|
+
# - minicpmo (本项目扩展 utils 的统一入口:from minicpmo.utils import ...)
|
|
33
|
+
dependencies = [
|
|
34
|
+
"numpy",
|
|
35
|
+
"pillow==10.4.0",
|
|
36
|
+
"librosa==0.9.0",
|
|
37
|
+
"decord==0.6.0",
|
|
38
|
+
"moviepy==2.1.2",
|
|
39
|
+
"numba==0.61.2",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
[project.optional-dependencies]
|
|
43
|
+
# cosyvoice TTS 相关依赖
|
|
44
|
+
tts = [
|
|
45
|
+
"torch>=2.3.0",
|
|
46
|
+
"torchaudio>=2.3.0",
|
|
47
|
+
"transformers>=4.51.0,<4.53.0", # 4.52+ 有兼容性问题
|
|
48
|
+
"onnxruntime>=1.18.0,<=1.21.0",
|
|
49
|
+
"onnx",
|
|
50
|
+
"hyperpyyaml",
|
|
51
|
+
"openai-whisper==20231117",
|
|
52
|
+
"tqdm",
|
|
53
|
+
"tiktoken",
|
|
54
|
+
"inflect",
|
|
55
|
+
"omegaconf>=2.0.6",
|
|
56
|
+
"conformer==0.3.2",
|
|
57
|
+
"einops==0.8.1",
|
|
58
|
+
"hydra-core",
|
|
59
|
+
"lightning==2.2.4",
|
|
60
|
+
"rich",
|
|
61
|
+
"gdown==5.2.0",
|
|
62
|
+
"matplotlib",
|
|
63
|
+
"wget",
|
|
64
|
+
"pyarrow",
|
|
65
|
+
"pyworld",
|
|
66
|
+
# 新增依赖
|
|
67
|
+
"scipy",
|
|
68
|
+
"pyyaml",
|
|
69
|
+
"regex",
|
|
70
|
+
"soundfile==0.12.1",
|
|
71
|
+
"diffusers==0.29.0"
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
# stepaudio2 基础依赖(token2wav 等)
|
|
75
|
+
streaming = [
|
|
76
|
+
"minicpmo-utils[tts]", # streaming 依赖 tts
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
# stepaudio2 Flash 推理引擎依赖(flashcosyvoice.engine 模块需要)
|
|
80
|
+
streaming-flash = [
|
|
81
|
+
"minicpmo-utils[streaming]",
|
|
82
|
+
"flash-attn>=2.6.0; sys_platform == 'linux'",
|
|
83
|
+
"triton>=2.3.0; sys_platform == 'linux'",
|
|
84
|
+
"safetensors",
|
|
85
|
+
"pynvml",
|
|
86
|
+
"xxhash",
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
# Linux GPU onnxruntime 可以很重,且与环境强相关,保留为可选 extra
|
|
90
|
+
gpu = [
|
|
91
|
+
"onnxruntime-gpu>=1.18.0,<=1.23.2; sys_platform == 'linux'",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
all = [
|
|
95
|
+
"minicpmo-utils[tts,streaming,gpu]",
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
[tool.setuptools]
|
|
99
|
+
include-package-data = true
|
|
100
|
+
|
|
101
|
+
[tool.setuptools.packages.find]
|
|
102
|
+
# 现在所有代码都在本项目的 src/ 下
|
|
103
|
+
where = ["src"]
|
|
104
|
+
|
|
105
|
+
[tool.setuptools.package-data]
|
|
106
|
+
"cosyvoice.tokenizer.assets" = ["*.tiktoken"]
|
|
107
|
+
"s3tokenizer.assets" = ["*.wav", "*.npz"]
|
|
108
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CosyVoice: Text-to-Speech with Large Language Model
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
|
|
7
|
+
# Lazy import to avoid requiring all dependencies at package import time
|
|
8
|
+
def __getattr__(name):
|
|
9
|
+
if name in ('CosyVoice', 'CosyVoice2'):
|
|
10
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
11
|
+
if name == 'CosyVoice':
|
|
12
|
+
return CosyVoice
|
|
13
|
+
elif name == 'CosyVoice2':
|
|
14
|
+
return CosyVoice2
|
|
15
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
16
|
+
|
|
17
|
+
__all__ = ['CosyVoice', 'CosyVoice2']
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
|
2
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import argparse
|
|
18
|
+
import glob
|
|
19
|
+
|
|
20
|
+
import yaml
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_args():
|
|
25
|
+
parser = argparse.ArgumentParser(description='average model')
|
|
26
|
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
|
27
|
+
parser.add_argument('--src_path',
|
|
28
|
+
required=True,
|
|
29
|
+
help='src model path for average')
|
|
30
|
+
parser.add_argument('--val_best',
|
|
31
|
+
action="store_true",
|
|
32
|
+
help='averaged model')
|
|
33
|
+
parser.add_argument('--num',
|
|
34
|
+
default=5,
|
|
35
|
+
type=int,
|
|
36
|
+
help='nums for averaged model')
|
|
37
|
+
|
|
38
|
+
args = parser.parse_args()
|
|
39
|
+
print(args)
|
|
40
|
+
return args
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def main():
|
|
44
|
+
args = get_args()
|
|
45
|
+
val_scores = []
|
|
46
|
+
if args.val_best:
|
|
47
|
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
|
48
|
+
yamls = [
|
|
49
|
+
f for f in yamls
|
|
50
|
+
if not (os.path.basename(f).startswith('train')
|
|
51
|
+
or os.path.basename(f).startswith('init'))
|
|
52
|
+
]
|
|
53
|
+
for y in yamls:
|
|
54
|
+
with open(y, 'r') as f:
|
|
55
|
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
|
56
|
+
loss = float(dic_yaml['loss_dict']['loss'])
|
|
57
|
+
epoch = int(dic_yaml['epoch'])
|
|
58
|
+
step = int(dic_yaml['step'])
|
|
59
|
+
tag = dic_yaml['tag']
|
|
60
|
+
val_scores += [[epoch, step, loss, tag]]
|
|
61
|
+
sorted_val_scores = sorted(val_scores,
|
|
62
|
+
key=lambda x: x[2],
|
|
63
|
+
reverse=False)
|
|
64
|
+
print("best val (epoch, step, loss, tag) = " +
|
|
65
|
+
str(sorted_val_scores[:args.num]))
|
|
66
|
+
path_list = [
|
|
67
|
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
|
68
|
+
for score in sorted_val_scores[:args.num]
|
|
69
|
+
]
|
|
70
|
+
print(path_list)
|
|
71
|
+
avg = {}
|
|
72
|
+
num = args.num
|
|
73
|
+
assert num == len(path_list)
|
|
74
|
+
for path in path_list:
|
|
75
|
+
print('Processing {}'.format(path))
|
|
76
|
+
states = torch.load(path, map_location=torch.device('cpu'))
|
|
77
|
+
for k in states.keys():
|
|
78
|
+
if k not in ['step', 'epoch']:
|
|
79
|
+
if k not in avg.keys():
|
|
80
|
+
avg[k] = states[k].clone()
|
|
81
|
+
else:
|
|
82
|
+
avg[k] += states[k]
|
|
83
|
+
# average
|
|
84
|
+
for k in avg.keys():
|
|
85
|
+
if avg[k] is not None:
|
|
86
|
+
# pytorch 1.6 use true_divide instead of /=
|
|
87
|
+
avg[k] = torch.true_divide(avg[k], num)
|
|
88
|
+
print('Saving to {}'.format(args.dst_model))
|
|
89
|
+
torch.save(avg, args.dst_model)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
main()
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
import torch
|
|
23
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
24
|
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
25
|
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
26
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
27
|
+
from cosyvoice.utils.file_utils import logging
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_args():
|
|
31
|
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
32
|
+
parser.add_argument('--model_dir',
|
|
33
|
+
type=str,
|
|
34
|
+
default='pretrained_models/CosyVoice-300M',
|
|
35
|
+
help='local path')
|
|
36
|
+
args = parser.parse_args()
|
|
37
|
+
print(args)
|
|
38
|
+
return args
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_optimized_script(model, preserved_attrs=[]):
|
|
42
|
+
script = torch.jit.script(model)
|
|
43
|
+
if preserved_attrs != []:
|
|
44
|
+
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
|
45
|
+
else:
|
|
46
|
+
script = torch.jit.freeze(script)
|
|
47
|
+
script = torch.jit.optimize_for_inference(script)
|
|
48
|
+
return script
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def main():
|
|
52
|
+
args = get_args()
|
|
53
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
54
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
55
|
+
|
|
56
|
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
|
57
|
+
torch._C._jit_set_profiling_mode(False)
|
|
58
|
+
torch._C._jit_set_profiling_executor(False)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
model = CosyVoice(args.model_dir)
|
|
62
|
+
except Exception:
|
|
63
|
+
try:
|
|
64
|
+
model = CosyVoice2(args.model_dir)
|
|
65
|
+
except Exception:
|
|
66
|
+
raise TypeError('no valid model_type!')
|
|
67
|
+
|
|
68
|
+
if not isinstance(model, CosyVoice2):
|
|
69
|
+
# 1. export llm text_encoder
|
|
70
|
+
llm_text_encoder = model.model.llm.text_encoder
|
|
71
|
+
script = get_optimized_script(llm_text_encoder)
|
|
72
|
+
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
|
73
|
+
script = get_optimized_script(llm_text_encoder.half())
|
|
74
|
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
|
75
|
+
logging.info('successfully export llm_text_encoder')
|
|
76
|
+
|
|
77
|
+
# 2. export llm llm
|
|
78
|
+
llm_llm = model.model.llm.llm
|
|
79
|
+
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
|
80
|
+
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
|
81
|
+
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
|
82
|
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
|
83
|
+
logging.info('successfully export llm_llm')
|
|
84
|
+
|
|
85
|
+
# 3. export flow encoder
|
|
86
|
+
flow_encoder = model.model.flow.encoder
|
|
87
|
+
script = get_optimized_script(flow_encoder)
|
|
88
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
89
|
+
script = get_optimized_script(flow_encoder.half())
|
|
90
|
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
91
|
+
logging.info('successfully export flow_encoder')
|
|
92
|
+
else:
|
|
93
|
+
# 3. export flow encoder
|
|
94
|
+
flow_encoder = model.model.flow.encoder
|
|
95
|
+
script = get_optimized_script(flow_encoder)
|
|
96
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
97
|
+
script = get_optimized_script(flow_encoder.half())
|
|
98
|
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
99
|
+
logging.info('successfully export flow_encoder')
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if __name__ == '__main__':
|
|
103
|
+
main()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
|
2
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import print_function
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import logging
|
|
20
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
21
|
+
import os
|
|
22
|
+
import sys
|
|
23
|
+
import onnxruntime
|
|
24
|
+
import random
|
|
25
|
+
import torch
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
28
|
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
29
|
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
30
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
31
|
+
from cosyvoice.utils.file_utils import logging
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
|
35
|
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
36
|
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
|
37
|
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
38
|
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
|
39
|
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
|
40
|
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
41
|
+
return x, mask, mu, t, spks, cond
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_args():
|
|
45
|
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
46
|
+
parser.add_argument('--model_dir',
|
|
47
|
+
type=str,
|
|
48
|
+
default='pretrained_models/CosyVoice-300M',
|
|
49
|
+
help='local path')
|
|
50
|
+
args = parser.parse_args()
|
|
51
|
+
print(args)
|
|
52
|
+
return args
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@torch.no_grad()
|
|
56
|
+
def main():
|
|
57
|
+
args = get_args()
|
|
58
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
59
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
model = CosyVoice(args.model_dir)
|
|
63
|
+
except Exception:
|
|
64
|
+
try:
|
|
65
|
+
model = CosyVoice2(args.model_dir)
|
|
66
|
+
except Exception:
|
|
67
|
+
raise TypeError('no valid model_type!')
|
|
68
|
+
|
|
69
|
+
# 1. export flow decoder estimator
|
|
70
|
+
estimator = model.model.flow.decoder.estimator
|
|
71
|
+
estimator.eval()
|
|
72
|
+
|
|
73
|
+
device = model.model.device
|
|
74
|
+
batch_size, seq_len = 2, 256
|
|
75
|
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
76
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
77
|
+
torch.onnx.export(
|
|
78
|
+
estimator,
|
|
79
|
+
(x, mask, mu, t, spks, cond),
|
|
80
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
81
|
+
export_params=True,
|
|
82
|
+
opset_version=18,
|
|
83
|
+
do_constant_folding=True,
|
|
84
|
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
85
|
+
output_names=['estimator_out'],
|
|
86
|
+
dynamic_axes={
|
|
87
|
+
'x': {2: 'seq_len'},
|
|
88
|
+
'mask': {2: 'seq_len'},
|
|
89
|
+
'mu': {2: 'seq_len'},
|
|
90
|
+
'cond': {2: 'seq_len'},
|
|
91
|
+
'estimator_out': {2: 'seq_len'},
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# 2. test computation consistency
|
|
96
|
+
option = onnxruntime.SessionOptions()
|
|
97
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
98
|
+
option.intra_op_num_threads = 1
|
|
99
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
100
|
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
101
|
+
sess_options=option, providers=providers)
|
|
102
|
+
|
|
103
|
+
for _ in tqdm(range(10)):
|
|
104
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
105
|
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
|
106
|
+
ort_inputs = {
|
|
107
|
+
'x': x.cpu().numpy(),
|
|
108
|
+
'mask': mask.cpu().numpy(),
|
|
109
|
+
'mu': mu.cpu().numpy(),
|
|
110
|
+
't': t.cpu().numpy(),
|
|
111
|
+
'spks': spks.cpu().numpy(),
|
|
112
|
+
'cond': cond.cpu().numpy()
|
|
113
|
+
}
|
|
114
|
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
|
115
|
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
|
116
|
+
logging.info('successfully export estimator')
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
if __name__ == "__main__":
|
|
120
|
+
main()
|