xinference 0.10.2.post1__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +1 -1
- xinference/api/restful_api.py +53 -61
- xinference/client/restful/restful_client.py +52 -57
- xinference/conftest.py +1 -1
- xinference/core/cache_tracker.py +1 -1
- xinference/core/event.py +1 -1
- xinference/core/model.py +15 -4
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +58 -72
- xinference/core/worker.py +73 -102
- xinference/deploy/cmdline.py +175 -6
- xinference/deploy/test/test_cmdline.py +2 -0
- xinference/deploy/utils.py +1 -1
- xinference/device_utils.py +29 -3
- xinference/fields.py +5 -1
- xinference/model/audio/model_spec.json +8 -1
- xinference/model/audio/whisper.py +88 -12
- xinference/model/core.py +2 -2
- xinference/model/embedding/core.py +13 -0
- xinference/model/image/__init__.py +29 -0
- xinference/model/image/core.py +6 -0
- xinference/model/image/custom.py +109 -0
- xinference/model/llm/__init__.py +92 -32
- xinference/model/llm/core.py +57 -102
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +2 -2
- xinference/model/llm/llm_family.json +446 -2
- xinference/model/llm/llm_family.py +45 -41
- xinference/model/llm/llm_family_modelscope.json +208 -1
- xinference/model/llm/pytorch/deepseek_vl.py +89 -33
- xinference/model/llm/pytorch/qwen_vl.py +67 -12
- xinference/model/llm/pytorch/yi_vl.py +62 -45
- xinference/model/llm/utils.py +45 -15
- xinference/model/llm/vllm/core.py +21 -4
- xinference/model/rerank/core.py +48 -20
- xinference/thirdparty/omnilmm/chat.py +2 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +2 -1
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +6 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.54bca460.css +2 -0
- xinference/web/ui/build/static/css/main.54bca460.css.map +1 -0
- xinference/web/ui/build/static/js/main.8e44da4b.js +3 -0
- xinference/web/ui/build/static/js/{main.26fdbfbe.js.LICENSE.txt → main.8e44da4b.js.LICENSE.txt} +7 -0
- xinference/web/ui/build/static/js/main.8e44da4b.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/29dda700ab913cf7f2cfabe450ddabfb283e96adfa3ec9d315b2fa6c63cd375c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/53f6c0c0afb51265cd8fb940daeb65523501879ac2a8c03a1ead22b9793c5041.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8ccbb839002bc5bc03e0a0e7612362bf92f6ae64f87e094f8682d6a6fe4619bb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/97ed30d6e22cf76f0733651e2c18364689a01665d0b5fe811c1b7ca3eb713c82.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9c0c70f1838913aaa792a0d2260f17f90fd177b95698ed46b7bc3050eb712c1c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ada71518a429f821a9b1dea38bc951447f03c8db509887e0980b893acac938f3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6c9558d28b5972bb8b2691c5a76a2c8814a815eb3443126da9f49f7d6a0c118.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bb0f721c084a4d85c09201c984f02ee8437d3b6c5c38a57cb4a101f653daef1b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ddaec68b88e5eff792df1e39a4b4b8b737bfc832293c015660c3c69334e3cf5c.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +33 -0
- xinference/web/ui/node_modules/clipboard/.babelrc.json +11 -0
- xinference/web/ui/node_modules/clipboard/.eslintrc.json +24 -0
- xinference/web/ui/node_modules/clipboard/.prettierrc.json +9 -0
- xinference/web/ui/node_modules/clipboard/bower.json +18 -0
- xinference/web/ui/node_modules/clipboard/composer.json +25 -0
- xinference/web/ui/node_modules/clipboard/package.json +63 -0
- xinference/web/ui/node_modules/delegate/package.json +31 -0
- xinference/web/ui/node_modules/good-listener/bower.json +11 -0
- xinference/web/ui/node_modules/good-listener/package.json +35 -0
- xinference/web/ui/node_modules/select/bower.json +13 -0
- xinference/web/ui/node_modules/select/package.json +29 -0
- xinference/web/ui/node_modules/tiny-emitter/package.json +53 -0
- xinference/web/ui/package-lock.json +34 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/METADATA +14 -13
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/RECORD +81 -60
- xinference/client/oscar/__init__.py +0 -13
- xinference/client/oscar/actor_client.py +0 -611
- xinference/model/llm/pytorch/spec_decoding_utils.py +0 -531
- xinference/model/llm/pytorch/spec_model.py +0 -186
- xinference/web/ui/build/static/js/main.26fdbfbe.js +0 -3
- xinference/web/ui/build/static/js/main.26fdbfbe.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +0 -1
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/LICENSE +0 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/WHEEL +0 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.10.2.post1.dist-info → xinference-0.11.0.dist-info}/top_level.txt +0 -0
xinference/deploy/cmdline.py
CHANGED
|
@@ -17,7 +17,7 @@ import logging
|
|
|
17
17
|
import os
|
|
18
18
|
import sys
|
|
19
19
|
import warnings
|
|
20
|
-
from typing import List, Optional, Tuple, Union
|
|
20
|
+
from typing import List, Optional, Sequence, Tuple, Union
|
|
21
21
|
|
|
22
22
|
import click
|
|
23
23
|
from xoscar.utils import get_next_port
|
|
@@ -598,6 +598,13 @@ def list_model_registrations(
|
|
|
598
598
|
default="LLM",
|
|
599
599
|
help="Specify type of model, LLM as default.",
|
|
600
600
|
)
|
|
601
|
+
@click.option(
|
|
602
|
+
"--model-engine",
|
|
603
|
+
"-en",
|
|
604
|
+
type=str,
|
|
605
|
+
default=None,
|
|
606
|
+
help="Specify the inference engine of the model when launching LLM.",
|
|
607
|
+
)
|
|
601
608
|
@click.option(
|
|
602
609
|
"--model-uid",
|
|
603
610
|
"-u",
|
|
@@ -691,6 +698,7 @@ def model_launch(
|
|
|
691
698
|
endpoint: Optional[str],
|
|
692
699
|
model_name: str,
|
|
693
700
|
model_type: str,
|
|
701
|
+
model_engine: Optional[str],
|
|
694
702
|
model_uid: str,
|
|
695
703
|
size_in_billions: str,
|
|
696
704
|
model_format: str,
|
|
@@ -712,6 +720,9 @@ def model_launch(
|
|
|
712
720
|
kwargs[ctx.args[i][2:]] = handle_click_args_type(ctx.args[i + 1])
|
|
713
721
|
print(f"Launch model name: {model_name} with kwargs: {kwargs}", file=sys.stderr)
|
|
714
722
|
|
|
723
|
+
if model_type == "LLM" and model_engine is None:
|
|
724
|
+
raise ValueError("--model-engine is required for LLM models.")
|
|
725
|
+
|
|
715
726
|
if n_gpu.lower() == "none":
|
|
716
727
|
_n_gpu: Optional[Union[int, str]] = None
|
|
717
728
|
elif n_gpu == "auto":
|
|
@@ -736,11 +747,15 @@ def model_launch(
|
|
|
736
747
|
else []
|
|
737
748
|
)
|
|
738
749
|
|
|
739
|
-
peft_model_config =
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
750
|
+
peft_model_config = (
|
|
751
|
+
{
|
|
752
|
+
"image_lora_load_kwargs": image_lora_load_params,
|
|
753
|
+
"image_lora_fuse_kwargs": image_lora_fuse_params,
|
|
754
|
+
"lora_list": lora_list,
|
|
755
|
+
}
|
|
756
|
+
if lora_list or image_lora_load_params or image_lora_fuse_params
|
|
757
|
+
else None
|
|
758
|
+
)
|
|
744
759
|
|
|
745
760
|
_gpu_idx: Optional[List[int]] = (
|
|
746
761
|
None if gpu_idx is None else [int(idx) for idx in gpu_idx.split(",")]
|
|
@@ -761,6 +776,7 @@ def model_launch(
|
|
|
761
776
|
model_uid = client.launch_model(
|
|
762
777
|
model_name=model_name,
|
|
763
778
|
model_type=model_type,
|
|
779
|
+
model_engine=model_engine,
|
|
764
780
|
model_uid=model_uid,
|
|
765
781
|
model_size_in_billions=model_size,
|
|
766
782
|
model_format=model_format,
|
|
@@ -1199,5 +1215,158 @@ def cluster_login(
|
|
|
1199
1215
|
f.write(access_token)
|
|
1200
1216
|
|
|
1201
1217
|
|
|
1218
|
+
@cli.command(name="engine", help="Query the applicable inference engine by model name.")
|
|
1219
|
+
@click.option(
|
|
1220
|
+
"--model-name",
|
|
1221
|
+
"-n",
|
|
1222
|
+
type=str,
|
|
1223
|
+
required=True,
|
|
1224
|
+
help="The model name you want to query.",
|
|
1225
|
+
)
|
|
1226
|
+
@click.option(
|
|
1227
|
+
"--model-engine",
|
|
1228
|
+
"-en",
|
|
1229
|
+
type=str,
|
|
1230
|
+
default=None,
|
|
1231
|
+
help="Specify the `model_engine` to query the corresponding combination of other parameters.",
|
|
1232
|
+
)
|
|
1233
|
+
@click.option(
|
|
1234
|
+
"--model-format",
|
|
1235
|
+
"-f",
|
|
1236
|
+
type=str,
|
|
1237
|
+
default=None,
|
|
1238
|
+
help="Specify the `model_format` to query the corresponding combination of other parameters.",
|
|
1239
|
+
)
|
|
1240
|
+
@click.option(
|
|
1241
|
+
"--model-size-in-billions",
|
|
1242
|
+
"-s",
|
|
1243
|
+
type=str,
|
|
1244
|
+
default=None,
|
|
1245
|
+
help="Specify the `model_size_in_billions` to query the corresponding combination of other parameters.",
|
|
1246
|
+
)
|
|
1247
|
+
@click.option(
|
|
1248
|
+
"--quantization",
|
|
1249
|
+
"-q",
|
|
1250
|
+
type=str,
|
|
1251
|
+
default=None,
|
|
1252
|
+
help="Specify the `quantization` to query the corresponding combination of other parameters.",
|
|
1253
|
+
)
|
|
1254
|
+
@click.option("--endpoint", "-e", type=str, help="Xinference endpoint.")
|
|
1255
|
+
@click.option(
|
|
1256
|
+
"--api-key",
|
|
1257
|
+
"-ak",
|
|
1258
|
+
default=None,
|
|
1259
|
+
type=str,
|
|
1260
|
+
help="Api-Key for access xinference api with authorization.",
|
|
1261
|
+
)
|
|
1262
|
+
def query_engine_by_model_name(
|
|
1263
|
+
model_name: str,
|
|
1264
|
+
model_engine: Optional[str],
|
|
1265
|
+
model_format: Optional[str],
|
|
1266
|
+
model_size_in_billions: Optional[Union[str, int]],
|
|
1267
|
+
quantization: Optional[str],
|
|
1268
|
+
endpoint: Optional[str],
|
|
1269
|
+
api_key: Optional[str],
|
|
1270
|
+
):
|
|
1271
|
+
from tabulate import tabulate
|
|
1272
|
+
|
|
1273
|
+
def match_engine_from_spell(value: str, target: Sequence[str]) -> Tuple[bool, str]:
|
|
1274
|
+
"""
|
|
1275
|
+
For better usage experience.
|
|
1276
|
+
"""
|
|
1277
|
+
for t in target:
|
|
1278
|
+
if value.lower() == t.lower():
|
|
1279
|
+
return True, t
|
|
1280
|
+
return False, value
|
|
1281
|
+
|
|
1282
|
+
def handle_user_passed_parameters() -> List[str]:
|
|
1283
|
+
user_specified_parameters = []
|
|
1284
|
+
if model_engine is not None:
|
|
1285
|
+
user_specified_parameters.append(f"--model-engine {model_engine}")
|
|
1286
|
+
if model_format is not None:
|
|
1287
|
+
user_specified_parameters.append(f"--model-format {model_format}")
|
|
1288
|
+
if model_size_in_billions is not None:
|
|
1289
|
+
user_specified_parameters.append(
|
|
1290
|
+
f"--model-size-in-billions {model_size_in_billions}"
|
|
1291
|
+
)
|
|
1292
|
+
if quantization is not None:
|
|
1293
|
+
user_specified_parameters.append(f"--quantization {quantization}")
|
|
1294
|
+
return user_specified_parameters
|
|
1295
|
+
|
|
1296
|
+
user_specified_params = handle_user_passed_parameters()
|
|
1297
|
+
|
|
1298
|
+
endpoint = get_endpoint(endpoint)
|
|
1299
|
+
client = RESTfulClient(base_url=endpoint, api_key=api_key)
|
|
1300
|
+
if api_key is None:
|
|
1301
|
+
client._set_token(get_stored_token(endpoint, client))
|
|
1302
|
+
|
|
1303
|
+
llm_engines = client.query_engine_by_model_name(model_name)
|
|
1304
|
+
if model_engine is not None:
|
|
1305
|
+
is_matched, model_engine = match_engine_from_spell(
|
|
1306
|
+
model_engine, list(llm_engines.keys())
|
|
1307
|
+
)
|
|
1308
|
+
if not is_matched:
|
|
1309
|
+
print(
|
|
1310
|
+
f'Xinference does not support this inference engine "{model_engine}".',
|
|
1311
|
+
file=sys.stderr,
|
|
1312
|
+
)
|
|
1313
|
+
return
|
|
1314
|
+
|
|
1315
|
+
table = []
|
|
1316
|
+
engines = [model_engine] if model_engine is not None else list(llm_engines.keys())
|
|
1317
|
+
for engine in engines:
|
|
1318
|
+
params = llm_engines[engine]
|
|
1319
|
+
for param in params:
|
|
1320
|
+
if (
|
|
1321
|
+
(model_format is None or model_format == param["model_format"])
|
|
1322
|
+
and (
|
|
1323
|
+
model_size_in_billions is None
|
|
1324
|
+
or model_size_in_billions == str(param["model_size_in_billions"])
|
|
1325
|
+
)
|
|
1326
|
+
and (quantization is None or quantization in param["quantizations"])
|
|
1327
|
+
):
|
|
1328
|
+
if quantization is not None:
|
|
1329
|
+
table.append(
|
|
1330
|
+
[
|
|
1331
|
+
model_name,
|
|
1332
|
+
engine,
|
|
1333
|
+
param["model_format"],
|
|
1334
|
+
param["model_size_in_billions"],
|
|
1335
|
+
quantization,
|
|
1336
|
+
]
|
|
1337
|
+
)
|
|
1338
|
+
else:
|
|
1339
|
+
for quant in param["quantizations"]:
|
|
1340
|
+
table.append(
|
|
1341
|
+
[
|
|
1342
|
+
model_name,
|
|
1343
|
+
engine,
|
|
1344
|
+
param["model_format"],
|
|
1345
|
+
param["model_size_in_billions"],
|
|
1346
|
+
quant,
|
|
1347
|
+
]
|
|
1348
|
+
)
|
|
1349
|
+
if len(table) == 0:
|
|
1350
|
+
print(
|
|
1351
|
+
f"Xinference does not support "
|
|
1352
|
+
f"your provided params: {', '.join(user_specified_params)} for the model {model_name}.",
|
|
1353
|
+
file=sys.stderr,
|
|
1354
|
+
)
|
|
1355
|
+
else:
|
|
1356
|
+
print(
|
|
1357
|
+
tabulate(
|
|
1358
|
+
table,
|
|
1359
|
+
headers=[
|
|
1360
|
+
"Name",
|
|
1361
|
+
"Engine",
|
|
1362
|
+
"Format",
|
|
1363
|
+
"Size (in billions)",
|
|
1364
|
+
"Quantization",
|
|
1365
|
+
],
|
|
1366
|
+
),
|
|
1367
|
+
file=sys.stderr,
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1370
|
+
|
|
1202
1371
|
if __name__ == "__main__":
|
|
1203
1372
|
cli()
|
|
@@ -65,6 +65,7 @@ def test_cmdline(setup, stream, model_uid):
|
|
|
65
65
|
original_model_uid = model_uid
|
|
66
66
|
model_uid = client.launch_model(
|
|
67
67
|
model_name="orca",
|
|
68
|
+
model_engine="llama.cpp",
|
|
68
69
|
model_uid=model_uid,
|
|
69
70
|
model_size_in_billions=3,
|
|
70
71
|
quantization="q4_0",
|
|
@@ -247,6 +248,7 @@ def test_rotate_logs(setup_with_file_logging):
|
|
|
247
248
|
replica = 1 if os.name == "nt" else 2
|
|
248
249
|
model_uid = client.launch_model(
|
|
249
250
|
model_name="orca",
|
|
251
|
+
model_engine="llama.cpp",
|
|
250
252
|
model_uid=None,
|
|
251
253
|
model_size_in_billions=3,
|
|
252
254
|
quantization="q4_0",
|
xinference/deploy/utils.py
CHANGED
|
@@ -129,7 +129,7 @@ def health_check(address: str, max_attempts: int, sleep_interval: int = 3) -> bo
|
|
|
129
129
|
try:
|
|
130
130
|
from xinference.core.supervisor import SupervisorActor
|
|
131
131
|
|
|
132
|
-
supervisor_ref: xo.ActorRefType[SupervisorActor] = await xo.actor_ref(
|
|
132
|
+
supervisor_ref: xo.ActorRefType[SupervisorActor] = await xo.actor_ref( # type: ignore
|
|
133
133
|
address=address, uid=SupervisorActor.uid()
|
|
134
134
|
)
|
|
135
135
|
|
xinference/device_utils.py
CHANGED
|
@@ -17,13 +17,27 @@ import os
|
|
|
17
17
|
import torch
|
|
18
18
|
from typing_extensions import Literal, Union
|
|
19
19
|
|
|
20
|
-
DeviceType = Literal["cuda", "mps", "xpu", "cpu"]
|
|
20
|
+
DeviceType = Literal["cuda", "mps", "xpu", "npu", "cpu"]
|
|
21
|
+
DEVICE_TO_ENV_NAME = {
|
|
22
|
+
"cuda": "CUDA_VISIBLE_DEVICES",
|
|
23
|
+
"npu": "ASCEND_RT_VISIBLE_DEVICES",
|
|
24
|
+
}
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
def is_xpu_available() -> bool:
|
|
24
28
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
25
29
|
|
|
26
30
|
|
|
31
|
+
def is_npu_available() -> bool:
|
|
32
|
+
try:
|
|
33
|
+
import torch
|
|
34
|
+
import torch_npu # noqa: F401
|
|
35
|
+
|
|
36
|
+
return torch.npu.is_available()
|
|
37
|
+
except ImportError:
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
27
41
|
def get_available_device() -> DeviceType:
|
|
28
42
|
if torch.cuda.is_available():
|
|
29
43
|
return "cuda"
|
|
@@ -31,6 +45,8 @@ def get_available_device() -> DeviceType:
|
|
|
31
45
|
return "mps"
|
|
32
46
|
elif is_xpu_available():
|
|
33
47
|
return "xpu"
|
|
48
|
+
elif is_npu_available():
|
|
49
|
+
return "npu"
|
|
34
50
|
return "cpu"
|
|
35
51
|
|
|
36
52
|
|
|
@@ -41,6 +57,8 @@ def is_device_available(device: str) -> bool:
|
|
|
41
57
|
return torch.backends.mps.is_available()
|
|
42
58
|
elif device == "xpu":
|
|
43
59
|
return is_xpu_available()
|
|
60
|
+
elif device == "npu":
|
|
61
|
+
return is_npu_available()
|
|
44
62
|
elif device == "cpu":
|
|
45
63
|
return True
|
|
46
64
|
|
|
@@ -59,7 +77,7 @@ def move_model_to_available_device(model):
|
|
|
59
77
|
def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
|
|
60
78
|
if device == "cpu":
|
|
61
79
|
return torch.float32
|
|
62
|
-
elif device == "cuda" or device == "mps":
|
|
80
|
+
elif device == "cuda" or device == "mps" or device == "npu":
|
|
63
81
|
return torch.float16
|
|
64
82
|
elif device == "xpu":
|
|
65
83
|
return torch.bfloat16
|
|
@@ -68,7 +86,7 @@ def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
|
|
|
68
86
|
|
|
69
87
|
|
|
70
88
|
def is_hf_accelerate_supported(device: str) -> bool:
|
|
71
|
-
return device == "cuda" or device == "xpu"
|
|
89
|
+
return device == "cuda" or device == "xpu" or device == "npu"
|
|
72
90
|
|
|
73
91
|
|
|
74
92
|
def empty_cache():
|
|
@@ -78,6 +96,12 @@ def empty_cache():
|
|
|
78
96
|
torch.mps.empty_cache()
|
|
79
97
|
if is_xpu_available():
|
|
80
98
|
torch.xpu.empty_cache()
|
|
99
|
+
if is_npu_available():
|
|
100
|
+
torch.npu.empty_cache()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_available_device_env_name():
|
|
104
|
+
return DEVICE_TO_ENV_NAME.get(get_available_device())
|
|
81
105
|
|
|
82
106
|
|
|
83
107
|
def gpu_count():
|
|
@@ -94,5 +118,7 @@ def gpu_count():
|
|
|
94
118
|
return min(torch.cuda.device_count(), len(cuda_visible_devices))
|
|
95
119
|
elif is_xpu_available():
|
|
96
120
|
return torch.xpu.device_count()
|
|
121
|
+
elif is_npu_available():
|
|
122
|
+
return torch.npu.device_count()
|
|
97
123
|
else:
|
|
98
124
|
return 0
|
xinference/fields.py
CHANGED
|
@@ -32,7 +32,6 @@ logprobs_field = Field(
|
|
|
32
32
|
max_tokens_field = Field(
|
|
33
33
|
default=1024,
|
|
34
34
|
ge=1,
|
|
35
|
-
le=32768,
|
|
36
35
|
description="The maximum number of tokens to generate.",
|
|
37
36
|
)
|
|
38
37
|
|
|
@@ -75,6 +74,11 @@ stream_field = Field(
|
|
|
75
74
|
description="Whether to stream the results as they are generated. Useful for chatbots.",
|
|
76
75
|
)
|
|
77
76
|
|
|
77
|
+
stream_option_field = Field(
|
|
78
|
+
default={},
|
|
79
|
+
description="If set, an additional chunk will be streamed before the `data: [DONE]` message.",
|
|
80
|
+
)
|
|
81
|
+
|
|
78
82
|
top_k_field = Field(
|
|
79
83
|
default=40,
|
|
80
84
|
ge=0,
|
|
@@ -75,5 +75,12 @@
|
|
|
75
75
|
"model_id": "BELLE-2/Belle-whisper-large-v2-zh",
|
|
76
76
|
"model_revision": "ec5bd5d78598545b7585814edde86dac2002b5b9",
|
|
77
77
|
"multilingual": false
|
|
78
|
+
},
|
|
79
|
+
{
|
|
80
|
+
"model_name": "Belle-whisper-large-v3-zh",
|
|
81
|
+
"model_family": "whisper",
|
|
82
|
+
"model_id": "BELLE-2/Belle-whisper-large-v3-zh",
|
|
83
|
+
"model_revision": "3bebc7247696b39f5ab9ed22db426943ac33f600",
|
|
84
|
+
"multilingual": false
|
|
78
85
|
}
|
|
79
|
-
]
|
|
86
|
+
]
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
|
-
from typing import TYPE_CHECKING, Dict, Optional
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
16
16
|
|
|
17
17
|
from xinference.device_utils import (
|
|
18
18
|
get_available_device,
|
|
@@ -81,12 +81,87 @@ class WhisperModel:
|
|
|
81
81
|
audio: bytes,
|
|
82
82
|
generate_kwargs: Dict,
|
|
83
83
|
response_format: str,
|
|
84
|
+
temperature: float = 0,
|
|
85
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
84
86
|
):
|
|
87
|
+
if temperature != 0:
|
|
88
|
+
generate_kwargs.update({"temperature": temperature, "do_sample": True})
|
|
89
|
+
|
|
85
90
|
if response_format == "json":
|
|
86
91
|
logger.debug("Call whisper model with generate_kwargs: %s", generate_kwargs)
|
|
87
92
|
assert callable(self._model)
|
|
88
93
|
result = self._model(audio, generate_kwargs=generate_kwargs)
|
|
89
94
|
return {"text": result["text"]}
|
|
95
|
+
elif response_format == "verbose_json":
|
|
96
|
+
return_timestamps: Union[bool, str] = False
|
|
97
|
+
if not timestamp_granularities:
|
|
98
|
+
return_timestamps = True
|
|
99
|
+
elif timestamp_granularities == ["segment"]:
|
|
100
|
+
return_timestamps = True
|
|
101
|
+
elif timestamp_granularities == ["word"]:
|
|
102
|
+
return_timestamps = "word"
|
|
103
|
+
else:
|
|
104
|
+
raise Exception(
|
|
105
|
+
f"Unsupported timestamp_granularities: {timestamp_granularities}"
|
|
106
|
+
)
|
|
107
|
+
assert callable(self._model)
|
|
108
|
+
results = self._model(
|
|
109
|
+
audio,
|
|
110
|
+
generate_kwargs=generate_kwargs,
|
|
111
|
+
return_timestamps=return_timestamps,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
language = generate_kwargs.get("language", "english")
|
|
115
|
+
|
|
116
|
+
if return_timestamps is True:
|
|
117
|
+
segments: List[dict] = []
|
|
118
|
+
|
|
119
|
+
def _get_chunk_segment_json(idx, text, start, end):
|
|
120
|
+
find_start = 0
|
|
121
|
+
if segments:
|
|
122
|
+
find_start = segments[-1]["seek"] + len(segments[-1]["text"])
|
|
123
|
+
return {
|
|
124
|
+
"id": idx,
|
|
125
|
+
"seek": results["text"].find(text, find_start),
|
|
126
|
+
"start": start,
|
|
127
|
+
"end": end,
|
|
128
|
+
"text": text,
|
|
129
|
+
"tokens": [],
|
|
130
|
+
"temperature": temperature,
|
|
131
|
+
# We can't provide these values.
|
|
132
|
+
"avg_logprob": 0.0,
|
|
133
|
+
"compression_ratio": 0.0,
|
|
134
|
+
"no_speech_prob": 0.0,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
for idx, c in enumerate(results.get("chunks", [])):
|
|
138
|
+
text = c["text"]
|
|
139
|
+
start, end = c["timestamp"]
|
|
140
|
+
segments.append(_get_chunk_segment_json(idx, text, start, end))
|
|
141
|
+
|
|
142
|
+
return {
|
|
143
|
+
"task": "transcribe",
|
|
144
|
+
"language": language,
|
|
145
|
+
"duration": segments[-1]["end"] if segments else 0,
|
|
146
|
+
"text": results["text"],
|
|
147
|
+
"segments": segments,
|
|
148
|
+
}
|
|
149
|
+
else:
|
|
150
|
+
assert return_timestamps == "word"
|
|
151
|
+
|
|
152
|
+
words = []
|
|
153
|
+
for idx, c in enumerate(results.get("chunks", [])):
|
|
154
|
+
text = c["text"]
|
|
155
|
+
start, end = c["timestamp"]
|
|
156
|
+
words.append({"word": text, "start": start, "end": end})
|
|
157
|
+
|
|
158
|
+
return {
|
|
159
|
+
"task": "transcribe",
|
|
160
|
+
"language": language,
|
|
161
|
+
"duration": words[-1]["end"] if words else 0,
|
|
162
|
+
"text": results["text"],
|
|
163
|
+
"words": words,
|
|
164
|
+
}
|
|
90
165
|
else:
|
|
91
166
|
raise ValueError(f"Unsupported response format: {response_format}")
|
|
92
167
|
|
|
@@ -97,12 +172,8 @@ class WhisperModel:
|
|
|
97
172
|
prompt: Optional[str] = None,
|
|
98
173
|
response_format: str = "json",
|
|
99
174
|
temperature: float = 0,
|
|
175
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
100
176
|
):
|
|
101
|
-
if temperature != 0:
|
|
102
|
-
logger.warning(
|
|
103
|
-
"Temperature for whisper transcriptions will be ignored: %s.",
|
|
104
|
-
temperature,
|
|
105
|
-
)
|
|
106
177
|
if prompt is not None:
|
|
107
178
|
logger.warning(
|
|
108
179
|
"Prompt for whisper transcriptions will be ignored: %s", prompt
|
|
@@ -115,30 +186,35 @@ class WhisperModel:
|
|
|
115
186
|
else {"task": "transcribe"}
|
|
116
187
|
),
|
|
117
188
|
response_format=response_format,
|
|
189
|
+
temperature=temperature,
|
|
190
|
+
timestamp_granularities=timestamp_granularities,
|
|
118
191
|
)
|
|
119
192
|
|
|
120
193
|
def translations(
|
|
121
194
|
self,
|
|
122
195
|
audio: bytes,
|
|
196
|
+
language: Optional[str] = None,
|
|
123
197
|
prompt: Optional[str] = None,
|
|
124
198
|
response_format: str = "json",
|
|
125
199
|
temperature: float = 0,
|
|
200
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
126
201
|
):
|
|
127
202
|
if not self._model_spec.multilingual:
|
|
128
203
|
raise RuntimeError(
|
|
129
204
|
f"Model {self._model_spec.model_name} is not suitable for translations."
|
|
130
205
|
)
|
|
131
|
-
if temperature != 0:
|
|
132
|
-
logger.warning(
|
|
133
|
-
"Temperature for whisper transcriptions will be ignored: %s.",
|
|
134
|
-
temperature,
|
|
135
|
-
)
|
|
136
206
|
if prompt is not None:
|
|
137
207
|
logger.warning(
|
|
138
208
|
"Prompt for whisper transcriptions will be ignored: %s", prompt
|
|
139
209
|
)
|
|
140
210
|
return self._call_model(
|
|
141
211
|
audio=audio,
|
|
142
|
-
generate_kwargs=
|
|
212
|
+
generate_kwargs=(
|
|
213
|
+
{"language": language, "task": "translate"}
|
|
214
|
+
if language is not None
|
|
215
|
+
else {"task": "translate"}
|
|
216
|
+
),
|
|
143
217
|
response_format=response_format,
|
|
218
|
+
temperature=temperature,
|
|
219
|
+
timestamp_granularities=timestamp_granularities,
|
|
144
220
|
)
|
xinference/model/core.py
CHANGED
|
@@ -50,11 +50,11 @@ def create_model_instance(
|
|
|
50
50
|
model_uid: str,
|
|
51
51
|
model_type: str,
|
|
52
52
|
model_name: str,
|
|
53
|
+
model_engine: Optional[str],
|
|
53
54
|
model_format: Optional[str] = None,
|
|
54
55
|
model_size_in_billions: Optional[Union[int, str]] = None,
|
|
55
56
|
quantization: Optional[str] = None,
|
|
56
57
|
peft_model_config: Optional[PeftModelConfig] = None,
|
|
57
|
-
is_local_deployment: bool = False,
|
|
58
58
|
**kwargs,
|
|
59
59
|
) -> Tuple[Any, ModelDescription]:
|
|
60
60
|
from .audio.core import create_audio_model_instance
|
|
@@ -69,11 +69,11 @@ def create_model_instance(
|
|
|
69
69
|
devices,
|
|
70
70
|
model_uid,
|
|
71
71
|
model_name,
|
|
72
|
+
model_engine,
|
|
72
73
|
model_format,
|
|
73
74
|
model_size_in_billions,
|
|
74
75
|
quantization,
|
|
75
76
|
peft_model_config,
|
|
76
|
-
is_local_deployment,
|
|
77
77
|
**kwargs,
|
|
78
78
|
)
|
|
79
79
|
elif model_type == "embedding":
|
|
@@ -12,12 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import gc
|
|
15
16
|
import logging
|
|
17
|
+
import os
|
|
16
18
|
from collections import defaultdict
|
|
17
19
|
from typing import Dict, List, Optional, Tuple, Union, no_type_check
|
|
18
20
|
|
|
19
21
|
import numpy as np
|
|
20
22
|
|
|
23
|
+
from ...device_utils import empty_cache
|
|
21
24
|
from ...types import Embedding, EmbeddingData, EmbeddingUsage
|
|
22
25
|
from ..core import CacheableModelSpec, ModelDescription
|
|
23
26
|
from ..utils import get_cache_dir, is_model_cached
|
|
@@ -28,6 +31,10 @@ logger = logging.getLogger(__name__)
|
|
|
28
31
|
# Init when registering all the builtin models.
|
|
29
32
|
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
30
33
|
EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
34
|
+
EMBEDDING_EMPTY_CACHE_COUNT = int(
|
|
35
|
+
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
|
|
36
|
+
)
|
|
37
|
+
assert EMBEDDING_EMPTY_CACHE_COUNT > 0
|
|
31
38
|
|
|
32
39
|
|
|
33
40
|
def get_embedding_model_descriptions():
|
|
@@ -116,6 +123,7 @@ class EmbeddingModel:
|
|
|
116
123
|
self._model_path = model_path
|
|
117
124
|
self._device = device
|
|
118
125
|
self._model = None
|
|
126
|
+
self._counter = 0
|
|
119
127
|
|
|
120
128
|
def load(self):
|
|
121
129
|
try:
|
|
@@ -134,6 +142,11 @@ class EmbeddingModel:
|
|
|
134
142
|
self._model = SentenceTransformer(self._model_path, device=self._device)
|
|
135
143
|
|
|
136
144
|
def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
|
|
145
|
+
self._counter += 1
|
|
146
|
+
if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0:
|
|
147
|
+
logger.debug("Empty embedding cache.")
|
|
148
|
+
gc.collect()
|
|
149
|
+
empty_cache()
|
|
137
150
|
from sentence_transformers import SentenceTransformer
|
|
138
151
|
|
|
139
152
|
kwargs.setdefault("normalize_embeddings", True)
|
|
@@ -20,12 +20,19 @@ from itertools import chain
|
|
|
20
20
|
from .core import (
|
|
21
21
|
BUILTIN_IMAGE_MODELS,
|
|
22
22
|
IMAGE_MODEL_DESCRIPTIONS,
|
|
23
|
+
MODEL_NAME_TO_REVISION,
|
|
23
24
|
MODELSCOPE_IMAGE_MODELS,
|
|
24
25
|
ImageModelFamilyV1,
|
|
25
26
|
generate_image_description,
|
|
26
27
|
get_cache_status,
|
|
27
28
|
get_image_model_descriptions,
|
|
28
29
|
)
|
|
30
|
+
from .custom import (
|
|
31
|
+
CustomImageModelFamilyV1,
|
|
32
|
+
get_user_defined_images,
|
|
33
|
+
register_image,
|
|
34
|
+
unregister_image,
|
|
35
|
+
)
|
|
29
36
|
|
|
30
37
|
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
31
38
|
_model_spec_modelscope_json = os.path.join(
|
|
@@ -37,6 +44,9 @@ BUILTIN_IMAGE_MODELS.update(
|
|
|
37
44
|
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
|
|
38
45
|
)
|
|
39
46
|
)
|
|
47
|
+
for model_name, model_spec in BUILTIN_IMAGE_MODELS.items():
|
|
48
|
+
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
49
|
+
|
|
40
50
|
MODELSCOPE_IMAGE_MODELS.update(
|
|
41
51
|
dict(
|
|
42
52
|
(spec["model_name"], ImageModelFamilyV1(**spec))
|
|
@@ -45,6 +55,8 @@ MODELSCOPE_IMAGE_MODELS.update(
|
|
|
45
55
|
)
|
|
46
56
|
)
|
|
47
57
|
)
|
|
58
|
+
for model_name, model_spec in MODELSCOPE_IMAGE_MODELS.items():
|
|
59
|
+
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
48
60
|
|
|
49
61
|
# register model description
|
|
50
62
|
for model_name, model_spec in chain(
|
|
@@ -52,4 +64,21 @@ for model_name, model_spec in chain(
|
|
|
52
64
|
):
|
|
53
65
|
IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(model_spec))
|
|
54
66
|
|
|
67
|
+
from ...constants import XINFERENCE_MODEL_DIR
|
|
68
|
+
|
|
69
|
+
user_defined_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "image")
|
|
70
|
+
if os.path.isdir(user_defined_image_dir):
|
|
71
|
+
for f in os.listdir(user_defined_image_dir):
|
|
72
|
+
with codecs.open(
|
|
73
|
+
os.path.join(user_defined_image_dir, f), encoding="utf-8"
|
|
74
|
+
) as fd:
|
|
75
|
+
user_defined_image_family = CustomImageModelFamilyV1.parse_obj(
|
|
76
|
+
json.load(fd)
|
|
77
|
+
)
|
|
78
|
+
register_image(user_defined_image_family, persist=False)
|
|
79
|
+
|
|
80
|
+
for ud_image in get_user_defined_images():
|
|
81
|
+
IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(ud_image))
|
|
82
|
+
|
|
55
83
|
del _model_spec_json
|
|
84
|
+
del _model_spec_modelscope_json
|
xinference/model/image/core.py
CHANGED
|
@@ -27,6 +27,7 @@ MAX_ATTEMPTS = 3
|
|
|
27
27
|
|
|
28
28
|
logger = logging.getLogger(__name__)
|
|
29
29
|
|
|
30
|
+
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
30
31
|
IMAGE_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
31
32
|
BUILTIN_IMAGE_MODELS: Dict[str, "ImageModelFamilyV1"] = {}
|
|
32
33
|
MODELSCOPE_IMAGE_MODELS: Dict[str, "ImageModelFamilyV1"] = {}
|
|
@@ -119,6 +120,11 @@ def generate_image_description(
|
|
|
119
120
|
def match_diffusion(model_name: str) -> ImageModelFamilyV1:
|
|
120
121
|
from ..utils import download_from_modelscope
|
|
121
122
|
from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
|
|
123
|
+
from .custom import get_user_defined_images
|
|
124
|
+
|
|
125
|
+
for model_spec in get_user_defined_images():
|
|
126
|
+
if model_spec.model_name == model_name:
|
|
127
|
+
return model_spec
|
|
122
128
|
|
|
123
129
|
if download_from_modelscope():
|
|
124
130
|
if model_name in MODELSCOPE_IMAGE_MODELS:
|