xinference 1.7.0.post1__py3-none-any.whl → 1.7.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +3 -4
- xinference/client/__init__.py +2 -0
- xinference/client/common.py +49 -2
- xinference/client/handlers.py +18 -0
- xinference/client/restful/async_restful_client.py +1760 -0
- xinference/client/restful/restful_client.py +74 -78
- xinference/core/media_interface.py +3 -1
- xinference/core/model.py +5 -4
- xinference/core/supervisor.py +10 -5
- xinference/core/worker.py +15 -14
- xinference/deploy/local.py +51 -9
- xinference/deploy/worker.py +5 -3
- xinference/device_utils.py +22 -3
- xinference/model/audio/fish_speech.py +23 -34
- xinference/model/audio/model_spec.json +4 -2
- xinference/model/audio/model_spec_modelscope.json +4 -2
- xinference/model/audio/utils.py +2 -2
- xinference/model/core.py +1 -0
- xinference/model/embedding/__init__.py +8 -8
- xinference/model/embedding/custom.py +6 -1
- xinference/model/embedding/embed_family.py +0 -41
- xinference/model/embedding/model_spec.json +10 -1
- xinference/model/embedding/model_spec_modelscope.json +10 -1
- xinference/model/embedding/sentence_transformers/core.py +30 -15
- xinference/model/flexible/core.py +1 -1
- xinference/model/flexible/launchers/__init__.py +2 -0
- xinference/model/flexible/launchers/image_process_launcher.py +1 -1
- xinference/model/flexible/launchers/modelscope_launcher.py +47 -0
- xinference/model/flexible/launchers/transformers_launcher.py +5 -5
- xinference/model/flexible/launchers/yolo_launcher.py +62 -0
- xinference/model/llm/__init__.py +7 -0
- xinference/model/llm/core.py +18 -1
- xinference/model/llm/llama_cpp/core.py +1 -1
- xinference/model/llm/llm_family.json +41 -1
- xinference/model/llm/llm_family.py +6 -0
- xinference/model/llm/llm_family_modelscope.json +43 -1
- xinference/model/llm/mlx/core.py +271 -18
- xinference/model/llm/mlx/distributed_models/__init__.py +13 -0
- xinference/model/llm/mlx/distributed_models/core.py +164 -0
- xinference/model/llm/mlx/distributed_models/deepseek_v3.py +75 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +82 -0
- xinference/model/llm/mlx/distributed_models/qwen3.py +82 -0
- xinference/model/llm/mlx/distributed_models/qwen3_moe.py +76 -0
- xinference/model/llm/reasoning_parser.py +12 -6
- xinference/model/llm/sglang/core.py +8 -4
- xinference/model/llm/transformers/chatglm.py +4 -1
- xinference/model/llm/transformers/core.py +4 -2
- xinference/model/llm/transformers/multimodal/cogagent.py +10 -4
- xinference/model/llm/transformers/multimodal/intern_vl.py +1 -1
- xinference/model/llm/utils.py +36 -17
- xinference/model/llm/vllm/core.py +142 -34
- xinference/model/llm/vllm/distributed_executor.py +96 -21
- xinference/model/llm/vllm/xavier/transfer.py +2 -2
- xinference/model/rerank/core.py +16 -9
- xinference/model/rerank/model_spec.json +3 -3
- xinference/model/rerank/model_spec_modelscope.json +3 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.9b12b7f9.js +3 -0
- xinference/web/ui/build/static/js/main.9b12b7f9.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0fd4820d93f99509e80d8702dc3f6f8272424acab5608fa7c0e82cb1d3250a87.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f75545479c17fdfe2a00235fa4a0e9da1ae95e6b3caafba87ded92de6b0240e4.json +1 -0
- xinference/web/ui/src/locales/en.json +3 -0
- xinference/web/ui/src/locales/ja.json +3 -0
- xinference/web/ui/src/locales/ko.json +3 -0
- xinference/web/ui/src/locales/zh.json +3 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/METADATA +4 -3
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/RECORD +77 -67
- xinference/web/ui/build/static/js/main.8a9e3ba0.js +0 -3
- xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/cc97b49285d7717c63374766c789141a4329a04582ab32756d7e0e614d4c5c7f.json +0 -1
- /xinference/web/ui/build/static/js/{main.8a9e3ba0.js.LICENSE.txt → main.9b12b7f9.js.LICENSE.txt} +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/WHEEL +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/top_level.txt +0 -0
xinference/model/llm/mlx/core.py
CHANGED
|
@@ -11,14 +11,32 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import concurrent.futures
|
|
17
|
+
import importlib
|
|
14
18
|
import importlib.util
|
|
15
19
|
import logging
|
|
20
|
+
import pathlib
|
|
16
21
|
import platform
|
|
17
22
|
import sys
|
|
23
|
+
import threading
|
|
18
24
|
import time
|
|
19
25
|
import uuid
|
|
20
26
|
from dataclasses import dataclass, field
|
|
21
|
-
from typing import
|
|
27
|
+
from typing import (
|
|
28
|
+
Any,
|
|
29
|
+
Callable,
|
|
30
|
+
Dict,
|
|
31
|
+
Iterator,
|
|
32
|
+
List,
|
|
33
|
+
Optional,
|
|
34
|
+
Tuple,
|
|
35
|
+
TypedDict,
|
|
36
|
+
Union,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
import xoscar as xo
|
|
22
40
|
|
|
23
41
|
from ....fields import max_tokens_field
|
|
24
42
|
from ....types import (
|
|
@@ -29,7 +47,7 @@ from ....types import (
|
|
|
29
47
|
CompletionUsage,
|
|
30
48
|
LoRA,
|
|
31
49
|
)
|
|
32
|
-
from ..core import LLM
|
|
50
|
+
from ..core import LLM, chat_context_var
|
|
33
51
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
52
|
from ..utils import (
|
|
35
53
|
DEEPSEEK_TOOL_CALL_FAMILY,
|
|
@@ -46,6 +64,10 @@ class MLXModelConfig(TypedDict, total=False):
|
|
|
46
64
|
max_gpu_memory: str
|
|
47
65
|
trust_remote_code: bool
|
|
48
66
|
reasoning_content: bool
|
|
67
|
+
# distributed
|
|
68
|
+
address: Optional[str]
|
|
69
|
+
shard: Optional[int]
|
|
70
|
+
n_worker: Optional[int]
|
|
49
71
|
|
|
50
72
|
|
|
51
73
|
class MLXGenerateConfig(TypedDict, total=False):
|
|
@@ -71,6 +93,8 @@ class PromptCache:
|
|
|
71
93
|
|
|
72
94
|
|
|
73
95
|
class MLXModel(LLM):
|
|
96
|
+
_rank_to_addresses: Optional[Dict[int, str]]
|
|
97
|
+
|
|
74
98
|
def __init__(
|
|
75
99
|
self,
|
|
76
100
|
model_uid: str,
|
|
@@ -84,10 +108,43 @@ class MLXModel(LLM):
|
|
|
84
108
|
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
|
|
85
109
|
self._use_fast_tokenizer = True
|
|
86
110
|
self._model_config: MLXModelConfig = self._sanitize_model_config(model_config)
|
|
111
|
+
# for distributed
|
|
112
|
+
assert model_config is not None
|
|
113
|
+
self._address = model_config.pop("address", None)
|
|
114
|
+
self._n_worker = model_config.pop("n_worker", 1)
|
|
115
|
+
self._shard = model_config.pop("shard", 0)
|
|
116
|
+
self._driver_info = model_config.pop("driver_info", None) # type: ignore
|
|
117
|
+
self._rank_to_addresses = None
|
|
118
|
+
self._loading_thread = None
|
|
119
|
+
self._loading_error = None
|
|
120
|
+
self._all_worker_started = asyncio.Event()
|
|
87
121
|
self._max_kv_size = None
|
|
88
122
|
self._prompt_cache = None
|
|
89
123
|
if peft_model is not None:
|
|
90
124
|
raise ValueError("MLX engine has not supported lora yet")
|
|
125
|
+
# used to call async
|
|
126
|
+
self._loop = None
|
|
127
|
+
|
|
128
|
+
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
|
129
|
+
# loop will be passed into ModelWrapper,
|
|
130
|
+
# to call aynsc method with asyncio.run_coroutine_threadsafe
|
|
131
|
+
self._loop = loop # type: ignore
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def driver_info(self) -> Optional[dict]:
|
|
135
|
+
return self._driver_info
|
|
136
|
+
|
|
137
|
+
def set_shard_info(self, shard: int, address: str):
|
|
138
|
+
# set shard info to rank 0
|
|
139
|
+
if self._rank_to_addresses is None:
|
|
140
|
+
self._rank_to_addresses = {}
|
|
141
|
+
self._rank_to_addresses[shard] = address
|
|
142
|
+
if len(self._rank_to_addresses) == self._n_worker:
|
|
143
|
+
self._all_worker_started.set()
|
|
144
|
+
|
|
145
|
+
async def get_rank_addresses(self) -> Optional[Dict[int, str]]:
|
|
146
|
+
await self._all_worker_started.wait()
|
|
147
|
+
return self._rank_to_addresses
|
|
91
148
|
|
|
92
149
|
def _sanitize_model_config(
|
|
93
150
|
self, model_config: Optional[MLXModelConfig]
|
|
@@ -158,6 +215,97 @@ class MLXModel(LLM):
|
|
|
158
215
|
tokenizer.add_eos_token(stop_token_id)
|
|
159
216
|
return model, tokenizer
|
|
160
217
|
|
|
218
|
+
def _load_model_shard(self, **kwargs):
|
|
219
|
+
try:
|
|
220
|
+
import mlx.core as mx
|
|
221
|
+
from mlx_lm.utils import load_model, load_tokenizer
|
|
222
|
+
except ImportError:
|
|
223
|
+
error_message = "Failed to import module 'mlx_lm'"
|
|
224
|
+
installation_guide = [
|
|
225
|
+
"Please make sure 'mlx_lm' is installed. ",
|
|
226
|
+
"You can install it by `pip install mlx_lm`\n",
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
230
|
+
|
|
231
|
+
# Ensure some attributes correctly inited by model actor
|
|
232
|
+
assert (
|
|
233
|
+
self._loop is not None and self._rank_to_addresses is not None
|
|
234
|
+
), "Service not started correctly"
|
|
235
|
+
|
|
236
|
+
tokenizer_config = dict(
|
|
237
|
+
use_fast=self._use_fast_tokenizer,
|
|
238
|
+
trust_remote_code=kwargs["trust_remote_code"],
|
|
239
|
+
revision=kwargs["revision"],
|
|
240
|
+
)
|
|
241
|
+
logger.debug(
|
|
242
|
+
"loading model with tokenizer config: %s, model config: %s, shard: %d, n_worker: %d",
|
|
243
|
+
tokenizer_config,
|
|
244
|
+
self._model_config,
|
|
245
|
+
self._shard,
|
|
246
|
+
self._n_worker,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
cache_limit_gb = kwargs.get("cache_limit_gb", None)
|
|
250
|
+
if cache_limit_gb:
|
|
251
|
+
logger.debug(f"Setting cache limit to {cache_limit_gb} GB")
|
|
252
|
+
mx.metal.set_cache_limit(cache_limit_gb * 1024 * 1024 * 1024)
|
|
253
|
+
|
|
254
|
+
self._max_kv_size = kwargs.get("max_kv_size", None)
|
|
255
|
+
self._prompt_cache = PromptCache()
|
|
256
|
+
|
|
257
|
+
self._model, config = load_model(
|
|
258
|
+
pathlib.Path(self.model_path),
|
|
259
|
+
lazy=True,
|
|
260
|
+
get_model_classes=self._get_classes,
|
|
261
|
+
)
|
|
262
|
+
model = self._model.model
|
|
263
|
+
model.rank = self._shard
|
|
264
|
+
model.world_size = self._n_worker
|
|
265
|
+
model.model_uid = self.model_uid
|
|
266
|
+
model.loop = self._loop
|
|
267
|
+
model.address = self._address
|
|
268
|
+
model.rank_to_addresses = self._rank_to_addresses
|
|
269
|
+
|
|
270
|
+
# create actors and so forth
|
|
271
|
+
model.prepare()
|
|
272
|
+
# real load the partial weights
|
|
273
|
+
model.pipeline()
|
|
274
|
+
mx.eval(model.parameters())
|
|
275
|
+
|
|
276
|
+
self._tokenizer = load_tokenizer(
|
|
277
|
+
pathlib.Path(self.model_path),
|
|
278
|
+
tokenizer_config,
|
|
279
|
+
eos_token_ids=config.get("eos_token_id", None),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def _get_classes(config: dict):
|
|
284
|
+
"""
|
|
285
|
+
Retrieve the model and model args classes based on the configuration
|
|
286
|
+
that supported distributed inference.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
config (dict): The model configuration.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
A tuple containing the Model class and the ModelArgs class.
|
|
293
|
+
"""
|
|
294
|
+
from mlx_lm.utils import MODEL_REMAPPING
|
|
295
|
+
|
|
296
|
+
model_type = config["model_type"]
|
|
297
|
+
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
298
|
+
try:
|
|
299
|
+
arch = importlib.import_module(
|
|
300
|
+
f"xinference.model.llm.mlx.distributed_models.{model_type}"
|
|
301
|
+
)
|
|
302
|
+
except ImportError:
|
|
303
|
+
msg = f"Model type {model_type} not supported for distributed inference."
|
|
304
|
+
logger.error(msg)
|
|
305
|
+
raise ValueError(msg)
|
|
306
|
+
|
|
307
|
+
return arch.Model, arch.ModelArgs
|
|
308
|
+
|
|
161
309
|
def load(self):
|
|
162
310
|
reasoning_content = self._model_config.pop("reasoning_content")
|
|
163
311
|
enable_thinking = self._model_config.pop("enable_thinking", True)
|
|
@@ -172,7 +320,49 @@ class MLXModel(LLM):
|
|
|
172
320
|
kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
|
|
173
321
|
kwargs["cache_limit_gb"] = self._model_config.pop("cache_limit_gb", None)
|
|
174
322
|
|
|
175
|
-
|
|
323
|
+
if self._n_worker <= 1:
|
|
324
|
+
self._model, self._tokenizer = self._load_model(**kwargs)
|
|
325
|
+
else:
|
|
326
|
+
|
|
327
|
+
def _load():
|
|
328
|
+
try:
|
|
329
|
+
if self._shard == 0:
|
|
330
|
+
self._driver_info = {"address": self._address}
|
|
331
|
+
self.set_shard_info(0, self._address)
|
|
332
|
+
else:
|
|
333
|
+
assert self._driver_info is not None
|
|
334
|
+
driver_address = self._driver_info["address"]
|
|
335
|
+
|
|
336
|
+
async def wait_for_all_shards():
|
|
337
|
+
model_ref = await xo.actor_ref(
|
|
338
|
+
address=driver_address, uid=self.raw_model_uid
|
|
339
|
+
)
|
|
340
|
+
# set shard info
|
|
341
|
+
await model_ref.set_shard_info(self._shard, self._address)
|
|
342
|
+
# wait for all shards
|
|
343
|
+
self._rank_to_addresses = (
|
|
344
|
+
await model_ref.get_rank_addresses()
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
asyncio.run_coroutine_threadsafe(
|
|
348
|
+
wait_for_all_shards(), self._loop
|
|
349
|
+
).result()
|
|
350
|
+
|
|
351
|
+
self._load_model_shard(**kwargs)
|
|
352
|
+
except:
|
|
353
|
+
logger.exception("Loading mlx shard model failed")
|
|
354
|
+
self._loading_error = sys.exc_info()
|
|
355
|
+
|
|
356
|
+
# distributed inference
|
|
357
|
+
self._loading_thread = threading.Thread(target=_load)
|
|
358
|
+
self._loading_thread.start()
|
|
359
|
+
|
|
360
|
+
def wait_for_load(self):
|
|
361
|
+
if self._loading_thread:
|
|
362
|
+
self._loading_thread.join()
|
|
363
|
+
if self._loading_error:
|
|
364
|
+
_, err, tb = self._loading_error
|
|
365
|
+
raise err.with_traceback(tb)
|
|
176
366
|
|
|
177
367
|
@classmethod
|
|
178
368
|
def check_lib(cls) -> bool:
|
|
@@ -369,20 +559,57 @@ class MLXModel(LLM):
|
|
|
369
559
|
)
|
|
370
560
|
yield completion_chunk, completion_usage
|
|
371
561
|
|
|
562
|
+
def _run_non_drivers(
|
|
563
|
+
self, method: str, stream: bool, *args, **kwargs
|
|
564
|
+
) -> Optional[concurrent.futures.Future]:
|
|
565
|
+
assert self._n_worker is not None and self._shard is not None
|
|
566
|
+
if self._n_worker == 1 or self._shard > 0:
|
|
567
|
+
# only run for distributed driver
|
|
568
|
+
return None
|
|
569
|
+
|
|
570
|
+
async def run_other_shard(shard: int):
|
|
571
|
+
assert self._rank_to_addresses is not None
|
|
572
|
+
address = self._rank_to_addresses[shard]
|
|
573
|
+
model_actor_ref = await xo.actor_ref(
|
|
574
|
+
address=address, uid=self.raw_model_uid
|
|
575
|
+
)
|
|
576
|
+
# we don't actually need to get the result from shard >= 1
|
|
577
|
+
if stream:
|
|
578
|
+
async for _ in await getattr(model_actor_ref, method)(*args, **kwargs):
|
|
579
|
+
pass
|
|
580
|
+
else:
|
|
581
|
+
await getattr(model_actor_ref, method)(*args, **kwargs)
|
|
582
|
+
|
|
583
|
+
async def run_non_driver_shards():
|
|
584
|
+
logger.debug("Start to run non driver %s", method)
|
|
585
|
+
coros = []
|
|
586
|
+
for rank in range(1, self._n_worker):
|
|
587
|
+
coros.append(run_other_shard(rank))
|
|
588
|
+
await asyncio.gather(*coros)
|
|
589
|
+
|
|
590
|
+
assert self._loop is not None
|
|
591
|
+
return asyncio.run_coroutine_threadsafe(run_non_driver_shards(), self._loop)
|
|
592
|
+
|
|
372
593
|
def generate(
|
|
373
594
|
self,
|
|
374
595
|
prompt: Union[str, Dict[str, Any]],
|
|
375
596
|
generate_config: Optional[MLXGenerateConfig] = None,
|
|
597
|
+
from_chat: bool = False,
|
|
376
598
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
377
599
|
def generator_wrapper(
|
|
378
|
-
prompt: Union[str, Dict[str, Any]],
|
|
600
|
+
prompt: Union[str, Dict[str, Any]],
|
|
601
|
+
generate_config: MLXGenerateConfig,
|
|
602
|
+
cb: Callable,
|
|
379
603
|
) -> Iterator[CompletionChunk]:
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
604
|
+
try:
|
|
605
|
+
for completion_chunk, completion_usage in self._generate_stream(
|
|
606
|
+
prompt,
|
|
607
|
+
generate_config,
|
|
608
|
+
):
|
|
609
|
+
completion_chunk["usage"] = completion_usage
|
|
610
|
+
yield completion_chunk
|
|
611
|
+
finally:
|
|
612
|
+
cb()
|
|
386
613
|
|
|
387
614
|
logger.debug(
|
|
388
615
|
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
|
|
@@ -394,6 +621,9 @@ class MLXModel(LLM):
|
|
|
394
621
|
assert self._tokenizer is not None
|
|
395
622
|
|
|
396
623
|
stream = generate_config.get("stream", False)
|
|
624
|
+
fut = self._run_non_drivers(
|
|
625
|
+
"generate", stream, prompt, generate_config=generate_config
|
|
626
|
+
)
|
|
397
627
|
if not stream:
|
|
398
628
|
for completion_chunk, completion_usage in self._generate_stream(
|
|
399
629
|
prompt,
|
|
@@ -408,9 +638,18 @@ class MLXModel(LLM):
|
|
|
408
638
|
choices=completion_chunk["choices"],
|
|
409
639
|
usage=completion_usage,
|
|
410
640
|
)
|
|
411
|
-
|
|
641
|
+
try:
|
|
642
|
+
return completion
|
|
643
|
+
finally:
|
|
644
|
+
if fut:
|
|
645
|
+
fut.result()
|
|
412
646
|
else:
|
|
413
|
-
|
|
647
|
+
|
|
648
|
+
def finish_callback():
|
|
649
|
+
if fut:
|
|
650
|
+
fut.result()
|
|
651
|
+
|
|
652
|
+
return generator_wrapper(prompt, generate_config, finish_callback)
|
|
414
653
|
|
|
415
654
|
|
|
416
655
|
class MLXChatModel(MLXModel, ChatModelMixin):
|
|
@@ -452,9 +691,14 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
452
691
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
453
692
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
454
693
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
455
|
-
|
|
456
|
-
self._get_chat_template_kwargs_from_generate_config(
|
|
694
|
+
chat_template_kwargs = (
|
|
695
|
+
self._get_chat_template_kwargs_from_generate_config(
|
|
696
|
+
generate_config, self.reasoning_parser
|
|
697
|
+
)
|
|
698
|
+
or {}
|
|
457
699
|
)
|
|
700
|
+
chat_context_var.set(chat_template_kwargs)
|
|
701
|
+
full_context_kwargs = chat_template_kwargs.copy()
|
|
458
702
|
if tools:
|
|
459
703
|
if (
|
|
460
704
|
model_family in QWEN_TOOL_CALL_FAMILY
|
|
@@ -470,11 +714,11 @@ class MLXChatModel(MLXModel, ChatModelMixin):
|
|
|
470
714
|
|
|
471
715
|
stream = generate_config.get("stream", False)
|
|
472
716
|
if stream:
|
|
473
|
-
it = self.generate(full_prompt, generate_config)
|
|
717
|
+
it = self.generate(full_prompt, generate_config, from_chat=True)
|
|
474
718
|
assert isinstance(it, Iterator)
|
|
475
719
|
return self._to_chat_completion_chunks(it, self.reasoning_parser)
|
|
476
720
|
else:
|
|
477
|
-
c = self.generate(full_prompt, generate_config)
|
|
721
|
+
c = self.generate(full_prompt, generate_config, from_chat=True)
|
|
478
722
|
assert not isinstance(c, Iterator)
|
|
479
723
|
if tools:
|
|
480
724
|
return self._post_process_completion(
|
|
@@ -518,6 +762,11 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
|
518
762
|
return load(self.model_path)
|
|
519
763
|
|
|
520
764
|
def load(self):
|
|
765
|
+
if self._n_worker > 1:
|
|
766
|
+
raise NotImplementedError(
|
|
767
|
+
"Distributed inference is not supported for vision models"
|
|
768
|
+
)
|
|
769
|
+
|
|
521
770
|
kwargs = {}
|
|
522
771
|
kwargs["revision"] = self._model_config.get(
|
|
523
772
|
"revision", self.model_spec.model_revision
|
|
@@ -636,10 +885,14 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
|
|
|
636
885
|
if "internvl2" not in model_family.lower():
|
|
637
886
|
from qwen_vl_utils import process_vision_info
|
|
638
887
|
|
|
639
|
-
|
|
640
|
-
self._get_chat_template_kwargs_from_generate_config(
|
|
888
|
+
chat_template_kwargs = (
|
|
889
|
+
self._get_chat_template_kwargs_from_generate_config(
|
|
890
|
+
generate_config, self.reasoning_parser
|
|
891
|
+
)
|
|
641
892
|
or {}
|
|
642
893
|
)
|
|
894
|
+
chat_context_var.set(chat_template_kwargs)
|
|
895
|
+
full_context_kwargs = chat_template_kwargs.copy()
|
|
643
896
|
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
644
897
|
full_context_kwargs["tools"] = tools
|
|
645
898
|
assert self.model_family.chat_template is not None
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
|
19
|
+
|
|
20
|
+
import xoscar as xo
|
|
21
|
+
from xoscar.utils import lazy_import
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import mlx.core as mx
|
|
25
|
+
else:
|
|
26
|
+
mx = lazy_import("mlx.core")
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
DEBUG_DISTRIBUTED_MLX = bool(int(os.getenv("XINFERENCE_DEBUG_DISTRIBUTED_MLX", "0")))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ReceiverActor(xo.StatelessActor):
|
|
33
|
+
def __init__(self, *args, **kwargs):
|
|
34
|
+
super().__init__(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
self._recv_queue = asyncio.Queue()
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def gen_uid(cls, uid: str, rank: int):
|
|
40
|
+
return f"Receiver-{uid}-{rank}"
|
|
41
|
+
|
|
42
|
+
async def send(self, data: "mx.array"):
|
|
43
|
+
# no need to use async function,
|
|
44
|
+
# but make it more convenient to patch this function for test purpose
|
|
45
|
+
if not isinstance(data, mx.array):
|
|
46
|
+
data = mx.array(data)
|
|
47
|
+
self._recv_queue.put_nowait(data)
|
|
48
|
+
|
|
49
|
+
async def recv(self):
|
|
50
|
+
return await self._recv_queue.get()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DistributedModelMixin:
|
|
54
|
+
rank: int
|
|
55
|
+
world_size: int
|
|
56
|
+
model_uid: Optional[str]
|
|
57
|
+
address: Optional[str]
|
|
58
|
+
_receiver_ref: Optional[xo.ActorRefType[ReceiverActor]]
|
|
59
|
+
rank_to_addresses: Optional[Dict[int, str]]
|
|
60
|
+
|
|
61
|
+
layers: list
|
|
62
|
+
|
|
63
|
+
def __init__(self):
|
|
64
|
+
self.rank = 0
|
|
65
|
+
self.world_size = 1
|
|
66
|
+
self.model_uid = None
|
|
67
|
+
self.loop = None
|
|
68
|
+
self.address = None
|
|
69
|
+
# actor ref
|
|
70
|
+
self._receiver_ref = None
|
|
71
|
+
self.rank_to_addresses = None
|
|
72
|
+
|
|
73
|
+
def prepare(self):
|
|
74
|
+
coro = xo.create_actor(
|
|
75
|
+
ReceiverActor,
|
|
76
|
+
uid=ReceiverActor.gen_uid(self.model_uid, self.rank),
|
|
77
|
+
address=self.address,
|
|
78
|
+
)
|
|
79
|
+
self._receiver_ref = asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
80
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
81
|
+
logger.debug("Finish preparing distributed env for rank %s", self.rank)
|
|
82
|
+
|
|
83
|
+
def _send_stage_result(self, result: "mx.array"):
|
|
84
|
+
assert self.rank > 0
|
|
85
|
+
assert self.rank_to_addresses is not None
|
|
86
|
+
assert self.model_uid is not None
|
|
87
|
+
last_rank = self.rank - 1
|
|
88
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
89
|
+
logger.debug(
|
|
90
|
+
"Start to send %s partial result to rank %d", self.model_uid, last_rank
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
async def send():
|
|
94
|
+
receiver_ref = await xo.actor_ref(
|
|
95
|
+
uid=ReceiverActor.gen_uid(self.model_uid, last_rank),
|
|
96
|
+
address=self.rank_to_addresses[last_rank],
|
|
97
|
+
)
|
|
98
|
+
return await receiver_ref.send(result)
|
|
99
|
+
|
|
100
|
+
asyncio.run_coroutine_threadsafe(send(), self.loop).result()
|
|
101
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
102
|
+
logger.debug(
|
|
103
|
+
"Finish send %s partial result to rank %d, shape %s",
|
|
104
|
+
self.model_uid,
|
|
105
|
+
last_rank,
|
|
106
|
+
result.shape,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def _wait_prev_stage_result(self):
|
|
110
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
111
|
+
logger.debug("Wait for partial result from prev shard %d", self.rank + 1)
|
|
112
|
+
coro = self._receiver_ref.recv()
|
|
113
|
+
result = asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
114
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
115
|
+
logger.debug(
|
|
116
|
+
"Received partial result from prev shard %d, shape %s",
|
|
117
|
+
self.rank + 1,
|
|
118
|
+
result.shape,
|
|
119
|
+
)
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
def _broadcast_result(self, result: "mx.array"):
|
|
123
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
124
|
+
logger.debug("broadcast result from driver")
|
|
125
|
+
|
|
126
|
+
async def broadcast(rank: int):
|
|
127
|
+
assert self.model_uid is not None
|
|
128
|
+
assert self.rank_to_addresses is not None
|
|
129
|
+
|
|
130
|
+
receiver = await xo.actor_ref(
|
|
131
|
+
uid=ReceiverActor.gen_uid(self.model_uid, rank),
|
|
132
|
+
address=self.rank_to_addresses[rank],
|
|
133
|
+
)
|
|
134
|
+
await receiver.send(result)
|
|
135
|
+
|
|
136
|
+
async def broadcast_all():
|
|
137
|
+
coros = []
|
|
138
|
+
for rank in range(1, self.world_size):
|
|
139
|
+
coros.append(broadcast(rank))
|
|
140
|
+
await asyncio.gather(*coros)
|
|
141
|
+
|
|
142
|
+
return asyncio.run_coroutine_threadsafe(broadcast_all(), self.loop).result()
|
|
143
|
+
|
|
144
|
+
def _get_result(self) -> "mx.array":
|
|
145
|
+
if DEBUG_DISTRIBUTED_MLX:
|
|
146
|
+
logger.debug("Get result from broadcasted data on self receiver")
|
|
147
|
+
assert self.model_uid is not None
|
|
148
|
+
coro = xo.actor_ref(
|
|
149
|
+
uid=ReceiverActor.gen_uid(self.model_uid, self.rank), address=self.address
|
|
150
|
+
)
|
|
151
|
+
ref = asyncio.run_coroutine_threadsafe(coro, self.loop).result()
|
|
152
|
+
return asyncio.run_coroutine_threadsafe(ref.recv(), loop=self.loop).result()
|
|
153
|
+
|
|
154
|
+
def pipeline(self):
|
|
155
|
+
pipeline_size, rank = self.world_size, self.rank
|
|
156
|
+
layers_per_rank = len(self.layers) // pipeline_size
|
|
157
|
+
extra = len(self.layers) - layers_per_rank * pipeline_size
|
|
158
|
+
if self.rank < extra:
|
|
159
|
+
layers_per_rank += 1
|
|
160
|
+
self.start_idx = (pipeline_size - rank - 1) * layers_per_rank
|
|
161
|
+
self.end_idx = self.start_idx + layers_per_rank
|
|
162
|
+
self.layers = self.layers[: self.end_idx]
|
|
163
|
+
self.layers[: self.start_idx] = [None] * self.start_idx
|
|
164
|
+
self.num_layers = len(self.layers) - self.start_idx
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
import mlx.nn as nn
|
|
19
|
+
from mlx_lm.models.base import create_attention_mask
|
|
20
|
+
from mlx_lm.models.deepseek_v3 import DeepseekV3Model as _DeepseekV3Model
|
|
21
|
+
from mlx_lm.models.deepseek_v3 import Model as _Model
|
|
22
|
+
from mlx_lm.models.deepseek_v3 import ModelArgs
|
|
23
|
+
|
|
24
|
+
from .core import DistributedModelMixin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DeepseekV3Model(_DeepseekV3Model, DistributedModelMixin):
|
|
28
|
+
def __init__(self, *args, **kwargs):
|
|
29
|
+
_DeepseekV3Model.__init__(self, *args, **kwargs)
|
|
30
|
+
DistributedModelMixin.__init__(self)
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self,
|
|
34
|
+
x: mx.array,
|
|
35
|
+
cache: Optional[Any] = None,
|
|
36
|
+
mask: Optional[mx.array] = None,
|
|
37
|
+
) -> mx.array:
|
|
38
|
+
h = self.embed_tokens(x)
|
|
39
|
+
|
|
40
|
+
pipeline_rank = self.rank
|
|
41
|
+
pipeline_size = self.world_size
|
|
42
|
+
if mask is None:
|
|
43
|
+
mask = create_attention_mask(h, cache)
|
|
44
|
+
|
|
45
|
+
if cache is None:
|
|
46
|
+
cache = [None] * self.num_layers
|
|
47
|
+
|
|
48
|
+
# Receive from the previous process in the pipeline
|
|
49
|
+
|
|
50
|
+
if pipeline_rank < pipeline_size - 1:
|
|
51
|
+
# wait for previous result
|
|
52
|
+
h = self._wait_prev_stage_result()
|
|
53
|
+
|
|
54
|
+
for i in range(self.num_layers):
|
|
55
|
+
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
|
56
|
+
mx.eval(h)
|
|
57
|
+
|
|
58
|
+
if pipeline_rank != 0:
|
|
59
|
+
# Send to the next process in the pipeline
|
|
60
|
+
self._send_stage_result(h)
|
|
61
|
+
# wait for the final result
|
|
62
|
+
h = self._get_result()
|
|
63
|
+
else:
|
|
64
|
+
self._set_result(h)
|
|
65
|
+
|
|
66
|
+
return self.norm(h)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Model(_Model):
|
|
70
|
+
def __init__(self, config: ModelArgs):
|
|
71
|
+
nn.Module.__init__(self)
|
|
72
|
+
self.args = config
|
|
73
|
+
self.model_type = config.model_type
|
|
74
|
+
self.model = DeepseekV3Model(config)
|
|
75
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|