xinference 1.9.0__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +415 -1
- xinference/constants.py +2 -0
- xinference/core/model.py +3 -4
- xinference/core/supervisor.py +29 -1
- xinference/core/worker.py +4 -1
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/kokoro.py +1 -1
- xinference/model/audio/kokoro_zh.py +124 -0
- xinference/model/audio/model_spec.json +64 -20
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +19 -4
- xinference/model/embedding/vllm/core.py +40 -8
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +116 -9
- xinference/model/image/stable_diffusion/core.py +141 -31
- xinference/model/llm/core.py +10 -0
- xinference/model/llm/llama_cpp/core.py +42 -40
- xinference/model/llm/llm_family.json +435 -23
- xinference/model/llm/llm_family.py +1 -0
- xinference/model/llm/mlx/core.py +52 -33
- xinference/model/llm/sglang/core.py +2 -44
- xinference/model/llm/tool_parsers/__init__.py +58 -0
- xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +128 -0
- xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
- xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
- xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
- xinference/model/llm/transformers/core.py +6 -12
- xinference/model/llm/utils.py +128 -46
- xinference/model/llm/vllm/core.py +8 -61
- xinference/model/rerank/core.py +3 -0
- xinference/model/rerank/sentence_transformers/core.py +1 -1
- xinference/model/rerank/vllm/core.py +56 -6
- xinference/model/utils.py +1 -2
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +105 -2
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/METADATA +16 -12
- {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/RECORD +86 -77
- xinference/ui/web/ui/build/static/js/main.4918643a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.4918643a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.4918643a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/WHEEL +0 -0
- {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/top_level.txt +0 -0
xinference/model/audio/kokoro.py
CHANGED
|
@@ -81,7 +81,7 @@ class KokoroModel:
|
|
|
81
81
|
logger.info("Launching Kokoro model with language code: %s", lang_code)
|
|
82
82
|
self._model = KPipeline(
|
|
83
83
|
lang_code=lang_code,
|
|
84
|
-
model=KModel(config=config_path, model=model_path),
|
|
84
|
+
model=KModel(config=config_path, model=model_path).to(self._device),
|
|
85
85
|
device=self._device,
|
|
86
86
|
)
|
|
87
87
|
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
from io import BytesIO
|
|
16
|
+
from typing import TYPE_CHECKING, Optional
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from ...device_utils import get_available_device, is_device_available
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .core import AudioModelFamilyV2
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
REPO_ID = "hexgrad/Kokoro-82M-v1.1-zh"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class KokoroZHModel:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
model_uid: str,
|
|
34
|
+
model_path: str,
|
|
35
|
+
model_spec: "AudioModelFamilyV2",
|
|
36
|
+
device: Optional[str] = None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
self.model_family = model_spec
|
|
40
|
+
self._model_uid = model_uid
|
|
41
|
+
self._model_path = model_path
|
|
42
|
+
self._model_spec = model_spec
|
|
43
|
+
self._device = device
|
|
44
|
+
self._model = None
|
|
45
|
+
self._kwargs = kwargs
|
|
46
|
+
self._en_pipeline = None
|
|
47
|
+
|
|
48
|
+
def _en_callable(self, text):
|
|
49
|
+
"""
|
|
50
|
+
Fixing the issue of English words being skipped in the Chinese model.
|
|
51
|
+
from https://hf-mirror.com/hexgrad/Kokoro-82M-v1.1-zh/blob/main/samples/make_zh.py
|
|
52
|
+
"""
|
|
53
|
+
if text == "Kokoro":
|
|
54
|
+
return "kˈOkəɹO"
|
|
55
|
+
elif text == "Sol":
|
|
56
|
+
return "sˈOl"
|
|
57
|
+
return next(self._en_pipeline(text)).phonemes
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def model_ability(self):
|
|
61
|
+
return self._model_spec.model_ability
|
|
62
|
+
|
|
63
|
+
def load(self):
|
|
64
|
+
if self._device is None:
|
|
65
|
+
self._device = get_available_device()
|
|
66
|
+
else:
|
|
67
|
+
if not is_device_available(self._device):
|
|
68
|
+
raise ValueError(f"Device {self._device} is not available!")
|
|
69
|
+
|
|
70
|
+
import os
|
|
71
|
+
|
|
72
|
+
from kokoro import KModel, KPipeline
|
|
73
|
+
|
|
74
|
+
self._en_pipeline = KPipeline(lang_code="a", repo_id=REPO_ID, model=False)
|
|
75
|
+
|
|
76
|
+
config_path = os.path.join(self._model_path, "config.json")
|
|
77
|
+
model_path = os.path.join(self._model_path, "kokoro-v1_1-zh.pth")
|
|
78
|
+
lang_code = self._kwargs.get("lang_code", "z")
|
|
79
|
+
logger.info("Launching Kokoro model with language code: %s", lang_code)
|
|
80
|
+
|
|
81
|
+
self._model = KPipeline(
|
|
82
|
+
lang_code=lang_code,
|
|
83
|
+
model=KModel(config=config_path, model=model_path).to(self._device),
|
|
84
|
+
repo_id=REPO_ID,
|
|
85
|
+
en_callable=self._en_callable,
|
|
86
|
+
device=self._device,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def speech(
|
|
90
|
+
self,
|
|
91
|
+
input: str,
|
|
92
|
+
voice: str,
|
|
93
|
+
response_format: str = "mp3",
|
|
94
|
+
speed: float = 1.0,
|
|
95
|
+
stream: bool = False,
|
|
96
|
+
**kwargs,
|
|
97
|
+
):
|
|
98
|
+
import soundfile
|
|
99
|
+
|
|
100
|
+
if stream:
|
|
101
|
+
raise Exception("Kokoro does not support stream mode.")
|
|
102
|
+
assert self._model is not None
|
|
103
|
+
if not voice:
|
|
104
|
+
voice = "zf_001"
|
|
105
|
+
logger.info("Auto select speaker: %s", voice)
|
|
106
|
+
elif voice.endswith(".pt"):
|
|
107
|
+
logger.info("Using custom voice pt: %s", voice)
|
|
108
|
+
else:
|
|
109
|
+
logger.info("Using voice: %s", voice)
|
|
110
|
+
logger.info("Speech kwargs: %s", kwargs)
|
|
111
|
+
generator = self._model(text=input, voice=voice, speed=speed, **kwargs)
|
|
112
|
+
results = list(generator)
|
|
113
|
+
audio = np.concatenate([r[2] for r in results])
|
|
114
|
+
# Save the generated audio
|
|
115
|
+
with BytesIO() as out:
|
|
116
|
+
with soundfile.SoundFile(
|
|
117
|
+
out,
|
|
118
|
+
"w",
|
|
119
|
+
24000,
|
|
120
|
+
1,
|
|
121
|
+
format=response_format.upper(),
|
|
122
|
+
) as f:
|
|
123
|
+
f.write(audio)
|
|
124
|
+
return out.getvalue()
|
|
@@ -525,7 +525,8 @@
|
|
|
525
525
|
"model_name": "ChatTTS",
|
|
526
526
|
"model_family": "ChatTTS",
|
|
527
527
|
"model_ability": [
|
|
528
|
-
"text2audio"
|
|
528
|
+
"text2audio",
|
|
529
|
+
"text2audio_zero_shot"
|
|
529
530
|
],
|
|
530
531
|
"multilingual": true,
|
|
531
532
|
"virtualenv": {
|
|
@@ -551,7 +552,8 @@
|
|
|
551
552
|
"model_name": "CosyVoice-300M",
|
|
552
553
|
"model_family": "CosyVoice",
|
|
553
554
|
"model_ability": [
|
|
554
|
-
"text2audio"
|
|
555
|
+
"text2audio",
|
|
556
|
+
"text2audio_voice_cloning"
|
|
555
557
|
],
|
|
556
558
|
"multilingual": true,
|
|
557
559
|
"model_src": {
|
|
@@ -570,7 +572,8 @@
|
|
|
570
572
|
"model_name": "CosyVoice-300M-SFT",
|
|
571
573
|
"model_family": "CosyVoice",
|
|
572
574
|
"model_ability": [
|
|
573
|
-
"text2audio"
|
|
575
|
+
"text2audio",
|
|
576
|
+
"text2audio_zero_shot"
|
|
574
577
|
],
|
|
575
578
|
"multilingual": true,
|
|
576
579
|
"model_src": {
|
|
@@ -589,7 +592,8 @@
|
|
|
589
592
|
"model_name": "CosyVoice-300M-Instruct",
|
|
590
593
|
"model_family": "CosyVoice",
|
|
591
594
|
"model_ability": [
|
|
592
|
-
"text2audio"
|
|
595
|
+
"text2audio",
|
|
596
|
+
"text2audio_zero_shot"
|
|
593
597
|
],
|
|
594
598
|
"multilingual": true,
|
|
595
599
|
"model_src": {
|
|
@@ -608,7 +612,9 @@
|
|
|
608
612
|
"model_name": "CosyVoice2-0.5B",
|
|
609
613
|
"model_family": "CosyVoice",
|
|
610
614
|
"model_ability": [
|
|
611
|
-
"text2audio"
|
|
615
|
+
"text2audio",
|
|
616
|
+
"text2audio_zero_shot",
|
|
617
|
+
"text2audio_voice_cloning"
|
|
612
618
|
],
|
|
613
619
|
"multilingual": true,
|
|
614
620
|
"virtualenv": {
|
|
@@ -625,7 +631,8 @@
|
|
|
625
631
|
"HyperPyYAML",
|
|
626
632
|
"onnxruntime>=1.16.0",
|
|
627
633
|
"pyworld>=0.3.4",
|
|
628
|
-
"
|
|
634
|
+
"wetext==0.0.9",
|
|
635
|
+
"transformers==4.51.3",
|
|
629
636
|
"#system_numpy#",
|
|
630
637
|
"#system_torch#"
|
|
631
638
|
]
|
|
@@ -646,7 +653,9 @@
|
|
|
646
653
|
"model_name": "FishSpeech-1.5",
|
|
647
654
|
"model_family": "FishAudio",
|
|
648
655
|
"model_ability": [
|
|
649
|
-
"text2audio"
|
|
656
|
+
"text2audio",
|
|
657
|
+
"text2audio_zero_shot",
|
|
658
|
+
"text2audio_voice_cloning"
|
|
650
659
|
],
|
|
651
660
|
"multilingual": true,
|
|
652
661
|
"model_src": {
|
|
@@ -665,7 +674,9 @@
|
|
|
665
674
|
"model_name": "F5-TTS",
|
|
666
675
|
"model_family": "F5-TTS",
|
|
667
676
|
"model_ability": [
|
|
668
|
-
"text2audio"
|
|
677
|
+
"text2audio",
|
|
678
|
+
"text2audio_zero_shot",
|
|
679
|
+
"text2audio_voice_cloning"
|
|
669
680
|
],
|
|
670
681
|
"multilingual": true,
|
|
671
682
|
"model_src": {
|
|
@@ -684,7 +695,9 @@
|
|
|
684
695
|
"model_name": "F5-TTS-MLX",
|
|
685
696
|
"model_family": "F5-TTS-MLX",
|
|
686
697
|
"model_ability": [
|
|
687
|
-
"text2audio"
|
|
698
|
+
"text2audio",
|
|
699
|
+
"text2audio_zero_shot",
|
|
700
|
+
"text2audio_voice_cloning"
|
|
688
701
|
],
|
|
689
702
|
"multilingual": true,
|
|
690
703
|
"model_src": {
|
|
@@ -699,7 +712,8 @@
|
|
|
699
712
|
"model_name": "MeloTTS-English",
|
|
700
713
|
"model_family": "MeloTTS",
|
|
701
714
|
"model_ability": [
|
|
702
|
-
"text2audio"
|
|
715
|
+
"text2audio",
|
|
716
|
+
"text2audio_zero_shot"
|
|
703
717
|
],
|
|
704
718
|
"multilingual": false,
|
|
705
719
|
"language": "EN",
|
|
@@ -715,7 +729,8 @@
|
|
|
715
729
|
"model_name": "MeloTTS-English-v2",
|
|
716
730
|
"model_family": "MeloTTS",
|
|
717
731
|
"model_ability": [
|
|
718
|
-
"text2audio"
|
|
732
|
+
"text2audio",
|
|
733
|
+
"text2audio_zero_shot"
|
|
719
734
|
],
|
|
720
735
|
"multilingual": false,
|
|
721
736
|
"language": "EN",
|
|
@@ -731,7 +746,8 @@
|
|
|
731
746
|
"model_name": "MeloTTS-English-v3",
|
|
732
747
|
"model_family": "MeloTTS",
|
|
733
748
|
"model_ability": [
|
|
734
|
-
"text2audio"
|
|
749
|
+
"text2audio",
|
|
750
|
+
"text2audio_zero_shot"
|
|
735
751
|
],
|
|
736
752
|
"multilingual": false,
|
|
737
753
|
"language": "EN",
|
|
@@ -747,7 +763,8 @@
|
|
|
747
763
|
"model_name": "MeloTTS-French",
|
|
748
764
|
"model_family": "MeloTTS",
|
|
749
765
|
"model_ability": [
|
|
750
|
-
"text2audio"
|
|
766
|
+
"text2audio",
|
|
767
|
+
"text2audio_zero_shot"
|
|
751
768
|
],
|
|
752
769
|
"multilingual": false,
|
|
753
770
|
"language": "FR",
|
|
@@ -763,7 +780,8 @@
|
|
|
763
780
|
"model_name": "MeloTTS-Japanese",
|
|
764
781
|
"model_family": "MeloTTS",
|
|
765
782
|
"model_ability": [
|
|
766
|
-
"text2audio"
|
|
783
|
+
"text2audio",
|
|
784
|
+
"text2audio_zero_shot"
|
|
767
785
|
],
|
|
768
786
|
"multilingual": false,
|
|
769
787
|
"language": "JP",
|
|
@@ -779,7 +797,8 @@
|
|
|
779
797
|
"model_name": "MeloTTS-Spanish",
|
|
780
798
|
"model_family": "MeloTTS",
|
|
781
799
|
"model_ability": [
|
|
782
|
-
"text2audio"
|
|
800
|
+
"text2audio",
|
|
801
|
+
"text2audio_zero_shot"
|
|
783
802
|
],
|
|
784
803
|
"multilingual": false,
|
|
785
804
|
"language": "ES",
|
|
@@ -795,7 +814,8 @@
|
|
|
795
814
|
"model_name": "MeloTTS-Chinese",
|
|
796
815
|
"model_family": "MeloTTS",
|
|
797
816
|
"model_ability": [
|
|
798
|
-
"text2audio"
|
|
817
|
+
"text2audio",
|
|
818
|
+
"text2audio_zero_shot"
|
|
799
819
|
],
|
|
800
820
|
"multilingual": false,
|
|
801
821
|
"language": "ZH",
|
|
@@ -811,7 +831,8 @@
|
|
|
811
831
|
"model_name": "MeloTTS-Korean",
|
|
812
832
|
"model_family": "MeloTTS",
|
|
813
833
|
"model_ability": [
|
|
814
|
-
"text2audio"
|
|
834
|
+
"text2audio",
|
|
835
|
+
"text2audio_zero_shot"
|
|
815
836
|
],
|
|
816
837
|
"multilingual": false,
|
|
817
838
|
"language": "KR",
|
|
@@ -827,7 +848,8 @@
|
|
|
827
848
|
"model_name": "Kokoro-82M",
|
|
828
849
|
"model_family": "Kokoro",
|
|
829
850
|
"model_ability": [
|
|
830
|
-
"text2audio"
|
|
851
|
+
"text2audio",
|
|
852
|
+
"text2audio_zero_shot"
|
|
831
853
|
],
|
|
832
854
|
"multilingual": true,
|
|
833
855
|
"model_src": {
|
|
@@ -840,13 +862,34 @@
|
|
|
840
862
|
"model_revision": "master"
|
|
841
863
|
}
|
|
842
864
|
}
|
|
865
|
+
},
|
|
866
|
+
{
|
|
867
|
+
"version": 2,
|
|
868
|
+
"model_name": "Kokoro-82M-v1.1-zh",
|
|
869
|
+
"model_family": "Kokoro-zh",
|
|
870
|
+
"model_ability": [
|
|
871
|
+
"text2audio",
|
|
872
|
+
"text2audio_zero_shot"
|
|
873
|
+
],
|
|
874
|
+
"multilingual": false,
|
|
875
|
+
"model_src": {
|
|
876
|
+
"huggingface": {
|
|
877
|
+
"model_id": "hexgrad/Kokoro-82M-v1.1-zh",
|
|
878
|
+
"model_revision": "01e7505bd6a7a2ac4975463114c3a7650a9f7218"
|
|
879
|
+
},
|
|
880
|
+
"modelscope": {
|
|
881
|
+
"model_id": "AI-ModelScope/Kokoro-82M-v1.1-zh",
|
|
882
|
+
"model_revision": "master"
|
|
883
|
+
}
|
|
884
|
+
}
|
|
843
885
|
},
|
|
844
886
|
{
|
|
845
887
|
"version": 2,
|
|
846
888
|
"model_name": "Kokoro-82M-MLX",
|
|
847
889
|
"model_family": "Kokoro-MLX",
|
|
848
890
|
"model_ability": [
|
|
849
|
-
"text2audio"
|
|
891
|
+
"text2audio",
|
|
892
|
+
"text2audio_zero_shot"
|
|
850
893
|
],
|
|
851
894
|
"multilingual": true,
|
|
852
895
|
"model_src": {
|
|
@@ -874,7 +917,8 @@
|
|
|
874
917
|
"model_name": "MegaTTS3",
|
|
875
918
|
"model_family": "MegaTTS",
|
|
876
919
|
"model_ability": [
|
|
877
|
-
"text2audio"
|
|
920
|
+
"text2audio",
|
|
921
|
+
"text2audio_zero_shot"
|
|
878
922
|
],
|
|
879
923
|
"multilingual": true,
|
|
880
924
|
"model_src": {
|
|
@@ -58,6 +58,11 @@ class FlagEmbeddingModel(EmbeddingModel):
|
|
|
58
58
|
self._return_sparse = return_sparse
|
|
59
59
|
|
|
60
60
|
def load(self):
|
|
61
|
+
# add truncate_dim args hint
|
|
62
|
+
if self._kwargs and "dimensions" in self._kwargs:
|
|
63
|
+
raise NotImplementedError(
|
|
64
|
+
"Flag embedder does not support dimensions argument now."
|
|
65
|
+
)
|
|
61
66
|
try:
|
|
62
67
|
from FlagEmbedding import BGEM3FlagModel
|
|
63
68
|
except ImportError:
|
|
@@ -22,7 +22,7 @@ import queue
|
|
|
22
22
|
import sys
|
|
23
23
|
from typing import List, Optional, Union
|
|
24
24
|
|
|
25
|
-
import
|
|
25
|
+
from packaging import version
|
|
26
26
|
|
|
27
27
|
from ....types import Embedding
|
|
28
28
|
from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
|
|
@@ -69,15 +69,29 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
|
|
|
69
69
|
return sys.platform.startswith("linux")
|
|
70
70
|
|
|
71
71
|
def load(self):
|
|
72
|
+
# add truncate_dim args hint
|
|
73
|
+
if "dimensions" in self._kwargs:
|
|
74
|
+
raise NotImplementedError(
|
|
75
|
+
"LlamaCpp embedder does not support dimensions argument now."
|
|
76
|
+
)
|
|
72
77
|
try:
|
|
73
78
|
from xllamacpp import (
|
|
74
79
|
CommonParams,
|
|
75
80
|
Server,
|
|
81
|
+
__version__,
|
|
76
82
|
estimate_gpu_layers,
|
|
77
83
|
get_device_info,
|
|
78
84
|
ggml_backend_dev_type,
|
|
79
85
|
llama_pooling_type,
|
|
80
86
|
)
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
if version.parse(__version__) < version.parse("0.2.0"):
|
|
90
|
+
raise RuntimeError(
|
|
91
|
+
"Please update xllamacpp to >= 0.2.0 by `pip install -U xllamacpp`"
|
|
92
|
+
)
|
|
93
|
+
except version.InvalidVersion:
|
|
94
|
+
pass # If the version parse failed, we just skip the version check.
|
|
81
95
|
except ImportError:
|
|
82
96
|
error_message = "Failed to import module 'xllamacpp'"
|
|
83
97
|
installation_guide = ["Please make sure 'xllamacpp' is installed. "]
|
|
@@ -162,7 +176,8 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
|
|
|
162
176
|
)
|
|
163
177
|
logger.info("Estimate num gpu layers: %s", estimate)
|
|
164
178
|
if estimate.tensor_split:
|
|
165
|
-
|
|
179
|
+
for i in range(len(estimate.tensor_split)):
|
|
180
|
+
params.tensor_split[i] = estimate.tensor_split[i]
|
|
166
181
|
else:
|
|
167
182
|
params.n_gpu_layers = estimate.layers
|
|
168
183
|
except Exception as e:
|
|
@@ -190,24 +205,12 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
|
|
|
190
205
|
model_uid: Optional[str] = kwargs.pop("model_uid", None)
|
|
191
206
|
if model_uid:
|
|
192
207
|
data["model"] = model_uid
|
|
193
|
-
prompt_json = orjson.dumps(data)
|
|
194
|
-
|
|
195
|
-
def _error_callback(err):
|
|
196
|
-
try:
|
|
197
|
-
msg = orjson.loads(err)
|
|
198
|
-
q.put(_Error(msg))
|
|
199
|
-
except Exception as e:
|
|
200
|
-
q.put(_Error(str(e)))
|
|
201
|
-
|
|
202
|
-
def _ok_callback(ok):
|
|
203
|
-
try:
|
|
204
|
-
res = orjson.loads(ok)
|
|
205
|
-
q.put(res)
|
|
206
|
-
except Exception as e:
|
|
207
|
-
q.put(_Error(str(e)))
|
|
208
|
-
|
|
209
208
|
try:
|
|
210
|
-
self._llm.handle_embeddings(
|
|
209
|
+
res = self._llm.handle_embeddings(data)
|
|
210
|
+
if res.get("code"):
|
|
211
|
+
q.put(_Error(res))
|
|
212
|
+
else:
|
|
213
|
+
q.put(res)
|
|
211
214
|
except Exception as ex:
|
|
212
215
|
q.put(_Error(str(ex)))
|
|
213
216
|
q.put(_Done)
|
|
@@ -71,6 +71,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
71
71
|
)
|
|
72
72
|
torch_dtype = torch.float32
|
|
73
73
|
|
|
74
|
+
dimensions = self._kwargs.get("dimensions")
|
|
75
|
+
assert dimensions is None or isinstance(dimensions, int), (
|
|
76
|
+
"The `dimensions` argument must be an integer, "
|
|
77
|
+
f"but got {type(dimensions)}: {dimensions}"
|
|
78
|
+
)
|
|
79
|
+
|
|
74
80
|
if (
|
|
75
81
|
"gte" in self.model_family.model_name.lower()
|
|
76
82
|
and "qwen2" in self.model_family.model_name.lower()
|
|
@@ -82,6 +88,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
82
88
|
self._model_path,
|
|
83
89
|
device=self._device,
|
|
84
90
|
model_kwargs=model_kwargs,
|
|
91
|
+
truncate_dim=dimensions,
|
|
85
92
|
)
|
|
86
93
|
elif "qwen3" in self.model_family.model_name.lower():
|
|
87
94
|
# qwen3 embedding
|
|
@@ -106,6 +113,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
106
113
|
device=self._device,
|
|
107
114
|
model_kwargs=model_kwargs,
|
|
108
115
|
tokenizer_kwargs=tokenizer_kwargs,
|
|
116
|
+
truncate_dim=dimensions,
|
|
109
117
|
)
|
|
110
118
|
else:
|
|
111
119
|
model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
|
|
@@ -114,6 +122,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
114
122
|
device=self._device,
|
|
115
123
|
model_kwargs=model_kwargs,
|
|
116
124
|
trust_remote_code=True,
|
|
125
|
+
truncate_dim=dimensions,
|
|
117
126
|
)
|
|
118
127
|
|
|
119
128
|
if hasattr(self._model, "tokenizer"):
|
|
@@ -256,10 +265,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
256
265
|
"clip" in self.model_family.model_name.lower()
|
|
257
266
|
or "jina-embeddings-v4" in self.model_family.model_name.lower()
|
|
258
267
|
):
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
268
|
+
# support input_ids and text_input_ids
|
|
269
|
+
for key in ["input_ids", "text_input_ids"]:
|
|
270
|
+
if key in features and hasattr(features[key], "numel"):
|
|
271
|
+
all_token_nums += features[key].numel()
|
|
263
272
|
if "pixel_values" in features and hasattr(
|
|
264
273
|
features["pixel_values"], "numel"
|
|
265
274
|
):
|
|
@@ -270,6 +279,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
270
279
|
with torch.no_grad():
|
|
271
280
|
out_features = model.forward(features, **kwargs)
|
|
272
281
|
|
|
282
|
+
from sentence_transformers.util import truncate_embeddings
|
|
283
|
+
|
|
284
|
+
out_features["sentence_embedding"] = truncate_embeddings(
|
|
285
|
+
out_features["sentence_embedding"], model.truncate_dim
|
|
286
|
+
)
|
|
287
|
+
|
|
273
288
|
if output_value == "token_embeddings":
|
|
274
289
|
embeddings = []
|
|
275
290
|
for token_emb, attention in zip(
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import importlib.util
|
|
16
|
+
import json
|
|
16
17
|
import logging
|
|
17
18
|
from typing import List, Union
|
|
18
19
|
|
|
@@ -25,7 +26,6 @@ SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
|
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class VLLMEmbeddingModel(EmbeddingModel):
|
|
28
|
-
|
|
29
29
|
def __init__(self, *args, **kwargs):
|
|
30
30
|
super().__init__(*args, **kwargs)
|
|
31
31
|
self._context_length = None
|
|
@@ -42,13 +42,31 @@ class VLLMEmbeddingModel(EmbeddingModel):
|
|
|
42
42
|
]
|
|
43
43
|
|
|
44
44
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
45
|
+
if self.model_family.model_name in {
|
|
46
|
+
"Qwen3-Embedding-0.6B",
|
|
47
|
+
"Qwen3-Embedding-4B",
|
|
48
|
+
"Qwen3-Embedding-8B",
|
|
49
|
+
}:
|
|
50
|
+
if "hf_overrides" not in self._kwargs:
|
|
51
|
+
self._kwargs["hf_overrides"] = {
|
|
52
|
+
"is_matryoshka": True,
|
|
53
|
+
}
|
|
54
|
+
elif isinstance(self._kwargs["hf_overrides"], dict):
|
|
55
|
+
self._kwargs["hf_overrides"].update(
|
|
56
|
+
is_matryoshka=True,
|
|
57
|
+
)
|
|
58
|
+
elif isinstance(self._kwargs["hf_overrides"], str):
|
|
59
|
+
self._kwargs["hf_overrides"] = json.loads(self._kwargs["hf_overrides"])
|
|
60
|
+
self._kwargs["hf_overrides"].update(
|
|
61
|
+
is_matryoshka=True,
|
|
62
|
+
)
|
|
45
63
|
|
|
46
64
|
self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
|
|
47
65
|
self._tokenizer = self._model.get_tokenizer()
|
|
48
66
|
|
|
49
67
|
@staticmethod
|
|
50
68
|
def _get_detailed_instruct(task_description: str, query: str) -> str:
|
|
51
|
-
return f"Instruct: {task_description}\nQuery:{query}"
|
|
69
|
+
return f"Instruct: {task_description}\nQuery:{query}" # noqa: E231
|
|
52
70
|
|
|
53
71
|
@cache_clean
|
|
54
72
|
def create_embedding(
|
|
@@ -56,14 +74,15 @@ class VLLMEmbeddingModel(EmbeddingModel):
|
|
|
56
74
|
sentences: Union[str, List[str]],
|
|
57
75
|
**kwargs,
|
|
58
76
|
):
|
|
77
|
+
from packaging.version import Version
|
|
78
|
+
from vllm import PoolingParams
|
|
79
|
+
from vllm import __version__ as vllm_version
|
|
80
|
+
|
|
59
81
|
sentences = self._fix_langchain_openai_inputs(sentences)
|
|
60
82
|
model_uid = kwargs.pop("model_uid", None)
|
|
61
83
|
|
|
62
84
|
normalize_embedding = kwargs.get("normalize_embedding", True)
|
|
63
|
-
|
|
64
|
-
raise ValueError(
|
|
65
|
-
"vllm embedding engine does not support setting `normalize_embedding=False`"
|
|
66
|
-
)
|
|
85
|
+
dimensions = kwargs.get("dimensions", None)
|
|
67
86
|
|
|
68
87
|
assert self._model is not None
|
|
69
88
|
|
|
@@ -92,8 +111,21 @@ class VLLMEmbeddingModel(EmbeddingModel):
|
|
|
92
111
|
sentences = truncated_sentences[0]
|
|
93
112
|
else:
|
|
94
113
|
sentences = truncated_sentences
|
|
95
|
-
|
|
96
|
-
|
|
114
|
+
if Version(vllm_version) > Version("0.10.1"):
|
|
115
|
+
pool_params = PoolingParams(
|
|
116
|
+
dimensions=dimensions, normalize=normalize_embedding
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
if not normalize_embedding:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"vLLM version {vllm_version} does not support "
|
|
122
|
+
f"unnormalized embeddings. "
|
|
123
|
+
f"Please upgrade to v0.10.1 or later."
|
|
124
|
+
)
|
|
125
|
+
pool_params = PoolingParams(dimensions=dimensions)
|
|
126
|
+
outputs = self._model.embed(
|
|
127
|
+
sentences, use_tqdm=False, pooling_params=pool_params
|
|
128
|
+
)
|
|
97
129
|
embedding_list = []
|
|
98
130
|
all_token_nums = 0
|
|
99
131
|
for index, output in enumerate(outputs):
|
|
@@ -60,3 +60,59 @@ class ImageCacheManager(CacheManager):
|
|
|
60
60
|
raise NotImplementedError
|
|
61
61
|
|
|
62
62
|
return full_path
|
|
63
|
+
|
|
64
|
+
def cache_lightning(self, lightning_version: Optional[str] = None):
|
|
65
|
+
from ..utils import IS_NEW_HUGGINGFACE_HUB, retry_download, symlink_local_file
|
|
66
|
+
from .core import ImageModelFamilyV2
|
|
67
|
+
|
|
68
|
+
if not lightning_version:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
assert isinstance(self._model_family, ImageModelFamilyV2)
|
|
72
|
+
cache_dir = self.get_cache_dir()
|
|
73
|
+
|
|
74
|
+
if not self._model_family.lightning_model_file_name_template:
|
|
75
|
+
raise NotImplementedError(
|
|
76
|
+
f"{self._model_family.model_name} does not support lightning"
|
|
77
|
+
)
|
|
78
|
+
if lightning_version not in (self._model_family.lightning_versions or []):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Cannot support lightning version {lightning_version}, "
|
|
81
|
+
f"available lightning version: {self._model_family.lightning_versions}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
filename = self._model_family.lightning_model_file_name_template.format(lightning_version=lightning_version) # type: ignore
|
|
85
|
+
full_path = os.path.join(cache_dir, filename)
|
|
86
|
+
|
|
87
|
+
if self._model_family.model_hub == "huggingface":
|
|
88
|
+
import huggingface_hub
|
|
89
|
+
|
|
90
|
+
use_symlinks = {}
|
|
91
|
+
if not IS_NEW_HUGGINGFACE_HUB:
|
|
92
|
+
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
|
|
93
|
+
download_file_path = retry_download(
|
|
94
|
+
huggingface_hub.hf_hub_download,
|
|
95
|
+
self._model_family.model_name,
|
|
96
|
+
None,
|
|
97
|
+
self._model_family.lightning_model_id,
|
|
98
|
+
filename=filename,
|
|
99
|
+
**use_symlinks,
|
|
100
|
+
)
|
|
101
|
+
if IS_NEW_HUGGINGFACE_HUB:
|
|
102
|
+
symlink_local_file(download_file_path, cache_dir, filename)
|
|
103
|
+
elif self._model_family.model_hub == "modelscope":
|
|
104
|
+
from modelscope.hub.file_download import model_file_download
|
|
105
|
+
|
|
106
|
+
download_file_path = retry_download(
|
|
107
|
+
model_file_download,
|
|
108
|
+
self._model_family.model_name,
|
|
109
|
+
None,
|
|
110
|
+
self._model_family.lightning_model_id,
|
|
111
|
+
filename,
|
|
112
|
+
revision=self._model_family.model_revision,
|
|
113
|
+
)
|
|
114
|
+
symlink_local_file(download_file_path, cache_dir, filename)
|
|
115
|
+
else:
|
|
116
|
+
raise NotImplementedError
|
|
117
|
+
|
|
118
|
+
return full_path
|