xinference 0.10.3__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (101) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +1 -1
  3. xinference/api/restful_api.py +53 -61
  4. xinference/client/restful/restful_client.py +52 -57
  5. xinference/conftest.py +1 -1
  6. xinference/core/cache_tracker.py +1 -1
  7. xinference/core/chat_interface.py +10 -4
  8. xinference/core/event.py +1 -1
  9. xinference/core/model.py +17 -6
  10. xinference/core/status_guard.py +1 -1
  11. xinference/core/supervisor.py +58 -72
  12. xinference/core/worker.py +68 -101
  13. xinference/deploy/cmdline.py +166 -1
  14. xinference/deploy/test/test_cmdline.py +2 -0
  15. xinference/deploy/utils.py +1 -1
  16. xinference/device_utils.py +29 -3
  17. xinference/fields.py +7 -1
  18. xinference/model/audio/whisper.py +88 -12
  19. xinference/model/core.py +2 -2
  20. xinference/model/image/__init__.py +29 -0
  21. xinference/model/image/core.py +6 -0
  22. xinference/model/image/custom.py +109 -0
  23. xinference/model/llm/__init__.py +92 -32
  24. xinference/model/llm/core.py +57 -102
  25. xinference/model/llm/ggml/chatglm.py +98 -13
  26. xinference/model/llm/ggml/llamacpp.py +49 -2
  27. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +2 -2
  28. xinference/model/llm/llm_family.json +438 -7
  29. xinference/model/llm/llm_family.py +45 -41
  30. xinference/model/llm/llm_family_modelscope.json +258 -5
  31. xinference/model/llm/pytorch/chatglm.py +48 -0
  32. xinference/model/llm/pytorch/core.py +23 -6
  33. xinference/model/llm/pytorch/deepseek_vl.py +115 -33
  34. xinference/model/llm/pytorch/internlm2.py +32 -1
  35. xinference/model/llm/pytorch/qwen_vl.py +94 -12
  36. xinference/model/llm/pytorch/utils.py +38 -1
  37. xinference/model/llm/pytorch/yi_vl.py +96 -51
  38. xinference/model/llm/sglang/core.py +31 -9
  39. xinference/model/llm/utils.py +54 -20
  40. xinference/model/llm/vllm/core.py +101 -7
  41. xinference/thirdparty/omnilmm/chat.py +2 -1
  42. xinference/thirdparty/omnilmm/model/omnilmm.py +2 -1
  43. xinference/types.py +11 -0
  44. xinference/web/ui/build/asset-manifest.json +6 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.54bca460.css +2 -0
  47. xinference/web/ui/build/static/css/main.54bca460.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.551aa479.js +3 -0
  49. xinference/web/ui/build/static/js/{main.26fdbfbe.js.LICENSE.txt → main.551aa479.js.LICENSE.txt} +7 -0
  50. xinference/web/ui/build/static/js/main.551aa479.js.map +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/23caf6f1e52c43e983ca3bfd4189f41dbd645fa78f2dfdcd7f6b69bc41678665.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/29dda700ab913cf7f2cfabe450ddabfb283e96adfa3ec9d315b2fa6c63cd375c.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/53f6c0c0afb51265cd8fb940daeb65523501879ac2a8c03a1ead22b9793c5041.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/8ccbb839002bc5bc03e0a0e7612362bf92f6ae64f87e094f8682d6a6fe4619bb.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/97ed30d6e22cf76f0733651e2c18364689a01665d0b5fe811c1b7ca3eb713c82.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/9c0c70f1838913aaa792a0d2260f17f90fd177b95698ed46b7bc3050eb712c1c.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/ada71518a429f821a9b1dea38bc951447f03c8db509887e0980b893acac938f3.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/b6c9558d28b5972bb8b2691c5a76a2c8814a815eb3443126da9f49f7d6a0c118.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/bb0f721c084a4d85c09201c984f02ee8437d3b6c5c38a57cb4a101f653daef1b.json +1 -0
  68. xinference/web/ui/node_modules/.package-lock.json +33 -0
  69. xinference/web/ui/node_modules/clipboard/.babelrc.json +11 -0
  70. xinference/web/ui/node_modules/clipboard/.eslintrc.json +24 -0
  71. xinference/web/ui/node_modules/clipboard/.prettierrc.json +9 -0
  72. xinference/web/ui/node_modules/clipboard/bower.json +18 -0
  73. xinference/web/ui/node_modules/clipboard/composer.json +25 -0
  74. xinference/web/ui/node_modules/clipboard/package.json +63 -0
  75. xinference/web/ui/node_modules/delegate/package.json +31 -0
  76. xinference/web/ui/node_modules/good-listener/bower.json +11 -0
  77. xinference/web/ui/node_modules/good-listener/package.json +35 -0
  78. xinference/web/ui/node_modules/select/bower.json +13 -0
  79. xinference/web/ui/node_modules/select/package.json +29 -0
  80. xinference/web/ui/node_modules/tiny-emitter/package.json +53 -0
  81. xinference/web/ui/package-lock.json +34 -0
  82. xinference/web/ui/package.json +1 -0
  83. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/METADATA +13 -12
  84. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/RECORD +88 -67
  85. xinference/client/oscar/__init__.py +0 -13
  86. xinference/client/oscar/actor_client.py +0 -611
  87. xinference/model/llm/pytorch/spec_decoding_utils.py +0 -531
  88. xinference/model/llm/pytorch/spec_model.py +0 -186
  89. xinference/web/ui/build/static/js/main.26fdbfbe.js +0 -3
  90. xinference/web/ui/build/static/js/main.26fdbfbe.js.map +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +0 -1
  93. xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +0 -1
  94. xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +0 -1
  95. xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +0 -1
  96. xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +0 -1
  98. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/LICENSE +0 -0
  99. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/WHEEL +0 -0
  100. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/entry_points.txt +0 -0
  101. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import logging
