xinference 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +72 -66
- xinference/core/model.py +78 -25
- xinference/core/supervisor.py +81 -10
- xinference/core/utils.py +12 -8
- xinference/core/worker.py +32 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +143 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +148 -0
- xinference/model/llm/mlx/core.py +37 -32
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/utils.py +28 -3
- xinference/model/llm/vllm/core.py +48 -9
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +112 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +116 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +132 -0
- xinference/model/llm/vllm/xavier/scheduler.py +422 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +122 -0
- xinference/model/llm/vllm/xavier/transfer.py +298 -0
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/types.py +13 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js +3 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/METADATA +19 -11
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/RECORD +178 -111
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.4eb4ee80.js +0 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/web/ui/build/static/js/{main.4eb4ee80.js.LICENSE.txt → main.1eb206d1.js.LICENSE.txt} +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import logging
|
|
16
|
+
from functools import lru_cache
|
|
17
|
+
from queue import Queue
|
|
18
|
+
from typing import Dict, List, Optional, no_type_check
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import xoscar as xo
|
|
22
|
+
from vllm.core.scheduler import Scheduler
|
|
23
|
+
from vllm.utils import TORCH_DTYPE_TO_NUMPY_DTYPE, Device
|
|
24
|
+
from vllm.worker.cache_engine import CacheEngine
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BufferTransferMixin:
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self.num_buffer: int = 0
|
|
32
|
+
self.buffers: List[torch.Tensor] = []
|
|
33
|
+
self.buffer_queue: Optional[Queue] = None
|
|
34
|
+
self.transfer_block_num = 0
|
|
35
|
+
self.num_attn_layers = 0
|
|
36
|
+
|
|
37
|
+
def init_buffer(
|
|
38
|
+
self, num_buffer: int, buffer_shape, buffer_dtype, buffer_device, pin_memory
|
|
39
|
+
):
|
|
40
|
+
# (transfer_block_num, num_attn_layers, 2, *kv_cache_shape[2:])
|
|
41
|
+
|
|
42
|
+
if buffer_dtype is torch.bfloat16:
|
|
43
|
+
buffer_dtype = torch.float16
|
|
44
|
+
|
|
45
|
+
self.num_buffer = num_buffer
|
|
46
|
+
self.transfer_block_num = buffer_shape[0]
|
|
47
|
+
self.num_attn_layers = buffer_shape[1]
|
|
48
|
+
|
|
49
|
+
self.buffers = [
|
|
50
|
+
torch.zeros(
|
|
51
|
+
size=buffer_shape,
|
|
52
|
+
dtype=buffer_dtype,
|
|
53
|
+
device=buffer_device,
|
|
54
|
+
pin_memory=pin_memory,
|
|
55
|
+
)
|
|
56
|
+
for _ in range(self.num_buffer)
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
self.buffer_queue = Queue()
|
|
60
|
+
for i in range(self.num_buffer):
|
|
61
|
+
self.buffer_queue.put_nowait(i)
|
|
62
|
+
logger.debug(
|
|
63
|
+
f"Init buffer done. "
|
|
64
|
+
f"transfer_block_num: {self.transfer_block_num}, "
|
|
65
|
+
f"num_buffer: {self.num_buffer}, "
|
|
66
|
+
f"buffer_dtype: {buffer_dtype}, "
|
|
67
|
+
f"buffer_shape: {buffer_shape}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@no_type_check
|
|
71
|
+
def get_buffer_index(self) -> int:
|
|
72
|
+
return self.buffer_queue.get()
|
|
73
|
+
|
|
74
|
+
@no_type_check
|
|
75
|
+
def free_buffer_index(self, index: int) -> None:
|
|
76
|
+
self.buffer_queue.put_nowait(index)
|
|
77
|
+
|
|
78
|
+
def get_swap_buffer(self, index: int, num_blocks: int) -> torch.Tensor:
|
|
79
|
+
buf = self.buffers[index]
|
|
80
|
+
buffer = buf[:num_blocks].view(
|
|
81
|
+
self.num_attn_layers, 2, num_blocks, *buf.shape[3:]
|
|
82
|
+
)
|
|
83
|
+
return buffer
|
|
84
|
+
|
|
85
|
+
@lru_cache(maxsize=None)
|
|
86
|
+
def get_gloo_dtype(self, input_dtype: torch.dtype):
|
|
87
|
+
from xoscar.collective.common import TypeMappingGloo
|
|
88
|
+
|
|
89
|
+
return TypeMappingGloo[TORCH_DTYPE_TO_NUMPY_DTYPE[input_dtype]]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TransferActor(xo.StatelessActor, BufferTransferMixin):
|
|
93
|
+
@classmethod
|
|
94
|
+
def default_uid(cls):
|
|
95
|
+
return f"vllm-transfer-actor"
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
rank: int,
|
|
100
|
+
world_size: int,
|
|
101
|
+
rank_address: str,
|
|
102
|
+
store_address: str,
|
|
103
|
+
store_port: int,
|
|
104
|
+
world_addresses: List[str],
|
|
105
|
+
):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self._rank = rank
|
|
108
|
+
self._world_size = world_size
|
|
109
|
+
self._store_address = store_address
|
|
110
|
+
self._rank_address = rank_address
|
|
111
|
+
self._store_port = store_port
|
|
112
|
+
self._world_addresses = world_addresses
|
|
113
|
+
self._context = None
|
|
114
|
+
self._cache_engine: Optional[List[CacheEngine]] = None
|
|
115
|
+
self._scheduler: Optional[List[Scheduler]] = None
|
|
116
|
+
self._swap_stream = torch.cuda.Stream()
|
|
117
|
+
|
|
118
|
+
async def __post_create__(self):
|
|
119
|
+
from xoscar.collective import xoscar_pygloo as xp
|
|
120
|
+
|
|
121
|
+
context = xp.rendezvous.Context(self._rank, self._world_size)
|
|
122
|
+
|
|
123
|
+
attr = xp.transport.tcp.attr(self._rank_address.split(":")[0])
|
|
124
|
+
dev = xp.transport.tcp.CreateDevice(attr)
|
|
125
|
+
|
|
126
|
+
opt = xp.rendezvous.TCPStoreOptions()
|
|
127
|
+
opt.port = self._store_port
|
|
128
|
+
opt.numWorkers = self._world_size
|
|
129
|
+
opt.isServer = self._rank == 0
|
|
130
|
+
|
|
131
|
+
store = xp.rendezvous.TCPStore(self._store_address, opt)
|
|
132
|
+
store = xp.rendezvous.PrefixStore(str(self._world_size), store)
|
|
133
|
+
|
|
134
|
+
context.connectFullMesh(store, dev)
|
|
135
|
+
self._context = context
|
|
136
|
+
logger.debug(
|
|
137
|
+
f"Rank {self._rank} arrives successfully, world addresses: {self._world_addresses}"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def setup(
|
|
141
|
+
self,
|
|
142
|
+
cache_engine: List[CacheEngine],
|
|
143
|
+
scheduler: List[Scheduler],
|
|
144
|
+
num_buffer: int,
|
|
145
|
+
buffer_shape,
|
|
146
|
+
buffer_dtype,
|
|
147
|
+
buffer_device,
|
|
148
|
+
pin_memory: bool,
|
|
149
|
+
):
|
|
150
|
+
self._cache_engine = cache_engine
|
|
151
|
+
self._scheduler = scheduler
|
|
152
|
+
self.init_buffer(
|
|
153
|
+
num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def _get_cache_engine(self, virtual_engine: int) -> CacheEngine:
|
|
157
|
+
return self._cache_engine[virtual_engine] # type: ignore
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _get_swap_block_ids(src_to_dst: Dict[int, int], is_sender: bool) -> List[int]:
|
|
161
|
+
return list(sorted([r if is_sender else l for r, l in src_to_dst.items()]))
|
|
162
|
+
|
|
163
|
+
def _swap_out_to_buffer(
|
|
164
|
+
self, cache_engine: CacheEngine, cpu_buf_index: int, block_ids: List[int]
|
|
165
|
+
) -> torch.Tensor:
|
|
166
|
+
num_blocks = len(block_ids)
|
|
167
|
+
src_to_dst = torch.tensor(
|
|
168
|
+
[(block_num, idx) for idx, block_num in enumerate(block_ids)],
|
|
169
|
+
device="cpu",
|
|
170
|
+
dtype=torch.int64,
|
|
171
|
+
).view(-1, 2)
|
|
172
|
+
cpu_buf = self.get_swap_buffer(cpu_buf_index, num_blocks)
|
|
173
|
+
with torch.cuda.stream(self._swap_stream):
|
|
174
|
+
for i in range(self.num_attn_layers):
|
|
175
|
+
cache_engine.attn_backend.swap_blocks(
|
|
176
|
+
cache_engine.gpu_cache[i], cpu_buf[i], src_to_dst
|
|
177
|
+
)
|
|
178
|
+
torch.cuda.Stream.synchronize(self._swap_stream)
|
|
179
|
+
return cpu_buf
|
|
180
|
+
|
|
181
|
+
def _swap_in_from_buffer(
|
|
182
|
+
self, cache_engine: CacheEngine, cpu_buf: torch.Tensor, block_ids: List[int]
|
|
183
|
+
) -> None:
|
|
184
|
+
src_to_dst = torch.tensor(
|
|
185
|
+
[(idx, block_num) for idx, block_num in enumerate(block_ids)],
|
|
186
|
+
device="cpu",
|
|
187
|
+
dtype=torch.int64,
|
|
188
|
+
).view(-1, 2)
|
|
189
|
+
with torch.cuda.stream(self._swap_stream):
|
|
190
|
+
for i in range(self.num_attn_layers):
|
|
191
|
+
cache_engine.attn_backend.swap_blocks(
|
|
192
|
+
cpu_buf[i], cache_engine.gpu_cache[i], src_to_dst
|
|
193
|
+
)
|
|
194
|
+
torch.cuda.Stream.synchronize(self._swap_stream)
|
|
195
|
+
|
|
196
|
+
def _incr_count_for_block_id(self, virtual_engine: int, block_ids: List[int]):
|
|
197
|
+
"""
|
|
198
|
+
The reference count of the `block_id` involved in the transfer is incremented by 1
|
|
199
|
+
to ensure it is not reclaimed.
|
|
200
|
+
"""
|
|
201
|
+
scheduler = self._scheduler[virtual_engine] # type: ignore
|
|
202
|
+
gpu_allocator = scheduler.block_manager.block_allocator._allocators[Device.GPU]
|
|
203
|
+
|
|
204
|
+
for _id in block_ids:
|
|
205
|
+
gpu_allocator._refcounter.incr(_id)
|
|
206
|
+
|
|
207
|
+
def _decr_count_for_block_id(self, virtual_engine: int, block_ids: List[int]):
|
|
208
|
+
"""
|
|
209
|
+
After the transfer, the reference count is decremented by 1.
|
|
210
|
+
"""
|
|
211
|
+
scheduler = self._scheduler[virtual_engine] # type: ignore
|
|
212
|
+
gpu_allocator = scheduler.block_manager.block_allocator._allocators[Device.GPU]
|
|
213
|
+
|
|
214
|
+
for _id in block_ids:
|
|
215
|
+
gpu_allocator._refcounter.decr(_id)
|
|
216
|
+
|
|
217
|
+
async def do_send(
|
|
218
|
+
self, virtual_engine: int, to_rank: int, src_to_dst: Dict[int, int]
|
|
219
|
+
):
|
|
220
|
+
"""
|
|
221
|
+
Sending logic: GPU -> Buffer -> Gloo send.
|
|
222
|
+
GPU -> Buffer is directly handled using the internal `swap_out` interface of vllm.
|
|
223
|
+
"""
|
|
224
|
+
from xoscar.collective import xoscar_pygloo as xp
|
|
225
|
+
|
|
226
|
+
cache_engine = self._get_cache_engine(virtual_engine)
|
|
227
|
+
|
|
228
|
+
block_ids = self._get_swap_block_ids(src_to_dst, is_sender=True)
|
|
229
|
+
self._incr_count_for_block_id(virtual_engine, block_ids)
|
|
230
|
+
cpu_buf_index = self.get_buffer_index()
|
|
231
|
+
total_blocks: int = len(block_ids)
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
for start_idx in range(0, total_blocks, self.transfer_block_num):
|
|
235
|
+
offset = min(self.transfer_block_num, total_blocks - start_idx)
|
|
236
|
+
send_block_ids = block_ids[start_idx : start_idx + offset]
|
|
237
|
+
sendbuf = self._swap_out_to_buffer(
|
|
238
|
+
cache_engine, cpu_buf_index, send_block_ids
|
|
239
|
+
)
|
|
240
|
+
assert sendbuf.is_contiguous()
|
|
241
|
+
sendptr = sendbuf.numpy().ctypes.data
|
|
242
|
+
data_size = sendbuf.numel()
|
|
243
|
+
datatype = self.get_gloo_dtype(sendbuf.dtype)
|
|
244
|
+
peer = to_rank
|
|
245
|
+
xp.send(self._context, sendptr, data_size, datatype, peer)
|
|
246
|
+
finally:
|
|
247
|
+
self._decr_count_for_block_id(virtual_engine, block_ids)
|
|
248
|
+
self.free_buffer_index(cpu_buf_index)
|
|
249
|
+
|
|
250
|
+
async def do_recv(
|
|
251
|
+
self, virtual_engine: int, from_rank: int, src_to_dst: Dict[int, int]
|
|
252
|
+
):
|
|
253
|
+
"""
|
|
254
|
+
Receiving logic: Gloo recv -> Buffer -> GPU.
|
|
255
|
+
Buffer -> GPU is directly handled using the internal `swap_in` interface of vllm.
|
|
256
|
+
"""
|
|
257
|
+
from xoscar.collective import xoscar_pygloo as xp
|
|
258
|
+
|
|
259
|
+
cache_engine = self._get_cache_engine(virtual_engine)
|
|
260
|
+
|
|
261
|
+
block_ids = self._get_swap_block_ids(src_to_dst, is_sender=False)
|
|
262
|
+
self._incr_count_for_block_id(virtual_engine, block_ids)
|
|
263
|
+
total_blocks = len(block_ids)
|
|
264
|
+
cpu_buf_index = self.get_buffer_index()
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
for start_idx in range(0, total_blocks, self.transfer_block_num):
|
|
268
|
+
offset = min(self.transfer_block_num, total_blocks - start_idx)
|
|
269
|
+
recv_block_ids = block_ids[start_idx : start_idx + offset]
|
|
270
|
+
recvbuf = self.get_swap_buffer(cpu_buf_index, len(recv_block_ids))
|
|
271
|
+
assert recvbuf.is_contiguous()
|
|
272
|
+
recvptr = recvbuf.numpy().ctypes.data
|
|
273
|
+
data_size = recvbuf.numel()
|
|
274
|
+
datatype = self.get_gloo_dtype(recvbuf.dtype)
|
|
275
|
+
peer = from_rank
|
|
276
|
+
xp.recv(self._context, recvptr, data_size, datatype, peer)
|
|
277
|
+
|
|
278
|
+
self._swap_in_from_buffer(cache_engine, recvbuf, recv_block_ids)
|
|
279
|
+
finally:
|
|
280
|
+
self._decr_count_for_block_id(virtual_engine, block_ids)
|
|
281
|
+
self.free_buffer_index(cpu_buf_index)
|
|
282
|
+
|
|
283
|
+
async def recv(
|
|
284
|
+
self, virtual_engine: int, from_address: str, src_to_dst: Dict[int, int]
|
|
285
|
+
):
|
|
286
|
+
"""
|
|
287
|
+
This is the external entry point for the call.
|
|
288
|
+
The transfer logic is as follows:
|
|
289
|
+
the receiver requests the sender to send the data directly to itself in a point-to-point manner.
|
|
290
|
+
"""
|
|
291
|
+
rank = self._world_addresses.index(from_address)
|
|
292
|
+
sender_ref = await xo.actor_ref(
|
|
293
|
+
address=from_address, uid=f"{TransferActor.default_uid()}-{rank}"
|
|
294
|
+
)
|
|
295
|
+
await asyncio.gather(
|
|
296
|
+
sender_ref.do_send(virtual_engine, self._rank, src_to_dst),
|
|
297
|
+
self.do_recv(virtual_engine, rank, src_to_dst),
|
|
298
|
+
)
|
|
@@ -91,6 +91,20 @@ class DiffUsersVideoModel:
|
|
|
91
91
|
pipeline = self._model = CogVideoXPipeline.from_pretrained(
|
|
92
92
|
self._model_path, **kwargs
|
|
93
93
|
)
|
|
94
|
+
elif self._model_spec.model_family == "HunyuanVideo":
|
|
95
|
+
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
|
96
|
+
|
|
97
|
+
transformer_torch_dtype = kwargs.pop("transformer_torch_dtype")
|
|
98
|
+
if isinstance(transformer_torch_dtype, str):
|
|
99
|
+
transformer_torch_dtype = getattr(torch, transformer_torch_dtype)
|
|
100
|
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
|
101
|
+
self._model_path,
|
|
102
|
+
subfolder="transformer",
|
|
103
|
+
torch_dtype=transformer_torch_dtype,
|
|
104
|
+
)
|
|
105
|
+
pipeline = self._model = HunyuanVideoPipeline.from_pretrained(
|
|
106
|
+
self._model_path, transformer=transformer, **kwargs
|
|
107
|
+
)
|
|
94
108
|
else:
|
|
95
109
|
raise Exception(
|
|
96
110
|
f"Unsupported model family: {self._model_spec.model_family}"
|
|
@@ -30,5 +30,20 @@
|
|
|
30
30
|
"default_generate_config": {
|
|
31
31
|
"guidance_scale": 7
|
|
32
32
|
}
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
"model_name": "HunyuanVideo",
|
|
36
|
+
"model_family": "HunyuanVideo",
|
|
37
|
+
"model_id": "hunyuanvideo-community/HunyuanVideo",
|
|
38
|
+
"model_revision": "e8c2aaa66fe3742a32c11a6766aecbf07c56e773",
|
|
39
|
+
"model_ability": [
|
|
40
|
+
"text2video"
|
|
41
|
+
],
|
|
42
|
+
"default_model_config": {
|
|
43
|
+
"transformer_torch_dtype": "bfloat16",
|
|
44
|
+
"torch_dtype": "float16"
|
|
45
|
+
},
|
|
46
|
+
"default_generate_config": {
|
|
47
|
+
}
|
|
33
48
|
}
|
|
34
49
|
]
|
|
@@ -32,5 +32,21 @@
|
|
|
32
32
|
"default_generate_config": {
|
|
33
33
|
"guidance_scale": 7
|
|
34
34
|
}
|
|
35
|
+
},
|
|
36
|
+
{
|
|
37
|
+
"model_name": "HunyuanVideo",
|
|
38
|
+
"model_family": "HunyuanVideo",
|
|
39
|
+
"model_hub": "modelscope",
|
|
40
|
+
"model_id": "Xorbits/HunyuanVideo",
|
|
41
|
+
"model_revision": "master",
|
|
42
|
+
"model_ability": [
|
|
43
|
+
"text2video"
|
|
44
|
+
],
|
|
45
|
+
"default_model_config": {
|
|
46
|
+
"transformer_torch_dtype": "bfloat16",
|
|
47
|
+
"torch_dtype": "float16"
|
|
48
|
+
},
|
|
49
|
+
"default_generate_config": {
|
|
50
|
+
}
|
|
35
51
|
}
|
|
36
52
|
]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
|
2
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import argparse
|
|
18
|
+
import glob
|
|
19
|
+
|
|
20
|
+
import yaml
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_args():
|
|
25
|
+
parser = argparse.ArgumentParser(description='average model')
|
|
26
|
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
|
27
|
+
parser.add_argument('--src_path',
|
|
28
|
+
required=True,
|
|
29
|
+
help='src model path for average')
|
|
30
|
+
parser.add_argument('--val_best',
|
|
31
|
+
action="store_true",
|
|
32
|
+
help='averaged model')
|
|
33
|
+
parser.add_argument('--num',
|
|
34
|
+
default=5,
|
|
35
|
+
type=int,
|
|
36
|
+
help='nums for averaged model')
|
|
37
|
+
|
|
38
|
+
args = parser.parse_args()
|
|
39
|
+
print(args)
|
|
40
|
+
return args
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def main():
|
|
44
|
+
args = get_args()
|
|
45
|
+
val_scores = []
|
|
46
|
+
if args.val_best:
|
|
47
|
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
|
48
|
+
yamls = [
|
|
49
|
+
f for f in yamls
|
|
50
|
+
if not (os.path.basename(f).startswith('train')
|
|
51
|
+
or os.path.basename(f).startswith('init'))
|
|
52
|
+
]
|
|
53
|
+
for y in yamls:
|
|
54
|
+
with open(y, 'r') as f:
|
|
55
|
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
|
56
|
+
loss = float(dic_yaml['loss_dict']['loss'])
|
|
57
|
+
epoch = int(dic_yaml['epoch'])
|
|
58
|
+
step = int(dic_yaml['step'])
|
|
59
|
+
tag = dic_yaml['tag']
|
|
60
|
+
val_scores += [[epoch, step, loss, tag]]
|
|
61
|
+
sorted_val_scores = sorted(val_scores,
|
|
62
|
+
key=lambda x: x[2],
|
|
63
|
+
reverse=False)
|
|
64
|
+
print("best val (epoch, step, loss, tag) = " +
|
|
65
|
+
str(sorted_val_scores[:args.num]))
|
|
66
|
+
path_list = [
|
|
67
|
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
|
68
|
+
for score in sorted_val_scores[:args.num]
|
|
69
|
+
]
|
|
70
|
+
print(path_list)
|
|
71
|
+
avg = {}
|
|
72
|
+
num = args.num
|
|
73
|
+
assert num == len(path_list)
|
|
74
|
+
for path in path_list:
|
|
75
|
+
print('Processing {}'.format(path))
|
|
76
|
+
states = torch.load(path, map_location=torch.device('cpu'))
|
|
77
|
+
for k in states.keys():
|
|
78
|
+
if k not in avg.keys():
|
|
79
|
+
avg[k] = states[k].clone()
|
|
80
|
+
else:
|
|
81
|
+
avg[k] += states[k]
|
|
82
|
+
# average
|
|
83
|
+
for k in avg.keys():
|
|
84
|
+
if avg[k] is not None:
|
|
85
|
+
# pytorch 1.6 use true_divide instead of /=
|
|
86
|
+
avg[k] = torch.true_divide(avg[k], num)
|
|
87
|
+
print('Saving to {}'.format(args.dst_model))
|
|
88
|
+
torch.save(avg, args.dst_model)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
if __name__ == '__main__':
|
|
92
|
+
main()
|
|
@@ -19,12 +19,13 @@ import logging
|
|
|
19
19
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
20
|
import os
|
|
21
21
|
import sys
|
|
22
|
+
import torch
|
|
22
23
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
23
24
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
24
25
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
25
|
-
import torch
|
|
26
26
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
27
27
|
|
|
28
|
+
|
|
28
29
|
def get_args():
|
|
29
30
|
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
30
31
|
parser.add_argument('--model_dir',
|
|
@@ -35,6 +36,7 @@ def get_args():
|
|
|
35
36
|
print(args)
|
|
36
37
|
return args
|
|
37
38
|
|
|
39
|
+
|
|
38
40
|
def main():
|
|
39
41
|
args = get_args()
|
|
40
42
|
logging.basicConfig(level=logging.DEBUG,
|
|
@@ -44,7 +46,7 @@ def main():
|
|
|
44
46
|
torch._C._jit_set_profiling_mode(False)
|
|
45
47
|
torch._C._jit_set_profiling_executor(False)
|
|
46
48
|
|
|
47
|
-
cosyvoice = CosyVoice(args.model_dir, load_jit=False,
|
|
49
|
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
|
48
50
|
|
|
49
51
|
# 1. export llm text_encoder
|
|
50
52
|
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
|
@@ -60,5 +62,13 @@ def main():
|
|
|
60
62
|
script = torch.jit.optimize_for_inference(script)
|
|
61
63
|
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
|
62
64
|
|
|
65
|
+
# 3. export flow encoder
|
|
66
|
+
flow_encoder = cosyvoice.model.flow.encoder
|
|
67
|
+
script = torch.jit.script(flow_encoder)
|
|
68
|
+
script = torch.jit.freeze(script)
|
|
69
|
+
script = torch.jit.optimize_for_inference(script)
|
|
70
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
71
|
+
|
|
72
|
+
|
|
63
73
|
if __name__ == '__main__':
|
|
64
74
|
main()
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
|
2
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import print_function
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import logging
|
|
20
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
21
|
+
import os
|
|
22
|
+
import sys
|
|
23
|
+
import onnxruntime
|
|
24
|
+
import random
|
|
25
|
+
import torch
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
28
|
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
29
|
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
30
|
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
|
34
|
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
35
|
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
|
36
|
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
37
|
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
|
38
|
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
|
39
|
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
|
40
|
+
return x, mask, mu, t, spks, cond
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_args():
|
|
44
|
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
|
45
|
+
parser.add_argument('--model_dir',
|
|
46
|
+
type=str,
|
|
47
|
+
default='pretrained_models/CosyVoice-300M',
|
|
48
|
+
help='local path')
|
|
49
|
+
args = parser.parse_args()
|
|
50
|
+
print(args)
|
|
51
|
+
return args
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def main():
|
|
55
|
+
args = get_args()
|
|
56
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
57
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
58
|
+
|
|
59
|
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
|
60
|
+
|
|
61
|
+
# 1. export flow decoder estimator
|
|
62
|
+
estimator = cosyvoice.model.flow.decoder.estimator
|
|
63
|
+
|
|
64
|
+
device = cosyvoice.model.device
|
|
65
|
+
batch_size, seq_len = 1, 256
|
|
66
|
+
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
|
67
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
68
|
+
torch.onnx.export(
|
|
69
|
+
estimator,
|
|
70
|
+
(x, mask, mu, t, spks, cond),
|
|
71
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
72
|
+
export_params=True,
|
|
73
|
+
opset_version=18,
|
|
74
|
+
do_constant_folding=True,
|
|
75
|
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
76
|
+
output_names=['estimator_out'],
|
|
77
|
+
dynamic_axes={
|
|
78
|
+
'x': {0: 'batch_size', 2: 'seq_len'},
|
|
79
|
+
'mask': {0: 'batch_size', 2: 'seq_len'},
|
|
80
|
+
'mu': {0: 'batch_size', 2: 'seq_len'},
|
|
81
|
+
'cond': {0: 'batch_size', 2: 'seq_len'},
|
|
82
|
+
't': {0: 'batch_size'},
|
|
83
|
+
'spks': {0: 'batch_size'},
|
|
84
|
+
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# 2. test computation consistency
|
|
89
|
+
option = onnxruntime.SessionOptions()
|
|
90
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
91
|
+
option.intra_op_num_threads = 1
|
|
92
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
93
|
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
94
|
+
sess_options=option, providers=providers)
|
|
95
|
+
|
|
96
|
+
for _ in tqdm(range(10)):
|
|
97
|
+
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
|
98
|
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
|
99
|
+
ort_inputs = {
|
|
100
|
+
'x': x.cpu().numpy(),
|
|
101
|
+
'mask': mask.cpu().numpy(),
|
|
102
|
+
'mu': mu.cpu().numpy(),
|
|
103
|
+
't': t.cpu().numpy(),
|
|
104
|
+
'spks': spks.cpu().numpy(),
|
|
105
|
+
'cond': cond.cpu().numpy()
|
|
106
|
+
}
|
|
107
|
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
|
108
|
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
if __name__ == "__main__":
|
|
112
|
+
main()
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
|
3
|
+
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
|
|
4
|
+
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
|
|
5
|
+
TRT_DIR=<YOUR_TRT_DIR>
|
|
6
|
+
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
|
7
|
+
|
|
8
|
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
|
9
|
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
|
@@ -18,16 +18,15 @@ import argparse
|
|
|
18
18
|
import logging
|
|
19
19
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
20
|
import os
|
|
21
|
-
|
|
22
21
|
import torch
|
|
23
22
|
from torch.utils.data import DataLoader
|
|
24
23
|
import torchaudio
|
|
25
24
|
from hyperpyyaml import load_hyperpyyaml
|
|
26
25
|
from tqdm import tqdm
|
|
27
26
|
from cosyvoice.cli.model import CosyVoiceModel
|
|
28
|
-
|
|
29
27
|
from cosyvoice.dataset.dataset import Dataset
|
|
30
28
|
|
|
29
|
+
|
|
31
30
|
def get_args():
|
|
32
31
|
parser = argparse.ArgumentParser(description='inference with your model')
|
|
33
32
|
parser.add_argument('--config', required=True, help='config file')
|
|
@@ -66,7 +65,8 @@ def main():
|
|
|
66
65
|
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
67
66
|
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
|
68
67
|
|
|
69
|
-
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
|
68
|
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
|
69
|
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
|
70
70
|
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
|
71
71
|
|
|
72
72
|
del configs
|
|
@@ -74,13 +74,11 @@ def main():
|
|
|
74
74
|
fn = os.path.join(args.result_dir, 'wav.scp')
|
|
75
75
|
f = open(fn, 'w')
|
|
76
76
|
with torch.no_grad():
|
|
77
|
-
for
|
|
77
|
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
|
78
78
|
utts = batch["utts"]
|
|
79
79
|
assert len(utts) == 1, "inference mode only support batchsize 1"
|
|
80
|
-
text = batch["text"]
|
|
81
80
|
text_token = batch["text_token"].to(device)
|
|
82
81
|
text_token_len = batch["text_token_len"].to(device)
|
|
83
|
-
tts_text = batch["tts_text"]
|
|
84
82
|
tts_index = batch["tts_index"]
|
|
85
83
|
tts_text_token = batch["tts_text_token"].to(device)
|
|
86
84
|
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
|
@@ -101,7 +99,7 @@ def main():
|
|
|
101
99
|
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
102
100
|
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
|
103
101
|
tts_speeches = []
|
|
104
|
-
for model_output in model.
|
|
102
|
+
for model_output in model.tts(**model_input):
|
|
105
103
|
tts_speeches.append(model_output['tts_speech'])
|
|
106
104
|
tts_speeches = torch.concat(tts_speeches, dim=1)
|
|
107
105
|
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|