xinference 1.7.0.post1__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (83) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +3 -4
  3. xinference/client/__init__.py +2 -0
  4. xinference/client/common.py +49 -2
  5. xinference/client/handlers.py +18 -0
  6. xinference/client/restful/async_restful_client.py +1760 -0
  7. xinference/client/restful/restful_client.py +74 -78
  8. xinference/core/media_interface.py +3 -1
  9. xinference/core/model.py +5 -4
  10. xinference/core/supervisor.py +10 -5
  11. xinference/core/worker.py +15 -14
  12. xinference/deploy/local.py +51 -9
  13. xinference/deploy/worker.py +5 -3
  14. xinference/device_utils.py +22 -3
  15. xinference/model/audio/fish_speech.py +23 -34
  16. xinference/model/audio/model_spec.json +4 -2
  17. xinference/model/audio/model_spec_modelscope.json +4 -2
  18. xinference/model/audio/utils.py +2 -2
  19. xinference/model/core.py +1 -0
  20. xinference/model/embedding/__init__.py +8 -8
  21. xinference/model/embedding/custom.py +6 -1
  22. xinference/model/embedding/embed_family.py +0 -41
  23. xinference/model/embedding/model_spec.json +10 -1
  24. xinference/model/embedding/model_spec_modelscope.json +10 -1
  25. xinference/model/embedding/sentence_transformers/core.py +30 -15
  26. xinference/model/flexible/core.py +1 -1
  27. xinference/model/flexible/launchers/__init__.py +2 -0
  28. xinference/model/flexible/launchers/image_process_launcher.py +1 -1
  29. xinference/model/flexible/launchers/modelscope_launcher.py +47 -0
  30. xinference/model/flexible/launchers/transformers_launcher.py +5 -5
  31. xinference/model/flexible/launchers/yolo_launcher.py +62 -0
  32. xinference/model/llm/__init__.py +7 -0
  33. xinference/model/llm/core.py +18 -1
  34. xinference/model/llm/llama_cpp/core.py +1 -1
  35. xinference/model/llm/llm_family.json +41 -1
  36. xinference/model/llm/llm_family.py +6 -0
  37. xinference/model/llm/llm_family_modelscope.json +43 -1
  38. xinference/model/llm/mlx/core.py +271 -18
  39. xinference/model/llm/mlx/distributed_models/__init__.py +13 -0
  40. xinference/model/llm/mlx/distributed_models/core.py +164 -0
  41. xinference/model/llm/mlx/distributed_models/deepseek_v3.py +75 -0
  42. xinference/model/llm/mlx/distributed_models/qwen2.py +82 -0
  43. xinference/model/llm/mlx/distributed_models/qwen3.py +82 -0
  44. xinference/model/llm/mlx/distributed_models/qwen3_moe.py +76 -0
  45. xinference/model/llm/reasoning_parser.py +12 -6
  46. xinference/model/llm/sglang/core.py +8 -4
  47. xinference/model/llm/transformers/chatglm.py +4 -1
  48. xinference/model/llm/transformers/core.py +4 -2
  49. xinference/model/llm/transformers/multimodal/cogagent.py +10 -4
  50. xinference/model/llm/transformers/multimodal/intern_vl.py +1 -1
  51. xinference/model/llm/utils.py +36 -17
  52. xinference/model/llm/vllm/core.py +142 -34
  53. xinference/model/llm/vllm/distributed_executor.py +96 -21
  54. xinference/model/llm/vllm/xavier/transfer.py +2 -2
  55. xinference/model/rerank/core.py +16 -9
  56. xinference/model/rerank/model_spec.json +3 -3
  57. xinference/model/rerank/model_spec_modelscope.json +3 -3
  58. xinference/web/ui/build/asset-manifest.json +3 -3
  59. xinference/web/ui/build/index.html +1 -1
  60. xinference/web/ui/build/static/js/main.9b12b7f9.js +3 -0
  61. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/0fd4820d93f99509e80d8702dc3f6f8272424acab5608fa7c0e82cb1d3250a87.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/f75545479c17fdfe2a00235fa4a0e9da1ae95e6b3caafba87ded92de6b0240e4.json +1 -0
  67. xinference/web/ui/src/locales/en.json +3 -0
  68. xinference/web/ui/src/locales/ja.json +3 -0
  69. xinference/web/ui/src/locales/ko.json +3 -0
  70. xinference/web/ui/src/locales/zh.json +3 -0
  71. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.dist-info}/METADATA +4 -3
  72. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.dist-info}/RECORD +77 -67
  73. xinference/web/ui/build/static/js/main.8a9e3ba0.js +0 -3
  74. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/cc97b49285d7717c63374766c789141a4329a04582ab32756d7e0e614d4c5c7f.json +0 -1
  79. /xinference/web/ui/build/static/js/{main.8a9e3ba0.js.LICENSE.txt → main.9b12b7f9.js.LICENSE.txt} +0 -0
  80. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.dist-info}/WHEEL +0 -0
  81. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.dist-info}/entry_points.txt +0 -0
  82. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.dist-info}/licenses/LICENSE +0 -0
  83. {xinference-1.7.0.post1.dist-info → xinference-1.7.1.dist-info}/top_level.txt +0 -0
@@ -11,14 +11,32 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import asyncio
16
+ import concurrent.futures
17
+ import importlib
14
18
  import importlib.util
15
19
  import logging
20
+ import pathlib
16
21
  import platform
17
22
  import sys
23
+ import threading
18
24
  import time
19
25
  import uuid
20
26
  from dataclasses import dataclass, field
21
- from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict, Union
27
+ from typing import (
28
+ Any,
29
+ Callable,
30
+ Dict,
31
+ Iterator,
32
+ List,
33
+ Optional,
34
+ Tuple,
35
+ TypedDict,
36
+ Union,
37
+ )
38
+
39
+ import xoscar as xo
22
40
 
23
41
  from ....fields import max_tokens_field
24
42
  from ....types import (
@@ -29,7 +47,7 @@ from ....types import (
29
47
  CompletionUsage,
30
48
  LoRA,
31
49
  )
32
- from ..core import LLM
50
+ from ..core import LLM, chat_context_var
33
51
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
52
  from ..utils import (
35
53
  DEEPSEEK_TOOL_CALL_FAMILY,
@@ -46,6 +64,10 @@ class MLXModelConfig(TypedDict, total=False):
46
64
  max_gpu_memory: str
47
65
  trust_remote_code: bool
48
66
  reasoning_content: bool
67
+ # distributed
68
+ address: Optional[str]
69
+ shard: Optional[int]
70
+ n_worker: Optional[int]
49
71
 
50
72
 
51
73
  class MLXGenerateConfig(TypedDict, total=False):
@@ -71,6 +93,8 @@ class PromptCache:
71
93
 
72
94
 
73
95
  class MLXModel(LLM):
96
+ _rank_to_addresses: Optional[Dict[int, str]]
97
+
74
98
  def __init__(
75
99
  self,
76
100
  model_uid: str,
@@ -84,10 +108,43 @@ class MLXModel(LLM):
84
108
  super().__init__(model_uid, model_family, model_spec, quantization, model_path)
85
109
  self._use_fast_tokenizer = True
86
110
  self._model_config: MLXModelConfig = self._sanitize_model_config(model_config)
111
+ # for distributed
112
+ assert model_config is not None
113
+ self._address = model_config.pop("address", None)
114
+ self._n_worker = model_config.pop("n_worker", 1)
115
+ self._shard = model_config.pop("shard", 0)
116
+ self._driver_info = model_config.pop("driver_info", None) # type: ignore
117
+ self._rank_to_addresses = None
118
+ self._loading_thread = None
119
+ self._loading_error = None
120
+ self._all_worker_started = asyncio.Event()
87
121
  self._max_kv_size = None
88
122
  self._prompt_cache = None
89
123
  if peft_model is not None:
90
124
  raise ValueError("MLX engine has not supported lora yet")
125
+ # used to call async
126
+ self._loop = None
127
+
128
+ def set_loop(self, loop: asyncio.AbstractEventLoop):
129
+ # loop will be passed into ModelWrapper,
130
+ # to call aynsc method with asyncio.run_coroutine_threadsafe
131
+ self._loop = loop # type: ignore
132
+
133
+ @property
134
+ def driver_info(self) -> Optional[dict]:
135
+ return self._driver_info
136
+
137
+ def set_shard_info(self, shard: int, address: str):
138
+ # set shard info to rank 0
139
+ if self._rank_to_addresses is None:
140
+ self._rank_to_addresses = {}
141
+ self._rank_to_addresses[shard] = address
142
+ if len(self._rank_to_addresses) == self._n_worker:
143
+ self._all_worker_started.set()
144
+
145
+ async def get_rank_addresses(self) -> Optional[Dict[int, str]]:
146
+ await self._all_worker_started.wait()
147
+ return self._rank_to_addresses
91
148
 
92
149
  def _sanitize_model_config(
93
150
  self, model_config: Optional[MLXModelConfig]
@@ -158,6 +215,97 @@ class MLXModel(LLM):
158
215
  tokenizer.add_eos_token(stop_token_id)
159
216
  return model, tokenizer
160
217
 
218
+ def _load_model_shard(self, **kwargs):
219
+ try:
220
+ import mlx.core as mx
221
+ from mlx_lm.utils import load_model, load_tokenizer
222
+ except ImportError:
223
+ error_message = "Failed to import module 'mlx_lm'"
224
+ installation_guide = [
225
+ "Please make sure 'mlx_lm' is installed. ",
226
+ "You can install it by `pip install mlx_lm`\n",
227
+ ]
228
+
229
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
230
+
231
+ # Ensure some attributes correctly inited by model actor
232
+ assert (
233
+ self._loop is not None and self._rank_to_addresses is not None
234
+ ), "Service not started correctly"
235
+
236
+ tokenizer_config = dict(
237
+ use_fast=self._use_fast_tokenizer,
238
+ trust_remote_code=kwargs["trust_remote_code"],
239
+ revision=kwargs["revision"],
240
+ )
241
+ logger.debug(
242
+ "loading model with tokenizer config: %s, model config: %s, shard: %d, n_worker: %d",
243
+ tokenizer_config,
244
+ self._model_config,
245
+ self._shard,
246
+ self._n_worker,
247
+ )
248
+
249
+ cache_limit_gb = kwargs.get("cache_limit_gb", None)
250
+ if cache_limit_gb:
251
+ logger.debug(f"Setting cache limit to {cache_limit_gb} GB")
252
+ mx.metal.set_cache_limit(cache_limit_gb * 1024 * 1024 * 1024)
253
+
254
+ self._max_kv_size = kwargs.get("max_kv_size", None)
255
+ self._prompt_cache = PromptCache()
256
+
257
+ self._model, config = load_model(
258
+ pathlib.Path(self.model_path),
259
+ lazy=True,
260
+ get_model_classes=self._get_classes,
261
+ )
262
+ model = self._model.model
263
+ model.rank = self._shard
264
+ model.world_size = self._n_worker
265
+ model.model_uid = self.model_uid
266
+ model.loop = self._loop
267
+ model.address = self._address
268
+ model.rank_to_addresses = self._rank_to_addresses
269
+
270
+ # create actors and so forth
271
+ model.prepare()
272
+ # real load the partial weights
273
+ model.pipeline()
274
+ mx.eval(model.parameters())
275
+
276
+ self._tokenizer = load_tokenizer(
277
+ pathlib.Path(self.model_path),
278
+ tokenizer_config,
279
+ eos_token_ids=config.get("eos_token_id", None),
280
+ )
281
+
282
+ @staticmethod
283
+ def _get_classes(config: dict):
284
+ """
285
+ Retrieve the model and model args classes based on the configuration
286
+ that supported distributed inference.
287
+
288
+ Args:
289
+ config (dict): The model configuration.
290
+
291
+ Returns:
292
+ A tuple containing the Model class and the ModelArgs class.
293
+ """
294
+ from mlx_lm.utils import MODEL_REMAPPING
295
+
296
+ model_type = config["model_type"]
297
+ model_type = MODEL_REMAPPING.get(model_type, model_type)
298
+ try:
299
+ arch = importlib.import_module(
300
+ f"xinference.model.llm.mlx.distributed_models.{model_type}"
301
+ )
302
+ except ImportError:
303
+ msg = f"Model type {model_type} not supported for distributed inference."
304
+ logger.error(msg)
305
+ raise ValueError(msg)
306
+
307
+ return arch.Model, arch.ModelArgs
308
+
161
309
  def load(self):
162
310
  reasoning_content = self._model_config.pop("reasoning_content")
163
311
  enable_thinking = self._model_config.pop("enable_thinking", True)
@@ -172,7 +320,49 @@ class MLXModel(LLM):
172
320
  kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
173
321
  kwargs["cache_limit_gb"] = self._model_config.pop("cache_limit_gb", None)
174
322
 
175
- self._model, self._tokenizer = self._load_model(**kwargs)
323
+ if self._n_worker <= 1:
324
+ self._model, self._tokenizer = self._load_model(**kwargs)
325
+ else:
326
+
327
+ def _load():
328
+ try:
329
+ if self._shard == 0:
330
+ self._driver_info = {"address": self._address}
331
+ self.set_shard_info(0, self._address)
332
+ else:
333
+ assert self._driver_info is not None
334
+ driver_address = self._driver_info["address"]
335
+
336
+ async def wait_for_all_shards():
337
+ model_ref = await xo.actor_ref(
338
+ address=driver_address, uid=self.raw_model_uid
339
+ )
340
+ # set shard info
341
+ await model_ref.set_shard_info(self._shard, self._address)
342
+ # wait for all shards
343
+ self._rank_to_addresses = (
344
+ await model_ref.get_rank_addresses()
345
+ )
346
+
347
+ asyncio.run_coroutine_threadsafe(
348
+ wait_for_all_shards(), self._loop
349
+ ).result()
350
+
351
+ self._load_model_shard(**kwargs)
352
+ except:
353
+ logger.exception("Loading mlx shard model failed")
354
+ self._loading_error = sys.exc_info()
355
+
356
+ # distributed inference
357
+ self._loading_thread = threading.Thread(target=_load)
358
+ self._loading_thread.start()
359
+
360
+ def wait_for_load(self):
361
+ if self._loading_thread:
362
+ self._loading_thread.join()
363
+ if self._loading_error:
364
+ _, err, tb = self._loading_error
365
+ raise err.with_traceback(tb)
176
366
 
177
367
  @classmethod
178
368
  def check_lib(cls) -> bool:
@@ -369,20 +559,57 @@ class MLXModel(LLM):
369
559
  )
370
560
  yield completion_chunk, completion_usage
371
561
 
562
+ def _run_non_drivers(
563
+ self, method: str, stream: bool, *args, **kwargs
564
+ ) -> Optional[concurrent.futures.Future]:
565
+ assert self._n_worker is not None and self._shard is not None
566
+ if self._n_worker == 1 or self._shard > 0:
567
+ # only run for distributed driver
568
+ return None
569
+
570
+ async def run_other_shard(shard: int):
571
+ assert self._rank_to_addresses is not None
572
+ address = self._rank_to_addresses[shard]
573
+ model_actor_ref = await xo.actor_ref(
574
+ address=address, uid=self.raw_model_uid
575
+ )
576
+ # we don't actually need to get the result from shard >= 1
577
+ if stream:
578
+ async for _ in await getattr(model_actor_ref, method)(*args, **kwargs):
579
+ pass
580
+ else:
581
+ await getattr(model_actor_ref, method)(*args, **kwargs)
582
+
583
+ async def run_non_driver_shards():
584
+ logger.debug("Start to run non driver %s", method)
585
+ coros = []
586
+ for rank in range(1, self._n_worker):
587
+ coros.append(run_other_shard(rank))
588
+ await asyncio.gather(*coros)
589
+
590
+ assert self._loop is not None
591
+ return asyncio.run_coroutine_threadsafe(run_non_driver_shards(), self._loop)
592
+
372
593
  def generate(
373
594
  self,
374
595
  prompt: Union[str, Dict[str, Any]],
375
596
  generate_config: Optional[MLXGenerateConfig] = None,
597
+ from_chat: bool = False,
376
598
  ) -> Union[Completion, Iterator[CompletionChunk]]:
377
599
  def generator_wrapper(
378
- prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig
600
+ prompt: Union[str, Dict[str, Any]],
601
+ generate_config: MLXGenerateConfig,
602
+ cb: Callable,
379
603
  ) -> Iterator[CompletionChunk]:
380
- for completion_chunk, completion_usage in self._generate_stream(
381
- prompt,
382
- generate_config,
383
- ):
384
- completion_chunk["usage"] = completion_usage
385
- yield completion_chunk
604
+ try:
605
+ for completion_chunk, completion_usage in self._generate_stream(
606
+ prompt,
607
+ generate_config,
608
+ ):
609
+ completion_chunk["usage"] = completion_usage
610
+ yield completion_chunk
611
+ finally:
612
+ cb()
386
613
 
387
614
  logger.debug(
388
615
  "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
@@ -394,6 +621,9 @@ class MLXModel(LLM):
394
621
  assert self._tokenizer is not None
395
622
 
396
623
  stream = generate_config.get("stream", False)
624
+ fut = self._run_non_drivers(
625
+ "generate", stream, prompt, generate_config=generate_config
626
+ )
397
627
  if not stream:
398
628
  for completion_chunk, completion_usage in self._generate_stream(
399
629
  prompt,
@@ -408,9 +638,18 @@ class MLXModel(LLM):
408
638
  choices=completion_chunk["choices"],
409
639
  usage=completion_usage,
410
640
  )
411
- return completion
641
+ try:
642
+ return completion
643
+ finally:
644
+ if fut:
645
+ fut.result()
412
646
  else:
413
- return generator_wrapper(prompt, generate_config)
647
+
648
+ def finish_callback():
649
+ if fut:
650
+ fut.result()
651
+
652
+ return generator_wrapper(prompt, generate_config, finish_callback)
414
653
 
415
654
 
416
655
  class MLXChatModel(MLXModel, ChatModelMixin):
@@ -452,9 +691,14 @@ class MLXChatModel(MLXModel, ChatModelMixin):
452
691
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
453
692
  model_family = self.model_family.model_family or self.model_family.model_name
454
693
  tools = generate_config.pop("tools", []) if generate_config else None
455
- full_context_kwargs = (
456
- self._get_chat_template_kwargs_from_generate_config(generate_config, self.reasoning_parser) or {} # type: ignore
694
+ chat_template_kwargs = (
695
+ self._get_chat_template_kwargs_from_generate_config(
696
+ generate_config, self.reasoning_parser
697
+ )
698
+ or {}
457
699
  )
700
+ chat_context_var.set(chat_template_kwargs)
701
+ full_context_kwargs = chat_template_kwargs.copy()
458
702
  if tools:
459
703
  if (
460
704
  model_family in QWEN_TOOL_CALL_FAMILY
@@ -470,11 +714,11 @@ class MLXChatModel(MLXModel, ChatModelMixin):
470
714
 
471
715
  stream = generate_config.get("stream", False)
472
716
  if stream:
473
- it = self.generate(full_prompt, generate_config)
717
+ it = self.generate(full_prompt, generate_config, from_chat=True)
474
718
  assert isinstance(it, Iterator)
475
719
  return self._to_chat_completion_chunks(it, self.reasoning_parser)
476
720
  else:
477
- c = self.generate(full_prompt, generate_config)
721
+ c = self.generate(full_prompt, generate_config, from_chat=True)
478
722
  assert not isinstance(c, Iterator)
479
723
  if tools:
480
724
  return self._post_process_completion(
@@ -518,6 +762,11 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
518
762
  return load(self.model_path)
519
763
 
520
764
  def load(self):
765
+ if self._n_worker > 1:
766
+ raise NotImplementedError(
767
+ "Distributed inference is not supported for vision models"
768
+ )
769
+
521
770
  kwargs = {}
522
771
  kwargs["revision"] = self._model_config.get(
523
772
  "revision", self.model_spec.model_revision
@@ -636,10 +885,14 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
636
885
  if "internvl2" not in model_family.lower():
637
886
  from qwen_vl_utils import process_vision_info
638
887
 
639
- full_context_kwargs = (
640
- self._get_chat_template_kwargs_from_generate_config(generate_config, self.reasoning_parser) # type: ignore
888
+ chat_template_kwargs = (
889
+ self._get_chat_template_kwargs_from_generate_config(
890
+ generate_config, self.reasoning_parser
891
+ )
641
892
  or {}
642
893
  )
894
+ chat_context_var.set(chat_template_kwargs)
895
+ full_context_kwargs = chat_template_kwargs.copy()
643
896
  if tools and model_family in QWEN_TOOL_CALL_FAMILY:
644
897
  full_context_kwargs["tools"] = tools
645
898
  assert self.model_family.chat_template is not None
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,164 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import os
18
+ from typing import TYPE_CHECKING, Dict, Optional
19
+
20
+ import xoscar as xo
21
+ from xoscar.utils import lazy_import
22
+
23
+ if TYPE_CHECKING:
24
+ import mlx.core as mx
25
+ else:
26
+ mx = lazy_import("mlx.core")
27
+ logger = logging.getLogger(__name__)
28
+
29
+ DEBUG_DISTRIBUTED_MLX = bool(int(os.getenv("XINFERENCE_DEBUG_DISTRIBUTED_MLX", "0")))
30
+
31
+
32
+ class ReceiverActor(xo.StatelessActor):
33
+ def __init__(self, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+
36
+ self._recv_queue = asyncio.Queue()
37
+
38
+ @classmethod
39
+ def gen_uid(cls, uid: str, rank: int):
40
+ return f"Receiver-{uid}-{rank}"
41
+
42
+ async def send(self, data: "mx.array"):
43
+ # no need to use async function,
44
+ # but make it more convenient to patch this function for test purpose
45
+ if not isinstance(data, mx.array):
46
+ data = mx.array(data)
47
+ self._recv_queue.put_nowait(data)
48
+
49
+ async def recv(self):
50
+ return await self._recv_queue.get()
51
+
52
+
53
+ class DistributedModelMixin:
54
+ rank: int
55
+ world_size: int
56
+ model_uid: Optional[str]
57
+ address: Optional[str]
58
+ _receiver_ref: Optional[xo.ActorRefType[ReceiverActor]]
59
+ rank_to_addresses: Optional[Dict[int, str]]
60
+
61
+ layers: list
62
+
63
+ def __init__(self):
64
+ self.rank = 0
65
+ self.world_size = 1
66
+ self.model_uid = None
67
+ self.loop = None
68
+ self.address = None
69
+ # actor ref
70
+ self._receiver_ref = None
71
+ self.rank_to_addresses = None
72
+
73
+ def prepare(self):
74
+ coro = xo.create_actor(
75
+ ReceiverActor,
76
+ uid=ReceiverActor.gen_uid(self.model_uid, self.rank),
77
+ address=self.address,
78
+ )
79
+ self._receiver_ref = asyncio.run_coroutine_threadsafe(coro, self.loop).result()
80
+ if DEBUG_DISTRIBUTED_MLX:
81
+ logger.debug("Finish preparing distributed env for rank %s", self.rank)
82
+
83
+ def _send_stage_result(self, result: "mx.array"):
84
+ assert self.rank > 0
85
+ assert self.rank_to_addresses is not None
86
+ assert self.model_uid is not None
87
+ last_rank = self.rank - 1
88
+ if DEBUG_DISTRIBUTED_MLX:
89
+ logger.debug(
90
+ "Start to send %s partial result to rank %d", self.model_uid, last_rank
91
+ )
92
+
93
+ async def send():
94
+ receiver_ref = await xo.actor_ref(
95
+ uid=ReceiverActor.gen_uid(self.model_uid, last_rank),
96
+ address=self.rank_to_addresses[last_rank],
97
+ )
98
+ return await receiver_ref.send(result)
99
+
100
+ asyncio.run_coroutine_threadsafe(send(), self.loop).result()
101
+ if DEBUG_DISTRIBUTED_MLX:
102
+ logger.debug(
103
+ "Finish send %s partial result to rank %d, shape %s",
104
+ self.model_uid,
105
+ last_rank,
106
+ result.shape,
107
+ )
108
+
109
+ def _wait_prev_stage_result(self):
110
+ if DEBUG_DISTRIBUTED_MLX:
111
+ logger.debug("Wait for partial result from prev shard %d", self.rank + 1)
112
+ coro = self._receiver_ref.recv()
113
+ result = asyncio.run_coroutine_threadsafe(coro, self.loop).result()
114
+ if DEBUG_DISTRIBUTED_MLX:
115
+ logger.debug(
116
+ "Received partial result from prev shard %d, shape %s",
117
+ self.rank + 1,
118
+ result.shape,
119
+ )
120
+ return result
121
+
122
+ def _broadcast_result(self, result: "mx.array"):
123
+ if DEBUG_DISTRIBUTED_MLX:
124
+ logger.debug("broadcast result from driver")
125
+
126
+ async def broadcast(rank: int):
127
+ assert self.model_uid is not None
128
+ assert self.rank_to_addresses is not None
129
+
130
+ receiver = await xo.actor_ref(
131
+ uid=ReceiverActor.gen_uid(self.model_uid, rank),
132
+ address=self.rank_to_addresses[rank],
133
+ )
134
+ await receiver.send(result)
135
+
136
+ async def broadcast_all():
137
+ coros = []
138
+ for rank in range(1, self.world_size):
139
+ coros.append(broadcast(rank))
140
+ await asyncio.gather(*coros)
141
+
142
+ return asyncio.run_coroutine_threadsafe(broadcast_all(), self.loop).result()
143
+
144
+ def _get_result(self) -> "mx.array":
145
+ if DEBUG_DISTRIBUTED_MLX:
146
+ logger.debug("Get result from broadcasted data on self receiver")
147
+ assert self.model_uid is not None
148
+ coro = xo.actor_ref(
149
+ uid=ReceiverActor.gen_uid(self.model_uid, self.rank), address=self.address
150
+ )
151
+ ref = asyncio.run_coroutine_threadsafe(coro, self.loop).result()
152
+ return asyncio.run_coroutine_threadsafe(ref.recv(), loop=self.loop).result()
153
+
154
+ def pipeline(self):
155
+ pipeline_size, rank = self.world_size, self.rank
156
+ layers_per_rank = len(self.layers) // pipeline_size
157
+ extra = len(self.layers) - layers_per_rank * pipeline_size
158
+ if self.rank < extra:
159
+ layers_per_rank += 1
160
+ self.start_idx = (pipeline_size - rank - 1) * layers_per_rank
161
+ self.end_idx = self.start_idx + layers_per_rank
162
+ self.layers = self.layers[: self.end_idx]
163
+ self.layers[: self.start_idx] = [None] * self.start_idx
164
+ self.num_layers = len(self.layers) - self.start_idx
@@ -0,0 +1,75 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ import mlx.core as mx
18
+ import mlx.nn as nn
19
+ from mlx_lm.models.base import create_attention_mask
20
+ from mlx_lm.models.deepseek_v3 import DeepseekV3Model as _DeepseekV3Model
21
+ from mlx_lm.models.deepseek_v3 import Model as _Model
22
+ from mlx_lm.models.deepseek_v3 import ModelArgs
23
+
24
+ from .core import DistributedModelMixin
25
+
26
+
27
+ class DeepseekV3Model(_DeepseekV3Model, DistributedModelMixin):
28
+ def __init__(self, *args, **kwargs):
29
+ _DeepseekV3Model.__init__(self, *args, **kwargs)
30
+ DistributedModelMixin.__init__(self)
31
+
32
+ def __call__(
33
+ self,
34
+ x: mx.array,
35
+ cache: Optional[Any] = None,
36
+ mask: Optional[mx.array] = None,
37
+ ) -> mx.array:
38
+ h = self.embed_tokens(x)
39
+
40
+ pipeline_rank = self.rank
41
+ pipeline_size = self.world_size
42
+ if mask is None:
43
+ mask = create_attention_mask(h, cache)
44
+
45
+ if cache is None:
46
+ cache = [None] * self.num_layers
47
+
48
+ # Receive from the previous process in the pipeline
49
+
50
+ if pipeline_rank < pipeline_size - 1:
51
+ # wait for previous result
52
+ h = self._wait_prev_stage_result()
53
+
54
+ for i in range(self.num_layers):
55
+ h = self.layers[self.start_idx + i](h, mask, cache[i])
56
+ mx.eval(h)
57
+
58
+ if pipeline_rank != 0:
59
+ # Send to the next process in the pipeline
60
+ self._send_stage_result(h)
61
+ # wait for the final result
62
+ h = self._get_result()
63
+ else:
64
+ self._set_result(h)
65
+
66
+ return self.norm(h)
67
+
68
+
69
+ class Model(_Model):
70
+ def __init__(self, config: ModelArgs):
71
+ nn.Module.__init__(self)
72
+ self.args = config
73
+ self.model_type = config.model_type
74
+ self.model = DeepseekV3Model(config)
75
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)