17
17
  import os
18
18
  import sys
19
19
  import warnings
20
- from typing import List, Optional, Tuple, Union
20
+ from typing import List, Optional, Sequence, Tuple, Union
21
21
 
22
22
  import click
23
23
  from xoscar.utils import get_next_port
@@ -598,6 +598,13 @@ def list_model_registrations(
598
598
  default="LLM",
599
599
  help="Specify type of model, LLM as default.",
600
600
  )
601
+ @click.option(
602
+ "--model-engine",
603
+ "-en",
604
+ type=str,
605
+ default=None,
606
+ help="Specify the inference engine of the model when launching LLM.",
607
+ )
601
608
  @click.option(
602
609
  "--model-uid",
603
610
  "-u",
@@ -691,6 +698,7 @@ def model_launch(
691
698
  endpoint: Optional[str],
692
699
  model_name: str,
693
700
  model_type: str,
701
+ model_engine: Optional[str],
694
702
  model_uid: str,
695
703
  size_in_billions: str,
696
704
  model_format: str,
@@ -712,6 +720,9 @@ def model_launch(
712
720
  kwargs[ctx.args[i][2:]] = handle_click_args_type(ctx.args[i + 1])
713
721
  print(f"Launch model name: {model_name} with kwargs: {kwargs}", file=sys.stderr)
714
722
 
723
+ if model_type == "LLM" and model_engine is None:
724
+ raise ValueError("--model-engine is required for LLM models.")
725
+
715
726
  if n_gpu.lower() == "none":
716
727
  _n_gpu: Optional[Union[int, str]] = None
717
728
  elif n_gpu == "auto":
@@ -765,6 +776,7 @@ def model_launch(
765
776
  model_uid = client.launch_model(
766
777
  model_name=model_name,
767
778
  model_type=model_type,
779
+ model_engine=model_engine,
768
780
  model_uid=model_uid,
769
781
  model_size_in_billions=model_size,
770
782
  model_format=model_format,
@@ -1203,5 +1215,158 @@ def cluster_login(
1203
1215
  f.write(access_token)
1204
1216
 
1205
1217
 
1218
+ @cli.command(name="engine", help="Query the applicable inference engine by model name.")
1219
+ @click.option(
1220
+ "--model-name",
1221
+ "-n",
1222
+ type=str,
1223
+ required=True,
1224
+ help="The model name you want to query.",
1225
+ )
1226
+ @click.option(
1227
+ "--model-engine",
1228
+ "-en",
1229
+ type=str,
1230
+ default=None,
1231
+ help="Specify the `model_engine` to query the corresponding combination of other parameters.",
1232
+ )
1233
+ @click.option(
1234
+ "--model-format",
1235
+ "-f",
1236
+ type=str,
1237
+ default=None,
1238
+ help="Specify the `model_format` to query the corresponding combination of other parameters.",
1239
+ )
1240
+ @click.option(
1241
+ "--model-size-in-billions",
1242
+ "-s",
1243
+ type=str,
1244
+ default=None,
1245
+ help="Specify the `model_size_in_billions` to query the corresponding combination of other parameters.",
1246
+ )
1247
+ @click.option(
1248
+ "--quantization",
1249
+ "-q",
1250
+ type=str,
1251
+ default=None,
1252
+ help="Specify the `quantization` to query the corresponding combination of other parameters.",
1253
+ )
1254
+ @click.option("--endpoint", "-e", type=str, help="Xinference endpoint.")
1255
+ @click.option(
1256
+ "--api-key",
1257
+ "-ak",
1258
+ default=None,
1259
+ type=str,
1260
+ help="Api-Key for access xinference api with authorization.",
1261
+ )
1262
+ def query_engine_by_model_name(
1263
+ model_name: str,
1264
+ model_engine: Optional[str],
1265
+ model_format: Optional[str],
1266
+ model_size_in_billions: Optional[Union[str, int]],
1267
+ quantization: Optional[str],
1268
+ endpoint: Optional[str],
1269
+ api_key: Optional[str],
1270
+ ):
1271
+ from tabulate import tabulate
1272
+
1273
+ def match_engine_from_spell(value: str, target: Sequence[str]) -> Tuple[bool, str]:
1274
+ """
1275
+ For better usage experience.
1276
+ """
1277
+ for t in target:
1278
+ if value.lower() == t.lower():
1279
+ return True, t
1280
+ return False, value
1281
+
1282
+ def handle_user_passed_parameters() -> List[str]:
1283
+ user_specified_parameters = []
1284
+ if model_engine is not None:
1285
+ user_specified_parameters.append(f"--model-engine {model_engine}")
1286
+ if model_format is not None:
1287
+ user_specified_parameters.append(f"--model-format {model_format}")
1288
+ if model_size_in_billions is not None:
1289
+ user_specified_parameters.append(
1290
+ f"--model-size-in-billions {model_size_in_billions}"
1291
+ )
1292
+ if quantization is not None:
1293
+ user_specified_parameters.append(f"--quantization {quantization}")
1294
+ return user_specified_parameters
1295
+
1296
+ user_specified_params = handle_user_passed_parameters()
1297
+
1298
+ endpoint = get_endpoint(endpoint)
1299
+ client = RESTfulClient(base_url=endpoint, api_key=api_key)
1300
+ if api_key is None:
1301
+ client._set_token(get_stored_token(endpoint, client))
1302
+
1303
+ llm_engines = client.query_engine_by_model_name(model_name)
1304
+ if model_engine is not None:
1305
+ is_matched, model_engine = match_engine_from_spell(
1306
+ model_engine, list(llm_engines.keys())
1307
+ )
1308
+ if not is_matched:
1309
+ print(
1310
+ f'Xinference does not support this inference engine "{model_engine}".',
1311
+ file=sys.stderr,
1312
+ )
1313
+ return
1314
+
1315
+ table = []
1316
+ engines = [model_engine] if model_engine is not None else list(llm_engines.keys())
1317
+ for engine in engines:
1318
+ params = llm_engines[engine]
1319
+ for param in params:
1320
+ if (
1321
+ (model_format is None or model_format == param["model_format"])
1322
+ and (
1323
+ model_size_in_billions is None
1324
+ or model_size_in_billions == str(param["model_size_in_billions"])
1325
+ )
1326
+ and (quantization is None or quantization in param["quantizations"])
1327
+ ):
1328
+ if quantization is not None:
1329
+ table.append(
1330
+ [
1331
+ model_name,
1332
+ engine,
1333
+ param["model_format"],
1334
+ param["model_size_in_billions"],
1335
+ quantization,
1336
+ ]
1337
+ )
1338
+ else:
1339
+ for quant in param["quantizations"]:
1340
+ table.append(
1341
+ [
1342
+ model_name,
1343
+ engine,
1344
+ param["model_format"],
1345
+ param["model_size_in_billions"],
1346
+ quant,
1347
+ ]
1348
+ )
1349
+ if len(table) == 0:
1350
+ print(
1351
+ f"Xinference does not support "
1352
+ f"your provided params: {', '.join(user_specified_params)} for the model {model_name}.",
1353
+ file=sys.stderr,
1354
+ )
1355
+ else:
1356
+ print(
1357
+ tabulate(
1358
+ table,
1359
+ headers=[
1360
+ "Name",
1361
+ "Engine",
1362
+ "Format",
1363
+ "Size (in billions)",
1364
+ "Quantization",
1365
+ ],
1366
+ ),
1367
+ file=sys.stderr,
1368
+ )
1369
+
1370
+
1206
1371
  if __name__ == "__main__":
1207
1372
  cli()
@@ -65,6 +65,7 @@ def test_cmdline(setup, stream, model_uid):
65
65
  original_model_uid = model_uid
66
66
  model_uid = client.launch_model(
67
67
  model_name="orca",
68
+ model_engine="llama.cpp",
68
69
  model_uid=model_uid,
69
70
  model_size_in_billions=3,
70
71
  quantization="q4_0",
@@ -247,6 +248,7 @@ def test_rotate_logs(setup_with_file_logging):
247
248
  replica = 1 if os.name == "nt" else 2
248
249
  model_uid = client.launch_model(
249
250
  model_name="orca",
251
+ model_engine="llama.cpp",
250
252
  model_uid=None,
251
253
  model_size_in_billions=3,
252
254
  quantization="q4_0",
@@ -129,7 +129,7 @@ def health_check(address: str, max_attempts: int, sleep_interval: int = 3) -> bo
129
129
  try:
130
130
  from xinference.core.supervisor import SupervisorActor
131
131
 
132
- supervisor_ref: xo.ActorRefType[SupervisorActor] = await xo.actor_ref(
132
+ supervisor_ref: xo.ActorRefType[SupervisorActor] = await xo.actor_ref( # type: ignore
133
133
  address=address, uid=SupervisorActor.uid()
134
134
  )
135
135
 
@@ -17,13 +17,27 @@ import os
17
17
  import torch
18
18
  from typing_extensions import Literal, Union
19
19
 
20
- DeviceType = Literal["cuda", "mps", "xpu", "cpu"]
20
+ DeviceType = Literal["cuda", "mps", "xpu", "npu", "cpu"]
21
+ DEVICE_TO_ENV_NAME = {
22
+ "cuda": "CUDA_VISIBLE_DEVICES",
23
+ "npu": "ASCEND_RT_VISIBLE_DEVICES",
24
+ }
21
25
 
22
26
 
23
27
  def is_xpu_available() -> bool:
24
28
  return hasattr(torch, "xpu") and torch.xpu.is_available()
25
29
 
26
30
 
31
+ def is_npu_available() -> bool:
32
+ try:
33
+ import torch
34
+ import torch_npu # noqa: F401
35
+
36
+ return torch.npu.is_available()
37
+ except ImportError:
38
+ return False
39
+
40
+
27
41
  def get_available_device() -> DeviceType:
28
42
  if torch.cuda.is_available():
29
43
  return "cuda"
@@ -31,6 +45,8 @@ def get_available_device() -> DeviceType:
31
45
  return "mps"
32
46
  elif is_xpu_available():
33
47
  return "xpu"
48
+ elif is_npu_available():
49
+ return "npu"
34
50
  return "cpu"
35
51
 
36
52
 
@@ -41,6 +57,8 @@ def is_device_available(device: str) -> bool:
41
57
  return torch.backends.mps.is_available()
42
58
  elif device == "xpu":
43
59
  return is_xpu_available()
60
+ elif device == "npu":
61
+ return is_npu_available()
44
62
  elif device == "cpu":
45
63
  return True
46
64
 
@@ -59,7 +77,7 @@ def move_model_to_available_device(model):
59
77
  def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
60
78
  if device == "cpu":
61
79
  return torch.float32
62
- elif device == "cuda" or device == "mps":
80
+ elif device == "cuda" or device == "mps" or device == "npu":
63
81
  return torch.float16
64
82
  elif device == "xpu":
65
83
  return torch.bfloat16
@@ -68,7 +86,7 @@ def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
68
86
 
69
87
 
70
88
  def is_hf_accelerate_supported(device: str) -> bool:
71
- return device == "cuda" or device == "xpu"
89
+ return device == "cuda" or device == "xpu" or device == "npu"
72
90
 
73
91
 
74
92
  def empty_cache():
@@ -78,6 +96,12 @@ def empty_cache():
78
96
  torch.mps.empty_cache()
79
97
  if is_xpu_available():
80
98
  torch.xpu.empty_cache()
99
+ if is_npu_available():
100
+ torch.npu.empty_cache()
101
+
102
+
103
+ def get_available_device_env_name():
104
+ return DEVICE_TO_ENV_NAME.get(get_available_device())
81
105
 
82
106
 
83
107
  def gpu_count():
@@ -94,5 +118,7 @@ def gpu_count():
94
118
  return min(torch.cuda.device_count(), len(cuda_visible_devices))
95
119
  elif is_xpu_available():
96
120
  return torch.xpu.device_count()
121
+ elif is_npu_available():
122
+ return torch.npu.device_count()
97
123
  else:
98
124
  return 0
xinference/fields.py CHANGED
@@ -32,7 +32,6 @@ logprobs_field = Field(
32
32
  max_tokens_field = Field(
33
33
  default=1024,
34
34
  ge=1,
35
- le=32768,
36
35
  description="The maximum number of tokens to generate.",
37
36
  )
38
37
 
@@ -75,6 +74,13 @@ stream_field = Field(
75
74
  description="Whether to stream the results as they are generated. Useful for chatbots.",
76
75
  )
77
76
 
77
+ stream_option_field = Field(
78
+ default={
79
+ "include_usage": False,
80
+ },
81
+ description="If set, an additional chunk will be streamed before the `data: [DONE]` message.",
82
+ )
83
+
78
84
  top_k_field = Field(
79
85
  default=40,
80
86
  ge=0,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- from typing import TYPE_CHECKING, Dict, Optional
15
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
16
16
 
17
17
  from xinference.device_utils import (
18
18
  get_available_device,
@@ -81,12 +81,87 @@ class WhisperModel:
81
81
  audio: bytes,
82
82
  generate_kwargs: Dict,
83
83
  response_format: str,
84
+ temperature: float = 0,
85
+ timestamp_granularities: Optional[List[str]] = None,
84
86
  ):
87
+ if temperature != 0:
88
+ generate_kwargs.update({"temperature": temperature, "do_sample": True})
89
+
85
90
  if response_format == "json":
86
91
  logger.debug("Call whisper model with generate_kwargs: %s", generate_kwargs)
87
92
  assert callable(self._model)
88
93
  result = self._model(audio, generate_kwargs=generate_kwargs)
89
94
  return {"text": result["text"]}
95
+ elif response_format == "verbose_json":
96
+ return_timestamps: Union[bool, str] = False
97
+ if not timestamp_granularities:
98
+ return_timestamps = True
99
+ elif timestamp_granularities == ["segment"]:
100
+ return_timestamps = True
101
+ elif timestamp_granularities == ["word"]:
102
+ return_timestamps = "word"
103
+ else:
104
+ raise Exception(
105
+ f"Unsupported timestamp_granularities: {timestamp_granularities}"
106
+ )
107
+ assert callable(self._model)
108
+ results = self._model(
109
+ audio,
110
+ generate_kwargs=generate_kwargs,
111
+ return_timestamps=return_timestamps,
112
+ )
113
+
114
+ language = generate_kwargs.get("language", "english")
115
+
116
+ if return_timestamps is True:
117
+ segments: List[dict] = []
118
+
119
+ def _get_chunk_segment_json(idx, text, start, end):
120
+ find_start = 0
121
+ if segments:
122
+ find_start = segments[-1]["seek"] + len(segments[-1]["text"])
123
+ return {
124
+ "id": idx,
125
+ "seek": results["text"].find(text, find_start),
126
+ "start": start,
127
+ "end": end,
128
+ "text": text,
129
+ "tokens": [],
130
+ "temperature": temperature,
131
+ # We can't provide these values.
132
+ "avg_logprob": 0.0,
133
+ "compression_ratio": 0.0,
134
+ "no_speech_prob": 0.0,
135
+ }
136
+
137
+ for idx, c in enumerate(results.get("chunks", [])):
138
+ text = c["text"]
139
+ start, end = c["timestamp"]
140
+ segments.append(_get_chunk_segment_json(idx, text, start, end))
141
+
142
+ return {
143
+ "task": "transcribe",
144
+ "language": language,
145
+ "duration": segments[-1]["end"] if segments else 0,
146
+ "text": results["text"],
147
+ "segments": segments,
148
+ }
149
+ else:
150
+ assert return_timestamps == "word"
151
+
152
+ words = []
153
+ for idx, c in enumerate(results.get("chunks", [])):
154
+ text = c["text"]
155
+ start, end = c["timestamp"]
156
+ words.append({"word": text, "start": start, "end": end})
157
+
158
+ return {
159
+ "task": "transcribe",
160
+ "language": language,
161
+ "duration": words[-1]["end"] if words else 0,
162
+ "text": results["text"],
163
+ "words": words,
164
+ }
90
165
  else:
91
166
  raise ValueError(f"Unsupported response format: {response_format}")
92
167
 
@@ -97,12 +172,8 @@ class WhisperModel:
97
172
  prompt: Optional[str] = None,
98
173
  response_format: str = "json",
99
174
  temperature: float = 0,
175
+ timestamp_granularities: Optional[List[str]] = None,
100
176
  ):
101
- if temperature != 0:
102
- logger.warning(
103
- "Temperature for whisper transcriptions will be ignored: %s.",
104
- temperature,
105
- )
106
177
  if prompt is not None:
107
178
  logger.warning(
108
179
  "Prompt for whisper transcriptions will be ignored: %s", prompt
@@ -115,30 +186,35 @@ class WhisperModel:
115
186
  else {"task": "transcribe"}
116
187
  ),
117
188
  response_format=response_format,
189
+ temperature=temperature,
190
+ timestamp_granularities=timestamp_granularities,
118
191
  )
119
192
 
120
193
  def translations(
121
194
  self,
122
195
  audio: bytes,
196
+ language: Optional[str] = None,
123
197
  prompt: Optional[str] = None,
124
198
  response_format: str = "json",
125
199
  temperature: float = 0,
200
+ timestamp_granularities: Optional[List[str]] = None,
126
201
  ):
127
202
  if not self._model_spec.multilingual:
128
203
  raise RuntimeError(
129
204
  f"Model {self._model_spec.model_name} is not suitable for translations."
130
205
  )
131
- if temperature != 0:
132
- logger.warning(
133
- "Temperature for whisper transcriptions will be ignored: %s.",
134
- temperature,
135
- )
136
206
  if prompt is not None:
137
207
  logger.warning(
138
208
  "Prompt for whisper transcriptions will be ignored: %s", prompt
139
209
  )
140
210
  return self._call_model(
141
211
  audio=audio,
142
- generate_kwargs={"task": "translate"},
212
+ generate_kwargs=(
213
+ {"language": language, "task": "translate"}
214
+ if language is not None
215
+ else {"task": "translate"}
216
+ ),
143
217
  response_format=response_format,
218
+ temperature=temperature,
219
+ timestamp_granularities=timestamp_granularities,
144
220
  )
xinference/model/core.py CHANGED
@@ -50,11 +50,11 @@ def create_model_instance(
50
50
  model_uid: str,
51
51
  model_type: str,
52
52
  model_name: str,
53
+ model_engine: Optional[str],
53
54
  model_format: Optional[str] = None,
54
55
  model_size_in_billions: Optional[Union[int, str]] = None,
55
56
  quantization: Optional[str] = None,
56
57
  peft_model_config: Optional[PeftModelConfig] = None,
57
- is_local_deployment: bool = False,
58
58
  **kwargs,
59
59
  ) -> Tuple[Any, ModelDescription]:
60
60
  from .audio.core import create_audio_model_instance
@@ -69,11 +69,11 @@ def create_model_instance(
69
69
  devices,
70
70
  model_uid,
71
71
  model_name,
72
+ model_engine,
72
73
  model_format,
73
74
  model_size_in_billions,
74
75
  quantization,
75
76
  peft_model_config,
76
- is_local_deployment,
77
77
  **kwargs,
78
78
  )
79
79
  elif model_type == "embedding":
@@ -20,12 +20,19 @@ from itertools import chain
20
20
  from .core import (
21
21
  BUILTIN_IMAGE_MODELS,
22
22
  IMAGE_MODEL_DESCRIPTIONS,
23
+ MODEL_NAME_TO_REVISION,
23
24
  MODELSCOPE_IMAGE_MODELS,
24
25
  ImageModelFamilyV1,
25
26
  generate_image_description,
26
27
  get_cache_status,
27
28
  get_image_model_descriptions,
28
29
  )
30
+ from .custom import (
31
+ CustomImageModelFamilyV1,
32
+ get_user_defined_images,
33
+ register_image,
34
+ unregister_image,
35
+ )
29
36
 
30
37
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
31
38
  _model_spec_modelscope_json = os.path.join(
@@ -37,6 +44,9 @@ BUILTIN_IMAGE_MODELS.update(
37
44
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
38
45
  )
39
46
  )
47
+ for model_name, model_spec in BUILTIN_IMAGE_MODELS.items():
48
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
49
+
40
50
  MODELSCOPE_IMAGE_MODELS.update(
41
51
  dict(
42
52
  (spec["model_name"], ImageModelFamilyV1(**spec))
@@ -45,6 +55,8 @@ MODELSCOPE_IMAGE_MODELS.update(
45
55
  )
46
56
  )
47
57
  )
58
+ for model_name, model_spec in MODELSCOPE_IMAGE_MODELS.items():
59
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
48
60
 
49
61
  # register model description
50
62
  for model_name, model_spec in chain(
@@ -52,4 +64,21 @@ for model_name, model_spec in chain(
52
64
  ):
53
65
  IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(model_spec))
54
66
 
67
+ from ...constants import XINFERENCE_MODEL_DIR
68
+
69
+ user_defined_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "image")
70
+ if os.path.isdir(user_defined_image_dir):
71
+ for f in os.listdir(user_defined_image_dir):
72
+ with codecs.open(
73
+ os.path.join(user_defined_image_dir, f), encoding="utf-8"
74
+ ) as fd:
75
+ user_defined_image_family = CustomImageModelFamilyV1.parse_obj(
76
+ json.load(fd)
77
+ )
78
+ register_image(user_defined_image_family, persist=False)
79
+
80
+ for ud_image in get_user_defined_images():
81
+ IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(ud_image))
82
+
55
83
  del _model_spec_json
84
+ del _model_spec_modelscope_json
@@ -27,6 +27,7 @@ MAX_ATTEMPTS = 3
27
27
 
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
+ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
30
31
  IMAGE_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
31
32
  BUILTIN_IMAGE_MODELS: Dict[str, "ImageModelFamilyV1"] = {}
32
33
  MODELSCOPE_IMAGE_MODELS: Dict[str, "ImageModelFamilyV1"] = {}
@@ -119,6 +120,11 @@ def generate_image_description(
119
120
  def match_diffusion(model_name: str) -> ImageModelFamilyV1:
120
121
  from ..utils import download_from_modelscope
121
122
  from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
123
+ from .custom import get_user_defined_images
124
+
125
+ for model_spec in get_user_defined_images():
126
+ if model_spec.model_name == model_name:
127
+ return model_spec
122
128
 
123
129
  if download_from_modelscope():
124
130
  if model_name in MODELSCOPE_IMAGE_MODELS: