xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -73,13 +73,17 @@ class ImageInterface:
|
|
|
73
73
|
return interface
|
|
74
74
|
|
|
75
75
|
def text2image_interface(self) -> "gr.Blocks":
|
|
76
|
+
from ..model.image.stable_diffusion.core import SAMPLING_METHODS
|
|
77
|
+
|
|
76
78
|
def text_generate_image(
|
|
77
79
|
prompt: str,
|
|
78
80
|
n: int,
|
|
79
81
|
size_width: int,
|
|
80
82
|
size_height: int,
|
|
83
|
+
guidance_scale: int,
|
|
81
84
|
num_inference_steps: int,
|
|
82
85
|
negative_prompt: Optional[str] = None,
|
|
86
|
+
sampler_name: Optional[str] = None,
|
|
83
87
|
) -> PIL.Image.Image:
|
|
84
88
|
from ..client import RESTfulClient
|
|
85
89
|
|
|
@@ -89,16 +93,20 @@ class ImageInterface:
|
|
|
89
93
|
assert isinstance(model, RESTfulImageModelHandle)
|
|
90
94
|
|
|
91
95
|
size = f"{int(size_width)}*{int(size_height)}"
|
|
96
|
+
guidance_scale = None if guidance_scale == -1 else guidance_scale # type: ignore
|
|
92
97
|
num_inference_steps = (
|
|
93
98
|
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
94
99
|
)
|
|
100
|
+
sampler_name = None if sampler_name == "default" else sampler_name
|
|
95
101
|
|
|
96
102
|
response = model.text_to_image(
|
|
97
103
|
prompt=prompt,
|
|
98
104
|
n=n,
|
|
99
105
|
size=size,
|
|
100
106
|
num_inference_steps=num_inference_steps,
|
|
107
|
+
guidance_scale=guidance_scale,
|
|
101
108
|
negative_prompt=negative_prompt,
|
|
109
|
+
sampler_name=sampler_name,
|
|
102
110
|
response_format="b64_json",
|
|
103
111
|
)
|
|
104
112
|
|
|
@@ -132,9 +140,16 @@ class ImageInterface:
|
|
|
132
140
|
n = gr.Number(label="Number of Images", value=1)
|
|
133
141
|
size_width = gr.Number(label="Width", value=1024)
|
|
134
142
|
size_height = gr.Number(label="Height", value=1024)
|
|
143
|
+
with gr.Row():
|
|
144
|
+
guidance_scale = gr.Number(label="Guidance scale", value=-1)
|
|
135
145
|
num_inference_steps = gr.Number(
|
|
136
146
|
label="Inference Step Number", value=-1
|
|
137
147
|
)
|
|
148
|
+
sampler_name = gr.Dropdown(
|
|
149
|
+
choices=SAMPLING_METHODS,
|
|
150
|
+
value="default",
|
|
151
|
+
label="Sampling method",
|
|
152
|
+
)
|
|
138
153
|
|
|
139
154
|
with gr.Column():
|
|
140
155
|
image_output = gr.Gallery()
|
|
@@ -146,8 +161,10 @@ class ImageInterface:
|
|
|
146
161
|
n,
|
|
147
162
|
size_width,
|
|
148
163
|
size_height,
|
|
164
|
+
guidance_scale,
|
|
149
165
|
num_inference_steps,
|
|
150
166
|
negative_prompt,
|
|
167
|
+
sampler_name,
|
|
151
168
|
],
|
|
152
169
|
outputs=image_output,
|
|
153
170
|
)
|
|
@@ -155,6 +172,8 @@ class ImageInterface:
|
|
|
155
172
|
return text2image_vl_interface
|
|
156
173
|
|
|
157
174
|
def image2image_interface(self) -> "gr.Blocks":
|
|
175
|
+
from ..model.image.stable_diffusion.core import SAMPLING_METHODS
|
|
176
|
+
|
|
158
177
|
def image_generate_image(
|
|
159
178
|
prompt: str,
|
|
160
179
|
negative_prompt: str,
|
|
@@ -164,6 +183,7 @@ class ImageInterface:
|
|
|
164
183
|
size_height: int,
|
|
165
184
|
num_inference_steps: int,
|
|
166
185
|
padding_image_to_multiple: int,
|
|
186
|
+
sampler_name: Optional[str] = None,
|
|
167
187
|
) -> PIL.Image.Image:
|
|
168
188
|
from ..client import RESTfulClient
|
|
169
189
|
|
|
@@ -180,6 +200,7 @@ class ImageInterface:
|
|
|
180
200
|
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
181
201
|
)
|
|
182
202
|
padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
|
|
203
|
+
sampler_name = None if sampler_name == "default" else sampler_name
|
|
183
204
|
|
|
184
205
|
bio = io.BytesIO()
|
|
185
206
|
image.save(bio, format="png")
|
|
@@ -193,6 +214,7 @@ class ImageInterface:
|
|
|
193
214
|
response_format="b64_json",
|
|
194
215
|
num_inference_steps=num_inference_steps,
|
|
195
216
|
padding_image_to_multiple=padding_image_to_multiple,
|
|
217
|
+
sampler_name=sampler_name,
|
|
196
218
|
)
|
|
197
219
|
|
|
198
220
|
images = []
|
|
@@ -233,6 +255,11 @@ class ImageInterface:
|
|
|
233
255
|
padding_image_to_multiple = gr.Number(
|
|
234
256
|
label="Padding image to multiple", value=-1
|
|
235
257
|
)
|
|
258
|
+
sampler_name = gr.Dropdown(
|
|
259
|
+
choices=SAMPLING_METHODS,
|
|
260
|
+
value="default",
|
|
261
|
+
label="Sampling method",
|
|
262
|
+
)
|
|
236
263
|
|
|
237
264
|
with gr.Row():
|
|
238
265
|
with gr.Column(scale=1):
|
|
@@ -251,6 +278,7 @@ class ImageInterface:
|
|
|
251
278
|
size_height,
|
|
252
279
|
num_inference_steps,
|
|
253
280
|
padding_image_to_multiple,
|
|
281
|
+
sampler_name,
|
|
254
282
|
],
|
|
255
283
|
outputs=output_gallery,
|
|
256
284
|
)
|
xinference/core/model.py
CHANGED
|
@@ -19,6 +19,7 @@ import json
|
|
|
19
19
|
import os
|
|
20
20
|
import time
|
|
21
21
|
import types
|
|
22
|
+
import uuid
|
|
22
23
|
import weakref
|
|
23
24
|
from asyncio.queues import Queue
|
|
24
25
|
from asyncio.tasks import wait_for
|
|
@@ -65,7 +66,12 @@ except ImportError:
|
|
|
65
66
|
OutOfMemoryError = _OutOfMemoryError
|
|
66
67
|
|
|
67
68
|
|
|
68
|
-
XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
|
|
69
|
+
XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
|
|
70
|
+
"qwen-vl-chat",
|
|
71
|
+
"cogvlm2",
|
|
72
|
+
"glm-4v",
|
|
73
|
+
"MiniCPM-V-2.6",
|
|
74
|
+
]
|
|
69
75
|
|
|
70
76
|
|
|
71
77
|
def request_limit(fn):
|
|
@@ -265,7 +271,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
265
271
|
|
|
266
272
|
if self._worker_ref is None:
|
|
267
273
|
self._worker_ref = await xo.actor_ref(
|
|
268
|
-
address=self._worker_address, uid=WorkerActor.
|
|
274
|
+
address=self._worker_address, uid=WorkerActor.default_uid()
|
|
269
275
|
)
|
|
270
276
|
return self._worker_ref
|
|
271
277
|
|
|
@@ -434,23 +440,35 @@ class ModelActor(xo.StatelessActor):
|
|
|
434
440
|
assert output_type == "binary", f"Unknown output type '{output_type}'"
|
|
435
441
|
return ret
|
|
436
442
|
|
|
437
|
-
@log_async(logger=logger)
|
|
438
443
|
@request_limit
|
|
439
444
|
@xo.generator
|
|
445
|
+
@log_async(logger=logger)
|
|
440
446
|
async def generate(self, prompt: str, *args, **kwargs):
|
|
441
447
|
if self.allow_batching():
|
|
448
|
+
# not support request_id
|
|
449
|
+
kwargs.pop("request_id", None)
|
|
442
450
|
return await self.handle_batching_request(
|
|
443
451
|
prompt, "generate", *args, **kwargs
|
|
444
452
|
)
|
|
445
453
|
else:
|
|
446
454
|
kwargs.pop("raw_params", None)
|
|
447
455
|
if hasattr(self._model, "generate"):
|
|
456
|
+
# not support request_id
|
|
457
|
+
kwargs.pop("request_id", None)
|
|
448
458
|
return await self._call_wrapper_json(
|
|
449
459
|
self._model.generate, prompt, *args, **kwargs
|
|
450
460
|
)
|
|
451
461
|
if hasattr(self._model, "async_generate"):
|
|
462
|
+
if "request_id" not in kwargs:
|
|
463
|
+
kwargs["request_id"] = str(uuid.uuid1())
|
|
464
|
+
else:
|
|
465
|
+
# model only accept string
|
|
466
|
+
kwargs["request_id"] = str(kwargs["request_id"])
|
|
452
467
|
return await self._call_wrapper_json(
|
|
453
|
-
self._model.async_generate,
|
|
468
|
+
self._model.async_generate,
|
|
469
|
+
prompt,
|
|
470
|
+
*args,
|
|
471
|
+
**kwargs,
|
|
454
472
|
)
|
|
455
473
|
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
456
474
|
|
|
@@ -481,22 +499,27 @@ class ModelActor(xo.StatelessActor):
|
|
|
481
499
|
yield res
|
|
482
500
|
|
|
483
501
|
@staticmethod
|
|
484
|
-
def _get_stream_from_args(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
return False if args[2] is None else args[2].get("stream", False)
|
|
488
|
-
else:
|
|
489
|
-
assert args[0] is None or isinstance(args[0], dict)
|
|
490
|
-
return False if args[0] is None else args[0].get("stream", False)
|
|
502
|
+
def _get_stream_from_args(*args) -> bool:
|
|
503
|
+
assert args[0] is None or isinstance(args[0], dict)
|
|
504
|
+
return False if args[0] is None else args[0].get("stream", False)
|
|
491
505
|
|
|
492
|
-
async def handle_batching_request(
|
|
493
|
-
|
|
506
|
+
async def handle_batching_request(
|
|
507
|
+
self, prompt_or_messages: Union[str, List[Dict]], call_ability, *args, **kwargs
|
|
508
|
+
):
|
|
509
|
+
"""
|
|
510
|
+
The input parameter `prompt_or_messages`:
|
|
511
|
+
- when the model_ability is `generate`, it's `prompt`, which is str type.
|
|
512
|
+
- when the model_ability is `chat`, it's `messages`, which is List[Dict] type.
|
|
513
|
+
"""
|
|
514
|
+
stream = self._get_stream_from_args(*args)
|
|
494
515
|
assert self._scheduler_ref is not None
|
|
495
516
|
if stream:
|
|
496
517
|
assert self._scheduler_ref is not None
|
|
497
518
|
queue: Queue[Any] = Queue()
|
|
498
519
|
ret = self._queue_consumer(queue)
|
|
499
|
-
await self._scheduler_ref.add_request(
|
|
520
|
+
await self._scheduler_ref.add_request(
|
|
521
|
+
prompt_or_messages, queue, call_ability, *args, **kwargs
|
|
522
|
+
)
|
|
500
523
|
gen = self._to_async_gen("json", ret)
|
|
501
524
|
self._current_generator = weakref.ref(gen)
|
|
502
525
|
return gen
|
|
@@ -505,7 +528,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
505
528
|
|
|
506
529
|
assert self._loop is not None
|
|
507
530
|
future = ConcurrentFuture()
|
|
508
|
-
await self._scheduler_ref.add_request(
|
|
531
|
+
await self._scheduler_ref.add_request(
|
|
532
|
+
prompt_or_messages, future, call_ability, *args, **kwargs
|
|
533
|
+
)
|
|
509
534
|
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
510
535
|
result = await fut
|
|
511
536
|
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
@@ -514,27 +539,36 @@ class ModelActor(xo.StatelessActor):
|
|
|
514
539
|
)
|
|
515
540
|
return await asyncio.to_thread(json_dumps, result)
|
|
516
541
|
|
|
517
|
-
@log_async(logger=logger)
|
|
518
542
|
@request_limit
|
|
519
543
|
@xo.generator
|
|
520
|
-
|
|
544
|
+
@log_async(logger=logger)
|
|
545
|
+
async def chat(self, messages: List[Dict], *args, **kwargs):
|
|
521
546
|
start_time = time.time()
|
|
522
547
|
response = None
|
|
523
548
|
try:
|
|
524
549
|
if self.allow_batching():
|
|
550
|
+
# not support request_id
|
|
551
|
+
kwargs.pop("request_id", None)
|
|
525
552
|
return await self.handle_batching_request(
|
|
526
|
-
|
|
553
|
+
messages, "chat", *args, **kwargs
|
|
527
554
|
)
|
|
528
555
|
else:
|
|
529
556
|
kwargs.pop("raw_params", None)
|
|
530
557
|
if hasattr(self._model, "chat"):
|
|
558
|
+
# not support request_id
|
|
559
|
+
kwargs.pop("request_id", None)
|
|
531
560
|
response = await self._call_wrapper_json(
|
|
532
|
-
self._model.chat,
|
|
561
|
+
self._model.chat, messages, *args, **kwargs
|
|
533
562
|
)
|
|
534
563
|
return response
|
|
535
564
|
if hasattr(self._model, "async_chat"):
|
|
565
|
+
if "request_id" not in kwargs:
|
|
566
|
+
kwargs["request_id"] = str(uuid.uuid1())
|
|
567
|
+
else:
|
|
568
|
+
# model only accept string
|
|
569
|
+
kwargs["request_id"] = str(kwargs["request_id"])
|
|
536
570
|
response = await self._call_wrapper_json(
|
|
537
|
-
self._model.async_chat,
|
|
571
|
+
self._model.async_chat, messages, *args, **kwargs
|
|
538
572
|
)
|
|
539
573
|
return response
|
|
540
574
|
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
|
|
@@ -565,9 +599,10 @@ class ModelActor(xo.StatelessActor):
|
|
|
565
599
|
return await self._scheduler_ref.abort_request(request_id)
|
|
566
600
|
return AbortRequestMessage.NO_OP.name
|
|
567
601
|
|
|
568
|
-
@log_async(logger=logger)
|
|
569
602
|
@request_limit
|
|
603
|
+
@log_async(logger=logger)
|
|
570
604
|
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
|
|
605
|
+
kwargs.pop("request_id", None)
|
|
571
606
|
if hasattr(self._model, "create_embedding"):
|
|
572
607
|
return await self._call_wrapper_json(
|
|
573
608
|
self._model.create_embedding, input, *args, **kwargs
|
|
@@ -577,8 +612,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
577
612
|
f"Model {self._model.model_spec} is not for creating embedding."
|
|
578
613
|
)
|
|
579
614
|
|
|
580
|
-
@log_async(logger=logger)
|
|
581
615
|
@request_limit
|
|
616
|
+
@log_async(logger=logger)
|
|
582
617
|
async def rerank(
|
|
583
618
|
self,
|
|
584
619
|
documents: List[str],
|
|
@@ -590,6 +625,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
590
625
|
*args,
|
|
591
626
|
**kwargs,
|
|
592
627
|
):
|
|
628
|
+
kwargs.pop("request_id", None)
|
|
593
629
|
if hasattr(self._model, "rerank"):
|
|
594
630
|
return await self._call_wrapper_json(
|
|
595
631
|
self._model.rerank,
|
|
@@ -604,8 +640,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
604
640
|
)
|
|
605
641
|
raise AttributeError(f"Model {self._model.model_spec} is not for reranking.")
|
|
606
642
|
|
|
607
|
-
@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
|
|
608
643
|
@request_limit
|
|
644
|
+
@log_async(logger=logger, ignore_kwargs=["audio"])
|
|
609
645
|
async def transcriptions(
|
|
610
646
|
self,
|
|
611
647
|
audio: bytes,
|
|
@@ -614,7 +650,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
614
650
|
response_format: str = "json",
|
|
615
651
|
temperature: float = 0,
|
|
616
652
|
timestamp_granularities: Optional[List[str]] = None,
|
|
653
|
+
**kwargs,
|
|
617
654
|
):
|
|
655
|
+
kwargs.pop("request_id", None)
|
|
618
656
|
if hasattr(self._model, "transcriptions"):
|
|
619
657
|
return await self._call_wrapper_json(
|
|
620
658
|
self._model.transcriptions,
|
|
@@ -629,8 +667,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
629
667
|
f"Model {self._model.model_spec} is not for creating transcriptions."
|
|
630
668
|
)
|
|
631
669
|
|
|
632
|
-
@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
|
|
633
670
|
@request_limit
|
|
671
|
+
@log_async(logger=logger, ignore_kwargs=["audio"])
|
|
634
672
|
async def translations(
|
|
635
673
|
self,
|
|
636
674
|
audio: bytes,
|
|
@@ -639,7 +677,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
639
677
|
response_format: str = "json",
|
|
640
678
|
temperature: float = 0,
|
|
641
679
|
timestamp_granularities: Optional[List[str]] = None,
|
|
680
|
+
**kwargs,
|
|
642
681
|
):
|
|
682
|
+
kwargs.pop("request_id", None)
|
|
643
683
|
if hasattr(self._model, "translations"):
|
|
644
684
|
return await self._call_wrapper_json(
|
|
645
685
|
self._model.translations,
|
|
@@ -654,12 +694,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
654
694
|
f"Model {self._model.model_spec} is not for creating translations."
|
|
655
695
|
)
|
|
656
696
|
|
|
657
|
-
@log_async(
|
|
658
|
-
logger=logger,
|
|
659
|
-
args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
|
|
660
|
-
)
|
|
661
697
|
@request_limit
|
|
662
698
|
@xo.generator
|
|
699
|
+
@log_async(logger=logger, ignore_kwargs=["prompt_speech"])
|
|
663
700
|
async def speech(
|
|
664
701
|
self,
|
|
665
702
|
input: str,
|
|
@@ -669,6 +706,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
669
706
|
stream: bool = False,
|
|
670
707
|
**kwargs,
|
|
671
708
|
):
|
|
709
|
+
kwargs.pop("request_id", None)
|
|
672
710
|
if hasattr(self._model, "speech"):
|
|
673
711
|
return await self._call_wrapper_binary(
|
|
674
712
|
self._model.speech,
|
|
@@ -683,8 +721,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
683
721
|
f"Model {self._model.model_spec} is not for creating speech."
|
|
684
722
|
)
|
|
685
723
|
|
|
686
|
-
@log_async(logger=logger)
|
|
687
724
|
@request_limit
|
|
725
|
+
@log_async(logger=logger)
|
|
688
726
|
async def text_to_image(
|
|
689
727
|
self,
|
|
690
728
|
prompt: str,
|
|
@@ -694,6 +732,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
694
732
|
*args,
|
|
695
733
|
**kwargs,
|
|
696
734
|
):
|
|
735
|
+
kwargs.pop("request_id", None)
|
|
697
736
|
if hasattr(self._model, "text_to_image"):
|
|
698
737
|
return await self._call_wrapper_json(
|
|
699
738
|
self._model.text_to_image,
|
|
@@ -708,6 +747,24 @@ class ModelActor(xo.StatelessActor):
|
|
|
708
747
|
f"Model {self._model.model_spec} is not for creating image."
|
|
709
748
|
)
|
|
710
749
|
|
|
750
|
+
@request_limit
|
|
751
|
+
@log_async(logger=logger)
|
|
752
|
+
async def txt2img(
|
|
753
|
+
self,
|
|
754
|
+
**kwargs,
|
|
755
|
+
):
|
|
756
|
+
kwargs.pop("request_id", None)
|
|
757
|
+
if hasattr(self._model, "txt2img"):
|
|
758
|
+
return await self._call_wrapper_json(
|
|
759
|
+
self._model.txt2img,
|
|
760
|
+
**kwargs,
|
|
761
|
+
)
|
|
762
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
|
|
763
|
+
|
|
764
|
+
@log_async(
|
|
765
|
+
logger=logger,
|
|
766
|
+
ignore_kwargs=["image"],
|
|
767
|
+
)
|
|
711
768
|
async def image_to_image(
|
|
712
769
|
self,
|
|
713
770
|
image: "PIL.Image",
|
|
@@ -719,6 +776,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
719
776
|
*args,
|
|
720
777
|
**kwargs,
|
|
721
778
|
):
|
|
779
|
+
kwargs.pop("request_id", None)
|
|
722
780
|
if hasattr(self._model, "image_to_image"):
|
|
723
781
|
return await self._call_wrapper_json(
|
|
724
782
|
self._model.image_to_image,
|
|
@@ -735,6 +793,24 @@ class ModelActor(xo.StatelessActor):
|
|
|
735
793
|
f"Model {self._model.model_spec} is not for creating image."
|
|
736
794
|
)
|
|
737
795
|
|
|
796
|
+
@request_limit
|
|
797
|
+
@log_async(logger=logger)
|
|
798
|
+
async def img2img(
|
|
799
|
+
self,
|
|
800
|
+
**kwargs,
|
|
801
|
+
):
|
|
802
|
+
kwargs.pop("request_id", None)
|
|
803
|
+
if hasattr(self._model, "img2img"):
|
|
804
|
+
return await self._call_wrapper_json(
|
|
805
|
+
self._model.img2img,
|
|
806
|
+
**kwargs,
|
|
807
|
+
)
|
|
808
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
|
|
809
|
+
|
|
810
|
+
@log_async(
|
|
811
|
+
logger=logger,
|
|
812
|
+
ignore_kwargs=["image"],
|
|
813
|
+
)
|
|
738
814
|
async def inpainting(
|
|
739
815
|
self,
|
|
740
816
|
image: "PIL.Image",
|
|
@@ -747,6 +823,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
747
823
|
*args,
|
|
748
824
|
**kwargs,
|
|
749
825
|
):
|
|
826
|
+
kwargs.pop("request_id", None)
|
|
750
827
|
if hasattr(self._model, "inpainting"):
|
|
751
828
|
return await self._call_wrapper_json(
|
|
752
829
|
self._model.inpainting,
|
|
@@ -764,12 +841,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
764
841
|
f"Model {self._model.model_spec} is not for creating image."
|
|
765
842
|
)
|
|
766
843
|
|
|
767
|
-
@log_async(logger=logger)
|
|
768
844
|
@request_limit
|
|
845
|
+
@log_async(logger=logger, ignore_kwargs=["image"])
|
|
769
846
|
async def infer(
|
|
770
847
|
self,
|
|
771
848
|
**kwargs,
|
|
772
849
|
):
|
|
850
|
+
kwargs.pop("request_id", None)
|
|
773
851
|
if hasattr(self._model, "infer"):
|
|
774
852
|
return await self._call_wrapper_json(
|
|
775
853
|
self._model.infer,
|
|
@@ -779,8 +857,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
779
857
|
f"Model {self._model.model_spec} is not for flexible infer."
|
|
780
858
|
)
|
|
781
859
|
|
|
782
|
-
@log_async(logger=logger)
|
|
783
860
|
@request_limit
|
|
861
|
+
@log_async(logger=logger)
|
|
784
862
|
async def text_to_video(
|
|
785
863
|
self,
|
|
786
864
|
prompt: str,
|
|
@@ -788,6 +866,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
788
866
|
*args,
|
|
789
867
|
**kwargs,
|
|
790
868
|
):
|
|
869
|
+
kwargs.pop("request_id", None)
|
|
791
870
|
if hasattr(self._model, "text_to_video"):
|
|
792
871
|
return await self._call_wrapper_json(
|
|
793
872
|
self._model.text_to_video,
|
xinference/core/scheduler.py
CHANGED
|
@@ -18,7 +18,7 @@ import logging
|
|
|
18
18
|
import uuid
|
|
19
19
|
from collections import deque
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import List, Optional, Set, Tuple
|
|
21
|
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
22
22
|
|
|
23
23
|
import xoscar as xo
|
|
24
24
|
|
|
@@ -37,13 +37,24 @@ class AbortRequestMessage(Enum):
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class InferenceRequest:
|
|
40
|
-
def __init__(
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
prompt_or_messages,
|
|
43
|
+
future_or_queue,
|
|
44
|
+
is_prefill,
|
|
45
|
+
call_ability,
|
|
46
|
+
*args,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
# original prompt, prompt(str) for generate model and messages(List[Dict]) for chat model
|
|
50
|
+
self._prompt = prompt_or_messages
|
|
43
51
|
# full prompt that contains chat history and applies chat template
|
|
44
52
|
self._full_prompt = None
|
|
45
53
|
# whether the current request is in the prefill phase
|
|
46
54
|
self._is_prefill = is_prefill
|
|
55
|
+
# the ability that the user calls this model for, that is `generate` / `chat` for now,
|
|
56
|
+
# which is for results formatting
|
|
57
|
+
self._call_ability = call_ability
|
|
47
58
|
# full prompt tokens
|
|
48
59
|
self._prompt_tokens = None
|
|
49
60
|
# all new generated tokens during decode phase
|
|
@@ -88,38 +99,22 @@ class InferenceRequest:
|
|
|
88
99
|
self._check_args()
|
|
89
100
|
|
|
90
101
|
def _check_args(self):
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
)
|
|
97
|
-
# chat history
|
|
98
|
-
assert self._inference_args[1] is None or isinstance(
|
|
99
|
-
self._inference_args[1], list
|
|
100
|
-
)
|
|
101
|
-
# generate config
|
|
102
|
-
assert self._inference_args[2] is None or isinstance(
|
|
103
|
-
self._inference_args[2], dict
|
|
104
|
-
)
|
|
105
|
-
else: # generate
|
|
106
|
-
assert len(self._inference_args) == 1
|
|
107
|
-
# generate config
|
|
108
|
-
assert self._inference_args[0] is None or isinstance(
|
|
109
|
-
self._inference_args[0], dict
|
|
110
|
-
)
|
|
102
|
+
assert len(self._inference_args) == 1
|
|
103
|
+
# generate config
|
|
104
|
+
assert self._inference_args[0] is None or isinstance(
|
|
105
|
+
self._inference_args[0], dict
|
|
106
|
+
)
|
|
111
107
|
|
|
112
108
|
@property
|
|
113
109
|
def prompt(self):
|
|
110
|
+
"""
|
|
111
|
+
prompt for generate model and messages for chat model
|
|
112
|
+
"""
|
|
114
113
|
return self._prompt
|
|
115
114
|
|
|
116
115
|
@property
|
|
117
|
-
def
|
|
118
|
-
return self.
|
|
119
|
-
|
|
120
|
-
@property
|
|
121
|
-
def chat_history(self):
|
|
122
|
-
return self._inference_args[1]
|
|
116
|
+
def call_ability(self):
|
|
117
|
+
return self._call_ability
|
|
123
118
|
|
|
124
119
|
@property
|
|
125
120
|
def full_prompt(self):
|
|
@@ -162,11 +157,7 @@ class InferenceRequest:
|
|
|
162
157
|
|
|
163
158
|
@property
|
|
164
159
|
def generate_config(self):
|
|
165
|
-
return
|
|
166
|
-
self._inference_args[2]
|
|
167
|
-
if len(self._inference_args) == 3
|
|
168
|
-
else self._inference_args[0]
|
|
169
|
-
)
|
|
160
|
+
return self._inference_args[0]
|
|
170
161
|
|
|
171
162
|
@property
|
|
172
163
|
def sanitized_generate_config(self):
|
|
@@ -423,8 +414,17 @@ class SchedulerActor(xo.StatelessActor):
|
|
|
423
414
|
|
|
424
415
|
self._empty_cache()
|
|
425
416
|
|
|
426
|
-
async def add_request(
|
|
427
|
-
|
|
417
|
+
async def add_request(
|
|
418
|
+
self,
|
|
419
|
+
prompt_or_messages: Union[str, List[Dict]],
|
|
420
|
+
future_or_queue,
|
|
421
|
+
call_ability,
|
|
422
|
+
*args,
|
|
423
|
+
**kwargs,
|
|
424
|
+
):
|
|
425
|
+
req = InferenceRequest(
|
|
426
|
+
prompt_or_messages, future_or_queue, True, call_ability, *args, **kwargs
|
|
427
|
+
)
|
|
428
428
|
rid = req.request_id
|
|
429
429
|
if rid is not None:
|
|
430
430
|
if rid in self._id_to_req:
|
xinference/core/status_guard.py
CHANGED
xinference/core/supervisor.py
CHANGED
|
@@ -105,7 +105,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
105
105
|
self._lock = asyncio.Lock()
|
|
106
106
|
|
|
107
107
|
@classmethod
|
|
108
|
-
def
|
|
108
|
+
def default_uid(cls) -> str:
|
|
109
109
|
return "supervisor"
|
|
110
110
|
|
|
111
111
|
def _get_worker_ref_by_ip(
|
|
@@ -135,12 +135,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
135
135
|
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
136
136
|
"StatusGuardActor"
|
|
137
137
|
] = await xo.create_actor(
|
|
138
|
-
StatusGuardActor, address=self.address, uid=StatusGuardActor.
|
|
138
|
+
StatusGuardActor, address=self.address, uid=StatusGuardActor.default_uid()
|
|
139
139
|
)
|
|
140
140
|
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
141
141
|
"CacheTrackerActor"
|
|
142
142
|
] = await xo.create_actor(
|
|
143
|
-
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.
|
|
143
|
+
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.default_uid()
|
|
144
144
|
)
|
|
145
145
|
|
|
146
146
|
from .event import EventCollectorActor
|
|
@@ -148,7 +148,9 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
148
148
|
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
149
149
|
EventCollectorActor
|
|
150
150
|
] = await xo.create_actor(
|
|
151
|
-
EventCollectorActor,
|
|
151
|
+
EventCollectorActor,
|
|
152
|
+
address=self.address,
|
|
153
|
+
uid=EventCollectorActor.default_uid(),
|
|
152
154
|
)
|
|
153
155
|
|
|
154
156
|
from ..model.audio import (
|
|
@@ -308,14 +310,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
308
310
|
async def get_builtin_prompts() -> Dict[str, Any]:
|
|
309
311
|
from ..model.llm.llm_family import BUILTIN_LLM_PROMPT_STYLE
|
|
310
312
|
|
|
311
|
-
|
|
312
|
-
for k, v in BUILTIN_LLM_PROMPT_STYLE.items():
|
|
313
|
-
data[k] = v.dict()
|
|
314
|
-
return data
|
|
313
|
+
return {k: v for k, v in BUILTIN_LLM_PROMPT_STYLE.items()}
|
|
315
314
|
|
|
316
315
|
@staticmethod
|
|
317
316
|
async def get_builtin_families() -> Dict[str, List[str]]:
|
|
318
317
|
from ..model.llm.llm_family import (
|
|
318
|
+
BUILTIN_LLM_FAMILIES,
|
|
319
319
|
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
|
|
320
320
|
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
|
|
321
321
|
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
|
|
@@ -325,6 +325,11 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
325
325
|
"chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
|
|
326
326
|
"generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
|
|
327
327
|
"tools": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES),
|
|
328
|
+
"vision": [
|
|
329
|
+
family.model_name
|
|
330
|
+
for family in BUILTIN_LLM_FAMILIES
|
|
331
|
+
if "vision" in family.model_ability
|
|
332
|
+
],
|
|
328
333
|
}
|
|
329
334
|
|
|
330
335
|
async def get_devices_count(self) -> int:
|
|
@@ -1028,7 +1033,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1028
1033
|
else:
|
|
1029
1034
|
task = asyncio.create_task(_launch_model())
|
|
1030
1035
|
ASYNC_LAUNCH_TASKS[model_uid] = task
|
|
1031
|
-
task.add_done_callback(lambda _: callback_for_async_launch(model_uid))
|
|
1036
|
+
task.add_done_callback(lambda _: callback_for_async_launch(model_uid)) # type: ignore
|
|
1032
1037
|
return model_uid
|
|
1033
1038
|
|
|
1034
1039
|
async def get_instance_info(
|
|
@@ -1233,7 +1238,9 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1233
1238
|
worker_address not in self._worker_address_to_worker
|
|
1234
1239
|
), f"Worker {worker_address} exists"
|
|
1235
1240
|
|
|
1236
|
-
worker_ref = await xo.actor_ref(
|
|
1241
|
+
worker_ref = await xo.actor_ref(
|
|
1242
|
+
address=worker_address, uid=WorkerActor.default_uid()
|
|
1243
|
+
)
|
|
1237
1244
|
self._worker_address_to_worker[worker_address] = worker_ref
|
|
1238
1245
|
logger.debug("Worker %s has been added successfully", worker_address)
|
|
1239
1246
|
|