xinference 1.7.0.post1__py3-none-any.whl → 1.7.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (83) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +3 -4
  3. xinference/client/__init__.py +2 -0
  4. xinference/client/common.py +49 -2
  5. xinference/client/handlers.py +18 -0
  6. xinference/client/restful/async_restful_client.py +1760 -0
  7. xinference/client/restful/restful_client.py +74 -78
  8. xinference/core/media_interface.py +3 -1
  9. xinference/core/model.py +5 -4
  10. xinference/core/supervisor.py +10 -5
  11. xinference/core/worker.py +15 -14
  12. xinference/deploy/local.py +51 -9
  13. xinference/deploy/worker.py +5 -3
  14. xinference/device_utils.py +22 -3
  15. xinference/model/audio/fish_speech.py +23 -34
  16. xinference/model/audio/model_spec.json +4 -2
  17. xinference/model/audio/model_spec_modelscope.json +4 -2
  18. xinference/model/audio/utils.py +2 -2
  19. xinference/model/core.py +1 -0
  20. xinference/model/embedding/__init__.py +8 -8
  21. xinference/model/embedding/custom.py +6 -1
  22. xinference/model/embedding/embed_family.py +0 -41
  23. xinference/model/embedding/model_spec.json +10 -1
  24. xinference/model/embedding/model_spec_modelscope.json +10 -1
  25. xinference/model/embedding/sentence_transformers/core.py +30 -15
  26. xinference/model/flexible/core.py +1 -1
  27. xinference/model/flexible/launchers/__init__.py +2 -0
  28. xinference/model/flexible/launchers/image_process_launcher.py +1 -1
  29. xinference/model/flexible/launchers/modelscope_launcher.py +47 -0
  30. xinference/model/flexible/launchers/transformers_launcher.py +5 -5
  31. xinference/model/flexible/launchers/yolo_launcher.py +62 -0
  32. xinference/model/llm/__init__.py +7 -0
  33. xinference/model/llm/core.py +18 -1
  34. xinference/model/llm/llama_cpp/core.py +1 -1
  35. xinference/model/llm/llm_family.json +41 -1
  36. xinference/model/llm/llm_family.py +6 -0
  37. xinference/model/llm/llm_family_modelscope.json +43 -1
  38. xinference/model/llm/mlx/core.py +271 -18
  39. xinference/model/llm/mlx/distributed_models/__init__.py +13 -0
  40. xinference/model/llm/mlx/distributed_models/core.py +164 -0
  41. xinference/model/llm/mlx/distributed_models/deepseek_v3.py +75 -0
  42. xinference/model/llm/mlx/distributed_models/qwen2.py +82 -0
  43. xinference/model/llm/mlx/distributed_models/qwen3.py +82 -0
  44. xinference/model/llm/mlx/distributed_models/qwen3_moe.py +76 -0
  45. xinference/model/llm/reasoning_parser.py +12 -6
  46. xinference/model/llm/sglang/core.py +8 -4
  47. xinference/model/llm/transformers/chatglm.py +4 -1
  48. xinference/model/llm/transformers/core.py +4 -2
  49. xinference/model/llm/transformers/multimodal/cogagent.py +10 -4
  50. xinference/model/llm/transformers/multimodal/intern_vl.py +1 -1
  51. xinference/model/llm/utils.py +36 -17
  52. xinference/model/llm/vllm/core.py +142 -34
  53. xinference/model/llm/vllm/distributed_executor.py +96 -21
  54. xinference/model/llm/vllm/xavier/transfer.py +2 -2
  55. xinference/model/rerank/core.py +16 -9
  56. xinference/model/rerank/model_spec.json +3 -3
  57. xinference/model/rerank/model_spec_modelscope.json +3 -3
  58. xinference/web/ui/build/asset-manifest.json +3 -3
  59. xinference/web/ui/build/index.html +1 -1
  60. xinference/web/ui/build/static/js/main.9b12b7f9.js +3 -0
  61. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/0fd4820d93f99509e80d8702dc3f6f8272424acab5608fa7c0e82cb1d3250a87.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/f75545479c17fdfe2a00235fa4a0e9da1ae95e6b3caafba87ded92de6b0240e4.json +1 -0
  67. xinference/web/ui/src/locales/en.json +3 -0
  68. xinference/web/ui/src/locales/ja.json +3 -0
  69. xinference/web/ui/src/locales/ko.json +3 -0
  70. xinference/web/ui/src/locales/zh.json +3 -0
  71. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/METADATA +4 -3
  72. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/RECORD +77 -67
  73. xinference/web/ui/build/static/js/main.8a9e3ba0.js +0 -3
  74. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/cc97b49285d7717c63374766c789141a4329a04582ab32756d7e0e614d4c5c7f.json +0 -1
  79. /xinference/web/ui/build/static/js/{main.8a9e3ba0.js.LICENSE.txt → main.9b12b7f9.js.LICENSE.txt} +0 -0
  80. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/WHEEL +0 -0
  81. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/entry_points.txt +0 -0
  82. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/licenses/LICENSE +0 -0
  83. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,8 @@ import logging
17
17
  import multiprocessing
18
18
  import signal
19
19
  import sys
20
+ import traceback
21
+ from multiprocessing.connection import Connection
20
22
  from typing import Dict, Optional
21
23
 
22
24
  import xoscar as xo
@@ -25,6 +27,7 @@ from xoscar.utils import get_next_port
25
27
  from ..constants import (
26
28
  XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
27
29
  XINFERENCE_HEALTH_CHECK_INTERVAL,
30
+ XINFERENCE_HEALTH_CHECK_TIMEOUT,
28
31
  )
29
32
  from ..core.supervisor import SupervisorActor
30
33
  from .utils import health_check
@@ -33,11 +36,15 @@ from .worker import start_worker_components
33
36
  logger = logging.getLogger(__name__)
34
37
 
35
38
 
39
+ READY = "ok"
40
+
41
+
36
42
  async def _start_local_cluster(
37
43
  address: str,
38
44
  metrics_exporter_host: Optional[str] = None,
39
45
  metrics_exporter_port: Optional[int] = None,
40
46
  logging_conf: Optional[Dict] = None,
47
+ conn: Optional[Connection] = None,
41
48
  ):
42
49
  from .utils import create_worker_actor_pool
43
50
 
@@ -59,6 +66,13 @@ async def _start_local_cluster(
59
66
  metrics_exporter_host=metrics_exporter_host,
60
67
  metrics_exporter_port=metrics_exporter_port,
61
68
  )
69
+ if conn:
70
+ try:
71
+ conn.send(READY)
72
+ except BrokenPipeError:
73
+ # connection may be gc collected,
74
+ # just ignore this error
75
+ pass
62
76
  await pool.join()
63
77
  except asyncio.CancelledError:
64
78
  if pool is not None:
@@ -70,22 +84,36 @@ def run(
70
84
  metrics_exporter_host: Optional[str] = None,
71
85
  metrics_exporter_port: Optional[int] = None,
72
86
  logging_conf: Optional[Dict] = None,
87
+ conn: Optional[Connection] = None,
73
88
  ):
74
89
  def sigterm_handler(signum, frame):
75
90
  sys.exit(0)
76
91
 
77
92
  signal.signal(signal.SIGTERM, sigterm_handler)
78
93
 
79
- loop = asyncio.get_event_loop()
80
- task = loop.create_task(
81
- _start_local_cluster(
82
- address=address,
83
- metrics_exporter_host=metrics_exporter_host,
84
- metrics_exporter_port=metrics_exporter_port,
85
- logging_conf=logging_conf,
94
+ try:
95
+ loop = asyncio.get_event_loop()
96
+ task = loop.create_task(
97
+ _start_local_cluster(
98
+ address=address,
99
+ metrics_exporter_host=metrics_exporter_host,
100
+ metrics_exporter_port=metrics_exporter_port,
101
+ logging_conf=logging_conf,
102
+ conn=conn,
103
+ )
86
104
  )
87
- )
88
- loop.run_until_complete(task)
105
+ loop.run_until_complete(task)
106
+ except:
107
+ tb = traceback.format_exc()
108
+ if conn:
109
+ try:
110
+ conn.send(f"error: {tb}")
111
+ except BrokenPipeError:
112
+ # connection may be gc collected,
113
+ # just ignore this error
114
+ pass
115
+ # raise again in subprocess
116
+ raise
89
117
 
90
118
 
91
119
  def run_in_subprocess(
@@ -94,11 +122,25 @@ def run_in_subprocess(
94
122
  metrics_exporter_port: Optional[int] = None,
95
123
  logging_conf: Optional[Dict] = None,
96
124
  ) -> multiprocessing.Process:
125
+ parent_conn, child_conn = multiprocessing.Pipe()
97
126
  p = multiprocessing.Process(
98
127
  target=run,
99
128
  args=(address, metrics_exporter_host, metrics_exporter_port, logging_conf),
129
+ kwargs={"conn": child_conn},
100
130
  )
131
+ # Since Xoscar 0.7, we do not uses multiprocessing to create subpool any more,
132
+ # we should be able to use daemon here
133
+ p.daemon = True
101
134
  p.start()
135
+ if parent_conn.poll(timeout=XINFERENCE_HEALTH_CHECK_TIMEOUT):
136
+ msg = parent_conn.recv()
137
+ if msg != READY:
138
+ raise RuntimeError(f"Start service process failed during startup:\n{msg}")
139
+ else:
140
+ logger.info(
141
+ "No response from process after %s seconds", XINFERENCE_HEALTH_CHECK_TIMEOUT
142
+ )
143
+
102
144
  return p
103
145
 
104
146
 
@@ -21,7 +21,7 @@ import xoscar as xo
21
21
  from xoscar import MainActorPoolType
22
22
 
23
23
  from ..core.worker import WorkerActor
24
- from ..device_utils import gpu_count
24
+ from ..device_utils import get_available_device_env_name, gpu_count
25
25
 
26
26
  logger = logging.getLogger(__name__)
27
27
 
@@ -34,8 +34,10 @@ async def start_worker_components(
34
34
  metrics_exporter_port: Optional[int],
35
35
  ):
36
36
  gpu_device_indices = []
37
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
38
- if cuda_visible_devices is not None and cuda_visible_devices != "-1":
37
+ env_name = get_available_device_env_name()
38
+ cuda_visible_devices = os.environ.get(env_name) if env_name else None
39
+
40
+ if cuda_visible_devices and cuda_visible_devices != "-1":
39
41
  gpu_device_indices.extend([int(i) for i in cuda_visible_devices.split(",")])
40
42
  else:
41
43
  gpu_device_indices = list(range(gpu_count()))
@@ -17,10 +17,11 @@ from typing import Dict, Literal, Union
17
17
 
18
18
  import torch
19
19
 
20
- DeviceType = Literal["cuda", "mps", "xpu", "npu", "cpu"]
20
+ DeviceType = Literal["cuda", "mps", "xpu", "npu", "mlu", "cpu"]
21
21
  DEVICE_TO_ENV_NAME = {
22
22
  "cuda": "CUDA_VISIBLE_DEVICES",
23
23
  "npu": "ASCEND_RT_VISIBLE_DEVICES",
24
+ "mlu": "MLU_VISIBLE_DEVICES",
24
25
  }
25
26
 
26
27
 
@@ -38,6 +39,16 @@ def is_npu_available() -> bool:
38
39
  return False
39
40
 
40
41
 
42
+ def is_mlu_available() -> bool:
43
+ try:
44
+ import torch
45
+ import torch_mlu # noqa: F401
46
+
47
+ return torch.mlu.is_available()
48
+ except ImportError:
49
+ return False
50
+
51
+
41
52
  def get_available_device() -> DeviceType:
42
53
  if torch.cuda.is_available():
43
54
  return "cuda"
@@ -47,6 +58,8 @@ def get_available_device() -> DeviceType:
47
58
  return "xpu"
48
59
  elif is_npu_available():
49
60
  return "npu"
61
+ elif is_mlu_available():
62
+ return "mlu"
50
63
  return "cpu"
51
64
 
52
65
 
@@ -59,6 +72,8 @@ def is_device_available(device: str) -> bool:
59
72
  return is_xpu_available()
60
73
  elif device == "npu":
61
74
  return is_npu_available()
75
+ elif device == "mlu":
76
+ return is_mlu_available()
62
77
  elif device == "cpu":
63
78
  return True
64
79
 
@@ -77,7 +92,7 @@ def move_model_to_available_device(model):
77
92
  def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
78
93
  if device == "cpu":
79
94
  return torch.float32
80
- elif device == "cuda" or device == "mps" or device == "npu":
95
+ elif device == "cuda" or device == "mps" or device == "npu" or device == "mlu":
81
96
  return torch.float16
82
97
  elif device == "xpu":
83
98
  return torch.bfloat16
@@ -86,7 +101,7 @@ def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
86
101
 
87
102
 
88
103
  def is_hf_accelerate_supported(device: str) -> bool:
89
- return device == "cuda" or device == "xpu" or device == "npu"
104
+ return device == "cuda" or device == "xpu" or device == "npu" or device == "mlu"
90
105
 
91
106
 
92
107
  def empty_cache():
@@ -98,6 +113,8 @@ def empty_cache():
98
113
  torch.xpu.empty_cache()
99
114
  if is_npu_available():
100
115
  torch.npu.empty_cache()
116
+ if is_mlu_available():
117
+ torch.mlu.empty_cache()
101
118
 
102
119
 
103
120
  def get_available_device_env_name():
@@ -120,6 +137,8 @@ def gpu_count():
120
137
  return torch.xpu.device_count()
121
138
  elif is_npu_available():
122
139
  return torch.npu.device_count()
140
+ elif is_mlu_available():
141
+ return torch.mlu.device_count()
123
142
  else:
124
143
  return 0
125
144
 
@@ -123,9 +123,10 @@ class FishSpeechModel:
123
123
  logger.warning("Fish speech does not support setting voice: %s.", voice)
124
124
  if speed != 1.0:
125
125
  logger.warning("Fish speech does not support setting speed: %s.", speed)
126
- import torchaudio
127
126
  from tools.schema import ServeReferenceAudio, ServeTTSRequest
128
127
 
128
+ from .utils import audio_stream_generator, audio_to_bytes
129
+
129
130
  prompt_speech = kwargs.get("prompt_speech")
130
131
  prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
131
132
  if prompt_speech is not None:
@@ -153,40 +154,28 @@ class FishSpeechModel:
153
154
 
154
155
  if stream:
155
156
 
156
- def _stream_generator():
157
- with BytesIO() as out:
158
- writer = torchaudio.io.StreamWriter(out, format=response_format)
159
- writer.add_audio_stream(
160
- sample_rate=self._model.spec_transform.sample_rate,
161
- num_channels=1,
162
- )
163
- i = 0
164
- last_pos = 0
165
- with writer.open():
166
- for chunk in result:
167
- if chunk.code == "final":
168
- continue
169
- chunk = chunk.audio[1]
170
- if chunk is not None:
171
- chunk = chunk.reshape((chunk.shape[0], 1))
172
- trans_chunk = torch.from_numpy(chunk)
173
- writer.write_audio_chunk(i, trans_chunk)
174
- new_last_pos = out.tell()
175
- if new_last_pos != last_pos:
176
- out.seek(last_pos)
177
- encoded_bytes = out.read()
178
- yield encoded_bytes
179
- last_pos = new_last_pos
180
-
181
- return _stream_generator()
157
+ def _gen_chunk():
158
+ for chunk in result:
159
+ if chunk.code == "final":
160
+ continue
161
+ chunk = chunk.audio[1]
162
+ if chunk is not None:
163
+ yield chunk
164
+
165
+ return audio_stream_generator(
166
+ response_format=response_format,
167
+ sample_rate=self._model.spec_transform.sample_rate,
168
+ output_generator=_gen_chunk(),
169
+ output_chunk_transformer=lambda c: torch.from_numpy(
170
+ c.reshape((c.shape[0], 1))
171
+ ),
172
+ )
182
173
  else:
183
174
  result = list(result)
184
175
  sample_rate, audio = result[0].audio
185
176
  audio = np.array([audio])
186
-
187
- # Save the generated audio
188
- with BytesIO() as out:
189
- torchaudio.save(
190
- out, torch.from_numpy(audio), sample_rate, format=response_format
191
- )
192
- return out.getvalue()
177
+ return audio_to_bytes(
178
+ response_format=response_format,
179
+ sample_rate=sample_rate,
180
+ tensor=torch.from_numpy(audio),
181
+ )
@@ -280,7 +280,7 @@
280
280
  "hotword": "",
281
281
  "batch_size_s": 300
282
282
  }
283
- },
283
+ },
284
284
  {
285
285
  "model_name": "ChatTTS",
286
286
  "model_family": "ChatTTS",
@@ -329,6 +329,7 @@
329
329
  "multilingual": true,
330
330
  "virtualenv": {
331
331
  "packages": [
332
+ "librosa",
332
333
  "tiktoken",
333
334
  "lightning>=2.0.0",
334
335
  "hydra-core>=1.3.2",
@@ -340,7 +341,8 @@
340
341
  "HyperPyYAML",
341
342
  "onnxruntime>=1.16.0",
342
343
  "pyworld>=0.3.4",
343
- "numpy==1.26.4",
344
+ "WeTextProcessing<1.0.4",
345
+ "#system_numpy#",
344
346
  "#system_torch#"
345
347
  ]
346
348
  }
@@ -129,7 +129,7 @@
129
129
  "hotword": "",
130
130
  "batch_size_s": 300
131
131
  }
132
- },
132
+ },
133
133
  {
134
134
  "model_name": "ChatTTS",
135
135
  "model_family": "ChatTTS",
@@ -183,6 +183,7 @@
183
183
  "multilingual": true,
184
184
  "virtualenv": {
185
185
  "packages": [
186
+ "librosa",
186
187
  "tiktoken",
187
188
  "lightning>=2.0.0",
188
189
  "hydra-core>=1.3.2",
@@ -194,7 +195,8 @@
194
195
  "HyperPyYAML",
195
196
  "onnxruntime>=1.16.0",
196
197
  "pyworld>=0.3.4",
197
- "numpy==1.26.4",
198
+ "WeTextProcessing<1.0.4",
199
+ "#system_numpy#",
198
200
  "#system_torch#"
199
201
  ]
200
202
  }
@@ -14,7 +14,7 @@
14
14
 
15
15
  import io
16
16
  import logging
17
- import types
17
+ import typing
18
18
  import wave
19
19
  from collections.abc import Callable
20
20
 
@@ -67,7 +67,7 @@ def ensure_sample_rate(
67
67
  def audio_stream_generator(
68
68
  response_format: str,
69
69
  sample_rate: int,
70
- output_generator: types.GeneratorType,
70
+ output_generator: typing.Generator[typing.Any, None, None],
71
71
  output_chunk_transformer: Callable,
72
72
  ):
73
73
  import torch
xinference/model/core.py CHANGED
@@ -170,3 +170,4 @@ class VirtualEnvSettings(BaseModel):
170
170
  extra_index_url: Optional[str] = None
171
171
  find_links: Optional[str] = None
172
172
  trusted_host: Optional[str] = None
173
+ no_build_isolation: Optional[bool] = None
@@ -119,14 +119,6 @@ def _install():
119
119
  generate_embedding_description(model_spec)
120
120
  )
121
121
 
122
- register_custom_model()
123
-
124
- # register model description
125
- for ud_embedding in get_user_defined_embeddings():
126
- EMBEDDING_MODEL_DESCRIPTIONS.update(
127
- generate_embedding_description(ud_embedding)
128
- )
129
-
130
122
  from .flag.core import FlagEmbeddingModel
131
123
  from .sentence_transformers.core import SentenceTransformerEmbeddingModel
132
124
  from .vllm.core import VLLMEmbeddingModel
@@ -144,5 +136,13 @@ def _install():
144
136
  for model_spec in model_infos.values():
145
137
  generate_engine_config_by_model_name(model_spec)
146
138
 
139
+ register_custom_model()
140
+
141
+ # register model description
142
+ for ud_embedding in get_user_defined_embeddings():
143
+ EMBEDDING_MODEL_DESCRIPTIONS.update(
144
+ generate_embedding_description(ud_embedding)
145
+ )
146
+
147
147
  del _model_spec_json
148
148
  del _model_spec_modelscope_json
@@ -42,7 +42,11 @@ def get_user_defined_embeddings() -> List[EmbeddingModelSpec]:
42
42
  def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
43
43
  from ...constants import XINFERENCE_MODEL_DIR
44
44
  from ..utils import is_valid_model_name, is_valid_model_uri
45
- from . import BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS
45
+ from . import (
46
+ BUILTIN_EMBEDDING_MODELS,
47
+ MODELSCOPE_EMBEDDING_MODELS,
48
+ generate_engine_config_by_model_name,
49
+ )
46
50
 
47
51
  if not is_valid_model_name(model_spec.model_name):
48
52
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
@@ -63,6 +67,7 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
63
67
  )
64
68
 
65
69
  UD_EMBEDDINGS.append(model_spec)
70
+ generate_engine_config_by_model_name(model_spec)
66
71
 
67
72
  if persist:
68
73
  persist_path = os.path.join(
@@ -13,11 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from threading import Lock
17
16
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type
18
17
 
19
- from ..utils import is_valid_model_name
20
-
21
18
  if TYPE_CHECKING:
22
19
  from .core import EmbeddingModel, EmbeddingModelSpec
23
20
 
@@ -71,44 +68,6 @@ def match_embedding(
71
68
  # { embedding model name -> { engine name -> engine params } }
72
69
  EMBEDDING_ENGINES: Dict[str, Dict[str, List[Dict[str, Type["EmbeddingModel"]]]]] = {}
73
70
  SUPPORTED_ENGINES: Dict[str, List[Type["EmbeddingModel"]]] = {}
74
- UD_EMBEDDING_FAMILIES_LOCK = Lock()
75
- # user defined embedding models
76
- UD_EMBEDDING_SPECS: Dict[str, "EmbeddingModelSpec"] = {}
77
-
78
-
79
- def register_embedding(custom_embedding_spec: "EmbeddingModelSpec", persist: bool):
80
- from ..utils import is_valid_model_uri
81
- from . import generate_engine_config_by_model_name
82
-
83
- if not is_valid_model_name(custom_embedding_spec.model_name):
84
- raise ValueError(f"Invalid model name {custom_embedding_spec.model_name}.")
85
-
86
- model_uri = custom_embedding_spec.model_uri
87
- if model_uri and not is_valid_model_uri(model_uri):
88
- raise ValueError(f"Invalid model URI {model_uri}.")
89
-
90
- with UD_EMBEDDING_FAMILIES_LOCK:
91
- if (
92
- custom_embedding_spec.model_name in BUILTIN_EMBEDDING_MODELS
93
- or custom_embedding_spec.model_name in MODELSCOPE_EMBEDDING_MODELS
94
- or custom_embedding_spec.model_name in UD_EMBEDDING_SPECS
95
- ):
96
- raise ValueError(
97
- f"Model name conflicts with existing model {custom_embedding_spec.model_name}"
98
- )
99
-
100
- UD_EMBEDDING_SPECS[custom_embedding_spec.model_name] = custom_embedding_spec
101
- generate_engine_config_by_model_name(custom_embedding_spec)
102
-
103
-
104
- # TODO: add persist feature
105
- def unregister_embedding(custom_embedding_spec: "EmbeddingModelSpec"):
106
- with UD_EMBEDDING_FAMILIES_LOCK:
107
- model_name = custom_embedding_spec.model_name
108
- if model_name in UD_EMBEDDING_SPECS:
109
- del UD_EMBEDDING_SPECS[model_name]
110
- if model_name in EMBEDDING_ENGINES:
111
- del EMBEDDING_ENGINES[model_name]
112
71
 
113
72
 
114
73
  def check_engine_by_model_name_and_engine(
@@ -275,6 +275,15 @@
275
275
  "dimensions": 1024,
276
276
  "max_tokens": 8192,
277
277
  "language": ["89 languages supported"],
278
- "model_id": "jinaai/jina-clip-v2"
278
+ "model_id": "jinaai/jina-clip-v2",
279
+ "virtualenv": {
280
+ "packages": [
281
+ "sentence_transformers",
282
+ "transformers==4.51.3",
283
+ "xformers",
284
+ "flash_attn==2.7.4 ; sys_platform=='linux'"
285
+ ],
286
+ "no_build_isolation": true
287
+ }
279
288
  }
280
289
  ]
@@ -279,6 +279,15 @@
279
279
  "max_tokens": 8192,
280
280
  "language": ["89 languages supported"],
281
281
  "model_id": "jinaai/jina-clip-v2",
282
- "model_hub": "modelscope"
282
+ "model_hub": "modelscope",
283
+ "virtualenv": {
284
+ "packages": [
285
+ "sentence_transformers",
286
+ "transformers==4.51.3",
287
+ "xformers",
288
+ "flash_attn==2.7.3 ; sys_platform=='linux'"
289
+ ],
290
+ "no_build_isolation": true
291
+ }
283
292
  }
