xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-08-
|
|
11
|
+
"date": "2024-08-23T18:14:53+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.14.
|
|
14
|
+
"full-revisionid": "b5002242e04634bca7e75cac9df0cdc6c0bf407a",
|
|
15
|
+
"version": "0.14.3"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -1682,18 +1682,9 @@ class RESTfulAPI:
|
|
|
1682
1682
|
|
|
1683
1683
|
model_family = desc.get("model_family", "")
|
|
1684
1684
|
function_call_models = (
|
|
1685
|
-
["
|
|
1686
|
-
+ QWEN_TOOL_CALL_FAMILY
|
|
1687
|
-
+ GLM4_TOOL_CALL_FAMILY
|
|
1685
|
+
["gorilla-openfunctions-v1"] + QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY
|
|
1688
1686
|
)
|
|
1689
1687
|
|
|
1690
|
-
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
|
|
1691
|
-
|
|
1692
|
-
if is_qwen and system_prompt is not None:
|
|
1693
|
-
raise HTTPException(
|
|
1694
|
-
status_code=400, detail="Qwen ggml does not have system prompt"
|
|
1695
|
-
)
|
|
1696
|
-
|
|
1697
1688
|
if model_family not in function_call_models:
|
|
1698
1689
|
if body.tools:
|
|
1699
1690
|
raise HTTPException(
|
|
@@ -1724,18 +1715,13 @@ class RESTfulAPI:
|
|
|
1724
1715
|
iterator = None
|
|
1725
1716
|
try:
|
|
1726
1717
|
try:
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
system_prompt,
|
|
1735
|
-
chat_history,
|
|
1736
|
-
kwargs,
|
|
1737
|
-
raw_params=raw_kwargs,
|
|
1738
|
-
)
|
|
1718
|
+
iterator = await model.chat(
|
|
1719
|
+
prompt,
|
|
1720
|
+
system_prompt,
|
|
1721
|
+
chat_history,
|
|
1722
|
+
kwargs,
|
|
1723
|
+
raw_params=raw_kwargs,
|
|
1724
|
+
)
|
|
1739
1725
|
except RuntimeError as re:
|
|
1740
1726
|
await self._report_error_event(model_uid, str(re))
|
|
1741
1727
|
self.handle_request_limit_error(re)
|
|
@@ -1763,18 +1749,13 @@ class RESTfulAPI:
|
|
|
1763
1749
|
return EventSourceResponse(stream_results())
|
|
1764
1750
|
else:
|
|
1765
1751
|
try:
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
system_prompt,
|
|
1774
|
-
chat_history,
|
|
1775
|
-
kwargs,
|
|
1776
|
-
raw_params=raw_kwargs,
|
|
1777
|
-
)
|
|
1752
|
+
data = await model.chat(
|
|
1753
|
+
prompt,
|
|
1754
|
+
system_prompt,
|
|
1755
|
+
chat_history,
|
|
1756
|
+
kwargs,
|
|
1757
|
+
raw_params=raw_kwargs,
|
|
1758
|
+
)
|
|
1778
1759
|
return Response(content=data, media_type="application/json")
|
|
1779
1760
|
except Exception as e:
|
|
1780
1761
|
logger.error(e, exc_info=True)
|
|
@@ -426,7 +426,7 @@ class RESTfulGenerateModelHandle(RESTfulModelHandle):
|
|
|
426
426
|
The user's message or user's input.
|
|
427
427
|
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
|
|
428
428
|
Additional configuration for the chat generation.
|
|
429
|
-
"LlamaCppGenerateConfig" -> Configuration for
|
|
429
|
+
"LlamaCppGenerateConfig" -> Configuration for llama-cpp-python model
|
|
430
430
|
"PytorchGenerateConfig" -> Configuration for pytorch model
|
|
431
431
|
|
|
432
432
|
Returns
|
|
@@ -493,7 +493,7 @@ class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
|
|
|
493
493
|
A tool list.
|
|
494
494
|
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
|
|
495
495
|
Additional configuration for the chat generation.
|
|
496
|
-
"LlamaCppGenerateConfig" -> configuration for
|
|
496
|
+
"LlamaCppGenerateConfig" -> configuration for llama-cpp-python model
|
|
497
497
|
"PytorchGenerateConfig" -> configuration for pytorch model
|
|
498
498
|
|
|
499
499
|
Returns
|
|
@@ -236,8 +236,8 @@ class GradioInterface:
|
|
|
236
236
|
bot[-1][1] = history[-1]["content"]
|
|
237
237
|
yield history, bot
|
|
238
238
|
|
|
239
|
-
def add_text(history, bot, text, image):
|
|
240
|
-
logger.debug("Add text, text: %s, image: %s", text, image)
|
|
239
|
+
def add_text(history, bot, text, image, video):
|
|
240
|
+
logger.debug("Add text, text: %s, image: %s, video: %s", text, image, video)
|
|
241
241
|
if image:
|
|
242
242
|
buffered = BytesIO()
|
|
243
243
|
with PIL.Image.open(image) as img:
|
|
@@ -257,16 +257,47 @@ class GradioInterface:
|
|
|
257
257
|
},
|
|
258
258
|
],
|
|
259
259
|
}
|
|
260
|
+
elif video:
|
|
261
|
+
|
|
262
|
+
def video_to_base64(video_path):
|
|
263
|
+
with open(video_path, "rb") as video_file:
|
|
264
|
+
encoded_string = base64.b64encode(video_file.read()).decode(
|
|
265
|
+
"utf-8"
|
|
266
|
+
)
|
|
267
|
+
return encoded_string
|
|
268
|
+
|
|
269
|
+
def generate_html_video(video_path):
|
|
270
|
+
base64_video = video_to_base64(video_path)
|
|
271
|
+
video_format = video_path.split(".")[-1]
|
|
272
|
+
html_code = f"""
|
|
273
|
+
<video controls>
|
|
274
|
+
<source src="data:video/{video_format};base64,{base64_video}" type="video/{video_format}">
|
|
275
|
+
Your browser does not support the video tag.
|
|
276
|
+
</video>
|
|
277
|
+
"""
|
|
278
|
+
return html_code
|
|
279
|
+
|
|
280
|
+
display_content = f"{generate_html_video(video)}\n{text}"
|
|
281
|
+
message = {
|
|
282
|
+
"role": "user",
|
|
283
|
+
"content": [
|
|
284
|
+
{"type": "text", "text": text},
|
|
285
|
+
{
|
|
286
|
+
"type": "video_url",
|
|
287
|
+
"video_url": {"url": video},
|
|
288
|
+
},
|
|
289
|
+
],
|
|
290
|
+
}
|
|
260
291
|
else:
|
|
261
292
|
display_content = text
|
|
262
293
|
message = {"role": "user", "content": text}
|
|
263
294
|
history = history + [message]
|
|
264
295
|
bot = bot + [[display_content, None]]
|
|
265
|
-
return history, bot, "", None
|
|
296
|
+
return history, bot, "", None, None
|
|
266
297
|
|
|
267
298
|
def clear_history():
|
|
268
299
|
logger.debug("Clear history.")
|
|
269
|
-
return [], None, "", None
|
|
300
|
+
return [], None, "", None, None
|
|
270
301
|
|
|
271
302
|
def update_button(text):
|
|
272
303
|
return gr.update(interactive=bool(text))
|
|
@@ -309,10 +340,11 @@ class GradioInterface:
|
|
|
309
340
|
state = gr.State([])
|
|
310
341
|
with gr.Row():
|
|
311
342
|
chatbot = gr.Chatbot(
|
|
312
|
-
elem_id="chatbot", label=self.model_name, height=
|
|
343
|
+
elem_id="chatbot", label=self.model_name, height=700, scale=7
|
|
313
344
|
)
|
|
314
345
|
with gr.Column(scale=3):
|
|
315
346
|
imagebox = gr.Image(type="filepath")
|
|
347
|
+
videobox = gr.Video()
|
|
316
348
|
textbox = gr.Textbox(
|
|
317
349
|
show_label=False,
|
|
318
350
|
placeholder="Enter text and press ENTER",
|
|
@@ -340,8 +372,8 @@ class GradioInterface:
|
|
|
340
372
|
|
|
341
373
|
textbox.submit(
|
|
342
374
|
add_text,
|
|
343
|
-
[state, chatbot, textbox, imagebox],
|
|
344
|
-
[state, chatbot, textbox, imagebox],
|
|
375
|
+
[state, chatbot, textbox, imagebox, videobox],
|
|
376
|
+
[state, chatbot, textbox, imagebox, videobox],
|
|
345
377
|
queue=False,
|
|
346
378
|
).then(
|
|
347
379
|
predict,
|
|
@@ -351,8 +383,8 @@ class GradioInterface:
|
|
|
351
383
|
|
|
352
384
|
submit_btn.click(
|
|
353
385
|
add_text,
|
|
354
|
-
[state, chatbot, textbox, imagebox],
|
|
355
|
-
[state, chatbot, textbox, imagebox],
|
|
386
|
+
[state, chatbot, textbox, imagebox, videobox],
|
|
387
|
+
[state, chatbot, textbox, imagebox, videobox],
|
|
356
388
|
queue=False,
|
|
357
389
|
).then(
|
|
358
390
|
predict,
|
|
@@ -361,7 +393,10 @@ class GradioInterface:
|
|
|
361
393
|
)
|
|
362
394
|
|
|
363
395
|
clear_btn.click(
|
|
364
|
-
clear_history,
|
|
396
|
+
clear_history,
|
|
397
|
+
None,
|
|
398
|
+
[state, chatbot, textbox, imagebox, videobox],
|
|
399
|
+
queue=False,
|
|
365
400
|
)
|
|
366
401
|
|
|
367
402
|
return chat_vl_interface
|
|
@@ -163,6 +163,7 @@ class ImageInterface:
|
|
|
163
163
|
size_width: int,
|
|
164
164
|
size_height: int,
|
|
165
165
|
num_inference_steps: int,
|
|
166
|
+
padding_image_to_multiple: int,
|
|
166
167
|
) -> PIL.Image.Image:
|
|
167
168
|
from ..client import RESTfulClient
|
|
168
169
|
|
|
@@ -178,6 +179,7 @@ class ImageInterface:
|
|
|
178
179
|
num_inference_steps = (
|
|
179
180
|
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
180
181
|
)
|
|
182
|
+
padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
|
|
181
183
|
|
|
182
184
|
bio = io.BytesIO()
|
|
183
185
|
image.save(bio, format="png")
|
|
@@ -190,6 +192,7 @@ class ImageInterface:
|
|
|
190
192
|
size=size,
|
|
191
193
|
response_format="b64_json",
|
|
192
194
|
num_inference_steps=num_inference_steps,
|
|
195
|
+
padding_image_to_multiple=padding_image_to_multiple,
|
|
193
196
|
)
|
|
194
197
|
|
|
195
198
|
images = []
|
|
@@ -222,9 +225,14 @@ class ImageInterface:
|
|
|
222
225
|
n = gr.Number(label="Number of image", value=1)
|
|
223
226
|
size_width = gr.Number(label="Width", value=-1)
|
|
224
227
|
size_height = gr.Number(label="Height", value=-1)
|
|
228
|
+
|
|
229
|
+
with gr.Row():
|
|
225
230
|
num_inference_steps = gr.Number(
|
|
226
231
|
label="Inference Step Number", value=-1
|
|
227
232
|
)
|
|
233
|
+
padding_image_to_multiple = gr.Number(
|
|
234
|
+
label="Padding image to multiple", value=-1
|
|
235
|
+
)
|
|
228
236
|
|
|
229
237
|
with gr.Row():
|
|
230
238
|
with gr.Column(scale=1):
|
|
@@ -242,6 +250,7 @@ class ImageInterface:
|
|
|
242
250
|
size_width,
|
|
243
251
|
size_height,
|
|
244
252
|
num_inference_steps,
|
|
253
|
+
padding_image_to_multiple,
|
|
245
254
|
],
|
|
246
255
|
outputs=output_gallery,
|
|
247
256
|
)
|
xinference/core/model.py
CHANGED
|
@@ -132,8 +132,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
132
132
|
|
|
133
133
|
async def __pre_destroy__(self):
|
|
134
134
|
from ..model.embedding.core import EmbeddingModel
|
|
135
|
-
from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
|
|
136
135
|
from ..model.llm.sglang.core import SGLANGModel
|
|
136
|
+
from ..model.llm.transformers.core import PytorchModel as LLMPytorchModel
|
|
137
137
|
from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
|
|
138
138
|
|
|
139
139
|
if self.allow_batching():
|
|
@@ -177,8 +177,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
177
177
|
request_limits: Optional[int] = None,
|
|
178
178
|
):
|
|
179
179
|
super().__init__()
|
|
180
|
-
from ..model.llm.
|
|
180
|
+
from ..model.llm.lmdeploy.core import LMDeployModel
|
|
181
181
|
from ..model.llm.sglang.core import SGLANGModel
|
|
182
|
+
from ..model.llm.transformers.core import PytorchModel
|
|
182
183
|
from ..model.llm.vllm.core import VLLMModel
|
|
183
184
|
|
|
184
185
|
self._worker_address = worker_address
|
|
@@ -192,7 +193,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
192
193
|
self._current_generator = lambda: None
|
|
193
194
|
self._lock = (
|
|
194
195
|
None
|
|
195
|
-
if isinstance(
|
|
196
|
+
if isinstance(
|
|
197
|
+
self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel)
|
|
198
|
+
)
|
|
196
199
|
else asyncio.locks.Lock()
|
|
197
200
|
)
|
|
198
201
|
self._worker_ref = None
|
|
@@ -272,7 +275,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
272
275
|
return isinstance(self._model, VLLMModel)
|
|
273
276
|
|
|
274
277
|
def allow_batching(self) -> bool:
|
|
275
|
-
from ..model.llm.
|
|
278
|
+
from ..model.llm.transformers.core import PytorchModel
|
|
276
279
|
|
|
277
280
|
model_ability = self._model_description.get("model_ability", [])
|
|
278
281
|
|
|
@@ -415,7 +418,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
415
418
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
416
419
|
|
|
417
420
|
if self._lock is not None and self._current_generator():
|
|
418
|
-
raise Exception("Parallel generation is not supported by
|
|
421
|
+
raise Exception("Parallel generation is not supported by llama-cpp-python.")
|
|
419
422
|
|
|
420
423
|
if inspect.isgenerator(ret):
|
|
421
424
|
gen = self._to_generator(output_type, ret)
|
xinference/core/scheduler.py
CHANGED
|
@@ -24,7 +24,6 @@ import xoscar as xo
|
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
|
-
XINFERENCE_BATCHING_CLEAN_CACHE_INTERVAL = 5
|
|
28
27
|
XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
|
|
29
28
|
XINFERENCE_STREAMING_ERROR_FLAG = "<XINFERENCE_STREAMING_ERROR>"
|
|
30
29
|
XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
|
|
@@ -359,7 +358,7 @@ class SchedulerActor(xo.StatelessActor):
|
|
|
359
358
|
|
|
360
359
|
@staticmethod
|
|
361
360
|
def _empty_cache():
|
|
362
|
-
from ..model.llm.
|
|
361
|
+
from ..model.llm.transformers.utils import empty_cache
|
|
363
362
|
|
|
364
363
|
empty_cache()
|
|
365
364
|
|
xinference/core/worker.py
CHANGED
|
@@ -39,9 +39,11 @@ from ..core.status_guard import LaunchStatus
|
|
|
39
39
|
from ..device_utils import get_available_device_env_name, gpu_count
|
|
40
40
|
from ..model.core import ModelDescription, create_model_instance
|
|
41
41
|
from ..types import PeftModelConfig
|
|
42
|
+
from .cache_tracker import CacheTrackerActor
|
|
42
43
|
from .event import Event, EventCollectorActor, EventType
|
|
43
44
|
from .metrics import launch_metrics_export_server, record_metrics
|
|
44
45
|
from .resource import gather_node_info
|
|
46
|
+
from .status_guard import StatusGuardActor
|
|
45
47
|
from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
|
|
46
48
|
|
|
47
49
|
logger = getLogger(__name__)
|
|
@@ -71,6 +73,15 @@ class WorkerActor(xo.StatelessActor):
|
|
|
71
73
|
self._supervisor_ref: Optional[xo.ActorRefType] = None
|
|
72
74
|
self._main_pool = main_pool
|
|
73
75
|
self._main_pool.recover_sub_pool = self.recover_sub_pool
|
|
76
|
+
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
77
|
+
"StatusGuardActor"
|
|
78
|
+
] = None
|
|
79
|
+
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
80
|
+
EventCollectorActor
|
|
81
|
+
] = None
|
|
82
|
+
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
83
|
+
CacheTrackerActor
|
|
84
|
+
] = None
|
|
74
85
|
|
|
75
86
|
# internal states.
|
|
76
87
|
# temporary placeholder during model launch process:
|
|
@@ -308,56 +319,50 @@ class WorkerActor(xo.StatelessActor):
|
|
|
308
319
|
Params:
|
|
309
320
|
add_worker: By default will call supervisor.add_worker after first connect
|
|
310
321
|
"""
|
|
311
|
-
from .status_guard import StatusGuardActor
|
|
312
322
|
from .supervisor import SupervisorActor
|
|
313
323
|
|
|
314
324
|
if self._supervisor_ref is not None:
|
|
315
325
|
return self._supervisor_ref
|
|
316
|
-
|
|
326
|
+
supervisor_ref = await xo.actor_ref( # type: ignore
|
|
317
327
|
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
318
328
|
)
|
|
329
|
+
# Prevent concurrent operations leads to double initialization, check again.
|
|
330
|
+
if self._supervisor_ref is not None:
|
|
331
|
+
return self._supervisor_ref
|
|
332
|
+
self._supervisor_ref = supervisor_ref
|
|
319
333
|
if add_worker and len(self._model_uid_to_model) == 0:
|
|
320
334
|
# Newly started (or restarted), has no model, notify supervisor
|
|
321
335
|
await self._supervisor_ref.add_worker(self.address)
|
|
322
336
|
logger.info("Connected to supervisor as a fresh worker")
|
|
323
337
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
model_version_infos.update(get_llm_model_descriptions())
|
|
353
|
-
model_version_infos.update(get_embedding_model_descriptions())
|
|
354
|
-
model_version_infos.update(get_rerank_model_descriptions())
|
|
355
|
-
model_version_infos.update(get_image_model_descriptions())
|
|
356
|
-
model_version_infos.update(get_audio_model_descriptions())
|
|
357
|
-
model_version_infos.update(get_flexible_model_descriptions())
|
|
358
|
-
await self._cache_tracker_ref.record_model_version(
|
|
359
|
-
model_version_infos, self.address
|
|
360
|
-
)
|
|
338
|
+
self._status_guard_ref = await xo.actor_ref(
|
|
339
|
+
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
340
|
+
)
|
|
341
|
+
self._event_collector_ref = await xo.actor_ref(
|
|
342
|
+
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
343
|
+
)
|
|
344
|
+
self._cache_tracker_ref = await xo.actor_ref(
|
|
345
|
+
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
346
|
+
)
|
|
347
|
+
# cache_tracker is on supervisor
|
|
348
|
+
from ..model.audio import get_audio_model_descriptions
|
|
349
|
+
from ..model.embedding import get_embedding_model_descriptions
|
|
350
|
+
from ..model.flexible import get_flexible_model_descriptions
|
|
351
|
+
from ..model.image import get_image_model_descriptions
|
|
352
|
+
from ..model.llm import get_llm_model_descriptions
|
|
353
|
+
from ..model.rerank import get_rerank_model_descriptions
|
|
354
|
+
|
|
355
|
+
# record model version
|
|
356
|
+
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
|
|
357
|
+
model_version_infos.update(get_llm_model_descriptions())
|
|
358
|
+
model_version_infos.update(get_embedding_model_descriptions())
|
|
359
|
+
model_version_infos.update(get_rerank_model_descriptions())
|
|
360
|
+
model_version_infos.update(get_image_model_descriptions())
|
|
361
|
+
model_version_infos.update(get_audio_model_descriptions())
|
|
362
|
+
model_version_infos.update(get_flexible_model_descriptions())
|
|
363
|
+
await self._cache_tracker_ref.record_model_version(
|
|
364
|
+
model_version_infos, self.address
|
|
365
|
+
)
|
|
361
366
|
return self._supervisor_ref
|
|
362
367
|
|
|
363
368
|
@staticmethod
|
|
@@ -734,7 +739,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
734
739
|
elif model_type == "image":
|
|
735
740
|
return ["text_to_image"]
|
|
736
741
|
elif model_type == "audio":
|
|
737
|
-
return [
|
|
742
|
+
return [model._model_spec.ability]
|
|
738
743
|
elif model_type == "video":
|
|
739
744
|
return ["text_to_video"]
|
|
740
745
|
elif model_type == "flexible":
|
|
@@ -793,6 +798,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
793
798
|
logger.exception(e)
|
|
794
799
|
raise
|
|
795
800
|
try:
|
|
801
|
+
_ = await self.get_supervisor_ref()
|
|
796
802
|
if self._event_collector_ref is not None:
|
|
797
803
|
await self._event_collector_ref.report_event(
|
|
798
804
|
origin_uid,
|
|
@@ -830,7 +836,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
830
836
|
raise ValueError(
|
|
831
837
|
f"PEFT adaptors cannot be applied to embedding or rerank models."
|
|
832
838
|
)
|
|
833
|
-
if model_type == "LLM" and model_format in ("ggufv2",
|
|
839
|
+
if model_type == "LLM" and model_format in ("ggufv2",):
|
|
834
840
|
raise ValueError(
|
|
835
841
|
f"PEFT adaptors can only be applied to pytorch-like models"
|
|
836
842
|
)
|
|
@@ -914,6 +920,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
914
920
|
raise ValueError(f"{model_uid} is launching")
|
|
915
921
|
origin_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
916
922
|
try:
|
|
923
|
+
_ = await self.get_supervisor_ref()
|
|
917
924
|
if self._event_collector_ref is not None:
|
|
918
925
|
await self._event_collector_ref.report_event(
|
|
919
926
|
origin_uid,
|
|
@@ -1081,7 +1088,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
1081
1088
|
paths.update([os.path.realpath(path) for path in paths])
|
|
1082
1089
|
|
|
1083
1090
|
# get tensorizer path
|
|
1084
|
-
from ..model.llm.
|
|
1091
|
+
from ..model.llm.transformers.tensorizer_utils import get_tensorizer_dir
|
|
1085
1092
|
|
|
1086
1093
|
tensorizer_path = get_tensorizer_dir(path)
|
|
1087
1094
|
if os.path.isdir(tensorizer_path):
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -750,7 +750,7 @@ def remove_cache(
|
|
|
750
750
|
"-f",
|
|
751
751
|
default=None,
|
|
752
752
|
type=str,
|
|
753
|
-
help="Specify the format of the model, e.g. pytorch,
|
|
753
|
+
help="Specify the format of the model, e.g. pytorch, ggufv2, etc.",
|
|
754
754
|
)
|
|
755
755
|
@click.option(
|
|
756
756
|
"--quantization",
|
|
@@ -1516,7 +1516,7 @@ def query_engine_by_model_name(
|
|
|
1516
1516
|
"-f",
|
|
1517
1517
|
type=str,
|
|
1518
1518
|
required=True,
|
|
1519
|
-
help="Specify the format of the model, e.g. pytorch,
|
|
1519
|
+
help="Specify the format of the model, e.g. pytorch, ggufv2, etc.",
|
|
1520
1520
|
)
|
|
1521
1521
|
@click.option(
|
|
1522
1522
|
"--quantization",
|
|
@@ -66,10 +66,10 @@ def test_cmdline(setup, stream, model_uid):
|
|
|
66
66
|
replica = 1
|
|
67
67
|
original_model_uid = model_uid
|
|
68
68
|
model_uid = client.launch_model(
|
|
69
|
-
model_name="
|
|
69
|
+
model_name="qwen1.5-chat",
|
|
70
70
|
model_engine="llama.cpp",
|
|
71
71
|
model_uid=model_uid,
|
|
72
|
-
model_size_in_billions=
|
|
72
|
+
model_size_in_billions="0_5",
|
|
73
73
|
quantization="q4_0",
|
|
74
74
|
replica=replica,
|
|
75
75
|
)
|
|
@@ -249,10 +249,10 @@ def test_rotate_logs(setup_with_file_logging):
|
|
|
249
249
|
runner = CliRunner()
|
|
250
250
|
replica = 1 if os.name == "nt" else 2
|
|
251
251
|
model_uid = client.launch_model(
|
|
252
|
-
model_name="
|
|
252
|
+
model_name="qwen1.5-chat",
|
|
253
253
|
model_engine="llama.cpp",
|
|
254
254
|
model_uid=None,
|
|
255
|
-
model_size_in_billions=
|
|
255
|
+
model_size_in_billions="0_5",
|
|
256
256
|
quantization="q4_0",
|
|
257
257
|
replica=replica,
|
|
258
258
|
)
|
|
@@ -288,7 +288,7 @@ def test_list_cached_models(setup):
|
|
|
288
288
|
|
|
289
289
|
result = runner.invoke(
|
|
290
290
|
list_cached_models,
|
|
291
|
-
["--endpoint", endpoint, "--model_name", "
|
|
291
|
+
["--endpoint", endpoint, "--model_name", "qwen1.5-chat"],
|
|
292
292
|
)
|
|
293
293
|
assert "model_name" in result.stdout
|
|
294
294
|
assert "model_format" in result.stdout
|
|
@@ -305,9 +305,9 @@ def test_remove_cache(setup):
|
|
|
305
305
|
|
|
306
306
|
result = runner.invoke(
|
|
307
307
|
remove_cache,
|
|
308
|
-
["--endpoint", endpoint, "--model_version", "
|
|
308
|
+
["--endpoint", endpoint, "--model_version", "qwen1.5-chat"],
|
|
309
309
|
input="y\n",
|
|
310
310
|
)
|
|
311
311
|
|
|
312
312
|
assert result.exit_code == 0
|
|
313
|
-
assert "Cache directory
|
|
313
|
+
assert "Cache directory qwen1.5-chat has been deleted."
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import base64
|
|
14
15
|
import logging
|
|
15
16
|
from io import BytesIO
|
|
16
17
|
from typing import TYPE_CHECKING, Optional
|
|
@@ -61,16 +62,31 @@ class ChatTTSModel:
|
|
|
61
62
|
import torchaudio
|
|
62
63
|
import xxhash
|
|
63
64
|
|
|
64
|
-
|
|
65
|
+
rnd_spk_emb = None
|
|
65
66
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
67
|
+
if len(voice) > 400:
|
|
68
|
+
try:
|
|
69
|
+
assert self._model is not None
|
|
70
|
+
b = base64.b64decode(voice)
|
|
71
|
+
bio = BytesIO(b)
|
|
72
|
+
tensor = torch.load(bio, map_location="cpu")
|
|
73
|
+
rnd_spk_emb = self._model._encode_spk_emb(tensor)
|
|
74
|
+
logger.info("Speech by input speaker")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.info("Fallback to random speaker due to %s", e)
|
|
71
77
|
|
|
72
|
-
|
|
73
|
-
|
|
78
|
+
if rnd_spk_emb is None:
|
|
79
|
+
seed = xxhash.xxh32_intdigest(voice)
|
|
80
|
+
|
|
81
|
+
torch.manual_seed(seed)
|
|
82
|
+
np.random.seed(seed)
|
|
83
|
+
torch.cuda.manual_seed(seed)
|
|
84
|
+
torch.backends.cudnn.deterministic = True
|
|
85
|
+
torch.backends.cudnn.benchmark = False
|
|
86
|
+
|
|
87
|
+
assert self._model is not None
|
|
88
|
+
rnd_spk_emb = self._model.sample_random_speaker()
|
|
89
|
+
logger.info("Speech by voice %s", voice)
|
|
74
90
|
|
|
75
91
|
default = 5
|
|
76
92
|
infer_speed = int(default * speed)
|
|
@@ -100,7 +116,6 @@ class ChatTTSModel:
|
|
|
100
116
|
if new_last_pos != last_pos:
|
|
101
117
|
out.seek(last_pos)
|
|
102
118
|
encoded_bytes = out.read()
|
|
103
|
-
print(len(encoded_bytes))
|
|
104
119
|
yield encoded_bytes
|
|
105
120
|
last_pos = new_last_pos
|
|
106
121
|
|
xinference/model/audio/core.py
CHANGED
|
@@ -21,6 +21,7 @@ from ..core import CacheableModelSpec, ModelDescription
|
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
|
+
from .fish_speech import FishSpeechModel
|
|
24
25
|
from .funasr import FunASRModel
|
|
25
26
|
from .whisper import WhisperModel
|
|
26
27
|
|
|
@@ -46,6 +47,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
46
47
|
model_id: str
|
|
47
48
|
model_revision: str
|
|
48
49
|
multilingual: bool
|
|
50
|
+
ability: str
|
|
49
51
|
default_model_config: Optional[Dict[str, Any]]
|
|
50
52
|
default_transcription_config: Optional[Dict[str, Any]]
|
|
51
53
|
|
|
@@ -156,13 +158,15 @@ def create_audio_model_instance(
|
|
|
156
158
|
model_path: Optional[str] = None,
|
|
157
159
|
**kwargs,
|
|
158
160
|
) -> Tuple[
|
|
159
|
-
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
|
|
161
|
+
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
|
|
160
162
|
AudioModelDescription,
|
|
161
163
|
]:
|
|
162
164
|
model_spec = match_audio(model_name, download_hub)
|
|
163
165
|
if model_path is None:
|
|
164
166
|
model_path = cache(model_spec)
|
|
165
|
-
model: Union[
|
|
167
|
+
model: Union[
|
|
168
|
+
WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
|
|
169
|
+
]
|
|
166
170
|
if model_spec.model_family == "whisper":
|
|
167
171
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
168
172
|
elif model_spec.model_family == "funasr":
|
|
@@ -171,6 +175,8 @@ def create_audio_model_instance(
|
|
|
171
175
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
172
176
|
elif model_spec.model_family == "CosyVoice":
|
|
173
177
|
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
|
|
178
|
+
elif model_spec.model_family == "FishAudio":
|
|
179
|
+
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
|
|
174
180
|
else:
|
|
175
181
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
176
182
|
model_description = AudioModelDescription(
|