xinference 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +72 -66
- xinference/core/model.py +78 -25
- xinference/core/supervisor.py +81 -10
- xinference/core/utils.py +12 -8
- xinference/core/worker.py +32 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +143 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +148 -0
- xinference/model/llm/mlx/core.py +37 -32
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/utils.py +28 -3
- xinference/model/llm/vllm/core.py +48 -9
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +112 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +116 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +132 -0
- xinference/model/llm/vllm/xavier/scheduler.py +422 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +122 -0
- xinference/model/llm/vllm/xavier/transfer.py +298 -0
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/types.py +13 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js +3 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/METADATA +19 -11
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/RECORD +178 -111
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.4eb4ee80.js +0 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/web/ui/build/static/js/{main.4eb4ee80.js.LICENSE.txt → main.1eb206d1.js.LICENSE.txt} +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -18,6 +18,7 @@ import datetime
|
|
|
18
18
|
import logging
|
|
19
19
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
20
|
from copy import deepcopy
|
|
21
|
+
import os
|
|
21
22
|
import torch
|
|
22
23
|
import torch.distributed as dist
|
|
23
24
|
import deepspeed
|
|
@@ -67,13 +68,17 @@ def get_args():
|
|
|
67
68
|
action='store_true',
|
|
68
69
|
default=False,
|
|
69
70
|
help='Use pinned memory buffers used for reading')
|
|
71
|
+
parser.add_argument('--use_amp',
|
|
72
|
+
action='store_true',
|
|
73
|
+
default=False,
|
|
74
|
+
help='Use automatic mixed precision training')
|
|
70
75
|
parser.add_argument('--deepspeed.save_states',
|
|
71
76
|
dest='save_states',
|
|
72
77
|
default='model_only',
|
|
73
78
|
choices=['model_only', 'model+optimizer'],
|
|
74
79
|
help='save model/optimizer states')
|
|
75
80
|
parser.add_argument('--timeout',
|
|
76
|
-
default=
|
|
81
|
+
default=60,
|
|
77
82
|
type=int,
|
|
78
83
|
help='timeout (in seconds) of cosyvoice_join.')
|
|
79
84
|
parser = deepspeed.add_config_arguments(parser)
|
|
@@ -86,10 +91,16 @@ def main():
|
|
|
86
91
|
args = get_args()
|
|
87
92
|
logging.basicConfig(level=logging.DEBUG,
|
|
88
93
|
format='%(asctime)s %(levelname)s %(message)s')
|
|
94
|
+
# gan train has some special initialization logic
|
|
95
|
+
gan = True if args.model == 'hifigan' else False
|
|
89
96
|
|
|
90
|
-
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
|
|
97
|
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
|
98
|
+
if gan is True:
|
|
99
|
+
override_dict.pop('hift')
|
|
91
100
|
with open(args.config, 'r') as f:
|
|
92
101
|
configs = load_hyperpyyaml(f, overrides=override_dict)
|
|
102
|
+
if gan is True:
|
|
103
|
+
configs['train_conf'] = configs['train_conf_gan']
|
|
93
104
|
configs['train_conf'].update(vars(args))
|
|
94
105
|
|
|
95
106
|
# Init env for ddp
|
|
@@ -97,7 +108,7 @@ def main():
|
|
|
97
108
|
|
|
98
109
|
# Get dataset & dataloader
|
|
99
110
|
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
|
100
|
-
init_dataset_and_dataloader(args, configs)
|
|
111
|
+
init_dataset_and_dataloader(args, configs, gan)
|
|
101
112
|
|
|
102
113
|
# Do some sanity checks and save config to arsg.model_dir
|
|
103
114
|
configs = check_modify_and_save_config(args, configs)
|
|
@@ -107,30 +118,53 @@ def main():
|
|
|
107
118
|
|
|
108
119
|
# load checkpoint
|
|
109
120
|
model = configs[args.model]
|
|
121
|
+
start_step, start_epoch = 0, -1
|
|
110
122
|
if args.checkpoint is not None:
|
|
111
|
-
|
|
123
|
+
if os.path.exists(args.checkpoint):
|
|
124
|
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
|
125
|
+
model.load_state_dict(state_dict, strict=False)
|
|
126
|
+
if 'step' in state_dict:
|
|
127
|
+
start_step = state_dict['step']
|
|
128
|
+
if 'epoch' in state_dict:
|
|
129
|
+
start_epoch = state_dict['epoch']
|
|
130
|
+
else:
|
|
131
|
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
|
112
132
|
|
|
113
133
|
# Dispatch model from cpu to gpu
|
|
114
134
|
model = wrap_cuda_model(args, model)
|
|
115
135
|
|
|
116
136
|
# Get optimizer & scheduler
|
|
117
|
-
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
|
|
137
|
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
|
138
|
+
scheduler.set_step(start_step)
|
|
139
|
+
if scheduler_d is not None:
|
|
140
|
+
scheduler_d.set_step(start_step)
|
|
118
141
|
|
|
119
142
|
# Save init checkpoints
|
|
120
143
|
info_dict = deepcopy(configs['train_conf'])
|
|
144
|
+
info_dict['step'] = start_step
|
|
145
|
+
info_dict['epoch'] = start_epoch
|
|
121
146
|
save_model(model, 'init', info_dict)
|
|
122
147
|
|
|
123
148
|
# Get executor
|
|
124
|
-
executor = Executor()
|
|
149
|
+
executor = Executor(gan=gan)
|
|
150
|
+
executor.step = start_step
|
|
125
151
|
|
|
152
|
+
# Init scaler, used for pytorch amp mixed precision training
|
|
153
|
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
|
154
|
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
|
126
155
|
# Start training loop
|
|
127
|
-
for epoch in range(info_dict['max_epoch']):
|
|
156
|
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
|
128
157
|
executor.epoch = epoch
|
|
129
158
|
train_dataset.set_epoch(epoch)
|
|
130
159
|
dist.barrier()
|
|
131
160
|
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
|
132
|
-
|
|
161
|
+
if gan is True:
|
|
162
|
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
|
163
|
+
writer, info_dict, scaler, group_join)
|
|
164
|
+
else:
|
|
165
|
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
|
133
166
|
dist.destroy_process_group(group_join)
|
|
134
167
|
|
|
168
|
+
|
|
135
169
|
if __name__ == '__main__':
|
|
136
170
|
main()
|
|
@@ -13,15 +13,18 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import os
|
|
15
15
|
import time
|
|
16
|
+
from tqdm import tqdm
|
|
16
17
|
from hyperpyyaml import load_hyperpyyaml
|
|
17
18
|
from modelscope import snapshot_download
|
|
19
|
+
import torch
|
|
18
20
|
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
19
|
-
from cosyvoice.cli.model import CosyVoiceModel
|
|
21
|
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
20
22
|
from cosyvoice.utils.file_utils import logging
|
|
21
23
|
|
|
24
|
+
|
|
22
25
|
class CosyVoice:
|
|
23
26
|
|
|
24
|
-
def __init__(self, model_dir, load_jit=True):
|
|
27
|
+
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
|
|
25
28
|
instruct = True if '-Instruct' in model_dir else False
|
|
26
29
|
self.model_dir = model_dir
|
|
27
30
|
if not os.path.exists(model_dir):
|
|
@@ -35,65 +38,133 @@ class CosyVoice:
|
|
|
35
38
|
'{}/spk2info.pt'.format(model_dir),
|
|
36
39
|
instruct,
|
|
37
40
|
configs['allowed_special'])
|
|
38
|
-
self.
|
|
41
|
+
self.sample_rate = configs['sample_rate']
|
|
42
|
+
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
|
|
43
|
+
load_jit = False
|
|
44
|
+
fp16 = False
|
|
45
|
+
logging.warning('cpu do not support fp16 and jit, force set to False')
|
|
46
|
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
|
39
47
|
self.model.load('{}/llm.pt'.format(model_dir),
|
|
40
48
|
'{}/flow.pt'.format(model_dir),
|
|
41
49
|
'{}/hift.pt'.format(model_dir))
|
|
42
50
|
if load_jit:
|
|
43
51
|
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
|
44
|
-
|
|
52
|
+
'{}/llm.llm.fp16.zip'.format(model_dir),
|
|
53
|
+
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
|
54
|
+
if load_onnx:
|
|
55
|
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
|
45
56
|
del configs
|
|
46
57
|
|
|
47
58
|
def list_avaliable_spks(self):
|
|
48
59
|
spks = list(self.frontend.spk2info.keys())
|
|
49
60
|
return spks
|
|
50
61
|
|
|
51
|
-
def inference_sft(self, tts_text, spk_id, stream=False):
|
|
52
|
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
62
|
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
|
63
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
53
64
|
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
54
65
|
start_time = time.time()
|
|
55
66
|
logging.info('synthesis text {}'.format(i))
|
|
56
|
-
for model_output in self.model.
|
|
57
|
-
speech_len = model_output['tts_speech'].shape[1] /
|
|
67
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
68
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
58
69
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
59
70
|
yield model_output
|
|
60
71
|
start_time = time.time()
|
|
61
72
|
|
|
62
|
-
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
|
63
|
-
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
|
64
|
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
65
|
-
|
|
73
|
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
|
74
|
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
|
75
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
76
|
+
if len(i) < 0.5 * len(prompt_text):
|
|
77
|
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
|
78
|
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
|
66
79
|
start_time = time.time()
|
|
67
80
|
logging.info('synthesis text {}'.format(i))
|
|
68
|
-
for model_output in self.model.
|
|
69
|
-
speech_len = model_output['tts_speech'].shape[1] /
|
|
81
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
82
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
70
83
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
71
84
|
yield model_output
|
|
72
85
|
start_time = time.time()
|
|
73
86
|
|
|
74
|
-
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
|
75
|
-
if self.frontend.instruct is True:
|
|
87
|
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
|
88
|
+
if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
|
|
76
89
|
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
|
77
|
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
78
|
-
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
|
90
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
91
|
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
|
79
92
|
start_time = time.time()
|
|
80
93
|
logging.info('synthesis text {}'.format(i))
|
|
81
|
-
for model_output in self.model.
|
|
82
|
-
speech_len = model_output['tts_speech'].shape[1] /
|
|
94
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
95
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
83
96
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
84
97
|
yield model_output
|
|
85
98
|
start_time = time.time()
|
|
86
99
|
|
|
87
|
-
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
|
|
100
|
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
|
101
|
+
assert isinstance(self.model, CosyVoiceModel)
|
|
88
102
|
if self.frontend.instruct is False:
|
|
89
103
|
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
|
90
|
-
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
|
91
|
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
104
|
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
|
105
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
92
106
|
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
|
93
107
|
start_time = time.time()
|
|
94
108
|
logging.info('synthesis text {}'.format(i))
|
|
95
|
-
for model_output in self.model.
|
|
96
|
-
speech_len = model_output['tts_speech'].shape[1] /
|
|
109
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
110
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
111
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
112
|
+
yield model_output
|
|
113
|
+
start_time = time.time()
|
|
114
|
+
|
|
115
|
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
|
116
|
+
assert isinstance(self.model, CosyVoice2Model)
|
|
117
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
118
|
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
|
119
|
+
start_time = time.time()
|
|
120
|
+
logging.info('synthesis text {}'.format(i))
|
|
121
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
122
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
97
123
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
98
124
|
yield model_output
|
|
99
125
|
start_time = time.time()
|
|
126
|
+
|
|
127
|
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
|
128
|
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
|
131
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
132
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
133
|
+
yield model_output
|
|
134
|
+
start_time = time.time()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class CosyVoice2(CosyVoice):
|
|
138
|
+
|
|
139
|
+
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
|
140
|
+
instruct = True if '-Instruct' in model_dir else False
|
|
141
|
+
self.model_dir = model_dir
|
|
142
|
+
if not os.path.exists(model_dir):
|
|
143
|
+
model_dir = snapshot_download(model_dir)
|
|
144
|
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
|
145
|
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
|
146
|
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
|
147
|
+
configs['feat_extractor'],
|
|
148
|
+
'{}/campplus.onnx'.format(model_dir),
|
|
149
|
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
|
150
|
+
'{}/spk2info.pt'.format(model_dir),
|
|
151
|
+
instruct,
|
|
152
|
+
configs['allowed_special'])
|
|
153
|
+
self.sample_rate = configs['sample_rate']
|
|
154
|
+
if torch.cuda.is_available() is False and load_jit is True:
|
|
155
|
+
load_jit = False
|
|
156
|
+
logging.warning('cpu do not support jit, force set to False')
|
|
157
|
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
|
158
|
+
self.model.load('{}/llm.pt'.format(model_dir),
|
|
159
|
+
'{}/flow.pt'.format(model_dir),
|
|
160
|
+
'{}/hift.pt'.format(model_dir))
|
|
161
|
+
if load_jit:
|
|
162
|
+
self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
|
|
163
|
+
if load_trt is True and load_onnx is True:
|
|
164
|
+
load_onnx = False
|
|
165
|
+
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
|
|
166
|
+
if load_onnx:
|
|
167
|
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
|
168
|
+
if load_trt:
|
|
169
|
+
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
|
|
170
|
+
del configs
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
from functools import partial
|
|
15
|
+
import json
|
|
15
16
|
import onnxruntime
|
|
16
17
|
import torch
|
|
17
18
|
import numpy as np
|
|
@@ -50,9 +51,13 @@ class CosyVoiceFrontEnd:
|
|
|
50
51
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
51
52
|
option.intra_op_num_threads = 1
|
|
52
53
|
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
|
53
|
-
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
|
54
|
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
|
55
|
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
|
56
|
+
"CPUExecutionProvider"])
|
|
54
57
|
if os.path.exists(spk2info):
|
|
55
58
|
self.spk2info = torch.load(spk2info, map_location=self.device)
|
|
59
|
+
else:
|
|
60
|
+
self.spk2info = {}
|
|
56
61
|
self.instruct = instruct
|
|
57
62
|
self.allowed_special = allowed_special
|
|
58
63
|
self.inflect_parser = inflect.engine()
|
|
@@ -60,10 +65,9 @@ class CosyVoiceFrontEnd:
|
|
|
60
65
|
if self.use_ttsfrd:
|
|
61
66
|
self.frd = ttsfrd.TtsFrontendEngine()
|
|
62
67
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
63
|
-
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True,
|
|
64
|
-
|
|
65
|
-
self.frd.
|
|
66
|
-
self.frd.set_breakmodel_index(1)
|
|
68
|
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
|
69
|
+
'failed to initialize ttsfrd resource'
|
|
70
|
+
self.frd.set_lang_type('pinyinvg')
|
|
67
71
|
else:
|
|
68
72
|
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
|
69
73
|
self.en_tn_model = EnNormalizer()
|
|
@@ -75,9 +79,13 @@ class CosyVoiceFrontEnd:
|
|
|
75
79
|
return text_token, text_token_len
|
|
76
80
|
|
|
77
81
|
def _extract_speech_token(self, speech):
|
|
82
|
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
|
78
83
|
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
|
79
|
-
speech_token = self.speech_tokenizer_session.run(None,
|
|
80
|
-
|
|
84
|
+
speech_token = self.speech_tokenizer_session.run(None,
|
|
85
|
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
|
86
|
+
feat.detach().cpu().numpy(),
|
|
87
|
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
|
88
|
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
|
81
89
|
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
|
82
90
|
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
|
83
91
|
return speech_token, speech_token_len
|
|
@@ -88,7 +96,8 @@ class CosyVoiceFrontEnd:
|
|
|
88
96
|
dither=0,
|
|
89
97
|
sample_frequency=16000)
|
|
90
98
|
feat = feat - feat.mean(dim=0, keepdim=True)
|
|
91
|
-
embedding = self.campplus_session.run(None,
|
|
99
|
+
embedding = self.campplus_session.run(None,
|
|
100
|
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
|
92
101
|
embedding = torch.tensor([embedding]).to(self.device)
|
|
93
102
|
return embedding
|
|
94
103
|
|
|
@@ -98,32 +107,34 @@ class CosyVoiceFrontEnd:
|
|
|
98
107
|
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
|
99
108
|
return speech_feat, speech_feat_len
|
|
100
109
|
|
|
101
|
-
def text_normalize(self, text, split=True):
|
|
110
|
+
def text_normalize(self, text, split=True, text_frontend=True):
|
|
111
|
+
if text_frontend is False:
|
|
112
|
+
return [text] if split is True else text
|
|
102
113
|
text = text.strip()
|
|
103
114
|
if contains_chinese(text):
|
|
104
115
|
if self.use_ttsfrd:
|
|
105
|
-
|
|
116
|
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
|
117
|
+
text = ''.join(texts)
|
|
106
118
|
else:
|
|
107
119
|
text = self.zh_tn_model.normalize(text)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
comma_split=False)]
|
|
120
|
+
text = text.replace("\n", "")
|
|
121
|
+
text = replace_blank(text)
|
|
122
|
+
text = replace_corner_mark(text)
|
|
123
|
+
text = text.replace(".", "。")
|
|
124
|
+
text = text.replace(" - ", ",")
|
|
125
|
+
text = remove_bracket(text)
|
|
126
|
+
text = re.sub(r'[,,、]+$', '。', text)
|
|
127
|
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
|
128
|
+
token_min_n=60, merge_len=20, comma_split=False))
|
|
118
129
|
else:
|
|
119
130
|
if self.use_ttsfrd:
|
|
120
|
-
|
|
131
|
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
|
132
|
+
text = ''.join(texts)
|
|
121
133
|
else:
|
|
122
134
|
text = self.en_tn_model.normalize(text)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
comma_split=False)]
|
|
135
|
+
text = spell_out_number(text, self.inflect_parser)
|
|
136
|
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
|
137
|
+
token_min_n=60, merge_len=20, comma_split=False))
|
|
127
138
|
if split is False:
|
|
128
139
|
return text
|
|
129
140
|
return texts
|
|
@@ -134,12 +145,17 @@ class CosyVoiceFrontEnd:
|
|
|
134
145
|
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
135
146
|
return model_input
|
|
136
147
|
|
|
137
|
-
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
|
148
|
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
|
138
149
|
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
|
139
150
|
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
|
140
|
-
|
|
141
|
-
speech_feat, speech_feat_len = self._extract_speech_feat(
|
|
151
|
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
|
152
|
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
|
142
153
|
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
154
|
+
if resample_rate == 24000:
|
|
155
|
+
# cosyvoice2, force speech_feat % speech_token = 2
|
|
156
|
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
|
157
|
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
|
158
|
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
|
143
159
|
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
144
160
|
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
145
161
|
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
|
@@ -149,8 +165,8 @@ class CosyVoiceFrontEnd:
|
|
|
149
165
|
'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
150
166
|
return model_input
|
|
151
167
|
|
|
152
|
-
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
|
|
153
|
-
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
|
|
168
|
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
|
169
|
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
|
154
170
|
# in cross lingual mode, we remove prompt in llm
|
|
155
171
|
del model_input['prompt_text']
|
|
156
172
|
del model_input['prompt_text_len']
|
|
@@ -166,3 +182,34 @@ class CosyVoiceFrontEnd:
|
|
|
166
182
|
model_input['prompt_text'] = instruct_text_token
|
|
167
183
|
model_input['prompt_text_len'] = instruct_text_token_len
|
|
168
184
|
return model_input
|
|
185
|
+
|
|
186
|
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
|
187
|
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
|
188
|
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
|
|
189
|
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
|
190
|
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
|
191
|
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
192
|
+
if resample_rate == 24000:
|
|
193
|
+
# cosyvoice2, force speech_feat % speech_token = 2
|
|
194
|
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
|
195
|
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
|
196
|
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
|
197
|
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
198
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
199
|
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
|
200
|
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
201
|
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
202
|
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
|
203
|
+
return model_input
|
|
204
|
+
|
|
205
|
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
|
206
|
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
|
207
|
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
|
208
|
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
|
209
|
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
|
210
|
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
|
211
|
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
|
212
|
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
|
213
|
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
|
214
|
+
'flow_embedding': embedding}
|
|
215
|
+
return model_input
|