284
293
  ]
@@ -90,9 +90,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
90
90
  elif "qwen3" in self._model_spec.model_name.lower():
91
91
  # qwen3 embedding
92
92
  flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
93
+ flash_attn_enabled = self._kwargs.get("enable_flash_attn", True)
93
94
  model_kwargs = {"device_map": "auto"}
94
95
  tokenizer_kwargs = {}
95
- if flash_attn_installed:
96
+ if flash_attn_installed and flash_attn_enabled:
96
97
  model_kwargs["attn_implementation"] = "flash_attention_2"
97
98
  model_kwargs["torch_dtype"] = "bfloat16"
98
99
  tokenizer_kwargs["padding_side"] = "left"
@@ -254,8 +255,14 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
254
255
  # when batching, the attention mask 1 means there is a token
255
256
  # thus we just sum up it to get the total number of tokens
256
257
  if "clip" in self._model_spec.model_name.lower():
257
- all_token_nums += features["input_ids"].numel()
258
- all_token_nums += features["pixel_values"].numel()
258
+ if "input_ids" in features and hasattr(
259
+ features["input_ids"], "numel"
260
+ ):
261
+ all_token_nums += features["input_ids"].numel()
262
+ if "pixel_values" in features and hasattr(
263
+ features["pixel_values"], "numel"
264
+ ):
265
+ all_token_nums += features["pixel_values"].numel()
259
266
  else:
260
267
  all_token_nums += features["attention_mask"].sum().item()
261
268
 
@@ -340,24 +347,32 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
340
347
  img = Image.open(image_data)
341
348
  return img
342
349
 
343
- objs: list[dict[str, str]] = []
344
- for item in sentences:
345
- if isinstance(item, dict):
346
- if item.get("text") is not None:
347
- objs.append(item["text"])
348
- elif item.get("image") is not None:
349
- if re.match(r"^data:image/.+;base64,", item["image"]):
350
- image = base64_to_image(item["image"])
351
- objs.append(image)
350
+ objs: list[str] = []
351
+ if isinstance(sentences, str):
352
+ objs.append(sentences)
353
+ else:
354
+ for item in sentences:
355
+ if isinstance(item, dict):
356
+ if item.get("text") is not None:
357
+ objs.append(item["text"])
358
+ elif item.get("image") is not None:
359
+ if re.match(r"^data:image/.+;base64,", item["image"]):
360
+ image = base64_to_image(item["image"])
361
+ objs.append(image)
362
+ else:
363
+ objs.append(item["image"])
352
364
  else:
353
- objs.append(item["image"])
365
+ raise ValueError("Please check the input data.")
366
+ elif isinstance(item, str):
367
+ objs.append(item)
354
368
  else:
355
- logger.error("Please check the input data.")
369
+ raise ValueError("Please check the input data.")
370
+
356
371
  all_embeddings, all_token_nums = encode(
357
372
  self._model,
358
373
  objs,
359
374
  convert_to_numpy=False,
360
- **self._kwargs,
375
+ **kwargs,
361
376
  )
362
377
  else:
363
378
  all_embeddings, all_token_nums = encode(
@@ -189,7 +189,7 @@ class FlexibleModel:
189
189
  Load the model.
190
190
  """
191
191
 
192
- def infer(self, **kwargs):
192
+ def infer(self, *args, **kwargs):
193
193
  """
194
194
  Call model to inference.
195
195
  """
@@ -13,4 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .image_process_launcher import launcher as image_process
16
+ from .modelscope_launcher import launcher as modelscope
16
17
  from .transformers_launcher import launcher as transformers
18
+ from .yolo_launcher import launcher as yolo
@@ -23,7 +23,7 @@ from ..core import FlexibleModel, FlexibleModelSpec
23
23
 
24
24
 
25
25
  class ImageRemoveBackgroundModel(FlexibleModel):
26
- def infer(self, **kwargs):
26
+ def infer(self, *args, **kwargs):
27
27
  invert = kwargs.get("invert", False)
28
28
  b64_image: str = kwargs.get("image") # type: ignore
29
29
  only_mask = kwargs.pop("only_mask", True)
@@ -0,0 +1,47 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..core import FlexibleModel, FlexibleModelSpec
16
+
17
+
18
+ class ModelScopePipelineModel(FlexibleModel):
19
+ def load(self):
20
+ # we have to move import here,
21
+ # modelscope cannot be compatible with datasets>3.2.0
22
+ # if put outside, it will just raise error
23
+ # when enabled virtualenv,
24
+ # we can make sure mdoelscope works well
25
+ from modelscope.pipelines import pipeline
26
+
27
+ config = dict(self.config or {})
28
+ if self._device:
29
+ config["device"] = self._device
30
+ self._pipeline = pipeline(model=self._model_path, **config)
31
+
32
+ def infer(self, *args, **kwargs):
33
+ return self._pipeline(*args, **kwargs)
34
+
35
+
36
+ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
37
+ device = kwargs.get("device")
38
+ if not kwargs.get("task"):
39
+ raise ValueError("modelscope launcher requires `task`")
40
+
41
+ model_path = model_spec.model_uri
42
+ if model_path is None:
43
+ raise ValueError("model_path required")
44
+
45
+ return ModelScopePipelineModel(
46
+ model_uid=model_uid, model_path=model_path, device=device, config=kwargs
47
+ )